fix: truncation agent token calculations (#915)

This commit is contained in:
Kalvin C
2025-01-30 07:50:19 -08:00
committed by GitHub
parent e8ced5a385
commit ff71de422b
3 changed files with 33 additions and 11 deletions

View File

@@ -43,6 +43,8 @@ impl TruncateAgent {
&self,
messages: &mut Vec<Message>,
estimate_factor: f32,
system_prompt: &str,
tools: &mut Vec<Tool>,
) -> anyhow::Result<()> {
// Model's actual context limit
let context_limit = self
@@ -57,20 +59,37 @@ impl TruncateAgent {
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
let context_limit = (context_limit as f32 * estimate_factor) as usize;
// Calculate current token count
// Take into account the system prompt, and our tools input and subtract that from the
// remaining context limit
let system_prompt_token_count = self.token_counter.count_tokens(system_prompt);
let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice());
// Check if system prompt + tools exceed our context limit
let remaining_tokens = context_limit
.checked_sub(system_prompt_token_count)
.and_then(|remaining| remaining.checked_sub(tools_token_count))
.ok_or_else(|| {
anyhow::anyhow!("System prompt and tools exceed estimated context limit")
})?;
let context_limit = remaining_tokens;
// Calculate current token count of each message, use count_chat_tokens to ensure we
// capture the full content of the message, include ToolRequests and ToolResponses
let mut token_counts: Vec<usize> = messages
.iter()
.map(|msg| self.token_counter.count_tokens(&msg.as_concat_text()))
.map(|msg| {
self.token_counter
.count_chat_tokens("", std::slice::from_ref(msg), &[])
})
.collect();
let _ = truncate_messages(
truncate_messages(
messages,
&mut token_counts,
context_limit,
&OldestFirstTruncation,
);
Ok(())
)
}
}
@@ -229,7 +248,7 @@ impl Agent for TruncateAgent {
// Create an error message & terminate the stream
// the previous message would have been a user message (e.g. before any tool calls, this is just after the input message.
// at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user)
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate.");
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again.");
break;
}
@@ -243,7 +262,11 @@ impl Agent for TruncateAgent {
// release the lock before truncation to prevent deadlock
drop(capabilities);
self.truncate_messages(&mut messages, estimate_factor).await?;
if let Err(err) = self.truncate_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await {
yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err));
break;
}
// Re-acquire the lock
capabilities = self.capabilities.lock().await;

View File

@@ -112,7 +112,7 @@ impl TokenCounter {
encoding.len()
}
fn count_tokens_for_tools(&self, tools: &[Tool]) -> usize {
pub fn count_tokens_for_tools(&self, tools: &[Tool]) -> usize {
// Token counts for different function components
let func_init = 7; // Tokens for function initialization
let prop_init = 3; // Tokens for properties initialization

View File

@@ -63,7 +63,6 @@ impl TruncationStrategy for OldestFirstTruncation {
for (message_idx, tool_id) in &tool_ids_to_remove {
if message_idx != &i && message_tool_ids.contains(tool_id.as_str()) {
indices_to_remove.insert(i);
total_tokens -= token_counts[i];
// No need to check other tool_ids for this message since it's already marked
break;
}
@@ -86,7 +85,7 @@ pub fn truncate_messages(
token_counts: &mut Vec<usize>,
context_limit: usize,
strategy: &dyn TruncationStrategy,
) -> Result<()> {
) -> Result<(), anyhow::Error> {
if messages.len() != token_counts.len() {
return Err(anyhow!(
"The vector for messages and token_counts must have same length"