From 03e5549b5466e754df15c4d5370c7514c628341f Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Fri, 30 May 2025 11:50:14 -0400 Subject: [PATCH] feat: Handle MCP server notification messages (#2613) Co-authored-by: Michael Neale --- Cargo.lock | 2 + crates/goose-cli/Cargo.toml | 1 + crates/goose-cli/src/session/mod.rs | 54 +++++- crates/goose-cli/src/session/output.rs | 68 ++++++- crates/goose-ffi/src/lib.rs | 7 +- .../goose-mcp/src/computercontroller/mod.rs | 5 +- crates/goose-mcp/src/developer/mod.rs | 117 ++++++++++-- crates/goose-mcp/src/developer/shell.rs | 8 +- crates/goose-mcp/src/google_drive/mod.rs | 3 + crates/goose-mcp/src/jetbrains/mod.rs | 5 +- crates/goose-mcp/src/memory/mod.rs | 4 +- crates/goose-mcp/src/tutorial/mod.rs | 4 +- crates/goose-server/src/routes/reply.rs | 42 ++++- crates/goose/Cargo.toml | 2 +- crates/goose/examples/agent.rs | 14 +- crates/goose/src/agents/agent.rs | 173 +++++++++++++----- crates/goose/src/agents/extension_manager.rs | 111 ++++++----- crates/goose/src/agents/mod.rs | 2 +- crates/goose/src/agents/tool_execution.rs | 44 +++-- crates/goose/src/scheduler.rs | 6 +- crates/goose/tests/agent.rs | 7 +- crates/mcp-client/examples/clients.rs | 10 +- .../mcp-client/examples/integration_test.rs | 122 ++++++++++++ crates/mcp-client/examples/sse.rs | 6 +- crates/mcp-client/examples/stdio.rs | 9 +- .../mcp-client/examples/stdio_integration.rs | 6 +- crates/mcp-client/src/client.rs | 97 +++++++--- crates/mcp-client/src/service.rs | 87 ++++++++- crates/mcp-client/src/transport/mod.rs | 77 ++------ crates/mcp-client/src/transport/sse.rs | 142 ++++++-------- crates/mcp-client/src/transport/stdio.rs | 91 ++++----- crates/mcp-server/src/lib.rs | 45 ++++- crates/mcp-server/src/main.rs | 3 + crates/mcp-server/src/router.rs | 38 ++-- ui/desktop/.env | 2 +- ui/desktop/src/components/ChatView.tsx | 13 ++ ui/desktop/src/components/GooseMessage.tsx | 4 + .../src/components/ToolCallWithResponse.tsx | 163 ++++++++++++++++- ui/desktop/src/hooks/useMessageStream.ts | 29 ++- ui/desktop/tailwind.config.ts | 6 + 40 files changed, 1186 insertions(+), 443 deletions(-) create mode 100644 crates/mcp-client/examples/integration_test.rs diff --git a/Cargo.lock b/Cargo.lock index 49cb9f55..5c8f7acf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3435,6 +3435,7 @@ dependencies = [ "tokenizers", "tokio", "tokio-cron-scheduler", + "tokio-stream", "tracing", "tracing-subscriber", "url", @@ -3486,6 +3487,7 @@ dependencies = [ "goose", "goose-bench", "goose-mcp", + "indicatif", "mcp-client", "mcp-core", "mcp-server", diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index dfe73e48..3a73c44b 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -55,6 +55,7 @@ regex = "1.11.1" minijinja = "2.8.0" nix = { version = "0.30.1", features = ["process", "signal"] } tar = "0.4" +indicatif = "0.17.11" [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 273ec979..2695b366 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -7,6 +7,7 @@ mod thinking; pub use builder::{build_session, SessionBuilderConfig}; use console::Color; +use goose::agents::AgentEvent; use goose::permission::permission_confirmation::PrincipalType; use goose::permission::Permission; use goose::permission::PermissionConfirmation; @@ -26,6 +27,8 @@ use input::InputResult; use mcp_core::handler::ToolError; use mcp_core::prompt::PromptMessage; +use mcp_core::protocol::JsonRpcMessage; +use mcp_core::protocol::JsonRpcNotification; use rand::{distributions::Alphanumeric, Rng}; use serde_json::Value; use std::collections::HashMap; @@ -713,12 +716,15 @@ impl Session { ) .await?; + let mut progress_bars = output::McpSpinners::new(); + use futures::StreamExt; loop { tokio::select! { result = stream.next() => { + let _ = progress_bars.hide(); match result { - Some(Ok(message)) => { + Some(Ok(AgentEvent::Message(message))) => { // If it's a confirmation request, get approval but otherwise do not render/persist if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() { output::hide_thinking(); @@ -846,6 +852,51 @@ impl Session { if interactive {output::show_thinking()}; } } + Some(Ok(AgentEvent::McpNotification((_id, message)))) => { + if let JsonRpcMessage::Notification(JsonRpcNotification{ + method, + params: Some(Value::Object(o)), + .. + }) = message { + match method.as_str() { + "notifications/message" => { + let data = o.get("data").unwrap_or(&Value::Null); + let message = match data { + Value::String(s) => s.clone(), + Value::Object(o) => { + if let Some(Value::String(output)) = o.get("output") { + output.to_owned() + } else { + data.to_string() + } + }, + v => { + v.to_string() + }, + }; + // output::render_text_no_newlines(&message, None, true); + progress_bars.log(&message); + }, + "notifications/progress" => { + let progress = o.get("progress").and_then(|v| v.as_f64()); + let token = o.get("progressToken").map(|v| v.to_string()); + let message = o.get("message").and_then(|v| v.as_str()); + let total = o + .get("total") + .and_then(|v| v.as_f64()); + if let (Some(progress), Some(token)) = (progress, token) { + progress_bars.update( + token.as_str(), + progress, + total, + message, + ); + } + }, + _ => (), + } + } + } Some(Err(e)) => { eprintln!("Error: {}", e); drop(stream); @@ -872,6 +923,7 @@ impl Session { } } } + Ok(()) } diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index 873eec01..78b4eb76 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -2,12 +2,15 @@ use bat::WrappingMode; use console::{style, Color}; use goose::config::Config; use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use mcp_core::prompt::PromptArgument; use mcp_core::tool::ToolCall; use serde_json::Value; use std::cell::RefCell; use std::collections::HashMap; +use std::io::Error; use std::path::Path; +use std::time::Duration; // Re-export theme for use in main #[derive(Clone, Copy)] @@ -144,6 +147,10 @@ pub fn render_message(message: &Message, debug: bool) { } pub fn render_text(text: &str, color: Option, dim: bool) { + render_text_no_newlines(format!("\n{}\n\n", text).as_str(), color, dim); +} + +pub fn render_text_no_newlines(text: &str, color: Option, dim: bool) { let mut styled_text = style(text); if dim { styled_text = styled_text.dim(); @@ -153,7 +160,7 @@ pub fn render_text(text: &str, color: Option, dim: bool) { } else { styled_text = styled_text.green(); } - println!("\n{}\n", styled_text); + print!("{}", styled_text); } pub fn render_enter_plan_mode() { @@ -359,7 +366,6 @@ fn render_shell_request(call: &ToolCall, debug: bool) { } _ => print_params(&call.arguments, 0, debug), } - println!(); } fn render_default_request(call: &ToolCall, debug: bool) { @@ -568,6 +574,64 @@ pub fn display_greeting() { println!("\nGoose is running! Enter your instructions, or try asking what goose can do.\n"); } +pub struct McpSpinners { + bars: HashMap, + log_spinner: Option, + + multi_bar: MultiProgress, +} + +impl McpSpinners { + pub fn new() -> Self { + McpSpinners { + bars: HashMap::new(), + log_spinner: None, + multi_bar: MultiProgress::new(), + } + } + + pub fn log(&mut self, message: &str) { + let spinner = self.log_spinner.get_or_insert_with(|| { + let bar = self.multi_bar.add( + ProgressBar::new_spinner() + .with_style( + ProgressStyle::with_template("{spinner:.green} {msg}") + .unwrap() + .tick_chars("⠋⠙⠚⠛⠓⠒⠊⠉"), + ) + .with_message(message.to_string()), + ); + bar.enable_steady_tick(Duration::from_millis(100)); + bar + }); + + spinner.set_message(message.to_string()); + } + + pub fn update(&mut self, token: &str, value: f64, total: Option, message: Option<&str>) { + let bar = self.bars.entry(token.to_string()).or_insert_with(|| { + if let Some(total) = total { + self.multi_bar.add( + ProgressBar::new((total * 100.0) as u64).with_style( + ProgressStyle::with_template("[{elapsed}] {bar:40} {pos:>3}/{len:3} {msg}") + .unwrap(), + ), + ) + } else { + self.multi_bar.add(ProgressBar::new_spinner()) + } + }); + bar.set_position((value * 100.0) as u64); + if let Some(msg) = message { + bar.set_message(msg.to_string()); + } + } + + pub fn hide(&mut self) -> Result<(), Error> { + self.multi_bar.clear() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose-ffi/src/lib.rs b/crates/goose-ffi/src/lib.rs index bd2237d7..1afc97e9 100644 --- a/crates/goose-ffi/src/lib.rs +++ b/crates/goose-ffi/src/lib.rs @@ -3,7 +3,7 @@ use std::ptr; use std::sync::Arc; use futures::StreamExt; -use goose::agents::Agent; +use goose::agents::{Agent, AgentEvent}; use goose::message::Message; use goose::model::ModelConfig; use goose::providers::databricks::DatabricksProvider; @@ -256,13 +256,16 @@ pub unsafe extern "C" fn goose_agent_send_message( while let Some(message_result) = stream.next().await { match message_result { - Ok(message) => { + Ok(AgentEvent::Message(message)) => { // Get text or serialize to JSON // Note: Message doesn't have as_text method, we'll serialize to JSON if let Ok(json) = serde_json::to_string(&message) { full_response.push_str(&json); } } + Ok(AgentEvent::McpNotification(_)) => { + // TODO: Handle MCP notifications. + } Err(e) => { full_response.push_str(&format!("\nError in message stream: {}", e)); } diff --git a/crates/goose-mcp/src/computercontroller/mod.rs b/crates/goose-mcp/src/computercontroller/mod.rs index a2751852..ec8d5f61 100644 --- a/crates/goose-mcp/src/computercontroller/mod.rs +++ b/crates/goose-mcp/src/computercontroller/mod.rs @@ -6,7 +6,7 @@ use serde_json::{json, Value}; use std::{ collections::HashMap, fs, future::Future, path::PathBuf, pin::Pin, sync::Arc, sync::Mutex, }; -use tokio::process::Command; +use tokio::{process::Command, sync::mpsc}; #[cfg(unix)] use std::os::unix::fs::PermissionsExt; @@ -14,7 +14,7 @@ use std::os::unix::fs::PermissionsExt; use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, prompt::Prompt, - protocol::ServerCapabilities, + protocol::{JsonRpcMessage, ServerCapabilities}, resource::Resource, tool::{Tool, ToolAnnotations}, Content, @@ -1155,6 +1155,7 @@ impl Router for ComputerControllerRouter { &self, tool_name: &str, arguments: Value, + _notifier: mpsc::Sender, ) -> Pin, ToolError>> + Send + 'static>> { let this = self.clone(); let tool_name = tool_name.to_string(); diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index 5d7a9696..f5a12f1f 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -13,13 +13,17 @@ use std::{ path::{Path, PathBuf}, pin::Pin, }; -use tokio::process::Command; +use tokio::{ + io::{AsyncBufReadExt, BufReader}, + process::Command, + sync::mpsc, +}; use url::Url; use include_dir::{include_dir, Dir}; use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, - protocol::ServerCapabilities, + protocol::{JsonRpcMessage, JsonRpcNotification, ServerCapabilities}, resource::Resource, tool::Tool, Content, @@ -456,7 +460,11 @@ impl DeveloperRouter { } // Shell command execution with platform-specific handling - async fn bash(&self, params: Value) -> Result, ToolError> { + async fn bash( + &self, + params: Value, + notifier: mpsc::Sender, + ) -> Result, ToolError> { let command = params .get("command") @@ -488,27 +496,92 @@ impl DeveloperRouter { // Get platform-specific shell configuration let shell_config = get_shell_config(); - let cmd_with_redirect = format_command_for_platform(command); + let cmd_str = format_command_for_platform(command); // Execute the command using platform-specific shell - let child = Command::new(&shell_config.executable) + let mut child = Command::new(&shell_config.executable) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .stdin(Stdio::null()) .kill_on_drop(true) .arg(&shell_config.arg) - .arg(cmd_with_redirect) + .arg(cmd_str) .spawn() .map_err(|e| ToolError::ExecutionError(e.to_string()))?; + let stdout = child.stdout.take().unwrap(); + let stderr = child.stderr.take().unwrap(); + + let mut stdout_reader = BufReader::new(stdout); + let mut stderr_reader = BufReader::new(stderr); + + let output_task = tokio::spawn(async move { + let mut combined_output = String::new(); + + let mut stdout_buf = Vec::new(); + let mut stderr_buf = Vec::new(); + + loop { + tokio::select! { + n = stdout_reader.read_until(b'\n', &mut stdout_buf) => { + if n? == 0 { + break; + } + let line = String::from_utf8_lossy(&stdout_buf); + + notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": { + "type": "shell", + "stream": "stdout", + "output": line.to_string(), + } + })), + })) + .ok(); + + combined_output.push_str(&line); + stdout_buf.clear(); + } + n = stderr_reader.read_until(b'\n', &mut stderr_buf) => { + if n? == 0 { + break; + } + let line = String::from_utf8_lossy(&stderr_buf); + + notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": { + "type": "shell", + "stream": "stderr", + "output": line.to_string(), + } + })), + })) + .ok(); + + combined_output.push_str(&line); + stderr_buf.clear(); + } + } + } + Ok::<_, std::io::Error>(combined_output) + }); + // Wait for the command to complete and get output - let output = child - .wait_with_output() + child + .wait() .await .map_err(|e| ToolError::ExecutionError(e.to_string()))?; - let stdout_str = String::from_utf8_lossy(&output.stdout); - let output_str = stdout_str; + let output_str = match output_task.await { + Ok(result) => result.map_err(|e| ToolError::ExecutionError(e.to_string()))?, + Err(e) => return Err(ToolError::ExecutionError(e.to_string())), + }; // Check the character count of the output const MAX_CHAR_COUNT: usize = 400_000; // 409600 chars = 400KB @@ -1048,12 +1121,13 @@ impl Router for DeveloperRouter { &self, tool_name: &str, arguments: Value, + notifier: mpsc::Sender, ) -> Pin, ToolError>> + Send + 'static>> { let this = self.clone(); let tool_name = tool_name.to_string(); Box::pin(async move { match tool_name.as_str() { - "shell" => this.bash(arguments).await, + "shell" => this.bash(arguments, notifier).await, "text_editor" => this.text_editor(arguments).await, "list_windows" => this.list_windows(arguments).await, "screen_capture" => this.screen_capture(arguments).await, @@ -1195,6 +1269,10 @@ mod tests { .await } + fn dummy_sender() -> mpsc::Sender { + mpsc::channel(1).0 + } + #[tokio::test] #[serial] async fn test_shell_missing_parameters() { @@ -1202,7 +1280,7 @@ mod tests { std::env::set_current_dir(&temp_dir).unwrap(); let router = get_router().await; - let result = router.call_tool("shell", json!({})).await; + let result = router.call_tool("shell", json!({}), dummy_sender()).await; assert!(result.is_err()); let err = result.err().unwrap(); @@ -1263,6 +1341,7 @@ mod tests { "command": "view", "path": large_file_str }), + dummy_sender(), ) .await; @@ -1288,6 +1367,7 @@ mod tests { "command": "view", "path": many_chars_str }), + dummy_sender(), ) .await; @@ -1319,6 +1399,7 @@ mod tests { "path": file_path_str, "file_text": "Hello, world!" }), + dummy_sender(), ) .await .unwrap(); @@ -1331,6 +1412,7 @@ mod tests { "command": "view", "path": file_path_str }), + dummy_sender(), ) .await .unwrap(); @@ -1369,6 +1451,7 @@ mod tests { "path": file_path_str, "file_text": "Hello, world!" }), + dummy_sender(), ) .await .unwrap(); @@ -1383,6 +1466,7 @@ mod tests { "old_str": "world", "new_str": "Rust" }), + dummy_sender(), ) .await .unwrap(); @@ -1407,6 +1491,7 @@ mod tests { "command": "view", "path": file_path_str }), + dummy_sender(), ) .await .unwrap(); @@ -1444,6 +1529,7 @@ mod tests { "path": file_path_str, "file_text": "First line" }), + dummy_sender(), ) .await .unwrap(); @@ -1458,6 +1544,7 @@ mod tests { "old_str": "First line", "new_str": "Second line" }), + dummy_sender(), ) .await .unwrap(); @@ -1470,6 +1557,7 @@ mod tests { "command": "undo_edit", "path": file_path_str }), + dummy_sender(), ) .await .unwrap(); @@ -1485,6 +1573,7 @@ mod tests { "command": "view", "path": file_path_str }), + dummy_sender(), ) .await .unwrap(); @@ -1583,6 +1672,7 @@ mod tests { "path": temp_dir.path().join("secret.txt").to_str().unwrap(), "file_text": "test content" }), + dummy_sender(), ) .await; @@ -1601,6 +1691,7 @@ mod tests { "path": temp_dir.path().join("allowed.txt").to_str().unwrap(), "file_text": "test content" }), + dummy_sender(), ) .await; @@ -1642,6 +1733,7 @@ mod tests { json!({ "command": format!("cat {}", secret_file_path.to_str().unwrap()) }), + dummy_sender(), ) .await; @@ -1658,6 +1750,7 @@ mod tests { json!({ "command": format!("cat {}", allowed_file_path.to_str().unwrap()) }), + dummy_sender(), ) .await; diff --git a/crates/goose-mcp/src/developer/shell.rs b/crates/goose-mcp/src/developer/shell.rs index 34e531f2..cb60f9ba 100644 --- a/crates/goose-mcp/src/developer/shell.rs +++ b/crates/goose-mcp/src/developer/shell.rs @@ -4,7 +4,6 @@ use std::env; pub struct ShellConfig { pub executable: String, pub arg: String, - pub redirect_syntax: String, } impl Default for ShellConfig { @@ -14,13 +13,11 @@ impl Default for ShellConfig { Self { executable: "powershell.exe".to_string(), arg: "-NoProfile -NonInteractive -Command".to_string(), - redirect_syntax: "2>&1".to_string(), } } else { Self { executable: "bash".to_string(), arg: "-c".to_string(), - redirect_syntax: "2>&1".to_string(), } } } @@ -31,13 +28,12 @@ pub fn get_shell_config() -> ShellConfig { } pub fn format_command_for_platform(command: &str) -> String { - let config = get_shell_config(); if cfg!(windows) { // For PowerShell, wrap the command in braces to handle special characters - format!("{{ {} }} {}", command, config.redirect_syntax) + format!("{{ {} }}", command) } else { // For other shells, no braces needed - format!("{} {}", command, config.redirect_syntax) + command.to_string() } } diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index 710f42ae..1f1aeae7 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -7,6 +7,7 @@ use base64::Engine; use chrono::NaiveDate; use indoc::indoc; use lazy_static::lazy_static; +use mcp_core::protocol::JsonRpcMessage; use mcp_core::tool::ToolAnnotations; use oauth_pkce::PkceOAuth2Client; use regex::Regex; @@ -14,6 +15,7 @@ use serde_json::{json, Value}; use std::io::Cursor; use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc}; use storage::CredentialsManager; +use tokio::sync::mpsc; use mcp_core::content::Content; use mcp_core::{ @@ -3281,6 +3283,7 @@ impl Router for GoogleDriveRouter { &self, tool_name: &str, arguments: Value, + _notifier: mpsc::Sender, ) -> Pin, ToolError>> + Send + 'static>> { let this = self.clone(); let tool_name = tool_name.to_string(); diff --git a/crates/goose-mcp/src/jetbrains/mod.rs b/crates/goose-mcp/src/jetbrains/mod.rs index 0cdf8018..c015b9de 100644 --- a/crates/goose-mcp/src/jetbrains/mod.rs +++ b/crates/goose-mcp/src/jetbrains/mod.rs @@ -5,7 +5,7 @@ use mcp_core::{ content::Content, handler::{PromptError, ResourceError, ToolError}, prompt::Prompt, - protocol::ServerCapabilities, + protocol::{JsonRpcMessage, ServerCapabilities}, resource::Resource, role::Role, tool::Tool, @@ -16,7 +16,7 @@ use serde_json::Value; use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, Mutex}; use tokio::time::{sleep, Duration}; use tracing::error; @@ -158,6 +158,7 @@ impl Router for JetBrainsRouter { &self, tool_name: &str, arguments: Value, + _notifier: mpsc::Sender, ) -> Pin, ToolError>> + Send + 'static>> { let this = self.clone(); let tool_name = tool_name.to_string(); diff --git a/crates/goose-mcp/src/memory/mod.rs b/crates/goose-mcp/src/memory/mod.rs index 24dff4f1..8c814478 100644 --- a/crates/goose-mcp/src/memory/mod.rs +++ b/crates/goose-mcp/src/memory/mod.rs @@ -10,11 +10,12 @@ use std::{ path::PathBuf, pin::Pin, }; +use tokio::sync::mpsc; use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, prompt::Prompt, - protocol::ServerCapabilities, + protocol::{JsonRpcMessage, ServerCapabilities}, resource::Resource, tool::{Tool, ToolAnnotations, ToolCall}, Content, @@ -520,6 +521,7 @@ impl Router for MemoryRouter { &self, tool_name: &str, arguments: Value, + _notifier: mpsc::Sender, ) -> Pin, ToolError>> + Send + 'static>> { let this = self.clone(); let tool_name = tool_name.to_string(); diff --git a/crates/goose-mcp/src/tutorial/mod.rs b/crates/goose-mcp/src/tutorial/mod.rs index b2c26906..ea9e32f0 100644 --- a/crates/goose-mcp/src/tutorial/mod.rs +++ b/crates/goose-mcp/src/tutorial/mod.rs @@ -3,11 +3,12 @@ use include_dir::{include_dir, Dir}; use indoc::formatdoc; use serde_json::{json, Value}; use std::{future::Future, pin::Pin}; +use tokio::sync::mpsc; use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, prompt::Prompt, - protocol::ServerCapabilities, + protocol::{JsonRpcMessage, ServerCapabilities}, resource::Resource, role::Role, tool::{Tool, ToolAnnotations}, @@ -130,6 +131,7 @@ impl Router for TutorialRouter { &self, tool_name: &str, arguments: Value, + _notifier: mpsc::Sender, ) -> Pin, ToolError>> + Send + 'static>> { let this = self.clone(); let tool_name = tool_name.to_string(); diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index ef84dc58..ed92834d 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -10,7 +10,7 @@ use axum::{ use bytes::Bytes; use futures::{stream::StreamExt, Stream}; use goose::{ - agents::SessionConfig, + agents::{AgentEvent, SessionConfig}, message::{Message, MessageContent}, permission::permission_confirmation::PrincipalType, }; @@ -18,7 +18,7 @@ use goose::{ permission::{Permission, PermissionConfirmation}, session, }; -use mcp_core::{role::Role, Content, ToolResult}; +use mcp_core::{protocol::JsonRpcMessage, role::Role, Content, ToolResult}; use serde::{Deserialize, Serialize}; use serde_json::json; use serde_json::Value; @@ -79,9 +79,19 @@ impl IntoResponse for SseResponse { #[derive(Debug, Serialize)] #[serde(tag = "type")] enum MessageEvent { - Message { message: Message }, - Error { error: String }, - Finish { reason: String }, + Message { + message: Message, + }, + Error { + error: String, + }, + Finish { + reason: String, + }, + Notification { + request_id: String, + message: JsonRpcMessage, + }, } async fn stream_event( @@ -200,7 +210,7 @@ async fn handler( tokio::select! { response = timeout(Duration::from_millis(500), stream.next()) => { match response { - Ok(Some(Ok(message))) => { + Ok(Some(Ok(AgentEvent::Message(message)))) => { all_messages.push(message.clone()); if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await { tracing::error!("Error sending message through channel: {}", e); @@ -223,6 +233,20 @@ async fn handler( } }); } + Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => { + if let Err(e) = stream_event(MessageEvent::Notification{ + request_id: request_id.clone(), + message: n, + }, &tx).await { + tracing::error!("Error sending message through channel: {}", e); + let _ = stream_event( + MessageEvent::Error { + error: e.to_string(), + }, + &tx, + ).await; + } + } Ok(Some(Err(e))) => { tracing::error!("Error processing message: {}", e); let _ = stream_event( @@ -317,7 +341,7 @@ async fn ask_handler( while let Some(response) = stream.next().await { match response { - Ok(message) => { + Ok(AgentEvent::Message(message)) => { if message.role == Role::Assistant { for content in &message.content { if let MessageContent::Text(text) = content { @@ -328,6 +352,10 @@ async fn ask_handler( } } } + Ok(AgentEvent::McpNotification(n)) => { + // Handle notifications if needed + tracing::info!("Received notification: {:?}", n); + } Err(e) => { tracing::error!("Error processing as_ai message: {}", e); return Err(StatusCode::INTERNAL_SERVER_ERROR); diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 6809dd73..4aa3f129 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -71,10 +71,10 @@ aws-sdk-bedrockruntime = "1.74.0" # For GCP Vertex AI provider auth jsonwebtoken = "9.3.1" -# Added blake3 hashing library as a dependency blake3 = "1.5" fs2 = "0.4.3" futures-util = "0.3.31" +tokio-stream = "0.1.17" # Vector database for tool selection lancedb = "0.13" diff --git a/crates/goose/examples/agent.rs b/crates/goose/examples/agent.rs index bc3badac..c5ac11cb 100644 --- a/crates/goose/examples/agent.rs +++ b/crates/goose/examples/agent.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use dotenv::dotenv; use futures::StreamExt; -use goose::agents::{Agent, ExtensionConfig}; +use goose::agents::{Agent, AgentEvent, ExtensionConfig}; use goose::config::{DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT}; use goose::message::Message; use goose::providers::databricks::DatabricksProvider; @@ -20,10 +20,11 @@ async fn main() { let config = ExtensionConfig::stdio( "developer", - "./target/debug/developer", + "./target/debug/goose", DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT, - ); + ) + .with_args(vec!["mcp", "developer"]); agent.add_extension(config).await.unwrap(); println!("Extensions:"); @@ -35,11 +36,8 @@ async fn main() { .with_text("can you summarize the readme.md in this dir using just a haiku?")]; let mut stream = agent.reply(&messages, None).await.unwrap(); - while let Some(message) = stream.next().await { - println!( - "{}", - serde_json::to_string_pretty(&message.unwrap()).unwrap() - ); + while let Some(Ok(AgentEvent::Message(message))) = stream.next().await { + println!("{}", serde_json::to_string_pretty(&message).unwrap()); println!("\n"); } } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 02bda342..99160300 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -1,9 +1,14 @@ use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use anyhow::{anyhow, Result}; use futures::stream::BoxStream; -use futures::TryStreamExt; +use futures::{FutureExt, Stream, TryStreamExt}; +use futures_util::stream; +use futures_util::stream::StreamExt; +use mcp_core::protocol::JsonRpcMessage; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::message::Message; @@ -39,7 +44,7 @@ use mcp_core::{ use super::platform_tools; use super::router_tools; -use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; +use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; /// The main goose Agent pub struct Agent { @@ -56,6 +61,12 @@ pub struct Agent { pub(super) router_tool_selector: Mutex>>>, } +#[derive(Clone, Debug)] +pub enum AgentEvent { + Message(Message), + McpNotification((String, JsonRpcMessage)), +} + impl Agent { pub fn new() -> Self { // Create channels with buffer size 32 (adjust if needed) @@ -100,6 +111,40 @@ impl Default for Agent { } } +pub enum ToolStreamItem { + Message(JsonRpcMessage), + Result(T), +} + +pub type ToolStream = Pin>>> + Send>>; + +// tool_stream combines a stream of JsonRpcMessages with a future representing the +// final result of the tool call. MCP notifications are not request-scoped, but +// this lets us capture all notifications emitted during the tool call for +// simpler consumption +pub fn tool_stream(rx: S, done: F) -> ToolStream +where + S: Stream + Send + Unpin + 'static, + F: Future>> + Send + 'static, +{ + Box::pin(async_stream::stream! { + tokio::pin!(done); + let mut rx = rx; + + loop { + tokio::select! { + Some(msg) = rx.next() => { + yield ToolStreamItem::Message(msg); + } + r = &mut done => { + yield ToolStreamItem::Result(r); + break; + } + } + } + }) +} + impl Agent { /// Get a reference count clone to the provider pub async fn provider(&self) -> Result, anyhow::Error> { @@ -143,7 +188,7 @@ impl Agent { &self, tool_call: mcp_core::tool::ToolCall, request_id: String, - ) -> (String, Result, ToolError>) { + ) -> (String, Result) { // Check if this tool call should be allowed based on repetition monitoring if let Some(monitor) = self.tool_monitor.lock().await.as_mut() { let tool_call_info = ToolCall::new(tool_call.name.clone(), tool_call.arguments.clone()); @@ -171,52 +216,65 @@ impl Agent { .and_then(|v| v.as_str()) .unwrap_or("") .to_string(); - return self + let (request_id, result) = self .manage_extensions(action, extension_name, request_id) .await; + + return (request_id, Ok(ToolCallResult::from(result))); } let extension_manager = self.extension_manager.lock().await; - let result = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { + let result: ToolCallResult = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { // Check if the tool is read_resource and handle it separately - extension_manager - .read_resource(tool_call.arguments.clone()) - .await + ToolCallResult::from( + extension_manager + .read_resource(tool_call.arguments.clone()) + .await, + ) } else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME { - extension_manager - .list_resources(tool_call.arguments.clone()) - .await + ToolCallResult::from( + extension_manager + .list_resources(tool_call.arguments.clone()) + .await, + ) } else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME { - extension_manager.search_available_extensions().await + ToolCallResult::from(extension_manager.search_available_extensions().await) } else if self.is_frontend_tool(&tool_call.name).await { // For frontend tools, return an error indicating we need frontend execution - Err(ToolError::ExecutionError( + ToolCallResult::from(Err(ToolError::ExecutionError( "Frontend tool execution required".to_string(), - )) + ))) } else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME { let selector = self.router_tool_selector.lock().await.clone(); - if let Some(selector) = selector { + ToolCallResult::from(if let Some(selector) = selector { selector.select_tools(tool_call.arguments.clone()).await } else { Err(ToolError::ExecutionError( "Encountered vector search error.".to_string(), )) - } + }) } else { - extension_manager + // Clone the result to ensure no references to extension_manager are returned + let result = extension_manager .dispatch_tool_call(tool_call.clone()) - .await + .await; + match result { + Ok(call_result) => call_result, + Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))), + } }; - debug!( - "input" = serde_json::to_string(&tool_call).unwrap(), - "output" = serde_json::to_string(&result).unwrap(), - ); - - // Process the response to handle large text content - let processed_result = super::large_response_handler::process_tool_response(result); - - (request_id, processed_result) + ( + request_id, + Ok(ToolCallResult { + notification_stream: result.notification_stream, + result: Box::new( + result + .result + .map(super::large_response_handler::process_tool_response), + ), + }), + ) } pub(super) async fn manage_extensions( @@ -466,7 +524,7 @@ impl Agent { &self, messages: &[Message], session: Option, - ) -> anyhow::Result>> { + ) -> anyhow::Result>> { let mut messages = messages.to_vec(); let reply_span = tracing::Span::current(); @@ -532,9 +590,8 @@ impl Agent { } } } - // Yield the assistant's response with frontend tool requests filtered out - yield filtered_response.clone(); + yield AgentEvent::Message(filtered_response.clone()); tokio::task::yield_now().await; @@ -556,7 +613,7 @@ impl Agent { // execution is yeield back to this reply loop, and is of the same Message // type, so we can yield that back up to be handled while let Some(msg) = frontend_tool_stream.try_next().await? { - yield msg; + yield AgentEvent::Message(msg); } // Clone goose_mode once before the match to avoid move issues @@ -584,13 +641,23 @@ impl Agent { self.provider().await?).await; // Handle pre-approved and read-only tools in parallel - let mut tool_futures: Vec = Vec::new(); + let mut tool_futures: Vec<(String, ToolStream)> = Vec::new(); // Skip the confirmation for approved tools for request in &permission_check_result.approved { if let Ok(tool_call) = request.tool_call.clone() { - let tool_future = self.dispatch_tool_call(tool_call, request.id.clone()); - tool_futures.push(Box::pin(tool_future)); + let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await; + + tool_futures.push((req_id, match tool_result { + Ok(result) => tool_stream( + result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())), + result.result, + ), + Err(e) => tool_stream( + Box::new(stream::empty()), + futures::future::ready(Err(e)), + ), + })); } } @@ -618,7 +685,7 @@ impl Agent { // type, so we can yield the Message back up to be handled and grab any // confirmations or denials while let Some(msg) = tool_approval_stream.try_next().await? { - yield msg; + yield AgentEvent::Message(msg); } tool_futures = { @@ -628,16 +695,30 @@ impl Agent { futures_lock.drain(..).collect::>() }; - // Wait for all tool calls to complete - let results = futures::future::join_all(tool_futures).await; + let with_id = tool_futures + .into_iter() + .map(|(request_id, stream)| { + stream.map(move |item| (request_id.clone(), item)) + }) + .collect::>(); + + let mut combined = stream::select_all(with_id); + let mut all_install_successful = true; - for (request_id, output) in results.into_iter() { - if enable_extension_request_ids.contains(&request_id) && output.is_err(){ - all_install_successful = false; + while let Some((request_id, item)) = combined.next().await { + match item { + ToolStreamItem::Result(output) => { + if enable_extension_request_ids.contains(&request_id) && output.is_err(){ + all_install_successful = false; + } + let mut response = message_tool_response.lock().await; + *response = response.clone().with_tool_response(request_id, output); + }, + ToolStreamItem::Message(msg) => { + yield AgentEvent::McpNotification((request_id, msg)) + } } - let mut response = message_tool_response.lock().await; - *response = response.clone().with_tool_response(request_id, output); } // Update system prompt and tools if installations were successful @@ -647,7 +728,7 @@ impl Agent { } let final_message_tool_resp = message_tool_response.lock().await.clone(); - yield final_message_tool_resp.clone(); + yield AgentEvent::Message(final_message_tool_resp.clone()); messages.push(response); messages.push(final_message_tool_resp); @@ -656,15 +737,15 @@ impl Agent { // At this point, the last message should be a user message // because call to provider led to context length exceeded error // Immediately yield a special message and break - yield Message::assistant().with_context_length_exceeded( + yield AgentEvent::Message(Message::assistant().with_context_length_exceeded( "The context length of the model has been exceeded. Please start a new session and try again.", - ); + )); break; }, Err(e) => { // Create an error message & terminate the stream error!("Error: {}", e); - yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")); + yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."))); break; } } diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 4bc4d746..aa8d1172 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -1,8 +1,7 @@ use anyhow::Result; use chrono::{DateTime, TimeZone, Utc}; -use futures::future; use futures::stream::{FuturesUnordered, StreamExt}; -use mcp_client::McpService; +use futures::{future, FutureExt}; use mcp_core::protocol::GetPromptResult; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -10,15 +9,17 @@ use std::sync::LazyLock; use std::time::Duration; use tokio::sync::Mutex; use tokio::task; -use tracing::{debug, error, warn}; +use tokio_stream::wrappers::ReceiverStream; +use tracing::{error, warn}; use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo}; +use super::tool_execution::ToolCallResult; use crate::agents::extension::Envs; use crate::config::{Config, ExtensionConfigManager}; use crate::prompt_template; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; use mcp_client::transport::{SseTransport, StdioTransport, Transport}; -use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult}; +use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError}; use serde_json::Value; // By default, we set it to Jan 1, 2020 if the resource does not have a timestamp @@ -113,7 +114,8 @@ impl ExtensionManager { /// Add a new MCP extension based on the provided client type // TODO IMPORTANT need to ensure this times out if the extension command is broken! pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> { - let sanitized_name = normalize(config.key().to_string()); + let config_name = config.key().to_string(); + let sanitized_name = normalize(config_name.clone()); /// Helper function to merge environment variables from direct envs and keychain-stored env_keys async fn merge_environments( @@ -183,13 +185,15 @@ impl ExtensionManager { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; let transport = SseTransport::new(uri, all_envs); let handle = transport.start().await?; - let service = McpService::with_timeout( - handle, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), - ); - Box::new(McpClient::new(service)) + Box::new( + McpClient::connect( + handle, + Duration::from_secs( + timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), + ), + ) + .await?, + ) } ExtensionConfig::Stdio { cmd, @@ -202,13 +206,15 @@ impl ExtensionManager { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; let transport = StdioTransport::new(cmd, args.to_vec(), all_envs); let handle = transport.start().await?; - let service = McpService::with_timeout( - handle, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), - ); - Box::new(McpClient::new(service)) + Box::new( + McpClient::connect( + handle, + Duration::from_secs( + timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), + ), + ) + .await?, + ) } ExtensionConfig::Builtin { name, @@ -227,13 +233,15 @@ impl ExtensionManager { HashMap::new(), ); let handle = transport.start().await?; - let service = McpService::with_timeout( - handle, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), - ); - Box::new(McpClient::new(service)) + Box::new( + McpClient::connect( + handle, + Duration::from_secs( + timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), + ), + ) + .await?, + ) } _ => unreachable!(), }; @@ -609,7 +617,7 @@ impl ExtensionManager { } } - pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> ToolResult> { + pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> Result { // Dispatch tool call based on the prefix naming convention let (client_name, client) = self .get_client_for_tool(&tool_call.name) @@ -620,22 +628,26 @@ impl ExtensionManager { .name .strip_prefix(client_name) .and_then(|s| s.strip_prefix("__")) - .ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?; + .ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))? + .to_string(); - let client_guard = client.lock().await; + let arguments = tool_call.arguments.clone(); + let client = client.clone(); + let notifications_receiver = client.lock().await.subscribe().await; - let result = client_guard - .call_tool(tool_name, tool_call.clone().arguments) - .await - .map(|result| result.content) - .map_err(|e| ToolError::ExecutionError(e.to_string())); + let fut = async move { + let client_guard = client.lock().await; + client_guard + .call_tool(&tool_name, arguments) + .await + .map(|call| call.content) + .map_err(|e| ToolError::ExecutionError(e.to_string())) + }; - debug!( - "input" = serde_json::to_string(&tool_call).unwrap(), - "output" = serde_json::to_string(&result).unwrap(), - ); - - result + Ok(ToolCallResult { + result: Box::new(fut.boxed()), + notification_stream: Some(Box::new(ReceiverStream::new(notifications_receiver))), + }) } pub async fn list_prompts_from_extension( @@ -793,10 +805,11 @@ mod tests { use mcp_client::client::Error; use mcp_client::client::McpClientTrait; use mcp_core::protocol::{ - CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult, - ListToolsResult, ReadResourceResult, + CallToolResult, GetPromptResult, InitializeResult, JsonRpcMessage, ListPromptsResult, + ListResourcesResult, ListToolsResult, ReadResourceResult, }; use serde_json::json; + use tokio::sync::mpsc; struct MockClient {} @@ -849,6 +862,10 @@ mod tests { ) -> Result { Err(Error::NotInitialized) } + + async fn subscribe(&self) -> mpsc::Receiver { + mpsc::channel(1).1 + } } #[test] @@ -970,6 +987,9 @@ mod tests { let result = extension_manager .dispatch_tool_call(invalid_tool_call) + .await + .unwrap() + .result .await; assert!(matches!( result.err().unwrap(), @@ -986,6 +1006,11 @@ mod tests { let result = extension_manager .dispatch_tool_call(invalid_tool_call) .await; - assert!(matches!(result.err().unwrap(), ToolError::NotFound(_))); + if let Err(err) = result { + let tool_err = err.downcast_ref::().expect("Expected ToolError"); + assert!(matches!(tool_err, ToolError::NotFound(_))); + } else { + panic!("Expected ToolError::NotFound"); + } } } diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 54326285..24511ac6 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -13,7 +13,7 @@ mod tool_router_index_manager; pub(crate) mod tool_vectordb; mod types; -pub use agent::Agent; +pub use agent::{Agent, AgentEvent}; pub use extension::ExtensionConfig; pub use extension_manager::ExtensionManager; pub use prompt_manager::PromptManager; diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs index beffae66..446d1f58 100644 --- a/crates/goose/src/agents/tool_execution.rs +++ b/crates/goose/src/agents/tool_execution.rs @@ -1,23 +1,35 @@ use std::future::Future; -use std::pin::Pin; use std::sync::Arc; use async_stream::try_stream; -use futures::stream::BoxStream; -use futures::StreamExt; +use futures::stream::{self, BoxStream}; +use futures::{Stream, StreamExt}; +use mcp_core::protocol::JsonRpcMessage; use tokio::sync::Mutex; use crate::config::permission::PermissionLevel; use crate::config::PermissionManager; use crate::message::{Message, ToolRequest}; use crate::permission::Permission; -use mcp_core::{Content, ToolError}; +use mcp_core::{Content, ToolResult}; -// Type alias for ToolFutures - used in the agent loop to join all futures together -pub(crate) type ToolFuture<'a> = - Pin, ToolError>)> + Send + 'a>>; -pub(crate) type ToolFuturesVec<'a> = Arc>>>; +// ToolCallResult combines the result of a tool call with an optional notification stream that +// can be used to receive notifications from the tool. +pub struct ToolCallResult { + pub result: Box>> + Send + Unpin>, + pub notification_stream: Option + Send + Unpin>>, +} +impl From>> for ToolCallResult { + fn from(result: ToolResult>) -> Self { + Self { + result: Box::new(futures::future::ready(result)), + notification_stream: None, + } + } +} + +use super::agent::{tool_stream, ToolStream}; use crate::agents::Agent; pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \ @@ -37,7 +49,7 @@ impl Agent { pub(crate) fn handle_approval_tool_requests<'a>( &'a self, tool_requests: &'a [ToolRequest], - tool_futures: ToolFuturesVec<'a>, + tool_futures: Arc>>, permission_manager: &'a mut PermissionManager, message_tool_response: Arc>, ) -> BoxStream<'a, anyhow::Result> { @@ -56,9 +68,19 @@ impl Agent { while let Some((req_id, confirmation)) = rx.recv().await { if req_id == request.id { if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow { - let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone()); + let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone()).await; let mut futures = tool_futures.lock().await; - futures.push(Box::pin(tool_future)); + + futures.push((req_id, match tool_result { + Ok(result) => tool_stream( + result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())), + result.result, + ), + Err(e) => tool_stream( + Box::new(stream::empty()), + futures::future::ready(Err(e)), + ), + })); if confirmation.permission == Permission::AlwaysAllow { permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow); diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 6c8bc856..e6b0e356 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize}; use tokio::sync::Mutex; use tokio_cron_scheduler::{job::JobId, Job, JobScheduler as TokioJobScheduler}; +use crate::agents::AgentEvent; use crate::agents::{Agent, SessionConfig}; use crate::config::{self, Config}; use crate::message::Message; @@ -1102,12 +1103,15 @@ async fn run_scheduled_job_internal( tokio::task::yield_now().await; match message_result { - Ok(msg) => { + Ok(AgentEvent::Message(msg)) => { if msg.role == mcp_core::role::Role::Assistant { tracing::info!("[Job {}] Assistant: {:?}", job.id, msg.content); } all_session_messages.push(msg); } + Ok(AgentEvent::McpNotification(_)) => { + // Handle notifications if needed + } Err(e) => { tracing::error!( "[Job {}] Error receiving message from agent: {}", diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index bb851ab4..8f474ed8 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use anyhow::Result; use futures::StreamExt; -use goose::agents::Agent; +use goose::agents::{Agent, AgentEvent}; use goose::message::Message; use goose::model::ModelConfig; use goose::providers::base::Provider; @@ -132,7 +132,10 @@ async fn run_truncate_test( let mut responses = Vec::new(); while let Some(response_result) = reply_stream.next().await { match response_result { - Ok(response) => responses.push(response), + Ok(AgentEvent::Message(response)) => responses.push(response), + Ok(AgentEvent::McpNotification(n)) => { + println!("MCP Notification: {n:?}"); + } Err(e) => { println!("Error: {:?}", e); return Err(e); diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs index 4913b952..e36abeba 100644 --- a/crates/mcp-client/examples/clients.rs +++ b/crates/mcp-client/examples/clients.rs @@ -1,7 +1,6 @@ use mcp_client::{ client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}, transport::{SseTransport, StdioTransport, Transport}, - McpService, }; use rand::Rng; use rand::SeedableRng; @@ -20,18 +19,15 @@ async fn main() -> Result<(), Box> { let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); let handle1 = transport1.start().await?; - let service1 = McpService::with_timeout(handle1, Duration::from_secs(30)); - let client1 = McpClient::new(service1); + let client1 = McpClient::connect(handle1, Duration::from_secs(30)).await?; let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); let handle2 = transport2.start().await?; - let service2 = McpService::with_timeout(handle2, Duration::from_secs(30)); - let client2 = McpClient::new(service2); + let client2 = McpClient::connect(handle2, Duration::from_secs(30)).await?; let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new()); let handle3 = transport3.start().await?; - let service3 = McpService::with_timeout(handle3, Duration::from_secs(10)); - let client3 = McpClient::new(service3); + let client3 = McpClient::connect(handle3, Duration::from_secs(10)).await?; // Initialize both clients let mut clients: Vec> = diff --git a/crates/mcp-client/examples/integration_test.rs b/crates/mcp-client/examples/integration_test.rs new file mode 100644 index 00000000..b16af1be --- /dev/null +++ b/crates/mcp-client/examples/integration_test.rs @@ -0,0 +1,122 @@ +use anyhow::Result; +use futures::lock::Mutex; +use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; +use mcp_client::transport::{SseTransport, Transport}; +use mcp_client::StdioTransport; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("eventsource_client=info".parse().unwrap()), + ) + .init(); + + test_transport(sse_transport().await?).await?; + test_transport(stdio_transport().await?).await?; + + Ok(()) +} + +async fn sse_transport() -> Result { + let port = "60053"; + + tokio::process::Command::new("npx") + .env("PORT", port) + .arg("@modelcontextprotocol/server-everything") + .arg("sse") + .spawn()?; + tokio::time::sleep(Duration::from_secs(1)).await; + + Ok(SseTransport::new( + format!("http://localhost:{}/sse", port), + HashMap::new(), + )) +} + +async fn stdio_transport() -> Result { + Ok(StdioTransport::new( + "npx", + vec!["@modelcontextprotocol/server-everything"] + .into_iter() + .map(|s| s.to_string()) + .collect(), + HashMap::new(), + )) +} + +async fn test_transport(transport: T) -> Result<()> +where + T: Transport + Send + 'static, +{ + // Start transport + let handle = transport.start().await?; + + // Create client + let mut client = McpClient::connect(handle, Duration::from_secs(10)).await?; + println!("Client created\n"); + + let mut receiver = client.subscribe().await; + let events = Arc::new(Mutex::new(Vec::new())); + let events_clone = events.clone(); + tokio::spawn(async move { + while let Some(event) = receiver.recv().await { + println!("Received event: {event:?}"); + events_clone.lock().await.push(event); + } + }); + + // Initialize + let server_info = client + .initialize( + ClientInfo { + name: "test-client".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + println!("Connected to server: {server_info:?}\n"); + + // Sleep for 100ms to allow the server to start - surprisingly this is required! + tokio::time::sleep(Duration::from_millis(500)).await; + + // List tools + let tools = client.list_tools(None).await?; + println!("Available tools: {tools:#?}\n"); + + // Call tool + let tool_result = client + .call_tool("echo", serde_json::json!({ "message": "honk" })) + .await?; + println!("Tool result: {tool_result:#?}\n"); + + let collected_eventes_before = events.lock().await.len(); + let n_steps = 5; + let long_op = client + .call_tool( + "longRunningOperation", + serde_json::json!({ "duration": 3, "steps": n_steps }), + ) + .await?; + println!("Long op result: {long_op:#?}\n"); + let collected_events_after = events.lock().await.len(); + assert_eq!(collected_events_after - collected_eventes_before, n_steps); + + // List resources + let resources = client.list_resources(None).await?; + println!("Resources: {resources:#?}\n"); + + // Read resource + let resource = client.read_resource("test://static/resource/1").await?; + println!("Resource: {resource:#?}\n"); + + Ok(()) +} diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index 360a2bbc..6e97a0a6 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -1,7 +1,6 @@ use anyhow::Result; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; use mcp_client::transport::{SseTransport, Transport}; -use mcp_client::McpService; use std::collections::HashMap; use std::time::Duration; use tracing_subscriber::EnvFilter; @@ -23,11 +22,8 @@ async fn main() -> Result<()> { // Start transport let handle = transport.start().await?; - // Create the service with timeout middleware - let service = McpService::with_timeout(handle, Duration::from_secs(3)); - // Create client - let mut client = McpClient::new(service); + let mut client = McpClient::connect(handle, Duration::from_secs(3)).await?; println!("Client created\n"); // Initialize diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index e43f036c..98793597 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use anyhow::Result; use mcp_client::{ - ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, McpService, + ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, StdioTransport, Transport, }; use std::time::Duration; @@ -25,11 +25,8 @@ async fn main() -> Result<(), ClientError> { // 2) Start the transport to get a handle let transport_handle = transport.start().await?; - // 3) Create the service with timeout middleware - let service = McpService::with_timeout(transport_handle, Duration::from_secs(10)); - - // 4) Create the client with the middleware-wrapped service - let mut client = McpClient::new(service); + // 3) Create the client with the middleware-wrapped service + let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?; // Initialize let server_info = client diff --git a/crates/mcp-client/examples/stdio_integration.rs b/crates/mcp-client/examples/stdio_integration.rs index ffdcc10c..9b367d25 100644 --- a/crates/mcp-client/examples/stdio_integration.rs +++ b/crates/mcp-client/examples/stdio_integration.rs @@ -5,7 +5,6 @@ use mcp_client::client::{ ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, }; use mcp_client::transport::{StdioTransport, Transport}; -use mcp_client::McpService; use std::collections::HashMap; use std::time::Duration; use tracing_subscriber::EnvFilter; @@ -34,11 +33,8 @@ async fn main() -> Result<(), ClientError> { // Start the transport to get a handle let transport_handle = transport.start().await.unwrap(); - // Create the service with timeout middleware - let service = McpService::with_timeout(transport_handle, Duration::from_secs(10)); - // Create client - let mut client = McpClient::new(service); + let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?; // Initialize let server_info = client diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index cb8247b9..474592fd 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -4,11 +4,16 @@ use mcp_core::protocol::{ ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND, }; use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::sync::atomic::{AtomicU64, Ordering}; +use serde_json::{json, Value}; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, +}; use thiserror::Error; -use tokio::sync::Mutex; -use tower::{Service, ServiceExt}; // for Service::ready() +use tokio::sync::{mpsc, Mutex}; +use tower::{timeout::TimeoutLayer, Layer, Service, ServiceExt}; + +use crate::{McpService, TransportHandle}; pub type BoxError = Box; @@ -97,34 +102,67 @@ pub trait McpClientTrait: Send + Sync { async fn list_prompts(&self, next_cursor: Option) -> Result; async fn get_prompt(&self, name: &str, arguments: Value) -> Result; + + async fn subscribe(&self) -> mpsc::Receiver; } /// The MCP client is the interface for MCP operations. -pub struct McpClient +pub struct McpClient where - S: Service + Clone + Send + Sync + 'static, - S::Error: Into, - S::Future: Send, + T: TransportHandle + Send + Sync + 'static, { - service: Mutex, + service: Mutex>>, next_id: AtomicU64, server_capabilities: Option, server_info: Option, + notification_subscribers: Arc>>>, } -impl McpClient +impl McpClient where - S: Service + Clone + Send + Sync + 'static, - S::Error: Into, - S::Future: Send, + T: TransportHandle + Send + Sync + 'static, { - pub fn new(service: S) -> Self { - Self { - service: Mutex::new(service), + pub async fn connect(transport: T, timeout: std::time::Duration) -> Result { + let service = McpService::new(transport.clone()); + let service_ptr = service.clone(); + let notification_subscribers = + Arc::new(Mutex::new(Vec::>::new())); + let subscribers_ptr = notification_subscribers.clone(); + + tokio::spawn(async move { + loop { + match transport.receive().await { + Ok(message) => { + tracing::info!("Received message: {:?}", message); + match message { + JsonRpcMessage::Response(JsonRpcResponse { id: Some(id), .. }) => { + service_ptr.respond(&id.to_string(), Ok(message)).await; + } + _ => { + let mut subs = subscribers_ptr.lock().await; + subs.retain(|sub| sub.try_send(message.clone()).is_ok()); + } + } + } + Err(e) => { + tracing::error!("transport error: {:?}", e); + service_ptr.hangup().await; + subscribers_ptr.lock().await.clear(); + break; + } + } + } + }); + + let middleware = TimeoutLayer::new(timeout); + + Ok(Self { + service: Mutex::new(middleware.layer(service)), next_id: AtomicU64::new(1), server_capabilities: None, server_info: None, - } + notification_subscribers, + }) } /// Send a JSON-RPC request and check we don't get an error response. @@ -134,13 +172,18 @@ where { let mut service = self.service.lock().await; service.ready().await.map_err(|_| Error::NotReady)?; - let id = self.next_id.fetch_add(1, Ordering::SeqCst); + + let mut params = params.clone(); + params["_meta"] = json!({ + "progressToken": format!("prog-{}", id), + }); + let request = JsonRpcMessage::Request(JsonRpcRequest { jsonrpc: "2.0".to_string(), id: Some(id), method: method.to_string(), - params: Some(params.clone()), + params: Some(params), }); let response_msg = service @@ -154,7 +197,7 @@ where .unwrap_or("".to_string()), method: method.to_string(), // we don't need include params because it can be really large - source: Box::new(e.into()), + source: Box::::new(e.into()), })?; match response_msg { @@ -220,7 +263,7 @@ where .unwrap_or("".to_string()), method: method.to_string(), // we don't need include params because it can be really large - source: Box::new(e.into()), + source: Box::::new(e.into()), })?; Ok(()) @@ -233,11 +276,9 @@ where } #[async_trait::async_trait] -impl McpClientTrait for McpClient +impl McpClientTrait for McpClient where - S: Service + Clone + Send + Sync + 'static, - S::Error: Into, - S::Future: Send, + T: TransportHandle + Send + Sync + 'static, { async fn initialize( &mut self, @@ -388,4 +429,10 @@ where self.send_request("prompts/get", params).await } + + async fn subscribe(&self) -> mpsc::Receiver { + let (tx, rx) = mpsc::channel(16); + self.notification_subscribers.lock().await.push(tx); + rx + } } diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs index 00aa95be..b2ea82cf 100644 --- a/crates/mcp-client/src/service.rs +++ b/crates/mcp-client/src/service.rs @@ -1,7 +1,9 @@ use futures::future::BoxFuture; -use mcp_core::protocol::JsonRpcMessage; +use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest}; +use std::collections::HashMap; use std::sync::Arc; use std::task::{Context, Poll}; +use tokio::sync::{oneshot, RwLock}; use tower::{timeout::Timeout, Service, ServiceBuilder}; use crate::transport::{Error, TransportHandle}; @@ -10,14 +12,24 @@ use crate::transport::{Error, TransportHandle}; #[derive(Clone)] pub struct McpService { inner: Arc, + pending_requests: Arc, } impl McpService { pub fn new(transport: T) -> Self { Self { inner: Arc::new(transport), + pending_requests: Arc::new(PendingRequests::default()), } } + + pub async fn respond(&self, id: &str, response: Result) { + self.pending_requests.respond(id, response).await + } + + pub async fn hangup(&self) { + self.pending_requests.broadcast_close().await + } } impl Service for McpService @@ -35,7 +47,31 @@ where fn call(&mut self, request: JsonRpcMessage) -> Self::Future { let transport = self.inner.clone(); - Box::pin(async move { transport.send(request).await }) + let pending_requests = self.pending_requests.clone(); + + Box::pin(async move { + match request { + JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) => { + // Create a channel to receive the response + let (sender, receiver) = oneshot::channel(); + pending_requests.insert(id.to_string(), sender).await; + + transport.send(request).await?; + receiver.await.map_err(|_| Error::ChannelClosed)? + } + JsonRpcMessage::Request(_) => { + // Handle notifications without waiting for a response + transport.send(request).await?; + Ok(JsonRpcMessage::Nil) + } + JsonRpcMessage::Notification(_) => { + // Handle notifications without waiting for a response + transport.send(request).await?; + Ok(JsonRpcMessage::Nil) + } + _ => Err(Error::UnsupportedMessage), + } + }) } } @@ -50,3 +86,50 @@ where .service(McpService::new(transport)) } } + +// A data structure to store pending requests and their response channels +pub struct PendingRequests { + requests: RwLock>>>, +} + +impl Default for PendingRequests { + fn default() -> Self { + Self::new() + } +} + +impl PendingRequests { + pub fn new() -> Self { + Self { + requests: RwLock::new(HashMap::new()), + } + } + + pub async fn insert(&self, id: String, sender: oneshot::Sender>) { + self.requests.write().await.insert(id, sender); + } + + pub async fn respond(&self, id: &str, response: Result) { + if let Some(tx) = self.requests.write().await.remove(id) { + let _ = tx.send(response); + } + } + + pub async fn broadcast_close(&self) { + for (_, tx) in self.requests.write().await.drain() { + let _ = tx.send(Err(Error::ChannelClosed)); + } + } + + pub async fn clear(&self) { + self.requests.write().await.clear(); + } + + pub async fn len(&self) -> usize { + self.requests.read().await.len() + } + + pub async fn is_empty(&self) -> bool { + self.len().await == 0 + } +} diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index e2a66b26..28e6d929 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -1,8 +1,7 @@ use async_trait::async_trait; use mcp_core::protocol::JsonRpcMessage; -use std::collections::HashMap; use thiserror::Error; -use tokio::sync::{mpsc, oneshot, RwLock}; +use tokio::sync::{mpsc, oneshot}; pub type BoxError = Box; /// A generic error type for transport operations. @@ -57,74 +56,20 @@ pub trait Transport { #[async_trait] pub trait TransportHandle: Send + Sync + Clone + 'static { - async fn send(&self, message: JsonRpcMessage) -> Result; + async fn send(&self, message: JsonRpcMessage) -> Result<(), Error>; + async fn receive(&self) -> Result; } -// Helper function that contains the common send implementation -pub async fn send_message( - sender: &mpsc::Sender, +pub async fn serialize_and_send( + sender: &mpsc::Sender, message: JsonRpcMessage, -) -> Result { - match message { - JsonRpcMessage::Request(request) => { - let (respond_to, response) = oneshot::channel(); - let msg = TransportMessage { - message: JsonRpcMessage::Request(request), - response_tx: Some(respond_to), - }; - sender.send(msg).await.map_err(|_| Error::ChannelClosed)?; - Ok(response.await.map_err(|_| Error::ChannelClosed)??) +) -> Result<(), Error> { + match serde_json::to_string(&message).map_err(Error::Serialization) { + Ok(msg) => sender.send(msg).await.map_err(|_| Error::ChannelClosed), + Err(e) => { + tracing::error!(error = ?e, "Error serializing message"); + Err(e) } - JsonRpcMessage::Notification(notification) => { - let msg = TransportMessage { - message: JsonRpcMessage::Notification(notification), - response_tx: None, - }; - sender.send(msg).await.map_err(|_| Error::ChannelClosed)?; - Ok(JsonRpcMessage::Nil) - } - _ => Err(Error::UnsupportedMessage), - } -} - -// A data structure to store pending requests and their response channels -pub struct PendingRequests { - requests: RwLock>>>, -} - -impl Default for PendingRequests { - fn default() -> Self { - Self::new() - } -} - -impl PendingRequests { - pub fn new() -> Self { - Self { - requests: RwLock::new(HashMap::new()), - } - } - - pub async fn insert(&self, id: String, sender: oneshot::Sender>) { - self.requests.write().await.insert(id, sender); - } - - pub async fn respond(&self, id: &str, response: Result) { - if let Some(tx) = self.requests.write().await.remove(id) { - let _ = tx.send(response); - } - } - - pub async fn clear(&self) { - self.requests.write().await.clear(); - } - - pub async fn len(&self) -> usize { - self.requests.read().await.len() - } - - pub async fn is_empty(&self) -> bool { - self.len().await == 0 } } diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index 8a564708..7a38aca9 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -1,17 +1,17 @@ -use crate::transport::{Error, PendingRequests, TransportMessage}; +use crate::transport::Error; use async_trait::async_trait; use eventsource_client::{Client, SSE}; use futures::TryStreamExt; -use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest}; +use mcp_core::protocol::JsonRpcMessage; use reqwest::Client as HttpClient; use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::{mpsc, RwLock}; +use tokio::sync::{mpsc, Mutex, RwLock}; use tokio::time::{timeout, Duration}; use tracing::warn; use url::Url; -use super::{send_message, Transport, TransportHandle}; +use super::{serialize_and_send, Transport, TransportHandle}; // Timeout for the endpoint discovery const ENDPOINT_TIMEOUT_SECS: u64 = 5; @@ -21,9 +21,9 @@ const ENDPOINT_TIMEOUT_SECS: u64 = 5; /// - Sends outgoing messages via HTTP POST (once the post endpoint is known). pub struct SseActor { /// Receives messages (requests/notifications) from the handle - receiver: mpsc::Receiver, - /// Map of request-id -> oneshot sender - pending_requests: Arc, + receiver: mpsc::Receiver, + /// Sends messages (responses) back to the handle + sender: mpsc::Sender, /// Base SSE URL sse_url: String, /// For sending HTTP POST requests @@ -34,14 +34,14 @@ pub struct SseActor { impl SseActor { pub fn new( - receiver: mpsc::Receiver, - pending_requests: Arc, + receiver: mpsc::Receiver, + sender: mpsc::Sender, sse_url: String, post_endpoint: Arc>>, ) -> Self { Self { receiver, - pending_requests, + sender, sse_url, post_endpoint, http_client: HttpClient::new(), @@ -54,15 +54,14 @@ impl SseActor { pub async fn run(self) { tokio::join!( Self::handle_incoming_messages( + self.sender, self.sse_url.clone(), - Arc::clone(&self.pending_requests), Arc::clone(&self.post_endpoint) ), Self::handle_outgoing_messages( self.receiver, self.http_client.clone(), Arc::clone(&self.post_endpoint), - Arc::clone(&self.pending_requests), ) ); } @@ -72,14 +71,13 @@ impl SseActor { /// - If a `message` event is received, parse it as `JsonRpcMessage` /// and respond to pending requests if it's a `Response`. async fn handle_incoming_messages( + sender: mpsc::Sender, sse_url: String, - pending_requests: Arc, post_endpoint: Arc>>, ) { let client = match eventsource_client::ClientBuilder::for_url(&sse_url) { Ok(builder) => builder.build(), Err(e) => { - pending_requests.clear().await; warn!("Failed to connect SSE client: {}", e); return; } @@ -105,84 +103,54 @@ impl SseActor { } // Now handle subsequent events - while let Ok(Some(event)) = stream.try_next().await { - match event { - SSE::Event(e) if e.event_type == "message" => { - // Attempt to parse the SSE data as a JsonRpcMessage - match serde_json::from_str::(&e.data) { - Ok(message) => { - match &message { - JsonRpcMessage::Response(response) => { - if let Some(id) = &response.id { - pending_requests - .respond(&id.to_string(), Ok(message)) - .await; - } + loop { + match stream.try_next().await { + Ok(Some(event)) => { + match event { + SSE::Event(e) if e.event_type == "message" => { + // Attempt to parse the SSE data as a JsonRpcMessage + match serde_json::from_str::(&e.data) { + Ok(message) => { + let _ = sender.send(message).await; } - JsonRpcMessage::Error(error) => { - if let Some(id) = &error.id { - pending_requests - .respond(&id.to_string(), Ok(message)) - .await; - } + Err(err) => { + warn!("Failed to parse SSE message: {err}"); } - _ => {} // TODO: Handle other variants (Request, etc.) } } - Err(err) => { - warn!("Failed to parse SSE message: {err}"); - } + _ => { /* ignore other events */ } } } - _ => { /* ignore other events */ } + Ok(None) => { + // Stream ended + tracing::info!("SSE stream ended."); + break; + } + Err(e) => { + warn!("Error reading SSE stream: {e}"); + break; + } } } - // SSE stream ended or errored; signal any pending requests - tracing::error!("SSE stream ended or encountered an error; clearing pending requests."); - pending_requests.clear().await; + tracing::error!("SSE stream ended or encountered an error."); } - /// Continuously receives messages from the `mpsc::Receiver`. - /// - If it's a request, store the oneshot in `pending_requests`. - /// - POST the message to the discovered endpoint (once known). async fn handle_outgoing_messages( - mut receiver: mpsc::Receiver, + mut receiver: mpsc::Receiver, http_client: HttpClient, post_endpoint: Arc>>, - pending_requests: Arc, ) { - while let Some(transport_msg) = receiver.recv().await { + while let Some(message_str) = receiver.recv().await { let post_url = match post_endpoint.read().await.as_ref() { Some(url) => url.clone(), None => { - if let Some(response_tx) = transport_msg.response_tx { - let _ = response_tx.send(Err(Error::NotConnected)); - } + // TODO: the endpoint isn't discovered yet. This shouldn't happen -- we only return the handle + // after the endpoint is set. continue; } }; - // Serialize the JSON-RPC message - let message_str = match serde_json::to_string(&transport_msg.message) { - Ok(s) => s, - Err(e) => { - if let Some(tx) = transport_msg.response_tx { - let _ = tx.send(Err(Error::Serialization(e))); - } - continue; - } - }; - - // If it's a request, store the channel so we can respond later - if let Some(response_tx) = transport_msg.response_tx { - if let JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) = - &transport_msg.message - { - pending_requests.insert(id.to_string(), response_tx).await; - } - } - // Perform the HTTP POST match http_client .post(&post_url) @@ -209,26 +177,25 @@ impl SseActor { } } - // mpsc channel closed => no more outgoing messages - let pending = pending_requests.len().await; - if pending > 0 { - tracing::error!("SSE stream ended or encountered an error with {pending} unfulfilled pending requests."); - pending_requests.clear().await; - } else { - tracing::info!("SseActor shutdown cleanly. No pending requests."); - } + tracing::info!("SseActor shut down."); } } #[derive(Clone)] pub struct SseTransportHandle { - sender: mpsc::Sender, + sender: mpsc::Sender, + receiver: Arc>>, } #[async_trait::async_trait] impl TransportHandle for SseTransportHandle { - async fn send(&self, message: JsonRpcMessage) -> Result { - send_message(&self.sender, message).await + async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> { + serialize_and_send(&self.sender, message).await + } + + async fn receive(&self) -> Result { + let mut receiver = self.receiver.lock().await; + receiver.recv().await.ok_or(Error::ChannelClosed) } } @@ -279,17 +246,13 @@ impl Transport for SseTransport { // Create a channel for outgoing TransportMessages let (tx, rx) = mpsc::channel(32); + let (otx, orx) = mpsc::channel(32); let post_endpoint: Arc>> = Arc::new(RwLock::new(None)); let post_endpoint_clone = Arc::clone(&post_endpoint); // Build the actor - let actor = SseActor::new( - rx, - Arc::new(PendingRequests::new()), - self.sse_url.clone(), - post_endpoint, - ); + let actor = SseActor::new(rx, otx, self.sse_url.clone(), post_endpoint); // Spawn the actor task tokio::spawn(actor.run()); @@ -301,7 +264,10 @@ impl Transport for SseTransport { ) .await { - Ok(_) => Ok(SseTransportHandle { sender: tx }), + Ok(_) => Ok(SseTransportHandle { + sender: tx, + receiver: Arc::new(Mutex::new(orx)), + }), Err(e) => Err(Error::SseConnection(e.to_string())), } } diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index 5895e83e..94b51aff 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -14,7 +14,7 @@ use nix::sys::signal::{kill, Signal}; #[cfg(unix)] use nix::unistd::{getpgid, Pid}; -use super::{send_message, Error, PendingRequests, Transport, TransportHandle, TransportMessage}; +use super::{serialize_and_send, Error, Transport, TransportHandle}; // Global to track process groups we've created static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1); @@ -23,8 +23,8 @@ static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1); /// /// It uses channels for message passing and handles responses asynchronously through a background task. pub struct StdioActor { - receiver: Option>, - pending_requests: Arc, + receiver: Option>, + sender: Option>, process: Child, // we store the process to keep it alive error_sender: mpsc::Sender, stdin: Option, @@ -55,11 +55,11 @@ impl StdioActor { let stdout = self.stdout.take().expect("stdout should be available"); let stdin = self.stdin.take().expect("stdin should be available"); - let receiver = self.receiver.take().expect("receiver should be available"); + let msg_inbox = self.receiver.take().expect("receiver should be available"); + let msg_outbox = self.sender.take().expect("sender should be available"); - let incoming = Self::handle_incoming_messages(stdout, self.pending_requests.clone()); - let outgoing = - Self::handle_outgoing_messages(receiver, stdin, self.pending_requests.clone()); + let incoming = Self::handle_proc_output(stdout, msg_outbox); + let outgoing = Self::handle_proc_input(stdin, msg_inbox); // take ownership of futures for tokio::select pin!(incoming); @@ -96,12 +96,9 @@ impl StdioActor { .await; } } - - // Clean up regardless of which path we took - self.pending_requests.clear().await; } - async fn handle_incoming_messages(stdout: ChildStdout, pending_requests: Arc) { + async fn handle_proc_output(stdout: ChildStdout, sender: mpsc::Sender) { let mut reader = BufReader::new(stdout); let mut line = String::new(); loop { @@ -116,20 +113,12 @@ impl StdioActor { message = ?message, "Received incoming message" ); - - match &message { - JsonRpcMessage::Response(response) => { - if let Some(id) = &response.id { - pending_requests.respond(&id.to_string(), Ok(message)).await; - } - } - JsonRpcMessage::Error(error) => { - if let Some(id) = &error.id { - pending_requests.respond(&id.to_string(), Ok(message)).await; - } - } - _ => {} // TODO: Handle other variants (Request, etc.) - } + let _ = sender.send(message).await; + } else { + tracing::warn!( + message = ?line, + "Failed to parse incoming message" + ); } line.clear(); } @@ -141,44 +130,20 @@ impl StdioActor { } } - async fn handle_outgoing_messages( - mut receiver: mpsc::Receiver, - mut stdin: ChildStdin, - pending_requests: Arc, - ) { - while let Some(mut transport_msg) = receiver.recv().await { - let message_str = match serde_json::to_string(&transport_msg.message) { - Ok(s) => s, - Err(e) => { - if let Some(tx) = transport_msg.response_tx.take() { - let _ = tx.send(Err(Error::Serialization(e))); - } - continue; - } - }; - - tracing::debug!(message = ?transport_msg.message, "Sending outgoing message"); - - if let Some(response_tx) = transport_msg.response_tx.take() { - if let JsonRpcMessage::Request(request) = &transport_msg.message { - if let Some(id) = &request.id { - pending_requests.insert(id.to_string(), response_tx).await; - } - } - } + async fn handle_proc_input(mut stdin: ChildStdin, mut receiver: mpsc::Receiver) { + while let Some(message_str) = receiver.recv().await { + tracing::debug!(message = ?message_str, "Sending outgoing message"); if let Err(e) = stdin .write_all(format!("{}\n", message_str).as_bytes()) .await { tracing::error!(error = ?e, "Error writing message to child process"); - pending_requests.clear().await; break; } if let Err(e) = stdin.flush().await { tracing::error!(error = ?e, "Error flushing message to child process"); - pending_requests.clear().await; break; } } @@ -187,18 +152,24 @@ impl StdioActor { #[derive(Clone)] pub struct StdioTransportHandle { - sender: mpsc::Sender, + sender: mpsc::Sender, // to process + receiver: Arc>>, // from process error_receiver: Arc>>, } #[async_trait::async_trait] impl TransportHandle for StdioTransportHandle { - async fn send(&self, message: JsonRpcMessage) -> Result { - let result = send_message(&self.sender, message).await; + async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> { + let result = serialize_and_send(&self.sender, message).await; // Check for any pending errors even if send is successful self.check_for_errors().await?; result } + + async fn receive(&self) -> Result { + let mut receiver = self.receiver.lock().await; + receiver.recv().await.ok_or(Error::ChannelClosed) + } } impl StdioTransportHandle { @@ -289,12 +260,13 @@ impl Transport for StdioTransport { async fn start(&self) -> Result { let (process, stdin, stdout, stderr) = self.spawn_process().await?; - let (message_tx, message_rx) = mpsc::channel(32); + let (outbox_tx, outbox_rx) = mpsc::channel(32); + let (inbox_tx, inbox_rx) = mpsc::channel(32); let (error_tx, error_rx) = mpsc::channel(1); let actor = StdioActor { - receiver: Some(message_rx), - pending_requests: Arc::new(PendingRequests::new()), + receiver: Some(outbox_rx), // client to process + sender: Some(inbox_tx), // process to client process, error_sender: error_tx, stdin: Some(stdin), @@ -305,7 +277,8 @@ impl Transport for StdioTransport { tokio::spawn(actor.run()); let handle = StdioTransportHandle { - sender: message_tx, + sender: outbox_tx, // client to process + receiver: Arc::new(Mutex::new(inbox_rx)), // process to client error_receiver: Arc::new(Mutex::new(error_rx)), }; Ok(handle) diff --git a/crates/mcp-server/src/lib.rs b/crates/mcp-server/src/lib.rs index 97594523..413d01d9 100644 --- a/crates/mcp-server/src/lib.rs +++ b/crates/mcp-server/src/lib.rs @@ -4,9 +4,13 @@ use std::{ }; use futures::{Future, Stream}; -use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse}; +use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcResponse}; use pin_project::pin_project; -use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; +use router::McpRequest; +use tokio::{ + io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}, + sync::mpsc, +}; use tower_service::Service; mod errors; @@ -123,7 +127,7 @@ pub struct Server { impl Server where - S: Service + Send, + S: Service + Send, S::Error: Into, S::Future: Send, { @@ -134,8 +138,8 @@ where // TODO transport trait instead of byte transport if we implement others pub async fn run(self, mut transport: ByteTransport) -> Result<(), ServerError> where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, { use futures::StreamExt; let mut service = self.service; @@ -160,7 +164,22 @@ where ); // Process the request using our service - let response = match service.call(request).await { + let (notify_tx, mut notify_rx) = mpsc::channel(256); + let mcp_request = McpRequest { + request, + notifier: notify_tx, + }; + + let transport_fut = tokio::spawn(async move { + while let Some(notification) = notify_rx.recv().await { + if transport.write_message(notification).await.is_err() { + break; + } + } + transport + }); + + let response = match service.call(mcp_request).await { Ok(resp) => resp, Err(e) => { let error_msg = e.into().to_string(); @@ -178,6 +197,16 @@ where } }; + transport = match transport_fut.await { + Ok(transport) => transport, + Err(e) => { + tracing::error!(error = %e, "Failed to spawn transport task"); + return Err(ServerError::Transport(TransportError::Io( + e.into(), + ))); + } + }; + // Serialize response for logging let response_json = serde_json::to_string(&response) .unwrap_or_else(|_| "Failed to serialize response".to_string()); @@ -247,7 +276,7 @@ where // Any router implements this pub trait BoundedService: Service< - JsonRpcRequest, + McpRequest, Response = JsonRpcResponse, Error = BoxError, Future = Pin> + Send>>, @@ -259,7 +288,7 @@ pub trait BoundedService: // Implement it for any type that meets the bounds impl BoundedService for T where T: Service< - JsonRpcRequest, + McpRequest, Response = JsonRpcResponse, Error = BoxError, Future = Pin> + Send>>, diff --git a/crates/mcp-server/src/main.rs b/crates/mcp-server/src/main.rs index ff2b7adf..f09032b0 100644 --- a/crates/mcp-server/src/main.rs +++ b/crates/mcp-server/src/main.rs @@ -2,12 +2,14 @@ use anyhow::Result; use mcp_core::content::Content; use mcp_core::handler::{PromptError, ResourceError}; use mcp_core::prompt::{Prompt, PromptArgument}; +use mcp_core::protocol::JsonRpcMessage; use mcp_core::tool::ToolAnnotations; use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool}; use mcp_server::router::{CapabilitiesBuilder, RouterService}; use mcp_server::{ByteTransport, Router, Server}; use serde_json::Value; use std::{future::Future, pin::Pin, sync::Arc}; +use tokio::sync::mpsc; use tokio::{ io::{stdin, stdout}, sync::Mutex, @@ -124,6 +126,7 @@ impl Router for CounterRouter { &self, tool_name: &str, _arguments: Value, + _notifier: mpsc::Sender, ) -> Pin, ToolError>> + Send + 'static>> { let this = self.clone(); let tool_name = tool_name.to_string(); diff --git a/crates/mcp-server/src/router.rs b/crates/mcp-server/src/router.rs index 0931966f..6370bd1f 100644 --- a/crates/mcp-server/src/router.rs +++ b/crates/mcp-server/src/router.rs @@ -11,14 +11,15 @@ use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, prompt::{Prompt, PromptMessage, PromptMessageRole}, protocol::{ - CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcRequest, - JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult, + CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcMessage, + JsonRpcRequest, JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult, PromptsCapability, ReadResourceResult, ResourcesCapability, ServerCapabilities, ToolsCapability, }, ResourceContents, }; use serde_json::Value; +use tokio::sync::mpsc; use tower_service::Service; use crate::{BoxError, RouterError}; @@ -91,6 +92,7 @@ pub trait Router: Send + Sync + 'static { &self, tool_name: &str, arguments: Value, + notifier: mpsc::Sender, ) -> Pin, ToolError>> + Send + 'static>>; fn list_resources(&self) -> Vec; fn read_resource( @@ -159,6 +161,7 @@ pub trait Router: Send + Sync + 'static { fn handle_tools_call( &self, req: JsonRpcRequest, + notifier: mpsc::Sender, ) -> impl Future> + Send { async move { let params = req @@ -172,7 +175,7 @@ pub trait Router: Send + Sync + 'static { let arguments = params.get("arguments").cloned().unwrap_or(Value::Null); - let result = match self.call_tool(name, arguments).await { + let result = match self.call_tool(name, arguments, notifier).await { Ok(result) => CallToolResult { content: result, is_error: None, @@ -394,7 +397,12 @@ pub trait Router: Send + Sync + 'static { pub struct RouterService(pub T); -impl Service for RouterService +pub struct McpRequest { + pub request: JsonRpcRequest, + pub notifier: mpsc::Sender, +} + +impl Service for RouterService where T: Router + Clone + Send + Sync + 'static, { @@ -406,21 +414,21 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, req: JsonRpcRequest) -> Self::Future { + fn call(&mut self, req: McpRequest) -> Self::Future { let this = self.0.clone(); Box::pin(async move { - let result = match req.method.as_str() { - "initialize" => this.handle_initialize(req).await, - "tools/list" => this.handle_tools_list(req).await, - "tools/call" => this.handle_tools_call(req).await, - "resources/list" => this.handle_resources_list(req).await, - "resources/read" => this.handle_resources_read(req).await, - "prompts/list" => this.handle_prompts_list(req).await, - "prompts/get" => this.handle_prompts_get(req).await, + let result = match req.request.method.as_str() { + "initialize" => this.handle_initialize(req.request).await, + "tools/list" => this.handle_tools_list(req.request).await, + "tools/call" => this.handle_tools_call(req.request, req.notifier).await, + "resources/list" => this.handle_resources_list(req.request).await, + "resources/read" => this.handle_resources_read(req.request).await, + "prompts/list" => this.handle_prompts_list(req.request).await, + "prompts/get" => this.handle_prompts_get(req.request).await, _ => { - let mut response = this.create_response(req.id); - response.error = Some(RouterError::MethodNotFound(req.method).into()); + let mut response = this.create_response(req.request.id); + response.error = Some(RouterError::MethodNotFound(req.request.method).into()); Ok(response) } }; diff --git a/ui/desktop/.env b/ui/desktop/.env index 68502576..e408e241 100644 --- a/ui/desktop/.env +++ b/ui/desktop/.env @@ -1,4 +1,4 @@ VITE_START_EMBEDDED_SERVER=yes GOOSE_PROVIDER__TYPE=openai GOOSE_PROVIDER__HOST=https://api.openai.com -GOOSE_PROVIDER__MODEL=gpt-4o \ No newline at end of file +GOOSE_PROVIDER__MODEL=gpt-4o diff --git a/ui/desktop/src/components/ChatView.tsx b/ui/desktop/src/components/ChatView.tsx index 0e54d0a1..a1aea580 100644 --- a/ui/desktop/src/components/ChatView.tsx +++ b/ui/desktop/src/components/ChatView.tsx @@ -148,6 +148,7 @@ function ChatContent({ handleInputChange: _handleInputChange, handleSubmit: _submitMessage, updateMessageStreamBody, + notifications, } = useMessageStream({ api: getApiUrl('/reply'), initialMessages: chat.messages, @@ -492,6 +493,16 @@ function ChatContent({ const handleDragOver = (e: React.DragEvent) => { e.preventDefault(); }; + + const toolCallNotifications = notifications.reduce((map, item) => { + const key = item.request_id; + if (!map.has(key)) { + map.set(key, []); + } + map.get(key).push(item); + return map; + }, new Map()); + return (
{/* Loader when generating recipe */} @@ -571,6 +582,7 @@ function ChatContent({ const updatedMessages = [...messages, newMessage]; setMessages(updatedMessages); }} + toolCallNotifications={toolCallNotifications} /> )} @@ -578,6 +590,7 @@ function ChatContent({
))} + {error && (
diff --git a/ui/desktop/src/components/GooseMessage.tsx b/ui/desktop/src/components/GooseMessage.tsx index 6066386b..5b7058d9 100644 --- a/ui/desktop/src/components/GooseMessage.tsx +++ b/ui/desktop/src/components/GooseMessage.tsx @@ -17,6 +17,7 @@ import { } from '../types/message'; import ToolCallConfirmation from './ToolCallConfirmation'; import MessageCopyLink from './MessageCopyLink'; +import { NotificationEvent } from '../hooks/useMessageStream'; interface GooseMessageProps { // messages up to this index are presumed to be "history" from a resumed session, this is used to track older tool confirmation requests @@ -25,6 +26,7 @@ interface GooseMessageProps { message: Message; messages: Message[]; metadata?: string[]; + toolCallNotifications: Map; append: (value: string) => void; appendMessage: (message: Message) => void; } @@ -34,6 +36,7 @@ export default function GooseMessage({ message, metadata, messages, + toolCallNotifications, append, appendMessage, }: GooseMessageProps) { @@ -158,6 +161,7 @@ export default function GooseMessage({ } toolRequest={toolRequest} toolResponse={toolResponsesMap.get(toolRequest.id)} + notifications={toolCallNotifications.get(toolRequest.id)} />
))} diff --git a/ui/desktop/src/components/ToolCallWithResponse.tsx b/ui/desktop/src/components/ToolCallWithResponse.tsx index 721a39de..e33082ff 100644 --- a/ui/desktop/src/components/ToolCallWithResponse.tsx +++ b/ui/desktop/src/components/ToolCallWithResponse.tsx @@ -1,4 +1,4 @@ -import React from 'react'; +import React, { useEffect, useRef } from 'react'; import { Card } from './ui/card'; import { ToolCallArguments, ToolCallArgumentValue } from './ToolCallArguments'; import MarkdownContent from './MarkdownContent'; @@ -6,17 +6,20 @@ import { Content, ToolRequestMessageContent, ToolResponseMessageContent } from ' import { snakeToTitleCase } from '../utils'; import Dot, { LoadingStatus } from './ui/Dot'; import Expand from './ui/Expand'; +import { NotificationEvent } from '../hooks/useMessageStream'; interface ToolCallWithResponseProps { isCancelledMessage: boolean; toolRequest: ToolRequestMessageContent; toolResponse?: ToolResponseMessageContent; + notifications?: NotificationEvent[]; } export default function ToolCallWithResponse({ isCancelledMessage, toolRequest, toolResponse, + notifications, }: ToolCallWithResponseProps) { const toolCall = toolRequest.toolCall.status === 'success' ? toolRequest.toolCall.value : null; if (!toolCall) { @@ -26,7 +29,7 @@ export default function ToolCallWithResponse({ return (
- +
); @@ -47,8 +50,9 @@ function ToolCallExpandable({ children, className = '', }: ToolCallExpandableProps) { - const [isExpanded, setIsExpanded] = React.useState(isStartExpanded); - const toggleExpand = () => setIsExpanded((prev) => !prev); + const [isExpandedState, setIsExpanded] = React.useState(null); + const isExpanded = isExpandedState === null ? isStartExpanded : isExpandedState; + const toggleExpand = () => setIsExpanded(!isExpanded); React.useEffect(() => { if (isForceExpand) setIsExpanded(true); }, [isForceExpand]); @@ -71,9 +75,42 @@ interface ToolCallViewProps { arguments: Record; }; toolResponse?: ToolResponseMessageContent; + notifications?: NotificationEvent[]; } -function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallViewProps) { +interface Progress { + progress: number; + progressToken: string; + total?: number; + message?: string; +} + +const logToString = (logMessage: NotificationEvent) => { + const params = logMessage.message.params; + + // Special case for the developer system shell logs + if ( + params && + params.data && + typeof params.data === 'object' && + 'output' in params.data && + 'stream' in params.data + ) { + return `[${params.data.stream}] ${params.data.output}`; + } + + return typeof params.data === 'string' ? params.data : JSON.stringify(params.data); +}; + +const notificationToProgress = (notification: NotificationEvent): Progress => + notification.message.params as unknown as Progress; + +function ToolCallView({ + isCancelledMessage, + toolCall, + toolResponse, + notifications, +}: ToolCallViewProps) { const responseStyle = localStorage.getItem('response_style'); const isExpandToolDetails = (() => { switch (responseStyle) { @@ -103,6 +140,29 @@ function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallVi })) : []; + const logs = notifications + ?.filter((notification) => notification.message.method === 'notifications/message') + .map(logToString); + + const progress = notifications + ?.filter((notification) => notification.message.method === 'notifications/progress') + .map(notificationToProgress) + .reduce((map, item) => { + const key = item.progressToken; + if (!map.has(key)) { + map.set(key, []); + } + map.get(key)!.push(item); + return map; + }, new Map()); + + const progressEntries = [...(progress?.values() || [])].map( + (entries) => entries.sort((a, b) => b.progress - a.progress)[0] + ); + + const isRenderingProgress = + loadingStatus === 'loading' && (progressEntries.length > 0 || (logs || []).length > 0); + const isShouldExpand = isExpandToolDetails || toolResults.some((v) => v.isExpandToolResults); // Function to create a compact representation of arguments @@ -136,7 +196,7 @@ function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallVi return ( @@ -156,6 +216,24 @@ function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallVi
)} + {logs && logs.length > 0 && ( +
+ +
+ )} + + {toolResults.length === 0 && + progressEntries.length > 0 && + progressEntries.map((entry, index) => ( +
+ +
+ ))} + {/* Tool Output */} {!isCancelledMessage && ( <> @@ -234,3 +312,76 @@ function ToolResultView({ result, isStartExpanded }: ToolResultViewProps) { ); } + +function ToolLogsView({ + logs, + working, + isStartExpanded, +}: { + logs: string[]; + working: boolean; + isStartExpanded?: boolean; +}) { + const boxRef = useRef(null); + + // Whenever logs update, jump to the newest entry + useEffect(() => { + if (boxRef.current) { + boxRef.current.scrollTop = boxRef.current.scrollHeight; + } + }, [logs]); + + return ( + + Logs + {working && ( +
+ +
+ )} + + } + isStartExpanded={isStartExpanded} + > +
+ {logs.map((log, i) => ( + + {log} + + ))} +
+
+ ); +} + +const ProgressBar = ({ progress, total, message }: Omit) => { + const isDeterminate = typeof total === 'number'; + const percent = isDeterminate ? Math.min((progress / total!) * 100, 100) : 0; + + return ( +
+ {message &&
{message}
} + +
+ {isDeterminate ? ( +
+ ) : ( +
+ )} +
+
+ ); +}; diff --git a/ui/desktop/src/hooks/useMessageStream.ts b/ui/desktop/src/hooks/useMessageStream.ts index e7bb9718..87d311ac 100644 --- a/ui/desktop/src/hooks/useMessageStream.ts +++ b/ui/desktop/src/hooks/useMessageStream.ts @@ -6,11 +6,25 @@ import { Message, createUserMessage, hasCompletedToolCalls } from '../types/mess // Ensure TextDecoder is available in the global scope const TextDecoder = globalThis.TextDecoder; +type JsonValue = string | number | boolean | null | JsonValue[] | { [key: string]: JsonValue }; + +export interface NotificationEvent { + type: 'Notification'; + request_id: string; + message: { + method: string; + params: { + [key: string]: JsonValue; + }; + }; +} + // Event types for SSE stream type MessageEvent = | { type: 'Message'; message: Message } | { type: 'Error'; error: string } - | { type: 'Finish'; reason: string }; + | { type: 'Finish'; reason: string } + | NotificationEvent; export interface UseMessageStreamOptions { /** @@ -124,6 +138,8 @@ export interface UseMessageStreamHelpers { /** Modify body (session id and/or work dir mid-stream) **/ updateMessageStreamBody?: (newBody: object) => void; + + notifications: NotificationEvent[]; } /** @@ -151,6 +167,8 @@ export function useMessageStream({ fallbackData: initialMessages, }); + const [notifications, setNotifications] = useState([]); + // expose a way to update the body so we can update the session id when CLE occurs const updateMessageStreamBody = useCallback((newBody: object) => { extraMetadataRef.current.body = { @@ -247,6 +265,14 @@ export function useMessageStream({ break; } + case 'Notification': { + const newNotification = { + ...parsedEvent, + }; + setNotifications((prev) => [...prev, newNotification]); + break; + } + case 'Error': throw new Error(parsedEvent.error); @@ -516,5 +542,6 @@ export function useMessageStream({ isLoading: isLoading || false, addToolResult, updateMessageStreamBody, + notifications, }; } diff --git a/ui/desktop/tailwind.config.ts b/ui/desktop/tailwind.config.ts index 6955a3d4..1a62cc50 100644 --- a/ui/desktop/tailwind.config.ts +++ b/ui/desktop/tailwind.config.ts @@ -44,10 +44,16 @@ export default { '0%': { transform: 'rotate(0deg)' }, '100%': { transform: 'rotate(360deg)' }, }, + indeterminate: { + '0%': { left: '-40%', width: '40%' }, + '50%': { left: '20%', width: '60%' }, + '100%': { left: '100%', width: '80%' }, + }, }, animation: { 'shimmer-pulse': 'shimmer 4s ease-in-out infinite', 'gradient-loader': 'loader 750ms ease-in-out infinite', + indeterminate: 'indeterminate 1.5s infinite linear', }, colors: { bgApp: 'var(--background-app)',