Add pending request cleanup and config

This commit is contained in:
2025-05-29 14:53:48 +02:00
parent 493ba53c08
commit 89130c2067
15 changed files with 242 additions and 19 deletions

View File

@@ -72,6 +72,10 @@ pub enum ExtensionConfigRequest {
#[serde(default)] #[serde(default)]
env_keys: Vec<String>, env_keys: Vec<String>,
timeout: Option<u64>, timeout: Option<u64>,
#[serde(default)]
max_pending_requests: Option<usize>,
#[serde(default)]
pending_request_timeout: Option<u64>,
}, },
#[serde(rename = "stdio")] #[serde(rename = "stdio")]
Stdio { Stdio {
@@ -84,12 +88,20 @@ pub enum ExtensionConfigRequest {
#[serde(default)] #[serde(default)]
env_keys: Vec<String>, env_keys: Vec<String>,
timeout: Option<u64>, timeout: Option<u64>,
#[serde(default)]
max_pending_requests: Option<usize>,
#[serde(default)]
pending_request_timeout: Option<u64>,
}, },
#[serde(rename = "builtin")] #[serde(rename = "builtin")]
Builtin { Builtin {
name: String, name: String,
display_name: Option<String>, display_name: Option<String>,
timeout: Option<u64>, timeout: Option<u64>,
#[serde(default)]
max_pending_requests: Option<usize>,
#[serde(default)]
pending_request_timeout: Option<u64>,
}, },
#[serde(rename = "frontend")] #[serde(rename = "frontend")]
Frontend { Frontend {
@@ -370,7 +382,15 @@ pub async fn add_extension_handler(
} }
let extension = match req { let extension = match req {
ExtensionConfigRequest::Sse { name, uri, envs, env_keys, timeout } => { ExtensionConfigRequest::Sse {
name,
uri,
envs,
env_keys,
timeout,
max_pending_requests,
pending_request_timeout,
} => {
ExtensionConfig::Sse { ExtensionConfig::Sse {
name, name,
uri, uri,
@@ -378,10 +398,21 @@ pub async fn add_extension_handler(
env_keys, env_keys,
description: None, description: None,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
bundled: None, bundled: None,
} }
} }
ExtensionConfigRequest::Stdio { name, cmd, args, envs, env_keys, timeout } => { ExtensionConfigRequest::Stdio {
name,
cmd,
args,
envs,
env_keys,
timeout,
max_pending_requests,
pending_request_timeout,
} => {
ExtensionConfig::Stdio { ExtensionConfig::Stdio {
name, name,
cmd, cmd,
@@ -389,15 +420,25 @@ pub async fn add_extension_handler(
envs, envs,
env_keys, env_keys,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
description: None, description: None,
bundled: None, bundled: None,
} }
} }
ExtensionConfigRequest::Builtin { name, display_name, timeout } => { ExtensionConfigRequest::Builtin {
name,
display_name,
timeout,
max_pending_requests,
pending_request_timeout,
} => {
ExtensionConfig::Builtin { ExtensionConfig::Builtin {
name, name,
display_name, display_name,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
bundled: None, bundled: None,
} }
} }

View File

@@ -79,6 +79,8 @@ pub async fn handle_configure() -> Result<(), Box<dyn Error>> {
name: "developer".to_string(), name: "developer".to_string(),
display_name: Some(goose::config::DEFAULT_DISPLAY_NAME.to_string()), display_name: Some(goose::config::DEFAULT_DISPLAY_NAME.to_string()),
timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT), timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT),
max_pending_requests: Some(goose::config::DEFAULT_MAX_PENDING_REQUESTS),
pending_request_timeout: Some(goose::config::DEFAULT_PENDING_REQUEST_TIMEOUT),
bundled: Some(true), bundled: Some(true),
}, },
})?; })?;
@@ -548,6 +550,8 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
name: extension.clone(), name: extension.clone(),
display_name: Some(display_name), display_name: Some(display_name),
timeout: Some(timeout), timeout: Some(timeout),
max_pending_requests: Some(goose::config::DEFAULT_MAX_PENDING_REQUESTS),
pending_request_timeout: Some(goose::config::DEFAULT_PENDING_REQUEST_TIMEOUT),
bundled: Some(true), bundled: Some(true),
}, },
})?; })?;

View File

@@ -177,6 +177,8 @@ impl Session {
description: Some(goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string()), description: Some(goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string()),
// TODO: should set timeout // TODO: should set timeout
timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT), timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT),
max_pending_requests: Some(goose::config::DEFAULT_MAX_PENDING_REQUESTS),
pending_request_timeout: Some(goose::config::DEFAULT_PENDING_REQUEST_TIMEOUT),
bundled: None, bundled: None,
}; };
@@ -210,6 +212,8 @@ impl Session {
description: Some(goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string()), description: Some(goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string()),
// TODO: should set timeout // TODO: should set timeout
timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT), timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT),
max_pending_requests: Some(goose::config::DEFAULT_MAX_PENDING_REQUESTS),
pending_request_timeout: Some(goose::config::DEFAULT_PENDING_REQUEST_TIMEOUT),
bundled: None, bundled: None,
}; };
@@ -235,6 +239,8 @@ impl Session {
display_name: None, display_name: None,
// TODO: should set a timeout // TODO: should set a timeout
timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT), timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT),
max_pending_requests: Some(goose::config::DEFAULT_MAX_PENDING_REQUESTS),
pending_request_timeout: Some(goose::config::DEFAULT_PENDING_REQUEST_TIMEOUT),
bundled: None, bundled: None,
}; };
self.agent self.agent

View File

@@ -29,6 +29,10 @@ enum ExtensionConfigRequest {
#[serde(default)] #[serde(default)]
env_keys: Vec<String>, env_keys: Vec<String>,
timeout: Option<u64>, timeout: Option<u64>,
#[serde(default)]
max_pending_requests: Option<usize>,
#[serde(default)]
pending_request_timeout: Option<u64>,
}, },
/// Standard I/O (stdio) extension. /// Standard I/O (stdio) extension.
#[serde(rename = "stdio")] #[serde(rename = "stdio")]
@@ -47,6 +51,10 @@ enum ExtensionConfigRequest {
#[serde(default)] #[serde(default)]
env_keys: Vec<String>, env_keys: Vec<String>,
timeout: Option<u64>, timeout: Option<u64>,
#[serde(default)]
max_pending_requests: Option<usize>,
#[serde(default)]
pending_request_timeout: Option<u64>,
}, },
/// Built-in extension that is part of the goose binary. /// Built-in extension that is part of the goose binary.
#[serde(rename = "builtin")] #[serde(rename = "builtin")]
@@ -55,6 +63,10 @@ enum ExtensionConfigRequest {
name: String, name: String,
display_name: Option<String>, display_name: Option<String>,
timeout: Option<u64>, timeout: Option<u64>,
#[serde(default)]
max_pending_requests: Option<usize>,
#[serde(default)]
pending_request_timeout: Option<u64>,
}, },
/// Frontend extension that provides tools to be executed by the frontend. /// Frontend extension that provides tools to be executed by the frontend.
#[serde(rename = "frontend")] #[serde(rename = "frontend")]
@@ -167,6 +179,8 @@ async fn add_extension(
envs, envs,
env_keys, env_keys,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
} => ExtensionConfig::Sse { } => ExtensionConfig::Sse {
name, name,
uri, uri,
@@ -174,6 +188,8 @@ async fn add_extension(
env_keys, env_keys,
description: None, description: None,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
bundled: None, bundled: None,
}, },
ExtensionConfigRequest::Stdio { ExtensionConfigRequest::Stdio {
@@ -183,6 +199,8 @@ async fn add_extension(
envs, envs,
env_keys, env_keys,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
} => { } => {
// TODO: We can uncomment once bugs are fixed. Check allowlist for Stdio extensions // TODO: We can uncomment once bugs are fixed. Check allowlist for Stdio extensions
// if !is_command_allowed(&cmd, &args) { // if !is_command_allowed(&cmd, &args) {
@@ -204,6 +222,8 @@ async fn add_extension(
envs, envs,
env_keys, env_keys,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
bundled: None, bundled: None,
} }
} }
@@ -211,10 +231,14 @@ async fn add_extension(
name, name,
display_name, display_name,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
} => ExtensionConfig::Builtin { } => ExtensionConfig::Builtin {
name, name,
display_name, display_name,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
bundled: None, bundled: None,
}, },
ExtensionConfigRequest::Frontend { ExtensionConfigRequest::Frontend {

View File

@@ -136,6 +136,10 @@ pub enum ExtensionConfig {
// NOTE: set timeout to be optional for compatibility. // NOTE: set timeout to be optional for compatibility.
// However, new configurations should include this field. // However, new configurations should include this field.
timeout: Option<u64>, timeout: Option<u64>,
#[serde(default)]
max_pending_requests: Option<usize>,
#[serde(default)]
pending_request_timeout: Option<u64>,
/// Whether this extension is bundled with Goose /// Whether this extension is bundled with Goose
#[serde(default)] #[serde(default)]
bundled: Option<bool>, bundled: Option<bool>,
@@ -153,6 +157,10 @@ pub enum ExtensionConfig {
env_keys: Vec<String>, env_keys: Vec<String>,
timeout: Option<u64>, timeout: Option<u64>,
description: Option<String>, description: Option<String>,
#[serde(default)]
max_pending_requests: Option<usize>,
#[serde(default)]
pending_request_timeout: Option<u64>,
/// Whether this extension is bundled with Goose /// Whether this extension is bundled with Goose
#[serde(default)] #[serde(default)]
bundled: Option<bool>, bundled: Option<bool>,
@@ -164,6 +172,10 @@ pub enum ExtensionConfig {
name: String, name: String,
display_name: Option<String>, // needed for the UI display_name: Option<String>, // needed for the UI
timeout: Option<u64>, timeout: Option<u64>,
#[serde(default)]
max_pending_requests: Option<usize>,
#[serde(default)]
pending_request_timeout: Option<u64>,
/// Whether this extension is bundled with Goose /// Whether this extension is bundled with Goose
#[serde(default)] #[serde(default)]
bundled: Option<bool>, bundled: Option<bool>,
@@ -189,6 +201,8 @@ impl Default for ExtensionConfig {
name: config::DEFAULT_EXTENSION.to_string(), name: config::DEFAULT_EXTENSION.to_string(),
display_name: Some(config::DEFAULT_DISPLAY_NAME.to_string()), display_name: Some(config::DEFAULT_DISPLAY_NAME.to_string()),
timeout: Some(config::DEFAULT_EXTENSION_TIMEOUT), timeout: Some(config::DEFAULT_EXTENSION_TIMEOUT),
max_pending_requests: Some(config::DEFAULT_MAX_PENDING_REQUESTS),
pending_request_timeout: Some(config::DEFAULT_PENDING_REQUEST_TIMEOUT),
bundled: Some(true), bundled: Some(true),
} }
} }
@@ -203,6 +217,8 @@ impl ExtensionConfig {
env_keys: Vec::new(), env_keys: Vec::new(),
description: Some(description.into()), description: Some(description.into()),
timeout: Some(timeout.into()), timeout: Some(timeout.into()),
max_pending_requests: Some(config::DEFAULT_MAX_PENDING_REQUESTS),
pending_request_timeout: Some(config::DEFAULT_PENDING_REQUEST_TIMEOUT),
bundled: None, bundled: None,
} }
} }
@@ -221,6 +237,8 @@ impl ExtensionConfig {
env_keys: Vec::new(), env_keys: Vec::new(),
description: Some(description.into()), description: Some(description.into()),
timeout: Some(timeout.into()), timeout: Some(timeout.into()),
max_pending_requests: Some(config::DEFAULT_MAX_PENDING_REQUESTS),
pending_request_timeout: Some(config::DEFAULT_PENDING_REQUEST_TIMEOUT),
bundled: None, bundled: None,
} }
} }
@@ -237,6 +255,8 @@ impl ExtensionConfig {
envs, envs,
env_keys, env_keys,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
description, description,
bundled, bundled,
.. ..
@@ -248,6 +268,8 @@ impl ExtensionConfig {
args: args.into_iter().map(Into::into).collect(), args: args.into_iter().map(Into::into).collect(),
description, description,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
bundled, bundled,
}, },
other => other, other => other,

View File

@@ -178,10 +178,17 @@ impl ExtensionManager {
envs, envs,
env_keys, env_keys,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
.. ..
} => { } => {
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
let transport = SseTransport::new(uri, all_envs); let transport = SseTransport::new(
uri,
all_envs,
*max_pending_requests,
pending_request_timeout.map(Duration::from_secs),
);
let handle = transport.start().await?; let handle = transport.start().await?;
let service = McpService::with_timeout( let service = McpService::with_timeout(
handle, handle,
@@ -197,10 +204,18 @@ impl ExtensionManager {
envs, envs,
env_keys, env_keys,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
.. ..
} => { } => {
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
let transport = StdioTransport::new(cmd, args.to_vec(), all_envs); let transport = StdioTransport::new(
cmd,
args.to_vec(),
all_envs,
*max_pending_requests,
pending_request_timeout.map(Duration::from_secs),
);
let handle = transport.start().await?; let handle = transport.start().await?;
let service = McpService::with_timeout( let service = McpService::with_timeout(
handle, handle,
@@ -214,6 +229,8 @@ impl ExtensionManager {
name, name,
display_name: _, display_name: _,
timeout, timeout,
max_pending_requests,
pending_request_timeout,
bundled: _, bundled: _,
} => { } => {
let cmd = std::env::current_exe() let cmd = std::env::current_exe()
@@ -225,6 +242,8 @@ impl ExtensionManager {
&cmd, &cmd,
vec!["mcp".to_string(), name.clone()], vec!["mcp".to_string(), name.clone()],
HashMap::new(), HashMap::new(),
*max_pending_requests,
pending_request_timeout.map(Duration::from_secs),
); );
let handle = transport.start().await?; let handle = transport.start().await?;
let service = McpService::with_timeout( let service = McpService::with_timeout(

View File

@@ -7,6 +7,8 @@ use utoipa::ToSchema;
pub const DEFAULT_EXTENSION: &str = "developer"; pub const DEFAULT_EXTENSION: &str = "developer";
pub const DEFAULT_EXTENSION_TIMEOUT: u64 = 300; pub const DEFAULT_EXTENSION_TIMEOUT: u64 = 300;
pub const DEFAULT_MAX_PENDING_REQUESTS: usize = 128;
pub const DEFAULT_PENDING_REQUEST_TIMEOUT: u64 = 30;
pub const DEFAULT_EXTENSION_DESCRIPTION: &str = ""; pub const DEFAULT_EXTENSION_DESCRIPTION: &str = "";
pub const DEFAULT_DISPLAY_NAME: &str = "Developer"; pub const DEFAULT_DISPLAY_NAME: &str = "Developer";
@@ -45,6 +47,8 @@ impl ExtensionConfigManager {
name: DEFAULT_EXTENSION.to_string(), name: DEFAULT_EXTENSION.to_string(),
display_name: Some(DEFAULT_DISPLAY_NAME.to_string()), display_name: Some(DEFAULT_DISPLAY_NAME.to_string()),
timeout: Some(DEFAULT_EXTENSION_TIMEOUT), timeout: Some(DEFAULT_EXTENSION_TIMEOUT),
max_pending_requests: Some(DEFAULT_MAX_PENDING_REQUESTS),
pending_request_timeout: Some(DEFAULT_PENDING_REQUEST_TIMEOUT),
bundled: Some(true), bundled: Some(true),
}, },
}, },

View File

@@ -13,3 +13,5 @@ pub use extensions::DEFAULT_DISPLAY_NAME;
pub use extensions::DEFAULT_EXTENSION; pub use extensions::DEFAULT_EXTENSION;
pub use extensions::DEFAULT_EXTENSION_DESCRIPTION; pub use extensions::DEFAULT_EXTENSION_DESCRIPTION;
pub use extensions::DEFAULT_EXTENSION_TIMEOUT; pub use extensions::DEFAULT_EXTENSION_TIMEOUT;
pub use extensions::DEFAULT_MAX_PENDING_REQUESTS;
pub use extensions::DEFAULT_PENDING_REQUEST_TIMEOUT;

View File

@@ -18,17 +18,34 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
) )
.init(); .init();
let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); let transport1 = StdioTransport::new(
"uvx",
vec!["mcp-server-git".to_string()],
HashMap::new(),
None,
None,
);
let handle1 = transport1.start().await?; let handle1 = transport1.start().await?;
let service1 = McpService::with_timeout(handle1, Duration::from_secs(30)); let service1 = McpService::with_timeout(handle1, Duration::from_secs(30));
let client1 = McpClient::new(service1); let client1 = McpClient::new(service1);
let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); let transport2 = StdioTransport::new(
"uvx",
vec!["mcp-server-git".to_string()],
HashMap::new(),
None,
None,
);
let handle2 = transport2.start().await?; let handle2 = transport2.start().await?;
let service2 = McpService::with_timeout(handle2, Duration::from_secs(30)); let service2 = McpService::with_timeout(handle2, Duration::from_secs(30));
let client2 = McpClient::new(service2); let client2 = McpClient::new(service2);
let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new()); let transport3 = SseTransport::new(
"http://localhost:8000/sse",
HashMap::new(),
None,
None,
);
let handle3 = transport3.start().await?; let handle3 = transport3.start().await?;
let service3 = McpService::with_timeout(handle3, Duration::from_secs(10)); let service3 = McpService::with_timeout(handle3, Duration::from_secs(10));
let client3 = McpClient::new(service3); let client3 = McpClient::new(service3);

View File

@@ -18,7 +18,12 @@ async fn main() -> Result<()> {
.init(); .init();
// Create the base transport // Create the base transport
let transport = SseTransport::new("http://localhost:8000/sse", HashMap::new()); let transport = SseTransport::new(
"http://localhost:8000/sse",
HashMap::new(),
None,
None,
);
// Start transport // Start transport
let handle = transport.start().await?; let handle = transport.start().await?;

View File

@@ -20,7 +20,13 @@ async fn main() -> Result<(), ClientError> {
.init(); .init();
// 1) Create the transport // 1) Create the transport
let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); let transport = StdioTransport::new(
"uvx",
vec!["mcp-server-git".to_string()],
HashMap::new(),
None,
None,
);
// 2) Start the transport to get a handle // 2) Start the transport to get a handle
let transport_handle = transport.start().await?; let transport_handle = transport.start().await?;

View File

@@ -29,6 +29,8 @@ async fn main() -> Result<(), ClientError> {
.map(|s| s.to_string()) .map(|s| s.to_string())
.collect(), .collect(),
HashMap::new(), HashMap::new(),
None,
None,
); );
// Start the transport to get a handle // Start the transport to get a handle

View File

@@ -1,6 +1,7 @@
use async_trait::async_trait; use async_trait::async_trait;
use mcp_core::protocol::JsonRpcMessage; use mcp_core::protocol::JsonRpcMessage;
use std::collections::HashMap; use std::collections::HashMap;
use std::time::{Duration, Instant};
use thiserror::Error; use thiserror::Error;
use tokio::sync::{mpsc, oneshot, RwLock}; use tokio::sync::{mpsc, oneshot, RwLock};
@@ -31,6 +32,9 @@ pub enum Error {
#[error("HTTP error: {status} - {message}")] #[error("HTTP error: {status} - {message}")]
HttpError { status: u16, message: String }, HttpError { status: u16, message: String },
#[error("Too many pending requests")]
PendingRequestsFull,
} }
/// A message that can be sent through the transport /// A message that can be sent through the transport
@@ -89,7 +93,10 @@ pub async fn send_message(
// A data structure to store pending requests and their response channels // A data structure to store pending requests and their response channels
pub struct PendingRequests { pub struct PendingRequests {
requests: RwLock<HashMap<String, oneshot::Sender<Result<JsonRpcMessage, Error>>>>, requests:
RwLock<HashMap<String, (oneshot::Sender<Result<JsonRpcMessage, Error>>, Instant)>>,
max_size: Option<usize>,
timeout: Option<Duration>,
} }
impl Default for PendingRequests { impl Default for PendingRequests {
@@ -100,21 +107,49 @@ impl Default for PendingRequests {
impl PendingRequests { impl PendingRequests {
pub fn new() -> Self { pub fn new() -> Self {
Self::with_limits(None, None)
}
pub fn with_limits(max_size: Option<usize>, timeout: Option<Duration>) -> Self {
Self { Self {
requests: RwLock::new(HashMap::new()), requests: RwLock::new(HashMap::new()),
max_size,
timeout,
} }
} }
pub async fn insert(&self, id: String, sender: oneshot::Sender<Result<JsonRpcMessage, Error>>) { pub async fn insert(
self.requests.write().await.insert(id, sender); &self,
id: String,
sender: oneshot::Sender<Result<JsonRpcMessage, Error>>,
) -> Result<(), Error> {
self.cleanup().await;
let mut map = self.requests.write().await;
if let Some(max) = self.max_size {
if map.len() >= max {
return Err(Error::PendingRequestsFull);
}
}
map.insert(id, (sender, Instant::now()));
Ok(())
} }
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) { pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
if let Some(tx) = self.requests.write().await.remove(id) { if let Some((tx, _)) = self.requests.write().await.remove(id) {
let _ = tx.send(response); let _ = tx.send(response);
} }
} }
pub async fn cleanup(&self) {
if let Some(timeout) = self.timeout {
let now = Instant::now();
self.requests
.write()
.await
.retain(|_, (_, ts)| now.duration_since(*ts) <= timeout);
}
}
pub async fn clear(&self) { pub async fn clear(&self) {
self.requests.write().await.clear(); self.requests.write().await.clear();
} }

View File

@@ -179,7 +179,14 @@ impl SseActor {
if let JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) = if let JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) =
&transport_msg.message &transport_msg.message
{ {
pending_requests.insert(id.to_string(), response_tx).await; pending_requests.cleanup().await;
if let Err(e) = pending_requests
.insert(id.to_string(), response_tx)
.await
{
let _ = response_tx.send(Err(e));
continue;
}
} }
} }
@@ -236,14 +243,23 @@ impl TransportHandle for SseTransportHandle {
pub struct SseTransport { pub struct SseTransport {
sse_url: String, sse_url: String,
env: HashMap<String, String>, env: HashMap<String, String>,
max_pending: Option<usize>,
pending_timeout: Option<Duration>,
} }
/// The SSE transport spawns an `SseActor` on `start()`. /// The SSE transport spawns an `SseActor` on `start()`.
impl SseTransport { impl SseTransport {
pub fn new<S: Into<String>>(sse_url: S, env: HashMap<String, String>) -> Self { pub fn new<S: Into<String>>(
sse_url: S,
env: HashMap<String, String>,
max_pending: Option<usize>,
pending_timeout: Option<Duration>,
) -> Self {
Self { Self {
sse_url: sse_url.into(), sse_url: sse_url.into(),
env, env,
max_pending,
pending_timeout,
} }
} }
@@ -286,7 +302,10 @@ impl Transport for SseTransport {
// Build the actor // Build the actor
let actor = SseActor::new( let actor = SseActor::new(
rx, rx,
Arc::new(PendingRequests::new()), Arc::new(PendingRequests::with_limits(
self.max_pending,
self.pending_timeout,
)),
self.sse_url.clone(), self.sse_url.clone(),
post_endpoint, post_endpoint,
); );

View File

@@ -1,4 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Duration;
use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::atomic::{AtomicI32, Ordering};
use std::sync::Arc; use std::sync::Arc;
use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command}; use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
@@ -162,7 +163,14 @@ impl StdioActor {
if let Some(response_tx) = transport_msg.response_tx.take() { if let Some(response_tx) = transport_msg.response_tx.take() {
if let JsonRpcMessage::Request(request) = &transport_msg.message { if let JsonRpcMessage::Request(request) = &transport_msg.message {
if let Some(id) = &request.id { if let Some(id) = &request.id {
pending_requests.insert(id.to_string(), response_tx).await; pending_requests.cleanup().await;
if let Err(e) = pending_requests
.insert(id.to_string(), response_tx)
.await
{
let _ = response_tx.send(Err(e));
continue;
}
} }
} }
} }
@@ -218,6 +226,8 @@ pub struct StdioTransport {
command: String, command: String,
args: Vec<String>, args: Vec<String>,
env: HashMap<String, String>, env: HashMap<String, String>,
max_pending: Option<usize>,
pending_timeout: Option<Duration>,
} }
impl StdioTransport { impl StdioTransport {
@@ -225,11 +235,15 @@ impl StdioTransport {
command: S, command: S,
args: Vec<String>, args: Vec<String>,
env: HashMap<String, String>, env: HashMap<String, String>,
max_pending: Option<usize>,
pending_timeout: Option<Duration>,
) -> Self { ) -> Self {
Self { Self {
command: command.into(), command: command.into(),
args, args,
env, env,
max_pending,
pending_timeout,
} }
} }
@@ -294,7 +308,10 @@ impl Transport for StdioTransport {
let actor = StdioActor { let actor = StdioActor {
receiver: Some(message_rx), receiver: Some(message_rx),
pending_requests: Arc::new(PendingRequests::new()), pending_requests: Arc::new(PendingRequests::with_limits(
self.max_pending,
self.pending_timeout,
)),
process, process,
error_sender: error_tx, error_sender: error_tx,
stdin: Some(stdin), stdin: Some(stdin),