From 89130c206783b8d0dc58e65aae824a79228f6667 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 14:53:48 +0200 Subject: [PATCH] Add pending request cleanup and config --- crates/goose-api/src/handlers.rs | 47 +++++++++++++++++-- crates/goose-cli/src/commands/configure.rs | 4 ++ crates/goose-cli/src/session/mod.rs | 6 +++ crates/goose-server/src/routes/extension.rs | 24 ++++++++++ crates/goose/src/agents/extension.rs | 22 +++++++++ crates/goose/src/agents/extension_manager.rs | 23 ++++++++- crates/goose/src/config/extensions.rs | 4 ++ crates/goose/src/config/mod.rs | 2 + crates/mcp-client/examples/clients.rs | 23 +++++++-- crates/mcp-client/examples/sse.rs | 7 ++- crates/mcp-client/examples/stdio.rs | 8 +++- .../mcp-client/examples/stdio_integration.rs | 2 + crates/mcp-client/src/transport/mod.rs | 43 +++++++++++++++-- crates/mcp-client/src/transport/sse.rs | 25 ++++++++-- crates/mcp-client/src/transport/stdio.rs | 21 ++++++++- 15 files changed, 242 insertions(+), 19 deletions(-) diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index 096e7193..954069e4 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -72,6 +72,10 @@ pub enum ExtensionConfigRequest { #[serde(default)] env_keys: Vec, timeout: Option, + #[serde(default)] + max_pending_requests: Option, + #[serde(default)] + pending_request_timeout: Option, }, #[serde(rename = "stdio")] Stdio { @@ -84,12 +88,20 @@ pub enum ExtensionConfigRequest { #[serde(default)] env_keys: Vec, timeout: Option, + #[serde(default)] + max_pending_requests: Option, + #[serde(default)] + pending_request_timeout: Option, }, #[serde(rename = "builtin")] Builtin { name: String, display_name: Option, timeout: Option, + #[serde(default)] + max_pending_requests: Option, + #[serde(default)] + pending_request_timeout: Option, }, #[serde(rename = "frontend")] Frontend { @@ -370,7 +382,15 @@ pub async fn add_extension_handler( } let extension = match req { - ExtensionConfigRequest::Sse { name, uri, envs, env_keys, timeout } => { + ExtensionConfigRequest::Sse { + name, + uri, + envs, + env_keys, + timeout, + max_pending_requests, + pending_request_timeout, + } => { ExtensionConfig::Sse { name, uri, @@ -378,10 +398,21 @@ pub async fn add_extension_handler( env_keys, description: None, timeout, + max_pending_requests, + pending_request_timeout, bundled: None, } } - ExtensionConfigRequest::Stdio { name, cmd, args, envs, env_keys, timeout } => { + ExtensionConfigRequest::Stdio { + name, + cmd, + args, + envs, + env_keys, + timeout, + max_pending_requests, + pending_request_timeout, + } => { ExtensionConfig::Stdio { name, cmd, @@ -389,15 +420,25 @@ pub async fn add_extension_handler( envs, env_keys, timeout, + max_pending_requests, + pending_request_timeout, description: None, bundled: None, } } - ExtensionConfigRequest::Builtin { name, display_name, timeout } => { + ExtensionConfigRequest::Builtin { + name, + display_name, + timeout, + max_pending_requests, + pending_request_timeout, + } => { ExtensionConfig::Builtin { name, display_name, timeout, + max_pending_requests, + pending_request_timeout, bundled: None, } } diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 7e49c50c..03d56284 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -79,6 +79,8 @@ pub async fn handle_configure() -> Result<(), Box> { name: "developer".to_string(), display_name: Some(goose::config::DEFAULT_DISPLAY_NAME.to_string()), timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT), + max_pending_requests: Some(goose::config::DEFAULT_MAX_PENDING_REQUESTS), + pending_request_timeout: Some(goose::config::DEFAULT_PENDING_REQUEST_TIMEOUT), bundled: Some(true), }, })?; @@ -548,6 +550,8 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { name: extension.clone(), display_name: Some(display_name), timeout: Some(timeout), + max_pending_requests: Some(goose::config::DEFAULT_MAX_PENDING_REQUESTS), + pending_request_timeout: Some(goose::config::DEFAULT_PENDING_REQUEST_TIMEOUT), bundled: Some(true), }, })?; diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 273ec979..f1e10fc5 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -177,6 +177,8 @@ impl Session { description: Some(goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string()), // TODO: should set timeout timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT), + max_pending_requests: Some(goose::config::DEFAULT_MAX_PENDING_REQUESTS), + pending_request_timeout: Some(goose::config::DEFAULT_PENDING_REQUEST_TIMEOUT), bundled: None, }; @@ -210,6 +212,8 @@ impl Session { description: Some(goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string()), // TODO: should set timeout timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT), + max_pending_requests: Some(goose::config::DEFAULT_MAX_PENDING_REQUESTS), + pending_request_timeout: Some(goose::config::DEFAULT_PENDING_REQUEST_TIMEOUT), bundled: None, }; @@ -235,6 +239,8 @@ impl Session { display_name: None, // TODO: should set a timeout timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT), + max_pending_requests: Some(goose::config::DEFAULT_MAX_PENDING_REQUESTS), + pending_request_timeout: Some(goose::config::DEFAULT_PENDING_REQUEST_TIMEOUT), bundled: None, }; self.agent diff --git a/crates/goose-server/src/routes/extension.rs b/crates/goose-server/src/routes/extension.rs index 3b8f8dbc..d535a7fb 100644 --- a/crates/goose-server/src/routes/extension.rs +++ b/crates/goose-server/src/routes/extension.rs @@ -29,6 +29,10 @@ enum ExtensionConfigRequest { #[serde(default)] env_keys: Vec, timeout: Option, + #[serde(default)] + max_pending_requests: Option, + #[serde(default)] + pending_request_timeout: Option, }, /// Standard I/O (stdio) extension. #[serde(rename = "stdio")] @@ -47,6 +51,10 @@ enum ExtensionConfigRequest { #[serde(default)] env_keys: Vec, timeout: Option, + #[serde(default)] + max_pending_requests: Option, + #[serde(default)] + pending_request_timeout: Option, }, /// Built-in extension that is part of the goose binary. #[serde(rename = "builtin")] @@ -55,6 +63,10 @@ enum ExtensionConfigRequest { name: String, display_name: Option, timeout: Option, + #[serde(default)] + max_pending_requests: Option, + #[serde(default)] + pending_request_timeout: Option, }, /// Frontend extension that provides tools to be executed by the frontend. #[serde(rename = "frontend")] @@ -167,6 +179,8 @@ async fn add_extension( envs, env_keys, timeout, + max_pending_requests, + pending_request_timeout, } => ExtensionConfig::Sse { name, uri, @@ -174,6 +188,8 @@ async fn add_extension( env_keys, description: None, timeout, + max_pending_requests, + pending_request_timeout, bundled: None, }, ExtensionConfigRequest::Stdio { @@ -183,6 +199,8 @@ async fn add_extension( envs, env_keys, timeout, + max_pending_requests, + pending_request_timeout, } => { // TODO: We can uncomment once bugs are fixed. Check allowlist for Stdio extensions // if !is_command_allowed(&cmd, &args) { @@ -204,6 +222,8 @@ async fn add_extension( envs, env_keys, timeout, + max_pending_requests, + pending_request_timeout, bundled: None, } } @@ -211,10 +231,14 @@ async fn add_extension( name, display_name, timeout, + max_pending_requests, + pending_request_timeout, } => ExtensionConfig::Builtin { name, display_name, timeout, + max_pending_requests, + pending_request_timeout, bundled: None, }, ExtensionConfigRequest::Frontend { diff --git a/crates/goose/src/agents/extension.rs b/crates/goose/src/agents/extension.rs index ae00d479..2b3fc2d6 100644 --- a/crates/goose/src/agents/extension.rs +++ b/crates/goose/src/agents/extension.rs @@ -136,6 +136,10 @@ pub enum ExtensionConfig { // NOTE: set timeout to be optional for compatibility. // However, new configurations should include this field. timeout: Option, + #[serde(default)] + max_pending_requests: Option, + #[serde(default)] + pending_request_timeout: Option, /// Whether this extension is bundled with Goose #[serde(default)] bundled: Option, @@ -153,6 +157,10 @@ pub enum ExtensionConfig { env_keys: Vec, timeout: Option, description: Option, + #[serde(default)] + max_pending_requests: Option, + #[serde(default)] + pending_request_timeout: Option, /// Whether this extension is bundled with Goose #[serde(default)] bundled: Option, @@ -164,6 +172,10 @@ pub enum ExtensionConfig { name: String, display_name: Option, // needed for the UI timeout: Option, + #[serde(default)] + max_pending_requests: Option, + #[serde(default)] + pending_request_timeout: Option, /// Whether this extension is bundled with Goose #[serde(default)] bundled: Option, @@ -189,6 +201,8 @@ impl Default for ExtensionConfig { name: config::DEFAULT_EXTENSION.to_string(), display_name: Some(config::DEFAULT_DISPLAY_NAME.to_string()), timeout: Some(config::DEFAULT_EXTENSION_TIMEOUT), + max_pending_requests: Some(config::DEFAULT_MAX_PENDING_REQUESTS), + pending_request_timeout: Some(config::DEFAULT_PENDING_REQUEST_TIMEOUT), bundled: Some(true), } } @@ -203,6 +217,8 @@ impl ExtensionConfig { env_keys: Vec::new(), description: Some(description.into()), timeout: Some(timeout.into()), + max_pending_requests: Some(config::DEFAULT_MAX_PENDING_REQUESTS), + pending_request_timeout: Some(config::DEFAULT_PENDING_REQUEST_TIMEOUT), bundled: None, } } @@ -221,6 +237,8 @@ impl ExtensionConfig { env_keys: Vec::new(), description: Some(description.into()), timeout: Some(timeout.into()), + max_pending_requests: Some(config::DEFAULT_MAX_PENDING_REQUESTS), + pending_request_timeout: Some(config::DEFAULT_PENDING_REQUEST_TIMEOUT), bundled: None, } } @@ -237,6 +255,8 @@ impl ExtensionConfig { envs, env_keys, timeout, + max_pending_requests, + pending_request_timeout, description, bundled, .. @@ -248,6 +268,8 @@ impl ExtensionConfig { args: args.into_iter().map(Into::into).collect(), description, timeout, + max_pending_requests, + pending_request_timeout, bundled, }, other => other, diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 4bc4d746..ebbe1c10 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -178,10 +178,17 @@ impl ExtensionManager { envs, env_keys, timeout, + max_pending_requests, + pending_request_timeout, .. } => { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; - let transport = SseTransport::new(uri, all_envs); + let transport = SseTransport::new( + uri, + all_envs, + *max_pending_requests, + pending_request_timeout.map(Duration::from_secs), + ); let handle = transport.start().await?; let service = McpService::with_timeout( handle, @@ -197,10 +204,18 @@ impl ExtensionManager { envs, env_keys, timeout, + max_pending_requests, + pending_request_timeout, .. } => { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; - let transport = StdioTransport::new(cmd, args.to_vec(), all_envs); + let transport = StdioTransport::new( + cmd, + args.to_vec(), + all_envs, + *max_pending_requests, + pending_request_timeout.map(Duration::from_secs), + ); let handle = transport.start().await?; let service = McpService::with_timeout( handle, @@ -214,6 +229,8 @@ impl ExtensionManager { name, display_name: _, timeout, + max_pending_requests, + pending_request_timeout, bundled: _, } => { let cmd = std::env::current_exe() @@ -225,6 +242,8 @@ impl ExtensionManager { &cmd, vec!["mcp".to_string(), name.clone()], HashMap::new(), + *max_pending_requests, + pending_request_timeout.map(Duration::from_secs), ); let handle = transport.start().await?; let service = McpService::with_timeout( diff --git a/crates/goose/src/config/extensions.rs b/crates/goose/src/config/extensions.rs index 9bf964ac..f079f5ff 100644 --- a/crates/goose/src/config/extensions.rs +++ b/crates/goose/src/config/extensions.rs @@ -7,6 +7,8 @@ use utoipa::ToSchema; pub const DEFAULT_EXTENSION: &str = "developer"; pub const DEFAULT_EXTENSION_TIMEOUT: u64 = 300; +pub const DEFAULT_MAX_PENDING_REQUESTS: usize = 128; +pub const DEFAULT_PENDING_REQUEST_TIMEOUT: u64 = 30; pub const DEFAULT_EXTENSION_DESCRIPTION: &str = ""; pub const DEFAULT_DISPLAY_NAME: &str = "Developer"; @@ -45,6 +47,8 @@ impl ExtensionConfigManager { name: DEFAULT_EXTENSION.to_string(), display_name: Some(DEFAULT_DISPLAY_NAME.to_string()), timeout: Some(DEFAULT_EXTENSION_TIMEOUT), + max_pending_requests: Some(DEFAULT_MAX_PENDING_REQUESTS), + pending_request_timeout: Some(DEFAULT_PENDING_REQUEST_TIMEOUT), bundled: Some(true), }, }, diff --git a/crates/goose/src/config/mod.rs b/crates/goose/src/config/mod.rs index ca9abb01..1d954358 100644 --- a/crates/goose/src/config/mod.rs +++ b/crates/goose/src/config/mod.rs @@ -13,3 +13,5 @@ pub use extensions::DEFAULT_DISPLAY_NAME; pub use extensions::DEFAULT_EXTENSION; pub use extensions::DEFAULT_EXTENSION_DESCRIPTION; pub use extensions::DEFAULT_EXTENSION_TIMEOUT; +pub use extensions::DEFAULT_MAX_PENDING_REQUESTS; +pub use extensions::DEFAULT_PENDING_REQUEST_TIMEOUT; diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs index 4913b952..603721ca 100644 --- a/crates/mcp-client/examples/clients.rs +++ b/crates/mcp-client/examples/clients.rs @@ -18,17 +18,34 @@ async fn main() -> Result<(), Box> { ) .init(); - let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); + let transport1 = StdioTransport::new( + "uvx", + vec!["mcp-server-git".to_string()], + HashMap::new(), + None, + None, + ); let handle1 = transport1.start().await?; let service1 = McpService::with_timeout(handle1, Duration::from_secs(30)); let client1 = McpClient::new(service1); - let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); + let transport2 = StdioTransport::new( + "uvx", + vec!["mcp-server-git".to_string()], + HashMap::new(), + None, + None, + ); let handle2 = transport2.start().await?; let service2 = McpService::with_timeout(handle2, Duration::from_secs(30)); let client2 = McpClient::new(service2); - let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new()); + let transport3 = SseTransport::new( + "http://localhost:8000/sse", + HashMap::new(), + None, + None, + ); let handle3 = transport3.start().await?; let service3 = McpService::with_timeout(handle3, Duration::from_secs(10)); let client3 = McpClient::new(service3); diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index 360a2bbc..57b8132b 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -18,7 +18,12 @@ async fn main() -> Result<()> { .init(); // Create the base transport - let transport = SseTransport::new("http://localhost:8000/sse", HashMap::new()); + let transport = SseTransport::new( + "http://localhost:8000/sse", + HashMap::new(), + None, + None, + ); // Start transport let handle = transport.start().await?; diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index e43f036c..bd46972f 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -20,7 +20,13 @@ async fn main() -> Result<(), ClientError> { .init(); // 1) Create the transport - let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); + let transport = StdioTransport::new( + "uvx", + vec!["mcp-server-git".to_string()], + HashMap::new(), + None, + None, + ); // 2) Start the transport to get a handle let transport_handle = transport.start().await?; diff --git a/crates/mcp-client/examples/stdio_integration.rs b/crates/mcp-client/examples/stdio_integration.rs index ffdcc10c..86ca3cf5 100644 --- a/crates/mcp-client/examples/stdio_integration.rs +++ b/crates/mcp-client/examples/stdio_integration.rs @@ -29,6 +29,8 @@ async fn main() -> Result<(), ClientError> { .map(|s| s.to_string()) .collect(), HashMap::new(), + None, + None, ); // Start the transport to get a handle diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index e2a66b26..8d51fecf 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; use mcp_core::protocol::JsonRpcMessage; use std::collections::HashMap; +use std::time::{Duration, Instant}; use thiserror::Error; use tokio::sync::{mpsc, oneshot, RwLock}; @@ -31,6 +32,9 @@ pub enum Error { #[error("HTTP error: {status} - {message}")] HttpError { status: u16, message: String }, + + #[error("Too many pending requests")] + PendingRequestsFull, } /// A message that can be sent through the transport @@ -89,7 +93,10 @@ pub async fn send_message( // A data structure to store pending requests and their response channels pub struct PendingRequests { - requests: RwLock>>>, + requests: + RwLock>, Instant)>>, + max_size: Option, + timeout: Option, } impl Default for PendingRequests { @@ -100,21 +107,49 @@ impl Default for PendingRequests { impl PendingRequests { pub fn new() -> Self { + Self::with_limits(None, None) + } + + pub fn with_limits(max_size: Option, timeout: Option) -> Self { Self { requests: RwLock::new(HashMap::new()), + max_size, + timeout, } } - pub async fn insert(&self, id: String, sender: oneshot::Sender>) { - self.requests.write().await.insert(id, sender); + pub async fn insert( + &self, + id: String, + sender: oneshot::Sender>, + ) -> Result<(), Error> { + self.cleanup().await; + let mut map = self.requests.write().await; + if let Some(max) = self.max_size { + if map.len() >= max { + return Err(Error::PendingRequestsFull); + } + } + map.insert(id, (sender, Instant::now())); + Ok(()) } pub async fn respond(&self, id: &str, response: Result) { - if let Some(tx) = self.requests.write().await.remove(id) { + if let Some((tx, _)) = self.requests.write().await.remove(id) { let _ = tx.send(response); } } + pub async fn cleanup(&self) { + if let Some(timeout) = self.timeout { + let now = Instant::now(); + self.requests + .write() + .await + .retain(|_, (_, ts)| now.duration_since(*ts) <= timeout); + } + } + pub async fn clear(&self) { self.requests.write().await.clear(); } diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index 8a564708..dc94ce6a 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -179,7 +179,14 @@ impl SseActor { if let JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) = &transport_msg.message { - pending_requests.insert(id.to_string(), response_tx).await; + pending_requests.cleanup().await; + if let Err(e) = pending_requests + .insert(id.to_string(), response_tx) + .await + { + let _ = response_tx.send(Err(e)); + continue; + } } } @@ -236,14 +243,23 @@ impl TransportHandle for SseTransportHandle { pub struct SseTransport { sse_url: String, env: HashMap, + max_pending: Option, + pending_timeout: Option, } /// The SSE transport spawns an `SseActor` on `start()`. impl SseTransport { - pub fn new>(sse_url: S, env: HashMap) -> Self { + pub fn new>( + sse_url: S, + env: HashMap, + max_pending: Option, + pending_timeout: Option, + ) -> Self { Self { sse_url: sse_url.into(), env, + max_pending, + pending_timeout, } } @@ -286,7 +302,10 @@ impl Transport for SseTransport { // Build the actor let actor = SseActor::new( rx, - Arc::new(PendingRequests::new()), + Arc::new(PendingRequests::with_limits( + self.max_pending, + self.pending_timeout, + )), self.sse_url.clone(), post_endpoint, ); diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index 5895e83e..884ed10e 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::time::Duration; use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::Arc; use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command}; @@ -162,7 +163,14 @@ impl StdioActor { if let Some(response_tx) = transport_msg.response_tx.take() { if let JsonRpcMessage::Request(request) = &transport_msg.message { if let Some(id) = &request.id { - pending_requests.insert(id.to_string(), response_tx).await; + pending_requests.cleanup().await; + if let Err(e) = pending_requests + .insert(id.to_string(), response_tx) + .await + { + let _ = response_tx.send(Err(e)); + continue; + } } } } @@ -218,6 +226,8 @@ pub struct StdioTransport { command: String, args: Vec, env: HashMap, + max_pending: Option, + pending_timeout: Option, } impl StdioTransport { @@ -225,11 +235,15 @@ impl StdioTransport { command: S, args: Vec, env: HashMap, + max_pending: Option, + pending_timeout: Option, ) -> Self { Self { command: command.into(), args, env, + max_pending, + pending_timeout, } } @@ -294,7 +308,10 @@ impl Transport for StdioTransport { let actor = StdioActor { receiver: Some(message_rx), - pending_requests: Arc::new(PendingRequests::new()), + pending_requests: Arc::new(PendingRequests::with_limits( + self.max_pending, + self.pending_timeout, + )), process, error_sender: error_tx, stdin: Some(stdin),