From ff71de422b1d58cb29ef20f7a5e762eb3714a63b Mon Sep 17 00:00:00 2001 From: Kalvin C Date: Thu, 30 Jan 2025 07:50:19 -0800 Subject: [PATCH] fix: truncation agent token calculations (#915) --- crates/goose/src/agents/truncate.rs | 39 +++++++++++++++++++++++------ crates/goose/src/token_counter.rs | 2 +- crates/goose/src/truncate.rs | 3 +-- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index cef34c60..fa55b096 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -43,6 +43,8 @@ impl TruncateAgent { &self, messages: &mut Vec, estimate_factor: f32, + system_prompt: &str, + tools: &mut Vec, ) -> anyhow::Result<()> { // Model's actual context limit let context_limit = self @@ -57,20 +59,37 @@ impl TruncateAgent { // Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude) let context_limit = (context_limit as f32 * estimate_factor) as usize; - // Calculate current token count + // Take into account the system prompt, and our tools input and subtract that from the + // remaining context limit + let system_prompt_token_count = self.token_counter.count_tokens(system_prompt); + let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice()); + + // Check if system prompt + tools exceed our context limit + let remaining_tokens = context_limit + .checked_sub(system_prompt_token_count) + .and_then(|remaining| remaining.checked_sub(tools_token_count)) + .ok_or_else(|| { + anyhow::anyhow!("System prompt and tools exceed estimated context limit") + })?; + + let context_limit = remaining_tokens; + + // Calculate current token count of each message, use count_chat_tokens to ensure we + // capture the full content of the message, include ToolRequests and ToolResponses let mut token_counts: Vec = messages .iter() - .map(|msg| self.token_counter.count_tokens(&msg.as_concat_text())) + .map(|msg| { + self.token_counter + .count_chat_tokens("", std::slice::from_ref(msg), &[]) + }) .collect(); - let _ = truncate_messages( + truncate_messages( messages, &mut token_counts, context_limit, &OldestFirstTruncation, - ); - - Ok(()) + ) } } @@ -229,7 +248,7 @@ impl Agent for TruncateAgent { // Create an error message & terminate the stream // the previous message would have been a user message (e.g. before any tool calls, this is just after the input message. // at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user) - yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate."); + yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again."); break; } @@ -243,7 +262,11 @@ impl Agent for TruncateAgent { // release the lock before truncation to prevent deadlock drop(capabilities); - self.truncate_messages(&mut messages, estimate_factor).await?; + if let Err(err) = self.truncate_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await { + yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err)); + break; + } + // Re-acquire the lock capabilities = self.capabilities.lock().await; diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index b719433f..1650df25 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -112,7 +112,7 @@ impl TokenCounter { encoding.len() } - fn count_tokens_for_tools(&self, tools: &[Tool]) -> usize { + pub fn count_tokens_for_tools(&self, tools: &[Tool]) -> usize { // Token counts for different function components let func_init = 7; // Tokens for function initialization let prop_init = 3; // Tokens for properties initialization diff --git a/crates/goose/src/truncate.rs b/crates/goose/src/truncate.rs index d8375689..47e01a12 100644 --- a/crates/goose/src/truncate.rs +++ b/crates/goose/src/truncate.rs @@ -63,7 +63,6 @@ impl TruncationStrategy for OldestFirstTruncation { for (message_idx, tool_id) in &tool_ids_to_remove { if message_idx != &i && message_tool_ids.contains(tool_id.as_str()) { indices_to_remove.insert(i); - total_tokens -= token_counts[i]; // No need to check other tool_ids for this message since it's already marked break; } @@ -86,7 +85,7 @@ pub fn truncate_messages( token_counts: &mut Vec, context_limit: usize, strategy: &dyn TruncationStrategy, -) -> Result<()> { +) -> Result<(), anyhow::Error> { if messages.len() != token_counts.len() { return Err(anyhow!( "The vector for messages and token_counts must have same length"