mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 07:04:21 +01:00
feat(agent): memory condensation for longer context (#1457)
This commit is contained in:
@@ -13,10 +13,15 @@ impl AgentCommand {
|
|||||||
|
|
||||||
let versions = AgentFactory::available_versions();
|
let versions = AgentFactory::available_versions();
|
||||||
let default_version = AgentFactory::default_version();
|
let default_version = AgentFactory::default_version();
|
||||||
|
let configured_version = AgentFactory::configured_version();
|
||||||
|
|
||||||
for version in versions {
|
for version in versions {
|
||||||
if version == default_version {
|
if version == default_version && version == configured_version {
|
||||||
writeln!(output, "* {} (default)", version)?;
|
writeln!(output, "* {} (default)", version)?;
|
||||||
|
} else if version == default_version {
|
||||||
|
writeln!(output, " {} (default)", version)?;
|
||||||
|
} else if version == configured_version {
|
||||||
|
writeln!(output, "* {}", version)?;
|
||||||
} else {
|
} else {
|
||||||
writeln!(output, " {}", version)?;
|
writeln!(output, " {}", version)?;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,11 +31,7 @@ pub async fn build_session(
|
|||||||
goose::providers::create(&provider_name, model_config).expect("Failed to create provider");
|
goose::providers::create(&provider_name, model_config).expect("Failed to create provider");
|
||||||
|
|
||||||
// Create the agent
|
// Create the agent
|
||||||
let agent_version: Option<String> = config.get("GOOSE_AGENT").ok();
|
let mut agent = AgentFactory::create(&AgentFactory::configured_version(), provider)
|
||||||
let mut agent = match agent_version {
|
|
||||||
Some(version) => AgentFactory::create(&version, provider),
|
|
||||||
None => AgentFactory::create(AgentFactory::default_version(), provider),
|
|
||||||
}
|
|
||||||
.expect("Failed to create agent");
|
.expect("Failed to create agent");
|
||||||
|
|
||||||
// Setup extensions for the agent
|
// Setup extensions for the agent
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use std::collections::HashMap;
|
|||||||
use std::sync::{OnceLock, RwLock};
|
use std::sync::{OnceLock, RwLock};
|
||||||
|
|
||||||
pub use super::Agent;
|
pub use super::Agent;
|
||||||
|
use crate::config::Config;
|
||||||
use crate::providers::base::Provider;
|
use crate::providers::base::Provider;
|
||||||
|
|
||||||
type AgentConstructor = Box<dyn Fn(Box<dyn Provider>) -> Box<dyn Agent> + Send + Sync>;
|
type AgentConstructor = Box<dyn Fn(Box<dyn Provider>) -> Box<dyn Agent> + Send + Sync>;
|
||||||
@@ -46,6 +47,13 @@ impl AgentFactory {
|
|||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn configured_version() -> String {
|
||||||
|
let config = Config::global();
|
||||||
|
config
|
||||||
|
.get::<String>("GOOSE_AGENT")
|
||||||
|
.unwrap_or_else(|_| Self::default_version().to_string())
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the default version name
|
/// Get the default version name
|
||||||
pub fn default_version() -> &'static str {
|
pub fn default_version() -> &'static str {
|
||||||
"truncate"
|
"truncate"
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ pub mod extension;
|
|||||||
mod factory;
|
mod factory;
|
||||||
mod permission_judge;
|
mod permission_judge;
|
||||||
mod reference;
|
mod reference;
|
||||||
|
mod summarize;
|
||||||
mod truncate;
|
mod truncate;
|
||||||
|
|
||||||
pub use agent::Agent;
|
pub use agent::Agent;
|
||||||
|
|||||||
457
crates/goose/src/agents/summarize.rs
Normal file
457
crates/goose/src/agents/summarize.rs
Normal file
@@ -0,0 +1,457 @@
|
|||||||
|
/// A summarize agent that lets the model summarize the conversation when the history exceeds the
|
||||||
|
/// model's context limit. If the model fails to summarize, then it falls back to the legacy
|
||||||
|
/// truncation method. Still cannot read resources.
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::stream::BoxStream;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::{debug, error, instrument, warn};
|
||||||
|
|
||||||
|
use super::detect_read_only_tools;
|
||||||
|
use super::Agent;
|
||||||
|
use crate::agents::capabilities::Capabilities;
|
||||||
|
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::config::ExperimentManager;
|
||||||
|
use crate::memory_condense::condense_messages;
|
||||||
|
use crate::message::{Message, ToolRequest};
|
||||||
|
use crate::providers::base::Provider;
|
||||||
|
use crate::providers::base::ProviderUsage;
|
||||||
|
use crate::providers::errors::ProviderError;
|
||||||
|
use crate::register_agent;
|
||||||
|
use crate::token_counter::TokenCounter;
|
||||||
|
use crate::truncate::{truncate_messages, OldestFirstTruncation};
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use indoc::indoc;
|
||||||
|
use mcp_core::prompt::Prompt;
|
||||||
|
use mcp_core::protocol::GetPromptResult;
|
||||||
|
use mcp_core::{tool::Tool, Content};
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
|
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
|
||||||
|
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
|
||||||
|
|
||||||
|
/// Summarize implementation of an Agent
|
||||||
|
pub struct SummarizeAgent {
|
||||||
|
capabilities: Mutex<Capabilities>,
|
||||||
|
token_counter: TokenCounter,
|
||||||
|
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed)
|
||||||
|
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SummarizeAgent {
|
||||||
|
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||||
|
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||||
|
// Create channel with buffer size 32 (adjust if needed)
|
||||||
|
let (tx, rx) = mpsc::channel(32);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
capabilities: Mutex::new(Capabilities::new(provider)),
|
||||||
|
token_counter,
|
||||||
|
confirmation_tx: tx,
|
||||||
|
confirmation_rx: Mutex::new(rx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Truncates the messages to fit within the model's context window
|
||||||
|
/// Ensures the last message is a user message and removes tool call-response pairs
|
||||||
|
async fn summarize_messages(
|
||||||
|
&self,
|
||||||
|
messages: &mut Vec<Message>,
|
||||||
|
estimate_factor: f32,
|
||||||
|
system_prompt: &str,
|
||||||
|
tools: &mut Vec<Tool>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
// Model's actual context limit
|
||||||
|
let context_limit = self
|
||||||
|
.capabilities
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.provider()
|
||||||
|
.get_model_config()
|
||||||
|
.context_limit();
|
||||||
|
|
||||||
|
// Our conservative estimate of the **target** context limit
|
||||||
|
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
|
||||||
|
let context_limit = (context_limit as f32 * estimate_factor) as usize;
|
||||||
|
|
||||||
|
// Take into account the system prompt, and our tools input and subtract that from the
|
||||||
|
// remaining context limit
|
||||||
|
let system_prompt_token_count = self.token_counter.count_tokens(system_prompt);
|
||||||
|
let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice());
|
||||||
|
|
||||||
|
// Check if system prompt + tools exceed our context limit
|
||||||
|
let remaining_tokens = context_limit
|
||||||
|
.checked_sub(system_prompt_token_count)
|
||||||
|
.and_then(|remaining| remaining.checked_sub(tools_token_count))
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("System prompt and tools exceed estimated context limit")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let context_limit = remaining_tokens;
|
||||||
|
|
||||||
|
// Calculate current token count of each message, use count_chat_tokens to ensure we
|
||||||
|
// capture the full content of the message, include ToolRequests and ToolResponses
|
||||||
|
let mut token_counts: Vec<usize> = messages
|
||||||
|
.iter()
|
||||||
|
.map(|msg| {
|
||||||
|
self.token_counter
|
||||||
|
.count_chat_tokens("", std::slice::from_ref(msg), &[])
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let capabilities_guard = self.capabilities.lock().await;
|
||||||
|
if condense_messages(
|
||||||
|
&capabilities_guard,
|
||||||
|
&self.token_counter,
|
||||||
|
messages,
|
||||||
|
&mut token_counts,
|
||||||
|
context_limit,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
// Fallback to the legacy truncator if the model fails to condense the messages.
|
||||||
|
truncate_messages(
|
||||||
|
messages,
|
||||||
|
&mut token_counts,
|
||||||
|
context_limit,
|
||||||
|
&OldestFirstTruncation,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Agent for SummarizeAgent {
|
||||||
|
async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> {
|
||||||
|
let mut capabilities = self.capabilities.lock().await;
|
||||||
|
capabilities.add_extension(extension).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn remove_extension(&mut self, name: &str) {
|
||||||
|
let mut capabilities = self.capabilities.lock().await;
|
||||||
|
capabilities
|
||||||
|
.remove_extension(name)
|
||||||
|
.await
|
||||||
|
.expect("Failed to remove extension");
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_extensions(&self) -> Vec<String> {
|
||||||
|
let capabilities = self.capabilities.lock().await;
|
||||||
|
capabilities
|
||||||
|
.list_extensions()
|
||||||
|
.await
|
||||||
|
.expect("Failed to list extensions")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn passthrough(&self, _extension: &str, _request: Value) -> ExtensionResult<Value> {
|
||||||
|
// TODO implement
|
||||||
|
Ok(Value::Null)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle a confirmation response for a tool request
|
||||||
|
async fn handle_confirmation(&self, request_id: String, confirmed: bool) {
|
||||||
|
if let Err(e) = self.confirmation_tx.send((request_id, confirmed)).await {
|
||||||
|
error!("Failed to send confirmation: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self, messages), fields(user_message))]
|
||||||
|
async fn reply(
|
||||||
|
&self,
|
||||||
|
messages: &[Message],
|
||||||
|
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
|
||||||
|
let mut messages = messages.to_vec();
|
||||||
|
let reply_span = tracing::Span::current();
|
||||||
|
let mut capabilities = self.capabilities.lock().await;
|
||||||
|
let mut tools = capabilities.get_prefixed_tools().await?;
|
||||||
|
let mut truncation_attempt: usize = 0;
|
||||||
|
|
||||||
|
// Load settings from config
|
||||||
|
let config = Config::global();
|
||||||
|
let goose_mode = config.get("GOOSE_MODE").unwrap_or("auto".to_string());
|
||||||
|
|
||||||
|
// we add in the 2 resource tools if any extensions support resources
|
||||||
|
// TODO: make sure there is no collision with another extension's tool name
|
||||||
|
let read_resource_tool = Tool::new(
|
||||||
|
"platform__read_resource".to_string(),
|
||||||
|
indoc! {r#"
|
||||||
|
Read a resource from an extension.
|
||||||
|
|
||||||
|
Resources allow extensions to share data that provide context to LLMs, such as
|
||||||
|
files, database schemas, or application-specific information. This tool searches for the
|
||||||
|
resource URI in the provided extension, and reads in the resource content. If no extension
|
||||||
|
is provided, the tool will search all extensions for the resource.
|
||||||
|
"#}.to_string(),
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"required": ["uri"],
|
||||||
|
"properties": {
|
||||||
|
"uri": {"type": "string", "description": "Resource URI"},
|
||||||
|
"extension_name": {"type": "string", "description": "Optional extension name"}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let list_resources_tool = Tool::new(
|
||||||
|
"platform__list_resources".to_string(),
|
||||||
|
indoc! {r#"
|
||||||
|
List resources from an extension(s).
|
||||||
|
|
||||||
|
Resources allow extensions to share data that provide context to LLMs, such as
|
||||||
|
files, database schemas, or application-specific information. This tool lists resources
|
||||||
|
in the provided extension, and returns a list for the user to browse. If no extension
|
||||||
|
is provided, the tool will search all extensions for the resource.
|
||||||
|
"#}.to_string(),
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"extension_name": {"type": "string", "description": "Optional extension name"}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
if capabilities.supports_resources() {
|
||||||
|
tools.push(read_resource_tool);
|
||||||
|
tools.push(list_resources_tool);
|
||||||
|
}
|
||||||
|
|
||||||
|
let system_prompt = capabilities.get_system_prompt().await;
|
||||||
|
|
||||||
|
// Set the user_message field in the span instead of creating a new event
|
||||||
|
if let Some(content) = messages
|
||||||
|
.last()
|
||||||
|
.and_then(|msg| msg.content.first())
|
||||||
|
.and_then(|c| c.as_text())
|
||||||
|
{
|
||||||
|
debug!("user_message" = &content);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Box::pin(async_stream::try_stream! {
|
||||||
|
let _reply_guard = reply_span.enter();
|
||||||
|
loop {
|
||||||
|
match capabilities.provider().complete(
|
||||||
|
&system_prompt,
|
||||||
|
&messages,
|
||||||
|
&tools,
|
||||||
|
).await {
|
||||||
|
Ok((response, usage)) => {
|
||||||
|
capabilities.record_usage(usage).await;
|
||||||
|
|
||||||
|
// Reset truncation attempt
|
||||||
|
truncation_attempt = 0;
|
||||||
|
|
||||||
|
// Yield the assistant's response
|
||||||
|
yield response.clone();
|
||||||
|
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
|
||||||
|
// First collect any tool requests
|
||||||
|
let tool_requests: Vec<&ToolRequest> = response.content
|
||||||
|
.iter()
|
||||||
|
.filter_map(|content| content.as_tool_request())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if tool_requests.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process tool requests depending on goose_mode
|
||||||
|
let mut message_tool_response = Message::user();
|
||||||
|
// Clone goose_mode once before the match to avoid move issues
|
||||||
|
let mode = goose_mode.clone();
|
||||||
|
match mode.as_str() {
|
||||||
|
"approve" => {
|
||||||
|
let mut read_only_tools = Vec::new();
|
||||||
|
// Process each tool request sequentially with confirmation
|
||||||
|
if ExperimentManager::is_enabled("GOOSE_SMART_APPROVE")? {
|
||||||
|
read_only_tools = detect_read_only_tools(&capabilities, tool_requests.clone()).await;
|
||||||
|
}
|
||||||
|
for request in &tool_requests {
|
||||||
|
if let Ok(tool_call) = request.tool_call.clone() {
|
||||||
|
// Skip confirmation if the tool_call.name is in the read_only_tools list
|
||||||
|
if read_only_tools.contains(&tool_call.name) {
|
||||||
|
let output = capabilities.dispatch_tool_call(tool_call).await;
|
||||||
|
message_tool_response = message_tool_response.with_tool_response(
|
||||||
|
request.id.clone(),
|
||||||
|
output,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
let confirmation = Message::user().with_tool_confirmation_request(
|
||||||
|
request.id.clone(),
|
||||||
|
tool_call.name.clone(),
|
||||||
|
tool_call.arguments.clone(),
|
||||||
|
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
|
||||||
|
);
|
||||||
|
yield confirmation;
|
||||||
|
|
||||||
|
// Wait for confirmation response through the channel
|
||||||
|
let mut rx = self.confirmation_rx.lock().await;
|
||||||
|
// Loop the recv until we have a matched req_id due to potential duplicate messages.
|
||||||
|
while let Some((req_id, confirmed)) = rx.recv().await {
|
||||||
|
if req_id == request.id {
|
||||||
|
if confirmed {
|
||||||
|
// User approved - dispatch the tool call
|
||||||
|
let output = capabilities.dispatch_tool_call(tool_call).await;
|
||||||
|
message_tool_response = message_tool_response.with_tool_response(
|
||||||
|
request.id.clone(),
|
||||||
|
output,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// User declined - add declined response
|
||||||
|
message_tool_response = message_tool_response.with_tool_response(
|
||||||
|
request.id.clone(),
|
||||||
|
Ok(vec![Content::text("User declined to run this tool.")]),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
break; // Exit the loop once the matching `req_id` is found
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"chat" => {
|
||||||
|
// Skip all tool calls in chat mode
|
||||||
|
for request in &tool_requests {
|
||||||
|
message_tool_response = message_tool_response.with_tool_response(
|
||||||
|
request.id.clone(),
|
||||||
|
Ok(vec![Content::text(
|
||||||
|
"The following tool call was skipped in Goose chat mode. \
|
||||||
|
In chat mode, you cannot run tool calls, instead, you can \
|
||||||
|
only provide a detailed plan to the user. Provide an \
|
||||||
|
explanation of the proposed tool call as if it were a plan. \
|
||||||
|
Only if the user asks, provide a short explanation to the \
|
||||||
|
user that they could consider running the tool above on \
|
||||||
|
their own or with a different goose mode."
|
||||||
|
)]),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => {
|
||||||
|
if mode != "auto" {
|
||||||
|
warn!("Unknown GOOSE_MODE: {mode:?}. Defaulting to 'auto' mode.");
|
||||||
|
}
|
||||||
|
// Process tool requests in parallel
|
||||||
|
let mut tool_futures = Vec::new();
|
||||||
|
for request in &tool_requests {
|
||||||
|
if let Ok(tool_call) = request.tool_call.clone() {
|
||||||
|
tool_futures.push(async {
|
||||||
|
let output = capabilities.dispatch_tool_call(tool_call).await;
|
||||||
|
(request.id.clone(), output)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Wait for all tool calls to complete
|
||||||
|
let results = futures::future::join_all(tool_futures).await;
|
||||||
|
for (request_id, output) in results {
|
||||||
|
message_tool_response = message_tool_response.with_tool_response(
|
||||||
|
request_id,
|
||||||
|
output,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yield message_tool_response.clone();
|
||||||
|
|
||||||
|
messages.push(response);
|
||||||
|
messages.push(message_tool_response);
|
||||||
|
},
|
||||||
|
Err(ProviderError::ContextLengthExceeded(_)) => {
|
||||||
|
if truncation_attempt >= MAX_TRUNCATION_ATTEMPTS {
|
||||||
|
// Create an error message & terminate the stream
|
||||||
|
// the previous message would have been a user message (e.g. before any tool calls, this is just after the input message.
|
||||||
|
// at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user)
|
||||||
|
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
truncation_attempt += 1;
|
||||||
|
warn!("Context length exceeded. Truncation Attempt: {}/{}.", truncation_attempt, MAX_TRUNCATION_ATTEMPTS);
|
||||||
|
|
||||||
|
// Decay the estimate factor as we make more truncation attempts
|
||||||
|
// Estimate factor decays like this over time: 0.9, 0.81, 0.729, ...
|
||||||
|
let estimate_factor: f32 = ESTIMATE_FACTOR_DECAY.powi(truncation_attempt as i32);
|
||||||
|
|
||||||
|
// release the lock before truncation to prevent deadlock
|
||||||
|
drop(capabilities);
|
||||||
|
|
||||||
|
if let Err(err) = self.summarize_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await {
|
||||||
|
yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Re-acquire the lock
|
||||||
|
capabilities = self.capabilities.lock().await;
|
||||||
|
|
||||||
|
// Retry the loop after truncation
|
||||||
|
continue;
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
// Create an error message & terminate the stream
|
||||||
|
error!("Error: {}", e);
|
||||||
|
yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Yield control back to the scheduler to prevent blocking
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn usage(&self) -> Vec<ProviderUsage> {
|
||||||
|
let capabilities = self.capabilities.lock().await;
|
||||||
|
capabilities.get_usage().await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn extend_system_prompt(&mut self, extension: String) {
|
||||||
|
let mut capabilities = self.capabilities.lock().await;
|
||||||
|
capabilities.add_system_prompt_extension(extension);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn override_system_prompt(&mut self, template: String) {
|
||||||
|
let mut capabilities = self.capabilities.lock().await;
|
||||||
|
capabilities.set_system_prompt_override(template);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
|
||||||
|
let capabilities = self.capabilities.lock().await;
|
||||||
|
capabilities
|
||||||
|
.list_prompts()
|
||||||
|
.await
|
||||||
|
.expect("Failed to list prompts")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
|
||||||
|
let capabilities = self.capabilities.lock().await;
|
||||||
|
|
||||||
|
// First find which extension has this prompt
|
||||||
|
let prompts = capabilities
|
||||||
|
.list_prompts()
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
|
||||||
|
|
||||||
|
if let Some(extension) = prompts
|
||||||
|
.iter()
|
||||||
|
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
|
||||||
|
.map(|(extension, _)| extension)
|
||||||
|
{
|
||||||
|
return capabilities
|
||||||
|
.get_prompt(extension, name, arguments)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(anyhow!("Prompt '{}' not found", name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
register_agent!("summarize", SummarizeAgent);
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
pub mod agents;
|
pub mod agents;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
|
pub mod memory_condense;
|
||||||
pub mod message;
|
pub mod message;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
pub mod prompt_template;
|
pub mod prompt_template;
|
||||||
|
|||||||
130
crates/goose/src/memory_condense.rs
Normal file
130
crates/goose/src/memory_condense.rs
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
use crate::agents::Capabilities;
|
||||||
|
use crate::message::Message;
|
||||||
|
use crate::token_counter::TokenCounter;
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
const SYSTEM_PROMPT: &str = "You are good at summarizing.";
|
||||||
|
|
||||||
|
fn create_summarize_request(messages: &[Message]) -> Vec<Message> {
|
||||||
|
vec![
|
||||||
|
Message::user().with_text(format!("Please use a few concise sentences to summarize this chat, while keeping the important information.\n\n```\n{:?}```", messages)),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
async fn single_request(
|
||||||
|
capabilities: &Capabilities,
|
||||||
|
messages: &[Message],
|
||||||
|
) -> Result<Message, anyhow::Error> {
|
||||||
|
Ok(capabilities
|
||||||
|
.provider()
|
||||||
|
.complete(SYSTEM_PROMPT, messages, &[])
|
||||||
|
.await?
|
||||||
|
.0)
|
||||||
|
}
|
||||||
|
async fn memory_condense(
|
||||||
|
capabilities: &Capabilities,
|
||||||
|
token_counter: &TokenCounter,
|
||||||
|
messages: &mut Vec<Message>,
|
||||||
|
token_counts: &mut Vec<usize>,
|
||||||
|
context_limit: usize,
|
||||||
|
) -> Result<(), anyhow::Error> {
|
||||||
|
let system_prompt_tokens = token_counter.count_tokens(SYSTEM_PROMPT);
|
||||||
|
|
||||||
|
// Since the process will run multiple times, we should avoid expensive operations like random access.
|
||||||
|
let mut message_stack = messages.iter().cloned().rev().collect::<Vec<_>>();
|
||||||
|
let mut count_stack = token_counts.iter().copied().rev().collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Tracks the number of remaining tokens in the stack
|
||||||
|
let mut total_tokens = count_stack.iter().sum::<usize>();
|
||||||
|
|
||||||
|
// Tracks the change of total_tokens in the previous loop.
|
||||||
|
// If diff <= 0, then the model cannot summarize any further. We set it to 1 before the process
|
||||||
|
// to ensure that the process starts.
|
||||||
|
let mut diff = 1;
|
||||||
|
|
||||||
|
while total_tokens > context_limit && diff > 0 {
|
||||||
|
let mut batch = Vec::new();
|
||||||
|
let mut current_tokens = 0;
|
||||||
|
|
||||||
|
// Extracts the beginning messages (which appears in the front of the message stack) to
|
||||||
|
// summarize.
|
||||||
|
while total_tokens > current_tokens + context_limit
|
||||||
|
&& current_tokens + system_prompt_tokens <= context_limit
|
||||||
|
{
|
||||||
|
batch.push(message_stack.pop().unwrap());
|
||||||
|
current_tokens += count_stack.pop().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// It could happen that the extracted messages are always the previous summary when the
|
||||||
|
// context limit is very small. We should force it to consume more messages.
|
||||||
|
if !batch.is_empty()
|
||||||
|
&& !message_stack.is_empty()
|
||||||
|
&& current_tokens + system_prompt_tokens <= context_limit
|
||||||
|
{
|
||||||
|
batch.push(message_stack.pop().unwrap());
|
||||||
|
current_tokens += count_stack.pop().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
diff = -(current_tokens as isize);
|
||||||
|
let request = create_summarize_request(&batch);
|
||||||
|
let response_text = single_request(capabilities, &request)
|
||||||
|
.await?
|
||||||
|
.as_concat_text();
|
||||||
|
|
||||||
|
// Ensure the conversation starts with a User message
|
||||||
|
let curr_messages = vec![
|
||||||
|
// shoule be in reversed order
|
||||||
|
Message::assistant().with_text(&response_text),
|
||||||
|
Message::user().with_text("Hello! How are we progressing?"),
|
||||||
|
];
|
||||||
|
let curr_tokens = token_counter.count_chat_tokens("", &curr_messages, &[]);
|
||||||
|
diff += curr_tokens as isize;
|
||||||
|
count_stack.push(curr_tokens);
|
||||||
|
message_stack.extend(curr_messages);
|
||||||
|
|
||||||
|
// Update the counter
|
||||||
|
total_tokens = total_tokens.checked_add_signed(diff).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
if total_tokens <= context_limit {
|
||||||
|
*messages = message_stack.into_iter().rev().collect();
|
||||||
|
*token_counts = count_stack.into_iter().rev().collect();
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("Cannot compress the messages anymore"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn condense_messages(
|
||||||
|
capabilities: &Capabilities,
|
||||||
|
token_counter: &TokenCounter,
|
||||||
|
messages: &mut Vec<Message>,
|
||||||
|
token_counts: &mut Vec<usize>,
|
||||||
|
context_limit: usize,
|
||||||
|
) -> Result<(), anyhow::Error> {
|
||||||
|
let total_tokens: usize = token_counts.iter().sum();
|
||||||
|
debug!("Total tokens before memory condensation: {}", total_tokens);
|
||||||
|
|
||||||
|
// The compressor should determine whether we need to compress the messages or not. This
|
||||||
|
// function just checks if the limit is satisfied.
|
||||||
|
memory_condense(
|
||||||
|
capabilities,
|
||||||
|
token_counter,
|
||||||
|
messages,
|
||||||
|
token_counts,
|
||||||
|
context_limit,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let total_tokens: usize = token_counts.iter().sum();
|
||||||
|
debug!("Total tokens after memory condensation: {}", total_tokens);
|
||||||
|
|
||||||
|
// Compressor should handle this case.
|
||||||
|
assert!(total_tokens <= context_limit, "Illegal compression result from the compressor: the number of tokens is greater than the limit.");
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Memory condensation complete. Total tokens: {}",
|
||||||
|
total_tokens
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user