mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-19 05:24:28 +01:00
fix: truncation agent token calculations (#915)
This commit is contained in:
@@ -43,6 +43,8 @@ impl TruncateAgent {
|
||||
&self,
|
||||
messages: &mut Vec<Message>,
|
||||
estimate_factor: f32,
|
||||
system_prompt: &str,
|
||||
tools: &mut Vec<Tool>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Model's actual context limit
|
||||
let context_limit = self
|
||||
@@ -57,20 +59,37 @@ impl TruncateAgent {
|
||||
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
|
||||
let context_limit = (context_limit as f32 * estimate_factor) as usize;
|
||||
|
||||
// Calculate current token count
|
||||
// Take into account the system prompt, and our tools input and subtract that from the
|
||||
// remaining context limit
|
||||
let system_prompt_token_count = self.token_counter.count_tokens(system_prompt);
|
||||
let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice());
|
||||
|
||||
// Check if system prompt + tools exceed our context limit
|
||||
let remaining_tokens = context_limit
|
||||
.checked_sub(system_prompt_token_count)
|
||||
.and_then(|remaining| remaining.checked_sub(tools_token_count))
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("System prompt and tools exceed estimated context limit")
|
||||
})?;
|
||||
|
||||
let context_limit = remaining_tokens;
|
||||
|
||||
// Calculate current token count of each message, use count_chat_tokens to ensure we
|
||||
// capture the full content of the message, include ToolRequests and ToolResponses
|
||||
let mut token_counts: Vec<usize> = messages
|
||||
.iter()
|
||||
.map(|msg| self.token_counter.count_tokens(&msg.as_concat_text()))
|
||||
.map(|msg| {
|
||||
self.token_counter
|
||||
.count_chat_tokens("", std::slice::from_ref(msg), &[])
|
||||
})
|
||||
.collect();
|
||||
|
||||
let _ = truncate_messages(
|
||||
truncate_messages(
|
||||
messages,
|
||||
&mut token_counts,
|
||||
context_limit,
|
||||
&OldestFirstTruncation,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,7 +248,7 @@ impl Agent for TruncateAgent {
|
||||
// Create an error message & terminate the stream
|
||||
// the previous message would have been a user message (e.g. before any tool calls, this is just after the input message.
|
||||
// at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user)
|
||||
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate.");
|
||||
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again.");
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -243,7 +262,11 @@ impl Agent for TruncateAgent {
|
||||
// release the lock before truncation to prevent deadlock
|
||||
drop(capabilities);
|
||||
|
||||
self.truncate_messages(&mut messages, estimate_factor).await?;
|
||||
if let Err(err) = self.truncate_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await {
|
||||
yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err));
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Re-acquire the lock
|
||||
capabilities = self.capabilities.lock().await;
|
||||
|
||||
@@ -112,7 +112,7 @@ impl TokenCounter {
|
||||
encoding.len()
|
||||
}
|
||||
|
||||
fn count_tokens_for_tools(&self, tools: &[Tool]) -> usize {
|
||||
pub fn count_tokens_for_tools(&self, tools: &[Tool]) -> usize {
|
||||
// Token counts for different function components
|
||||
let func_init = 7; // Tokens for function initialization
|
||||
let prop_init = 3; // Tokens for properties initialization
|
||||
|
||||
@@ -63,7 +63,6 @@ impl TruncationStrategy for OldestFirstTruncation {
|
||||
for (message_idx, tool_id) in &tool_ids_to_remove {
|
||||
if message_idx != &i && message_tool_ids.contains(tool_id.as_str()) {
|
||||
indices_to_remove.insert(i);
|
||||
total_tokens -= token_counts[i];
|
||||
// No need to check other tool_ids for this message since it's already marked
|
||||
break;
|
||||
}
|
||||
@@ -86,7 +85,7 @@ pub fn truncate_messages(
|
||||
token_counts: &mut Vec<usize>,
|
||||
context_limit: usize,
|
||||
strategy: &dyn TruncationStrategy,
|
||||
) -> Result<()> {
|
||||
) -> Result<(), anyhow::Error> {
|
||||
if messages.len() != token_counts.len() {
|
||||
return Err(anyhow!(
|
||||
"The vector for messages and token_counts must have same length"
|
||||
|
||||
Reference in New Issue
Block a user