mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-23 07:24:24 +01:00
feat: present options to user when context length is exceeded (#2207)
Co-authored-by: Yingjie He <yingjiehe@squareup.com>
This commit is contained in:
@@ -6,6 +6,7 @@ mod prompt;
|
||||
mod thinking;
|
||||
|
||||
pub use builder::{build_session, SessionBuilderConfig};
|
||||
use console::Color;
|
||||
use goose::permission::permission_confirmation::PrincipalType;
|
||||
use goose::permission::Permission;
|
||||
use goose::permission::PermissionConfirmation;
|
||||
@@ -592,7 +593,7 @@ impl Session {
|
||||
.reply(
|
||||
&self.messages,
|
||||
Some(SessionConfig {
|
||||
id: session_id,
|
||||
id: session_id.clone(),
|
||||
working_dir: std::env::current_dir()
|
||||
.expect("failed to get current session working directory"),
|
||||
}),
|
||||
@@ -622,6 +623,54 @@ impl Session {
|
||||
principal_type: PrincipalType::Tool,
|
||||
permission,
|
||||
},).await;
|
||||
} else if let Some(MessageContent::ContextLengthExceeded(_)) = message.content.first() {
|
||||
output::hide_thinking();
|
||||
|
||||
let prompt = "The model's context length is maxed out. You will need to reduce the # msgs. Do you want to?".to_string();
|
||||
let selected = cliclack::select(prompt)
|
||||
.item("clear", "Clear Session", "Removes all messages from Goose's memory")
|
||||
.item("truncate", "Truncate Messages", "Removes old messages till context is within limits")
|
||||
.item("summarize", "Summarize Session", "Summarize the session to reduce context length")
|
||||
.interact()?;
|
||||
|
||||
match selected {
|
||||
"clear" => {
|
||||
self.messages.clear();
|
||||
let msg = format!("Session cleared.\n{}", "-".repeat(50));
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
break; // exit the loop to hand back control to the user
|
||||
}
|
||||
"truncate" => {
|
||||
// Truncate messages to fit within context length
|
||||
let (truncated_messages, _) = self.agent.truncate_context(&self.messages).await?;
|
||||
let msg = format!("Context maxed out\n{}\nGoose tried its best to truncate messages for you.", "-".repeat(50));
|
||||
output::render_text("", Some(Color::Yellow), true);
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
self.messages = truncated_messages;
|
||||
}
|
||||
"summarize" => {
|
||||
// Summarize messages to fit within context length
|
||||
let (summarized_messages, _) = self.agent.summarize_context(&self.messages).await?;
|
||||
let msg = format!("Context maxed out\n{}\nGoose summarized messages for you.", "-".repeat(50));
|
||||
output::render_text(&msg, Some(Color::Yellow), true);
|
||||
self.messages = summarized_messages;
|
||||
}
|
||||
_ => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
// Restart the stream after handling ContextLengthExceeded
|
||||
stream = self
|
||||
.agent
|
||||
.reply(
|
||||
&self.messages,
|
||||
Some(SessionConfig {
|
||||
id: session_id.clone(),
|
||||
working_dir: std::env::current_dir()
|
||||
.expect("failed to get current session working directory"),
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
// otherwise we have a model/tool to render
|
||||
else {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use bat::WrappingMode;
|
||||
use console::style;
|
||||
use console::{style, Color};
|
||||
use goose::config::Config;
|
||||
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
|
||||
use mcp_core::prompt::PromptArgument;
|
||||
@@ -143,6 +143,19 @@ pub fn render_message(message: &Message, debug: bool) {
|
||||
println!();
|
||||
}
|
||||
|
||||
pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
|
||||
let mut styled_text = style(text);
|
||||
if dim {
|
||||
styled_text = styled_text.dim();
|
||||
}
|
||||
if let Some(color) = color {
|
||||
styled_text = styled_text.fg(color);
|
||||
} else {
|
||||
styled_text = styled_text.green();
|
||||
}
|
||||
println!("\n{}\n", styled_text);
|
||||
}
|
||||
|
||||
pub fn render_enter_plan_mode() {
|
||||
println!(
|
||||
"\n{} {}\n",
|
||||
|
||||
64
crates/goose-server/src/routes/context.rs
Normal file
64
crates/goose-server/src/routes/context.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use super::utils::verify_secret_key;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::{HeaderMap, StatusCode},
|
||||
routing::post,
|
||||
Json, Router,
|
||||
};
|
||||
use goose::message::Message;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// Direct message serialization for context mgmt request
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ContextManageRequest {
|
||||
messages: Vec<Message>,
|
||||
manage_action: String,
|
||||
}
|
||||
|
||||
// Direct message serialization for context mgmt request
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ContextManageResponse {
|
||||
messages: Vec<Message>,
|
||||
token_counts: Vec<usize>,
|
||||
}
|
||||
|
||||
async fn manage_context(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<ContextManageRequest>,
|
||||
) -> Result<Json<ContextManageResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
// Get a lock on the shared agent
|
||||
let agent = state.agent.read().await;
|
||||
let agent = agent.as_ref().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
|
||||
|
||||
let mut processed_messages: Vec<Message> = vec![];
|
||||
let mut token_counts: Vec<usize> = vec![];
|
||||
if request.manage_action == "trunction" {
|
||||
(processed_messages, token_counts) = agent
|
||||
.truncate_context(&request.messages)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
} else if request.manage_action == "summarize" {
|
||||
(processed_messages, token_counts) = agent
|
||||
.summarize_context(&request.messages)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
|
||||
Ok(Json(ContextManageResponse {
|
||||
messages: processed_messages,
|
||||
token_counts,
|
||||
}))
|
||||
}
|
||||
|
||||
// Configure routes for this module
|
||||
pub fn routes(state: AppState) -> Router {
|
||||
Router::new()
|
||||
.route("/context/manage", post(manage_context))
|
||||
.with_state(state)
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
pub mod agent;
|
||||
pub mod config_management;
|
||||
pub mod configs;
|
||||
pub mod context;
|
||||
pub mod extension;
|
||||
pub mod health;
|
||||
pub mod recipe;
|
||||
@@ -16,6 +17,7 @@ pub fn configure(state: crate::state::AppState) -> Router {
|
||||
.merge(health::routes())
|
||||
.merge(reply::routes(state.clone()))
|
||||
.merge(agent::routes(state.clone()))
|
||||
.merge(context::routes(state.clone()))
|
||||
.merge(extension::routes(state.clone()))
|
||||
.merge(configs::routes(state.clone()))
|
||||
.merge(config_management::routes(state.clone()))
|
||||
|
||||
@@ -12,12 +12,10 @@ use crate::permission::PermissionConfirmation;
|
||||
use crate::providers::base::Provider;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::recipe::{Author, Recipe};
|
||||
use crate::token_counter::TokenCounter;
|
||||
use crate::truncate::{truncate_messages, OldestFirstTruncation};
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tracing::{debug, error, instrument, warn};
|
||||
use tracing::{debug, error, instrument};
|
||||
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionResult, ToolInfo};
|
||||
use crate::agents::extension_manager::{get_parameter_names, ExtensionManager};
|
||||
@@ -35,9 +33,6 @@ use mcp_core::{
|
||||
use super::platform_tools;
|
||||
use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
|
||||
|
||||
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
|
||||
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
|
||||
|
||||
/// The main goose Agent
|
||||
pub struct Agent {
|
||||
pub(super) provider: Arc<dyn Provider>,
|
||||
@@ -45,7 +40,7 @@ pub struct Agent {
|
||||
pub(super) frontend_tools: HashMap<String, FrontendTool>,
|
||||
pub(super) frontend_instructions: Option<String>,
|
||||
pub(super) prompt_manager: PromptManager,
|
||||
pub(super) token_counter: TokenCounter,
|
||||
// Channels for tool results and confirmations
|
||||
pub(super) confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
|
||||
pub(super) confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
|
||||
pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||
@@ -54,7 +49,6 @@ pub struct Agent {
|
||||
|
||||
impl Agent {
|
||||
pub fn new(provider: Arc<dyn Provider>) -> Self {
|
||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||
// Create channels with buffer size 32 (adjust if needed)
|
||||
let (confirm_tx, confirm_rx) = mpsc::channel(32);
|
||||
let (tool_tx, tool_rx) = mpsc::channel(32);
|
||||
@@ -65,7 +59,6 @@ impl Agent {
|
||||
frontend_tools: HashMap::new(),
|
||||
frontend_instructions: None,
|
||||
prompt_manager: PromptManager::new(),
|
||||
token_counter,
|
||||
confirmation_tx: confirm_tx,
|
||||
confirmation_rx: Mutex::new(confirm_rx),
|
||||
tool_result_tx: tool_tx,
|
||||
@@ -161,55 +154,6 @@ impl Agent {
|
||||
(request_id, result)
|
||||
}
|
||||
|
||||
/// Truncates the messages to fit within the model's context window
|
||||
/// Ensures the last message is a user message and removes tool call-response pairs
|
||||
async fn truncate_messages(
|
||||
&self,
|
||||
messages: &mut Vec<Message>,
|
||||
estimate_factor: f32,
|
||||
system_prompt: &str,
|
||||
tools: &mut Vec<Tool>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Model's actual context limit
|
||||
let context_limit = self.provider.get_model_config().context_limit();
|
||||
|
||||
// Our conservative estimate of the **target** context limit
|
||||
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
|
||||
let context_limit = (context_limit as f32 * estimate_factor) as usize;
|
||||
|
||||
// Take into account the system prompt, and our tools input and subtract that from the
|
||||
// remaining context limit
|
||||
let system_prompt_token_count = self.token_counter.count_tokens(system_prompt);
|
||||
let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice());
|
||||
|
||||
// Check if system prompt + tools exceed our context limit
|
||||
let remaining_tokens = context_limit
|
||||
.checked_sub(system_prompt_token_count)
|
||||
.and_then(|remaining| remaining.checked_sub(tools_token_count))
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("System prompt and tools exceed estimated context limit")
|
||||
})?;
|
||||
|
||||
let context_limit = remaining_tokens;
|
||||
|
||||
// Calculate current token count of each message, use count_chat_tokens to ensure we
|
||||
// capture the full content of the message, include ToolRequests and ToolResponses
|
||||
let mut token_counts: Vec<usize> = messages
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
self.token_counter
|
||||
.count_chat_tokens("", std::slice::from_ref(msg), &[])
|
||||
})
|
||||
.collect();
|
||||
|
||||
truncate_messages(
|
||||
messages,
|
||||
&mut token_counts,
|
||||
context_limit,
|
||||
&OldestFirstTruncation,
|
||||
)
|
||||
}
|
||||
|
||||
pub(super) async fn manage_extensions(
|
||||
&self,
|
||||
action: String,
|
||||
@@ -360,7 +304,6 @@ impl Agent {
|
||||
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
|
||||
let mut messages = messages.to_vec();
|
||||
let reply_span = tracing::Span::current();
|
||||
let mut truncation_attempt: usize = 0;
|
||||
|
||||
// Load settings from config
|
||||
let config = Config::global();
|
||||
@@ -398,9 +341,6 @@ impl Agent {
|
||||
Self::update_session_metrics(session_config, &usage, messages.len()).await?;
|
||||
}
|
||||
|
||||
// Reset truncation attempt
|
||||
truncation_attempt = 0;
|
||||
|
||||
// categorize the type of requests we need to handle
|
||||
let (frontend_requests,
|
||||
remaining_requests,
|
||||
@@ -529,24 +469,13 @@ impl Agent {
|
||||
messages.push(final_message_tool_resp);
|
||||
},
|
||||
Err(ProviderError::ContextLengthExceeded(_)) => {
|
||||
if truncation_attempt >= MAX_TRUNCATION_ATTEMPTS {
|
||||
// Create an error message & terminate the stream
|
||||
// the previous message would have been a user message (e.g. before any tool calls, this is just after the input message.
|
||||
// at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user)
|
||||
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again.");
|
||||
break;
|
||||
}
|
||||
truncation_attempt += 1;
|
||||
warn!("Context length exceeded. Truncation Attempt: {}/{}.", truncation_attempt, MAX_TRUNCATION_ATTEMPTS);
|
||||
// Decay the estimate factor as we make more truncation attempts
|
||||
// Estimate factor decays like this over time: 0.9, 0.81, 0.729, ...
|
||||
let estimate_factor: f32 = ESTIMATE_FACTOR_DECAY.powi(truncation_attempt as i32);
|
||||
if let Err(err) = self.truncate_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await {
|
||||
yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err));
|
||||
break;
|
||||
}
|
||||
// Retry the loop after truncation
|
||||
continue;
|
||||
// At this point, the last message should be a user message
|
||||
// because call to provider led to context length exceeded error
|
||||
// Immediately yield a special message and break
|
||||
yield Message::assistant().with_context_length_exceeded(
|
||||
"The context length of the model has been exceeded. Please start a new session and try again.",
|
||||
);
|
||||
break;
|
||||
},
|
||||
Err(e) => {
|
||||
// Create an error message & terminate the stream
|
||||
|
||||
63
crates/goose/src/agents/context.rs
Normal file
63
crates/goose/src/agents/context.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use anyhow::Ok;
|
||||
|
||||
use crate::message::Message;
|
||||
use crate::token_counter::TokenCounter;
|
||||
|
||||
use crate::context_mgmt::summarize::summarize_messages;
|
||||
use crate::context_mgmt::truncate::{truncate_messages, OldestFirstTruncation};
|
||||
use crate::context_mgmt::{estimate_target_context_limit, get_messages_token_counts};
|
||||
|
||||
use super::super::agents::Agent;
|
||||
|
||||
impl Agent {
|
||||
/// Public API to truncate oldest messages so that the conversation's token count is within the allowed context limit.
|
||||
pub async fn truncate_context(
|
||||
&self,
|
||||
messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded
|
||||
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
|
||||
let provider = self.provider.clone();
|
||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||
let target_context_limit = estimate_target_context_limit(provider);
|
||||
let token_counts = get_messages_token_counts(&token_counter, messages);
|
||||
|
||||
let (mut new_messages, mut new_token_counts) = truncate_messages(
|
||||
messages,
|
||||
&token_counts,
|
||||
target_context_limit,
|
||||
&OldestFirstTruncation,
|
||||
)?;
|
||||
|
||||
// Add an assistant message to the truncated messages
|
||||
// to ensure the assistant's response is included in the context.
|
||||
let assistant_message = Message::assistant().with_text("I had run into a context length exceeded error so I truncated some of the oldest messages in our conversation.");
|
||||
new_messages.push(assistant_message.clone());
|
||||
new_token_counts.push(token_counter.count_chat_tokens("", &[assistant_message], &[]));
|
||||
|
||||
Ok((new_messages, new_token_counts))
|
||||
}
|
||||
|
||||
/// Public API to summarize the conversation so that its token count is within the allowed context limit.
|
||||
pub async fn summarize_context(
|
||||
&self,
|
||||
messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded
|
||||
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
|
||||
let provider = self.provider.clone();
|
||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||
let target_context_limit = estimate_target_context_limit(provider.clone());
|
||||
|
||||
let (mut new_messages, mut new_token_counts) =
|
||||
summarize_messages(provider, messages, &token_counter, target_context_limit).await?;
|
||||
|
||||
// If the summarized messages only contains one message, it means no tool request and response message in the summarized messages,
|
||||
// Add an assistant message to the summarized messages to ensure the assistant's response is included in the context.
|
||||
if new_messages.len() == 1 {
|
||||
let assistant_message = Message::assistant().with_text(
|
||||
"I had run into a context length exceeded error so I summarized our conversation.",
|
||||
);
|
||||
new_messages.push(assistant_message.clone());
|
||||
new_token_counts.push(token_counter.count_chat_tokens("", &[assistant_message], &[]));
|
||||
}
|
||||
|
||||
Ok((new_messages, new_token_counts))
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
mod agent;
|
||||
mod context;
|
||||
pub mod extension;
|
||||
pub mod extension_manager;
|
||||
pub mod platform_tools;
|
||||
|
||||
57
crates/goose/src/context_mgmt/common.rs
Normal file
57
crates/goose/src/context_mgmt/common.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use mcp_core::Tool;
|
||||
|
||||
use crate::{message::Message, providers::base::Provider, token_counter::TokenCounter};
|
||||
|
||||
const ESTIMATE_FACTOR: f32 = 0.7;
|
||||
const SYSTEM_PROMPT_TOKEN_OVERHEAD: usize = 3_000;
|
||||
const TOOLS_TOKEN_OVERHEAD: usize = 5_000;
|
||||
|
||||
pub fn estimate_target_context_limit(provider: Arc<dyn Provider>) -> usize {
|
||||
let model_context_limit = provider.get_model_config().context_limit();
|
||||
|
||||
// Our conservative estimate of the **target** context limit
|
||||
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
|
||||
let target_limit = (model_context_limit as f32 * ESTIMATE_FACTOR) as usize;
|
||||
|
||||
// subtract out overhead for system prompt and tools
|
||||
target_limit - (SYSTEM_PROMPT_TOKEN_OVERHEAD + TOOLS_TOKEN_OVERHEAD)
|
||||
}
|
||||
|
||||
pub fn get_messages_token_counts(token_counter: &TokenCounter, messages: &[Message]) -> Vec<usize> {
|
||||
// Calculate current token count of each message, use count_chat_tokens to ensure we
|
||||
// capture the full content of the message, include ToolRequests and ToolResponses
|
||||
messages
|
||||
.iter()
|
||||
.map(|msg| token_counter.count_chat_tokens("", std::slice::from_ref(msg), &[]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// These are not being used now but could be useful in the future
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct ChatTokenCounts {
|
||||
pub system: usize,
|
||||
pub tools: usize,
|
||||
pub messages: Vec<usize>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_token_counts(
|
||||
token_counter: &TokenCounter,
|
||||
messages: &mut [Message],
|
||||
system_prompt: &str,
|
||||
tools: &mut Vec<Tool>,
|
||||
) -> ChatTokenCounts {
|
||||
// Take into account the system prompt (includes goosehints), and our tools input
|
||||
let system_prompt_token_count = token_counter.count_tokens(system_prompt);
|
||||
let tools_token_count = token_counter.count_tokens_for_tools(tools.as_slice());
|
||||
let messages_token_count = get_messages_token_counts(token_counter, messages);
|
||||
|
||||
ChatTokenCounts {
|
||||
system: system_prompt_token_count,
|
||||
tools: tools_token_count,
|
||||
messages: messages_token_count,
|
||||
}
|
||||
}
|
||||
5
crates/goose/src/context_mgmt/mod.rs
Normal file
5
crates/goose/src/context_mgmt/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod common;
|
||||
pub mod summarize;
|
||||
pub mod truncate;
|
||||
|
||||
pub use common::*;
|
||||
422
crates/goose/src/context_mgmt/summarize.rs
Normal file
422
crates/goose/src/context_mgmt/summarize.rs
Normal file
@@ -0,0 +1,422 @@
|
||||
use super::common::get_messages_token_counts;
|
||||
use crate::message::{Message, MessageContent};
|
||||
use crate::providers::base::Provider;
|
||||
use crate::token_counter::TokenCounter;
|
||||
use anyhow::Result;
|
||||
use mcp_core::Role;
|
||||
use std::sync::Arc;
|
||||
|
||||
// Constants for the summarization prompt and a follow-up user message.
|
||||
const SUMMARY_PROMPT: &str = "You are good at summarizing conversations";
|
||||
|
||||
/// Summarize the combined messages from the accumulated summary and the current chunk.
|
||||
///
|
||||
/// This method builds the summarization request, sends it to the provider, and returns the summarized response.
|
||||
async fn summarize_combined_messages(
|
||||
provider: &Arc<dyn Provider>,
|
||||
accumulated_summary: &[Message],
|
||||
current_chunk: &[Message],
|
||||
) -> Result<Vec<Message>, anyhow::Error> {
|
||||
// Combine the accumulated summary and current chunk into a single batch.
|
||||
let combined_messages: Vec<Message> = accumulated_summary
|
||||
.iter()
|
||||
.cloned()
|
||||
.chain(current_chunk.iter().cloned())
|
||||
.collect();
|
||||
|
||||
// Format the batch as a summarization request.
|
||||
let request_text = format!(
|
||||
"Please summarize the following conversation history, preserving the key points. This summarization will be used for the later conversations.\n\n```\n{:?}\n```",
|
||||
combined_messages
|
||||
);
|
||||
let summarization_request = vec![Message::user().with_text(&request_text)];
|
||||
|
||||
// Send the request to the provider and fetch the response.
|
||||
let mut response = provider
|
||||
.complete(SUMMARY_PROMPT, &summarization_request, &[])
|
||||
.await?
|
||||
.0;
|
||||
// Set role to user as it will be used in following conversation as user content.
|
||||
response.role = Role::User;
|
||||
|
||||
// Return the summary as the new accumulated summary.
|
||||
Ok(vec![response])
|
||||
}
|
||||
|
||||
/// Preprocesses the messages to handle edge cases involving tool responses.
|
||||
///
|
||||
/// This function separates messages into two groups:
|
||||
/// 1. Messages to be summarized (`preprocessed_messages`)
|
||||
/// 2. Messages to be temporarily removed (`removed_messages`), which include:
|
||||
/// - The last tool response message.
|
||||
/// - The corresponding tool request message that immediately precedes the last tool response message (if present).
|
||||
///
|
||||
/// The function only considers the last tool response message and its pair for removal.
|
||||
fn preprocess_messages(messages: &[Message]) -> (Vec<Message>, Vec<Message>) {
|
||||
let mut preprocessed_messages = messages.to_owned();
|
||||
let mut removed_messages = Vec::new();
|
||||
|
||||
if let Some((last_index, last_message)) = messages.iter().enumerate().rev().find(|(_, m)| {
|
||||
m.content
|
||||
.iter()
|
||||
.any(|c| matches!(c, MessageContent::ToolResponse(_)))
|
||||
}) {
|
||||
// Check for the corresponding tool request message
|
||||
if last_index > 0 {
|
||||
if let Some(previous_message) = messages.get(last_index - 1) {
|
||||
if previous_message
|
||||
.content
|
||||
.iter()
|
||||
.any(|c| matches!(c, MessageContent::ToolRequest(_)))
|
||||
{
|
||||
// Add the tool request message to removed_messages
|
||||
removed_messages.push(previous_message.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add the last tool response message to removed_messages
|
||||
removed_messages.push(last_message.clone());
|
||||
|
||||
// Calculate the correct start index for removal
|
||||
let start_index = last_index + 1 - removed_messages.len();
|
||||
|
||||
// Remove the tool response and its paired tool request from preprocessed_messages
|
||||
preprocessed_messages.drain(start_index..=last_index);
|
||||
}
|
||||
|
||||
(preprocessed_messages, removed_messages)
|
||||
}
|
||||
|
||||
/// Reinserts removed messages into the summarized output.
|
||||
///
|
||||
/// This function appends messages that were temporarily removed during preprocessing
|
||||
/// back into the summarized message list. This ensures that important context,
|
||||
/// such as tool responses, is not lost.
|
||||
fn reintegrate_removed_messages(
|
||||
summarized_messages: &[Message],
|
||||
removed_messages: &[Message],
|
||||
) -> Vec<Message> {
|
||||
let mut final_messages = summarized_messages.to_owned();
|
||||
final_messages.extend_from_slice(removed_messages);
|
||||
final_messages
|
||||
}
|
||||
|
||||
// Summarization steps:
|
||||
// 1. Break down large text into smaller chunks (roughly 30% of the model’s context window).
|
||||
// 2. For each chunk:
|
||||
// a. Combine it with the previous summary (or leave blank for the first iteration).
|
||||
// b. Summarize the combined text, focusing on extracting only the information we need.
|
||||
// 3. Generate a final summary using a tailored prompt.
|
||||
pub async fn summarize_messages(
|
||||
provider: Arc<dyn Provider>,
|
||||
messages: &[Message],
|
||||
token_counter: &TokenCounter,
|
||||
context_limit: usize,
|
||||
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
|
||||
let chunk_size = context_limit / 3; // 33% of the context window.
|
||||
let summary_prompt_tokens = token_counter.count_tokens(SUMMARY_PROMPT);
|
||||
let mut accumulated_summary = Vec::new();
|
||||
|
||||
// Preprocess messages to handle tool response edge case.
|
||||
let (preprocessed_messages, removed_messages) = preprocess_messages(messages);
|
||||
|
||||
// Get token counts for each message.
|
||||
let token_counts = get_messages_token_counts(token_counter, &preprocessed_messages);
|
||||
|
||||
// Tokenize and break messages into chunks.
|
||||
let mut current_chunk: Vec<Message> = Vec::new();
|
||||
let mut current_chunk_tokens = 0;
|
||||
|
||||
for (message, message_tokens) in preprocessed_messages.iter().zip(token_counts.iter()) {
|
||||
if current_chunk_tokens + message_tokens > chunk_size - summary_prompt_tokens {
|
||||
// Summarize the current chunk with the accumulated summary.
|
||||
accumulated_summary =
|
||||
summarize_combined_messages(&provider, &accumulated_summary, ¤t_chunk)
|
||||
.await?;
|
||||
|
||||
// Reset for the next chunk.
|
||||
current_chunk.clear();
|
||||
current_chunk_tokens = 0;
|
||||
}
|
||||
|
||||
// Add message to the current chunk.
|
||||
current_chunk.push(message.clone());
|
||||
current_chunk_tokens += message_tokens;
|
||||
}
|
||||
|
||||
// Summarize the final chunk if it exists.
|
||||
if !current_chunk.is_empty() {
|
||||
accumulated_summary =
|
||||
summarize_combined_messages(&provider, &accumulated_summary, ¤t_chunk).await?;
|
||||
}
|
||||
|
||||
// Add back removed messages.
|
||||
let final_summary = reintegrate_removed_messages(&accumulated_summary, &removed_messages);
|
||||
|
||||
Ok((
|
||||
final_summary.clone(),
|
||||
get_messages_token_counts(token_counter, &final_summary),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::message::{Message, MessageContent};
|
||||
use crate::model::{ModelConfig, GPT_4O_TOKENIZER};
|
||||
use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||
use crate::providers::errors::ProviderError;
|
||||
use chrono::Utc;
|
||||
use mcp_core::{tool::Tool, Role};
|
||||
use mcp_core::{Content, TextContent, ToolCall};
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MockProvider {
|
||||
model_config: ModelConfig,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Provider for MockProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::empty()
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model_config.clone()
|
||||
}
|
||||
|
||||
async fn complete(
|
||||
&self,
|
||||
_system: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
Ok((
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
text: "Summarized content".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
},
|
||||
ProviderUsage::new("mock".to_string(), Usage::default()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn create_mock_provider() -> Arc<dyn Provider> {
|
||||
let mock_model_config =
|
||||
ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into());
|
||||
Arc::new(MockProvider {
|
||||
model_config: mock_model_config,
|
||||
})
|
||||
}
|
||||
|
||||
fn create_test_messages() -> Vec<Message> {
|
||||
vec![
|
||||
set_up_text_message("Message 1", Role::User),
|
||||
set_up_text_message("Message 2", Role::Assistant),
|
||||
set_up_text_message("Message 3", Role::User),
|
||||
]
|
||||
}
|
||||
|
||||
fn set_up_text_message(text: &str, role: Role) -> Message {
|
||||
Message {
|
||||
role,
|
||||
created: 0,
|
||||
content: vec![MessageContent::text(text.to_string())],
|
||||
}
|
||||
}
|
||||
|
||||
fn set_up_tool_request_message(id: &str, tool_call: ToolCall) -> Message {
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
created: 0,
|
||||
content: vec![MessageContent::tool_request(id.to_string(), Ok(tool_call))],
|
||||
}
|
||||
}
|
||||
|
||||
fn set_up_tool_response_message(id: &str, tool_response: Vec<Content>) -> Message {
|
||||
Message {
|
||||
role: Role::User,
|
||||
created: 0,
|
||||
content: vec![MessageContent::tool_response(
|
||||
id.to_string(),
|
||||
Ok(tool_response),
|
||||
)],
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_summarize_messages_single_chunk() {
|
||||
let provider = create_mock_provider();
|
||||
let token_counter = TokenCounter::new(GPT_4O_TOKENIZER);
|
||||
let context_limit = 100; // Set a high enough limit to avoid chunking.
|
||||
let messages = create_test_messages();
|
||||
|
||||
let result = summarize_messages(
|
||||
Arc::clone(&provider),
|
||||
&messages,
|
||||
&token_counter,
|
||||
context_limit,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok(), "The function should return Ok.");
|
||||
let (summarized_messages, token_counts) = result.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
summarized_messages.len(),
|
||||
1,
|
||||
"The summary should contain one message."
|
||||
);
|
||||
assert_eq!(
|
||||
summarized_messages[0].role,
|
||||
Role::User,
|
||||
"The summarized message should be from the user."
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
token_counts.len(),
|
||||
1,
|
||||
"Token counts should match the number of summarized messages."
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_summarize_messages_multiple_chunks() {
|
||||
let provider = create_mock_provider();
|
||||
let token_counter = TokenCounter::new(GPT_4O_TOKENIZER);
|
||||
let context_limit = 30;
|
||||
let messages = create_test_messages();
|
||||
|
||||
let result = summarize_messages(
|
||||
Arc::clone(&provider),
|
||||
&messages,
|
||||
&token_counter,
|
||||
context_limit,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok(), "The function should return Ok.");
|
||||
let (summarized_messages, token_counts) = result.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
summarized_messages.len(),
|
||||
1,
|
||||
"There should be one final summarized message."
|
||||
);
|
||||
assert_eq!(
|
||||
summarized_messages[0].role,
|
||||
Role::User,
|
||||
"The summarized message should be from the user."
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
token_counts.len(),
|
||||
1,
|
||||
"Token counts should match the number of summarized messages."
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_summarize_messages_empty_input() {
|
||||
let provider = create_mock_provider();
|
||||
let token_counter = TokenCounter::new(GPT_4O_TOKENIZER);
|
||||
let context_limit = 100;
|
||||
let messages: Vec<Message> = Vec::new();
|
||||
|
||||
let result = summarize_messages(
|
||||
Arc::clone(&provider),
|
||||
&messages,
|
||||
&token_counter,
|
||||
context_limit,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok(), "The function should return Ok.");
|
||||
let (summarized_messages, token_counts) = result.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
summarized_messages.len(),
|
||||
0,
|
||||
"The summary should be empty for an empty input."
|
||||
);
|
||||
assert!(
|
||||
token_counts.is_empty(),
|
||||
"Token counts should be empty for an empty input."
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_preprocess_messages_without_tool_response() {
|
||||
let messages = create_test_messages();
|
||||
let (preprocessed_messages, removed_messages) = preprocess_messages(&messages);
|
||||
|
||||
assert_eq!(
|
||||
preprocessed_messages.len(),
|
||||
3,
|
||||
"Only the user message should remain after preprocessing."
|
||||
);
|
||||
assert_eq!(
|
||||
removed_messages.len(),
|
||||
0,
|
||||
"The tool request and tool response messages should be removed."
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_preprocess_messages_with_tool_response() {
|
||||
let arguments = json!({
|
||||
"param1": "value1"
|
||||
});
|
||||
let messages = vec![
|
||||
set_up_text_message("Message 1", Role::User),
|
||||
set_up_tool_request_message("id", ToolCall::new("tool_name", json!(arguments))),
|
||||
set_up_tool_response_message("id", vec![Content::text("tool done")]),
|
||||
];
|
||||
|
||||
let (preprocessed_messages, removed_messages) = preprocess_messages(&messages);
|
||||
|
||||
assert_eq!(
|
||||
preprocessed_messages.len(),
|
||||
1,
|
||||
"Only the user message should remain after preprocessing."
|
||||
);
|
||||
assert_eq!(
|
||||
removed_messages.len(),
|
||||
2,
|
||||
"The tool request and tool response messages should be removed."
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reintegrate_removed_messages() {
|
||||
let summarized_messages = vec![Message {
|
||||
role: Role::Assistant,
|
||||
created: Utc::now().timestamp(),
|
||||
content: vec![MessageContent::Text(TextContent {
|
||||
text: "Summary".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
}];
|
||||
let arguments = json!({
|
||||
"param1": "value1"
|
||||
});
|
||||
let removed_messages = vec![
|
||||
set_up_tool_request_message("id", ToolCall::new("tool_name", json!(arguments))),
|
||||
set_up_tool_response_message("id", vec![Content::text("tool done")]),
|
||||
];
|
||||
|
||||
let final_messages = reintegrate_removed_messages(&summarized_messages, &removed_messages);
|
||||
|
||||
assert_eq!(
|
||||
final_messages.len(),
|
||||
3,
|
||||
"The final message list should include the summary and removed messages."
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -4,77 +4,6 @@ use mcp_core::Role;
|
||||
use std::collections::HashSet;
|
||||
use tracing::debug;
|
||||
|
||||
/// Trait representing a truncation strategy
|
||||
pub trait TruncationStrategy {
|
||||
/// Determines the indices of messages to remove to fit within the context limit.
|
||||
///
|
||||
/// - `messages`: The list of messages in the conversation.
|
||||
/// - `token_counts`: A parallel array containing the token count for each message.
|
||||
/// - `context_limit`: The maximum allowed context length in tokens.
|
||||
///
|
||||
/// Returns a vector of indices to remove.
|
||||
fn determine_indices_to_remove(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
token_counts: &[usize],
|
||||
context_limit: usize,
|
||||
) -> Result<HashSet<usize>>;
|
||||
}
|
||||
|
||||
/// Strategy to truncate messages by removing the oldest first
|
||||
pub struct OldestFirstTruncation;
|
||||
/// Strategy to truncate messages explicitly
|
||||
pub struct ExplicitTruncation;
|
||||
|
||||
impl TruncationStrategy for OldestFirstTruncation {
|
||||
fn determine_indices_to_remove(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
token_counts: &[usize],
|
||||
context_limit: usize,
|
||||
) -> Result<HashSet<usize>> {
|
||||
let mut indices_to_remove = HashSet::new();
|
||||
let mut total_tokens: usize = token_counts.iter().sum();
|
||||
let mut tool_ids_to_remove = HashSet::new();
|
||||
|
||||
for (i, message) in messages.iter().enumerate() {
|
||||
if total_tokens <= context_limit {
|
||||
break;
|
||||
}
|
||||
|
||||
// Remove the message
|
||||
indices_to_remove.insert(i);
|
||||
total_tokens -= token_counts[i];
|
||||
debug!(
|
||||
"OldestFirst: Removing message at index {}. Tokens removed: {}",
|
||||
i, token_counts[i]
|
||||
);
|
||||
|
||||
// If it's a ToolRequest or ToolResponse, mark its pair for removal
|
||||
if message.is_tool_call() || message.is_tool_response() {
|
||||
message.get_tool_ids().iter().for_each(|id| {
|
||||
tool_ids_to_remove.insert((i, id.to_string()));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Now, find and remove paired ToolResponses or ToolRequests
|
||||
for (i, message) in messages.iter().enumerate() {
|
||||
let message_tool_ids = message.get_tool_ids();
|
||||
// Find the other part of the pair - same tool_id but different message index
|
||||
for (message_idx, tool_id) in &tool_ids_to_remove {
|
||||
if message_idx != &i && message_tool_ids.contains(tool_id.as_str()) {
|
||||
indices_to_remove.insert(i);
|
||||
// No need to check other tool_ids for this message since it's already marked
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(indices_to_remove)
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncates the messages to fit within the model's context window.
|
||||
/// Mutates the input messages and token counts in place.
|
||||
/// Returns an error if it's impossible to truncate the messages within the context limit.
|
||||
@@ -83,11 +12,14 @@ impl TruncationStrategy for OldestFirstTruncation {
|
||||
/// - context_limit: The maximum allowed context length in tokens.
|
||||
/// - strategy: The truncation strategy to use. Only option is OldestFirstTruncation.
|
||||
pub fn truncate_messages(
|
||||
messages: &mut Vec<Message>,
|
||||
token_counts: &mut Vec<usize>,
|
||||
messages: &[Message],
|
||||
token_counts: &[usize],
|
||||
context_limit: usize,
|
||||
strategy: &dyn TruncationStrategy,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<(Vec<Message>, Vec<usize>), anyhow::Error> {
|
||||
let mut messages = messages.to_owned();
|
||||
let mut token_counts = token_counts.to_owned();
|
||||
|
||||
if messages.len() != token_counts.len() {
|
||||
return Err(anyhow!(
|
||||
"The vector for messages and token_counts must have same length"
|
||||
@@ -114,12 +46,12 @@ pub fn truncate_messages(
|
||||
}
|
||||
|
||||
if total_tokens <= context_limit {
|
||||
return Ok(()); // No truncation needed
|
||||
return Ok((messages, token_counts)); // No truncation needed
|
||||
}
|
||||
|
||||
// Step 2: Determine indices to remove based on strategy
|
||||
let indices_to_remove =
|
||||
strategy.determine_indices_to_remove(messages, token_counts, context_limit)?;
|
||||
strategy.determine_indices_to_remove(&messages, &token_counts, context_limit)?;
|
||||
|
||||
// Step 3: Remove the marked messages
|
||||
// Vectorize the set and sort in reverse order to avoid shifting indices when removing
|
||||
@@ -174,10 +106,77 @@ pub fn truncate_messages(
|
||||
}
|
||||
|
||||
debug!("Truncation complete. Total tokens: {}", total_tokens);
|
||||
Ok(())
|
||||
Ok((messages, token_counts))
|
||||
}
|
||||
|
||||
// truncate.rs
|
||||
/// Trait representing a truncation strategy
|
||||
pub trait TruncationStrategy {
|
||||
/// Determines the indices of messages to remove to fit within the context limit.
|
||||
///
|
||||
/// - `messages`: The list of messages in the conversation.
|
||||
/// - `token_counts`: A parallel array containing the token count for each message.
|
||||
/// - `context_limit`: The maximum allowed context length in tokens.
|
||||
///
|
||||
/// Returns a vector of indices to remove.
|
||||
fn determine_indices_to_remove(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
token_counts: &[usize],
|
||||
context_limit: usize,
|
||||
) -> Result<HashSet<usize>>;
|
||||
}
|
||||
|
||||
/// Strategy to truncate messages by removing the oldest first
|
||||
pub struct OldestFirstTruncation;
|
||||
|
||||
impl TruncationStrategy for OldestFirstTruncation {
|
||||
fn determine_indices_to_remove(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
token_counts: &[usize],
|
||||
context_limit: usize,
|
||||
) -> Result<HashSet<usize>> {
|
||||
let mut indices_to_remove = HashSet::new();
|
||||
let mut total_tokens: usize = token_counts.iter().sum();
|
||||
let mut tool_ids_to_remove = HashSet::new();
|
||||
|
||||
for (i, message) in messages.iter().enumerate() {
|
||||
if total_tokens <= context_limit {
|
||||
break;
|
||||
}
|
||||
|
||||
// Remove the message
|
||||
indices_to_remove.insert(i);
|
||||
total_tokens -= token_counts[i];
|
||||
debug!(
|
||||
"OldestFirst: Removing message at index {}. Tokens removed: {}",
|
||||
i, token_counts[i]
|
||||
);
|
||||
|
||||
// If it's a ToolRequest or ToolResponse, mark its pair for removal
|
||||
if message.is_tool_call() || message.is_tool_response() {
|
||||
message.get_tool_ids().iter().for_each(|id| {
|
||||
tool_ids_to_remove.insert((i, id.to_string()));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Now, find and remove paired ToolResponses or ToolRequests
|
||||
for (i, message) in messages.iter().enumerate() {
|
||||
let message_tool_ids = message.get_tool_ids();
|
||||
// Find the other part of the pair - same tool_id but different message index
|
||||
for (message_idx, tool_id) in &tool_ids_to_remove {
|
||||
if message_idx != &i && message_tool_ids.contains(tool_id.as_str()) {
|
||||
indices_to_remove.insert(i);
|
||||
// No need to check other tool_ids for this message since it's already marked
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(indices_to_remove)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -326,7 +325,7 @@ mod tests {
|
||||
let (mut messages, mut token_counts) = create_messages_with_counts(2, 25, false);
|
||||
let context_limit = 100; // Exactly matches total tokens
|
||||
|
||||
truncate_messages(
|
||||
(messages, token_counts) = truncate_messages(
|
||||
&mut messages,
|
||||
&mut token_counts,
|
||||
context_limit,
|
||||
@@ -340,7 +339,7 @@ mod tests {
|
||||
messages.push(user_text(5, 1).0);
|
||||
token_counts.push(1);
|
||||
|
||||
truncate_messages(
|
||||
(messages, token_counts) = truncate_messages(
|
||||
&mut messages,
|
||||
&mut token_counts,
|
||||
context_limit,
|
||||
@@ -380,7 +379,7 @@ mod tests {
|
||||
let mut messages_clone = messages.clone();
|
||||
let mut token_counts_clone = token_counts.clone();
|
||||
|
||||
truncate_messages(
|
||||
(messages_clone, _) = truncate_messages(
|
||||
&mut messages_clone,
|
||||
&mut token_counts_clone,
|
||||
context_limit,
|
||||
@@ -425,7 +424,7 @@ mod tests {
|
||||
let mut token_counts = vec![50, 10, 10, 20, 5];
|
||||
let context_limit = 45; // Force truncation
|
||||
|
||||
truncate_messages(
|
||||
(messages, token_counts) = truncate_messages(
|
||||
&mut messages,
|
||||
&mut token_counts,
|
||||
context_limit,
|
||||
@@ -1,6 +1,6 @@
|
||||
pub mod agents;
|
||||
pub mod config;
|
||||
pub mod memory_condense;
|
||||
pub mod context_mgmt;
|
||||
pub mod message;
|
||||
pub mod model;
|
||||
pub mod permission;
|
||||
@@ -10,4 +10,3 @@ pub mod recipe;
|
||||
pub mod session;
|
||||
pub mod token_counter;
|
||||
pub mod tracing;
|
||||
pub mod truncate;
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
use crate::message::Message;
|
||||
use crate::providers::base::Provider;
|
||||
use crate::token_counter::TokenCounter;
|
||||
use anyhow::{anyhow, Result};
|
||||
use std::sync::Arc;
|
||||
use tracing::debug;
|
||||
|
||||
const SYSTEM_PROMPT: &str = "You are good at summarizing.";
|
||||
|
||||
fn create_summarize_request(messages: &[Message]) -> Vec<Message> {
|
||||
vec![
|
||||
Message::user().with_text(format!("Please use a few concise sentences to summarize this chat, while keeping the important information.\n\n```\n{:?}```", messages)),
|
||||
]
|
||||
}
|
||||
async fn single_request(
|
||||
provider: &Arc<dyn Provider>,
|
||||
messages: &[Message],
|
||||
) -> Result<Message, anyhow::Error> {
|
||||
Ok(provider.complete(SYSTEM_PROMPT, messages, &[]).await?.0)
|
||||
}
|
||||
async fn memory_condense(
|
||||
provider: Arc<dyn Provider>,
|
||||
token_counter: &TokenCounter,
|
||||
messages: &mut Vec<Message>,
|
||||
token_counts: &mut Vec<usize>,
|
||||
context_limit: usize,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let system_prompt_tokens = token_counter.count_tokens(SYSTEM_PROMPT);
|
||||
|
||||
// Since the process will run multiple times, we should avoid expensive operations like random access.
|
||||
let mut message_stack = messages.iter().cloned().rev().collect::<Vec<_>>();
|
||||
let mut count_stack = token_counts.iter().copied().rev().collect::<Vec<_>>();
|
||||
|
||||
// Tracks the number of remaining tokens in the stack
|
||||
let mut total_tokens = count_stack.iter().sum::<usize>();
|
||||
|
||||
// Tracks the change of total_tokens in the previous loop.
|
||||
// If diff <= 0, then the model cannot summarize any further. We set it to 1 before the process
|
||||
// to ensure that the process starts.
|
||||
let mut diff = 1;
|
||||
|
||||
while total_tokens > context_limit && diff > 0 {
|
||||
let mut batch = Vec::new();
|
||||
let mut current_tokens = 0;
|
||||
|
||||
// Extracts the beginning messages (which appears in the front of the message stack) to
|
||||
// summarize.
|
||||
while total_tokens > current_tokens + context_limit
|
||||
&& current_tokens + system_prompt_tokens <= context_limit
|
||||
{
|
||||
batch.push(message_stack.pop().unwrap());
|
||||
current_tokens += count_stack.pop().unwrap();
|
||||
}
|
||||
|
||||
// It could happen that the extracted messages are always the previous summary when the
|
||||
// context limit is very small. We should force it to consume more messages.
|
||||
if !batch.is_empty()
|
||||
&& !message_stack.is_empty()
|
||||
&& current_tokens + system_prompt_tokens <= context_limit
|
||||
{
|
||||
batch.push(message_stack.pop().unwrap());
|
||||
current_tokens += count_stack.pop().unwrap();
|
||||
}
|
||||
|
||||
diff = -(current_tokens as isize);
|
||||
let request = create_summarize_request(&batch);
|
||||
let response_text = single_request(&provider, &request).await?.as_concat_text();
|
||||
|
||||
// Ensure the conversation starts with a User message
|
||||
let curr_messages = vec![
|
||||
// shoule be in reversed order
|
||||
Message::assistant().with_text(&response_text),
|
||||
Message::user().with_text("Hello! How are we progressing?"),
|
||||
];
|
||||
let curr_tokens = token_counter.count_chat_tokens("", &curr_messages, &[]);
|
||||
diff += curr_tokens as isize;
|
||||
count_stack.push(curr_tokens);
|
||||
message_stack.extend(curr_messages);
|
||||
|
||||
// Update the counter
|
||||
total_tokens = total_tokens.checked_add_signed(diff).unwrap();
|
||||
}
|
||||
|
||||
if total_tokens <= context_limit {
|
||||
*messages = message_stack.into_iter().rev().collect();
|
||||
*token_counts = count_stack.into_iter().rev().collect();
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("Cannot compress the messages anymore"))
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO: currently not used. we will add this is a feature flag under context mgmt
|
||||
pub async fn condense_messages(
|
||||
provider: Arc<dyn Provider>,
|
||||
token_counter: &TokenCounter,
|
||||
messages: &mut Vec<Message>,
|
||||
token_counts: &mut Vec<usize>,
|
||||
context_limit: usize,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let total_tokens: usize = token_counts.iter().sum();
|
||||
debug!("Total tokens before memory condensation: {}", total_tokens);
|
||||
|
||||
// The compressor should determine whether we need to compress the messages or not. This
|
||||
// function just checks if the limit is satisfied.
|
||||
memory_condense(
|
||||
provider,
|
||||
token_counter,
|
||||
messages,
|
||||
token_counts,
|
||||
context_limit,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let total_tokens: usize = token_counts.iter().sum();
|
||||
debug!("Total tokens after memory condensation: {}", total_tokens);
|
||||
|
||||
// Compressor should handle this case.
|
||||
assert!(total_tokens <= context_limit, "Illegal compression result from the compressor: the number of tokens is greater than the limit.");
|
||||
|
||||
debug!(
|
||||
"Memory condensation complete. Total tokens: {}",
|
||||
total_tokens
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
@@ -14,11 +14,12 @@ use mcp_core::prompt::{PromptMessage, PromptMessageContent, PromptMessageRole};
|
||||
use mcp_core::resource::ResourceContents;
|
||||
use mcp_core::role::Role;
|
||||
use mcp_core::tool::ToolCall;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
mod tool_result_serde;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolRequest {
|
||||
pub id: String,
|
||||
@@ -42,7 +43,7 @@ impl ToolRequest {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolResponse {
|
||||
pub id: String,
|
||||
@@ -50,7 +51,7 @@ pub struct ToolResponse {
|
||||
pub tool_result: ToolResult<Vec<Content>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolConfirmationRequest {
|
||||
pub id: String,
|
||||
@@ -59,18 +60,18 @@ pub struct ToolConfirmationRequest {
|
||||
pub prompt: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ThinkingContent {
|
||||
pub thinking: String,
|
||||
pub signature: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct RedactedThinkingContent {
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FrontendToolRequest {
|
||||
pub id: String,
|
||||
@@ -78,7 +79,12 @@ pub struct FrontendToolRequest {
|
||||
pub tool_call: ToolResult<ToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ContextLengthExceeded {
|
||||
pub msg: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
/// Content passed inside a message, which can be both simple content and tool content
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
pub enum MessageContent {
|
||||
@@ -90,6 +96,7 @@ pub enum MessageContent {
|
||||
FrontendToolRequest(FrontendToolRequest),
|
||||
Thinking(ThinkingContent),
|
||||
RedactedThinking(RedactedThinkingContent),
|
||||
ContextLengthExceeded(ContextLengthExceeded),
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
@@ -153,6 +160,11 @@ impl MessageContent {
|
||||
tool_call,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn context_length_exceeded<S: Into<String>>(msg: S) -> Self {
|
||||
MessageContent::ContextLengthExceeded(ContextLengthExceeded { msg: msg.into() })
|
||||
}
|
||||
|
||||
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
|
||||
if let MessageContent::ToolRequest(ref tool_request) = self {
|
||||
Some(tool_request)
|
||||
@@ -261,7 +273,7 @@ impl From<PromptMessage> for Message {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
/// A message to or from an LLM
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Message {
|
||||
@@ -358,6 +370,11 @@ impl Message {
|
||||
self.with_content(MessageContent::redacted_thinking(data))
|
||||
}
|
||||
|
||||
/// Add context length exceeded content to the message
|
||||
pub fn with_context_length_exceeded<S: Into<String>>(self, msg: S) -> Self {
|
||||
self.with_content(MessageContent::context_length_exceeded(msg))
|
||||
}
|
||||
|
||||
/// Get the concatenated text content of the message, separated by newlines
|
||||
pub fn as_concat_text(&self) -> String {
|
||||
self.content
|
||||
|
||||
@@ -60,6 +60,9 @@ pub fn format_messages(messages: &[Message]) -> Vec<Value> {
|
||||
MessageContent::ToolConfirmationRequest(_tool_confirmation_request) => {
|
||||
// Skip tool confirmation requests
|
||||
}
|
||||
MessageContent::ContextLengthExceeded(_) => {
|
||||
// Skip
|
||||
}
|
||||
MessageContent::Thinking(thinking) => {
|
||||
content.push(json!({
|
||||
"type": "thinking",
|
||||
|
||||
@@ -42,6 +42,9 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::C
|
||||
// Redacted thinking blocks are not supported in Bedrock - skip
|
||||
bedrock::ContentBlock::Text("".to_string())
|
||||
}
|
||||
MessageContent::ContextLengthExceeded(_) => {
|
||||
bail!("ContextLengthExceeded should not get passed to the provider")
|
||||
}
|
||||
MessageContent::ToolRequest(tool_req) => {
|
||||
let tool_use_id = tool_req.id.to_string();
|
||||
let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() {
|
||||
|
||||
@@ -110,6 +110,9 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageContent::ContextLengthExceeded(_) => {
|
||||
continue;
|
||||
}
|
||||
MessageContent::ToolResponse(response) => {
|
||||
match &response.tool_result {
|
||||
Ok(contents) => {
|
||||
|
||||
@@ -52,6 +52,9 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
// Redacted thinking blocks are not directly used in OpenAI format
|
||||
continue;
|
||||
}
|
||||
MessageContent::ContextLengthExceeded(_) => {
|
||||
continue;
|
||||
}
|
||||
MessageContent::ToolRequest(request) => match &request.tool_call {
|
||||
Ok(tool_call) => {
|
||||
let sanitized_name = sanitize_function_name(&tool_call.name);
|
||||
|
||||
@@ -73,12 +73,18 @@ export interface ExtensionCallResult<T> {
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface ContextLengthExceededContent {
|
||||
type: 'contextLengthExceeded';
|
||||
msg: string;
|
||||
}
|
||||
|
||||
export type MessageContent =
|
||||
| TextContent
|
||||
| ImageContent
|
||||
| ToolRequestMessageContent
|
||||
| ToolResponseMessageContent
|
||||
| ToolConfirmationRequestMessageContent;
|
||||
| ToolConfirmationRequestMessageContent
|
||||
| ContextLengthExceededContent;
|
||||
|
||||
export interface Message {
|
||||
id?: string;
|
||||
@@ -175,8 +181,18 @@ function generateId(): string {
|
||||
// Helper functions to extract content from messages
|
||||
export function getTextContent(message: Message): string {
|
||||
return message.content
|
||||
.filter((content): content is TextContent => content.type === 'text')
|
||||
.map((content) => content.text)
|
||||
.filter(
|
||||
(content): content is TextContent | ContextLengthExceededContent =>
|
||||
content.type === 'text' || content.type === 'contextLengthExceeded'
|
||||
)
|
||||
.map((content) => {
|
||||
if (content.type === 'text') {
|
||||
return content.text;
|
||||
} else if (content.type === 'contextLengthExceeded') {
|
||||
return content.msg;
|
||||
}
|
||||
return '';
|
||||
})
|
||||
.join('\n');
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user