From 26da43ae077c084710d233795ba1e7c75810aa36 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 14:53:37 +0200 Subject: [PATCH 1/2] Handle context length errors with summarization in API --- crates/goose-api/Cargo.toml | 4 ++ crates/goose-api/src/handlers.rs | 48 +++++++++++++++- crates/goose-api/src/tests.rs | 99 +++++++++++++++++++++++++++++++- 3 files changed, 149 insertions(+), 2 deletions(-) diff --git a/crates/goose-api/Cargo.toml b/crates/goose-api/Cargo.toml index 1caa2ffe..0db532b7 100644 --- a/crates/goose-api/Cargo.toml +++ b/crates/goose-api/Cargo.toml @@ -22,3 +22,7 @@ futures-util = "0.3" # For session IDs uuid = { version = "1", features = ["serde", "v4"] } # Add dynamic-library for extension loading + +[dev-dependencies] +tempfile = "3" +async-trait = "0.1" diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index 096e7193..b2019ba2 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -6,7 +6,7 @@ use futures_util::TryStreamExt; use tracing::{info, warn, error}; use mcp_core::tool::Tool; use goose::agents::{extension::Envs, extension_manager::ExtensionManager, ExtensionConfig, Agent, SessionConfig}; -use goose::message::Message; +use goose::message::{Message, MessageContent}; use goose::session::{self, Identifier}; use goose::config::Config; use std::sync::LazyLock; @@ -127,6 +127,30 @@ pub async fn start_session_handler( match result { Ok(mut stream) => { if let Ok(Some(response)) = stream.try_next().await { + if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) { + match agent.summarize_context(&messages).await { + Ok((summarized, _)) => { + messages = summarized; + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + + let api_response = StartSessionResponse { + message: "Conversation summarized to fit context window".to_string(), + status: "warning".to_string(), + session_id, + }; + return Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )); + } + Err(e) => { + warn!("Failed to summarize context: {}", e); + } + } + } + let response_text = response.as_concat_text(); messages.push(response); if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { @@ -215,6 +239,28 @@ pub async fn reply_session_handler( match result { Ok(mut stream) => { if let Ok(Some(response)) = stream.try_next().await { + if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) { + match agent.summarize_context(&messages).await { + Ok((summarized, _)) => { + messages = summarized; + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + let api_response = ApiResponse { + message: "Conversation summarized to fit context window".to_string(), + status: "warning".to_string(), + }; + return Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )); + } + Err(e) => { + warn!("Failed to summarize context: {}", e); + } + } + } + let response_text = response.as_concat_text(); messages.push(response); if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { diff --git a/crates/goose-api/src/tests.rs b/crates/goose-api/src/tests.rs index 302cf8c3..607064d4 100644 --- a/crates/goose-api/src/tests.rs +++ b/crates/goose-api/src/tests.rs @@ -1,10 +1,107 @@ #[cfg(test)] mod tests { use super::*; + use goose::message::{Message, MessageContent}; + use goose::model::ModelConfig; + use goose::providers::{ + base::{Provider, ProviderMetadata, ProviderUsage, Usage}, + errors::ProviderError, + }; + use mcp_core::tool::Tool; + use std::sync::Arc; + use tempfile::TempDir; + use warp::reply::Reply; + use goose::session::{self, Identifier}; + use uuid::Uuid; + use hyper::body; + + #[derive(Clone)] + struct ContextProvider { + model_config: ModelConfig, + } + + #[async_trait::async_trait] + impl Provider for ContextProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::empty() + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + + async fn complete( + &self, + system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + if system.contains("summarizing") { + Ok(( + Message::user().with_text("summary"), + ProviderUsage::new("mock".to_string(), Usage::default()), + )) + } else { + Err(ProviderError::ContextLengthExceeded("too long".to_string())) + } + } + } + + async fn setup() -> (TempDir, Uuid) { + let tmp = tempfile::tempdir().unwrap(); + std::env::set_var("HOME", tmp.path()); + + let provider = Arc::new(ContextProvider { + model_config: ModelConfig::new("test".to_string()), + }); + let agent = AGENT.lock().await; + agent.update_provider(provider).await.unwrap(); + drop(agent); + + let req = SessionRequest { + prompt: "start".repeat(1000), + }; + let reply = start_session_handler(req, "key".to_string()).await.unwrap(); + let resp = reply.into_response(); + let body = body::to_bytes(resp.into_body()).await.unwrap(); + let start: StartSessionResponse = serde_json::from_slice(&body).unwrap(); + (tmp, start.session_id) + } #[tokio::test] async fn build_routes_compiles() { let _routes = build_routes("test-key".to_string()); - // Just ensure building routes doesn't panic + } + + #[tokio::test] + async fn summarizes_large_history_on_start() { + let (tmp, session_id) = setup().await; + + let session_path = session::get_path(Identifier::Name(session_id.to_string())); + let messages = session::read_messages(&session_path).unwrap(); + assert!(messages.iter().any(|m| m.as_concat_text().contains("summary"))); + drop(tmp); + } + + #[tokio::test] + async fn summarizes_large_history_on_reply() { + let (tmp, session_id) = setup().await; + + let req = SessionReplyRequest { + session_id, + prompt: "reply".repeat(1000), + }; + let reply = reply_session_handler(req, "key".to_string()).await.unwrap(); + let resp = reply.into_response(); + let body = body::to_bytes(resp.into_body()).await.unwrap(); + let api: ApiResponse = serde_json::from_slice(&body).unwrap(); + assert_eq!(api.status, "warning"); + + let session_path = session::get_path(Identifier::Name(session_id.to_string())); + let messages = session::read_messages(&session_path).unwrap(); + assert!(messages + .iter() + .all(|m| !matches!(m.content.first(), Some(MessageContent::ContextLengthExceeded(_))))); + drop(tmp); } } From 9fb798052c9fa8514b2713486d76cd2e82ac2047 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 14:53:59 +0200 Subject: [PATCH 2/2] feat(api): add metrics endpoint --- crates/goose-api/README.md | 24 ++++++++++++++ crates/goose-api/src/handlers.rs | 35 ++++++++++++++++++++ crates/goose-api/src/routes.rs | 7 +++- crates/goose/src/agents/extension_manager.rs | 20 ++++++++++- crates/mcp-client/src/transport/sse.rs | 12 +++++-- crates/mcp-client/src/transport/stdio.rs | 9 ++++- 6 files changed, 102 insertions(+), 5 deletions(-) diff --git a/crates/goose-api/README.md b/crates/goose-api/README.md index b7fb63a5..d44af846 100644 --- a/crates/goose-api/README.md +++ b/crates/goose-api/README.md @@ -236,6 +236,30 @@ By default, the server runs on `127.0.0.1:8080`. You can modify this using confi } ``` +### 7. Metrics + +**Endpoint**: `GET /metrics` + +**Description**: Returns runtime metrics about stored sessions and extensions. + +**Request**: +- Headers: + - `x-api-key: [your-api-key]` + +**Response** (example): +```json +{ + "session_messages": { + "20240605_001234": 3, + "20240605_010000": 5 + }, + "active_sessions": 2, + "pending_requests": { + "mcp_say": 0 + } +} +``` + ## Session Management Sessions created via the API are stored in the same location as the CLI diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index 096e7193..f731ea88 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -10,6 +10,7 @@ use goose::message::Message; use goose::session::{self, Identifier}; use goose::config::Config; use std::sync::LazyLock; +use std::collections::HashMap; pub static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); pub static AGENT: LazyLock> = LazyLock::new(|| tokio::sync::Mutex::new(Agent::new())); @@ -60,6 +61,13 @@ pub struct ExtensionResponse { pub message: Option, } +#[derive(Debug, Serialize)] +pub struct MetricsResponse { + pub session_messages: HashMap, + pub active_sessions: usize, + pub pending_requests: HashMap, +} + #[derive(Debug, Deserialize)] #[serde(tag = "type")] pub enum ExtensionConfigRequest { @@ -442,6 +450,33 @@ pub async fn remove_extension_handler( Ok(warp::reply::json(&resp)) } +pub async fn metrics_handler() -> Result { + // Gather session message counts + let mut session_messages = HashMap::new(); + if let Ok(sessions) = session::list_sessions() { + for (name, path) in sessions { + if let Ok(messages) = session::read_messages(&path) { + session_messages.insert(name, messages.len()); + } + } + } + + let active_sessions = session_messages.len(); + + // Gather pending request sizes for each extension + let pending_requests = EXTENSION_MANAGER + .pending_request_sizes() + .await; + + let resp = MetricsResponse { + session_messages, + active_sessions, + pending_requests, + }; + + Ok(warp::reply::json(&resp)) +} + pub fn with_api_key(api_key: String) -> impl Filter + Clone { warp::header::value("x-api-key") .and_then(move |header_api_key: HeaderValue| { diff --git a/crates/goose-api/src/routes.rs b/crates/goose-api/src/routes.rs index 759786c3..ea5680ea 100644 --- a/crates/goose-api/src/routes.rs +++ b/crates/goose-api/src/routes.rs @@ -4,7 +4,7 @@ use tracing::{info, warn, error}; use crate::handlers::{ add_extension_handler, end_session_handler, get_provider_config_handler, list_extensions_handler, remove_extension_handler, reply_session_handler, - start_session_handler, with_api_key, + start_session_handler, metrics_handler, with_api_key, }; use crate::config::{ initialize_extensions, initialize_provider_config, load_configuration, @@ -57,6 +57,10 @@ pub fn build_routes(api_key: String) -> impl Filter impl Filter Result<(), anyhow::Error> { diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 4bc4d746..4b03a99a 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -17,7 +17,7 @@ use crate::agents::extension::Envs; use crate::config::{Config, ExtensionConfigManager}; use crate::prompt_template; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; -use mcp_client::transport::{SseTransport, StdioTransport, Transport}; +use mcp_client::transport::{PendingRequests, SseTransport, StdioTransport, Transport}; use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult}; use serde_json::Value; @@ -33,6 +33,7 @@ pub struct ExtensionManager { clients: HashMap, instructions: HashMap, resource_capable_extensions: HashSet, + pending_requests: HashMap>, // track pending requests per extension } /// A flattened representation of a resource used by the agent to prepare inference @@ -103,6 +104,7 @@ impl ExtensionManager { clients: HashMap::new(), instructions: HashMap::new(), resource_capable_extensions: HashSet::new(), + pending_requests: HashMap::new(), } } @@ -183,12 +185,14 @@ impl ExtensionManager { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; let transport = SseTransport::new(uri, all_envs); let handle = transport.start().await?; + let pending = handle.pending_requests(); let service = McpService::with_timeout( handle, Duration::from_secs( timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), ), ); + self.pending_requests.insert(sanitized_name.clone(), pending); Box::new(McpClient::new(service)) } ExtensionConfig::Stdio { @@ -202,12 +206,14 @@ impl ExtensionManager { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; let transport = StdioTransport::new(cmd, args.to_vec(), all_envs); let handle = transport.start().await?; + let pending = handle.pending_requests(); let service = McpService::with_timeout( handle, Duration::from_secs( timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), ), ); + self.pending_requests.insert(sanitized_name.clone(), pending); Box::new(McpClient::new(service)) } ExtensionConfig::Builtin { @@ -227,12 +233,14 @@ impl ExtensionManager { HashMap::new(), ); let handle = transport.start().await?; + let pending = handle.pending_requests(); let service = McpService::with_timeout( handle, Duration::from_secs( timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), ), ); + self.pending_requests.insert(sanitized_name.clone(), pending); Box::new(McpClient::new(service)) } _ => unreachable!(), @@ -285,9 +293,19 @@ impl ExtensionManager { self.clients.remove(&sanitized_name); self.instructions.remove(&sanitized_name); self.resource_capable_extensions.remove(&sanitized_name); + self.pending_requests.remove(&sanitized_name); Ok(()) } + /// Get the size of each extension's pending request map + pub async fn pending_request_sizes(&self) -> HashMap { + let mut result = HashMap::new(); + for (name, pending) in &self.pending_requests { + result.insert(name.clone(), pending.len().await); + } + result + } + pub async fn suggest_disable_extensions_prompt(&self) -> Value { let enabled_extensions_count = self.clients.len(); diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index 8a564708..0e15f168 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -223,6 +223,7 @@ impl SseActor { #[derive(Clone)] pub struct SseTransportHandle { sender: mpsc::Sender, + pending_requests: Arc, } #[async_trait::async_trait] @@ -232,6 +233,12 @@ impl TransportHandle for SseTransportHandle { } } +impl SseTransportHandle { + pub fn pending_requests(&self) -> Arc { + Arc::clone(&self.pending_requests) + } +} + #[derive(Clone)] pub struct SseTransport { sse_url: String, @@ -284,9 +291,10 @@ impl Transport for SseTransport { let post_endpoint_clone = Arc::clone(&post_endpoint); // Build the actor + let pending_requests = Arc::new(PendingRequests::new()); let actor = SseActor::new( rx, - Arc::new(PendingRequests::new()), + pending_requests.clone(), self.sse_url.clone(), post_endpoint, ); @@ -301,7 +309,7 @@ impl Transport for SseTransport { ) .await { - Ok(_) => Ok(SseTransportHandle { sender: tx }), + Ok(_) => Ok(SseTransportHandle { sender: tx, pending_requests }), Err(e) => Err(Error::SseConnection(e.to_string())), } } diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index 5895e83e..76a48487 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -189,6 +189,7 @@ impl StdioActor { pub struct StdioTransportHandle { sender: mpsc::Sender, error_receiver: Arc>>, + pending_requests: Arc, } #[async_trait::async_trait] @@ -212,6 +213,10 @@ impl StdioTransportHandle { Err(_) => Ok(()), } } + + pub fn pending_requests(&self) -> Arc { + Arc::clone(&self.pending_requests) + } } pub struct StdioTransport { @@ -292,9 +297,10 @@ impl Transport for StdioTransport { let (message_tx, message_rx) = mpsc::channel(32); let (error_tx, error_rx) = mpsc::channel(1); + let pending_requests = Arc::new(PendingRequests::new()); let actor = StdioActor { receiver: Some(message_rx), - pending_requests: Arc::new(PendingRequests::new()), + pending_requests: pending_requests.clone(), process, error_sender: error_tx, stdin: Some(stdin), @@ -307,6 +313,7 @@ impl Transport for StdioTransport { let handle = StdioTransportHandle { sender: message_tx, error_receiver: Arc::new(Mutex::new(error_rx)), + pending_requests, }; Ok(handle) }