From 0ef38c665828ec7707b073bdfcbdb6f77d8afbb3 Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Fri, 25 Jul 2025 14:04:18 -0400 Subject: [PATCH] chore: use typed notifications from rmcp (#3653) --- crates/goose-cli/src/session/mod.rs | 206 +++++++++--------- crates/goose-cli/src/session/output.rs | 6 +- crates/goose-mcp/src/developer/mod.rs | 2 + crates/goose-server/src/routes/reply.rs | 4 +- crates/goose/src/agents/agent.rs | 11 +- crates/goose/src/agents/extension_manager.rs | 4 +- .../subagent_execution_tool/executor/mod.rs | 6 +- .../agents/subagent_execution_tool/lib/mod.rs | 4 +- .../subagent_execute_task_tool.rs | 5 +- .../task_execution_tracker.rs | 29 +-- crates/goose/src/agents/tool_execution.rs | 4 +- crates/mcp-client/src/client.rs | 26 ++- crates/mcp-client/src/service.rs | 18 +- crates/mcp-client/src/transport/mod.rs | 17 +- crates/mcp-client/src/transport/sse.rs | 14 +- crates/mcp-client/src/transport/stdio.rs | 14 +- .../src/transport/streamable_http.rs | 20 +- 17 files changed, 198 insertions(+), 192 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 6ebc7954..d0940823 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -36,7 +36,8 @@ use goose::providers::pricing::initialize_pricing_cache; use goose::session; use input::InputResult; use mcp_core::handler::ToolError; -use rmcp::model::{JsonRpcMessage, JsonRpcNotification, Notification, PromptMessage}; +use rmcp::model::PromptMessage; +use rmcp::model::ServerNotification; use rand::{distributions::Alphanumeric, Rng}; use rustyline::EditMode; @@ -1023,126 +1024,115 @@ impl Session { } } Some(Ok(AgentEvent::McpNotification((_id, message)))) => { - if let JsonRpcMessage::Notification( JsonRpcNotification { - notification: Notification { - method, - params: o,.. - },.. - }) = message { - match method.as_str() { - "notifications/message" => { - let data = o.get("data").unwrap_or(&Value::Null); - let (formatted_message, subagent_id, message_notification_type) = match data { - Value::String(s) => (s.clone(), None, None), - Value::Object(o) => { - // Check for subagent notification structure first - if let Some(Value::String(msg)) = o.get("message") { - // Extract subagent info for better display - let subagent_id = o.get("subagent_id") - .and_then(|v| v.as_str()) - .unwrap_or("unknown"); - let notification_type = o.get("type") - .and_then(|v| v.as_str()) - .unwrap_or(""); + match &message { + ServerNotification::LoggingMessageNotification(notification) => { + let data = ¬ification.params.data; + let (formatted_message, subagent_id, message_notification_type) = match data { + Value::String(s) => (s.clone(), None, None), + Value::Object(o) => { + // Check for subagent notification structure first + if let Some(Value::String(msg)) = o.get("message") { + // Extract subagent info for better display + let subagent_id = o.get("subagent_id") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let notification_type = o.get("type") + .and_then(|v| v.as_str()) + .unwrap_or(""); - let formatted = match notification_type { - "subagent_created" | "completed" | "terminated" => { - format!("🤖 {}", msg) - } - "tool_usage" | "tool_completed" | "tool_error" => { - format!("🔧 {}", msg) - } - "message_processing" | "turn_progress" => { - format!("💭 {}", msg) - } - "response_generated" => { - // Check verbosity setting for subagent response content - let config = Config::global(); - let min_priority = config - .get_param::("GOOSE_CLI_MIN_PRIORITY") - .ok() - .unwrap_or(0.5); + let formatted = match notification_type { + "subagent_created" | "completed" | "terminated" => { + format!("🤖 {}", msg) + } + "tool_usage" | "tool_completed" | "tool_error" => { + format!("🔧 {}", msg) + } + "message_processing" | "turn_progress" => { + format!("💭 {}", msg) + } + "response_generated" => { + // Check verbosity setting for subagent response content + let config = Config::global(); + let min_priority = config + .get_param::("GOOSE_CLI_MIN_PRIORITY") + .ok() + .unwrap_or(0.5); - if min_priority > 0.1 && !self.debug { - // High/Medium verbosity: show truncated response - if let Some(response_content) = msg.strip_prefix("Responded: ") { - format!("🤖 Responded: {}", safe_truncate(response_content, 100)) - } else { - format!("🤖 {}", msg) - } + if min_priority > 0.1 && !self.debug { + // High/Medium verbosity: show truncated response + if let Some(response_content) = msg.strip_prefix("Responded: ") { + format!("🤖 Responded: {}", safe_truncate(response_content, 100)) } else { - // All verbosity or debug: show full response format!("🤖 {}", msg) } + } else { + // All verbosity or debug: show full response + format!("🤖 {}", msg) } - _ => { - msg.to_string() - } - }; - (formatted, Some(subagent_id.to_string()), Some(notification_type.to_string())) - } else if let Some(Value::String(output)) = o.get("output") { - // Fallback for other MCP notification types - (output.to_owned(), None, None) - } else if let Some(result) = format_task_execution_notification(data) { - result - } else { - (data.to_string(), None, None) - } - }, - v => { - (v.to_string(), None, None) - }, - }; + } + _ => { + msg.to_string() + } + }; + (formatted, Some(subagent_id.to_string()), Some(notification_type.to_string())) + } else if let Some(Value::String(output)) = o.get("output") { + // Fallback for other MCP notification types + (output.to_owned(), None, None) + } else if let Some(result) = format_task_execution_notification(data) { + result + } else { + (data.to_string(), None, None) + } + }, + v => { + (v.to_string(), None, None) + }, + }; - // Handle subagent notifications - show immediately - if let Some(_id) = subagent_id { - // TODO: proper display for subagent notifications + // Handle subagent notifications - show immediately + if let Some(_id) = subagent_id { + // TODO: proper display for subagent notifications + if interactive { + let _ = progress_bars.hide(); + println!("{}", console::style(&formatted_message).green().dim()); + } else { + progress_bars.log(&formatted_message); + } + } else if let Some(ref notification_type) = message_notification_type { + if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE { if interactive { let _ = progress_bars.hide(); - println!("{}", console::style(&formatted_message).green().dim()); + print!("{}", formatted_message); + std::io::stdout().flush().unwrap(); } else { - progress_bars.log(&formatted_message); - } - } else if let Some(ref notification_type) = message_notification_type { - if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE { - if interactive { - let _ = progress_bars.hide(); - print!("{}", formatted_message); - std::io::stdout().flush().unwrap(); - } else { - print!("{}", formatted_message); - std::io::stdout().flush().unwrap(); - } + print!("{}", formatted_message); + std::io::stdout().flush().unwrap(); } } - else { - // Non-subagent notification, display immediately with compact spacing - if interactive { - let _ = progress_bars.hide(); - println!("{}", console::style(&formatted_message).green().dim()); - } else { - progress_bars.log(&formatted_message); - } + } + else { + // Non-subagent notification, display immediately with compact spacing + if interactive { + let _ = progress_bars.hide(); + println!("{}", console::style(&formatted_message).green().dim()); + } else { + progress_bars.log(&formatted_message); } - }, - "notifications/progress" => { - let progress = o.get("progress").and_then(|v| v.as_f64()); - let token = o.get("progressToken").map(|v| v.to_string()); - let message = o.get("message").and_then(|v| v.as_str()); - let total = o - .get("total") - .and_then(|v| v.as_f64()); - if let (Some(progress), Some(token)) = (progress, token) { - progress_bars.update( - token.as_str(), - progress, - total, - message, - ); - } - }, - _ => (), - } + } + }, + ServerNotification::ProgressNotification(notification) => { + let progress = notification.params.progress; + let text = notification.params.message.as_deref(); + let total = notification.params.total; + let token = ¬ification.params.progress_token; + progress_bars.update( + &token.0.to_string(), + progress, + total, + text, + ); + }, + _ => (), } } Some(Ok(AgentEvent::ModelChange { model, mode })) => { diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index c6265f74..b72b5a11 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -799,11 +799,11 @@ impl McpSpinners { spinner.set_message(message.to_string()); } - pub fn update(&mut self, token: &str, value: f64, total: Option, message: Option<&str>) { + pub fn update(&mut self, token: &str, value: u32, total: Option, message: Option<&str>) { let bar = self.bars.entry(token.to_string()).or_insert_with(|| { if let Some(total) = total { self.multi_bar.add( - ProgressBar::new((total * 100.0) as u64).with_style( + ProgressBar::new((total * 100) as u64).with_style( ProgressStyle::with_template("[{elapsed}] {bar:40} {pos:>3}/{len:3} {msg}") .unwrap(), ), @@ -812,7 +812,7 @@ impl McpSpinners { self.multi_bar.add(ProgressBar::new_spinner()) } }); - bar.set_position((value * 100.0) as u64); + bar.set_position((value * 100) as u64); if let Some(msg) = message { bar.set_message(msg.to_string()); } diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index 814a16b9..f77749c0 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -672,6 +672,7 @@ impl DeveloperRouter { notification: Notification { method: "notifications/message".to_string(), params: object!({ + "level": "info", "data": { "type": "shell", "stream": "stdout", @@ -698,6 +699,7 @@ impl DeveloperRouter { notification: Notification { method: "notifications/message".to_string(), params: object!({ + "level": "info", "data": { "type": "shell", "stream": "stderr", diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 5e173314..12cfae3a 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -19,7 +19,7 @@ use goose::{ session, }; use mcp_core::ToolResult; -use rmcp::model::{Content, JsonRpcMessage}; +use rmcp::model::{Content, ServerNotification}; use serde::{Deserialize, Serialize}; use serde_json::json; use serde_json::Value; @@ -97,7 +97,7 @@ enum MessageEvent { }, Notification { request_id: String, - message: JsonRpcMessage, + message: ServerNotification, }, } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 6f3b83e4..e4ee8e9d 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -45,8 +45,7 @@ use crate::tool_monitor::{ToolCall, ToolMonitor}; use crate::utils::is_token_cancelled; use mcp_core::{ToolError, ToolResult}; use regex::Regex; -use rmcp::model::Tool; -use rmcp::model::{Content, GetPromptResult, JsonRpcMessage, Prompt}; +use rmcp::model::{Content, GetPromptResult, Prompt, ServerNotification, Tool}; use serde_json::Value; use tokio::sync::{mpsc, Mutex, RwLock}; use tokio_util::sync::CancellationToken; @@ -83,7 +82,7 @@ pub struct Agent { #[derive(Clone, Debug)] pub enum AgentEvent { Message(Message), - McpNotification((String, JsonRpcMessage)), + McpNotification((String, ServerNotification)), ModelChange { model: String, mode: String }, } @@ -94,19 +93,19 @@ impl Default for Agent { } pub enum ToolStreamItem { - Message(JsonRpcMessage), + Message(ServerNotification), Result(T), } pub type ToolStream = Pin>>> + Send>>; -// tool_stream combines a stream of JsonRpcMessages with a future representing the +// tool_stream combines a stream of ServerNotifications with a future representing the // final result of the tool call. MCP notifications are not request-scoped, but // this lets us capture all notifications emitted during the tool call for // simpler consumption pub fn tool_stream(rx: S, done: F) -> ToolStream where - S: Stream + Send + Unpin + 'static, + S: Stream + Send + Unpin + 'static, F: Future>> + Send + 'static, { Box::pin(async_stream::stream! { diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 1898dcda..a3b89a07 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -835,7 +835,7 @@ mod tests { CallToolResult, InitializeResult, ListPromptsResult, ListResourcesResult, ListToolsResult, ReadResourceResult, }; - use rmcp::model::{GetPromptResult, JsonRpcMessage}; + use rmcp::model::{GetPromptResult, ServerNotification}; use serde_json::json; use tokio::sync::mpsc; @@ -891,7 +891,7 @@ mod tests { Err(Error::NotInitialized) } - async fn subscribe(&self) -> mpsc::Receiver { + async fn subscribe(&self) -> mpsc::Receiver { mpsc::channel(1).1 } } diff --git a/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs index 7aaa14c7..daa00a96 100644 --- a/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs @@ -7,7 +7,7 @@ use crate::agents::subagent_execution_tool::task_execution_tracker::{ use crate::agents::subagent_execution_tool::tasks::process_task; use crate::agents::subagent_execution_tool::workers::spawn_worker; use crate::agents::subagent_task_config::TaskConfig; -use rmcp::model::JsonRpcMessage; +use rmcp::model::ServerNotification; use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tokio::sync::mpsc; @@ -20,7 +20,7 @@ const DEFAULT_MAX_WORKERS: usize = 10; pub async fn execute_single_task( task: &Task, - notifier: mpsc::Sender, + notifier: mpsc::Sender, task_config: TaskConfig, cancellation_token: Option, ) -> ExecutionResponse { @@ -56,7 +56,7 @@ pub async fn execute_single_task( pub async fn execute_tasks_in_parallel( tasks: Vec, - notifier: Sender, + notifier: Sender, task_config: TaskConfig, cancellation_token: Option, ) -> ExecutionResponse { diff --git a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs index 172ad03c..695d8ddb 100644 --- a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs @@ -6,7 +6,7 @@ use crate::agents::subagent_execution_tool::{ tasks_manager::TasksManager, }; use crate::agents::subagent_task_config::TaskConfig; -use rmcp::model::JsonRpcMessage; +use rmcp::model::ServerNotification; use serde_json::{json, Value}; use tokio::sync::mpsc::Sender; use tokio_util::sync::CancellationToken; @@ -14,7 +14,7 @@ use tokio_util::sync::CancellationToken; pub async fn execute_tasks( input: Value, execution_mode: ExecutionMode, - notifier: Sender, + notifier: Sender, task_config: TaskConfig, tasks_manager: &TasksManager, cancellation_token: Option, diff --git a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs index 2527d558..430d3942 100644 --- a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs +++ b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs @@ -1,5 +1,5 @@ use mcp_core::ToolError; -use rmcp::model::{Content, Tool, ToolAnnotations}; +use rmcp::model::{Content, ServerNotification, Tool, ToolAnnotations}; use serde_json::Value; use crate::agents::subagent_task_config::TaskConfig; @@ -8,7 +8,6 @@ use crate::agents::{ subagent_execution_tool::task_types::ExecutionMode, subagent_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult, }; -use rmcp::model::JsonRpcMessage; use rmcp::object; use tokio::sync::mpsc; use tokio_stream; @@ -67,7 +66,7 @@ pub async fn run_tasks( tasks_manager: &TasksManager, cancellation_token: Option, ) -> ToolCallResult { - let (notification_tx, notification_rx) = mpsc::channel::(100); + let (notification_tx, notification_rx) = mpsc::channel::(100); let tasks_manager_clone = tasks_manager.clone(); let result_future = async move { diff --git a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs index 908bb465..18157eec 100644 --- a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs @@ -1,5 +1,7 @@ -use rmcp::model::{JsonRpcMessage, JsonRpcNotification, JsonRpcVersion2_0, Notification}; -use rmcp::object; +use rmcp::model::{ + LoggingLevel, LoggingMessageNotification, LoggingMessageNotificationMethod, + LoggingMessageNotificationParam, ServerNotification, +}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; @@ -52,7 +54,7 @@ fn format_task_metadata(task_info: &TaskInfo) -> String { pub struct TaskExecutionTracker { tasks: Arc>>, last_refresh: Arc>, - notifier: mpsc::Sender, + notifier: mpsc::Sender, display_mode: DisplayMode, cancellation_token: Option, } @@ -61,7 +63,7 @@ impl TaskExecutionTracker { pub fn new( tasks: Vec, display_mode: DisplayMode, - notifier: Sender, + notifier: Sender, cancellation_token: Option, ) -> Self { let task_map = tasks @@ -97,7 +99,7 @@ impl TaskExecutionTracker { fn log_notification_error( &self, - error: &mpsc::error::TrySendError, + error: &mpsc::error::TrySendError, context: &str, ) { if !self.is_cancelled() { @@ -108,16 +110,17 @@ impl TaskExecutionTracker { fn try_send_notification(&self, event: TaskExecutionNotificationEvent, context: &str) { if let Err(e) = self .notifier - .try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: JsonRpcVersion2_0, - notification: Notification { - method: "notifications/message".to_string(), - params: object!({ - "data": event.to_notification_data() - }), + .try_send(ServerNotification::LoggingMessageNotification( + LoggingMessageNotification { + method: LoggingMessageNotificationMethod, + params: LoggingMessageNotificationParam { + data: event.to_notification_data(), + level: LoggingLevel::Info, + logger: None, + }, extensions: Default::default(), }, - })) + )) { self.log_notification_error(&e, context); } diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs index bc9f4292..9be0d9b8 100644 --- a/crates/goose/src/agents/tool_execution.rs +++ b/crates/goose/src/agents/tool_execution.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use async_stream::try_stream; use futures::stream::{self, BoxStream}; use futures::{Stream, StreamExt}; -use rmcp::model::JsonRpcMessage; +use rmcp::model::ServerNotification; use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; @@ -19,7 +19,7 @@ use rmcp::model::Content; // can be used to receive notifications from the tool. pub struct ToolCallResult { pub result: Box>> + Send + Unpin>, - pub notification_stream: Option + Send + Unpin>>, + pub notification_stream: Option + Send + Unpin>>, } impl From>> for ToolCallResult { diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 59f64a70..2c2de1f4 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -5,6 +5,7 @@ use mcp_core::protocol::{ use rmcp::model::{ GetPromptResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, JsonRpcVersion2_0, Notification, NumberOrString, Request, RequestId, + ServerNotification, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -106,7 +107,7 @@ pub trait McpClientTrait: Send + Sync { async fn get_prompt(&self, name: &str, arguments: Value) -> Result; - async fn subscribe(&self) -> mpsc::Receiver; + async fn subscribe(&self) -> mpsc::Receiver; } /// The MCP client is the interface for MCP operations. @@ -118,7 +119,7 @@ where next_id_counter: AtomicU64, // Added for atomic ID generation server_capabilities: Option, server_info: Option, - notification_subscribers: Arc>>>, + notification_subscribers: Arc>>>, } impl McpClient @@ -129,7 +130,7 @@ where let service = McpService::new(transport.clone()); let service_ptr = service.clone(); let notification_subscribers = - Arc::new(Mutex::new(Vec::>::new())); + Arc::new(Mutex::new(Vec::>::new())); let subscribers_ptr = notification_subscribers.clone(); tokio::spawn(async move { @@ -148,9 +149,22 @@ where }) => { service_ptr.respond(&id.to_string(), Ok(message)).await; } - _ => { + JsonRpcMessage::Notification(JsonRpcNotification { + notification, + .. + }) => { let mut subs = subscribers_ptr.lock().await; - subs.retain(|sub| sub.try_send(message.clone()).is_ok()); + if let Some(server_notification) = notification.into() { + subs.retain(|sub| { + sub.try_send(server_notification.clone()).is_ok() + }); + } + } + _ => { + tracing::warn!( + "Received unexpected received message type: {:?}", + message + ); } } } @@ -437,7 +451,7 @@ where self.send_request("prompts/get", params).await } - async fn subscribe(&self) -> mpsc::Receiver { + async fn subscribe(&self) -> mpsc::Receiver { let (tx, rx) = mpsc::channel(16); self.notification_subscribers.lock().await.push(tx); rx diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index 0bdc680c..0a995c5d 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -6,7 +6,7 @@ use std::task::{Context, Poll}; use tokio::sync::{oneshot, RwLock}; use tower::{timeout::Timeout, Service, ServiceBuilder}; -use crate::transport::{Error, TransportHandle}; +use crate::transport::{Error, TransportHandle, TransportMessageRecv}; /// A wrapper service that implements Tower's Service trait for MCP transport #[derive(Clone)] @@ -23,7 +23,7 @@ impl McpService { } } - pub async fn respond(&self, id: &str, response: Result) { + pub async fn respond(&self, id: &str, response: Result) { self.pending_requests.respond(id, response).await } @@ -36,7 +36,7 @@ impl Service for McpService where T: TransportHandle + Send + Sync + 'static, { - type Response = JsonRpcMessage; + type Response = TransportMessageRecv; type Error = Error; type Future = BoxFuture<'static, Result>; @@ -63,7 +63,7 @@ where // Handle notifications without waiting for a response transport.send(request).await?; // Return a dummy response for notifications - let dummy_response: JsonRpcMessage = + let dummy_response: Self::Response = JsonRpcMessage::Response(rmcp::model::JsonRpcResponse { jsonrpc: rmcp::model::JsonRpcVersion2_0, id: rmcp::model::RequestId::Number(0), @@ -91,7 +91,7 @@ where // A data structure to store pending requests and their response channels pub struct PendingRequests { - requests: RwLock>>>, + requests: RwLock>>>, } impl Default for PendingRequests { @@ -107,11 +107,15 @@ impl PendingRequests { } } - pub async fn insert(&self, id: String, sender: oneshot::Sender>) { + pub async fn insert( + &self, + id: String, + sender: oneshot::Sender>, + ) { self.requests.write().await.insert(id, sender); } - pub async fn respond(&self, id: &str, response: Result) { + pub async fn respond(&self, id: &str, response: Result) { if let Some(tx) = self.requests.write().await.remove(id) { let _ = tx.send(response); } diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index 1a4c8e1b..c36cb079 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; -use rmcp::model::JsonRpcMessage; +use rmcp::model::{JsonObject, JsonRpcMessage, Request, ServerNotification}; use thiserror::Error; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::mpsc; pub type BoxError = Box; /// A generic error type for transport operations. @@ -38,15 +38,6 @@ pub enum Error { SessionError(String), } -/// A message that can be sent through the transport -#[derive(Debug)] -pub struct TransportMessage { - /// The JSON-RPC message to send - pub message: JsonRpcMessage, - /// Channel to receive the response on (None for notifications) - pub response_tx: Option>>, -} - /// A generic asynchronous transport trait with channel-based communication #[async_trait] pub trait Transport { @@ -60,10 +51,12 @@ pub trait Transport { async fn close(&self) -> Result<(), Error>; } +pub type TransportMessageRecv = JsonRpcMessage; + #[async_trait] pub trait TransportHandle: Send + Sync + Clone + 'static { async fn send(&self, message: JsonRpcMessage) -> Result<(), Error>; - async fn receive(&self) -> Result; + async fn receive(&self) -> Result; } pub async fn serialize_and_send( diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index 56a00fcf..6ef2f7d8 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -1,4 +1,4 @@ -use crate::transport::Error; +use crate::transport::{Error, TransportMessageRecv}; use async_trait::async_trait; use eventsource_client::{Client, SSE}; use futures::TryStreamExt; @@ -23,7 +23,7 @@ pub struct SseActor { /// Receives messages (requests/notifications) from the handle receiver: mpsc::Receiver, /// Sends messages (responses) back to the handle - sender: mpsc::Sender, + sender: mpsc::Sender, /// Base SSE URL sse_url: String, /// For sending HTTP POST requests @@ -35,7 +35,7 @@ pub struct SseActor { impl SseActor { pub fn new( receiver: mpsc::Receiver, - sender: mpsc::Sender, + sender: mpsc::Sender, sse_url: String, post_endpoint: Arc>>, ) -> Self { @@ -71,7 +71,7 @@ impl SseActor { /// - If a `message` event is received, parse it as `JsonRpcMessage` /// and respond to pending requests if it's a `Response`. async fn handle_incoming_messages( - sender: mpsc::Sender, + sender: mpsc::Sender, sse_url: String, post_endpoint: Arc>>, ) { @@ -109,7 +109,7 @@ impl SseActor { match event { SSE::Event(e) if e.event_type == "message" => { // Attempt to parse the SSE data as a JsonRpcMessage - match serde_json::from_str::(&e.data) { + match serde_json::from_str::(&e.data) { Ok(message) => { let _ = sender.send(message).await; } @@ -184,7 +184,7 @@ impl SseActor { #[derive(Clone)] pub struct SseTransportHandle { sender: mpsc::Sender, - receiver: Arc>>, + receiver: Arc>>, } #[async_trait::async_trait] @@ -193,7 +193,7 @@ impl TransportHandle for SseTransportHandle { serialize_and_send(&self.sender, message).await } - async fn receive(&self) -> Result { + async fn receive(&self) -> Result { let mut receiver = self.receiver.lock().await; receiver.recv().await.ok_or(Error::ChannelClosed) } diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index e721f0e5..225f245b 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -14,6 +14,8 @@ use nix::sys::signal::{kill, Signal}; #[cfg(unix)] use nix::unistd::{getpgid, Pid}; +use crate::transport::TransportMessageRecv; + use super::{serialize_and_send, Error, Transport, TransportHandle}; // Global to track process groups we've created @@ -24,7 +26,7 @@ static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1); /// It uses channels for message passing and handles responses asynchronously through a background task. pub struct StdioActor { receiver: Option>, - sender: Option>, + sender: Option>, process: Child, // we store the process to keep it alive error_sender: mpsc::Sender, stdin: Option, @@ -98,7 +100,7 @@ impl StdioActor { } } - async fn handle_proc_output(stdout: ChildStdout, sender: mpsc::Sender) { + async fn handle_proc_output(stdout: ChildStdout, sender: mpsc::Sender) { let mut reader = BufReader::new(stdout); let mut line = String::new(); loop { @@ -108,7 +110,7 @@ impl StdioActor { break; } // EOF Ok(_) => { - if let Ok(message) = serde_json::from_str::(&line) { + if let Ok(message) = serde_json::from_str::(&line) { tracing::debug!( message = ?message, "Received incoming message" @@ -149,8 +151,8 @@ impl StdioActor { #[derive(Clone)] pub struct StdioTransportHandle { - sender: mpsc::Sender, // to process - receiver: Arc>>, // from process + sender: mpsc::Sender, // to process + receiver: Arc>>, // from process error_receiver: Arc>>, } @@ -163,7 +165,7 @@ impl TransportHandle for StdioTransportHandle { result } - async fn receive(&self) -> Result { + async fn receive(&self) -> Result { let mut receiver = self.receiver.lock().await; match receiver.recv().await { Some(message) => Ok(message), diff --git a/crates/mcp-client/src/transport/streamable_http.rs b/crates/mcp-client/src/transport/streamable_http.rs index 8f3380da..760d861f 100644 --- a/crates/mcp-client/src/transport/streamable_http.rs +++ b/crates/mcp-client/src/transport/streamable_http.rs @@ -1,5 +1,5 @@ use crate::oauth::{authenticate_service, ServiceConfig}; -use crate::transport::Error; +use crate::transport::{Error, TransportMessageRecv}; use async_trait::async_trait; use eventsource_client::{Client, SSE}; use futures::TryStreamExt; @@ -25,7 +25,7 @@ pub struct StreamableHttpActor { /// Receives messages (requests/notifications) from the handle receiver: mpsc::Receiver, /// Sends messages (responses) back to the handle - sender: mpsc::Sender, + sender: mpsc::Sender, /// MCP endpoint URL mcp_endpoint: String, /// HTTP client for sending requests @@ -41,7 +41,7 @@ pub struct StreamableHttpActor { impl StreamableHttpActor { pub fn new( receiver: mpsc::Receiver, - sender: mpsc::Sender, + sender: mpsc::Sender, mcp_endpoint: String, session_id: Arc>>, env: HashMap, @@ -84,8 +84,8 @@ impl StreamableHttpActor { debug!("Sending message to MCP endpoint: {}", message_str); // Parse the message to determine if it's a request that expects a response - let parsed_message: JsonRpcMessage = - serde_json::from_str(&message_str).map_err(Error::Serialization)?; + let parsed_message = serde_json::from_str::(&message_str) + .map_err(Error::Serialization)?; let expects_response = matches!( parsed_message, @@ -196,8 +196,8 @@ impl StreamableHttpActor { })?; if !response_text.is_empty() { - let json_message: JsonRpcMessage = - serde_json::from_str(&response_text).map_err(Error::Serialization)?; + let json_message = serde_json::from_str::(&response_text) + .map_err(Error::Serialization)?; let _ = self.sender.send(json_message).await; } @@ -267,7 +267,7 @@ impl StreamableHttpActor { // Empty line indicates end of event if !event_data.is_empty() { // Parse the streamed data as JSON-RPC message - match serde_json::from_str::(&event_data) { + match serde_json::from_str::(&event_data) { Ok(message) => { debug!("Received streaming HTTP response message: {:?}", message); let _ = self.sender.send(message).await; @@ -301,7 +301,7 @@ impl StreamableHttpActor { #[derive(Clone)] pub struct StreamableHttpTransportHandle { sender: mpsc::Sender, - receiver: Arc>>, + receiver: Arc>>, session_id: Arc>>, mcp_endpoint: String, http_client: HttpClient, @@ -314,7 +314,7 @@ impl TransportHandle for StreamableHttpTransportHandle { serialize_and_send(&self.sender, message).await } - async fn receive(&self) -> Result { + async fn receive(&self) -> Result { let mut receiver = self.receiver.lock().await; receiver.recv().await.ok_or(Error::ChannelClosed) }