feat: present options to user when context length is exceeded (#2207)

Co-authored-by: Yingjie He <yingjiehe@squareup.com>
This commit is contained in:
Salman Mohammed
2025-04-23 21:19:32 -03:00
committed by GitHub
parent 1b9699cca3
commit fd1f829751
19 changed files with 825 additions and 303 deletions

View File

@@ -6,6 +6,7 @@ mod prompt;
mod thinking;
pub use builder::{build_session, SessionBuilderConfig};
use console::Color;
use goose::permission::permission_confirmation::PrincipalType;
use goose::permission::Permission;
use goose::permission::PermissionConfirmation;
@@ -592,7 +593,7 @@ impl Session {
.reply(
&self.messages,
Some(SessionConfig {
id: session_id,
id: session_id.clone(),
working_dir: std::env::current_dir()
.expect("failed to get current session working directory"),
}),
@@ -622,6 +623,54 @@ impl Session {
principal_type: PrincipalType::Tool,
permission,
},).await;
} else if let Some(MessageContent::ContextLengthExceeded(_)) = message.content.first() {
output::hide_thinking();
let prompt = "The model's context length is maxed out. You will need to reduce the # msgs. Do you want to?".to_string();
let selected = cliclack::select(prompt)
.item("clear", "Clear Session", "Removes all messages from Goose's memory")
.item("truncate", "Truncate Messages", "Removes old messages till context is within limits")
.item("summarize", "Summarize Session", "Summarize the session to reduce context length")
.interact()?;
match selected {
"clear" => {
self.messages.clear();
let msg = format!("Session cleared.\n{}", "-".repeat(50));
output::render_text(&msg, Some(Color::Yellow), true);
break; // exit the loop to hand back control to the user
}
"truncate" => {
// Truncate messages to fit within context length
let (truncated_messages, _) = self.agent.truncate_context(&self.messages).await?;
let msg = format!("Context maxed out\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50));
output::render_text("", Some(Color::Yellow), true);
output::render_text(&msg, Some(Color::Yellow), true);
self.messages = truncated_messages;
}
"summarize" => {
// Summarize messages to fit within context length
let (summarized_messages, _) = self.agent.summarize_context(&self.messages).await?;
let msg = format!("Context maxed out\n{}\nGoose summarized messages for you.", "-".repeat(50));
output::render_text(&msg, Some(Color::Yellow), true);
self.messages = summarized_messages;
}
_ => {
unreachable!()
}
}
// Restart the stream after handling ContextLengthExceeded
stream = self
.agent
.reply(
&self.messages,
Some(SessionConfig {
id: session_id.clone(),
working_dir: std::env::current_dir()
.expect("failed to get current session working directory"),
}),
)
.await?;
}
// otherwise we have a model/tool to render
else {

View File

@@ -1,5 +1,5 @@
use bat::WrappingMode;
use console::style;
use console::{style, Color};
use goose::config::Config;
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
use mcp_core::prompt::PromptArgument;
@@ -143,6 +143,19 @@ pub fn render_message(message: &Message, debug: bool) {
println!();
}
pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
let mut styled_text = style(text);
if dim {
styled_text = styled_text.dim();
}
if let Some(color) = color {
styled_text = styled_text.fg(color);
} else {
styled_text = styled_text.green();
}
println!("\n{}\n", styled_text);
}
pub fn render_enter_plan_mode() {
println!(
"\n{} {}\n",

View File

@@ -0,0 +1,64 @@
use super::utils::verify_secret_key;
use crate::state::AppState;
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
routing::post,
Json, Router,
};
use goose::message::Message;
use serde::{Deserialize, Serialize};
// Direct message serialization for context mgmt request
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ContextManageRequest {
messages: Vec<Message>,
manage_action: String,
}
// Direct message serialization for context mgmt request
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ContextManageResponse {
messages: Vec<Message>,
token_counts: Vec<usize>,
}
async fn manage_context(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<ContextManageRequest>,
) -> Result<Json<ContextManageResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
// Get a lock on the shared agent
let agent = state.agent.read().await;
let agent = agent.as_ref().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
let mut processed_messages: Vec<Message> = vec![];
let mut token_counts: Vec<usize> = vec![];
if request.manage_action == "trunction" {
(processed_messages, token_counts) = agent
.truncate_context(&request.messages)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
} else if request.manage_action == "summarize" {
(processed_messages, token_counts) = agent
.summarize_context(&request.messages)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}
Ok(Json(ContextManageResponse {
messages: processed_messages,
token_counts,
}))
}
// Configure routes for this module
pub fn routes(state: AppState) -> Router {
Router::new()
.route("/context/manage", post(manage_context))
.with_state(state)
}

View File

@@ -2,6 +2,7 @@
pub mod agent;
pub mod config_management;
pub mod configs;
pub mod context;
pub mod extension;
pub mod health;
pub mod recipe;
@@ -16,6 +17,7 @@ pub fn configure(state: crate::state::AppState) -> Router {
.merge(health::routes())
.merge(reply::routes(state.clone()))
.merge(agent::routes(state.clone()))
.merge(context::routes(state.clone()))
.merge(extension::routes(state.clone()))
.merge(configs::routes(state.clone()))
.merge(config_management::routes(state.clone()))

View File

@@ -12,12 +12,10 @@ use crate::permission::PermissionConfirmation;
use crate::providers::base::Provider;
use crate::providers::errors::ProviderError;
use crate::recipe::{Author, Recipe};
use crate::token_counter::TokenCounter;
use crate::truncate::{truncate_messages, OldestFirstTruncation};
use regex::Regex;
use serde_json::Value;
use tokio::sync::{mpsc, Mutex};
use tracing::{debug, error, instrument, warn};
use tracing::{debug, error, instrument};
use crate::agents::extension::{ExtensionConfig, ExtensionResult, ToolInfo};
use crate::agents::extension_manager::{get_parameter_names, ExtensionManager};
@@ -35,9 +33,6 @@ use mcp_core::{
use super::platform_tools;
use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
/// The main goose Agent
pub struct Agent {
pub(super) provider: Arc<dyn Provider>,
@@ -45,7 +40,7 @@ pub struct Agent {
pub(super) frontend_tools: HashMap<String, FrontendTool>,
pub(super) frontend_instructions: Option<String>,
pub(super) prompt_manager: PromptManager,
pub(super) token_counter: TokenCounter,
// Channels for tool results and confirmations
pub(super) confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
pub(super) confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
@@ -54,7 +49,6 @@ pub struct Agent {
impl Agent {
pub fn new(provider: Arc<dyn Provider>) -> Self {
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
// Create channels with buffer size 32 (adjust if needed)
let (confirm_tx, confirm_rx) = mpsc::channel(32);
let (tool_tx, tool_rx) = mpsc::channel(32);
@@ -65,7 +59,6 @@ impl Agent {
frontend_tools: HashMap::new(),
frontend_instructions: None,
prompt_manager: PromptManager::new(),
token_counter,
confirmation_tx: confirm_tx,
confirmation_rx: Mutex::new(confirm_rx),
tool_result_tx: tool_tx,
@@ -161,55 +154,6 @@ impl Agent {
(request_id, result)
}
/// Truncates the messages to fit within the model's context window
/// Ensures the last message is a user message and removes tool call-response pairs
async fn truncate_messages(
&self,
messages: &mut Vec<Message>,
estimate_factor: f32,
system_prompt: &str,
tools: &mut Vec<Tool>,
) -> anyhow::Result<()> {
// Model's actual context limit
let context_limit = self.provider.get_model_config().context_limit();
// Our conservative estimate of the **target** context limit
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
let context_limit = (context_limit as f32 * estimate_factor) as usize;
// Take into account the system prompt, and our tools input and subtract that from the
// remaining context limit
let system_prompt_token_count = self.token_counter.count_tokens(system_prompt);
let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice());
// Check if system prompt + tools exceed our context limit
let remaining_tokens = context_limit
.checked_sub(system_prompt_token_count)
.and_then(|remaining| remaining.checked_sub(tools_token_count))
.ok_or_else(|| {
anyhow::anyhow!("System prompt and tools exceed estimated context limit")
})?;
let context_limit = remaining_tokens;
// Calculate current token count of each message, use count_chat_tokens to ensure we
// capture the full content of the message, include ToolRequests and ToolResponses
let mut token_counts: Vec<usize> = messages
.iter()
.map(|msg| {
self.token_counter
.count_chat_tokens("", std::slice::from_ref(msg), &[])
})
.collect();
truncate_messages(
messages,
&mut token_counts,
context_limit,
&OldestFirstTruncation,
)
}
pub(super) async fn manage_extensions(
&self,
action: String,
@@ -360,7 +304,6 @@ impl Agent {
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
let mut messages = messages.to_vec();
let reply_span = tracing::Span::current();
let mut truncation_attempt: usize = 0;
// Load settings from config
let config = Config::global();
@@ -398,9 +341,6 @@ impl Agent {
Self::update_session_metrics(session_config, &usage, messages.len()).await?;
}
// Reset truncation attempt
truncation_attempt = 0;
// categorize the type of requests we need to handle
let (frontend_requests,
remaining_requests,
@@ -529,24 +469,13 @@ impl Agent {
messages.push(final_message_tool_resp);
},
Err(ProviderError::ContextLengthExceeded(_)) => {
if truncation_attempt >= MAX_TRUNCATION_ATTEMPTS {
// Create an error message & terminate the stream
// the previous message would have been a user message (e.g. before any tool calls, this is just after the input message.
// at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user)
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again.");
break;
}
truncation_attempt += 1;
warn!("Context length exceeded. Truncation Attempt: {}/{}.", truncation_attempt, MAX_TRUNCATION_ATTEMPTS);
// Decay the estimate factor as we make more truncation attempts
// Estimate factor decays like this over time: 0.9, 0.81, 0.729, ...
let estimate_factor: f32 = ESTIMATE_FACTOR_DECAY.powi(truncation_attempt as i32);
if let Err(err) = self.truncate_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await {
yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err));
break;
}
// Retry the loop after truncation
continue;
// At this point, the last message should be a user message
// because call to provider led to context length exceeded error
// Immediately yield a special message and break
yield Message::assistant().with_context_length_exceeded(
"The context length of the model has been exceeded. Please start a new session and try again.",
);
break;
},
Err(e) => {
// Create an error message & terminate the stream

View File

@@ -0,0 +1,63 @@
use anyhow::Ok;
use crate::message::Message;
use crate::token_counter::TokenCounter;
use crate::context_mgmt::summarize::summarize_messages;
use crate::context_mgmt::truncate::{truncate_messages, OldestFirstTruncation};
use crate::context_mgmt::{estimate_target_context_limit, get_messages_token_counts};
use super::super::agents::Agent;
impl Agent {
/// Public API to truncate oldest messages so that the conversation's token count is within the allowed context limit.
pub async fn truncate_context(
&self,
messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
let provider = self.provider.clone();
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
let target_context_limit = estimate_target_context_limit(provider);
let token_counts = get_messages_token_counts(&token_counter, messages);
let (mut new_messages, mut new_token_counts) = truncate_messages(
messages,
&token_counts,
target_context_limit,
&OldestFirstTruncation,
)?;
// Add an assistant message to the truncated messages
// to ensure the assistant's response is included in the context.
let assistant_message = Message::assistant().with_text("I had run into a context length exceeded error so I truncated some of the oldest messages in our conversation.");
new_messages.push(assistant_message.clone());
new_token_counts.push(token_counter.count_chat_tokens("", &[assistant_message], &[]));
Ok((new_messages, new_token_counts))
}
/// Public API to summarize the conversation so that its token count is within the allowed context limit.
pub async fn summarize_context(
&self,
messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
let provider = self.provider.clone();
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
let target_context_limit = estimate_target_context_limit(provider.clone());
let (mut new_messages, mut new_token_counts) =
summarize_messages(provider, messages, &token_counter, target_context_limit).await?;
// If the summarized messages only contains one message, it means no tool request and response message in the summarized messages,
// Add an assistant message to the summarized messages to ensure the assistant's response is included in the context.
if new_messages.len() == 1 {
let assistant_message = Message::assistant().with_text(
"I had run into a context length exceeded error so I summarized our conversation.",
);
new_messages.push(assistant_message.clone());
new_token_counts.push(token_counter.count_chat_tokens("", &[assistant_message], &[]));
}
Ok((new_messages, new_token_counts))
}
}

View File

@@ -1,4 +1,5 @@
mod agent;
mod context;
pub mod extension;
pub mod extension_manager;
pub mod platform_tools;

View File

@@ -0,0 +1,57 @@
use std::sync::Arc;
use mcp_core::Tool;
use crate::{message::Message, providers::base::Provider, token_counter::TokenCounter};
const ESTIMATE_FACTOR: f32 = 0.7;
const SYSTEM_PROMPT_TOKEN_OVERHEAD: usize = 3_000;
const TOOLS_TOKEN_OVERHEAD: usize = 5_000;
pub fn estimate_target_context_limit(provider: Arc<dyn Provider>) -> usize {
let model_context_limit = provider.get_model_config().context_limit();
// Our conservative estimate of the **target** context limit
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
let target_limit = (model_context_limit as f32 * ESTIMATE_FACTOR) as usize;
// subtract out overhead for system prompt and tools
target_limit - (SYSTEM_PROMPT_TOKEN_OVERHEAD + TOOLS_TOKEN_OVERHEAD)
}
pub fn get_messages_token_counts(token_counter: &TokenCounter, messages: &[Message]) -> Vec<usize> {
// Calculate current token count of each message, use count_chat_tokens to ensure we
// capture the full content of the message, include ToolRequests and ToolResponses
messages
.iter()
.map(|msg| token_counter.count_chat_tokens("", std::slice::from_ref(msg), &[]))
.collect()
}
// These are not being used now but could be useful in the future
#[allow(dead_code)]
pub struct ChatTokenCounts {
pub system: usize,
pub tools: usize,
pub messages: Vec<usize>,
}
#[allow(dead_code)]
pub fn get_token_counts(
token_counter: &TokenCounter,
messages: &mut [Message],
system_prompt: &str,
tools: &mut Vec<Tool>,
) -> ChatTokenCounts {
// Take into account the system prompt (includes goosehints), and our tools input
let system_prompt_token_count = token_counter.count_tokens(system_prompt);
let tools_token_count = token_counter.count_tokens_for_tools(tools.as_slice());
let messages_token_count = get_messages_token_counts(token_counter, messages);
ChatTokenCounts {
system: system_prompt_token_count,
tools: tools_token_count,
messages: messages_token_count,
}
}

View File

@@ -0,0 +1,5 @@
mod common;
pub mod summarize;
pub mod truncate;
pub use common::*;

View File

@@ -0,0 +1,422 @@
use super::common::get_messages_token_counts;
use crate::message::{Message, MessageContent};
use crate::providers::base::Provider;
use crate::token_counter::TokenCounter;
use anyhow::Result;
use mcp_core::Role;
use std::sync::Arc;
// Constants for the summarization prompt and a follow-up user message.
const SUMMARY_PROMPT: &str = "You are good at summarizing conversations";
/// Summarize the combined messages from the accumulated summary and the current chunk.
///
/// This method builds the summarization request, sends it to the provider, and returns the summarized response.
async fn summarize_combined_messages(
provider: &Arc<dyn Provider>,
accumulated_summary: &[Message],
current_chunk: &[Message],
) -> Result<Vec<Message>, anyhow::Error> {
// Combine the accumulated summary and current chunk into a single batch.
let combined_messages: Vec<Message> = accumulated_summary
.iter()
.cloned()
.chain(current_chunk.iter().cloned())
.collect();
// Format the batch as a summarization request.
let request_text = format!(
"Please summarize the following conversation history, preserving the key points. This summarization will be used for the later conversations.\n\n```\n{:?}\n```",
combined_messages
);
let summarization_request = vec![Message::user().with_text(&request_text)];
// Send the request to the provider and fetch the response.
let mut response = provider
.complete(SUMMARY_PROMPT, &summarization_request, &[])
.await?
.0;
// Set role to user as it will be used in following conversation as user content.
response.role = Role::User;
// Return the summary as the new accumulated summary.
Ok(vec![response])
}
/// Preprocesses the messages to handle edge cases involving tool responses.
///
/// This function separates messages into two groups:
/// 1. Messages to be summarized (`preprocessed_messages`)
/// 2. Messages to be temporarily removed (`removed_messages`), which include:
/// - The last tool response message.
/// - The corresponding tool request message that immediately precedes the last tool response message (if present).
///
/// The function only considers the last tool response message and its pair for removal.
fn preprocess_messages(messages: &[Message]) -> (Vec<Message>, Vec<Message>) {
let mut preprocessed_messages = messages.to_owned();
let mut removed_messages = Vec::new();
if let Some((last_index, last_message)) = messages.iter().enumerate().rev().find(|(_, m)| {
m.content
.iter()
.any(|c| matches!(c, MessageContent::ToolResponse(_)))
}) {
// Check for the corresponding tool request message
if last_index > 0 {
if let Some(previous_message) = messages.get(last_index - 1) {
if previous_message
.content
.iter()
.any(|c| matches!(c, MessageContent::ToolRequest(_)))
{
// Add the tool request message to removed_messages
removed_messages.push(previous_message.clone());
}
}
}
// Add the last tool response message to removed_messages
removed_messages.push(last_message.clone());
// Calculate the correct start index for removal
let start_index = last_index + 1 - removed_messages.len();
// Remove the tool response and its paired tool request from preprocessed_messages
preprocessed_messages.drain(start_index..=last_index);
}
(preprocessed_messages, removed_messages)
}
/// Reinserts removed messages into the summarized output.
///
/// This function appends messages that were temporarily removed during preprocessing
/// back into the summarized message list. This ensures that important context,
/// such as tool responses, is not lost.
fn reintegrate_removed_messages(
summarized_messages: &[Message],
removed_messages: &[Message],
) -> Vec<Message> {
let mut final_messages = summarized_messages.to_owned();
final_messages.extend_from_slice(removed_messages);
final_messages
}
// Summarization steps:
// 1. Break down large text into smaller chunks (roughly 30% of the models context window).
// 2. For each chunk:
// a. Combine it with the previous summary (or leave blank for the first iteration).
// b. Summarize the combined text, focusing on extracting only the information we need.
// 3. Generate a final summary using a tailored prompt.
pub async fn summarize_messages(
provider: Arc<dyn Provider>,
messages: &[Message],
token_counter: &TokenCounter,
context_limit: usize,
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
let chunk_size = context_limit / 3; // 33% of the context window.
let summary_prompt_tokens = token_counter.count_tokens(SUMMARY_PROMPT);
let mut accumulated_summary = Vec::new();
// Preprocess messages to handle tool response edge case.
let (preprocessed_messages, removed_messages) = preprocess_messages(messages);
// Get token counts for each message.
let token_counts = get_messages_token_counts(token_counter, &preprocessed_messages);
// Tokenize and break messages into chunks.
let mut current_chunk: Vec<Message> = Vec::new();
let mut current_chunk_tokens = 0;
for (message, message_tokens) in preprocessed_messages.iter().zip(token_counts.iter()) {
if current_chunk_tokens + message_tokens > chunk_size - summary_prompt_tokens {
// Summarize the current chunk with the accumulated summary.
accumulated_summary =
summarize_combined_messages(&provider, &accumulated_summary, &current_chunk)
.await?;
// Reset for the next chunk.
current_chunk.clear();
current_chunk_tokens = 0;
}
// Add message to the current chunk.
current_chunk.push(message.clone());
current_chunk_tokens += message_tokens;
}
// Summarize the final chunk if it exists.
if !current_chunk.is_empty() {
accumulated_summary =
summarize_combined_messages(&provider, &accumulated_summary, &current_chunk).await?;
}
// Add back removed messages.
let final_summary = reintegrate_removed_messages(&accumulated_summary, &removed_messages);
Ok((
final_summary.clone(),
get_messages_token_counts(token_counter, &final_summary),
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::{Message, MessageContent};
use crate::model::{ModelConfig, GPT_4O_TOKENIZER};
use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage};
use crate::providers::errors::ProviderError;
use chrono::Utc;
use mcp_core::{tool::Tool, Role};
use mcp_core::{Content, TextContent, ToolCall};
use serde_json::json;
use std::sync::Arc;
#[derive(Clone)]
struct MockProvider {
model_config: ModelConfig,
}
#[async_trait::async_trait]
impl Provider for MockProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::empty()
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: "Summarized content".to_string(),
annotations: None,
})],
},
ProviderUsage::new("mock".to_string(), Usage::default()),
))
}
}
fn create_mock_provider() -> Arc<dyn Provider> {
let mock_model_config =
ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into());
Arc::new(MockProvider {
model_config: mock_model_config,
})
}
fn create_test_messages() -> Vec<Message> {
vec![
set_up_text_message("Message 1", Role::User),
set_up_text_message("Message 2", Role::Assistant),
set_up_text_message("Message 3", Role::User),
]
}
fn set_up_text_message(text: &str, role: Role) -> Message {
Message {
role,
created: 0,
content: vec![MessageContent::text(text.to_string())],
}
}
fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message {
Message {
role: Role::Assistant,
created: 0,
content: vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
}
}
fn set_up_tool_response_message(id: &str, tool_response: Vec<Content>) -> Message {
Message {
role: Role::User,
created: 0,
content: vec![MessageContent::tool_response(
id.to_string(),
Ok(tool_response),
)],
}
}
#[tokio::test]
async fn test_summarize_messages_single_chunk() {
let provider = create_mock_provider();
let token_counter = TokenCounter::new(GPT_4O_TOKENIZER);
let context_limit = 100; // Set a high enough limit to avoid chunking.
let messages = create_test_messages();
let result = summarize_messages(
Arc::clone(&provider),
&messages,
&token_counter,
context_limit,
)
.await;
assert!(result.is_ok(), "The function should return Ok.");
let (summarized_messages, token_counts) = result.unwrap();
assert_eq!(
summarized_messages.len(),
1,
"The summary should contain one message."
);
assert_eq!(
summarized_messages[0].role,
Role::User,
"The summarized message should be from the user."
);
assert_eq!(
token_counts.len(),
1,
"Token counts should match the number of summarized messages."
);
}
#[tokio::test]
async fn test_summarize_messages_multiple_chunks() {
let provider = create_mock_provider();
let token_counter = TokenCounter::new(GPT_4O_TOKENIZER);
let context_limit = 30;
let messages = create_test_messages();
let result = summarize_messages(
Arc::clone(&provider),
&messages,
&token_counter,
context_limit,
)
.await;
assert!(result.is_ok(), "The function should return Ok.");
let (summarized_messages, token_counts) = result.unwrap();
assert_eq!(
summarized_messages.len(),
1,
"There should be one final summarized message."
);
assert_eq!(
summarized_messages[0].role,
Role::User,
"The summarized message should be from the user."
);
assert_eq!(
token_counts.len(),
1,
"Token counts should match the number of summarized messages."
);
}
#[tokio::test]
async fn test_summarize_messages_empty_input() {
let provider = create_mock_provider();
let token_counter = TokenCounter::new(GPT_4O_TOKENIZER);
let context_limit = 100;
let messages: Vec<Message> = Vec::new();
let result = summarize_messages(
Arc::clone(&provider),
&messages,
&token_counter,
context_limit,
)
.await;
assert!(result.is_ok(), "The function should return Ok.");
let (summarized_messages, token_counts) = result.unwrap();
assert_eq!(
summarized_messages.len(),
0,
"The summary should be empty for an empty input."
);
assert!(
token_counts.is_empty(),
"Token counts should be empty for an empty input."
);
}
#[tokio::test]
async fn test_preprocess_messages_without_tool_response() {
let messages = create_test_messages();
let (preprocessed_messages, removed_messages) = preprocess_messages(&messages);
assert_eq!(
preprocessed_messages.len(),
3,
"Only the user message should remain after preprocessing."
);
assert_eq!(
removed_messages.len(),
0,
"The tool request and tool response messages should be removed."
);
}
#[tokio::test]
async fn test_preprocess_messages_with_tool_response() {
let arguments = json!({
"param1": "value1"
});
let messages = vec![
set_up_text_message("Message 1", Role::User),
set_up_tool_request_message("id", ToolCall::new("tool_name", json!(arguments))),
set_up_tool_response_message("id", vec![Content::text("tool done")]),
];
let (preprocessed_messages, removed_messages) = preprocess_messages(&messages);
assert_eq!(
preprocessed_messages.len(),
1,
"Only the user message should remain after preprocessing."
);
assert_eq!(
removed_messages.len(),
2,
"The tool request and tool response messages should be removed."
);
}
#[tokio::test]
async fn test_reintegrate_removed_messages() {
let summarized_messages = vec![Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: "Summary".to_string(),
annotations: None,
})],
}];
let arguments = json!({
"param1": "value1"
});
let removed_messages = vec![
set_up_tool_request_message("id", ToolCall::new("tool_name", json!(arguments))),
set_up_tool_response_message("id", vec![Content::text("tool done")]),
];
let final_messages = reintegrate_removed_messages(&summarized_messages, &removed_messages);
assert_eq!(
final_messages.len(),
3,
"The final message list should include the summary and removed messages."
);
}
}

View File

@@ -4,77 +4,6 @@ use mcp_core::Role;
use std::collections::HashSet;
use tracing::debug;
/// Trait representing a truncation strategy
pub trait TruncationStrategy {
/// Determines the indices of messages to remove to fit within the context limit.
///
/// - `messages`: The list of messages in the conversation.
/// - `token_counts`: A parallel array containing the token count for each message.
/// - `context_limit`: The maximum allowed context length in tokens.
///
/// Returns a vector of indices to remove.
fn determine_indices_to_remove(
&self,
messages: &[Message],
token_counts: &[usize],
context_limit: usize,
) -> Result<HashSet<usize>>;
}
/// Strategy to truncate messages by removing the oldest first
pub struct OldestFirstTruncation;
/// Strategy to truncate messages explicitly
pub struct ExplicitTruncation;
impl TruncationStrategy for OldestFirstTruncation {
fn determine_indices_to_remove(
&self,
messages: &[Message],
token_counts: &[usize],
context_limit: usize,
) -> Result<HashSet<usize>> {
let mut indices_to_remove = HashSet::new();
let mut total_tokens: usize = token_counts.iter().sum();
let mut tool_ids_to_remove = HashSet::new();
for (i, message) in messages.iter().enumerate() {
if total_tokens <= context_limit {
break;
}
// Remove the message
indices_to_remove.insert(i);
total_tokens -= token_counts[i];
debug!(
"OldestFirst: Removing message at index {}. Tokens removed: {}",
i, token_counts[i]
);
// If it's a ToolRequest or ToolResponse, mark its pair for removal
if message.is_tool_call() || message.is_tool_response() {
message.get_tool_ids().iter().for_each(|id| {
tool_ids_to_remove.insert((i, id.to_string()));
});
}
}
// Now, find and remove paired ToolResponses or ToolRequests
for (i, message) in messages.iter().enumerate() {
let message_tool_ids = message.get_tool_ids();
// Find the other part of the pair - same tool_id but different message index
for (message_idx, tool_id) in &tool_ids_to_remove {
if message_idx != &i && message_tool_ids.contains(tool_id.as_str()) {
indices_to_remove.insert(i);
// No need to check other tool_ids for this message since it's already marked
break;
}
}
}
Ok(indices_to_remove)
}
}
/// Truncates the messages to fit within the model's context window.
/// Mutates the input messages and token counts in place.
/// Returns an error if it's impossible to truncate the messages within the context limit.
@@ -83,11 +12,14 @@ impl TruncationStrategy for OldestFirstTruncation {
/// - context_limit: The maximum allowed context length in tokens.
/// - strategy: The truncation strategy to use. Only option is OldestFirstTruncation.
pub fn truncate_messages(
messages: &mut Vec<Message>,
token_counts: &mut Vec<usize>,
messages: &[Message],
token_counts: &[usize],
context_limit: usize,
strategy: &dyn TruncationStrategy,
) -> Result<(), anyhow::Error> {
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
let mut messages = messages.to_owned();
let mut token_counts = token_counts.to_owned();
if messages.len() != token_counts.len() {
return Err(anyhow!(
"The vector for messages and token_counts must have same length"
@@ -114,12 +46,12 @@ pub fn truncate_messages(
}
if total_tokens <= context_limit {
return Ok(()); // No truncation needed
return Ok((messages, token_counts)); // No truncation needed
}
// Step 2: Determine indices to remove based on strategy
let indices_to_remove =
strategy.determine_indices_to_remove(messages, token_counts, context_limit)?;
strategy.determine_indices_to_remove(&messages, &token_counts, context_limit)?;
// Step 3: Remove the marked messages
// Vectorize the set and sort in reverse order to avoid shifting indices when removing
@@ -174,10 +106,77 @@ pub fn truncate_messages(
}
debug!("Truncation complete. Total tokens: {}", total_tokens);
Ok(())
Ok((messages, token_counts))
}
// truncate.rs
/// Trait representing a truncation strategy
pub trait TruncationStrategy {
/// Determines the indices of messages to remove to fit within the context limit.
///
/// - `messages`: The list of messages in the conversation.
/// - `token_counts`: A parallel array containing the token count for each message.
/// - `context_limit`: The maximum allowed context length in tokens.
///
/// Returns a vector of indices to remove.
fn determine_indices_to_remove(
&self,
messages: &[Message],
token_counts: &[usize],
context_limit: usize,
) -> Result<HashSet<usize>>;
}
/// Strategy to truncate messages by removing the oldest first
pub struct OldestFirstTruncation;
impl TruncationStrategy for OldestFirstTruncation {
fn determine_indices_to_remove(
&self,
messages: &[Message],
token_counts: &[usize],
context_limit: usize,
) -> Result<HashSet<usize>> {
let mut indices_to_remove = HashSet::new();
let mut total_tokens: usize = token_counts.iter().sum();
let mut tool_ids_to_remove = HashSet::new();
for (i, message) in messages.iter().enumerate() {
if total_tokens <= context_limit {
break;
}
// Remove the message
indices_to_remove.insert(i);
total_tokens -= token_counts[i];
debug!(
"OldestFirst: Removing message at index {}. Tokens removed: {}",
i, token_counts[i]
);
// If it's a ToolRequest or ToolResponse, mark its pair for removal
if message.is_tool_call() || message.is_tool_response() {
message.get_tool_ids().iter().for_each(|id| {
tool_ids_to_remove.insert((i, id.to_string()));
});
}
}
// Now, find and remove paired ToolResponses or ToolRequests
for (i, message) in messages.iter().enumerate() {
let message_tool_ids = message.get_tool_ids();
// Find the other part of the pair - same tool_id but different message index
for (message_idx, tool_id) in &tool_ids_to_remove {
if message_idx != &i && message_tool_ids.contains(tool_id.as_str()) {
indices_to_remove.insert(i);
// No need to check other tool_ids for this message since it's already marked
break;
}
}
}
Ok(indices_to_remove)
}
}
#[cfg(test)]
mod tests {
@@ -326,7 +325,7 @@ mod tests {
let (mut messages, mut token_counts) = create_messages_with_counts(2, 25, false);
let context_limit = 100; // Exactly matches total tokens
truncate_messages(
(messages, token_counts) = truncate_messages(
&mut messages,
&mut token_counts,
context_limit,
@@ -340,7 +339,7 @@ mod tests {
messages.push(user_text(5, 1).0);
token_counts.push(1);
truncate_messages(
(messages, token_counts) = truncate_messages(
&mut messages,
&mut token_counts,
context_limit,
@@ -380,7 +379,7 @@ mod tests {
let mut messages_clone = messages.clone();
let mut token_counts_clone = token_counts.clone();
truncate_messages(
(messages_clone, _) = truncate_messages(
&mut messages_clone,
&mut token_counts_clone,
context_limit,
@@ -425,7 +424,7 @@ mod tests {
let mut token_counts = vec![50, 10, 10, 20, 5];
let context_limit = 45; // Force truncation
truncate_messages(
(messages, token_counts) = truncate_messages(
&mut messages,
&mut token_counts,
context_limit,

View File

@@ -1,6 +1,6 @@
pub mod agents;
pub mod config;
pub mod memory_condense;
pub mod context_mgmt;
pub mod message;
pub mod model;
pub mod permission;
@@ -10,4 +10,3 @@ pub mod recipe;
pub mod session;
pub mod token_counter;
pub mod tracing;
pub mod truncate;

View File

@@ -1,126 +0,0 @@
use crate::message::Message;
use crate::providers::base::Provider;
use crate::token_counter::TokenCounter;
use anyhow::{anyhow, Result};
use std::sync::Arc;
use tracing::debug;
const SYSTEM_PROMPT: &str = "You are good at summarizing.";
fn create_summarize_request(messages: &[Message]) -> Vec<Message> {
vec![
Message::user().with_text(format!("Please use a few concise sentences to summarize this chat, while keeping the important information.\n\n```\n{:?}```", messages)),
]
}
async fn single_request(
provider: &Arc<dyn Provider>,
messages: &[Message],
) -> Result<Message, anyhow::Error> {
Ok(provider.complete(SYSTEM_PROMPT, messages, &[]).await?.0)
}
async fn memory_condense(
provider: Arc<dyn Provider>,
token_counter: &TokenCounter,
messages: &mut Vec<Message>,
token_counts: &mut Vec<usize>,
context_limit: usize,
) -> Result<(), anyhow::Error> {
let system_prompt_tokens = token_counter.count_tokens(SYSTEM_PROMPT);
// Since the process will run multiple times, we should avoid expensive operations like random access.
let mut message_stack = messages.iter().cloned().rev().collect::<Vec<_>>();
let mut count_stack = token_counts.iter().copied().rev().collect::<Vec<_>>();
// Tracks the number of remaining tokens in the stack
let mut total_tokens = count_stack.iter().sum::<usize>();
// Tracks the change of total_tokens in the previous loop.
// If diff <= 0, then the model cannot summarize any further. We set it to 1 before the process
// to ensure that the process starts.
let mut diff = 1;
while total_tokens > context_limit && diff > 0 {
let mut batch = Vec::new();
let mut current_tokens = 0;
// Extracts the beginning messages (which appears in the front of the message stack) to
// summarize.
while total_tokens > current_tokens + context_limit
&& current_tokens + system_prompt_tokens <= context_limit
{
batch.push(message_stack.pop().unwrap());
current_tokens += count_stack.pop().unwrap();
}
// It could happen that the extracted messages are always the previous summary when the
// context limit is very small. We should force it to consume more messages.
if !batch.is_empty()
&& !message_stack.is_empty()
&& current_tokens + system_prompt_tokens <= context_limit
{
batch.push(message_stack.pop().unwrap());
current_tokens += count_stack.pop().unwrap();
}
diff = -(current_tokens as isize);
let request = create_summarize_request(&batch);
let response_text = single_request(&provider, &request).await?.as_concat_text();
// Ensure the conversation starts with a User message
let curr_messages = vec![
// shoule be in reversed order
Message::assistant().with_text(&response_text),
Message::user().with_text("Hello! How are we progressing?"),
];
let curr_tokens = token_counter.count_chat_tokens("", &curr_messages, &[]);
diff += curr_tokens as isize;
count_stack.push(curr_tokens);
message_stack.extend(curr_messages);
// Update the counter
total_tokens = total_tokens.checked_add_signed(diff).unwrap();
}
if total_tokens <= context_limit {
*messages = message_stack.into_iter().rev().collect();
*token_counts = count_stack.into_iter().rev().collect();
Ok(())
} else {
Err(anyhow!("Cannot compress the messages anymore"))
}
}
/// TODO: currently not used. we will add this is a feature flag under context mgmt
pub async fn condense_messages(
provider: Arc<dyn Provider>,
token_counter: &TokenCounter,
messages: &mut Vec<Message>,
token_counts: &mut Vec<usize>,
context_limit: usize,
) -> Result<(), anyhow::Error> {
let total_tokens: usize = token_counts.iter().sum();
debug!("Total tokens before memory condensation: {}", total_tokens);
// The compressor should determine whether we need to compress the messages or not. This
// function just checks if the limit is satisfied.
memory_condense(
provider,
token_counter,
messages,
token_counts,
context_limit,
)
.await?;
let total_tokens: usize = token_counts.iter().sum();
debug!("Total tokens after memory condensation: {}", total_tokens);
// Compressor should handle this case.
assert!(total_tokens <= context_limit, "Illegal compression result from the compressor: the number of tokens is greater than the limit.");
debug!(
"Memory condensation complete. Total tokens: {}",
total_tokens
);
Ok(())
}

View File

@@ -14,11 +14,12 @@ use mcp_core::prompt::{PromptMessage, PromptMessageContent, PromptMessageRole};
use mcp_core::resource::ResourceContents;
use mcp_core::role::Role;
use mcp_core::tool::ToolCall;
use serde::{Deserialize, Serialize};
use serde_json::Value;
mod tool_result_serde;
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolRequest {
pub id: String,
@@ -42,7 +43,7 @@ impl ToolRequest {
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolResponse {
pub id: String,
@@ -50,7 +51,7 @@ pub struct ToolResponse {
pub tool_result: ToolResult<Vec<Content>>,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfirmationRequest {
pub id: String,
@@ -59,18 +60,18 @@ pub struct ToolConfirmationRequest {
pub prompt: Option<String>,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ThinkingContent {
pub thinking: String,
pub signature: String,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RedactedThinkingContent {
pub data: String,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FrontendToolRequest {
pub id: String,
@@ -78,7 +79,12 @@ pub struct FrontendToolRequest {
pub tool_call: ToolResult<ToolCall>,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ContextLengthExceeded {
pub msg: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
/// Content passed inside a message, which can be both simple content and tool content
#[serde(tag = "type", rename_all = "camelCase")]
pub enum MessageContent {
@@ -90,6 +96,7 @@ pub enum MessageContent {
FrontendToolRequest(FrontendToolRequest),
Thinking(ThinkingContent),
RedactedThinking(RedactedThinkingContent),
ContextLengthExceeded(ContextLengthExceeded),
}
impl MessageContent {
@@ -153,6 +160,11 @@ impl MessageContent {
tool_call,
})
}
pub fn context_length_exceeded<S: Into<String>>(msg: S) -> Self {
MessageContent::ContextLengthExceeded(ContextLengthExceeded { msg: msg.into() })
}
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
if let MessageContent::ToolRequest(ref tool_request) = self {
Some(tool_request)
@@ -261,7 +273,7 @@ impl From<PromptMessage> for Message {
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
/// A message to or from an LLM
#[serde(rename_all = "camelCase")]
pub struct Message {
@@ -358,6 +370,11 @@ impl Message {
self.with_content(MessageContent::redacted_thinking(data))
}
/// Add context length exceeded content to the message
pub fn with_context_length_exceeded<S: Into<String>>(self, msg: S) -> Self {
self.with_content(MessageContent::context_length_exceeded(msg))
}
/// Get the concatenated text content of the message, separated by newlines
pub fn as_concat_text(&self) -> String {
self.content

View File

@@ -60,6 +60,9 @@ pub fn format_messages(messages: &[Message]) -> Vec<Value> {
MessageContent::ToolConfirmationRequest(_tool_confirmation_request) => {
// Skip tool confirmation requests
}
MessageContent::ContextLengthExceeded(_) => {
// Skip
}
MessageContent::Thinking(thinking) => {
content.push(json!({
"type": "thinking",

View File

@@ -42,6 +42,9 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::C
// Redacted thinking blocks are not supported in Bedrock - skip
bedrock::ContentBlock::Text("".to_string())
}
MessageContent::ContextLengthExceeded(_) => {
bail!("ContextLengthExceeded should not get passed to the provider")
}
MessageContent::ToolRequest(tool_req) => {
let tool_use_id = tool_req.id.to_string();
let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() {

View File

@@ -110,6 +110,9 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
}
}
}
MessageContent::ContextLengthExceeded(_) => {
continue;
}
MessageContent::ToolResponse(response) => {
match &response.tool_result {
Ok(contents) => {

View File

@@ -52,6 +52,9 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
// Redacted thinking blocks are not directly used in OpenAI format
continue;
}
MessageContent::ContextLengthExceeded(_) => {
continue;
}
MessageContent::ToolRequest(request) => match &request.tool_call {
Ok(tool_call) => {
let sanitized_name = sanitize_function_name(&tool_call.name);

View File

@@ -73,12 +73,18 @@ export interface ExtensionCallResult<T> {
error?: string;
}
export interface ContextLengthExceededContent {
type: 'contextLengthExceeded';
msg: string;
}
export type MessageContent =
| TextContent
| ImageContent
| ToolRequestMessageContent
| ToolResponseMessageContent
| ToolConfirmationRequestMessageContent;
| ToolConfirmationRequestMessageContent
| ContextLengthExceededContent;
export interface Message {
id?: string;
@@ -175,8 +181,18 @@ function generateId(): string {
// Helper functions to extract content from messages
export function getTextContent(message: Message): string {
return message.content
.filter((content): content is TextContent => content.type === 'text')
.map((content) => content.text)
.filter(
(content): content is TextContent | ContextLengthExceededContent =>
content.type === 'text' || content.type === 'contextLengthExceeded'
)
.map((content) => {
if (content.type === 'text') {
return content.text;
} else if (content.type === 'contextLengthExceeded') {
return content.msg;
}
return '';
})
.join('\n');
}