mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
fix: large sessions summarize/truncate (#2846)
This commit is contained in:
@@ -27,11 +27,20 @@ impl Agent {
|
||||
&OldestFirstTruncation,
|
||||
)?;
|
||||
|
||||
// Add an assistant message to the truncated messages
|
||||
// to ensure the assistant's response is included in the context.
|
||||
// Only add an assistant message if we have room for it and it won't cause another overflow
|
||||
let assistant_message = Message::assistant().with_text("I had run into a context length exceeded error so I truncated some of the oldest messages in our conversation.");
|
||||
new_messages.push(assistant_message.clone());
|
||||
new_token_counts.push(token_counter.count_chat_tokens("", &[assistant_message], &[]));
|
||||
let assistant_tokens =
|
||||
token_counter.count_chat_tokens("", &[assistant_message.clone()], &[]);
|
||||
|
||||
let current_total: usize = new_token_counts.iter().sum();
|
||||
if current_total + assistant_tokens <= target_context_limit {
|
||||
new_messages.push(assistant_message);
|
||||
new_token_counts.push(assistant_tokens);
|
||||
} else {
|
||||
// If we can't fit the assistant message, at least log what happened
|
||||
tracing::warn!("Cannot add truncation notice message due to context limits. Current: {}, Assistant: {}, Limit: {}",
|
||||
current_total, assistant_tokens, target_context_limit);
|
||||
}
|
||||
|
||||
Ok((new_messages, new_token_counts))
|
||||
}
|
||||
@@ -54,8 +63,18 @@ impl Agent {
|
||||
let assistant_message = Message::assistant().with_text(
|
||||
"I had run into a context length exceeded error so I summarized our conversation.",
|
||||
);
|
||||
new_messages.push(assistant_message.clone());
|
||||
new_token_counts.push(token_counter.count_chat_tokens("", &[assistant_message], &[]));
|
||||
let assistant_tokens =
|
||||
token_counter.count_chat_tokens("", &[assistant_message.clone()], &[]);
|
||||
|
||||
let current_total: usize = new_token_counts.iter().sum();
|
||||
if current_total + assistant_tokens <= target_context_limit {
|
||||
new_messages.push(assistant_message);
|
||||
new_token_counts.push(assistant_tokens);
|
||||
} else {
|
||||
// If we can't fit the assistant message, at least log what happened
|
||||
tracing::warn!("Cannot add summarization notice message due to context limits. Current: {}, Assistant: {}, Limit: {}",
|
||||
current_total, assistant_tokens, target_context_limit);
|
||||
}
|
||||
}
|
||||
|
||||
Ok((new_messages, new_token_counts))
|
||||
|
||||
@@ -1,8 +1,166 @@
|
||||
use crate::message::Message;
|
||||
use crate::message::{Message, MessageContent};
|
||||
use anyhow::{anyhow, Result};
|
||||
use mcp_core::Role;
|
||||
use mcp_core::{Content, ResourceContents, Role};
|
||||
use std::collections::HashSet;
|
||||
use tracing::debug;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// Maximum size for truncated content in characters
|
||||
const MAX_TRUNCATED_CONTENT_SIZE: usize = 5000;
|
||||
|
||||
/// Handles messages that are individually larger than the context limit
|
||||
/// by truncating their content rather than removing them entirely
|
||||
fn handle_oversized_messages(
|
||||
messages: &[Message],
|
||||
token_counts: &[usize],
|
||||
context_limit: usize,
|
||||
strategy: &dyn TruncationStrategy,
|
||||
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
|
||||
let mut truncated_messages = Vec::new();
|
||||
let mut truncated_token_counts = Vec::new();
|
||||
let mut any_truncated = false;
|
||||
|
||||
// Create a basic token counter for re-estimating truncated content
|
||||
// Note: This is a rough approximation since we don't have access to the actual tokenizer here
|
||||
let estimate_tokens = |text: &str| -> usize {
|
||||
// Rough approximation: 1 token per 4 characters for English text
|
||||
(text.len() / 4).max(1)
|
||||
};
|
||||
|
||||
for (i, (message, &original_tokens)) in messages.iter().zip(token_counts.iter()).enumerate() {
|
||||
if original_tokens > context_limit {
|
||||
warn!(
|
||||
"Message {} has {} tokens, exceeding context limit of {}",
|
||||
i, original_tokens, context_limit
|
||||
);
|
||||
|
||||
// Try to truncate the message content
|
||||
let truncated_message = truncate_message_content(message, MAX_TRUNCATED_CONTENT_SIZE)?;
|
||||
let estimated_new_tokens =
|
||||
estimate_message_tokens(&truncated_message, &estimate_tokens);
|
||||
|
||||
if estimated_new_tokens > context_limit {
|
||||
// Even truncated message is too large, skip it entirely
|
||||
warn!("Skipping message {} as even truncated version ({} tokens) exceeds context limit", i, estimated_new_tokens);
|
||||
any_truncated = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
truncated_messages.push(truncated_message);
|
||||
truncated_token_counts.push(estimated_new_tokens);
|
||||
any_truncated = true;
|
||||
} else {
|
||||
truncated_messages.push(message.clone());
|
||||
truncated_token_counts.push(original_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
if any_truncated {
|
||||
debug!("Truncated large message content, now attempting normal truncation");
|
||||
// After content truncation, try normal truncation if still needed
|
||||
return truncate_messages(
|
||||
&truncated_messages,
|
||||
&truncated_token_counts,
|
||||
context_limit,
|
||||
strategy,
|
||||
);
|
||||
}
|
||||
|
||||
Ok((truncated_messages, truncated_token_counts))
|
||||
}
|
||||
|
||||
/// Truncates the content within a message while preserving its structure
|
||||
fn truncate_message_content(message: &Message, max_content_size: usize) -> Result<Message> {
|
||||
let mut new_message = message.clone();
|
||||
|
||||
for content in &mut new_message.content {
|
||||
match content {
|
||||
MessageContent::Text(text_content) => {
|
||||
if text_content.text.len() > max_content_size {
|
||||
let truncated = format!(
|
||||
"{}\n\n[... content truncated from {} to {} characters ...]",
|
||||
&text_content.text[..max_content_size.min(text_content.text.len())],
|
||||
text_content.text.len(),
|
||||
max_content_size
|
||||
);
|
||||
text_content.text = truncated;
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResponse(tool_response) => {
|
||||
if let Ok(ref mut result) = tool_response.tool_result {
|
||||
for content_item in result {
|
||||
if let Content::Text(ref mut text_content) = content_item {
|
||||
if text_content.text.len() > max_content_size {
|
||||
let truncated = format!(
|
||||
"{}\n\n[... tool response truncated from {} to {} characters ...]",
|
||||
&text_content.text[..max_content_size.min(text_content.text.len())],
|
||||
text_content.text.len(),
|
||||
max_content_size
|
||||
);
|
||||
text_content.text = truncated;
|
||||
}
|
||||
}
|
||||
// Handle Resource content which might contain large text
|
||||
else if let Content::Resource(ref mut resource_content) = content_item {
|
||||
if let ResourceContents::TextResourceContents { text, .. } =
|
||||
&mut resource_content.resource
|
||||
{
|
||||
if text.len() > max_content_size {
|
||||
let truncated = format!(
|
||||
"{}\n\n[... resource content truncated from {} to {} characters ...]",
|
||||
&text[..max_content_size.min(text.len())],
|
||||
text.len(),
|
||||
max_content_size
|
||||
);
|
||||
*text = truncated;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Other content types are typically smaller, but we could extend this if needed
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(new_message)
|
||||
}
|
||||
|
||||
/// Estimates token count for a message using a simple heuristic
|
||||
fn estimate_message_tokens(message: &Message, estimate_fn: &dyn Fn(&str) -> usize) -> usize {
|
||||
let mut total_tokens = 10; // Base overhead for message structure
|
||||
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text_content) => {
|
||||
total_tokens += estimate_fn(&text_content.text);
|
||||
}
|
||||
MessageContent::ToolResponse(tool_response) => {
|
||||
if let Ok(ref result) = tool_response.tool_result {
|
||||
for content_item in result {
|
||||
match content_item {
|
||||
Content::Text(text_content) => {
|
||||
total_tokens += estimate_fn(&text_content.text);
|
||||
}
|
||||
Content::Resource(resource_content) => {
|
||||
match &resource_content.resource {
|
||||
ResourceContents::TextResourceContents { text, .. } => {
|
||||
total_tokens += estimate_fn(text);
|
||||
}
|
||||
_ => total_tokens += 5, // Small overhead for other resource types
|
||||
}
|
||||
}
|
||||
_ => total_tokens += 5, // Small overhead for other content types
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => total_tokens += 5, // Small overhead for other content types
|
||||
}
|
||||
}
|
||||
|
||||
total_tokens
|
||||
}
|
||||
|
||||
/// Truncates the messages to fit within the model's context window.
|
||||
/// Mutates the input messages and token counts in place.
|
||||
@@ -31,6 +189,17 @@ pub fn truncate_messages(
|
||||
debug!("Total tokens before truncation: {}", total_tokens);
|
||||
|
||||
// Check if any individual message is larger than the context limit
|
||||
// First, check for any message that's too large
|
||||
let max_message_tokens = token_counts.iter().max().copied().unwrap_or(0);
|
||||
if max_message_tokens > context_limit {
|
||||
// Try to handle large messages by truncating their content
|
||||
debug!(
|
||||
"Found oversized message with {} tokens, attempting content truncation",
|
||||
max_message_tokens
|
||||
);
|
||||
return handle_oversized_messages(&messages, &token_counts, context_limit, strategy);
|
||||
}
|
||||
|
||||
let min_user_msg_tokens = messages
|
||||
.iter()
|
||||
.zip(token_counts.iter())
|
||||
@@ -41,7 +210,7 @@ pub fn truncate_messages(
|
||||
// If there are no valid user messages, or the smallest one is too big for the context
|
||||
if min_user_msg_tokens.is_none() || min_user_msg_tokens.unwrap() > context_limit {
|
||||
return Err(anyhow!(
|
||||
"Not possible to truncate messages within context limit"
|
||||
"Not possible to truncate messages within context limit: no suitable user messages found"
|
||||
));
|
||||
}
|
||||
|
||||
@@ -53,6 +222,28 @@ pub fn truncate_messages(
|
||||
let indices_to_remove =
|
||||
strategy.determine_indices_to_remove(&messages, &token_counts, context_limit)?;
|
||||
|
||||
// Circuit breaker: if we can't remove enough messages, fail gracefully
|
||||
let tokens_to_remove: usize = indices_to_remove
|
||||
.iter()
|
||||
.map(|&i| token_counts.get(i).copied().unwrap_or(0))
|
||||
.sum();
|
||||
|
||||
if total_tokens - tokens_to_remove > context_limit && !indices_to_remove.is_empty() {
|
||||
debug!(
|
||||
"Standard truncation insufficient: {} tokens remain after removing {} tokens",
|
||||
total_tokens - tokens_to_remove,
|
||||
tokens_to_remove
|
||||
);
|
||||
// Try more aggressive truncation or content truncation
|
||||
return handle_oversized_messages(&messages, &token_counts, context_limit, strategy);
|
||||
}
|
||||
|
||||
if indices_to_remove.is_empty() && total_tokens > context_limit {
|
||||
return Err(anyhow!(
|
||||
"Cannot truncate any messages: all messages may be essential or too large individually"
|
||||
));
|
||||
}
|
||||
|
||||
// Step 3: Remove the marked messages
|
||||
// Vectorize the set and sort in reverse order to avoid shifting indices when removing
|
||||
let mut indices_to_remove = indices_to_remove.iter().cloned().collect::<Vec<usize>>();
|
||||
@@ -212,6 +403,14 @@ mod tests {
|
||||
(Message::user().with_tool_response(id, Ok(result)), tokens)
|
||||
}
|
||||
|
||||
// Helper function to create a large tool response with massive content
|
||||
fn large_tool_response(id: &str, large_text: String, tokens: usize) -> (Message, usize) {
|
||||
(
|
||||
Message::user().with_tool_response(id, Ok(vec![Content::text(large_text)])),
|
||||
tokens,
|
||||
)
|
||||
}
|
||||
|
||||
// Helper function to create messages with alternating user and assistant
|
||||
// text messages of a fixed token count
|
||||
fn create_messages_with_counts(
|
||||
@@ -237,22 +436,70 @@ mod tests {
|
||||
(messages, token_counts)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_oversized_single_message() -> Result<()> {
|
||||
// Create a scenario similar to the real issue: one very large tool response
|
||||
let large_content = "A".repeat(50000); // Very large content
|
||||
let messages = vec![
|
||||
user_text(1, 10).0,
|
||||
assistant_tool_request(
|
||||
"tool1",
|
||||
ToolCall::new("read_file", json!({"path": "large_file.txt"})),
|
||||
20,
|
||||
)
|
||||
.0,
|
||||
large_tool_response("tool1", large_content, 100000).0, // Massive tool response
|
||||
user_text(2, 10).0,
|
||||
];
|
||||
let token_counts = vec![10, 20, 100000, 10]; // One message is huge
|
||||
let context_limit = 5000; // Much smaller than the large message
|
||||
|
||||
let result = truncate_messages(
|
||||
&messages,
|
||||
&token_counts,
|
||||
context_limit,
|
||||
&OldestFirstTruncation,
|
||||
);
|
||||
|
||||
// Should succeed by truncating the large content
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Should handle oversized message by content truncation"
|
||||
);
|
||||
let (truncated_messages, truncated_counts) = result.unwrap();
|
||||
|
||||
// Should have some messages remaining
|
||||
assert!(
|
||||
!truncated_messages.is_empty(),
|
||||
"Should have some messages left"
|
||||
);
|
||||
|
||||
// Total should be within limit
|
||||
let total_tokens: usize = truncated_counts.iter().sum();
|
||||
assert!(
|
||||
total_tokens <= context_limit,
|
||||
"Total tokens {} should be <= context limit {}",
|
||||
total_tokens,
|
||||
context_limit
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oldest_first_no_truncation() -> Result<()> {
|
||||
let (messages, token_counts) = create_messages_with_counts(1, 10, false);
|
||||
let context_limit = 25;
|
||||
|
||||
let mut messages_clone = messages.clone();
|
||||
let mut token_counts_clone = token_counts.clone();
|
||||
truncate_messages(
|
||||
&mut messages_clone,
|
||||
&mut token_counts_clone,
|
||||
let result = truncate_messages(
|
||||
&messages,
|
||||
&token_counts,
|
||||
context_limit,
|
||||
&OldestFirstTruncation,
|
||||
)?;
|
||||
|
||||
assert_eq!(messages_clone, messages);
|
||||
assert_eq!(token_counts_clone, token_counts);
|
||||
assert_eq!(result.0, messages);
|
||||
assert_eq!(result.1, token_counts);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -287,28 +534,27 @@ mod tests {
|
||||
let token_counts = vec![15, 20, 10, 25, 10, 30, 20, 35, 5];
|
||||
let context_limit = 100; // Force truncation while preserving some tool interactions
|
||||
|
||||
let mut messages_clone = messages.clone();
|
||||
let mut token_counts_clone = token_counts.clone();
|
||||
truncate_messages(
|
||||
&mut messages_clone,
|
||||
&mut token_counts_clone,
|
||||
let result = truncate_messages(
|
||||
&messages,
|
||||
&token_counts,
|
||||
context_limit,
|
||||
&OldestFirstTruncation,
|
||||
)?;
|
||||
let (truncated_messages, truncated_counts) = result;
|
||||
|
||||
// Verify that tool pairs are kept together and the conversation remains coherent
|
||||
assert!(messages_clone.len() >= 3); // At least one complete interaction should remain
|
||||
assert!(messages_clone.last().unwrap().role == Role::User); // Last message should be from user
|
||||
assert!(truncated_messages.len() >= 3); // At least one complete interaction should remain
|
||||
assert!(truncated_messages.last().unwrap().role == Role::User); // Last message should be from user
|
||||
|
||||
// Verify tool pairs are either both present or both removed
|
||||
let tool_ids: HashSet<_> = messages_clone
|
||||
let tool_ids: HashSet<_> = truncated_messages
|
||||
.iter()
|
||||
.flat_map(|m| m.get_tool_ids())
|
||||
.collect();
|
||||
|
||||
// Each tool ID should appear 0 or 2 times (request + response)
|
||||
for id in tool_ids {
|
||||
let count = messages_clone
|
||||
let count = truncated_messages
|
||||
.iter()
|
||||
.flat_map(|m| m.get_tool_ids().into_iter())
|
||||
.filter(|&tool_id| tool_id == id)
|
||||
@@ -316,21 +562,26 @@ mod tests {
|
||||
assert!(count == 0 || count == 2, "Tool pair was split: {}", id);
|
||||
}
|
||||
|
||||
// Total should be within limit
|
||||
let total_tokens: usize = truncated_counts.iter().sum();
|
||||
assert!(total_tokens <= context_limit);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_case_context_window() -> Result<()> {
|
||||
// Test case where we're exactly at the context limit
|
||||
let (mut messages, mut token_counts) = create_messages_with_counts(2, 25, false);
|
||||
let (messages, token_counts) = create_messages_with_counts(2, 25, false);
|
||||
let context_limit = 100; // Exactly matches total tokens
|
||||
|
||||
(messages, token_counts) = truncate_messages(
|
||||
&mut messages,
|
||||
&mut token_counts,
|
||||
let result = truncate_messages(
|
||||
&messages,
|
||||
&token_counts,
|
||||
context_limit,
|
||||
&OldestFirstTruncation,
|
||||
)?;
|
||||
let (mut messages, mut token_counts) = result;
|
||||
|
||||
assert_eq!(messages.len(), 4); // No truncation needed
|
||||
assert_eq!(token_counts.iter().sum::<usize>(), 100);
|
||||
@@ -339,12 +590,13 @@ mod tests {
|
||||
messages.push(user_text(5, 1).0);
|
||||
token_counts.push(1);
|
||||
|
||||
(messages, token_counts) = truncate_messages(
|
||||
&mut messages,
|
||||
&mut token_counts,
|
||||
let result = truncate_messages(
|
||||
&messages,
|
||||
&token_counts,
|
||||
context_limit,
|
||||
&OldestFirstTruncation,
|
||||
)?;
|
||||
let (messages, token_counts) = result;
|
||||
|
||||
assert!(token_counts.iter().sum::<usize>() <= context_limit);
|
||||
assert!(messages.last().unwrap().role == Role::User);
|
||||
@@ -376,30 +628,29 @@ mod tests {
|
||||
}
|
||||
|
||||
let context_limit = 50; // Force partial truncation
|
||||
let mut messages_clone = messages.clone();
|
||||
let mut token_counts_clone = token_counts.clone();
|
||||
|
||||
(messages_clone, _) = truncate_messages(
|
||||
&mut messages_clone,
|
||||
&mut token_counts_clone,
|
||||
let result = truncate_messages(
|
||||
&messages,
|
||||
&token_counts,
|
||||
context_limit,
|
||||
&OldestFirstTruncation,
|
||||
)?;
|
||||
let (truncated_messages, _) = result;
|
||||
|
||||
// Verify that remaining tool chains are complete
|
||||
let remaining_tool_ids: HashSet<_> = messages_clone
|
||||
let remaining_tool_ids: HashSet<_> = truncated_messages
|
||||
.iter()
|
||||
.flat_map(|m| m.get_tool_ids())
|
||||
.collect();
|
||||
|
||||
for _id in remaining_tool_ids {
|
||||
// Count request/response pairs
|
||||
let requests = messages_clone
|
||||
let requests = truncated_messages
|
||||
.iter()
|
||||
.flat_map(|m| m.get_tool_request_ids().into_iter())
|
||||
.count();
|
||||
|
||||
let responses = messages_clone
|
||||
let responses = truncated_messages
|
||||
.iter()
|
||||
.flat_map(|m| m.get_tool_response_ids().into_iter())
|
||||
.count();
|
||||
@@ -414,22 +665,23 @@ mod tests {
|
||||
#[test]
|
||||
fn test_truncation_with_image_content() -> Result<()> {
|
||||
// Create a conversation with image content mixed in
|
||||
let mut messages = vec![
|
||||
let messages = vec![
|
||||
Message::user().with_image("base64_data", "image/png"), // 50 tokens
|
||||
Message::assistant().with_text("I see the image"), // 10 tokens
|
||||
Message::user().with_text("Can you describe it?"), // 10 tokens
|
||||
Message::assistant().with_text("It shows..."), // 20 tokens
|
||||
Message::user().with_text("Thanks!"), // 5 tokens
|
||||
];
|
||||
let mut token_counts = vec![50, 10, 10, 20, 5];
|
||||
let token_counts = vec![50, 10, 10, 20, 5];
|
||||
let context_limit = 45; // Force truncation
|
||||
|
||||
(messages, token_counts) = truncate_messages(
|
||||
&mut messages,
|
||||
&mut token_counts,
|
||||
let result = truncate_messages(
|
||||
&messages,
|
||||
&token_counts,
|
||||
context_limit,
|
||||
&OldestFirstTruncation,
|
||||
)?;
|
||||
let (messages, token_counts) = result;
|
||||
|
||||
// Verify the conversation still makes sense
|
||||
assert!(messages.len() >= 1);
|
||||
@@ -442,24 +694,19 @@ mod tests {
|
||||
#[test]
|
||||
fn test_error_cases() -> Result<()> {
|
||||
// Test impossibly small context window
|
||||
let (mut messages, mut token_counts) = create_messages_with_counts(1, 10, false);
|
||||
let (messages, token_counts) = create_messages_with_counts(1, 10, false);
|
||||
let result = truncate_messages(
|
||||
&mut messages,
|
||||
&mut token_counts,
|
||||
&messages,
|
||||
&token_counts,
|
||||
5, // Impossibly small context
|
||||
&OldestFirstTruncation,
|
||||
);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test unmatched token counts
|
||||
let mut messages = vec![user_text(1, 10).0];
|
||||
let mut token_counts = vec![10, 10]; // Mismatched length
|
||||
let result = truncate_messages(
|
||||
&mut messages,
|
||||
&mut token_counts,
|
||||
100,
|
||||
&OldestFirstTruncation,
|
||||
);
|
||||
let messages = vec![user_text(1, 10).0];
|
||||
let token_counts = vec![10, 10]; // Mismatched length
|
||||
let result = truncate_messages(&messages, &token_counts, 100, &OldestFirstTruncation);
|
||||
assert!(result.is_err());
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -211,7 +211,20 @@ pub fn generate_session_id() -> String {
|
||||
///
|
||||
/// Creates the file if it doesn't exist, reads and deserializes all messages if it does.
|
||||
/// The first line of the file is expected to be metadata, and the rest are messages.
|
||||
/// Large messages are automatically truncated to prevent memory issues.
|
||||
pub fn read_messages(session_file: &Path) -> Result<Vec<Message>> {
|
||||
read_messages_with_truncation(session_file, Some(50000)) // 50KB limit per message content
|
||||
}
|
||||
|
||||
/// Read messages from a session file with optional content truncation
|
||||
///
|
||||
/// Creates the file if it doesn't exist, reads and deserializes all messages if it does.
|
||||
/// The first line of the file is expected to be metadata, and the rest are messages.
|
||||
/// If max_content_size is Some, large message content will be truncated during loading.
|
||||
pub fn read_messages_with_truncation(
|
||||
session_file: &Path,
|
||||
max_content_size: Option<usize>,
|
||||
) -> Result<Vec<Message>> {
|
||||
let file = fs::OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
@@ -231,18 +244,163 @@ pub fn read_messages(session_file: &Path) -> Result<Vec<Message>> {
|
||||
// Metadata successfully parsed, continue with the rest of the lines as messages
|
||||
} else {
|
||||
// This is not metadata, it's a message
|
||||
messages.push(serde_json::from_str::<Message>(&line)?);
|
||||
let message = parse_message_with_truncation(&line, max_content_size)?;
|
||||
messages.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
// Read the rest of the lines as messages
|
||||
for line in lines {
|
||||
messages.push(serde_json::from_str::<Message>(&line?)?);
|
||||
let line = line?;
|
||||
let message = parse_message_with_truncation(&line, max_content_size)?;
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
/// Parse a message from JSON string with optional content truncation
|
||||
fn parse_message_with_truncation(
|
||||
json_str: &str,
|
||||
max_content_size: Option<usize>,
|
||||
) -> Result<Message> {
|
||||
// First try to parse normally
|
||||
match serde_json::from_str::<Message>(json_str) {
|
||||
Ok(mut message) => {
|
||||
// If we have a size limit, check and truncate if needed
|
||||
if let Some(max_size) = max_content_size {
|
||||
truncate_message_content_in_place(&mut message, max_size);
|
||||
}
|
||||
Ok(message)
|
||||
}
|
||||
Err(e) => {
|
||||
// If parsing fails and the string is very long, it might be due to size
|
||||
if json_str.len() > 100000 {
|
||||
tracing::warn!(
|
||||
"Failed to parse very large message ({}KB), attempting truncation",
|
||||
json_str.len() / 1024
|
||||
);
|
||||
|
||||
// Try to truncate the JSON string itself before parsing
|
||||
let truncated_json = if let Some(max_size) = max_content_size {
|
||||
truncate_json_string(json_str, max_size)
|
||||
} else {
|
||||
json_str.to_string()
|
||||
};
|
||||
|
||||
match serde_json::from_str::<Message>(&truncated_json) {
|
||||
Ok(message) => {
|
||||
tracing::info!("Successfully parsed message after JSON truncation");
|
||||
Ok(message)
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::error!("Failed to parse message even after truncation, skipping");
|
||||
// Return a placeholder message indicating the issue
|
||||
Ok(Message::user()
|
||||
.with_text("[Message too large to load - content truncated]"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncate content within a message in place
|
||||
fn truncate_message_content_in_place(message: &mut Message, max_content_size: usize) {
|
||||
use crate::message::MessageContent;
|
||||
use mcp_core::{Content, ResourceContents};
|
||||
|
||||
for content in &mut message.content {
|
||||
match content {
|
||||
MessageContent::Text(text_content) => {
|
||||
if text_content.text.len() > max_content_size {
|
||||
let truncated = format!(
|
||||
"{}\n\n[... content truncated during session loading from {} to {} characters ...]",
|
||||
&text_content.text[..max_content_size.min(text_content.text.len())],
|
||||
text_content.text.len(),
|
||||
max_content_size
|
||||
);
|
||||
text_content.text = truncated;
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResponse(tool_response) => {
|
||||
if let Ok(ref mut result) = tool_response.tool_result {
|
||||
for content_item in result {
|
||||
match content_item {
|
||||
Content::Text(ref mut text_content) => {
|
||||
if text_content.text.len() > max_content_size {
|
||||
let truncated = format!(
|
||||
"{}\n\n[... tool response truncated during session loading from {} to {} characters ...]",
|
||||
&text_content.text[..max_content_size.min(text_content.text.len())],
|
||||
text_content.text.len(),
|
||||
max_content_size
|
||||
);
|
||||
text_content.text = truncated;
|
||||
}
|
||||
}
|
||||
Content::Resource(ref mut resource_content) => {
|
||||
if let ResourceContents::TextResourceContents { text, .. } =
|
||||
&mut resource_content.resource
|
||||
{
|
||||
if text.len() > max_content_size {
|
||||
let truncated = format!(
|
||||
"{}\n\n[... resource content truncated during session loading from {} to {} characters ...]",
|
||||
&text[..max_content_size.min(text.len())],
|
||||
text.len(),
|
||||
max_content_size
|
||||
);
|
||||
*text = truncated;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {} // Other content types are typically smaller
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {} // Other content types are typically smaller
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to truncate a JSON string by finding and truncating large text values
|
||||
fn truncate_json_string(json_str: &str, max_content_size: usize) -> String {
|
||||
// This is a heuristic approach - look for large text values in the JSON
|
||||
// and truncate them. This is not perfect but should handle the common case
|
||||
// of large tool responses.
|
||||
|
||||
if json_str.len() <= max_content_size * 2 {
|
||||
return json_str.to_string();
|
||||
}
|
||||
|
||||
// Try to find patterns that look like large text content
|
||||
// Look for "text":"..." patterns and truncate the content
|
||||
let mut result = json_str.to_string();
|
||||
|
||||
// Simple regex-like approach to find and truncate large text values
|
||||
if let Some(start) = result.find("\"text\":\"") {
|
||||
let text_start = start + 8; // Length of "text":"
|
||||
if let Some(end) = result[text_start..].find("\",") {
|
||||
let text_end = text_start + end;
|
||||
let text_content = &result[text_start..text_end];
|
||||
|
||||
if text_content.len() > max_content_size {
|
||||
let truncated_text = format!(
|
||||
"{}\n\n[... content truncated during JSON parsing from {} to {} characters ...]",
|
||||
&text_content[..max_content_size.min(text_content.len())],
|
||||
text_content.len(),
|
||||
max_content_size
|
||||
);
|
||||
result.replace_range(text_start..text_end, &truncated_text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Read session metadata from a session file
|
||||
///
|
||||
/// Returns default empty metadata if the file doesn't exist or has no metadata.
|
||||
@@ -462,12 +620,13 @@ mod tests {
|
||||
let dir = tempdir()?;
|
||||
let file_path = dir.path().join("special.jsonl");
|
||||
|
||||
// Insert some problematic JSON-like content between long text
|
||||
// Insert some problematic JSON-like content between moderately long text
|
||||
// (keeping under truncation limit to test serialization/deserialization)
|
||||
let long_text = format!(
|
||||
"Start_of_message\n{}{}SOME_MIDDLE_TEXT{}End_of_message",
|
||||
"A".repeat(100_000),
|
||||
"A".repeat(10_000), // Reduced from 100_000 to stay under 50KB limit
|
||||
"\"}]\n",
|
||||
"A".repeat(100_000)
|
||||
"A".repeat(10_000) // Reduced from 100_000 to stay under 50KB limit
|
||||
);
|
||||
|
||||
let special_chars = vec![
|
||||
@@ -562,6 +721,56 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_large_content_truncation() -> Result<()> {
|
||||
let dir = tempdir()?;
|
||||
let file_path = dir.path().join("large_content.jsonl");
|
||||
|
||||
// Create a message with content larger than the 50KB truncation limit
|
||||
let very_large_text = "A".repeat(100_000); // 100KB of text
|
||||
let messages = vec![
|
||||
Message::user().with_text(&very_large_text),
|
||||
Message::assistant().with_text("Small response"),
|
||||
];
|
||||
|
||||
// Write messages
|
||||
persist_messages(&file_path, &messages, None).await?;
|
||||
|
||||
// Read them back - should be truncated
|
||||
let read_messages = read_messages(&file_path)?;
|
||||
|
||||
assert_eq!(messages.len(), read_messages.len());
|
||||
|
||||
// First message should be truncated
|
||||
if let Some(MessageContent::Text(read_text)) = read_messages[0].content.first() {
|
||||
assert!(
|
||||
read_text.text.len() < very_large_text.len(),
|
||||
"Content should be truncated"
|
||||
);
|
||||
assert!(
|
||||
read_text
|
||||
.text
|
||||
.contains("content truncated during session loading"),
|
||||
"Should contain truncation notice"
|
||||
);
|
||||
assert!(
|
||||
read_text.text.starts_with("AAAA"),
|
||||
"Should start with original content"
|
||||
);
|
||||
} else {
|
||||
panic!("Expected text content in first message");
|
||||
}
|
||||
|
||||
// Second message should be unchanged
|
||||
if let Some(MessageContent::Text(read_text)) = read_messages[1].content.first() {
|
||||
assert_eq!(read_text.text, "Small response");
|
||||
} else {
|
||||
panic!("Expected text content in second message");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metadata_special_chars() -> Result<()> {
|
||||
let dir = tempdir()?;
|
||||
|
||||
Reference in New Issue
Block a user