diff --git a/crates/goose-api/Cargo.toml b/crates/goose-api/Cargo.toml index 1caa2ffe..0db532b7 100644 --- a/crates/goose-api/Cargo.toml +++ b/crates/goose-api/Cargo.toml @@ -22,3 +22,7 @@ futures-util = "0.3" # For session IDs uuid = { version = "1", features = ["serde", "v4"] } # Add dynamic-library for extension loading + +[dev-dependencies] +tempfile = "3" +async-trait = "0.1" diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index f731ea88..a1337579 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -6,7 +6,7 @@ use futures_util::TryStreamExt; use tracing::{info, warn, error}; use mcp_core::tool::Tool; use goose::agents::{extension::Envs, extension_manager::ExtensionManager, ExtensionConfig, Agent, SessionConfig}; -use goose::message::Message; +use goose::message::{Message, MessageContent}; use goose::session::{self, Identifier}; use goose::config::Config; use std::sync::LazyLock; @@ -135,6 +135,30 @@ pub async fn start_session_handler( match result { Ok(mut stream) => { if let Ok(Some(response)) = stream.try_next().await { + if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) { + match agent.summarize_context(&messages).await { + Ok((summarized, _)) => { + messages = summarized; + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + + let api_response = StartSessionResponse { + message: "Conversation summarized to fit context window".to_string(), + status: "warning".to_string(), + session_id, + }; + return Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )); + } + Err(e) => { + warn!("Failed to summarize context: {}", e); + } + } + } + let response_text = response.as_concat_text(); messages.push(response); if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { @@ -223,6 +247,28 @@ pub async fn reply_session_handler( match result { Ok(mut stream) => { if let Ok(Some(response)) = stream.try_next().await { + if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) { + match agent.summarize_context(&messages).await { + Ok((summarized, _)) => { + messages = summarized; + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + let api_response = ApiResponse { + message: "Conversation summarized to fit context window".to_string(), + status: "warning".to_string(), + }; + return Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )); + } + Err(e) => { + warn!("Failed to summarize context: {}", e); + } + } + } + let response_text = response.as_concat_text(); messages.push(response); if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { diff --git a/crates/goose-api/src/tests.rs b/crates/goose-api/src/tests.rs index 302cf8c3..607064d4 100644 --- a/crates/goose-api/src/tests.rs +++ b/crates/goose-api/src/tests.rs @@ -1,10 +1,107 @@ #[cfg(test)] mod tests { use super::*; + use goose::message::{Message, MessageContent}; + use goose::model::ModelConfig; + use goose::providers::{ + base::{Provider, ProviderMetadata, ProviderUsage, Usage}, + errors::ProviderError, + }; + use mcp_core::tool::Tool; + use std::sync::Arc; + use tempfile::TempDir; + use warp::reply::Reply; + use goose::session::{self, Identifier}; + use uuid::Uuid; + use hyper::body; + + #[derive(Clone)] + struct ContextProvider { + model_config: ModelConfig, + } + + #[async_trait::async_trait] + impl Provider for ContextProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::empty() + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + + async fn complete( + &self, + system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + if system.contains("summarizing") { + Ok(( + Message::user().with_text("summary"), + ProviderUsage::new("mock".to_string(), Usage::default()), + )) + } else { + Err(ProviderError::ContextLengthExceeded("too long".to_string())) + } + } + } + + async fn setup() -> (TempDir, Uuid) { + let tmp = tempfile::tempdir().unwrap(); + std::env::set_var("HOME", tmp.path()); + + let provider = Arc::new(ContextProvider { + model_config: ModelConfig::new("test".to_string()), + }); + let agent = AGENT.lock().await; + agent.update_provider(provider).await.unwrap(); + drop(agent); + + let req = SessionRequest { + prompt: "start".repeat(1000), + }; + let reply = start_session_handler(req, "key".to_string()).await.unwrap(); + let resp = reply.into_response(); + let body = body::to_bytes(resp.into_body()).await.unwrap(); + let start: StartSessionResponse = serde_json::from_slice(&body).unwrap(); + (tmp, start.session_id) + } #[tokio::test] async fn build_routes_compiles() { let _routes = build_routes("test-key".to_string()); - // Just ensure building routes doesn't panic + } + + #[tokio::test] + async fn summarizes_large_history_on_start() { + let (tmp, session_id) = setup().await; + + let session_path = session::get_path(Identifier::Name(session_id.to_string())); + let messages = session::read_messages(&session_path).unwrap(); + assert!(messages.iter().any(|m| m.as_concat_text().contains("summary"))); + drop(tmp); + } + + #[tokio::test] + async fn summarizes_large_history_on_reply() { + let (tmp, session_id) = setup().await; + + let req = SessionReplyRequest { + session_id, + prompt: "reply".repeat(1000), + }; + let reply = reply_session_handler(req, "key".to_string()).await.unwrap(); + let resp = reply.into_response(); + let body = body::to_bytes(resp.into_body()).await.unwrap(); + let api: ApiResponse = serde_json::from_slice(&body).unwrap(); + assert_eq!(api.status, "warning"); + + let session_path = session::get_path(Identifier::Name(session_id.to_string())); + let messages = session::read_messages(&session_path).unwrap(); + assert!(messages + .iter() + .all(|m| !matches!(m.content.first(), Some(MessageContent::ContextLengthExceeded(_))))); + drop(tmp); } }