feat: Handle MCP server notification messages (#2613)

Co-authored-by: Michael Neale <michael.neale@gmail.com>
This commit is contained in:
Jack Amadeo
2025-05-30 11:50:14 -04:00
committed by GitHub
parent eeb61ace22
commit 03e5549b54
40 changed files with 1186 additions and 443 deletions

2
Cargo.lock generated
View File

@@ -3435,6 +3435,7 @@ dependencies = [
"tokenizers", "tokenizers",
"tokio", "tokio",
"tokio-cron-scheduler", "tokio-cron-scheduler",
"tokio-stream",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"url", "url",
@@ -3486,6 +3487,7 @@ dependencies = [
"goose", "goose",
"goose-bench", "goose-bench",
"goose-mcp", "goose-mcp",
"indicatif",
"mcp-client", "mcp-client",
"mcp-core", "mcp-core",
"mcp-server", "mcp-server",

View File

@@ -55,6 +55,7 @@ regex = "1.11.1"
minijinja = "2.8.0" minijinja = "2.8.0"
nix = { version = "0.30.1", features = ["process", "signal"] } nix = { version = "0.30.1", features = ["process", "signal"] }
tar = "0.4" tar = "0.4"
indicatif = "0.17.11"
[target.'cfg(target_os = "windows")'.dependencies] [target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] } winapi = { version = "0.3", features = ["wincred"] }

View File

@@ -7,6 +7,7 @@ mod thinking;
pub use builder::{build_session, SessionBuilderConfig}; pub use builder::{build_session, SessionBuilderConfig};
use console::Color; use console::Color;
use goose::agents::AgentEvent;
use goose::permission::permission_confirmation::PrincipalType; use goose::permission::permission_confirmation::PrincipalType;
use goose::permission::Permission; use goose::permission::Permission;
use goose::permission::PermissionConfirmation; use goose::permission::PermissionConfirmation;
@@ -26,6 +27,8 @@ use input::InputResult;
use mcp_core::handler::ToolError; use mcp_core::handler::ToolError;
use mcp_core::prompt::PromptMessage; use mcp_core::prompt::PromptMessage;
use mcp_core::protocol::JsonRpcMessage;
use mcp_core::protocol::JsonRpcNotification;
use rand::{distributions::Alphanumeric, Rng}; use rand::{distributions::Alphanumeric, Rng};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
@@ -713,12 +716,15 @@ impl Session {
) )
.await?; .await?;
let mut progress_bars = output::McpSpinners::new();
use futures::StreamExt; use futures::StreamExt;
loop { loop {
tokio::select! { tokio::select! {
result = stream.next() => { result = stream.next() => {
let _ = progress_bars.hide();
match result { match result {
Some(Ok(message)) => { Some(Ok(AgentEvent::Message(message))) => {
// If it's a confirmation request, get approval but otherwise do not render/persist // If it's a confirmation request, get approval but otherwise do not render/persist
if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() { if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() {
output::hide_thinking(); output::hide_thinking();
@@ -846,6 +852,51 @@ impl Session {
if interactive {output::show_thinking()}; if interactive {output::show_thinking()};
} }
} }
Some(Ok(AgentEvent::McpNotification((_id, message)))) => {
if let JsonRpcMessage::Notification(JsonRpcNotification{
method,
params: Some(Value::Object(o)),
..
}) = message {
match method.as_str() {
"notifications/message" => {
let data = o.get("data").unwrap_or(&Value::Null);
let message = match data {
Value::String(s) => s.clone(),
Value::Object(o) => {
if let Some(Value::String(output)) = o.get("output") {
output.to_owned()
} else {
data.to_string()
}
},
v => {
v.to_string()
},
};
// output::render_text_no_newlines(&message, None, true);
progress_bars.log(&message);
},
"notifications/progress" => {
let progress = o.get("progress").and_then(|v| v.as_f64());
let token = o.get("progressToken").map(|v| v.to_string());
let message = o.get("message").and_then(|v| v.as_str());
let total = o
.get("total")
.and_then(|v| v.as_f64());
if let (Some(progress), Some(token)) = (progress, token) {
progress_bars.update(
token.as_str(),
progress,
total,
message,
);
}
},
_ => (),
}
}
}
Some(Err(e)) => { Some(Err(e)) => {
eprintln!("Error: {}", e); eprintln!("Error: {}", e);
drop(stream); drop(stream);
@@ -872,6 +923,7 @@ impl Session {
} }
} }
} }
Ok(()) Ok(())
} }

View File

@@ -2,12 +2,15 @@ use bat::WrappingMode;
use console::{style, Color}; use console::{style, Color};
use goose::config::Config; use goose::config::Config;
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use mcp_core::prompt::PromptArgument; use mcp_core::prompt::PromptArgument;
use mcp_core::tool::ToolCall; use mcp_core::tool::ToolCall;
use serde_json::Value; use serde_json::Value;
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::Error;
use std::path::Path; use std::path::Path;
use std::time::Duration;
// Re-export theme for use in main // Re-export theme for use in main
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
@@ -144,6 +147,10 @@ pub fn render_message(message: &Message, debug: bool) {
} }
pub fn render_text(text: &str, color: Option<Color>, dim: bool) { pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
render_text_no_newlines(format!("\n{}\n\n", text).as_str(), color, dim);
}
pub fn render_text_no_newlines(text: &str, color: Option<Color>, dim: bool) {
let mut styled_text = style(text); let mut styled_text = style(text);
if dim { if dim {
styled_text = styled_text.dim(); styled_text = styled_text.dim();
@@ -153,7 +160,7 @@ pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
} else { } else {
styled_text = styled_text.green(); styled_text = styled_text.green();
} }
println!("\n{}\n", styled_text); print!("{}", styled_text);
} }
pub fn render_enter_plan_mode() { pub fn render_enter_plan_mode() {
@@ -359,7 +366,6 @@ fn render_shell_request(call: &ToolCall, debug: bool) {
} }
_ => print_params(&call.arguments, 0, debug), _ => print_params(&call.arguments, 0, debug),
} }
println!();
} }
fn render_default_request(call: &ToolCall, debug: bool) { fn render_default_request(call: &ToolCall, debug: bool) {
@@ -568,6 +574,64 @@ pub fn display_greeting() {
println!("\nGoose is running! Enter your instructions, or try asking what goose can do.\n"); println!("\nGoose is running! Enter your instructions, or try asking what goose can do.\n");
} }
pub struct McpSpinners {
bars: HashMap<String, ProgressBar>,
log_spinner: Option<ProgressBar>,
multi_bar: MultiProgress,
}
impl McpSpinners {
pub fn new() -> Self {
McpSpinners {
bars: HashMap::new(),
log_spinner: None,
multi_bar: MultiProgress::new(),
}
}
pub fn log(&mut self, message: &str) {
let spinner = self.log_spinner.get_or_insert_with(|| {
let bar = self.multi_bar.add(
ProgressBar::new_spinner()
.with_style(
ProgressStyle::with_template("{spinner:.green} {msg}")
.unwrap()
.tick_chars("⠋⠙⠚⠛⠓⠒⠊⠉"),
)
.with_message(message.to_string()),
);
bar.enable_steady_tick(Duration::from_millis(100));
bar
});
spinner.set_message(message.to_string());
}
pub fn update(&mut self, token: &str, value: f64, total: Option<f64>, message: Option<&str>) {
let bar = self.bars.entry(token.to_string()).or_insert_with(|| {
if let Some(total) = total {
self.multi_bar.add(
ProgressBar::new((total * 100.0) as u64).with_style(
ProgressStyle::with_template("[{elapsed}] {bar:40} {pos:>3}/{len:3} {msg}")
.unwrap(),
),
)
} else {
self.multi_bar.add(ProgressBar::new_spinner())
}
});
bar.set_position((value * 100.0) as u64);
if let Some(msg) = message {
bar.set_message(msg.to_string());
}
}
pub fn hide(&mut self) -> Result<(), Error> {
self.multi_bar.clear()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@@ -3,7 +3,7 @@ use std::ptr;
use std::sync::Arc; use std::sync::Arc;
use futures::StreamExt; use futures::StreamExt;
use goose::agents::Agent; use goose::agents::{Agent, AgentEvent};
use goose::message::Message; use goose::message::Message;
use goose::model::ModelConfig; use goose::model::ModelConfig;
use goose::providers::databricks::DatabricksProvider; use goose::providers::databricks::DatabricksProvider;
@@ -256,13 +256,16 @@ pub unsafe extern "C" fn goose_agent_send_message(
while let Some(message_result) = stream.next().await { while let Some(message_result) = stream.next().await {
match message_result { match message_result {
Ok(message) => { Ok(AgentEvent::Message(message)) => {
// Get text or serialize to JSON // Get text or serialize to JSON
// Note: Message doesn't have as_text method, we'll serialize to JSON // Note: Message doesn't have as_text method, we'll serialize to JSON
if let Ok(json) = serde_json::to_string(&message) { if let Ok(json) = serde_json::to_string(&message) {
full_response.push_str(&json); full_response.push_str(&json);
} }
} }
Ok(AgentEvent::McpNotification(_)) => {
// TODO: Handle MCP notifications.
}
Err(e) => { Err(e) => {
full_response.push_str(&format!("\nError in message stream: {}", e)); full_response.push_str(&format!("\nError in message stream: {}", e));
} }

View File

@@ -6,7 +6,7 @@ use serde_json::{json, Value};
use std::{ use std::{
collections::HashMap, fs, future::Future, path::PathBuf, pin::Pin, sync::Arc, sync::Mutex, collections::HashMap, fs, future::Future, path::PathBuf, pin::Pin, sync::Arc, sync::Mutex,
}; };
use tokio::process::Command; use tokio::{process::Command, sync::mpsc};
#[cfg(unix)] #[cfg(unix)]
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
@@ -14,7 +14,7 @@ use std::os::unix::fs::PermissionsExt;
use mcp_core::{ use mcp_core::{
handler::{PromptError, ResourceError, ToolError}, handler::{PromptError, ResourceError, ToolError},
prompt::Prompt, prompt::Prompt,
protocol::ServerCapabilities, protocol::{JsonRpcMessage, ServerCapabilities},
resource::Resource, resource::Resource,
tool::{Tool, ToolAnnotations}, tool::{Tool, ToolAnnotations},
Content, Content,
@@ -1155,6 +1155,7 @@ impl Router for ComputerControllerRouter {
&self, &self,
tool_name: &str, tool_name: &str,
arguments: Value, arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> { ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone(); let this = self.clone();
let tool_name = tool_name.to_string(); let tool_name = tool_name.to_string();

View File

@@ -13,13 +13,17 @@ use std::{
path::{Path, PathBuf}, path::{Path, PathBuf},
pin::Pin, pin::Pin,
}; };
use tokio::process::Command; use tokio::{
io::{AsyncBufReadExt, BufReader},
process::Command,
sync::mpsc,
};
use url::Url; use url::Url;
use include_dir::{include_dir, Dir}; use include_dir::{include_dir, Dir};
use mcp_core::{ use mcp_core::{
handler::{PromptError, ResourceError, ToolError}, handler::{PromptError, ResourceError, ToolError},
protocol::ServerCapabilities, protocol::{JsonRpcMessage, JsonRpcNotification, ServerCapabilities},
resource::Resource, resource::Resource,
tool::Tool, tool::Tool,
Content, Content,
@@ -456,7 +460,11 @@ impl DeveloperRouter {
} }
// Shell command execution with platform-specific handling // Shell command execution with platform-specific handling
async fn bash(&self, params: Value) -> Result<Vec<Content>, ToolError> { async fn bash(
&self,
params: Value,
notifier: mpsc::Sender<JsonRpcMessage>,
) -> Result<Vec<Content>, ToolError> {
let command = let command =
params params
.get("command") .get("command")
@@ -488,27 +496,92 @@ impl DeveloperRouter {
// Get platform-specific shell configuration // Get platform-specific shell configuration
let shell_config = get_shell_config(); let shell_config = get_shell_config();
let cmd_with_redirect = format_command_for_platform(command); let cmd_str = format_command_for_platform(command);
// Execute the command using platform-specific shell // Execute the command using platform-specific shell
let child = Command::new(&shell_config.executable) let mut child = Command::new(&shell_config.executable)
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.stdin(Stdio::null()) .stdin(Stdio::null())
.kill_on_drop(true) .kill_on_drop(true)
.arg(&shell_config.arg) .arg(&shell_config.arg)
.arg(cmd_with_redirect) .arg(cmd_str)
.spawn() .spawn()
.map_err(|e| ToolError::ExecutionError(e.to_string()))?; .map_err(|e| ToolError::ExecutionError(e.to_string()))?;
let stdout = child.stdout.take().unwrap();
let stderr = child.stderr.take().unwrap();
let mut stdout_reader = BufReader::new(stdout);
let mut stderr_reader = BufReader::new(stderr);
let output_task = tokio::spawn(async move {
let mut combined_output = String::new();
let mut stdout_buf = Vec::new();
let mut stderr_buf = Vec::new();
loop {
tokio::select! {
n = stdout_reader.read_until(b'\n', &mut stdout_buf) => {
if n? == 0 {
break;
}
let line = String::from_utf8_lossy(&stdout_buf);
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "notifications/message".to_string(),
params: Some(json!({
"data": {
"type": "shell",
"stream": "stdout",
"output": line.to_string(),
}
})),
}))
.ok();
combined_output.push_str(&line);
stdout_buf.clear();
}
n = stderr_reader.read_until(b'\n', &mut stderr_buf) => {
if n? == 0 {
break;
}
let line = String::from_utf8_lossy(&stderr_buf);
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "notifications/message".to_string(),
params: Some(json!({
"data": {
"type": "shell",
"stream": "stderr",
"output": line.to_string(),
}
})),
}))
.ok();
combined_output.push_str(&line);
stderr_buf.clear();
}
}
}
Ok::<_, std::io::Error>(combined_output)
});
// Wait for the command to complete and get output // Wait for the command to complete and get output
let output = child child
.wait_with_output() .wait()
.await .await
.map_err(|e| ToolError::ExecutionError(e.to_string()))?; .map_err(|e| ToolError::ExecutionError(e.to_string()))?;
let stdout_str = String::from_utf8_lossy(&output.stdout); let output_str = match output_task.await {
let output_str = stdout_str; Ok(result) => result.map_err(|e| ToolError::ExecutionError(e.to_string()))?,
Err(e) => return Err(ToolError::ExecutionError(e.to_string())),
};
// Check the character count of the output // Check the character count of the output
const MAX_CHAR_COUNT: usize = 400_000; // 409600 chars = 400KB const MAX_CHAR_COUNT: usize = 400_000; // 409600 chars = 400KB
@@ -1048,12 +1121,13 @@ impl Router for DeveloperRouter {
&self, &self,
tool_name: &str, tool_name: &str,
arguments: Value, arguments: Value,
notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> { ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone(); let this = self.clone();
let tool_name = tool_name.to_string(); let tool_name = tool_name.to_string();
Box::pin(async move { Box::pin(async move {
match tool_name.as_str() { match tool_name.as_str() {
"shell" => this.bash(arguments).await, "shell" => this.bash(arguments, notifier).await,
"text_editor" => this.text_editor(arguments).await, "text_editor" => this.text_editor(arguments).await,
"list_windows" => this.list_windows(arguments).await, "list_windows" => this.list_windows(arguments).await,
"screen_capture" => this.screen_capture(arguments).await, "screen_capture" => this.screen_capture(arguments).await,
@@ -1195,6 +1269,10 @@ mod tests {
.await .await
} }
fn dummy_sender() -> mpsc::Sender<JsonRpcMessage> {
mpsc::channel(1).0
}
#[tokio::test] #[tokio::test]
#[serial] #[serial]
async fn test_shell_missing_parameters() { async fn test_shell_missing_parameters() {
@@ -1202,7 +1280,7 @@ mod tests {
std::env::set_current_dir(&temp_dir).unwrap(); std::env::set_current_dir(&temp_dir).unwrap();
let router = get_router().await; let router = get_router().await;
let result = router.call_tool("shell", json!({})).await; let result = router.call_tool("shell", json!({}), dummy_sender()).await;
assert!(result.is_err()); assert!(result.is_err());
let err = result.err().unwrap(); let err = result.err().unwrap();
@@ -1263,6 +1341,7 @@ mod tests {
"command": "view", "command": "view",
"path": large_file_str "path": large_file_str
}), }),
dummy_sender(),
) )
.await; .await;
@@ -1288,6 +1367,7 @@ mod tests {
"command": "view", "command": "view",
"path": many_chars_str "path": many_chars_str
}), }),
dummy_sender(),
) )
.await; .await;
@@ -1319,6 +1399,7 @@ mod tests {
"path": file_path_str, "path": file_path_str,
"file_text": "Hello, world!" "file_text": "Hello, world!"
}), }),
dummy_sender(),
) )
.await .await
.unwrap(); .unwrap();
@@ -1331,6 +1412,7 @@ mod tests {
"command": "view", "command": "view",
"path": file_path_str "path": file_path_str
}), }),
dummy_sender(),
) )
.await .await
.unwrap(); .unwrap();
@@ -1369,6 +1451,7 @@ mod tests {
"path": file_path_str, "path": file_path_str,
"file_text": "Hello, world!" "file_text": "Hello, world!"
}), }),
dummy_sender(),
) )
.await .await
.unwrap(); .unwrap();
@@ -1383,6 +1466,7 @@ mod tests {
"old_str": "world", "old_str": "world",
"new_str": "Rust" "new_str": "Rust"
}), }),
dummy_sender(),
) )
.await .await
.unwrap(); .unwrap();
@@ -1407,6 +1491,7 @@ mod tests {
"command": "view", "command": "view",
"path": file_path_str "path": file_path_str
}), }),
dummy_sender(),
) )
.await .await
.unwrap(); .unwrap();
@@ -1444,6 +1529,7 @@ mod tests {
"path": file_path_str, "path": file_path_str,
"file_text": "First line" "file_text": "First line"
}), }),
dummy_sender(),
) )
.await .await
.unwrap(); .unwrap();
@@ -1458,6 +1544,7 @@ mod tests {
"old_str": "First line", "old_str": "First line",
"new_str": "Second line" "new_str": "Second line"
}), }),
dummy_sender(),
) )
.await .await
.unwrap(); .unwrap();
@@ -1470,6 +1557,7 @@ mod tests {
"command": "undo_edit", "command": "undo_edit",
"path": file_path_str "path": file_path_str
}), }),
dummy_sender(),
) )
.await .await
.unwrap(); .unwrap();
@@ -1485,6 +1573,7 @@ mod tests {
"command": "view", "command": "view",
"path": file_path_str "path": file_path_str
}), }),
dummy_sender(),
) )
.await .await
.unwrap(); .unwrap();
@@ -1583,6 +1672,7 @@ mod tests {
"path": temp_dir.path().join("secret.txt").to_str().unwrap(), "path": temp_dir.path().join("secret.txt").to_str().unwrap(),
"file_text": "test content" "file_text": "test content"
}), }),
dummy_sender(),
) )
.await; .await;
@@ -1601,6 +1691,7 @@ mod tests {
"path": temp_dir.path().join("allowed.txt").to_str().unwrap(), "path": temp_dir.path().join("allowed.txt").to_str().unwrap(),
"file_text": "test content" "file_text": "test content"
}), }),
dummy_sender(),
) )
.await; .await;
@@ -1642,6 +1733,7 @@ mod tests {
json!({ json!({
"command": format!("cat {}", secret_file_path.to_str().unwrap()) "command": format!("cat {}", secret_file_path.to_str().unwrap())
}), }),
dummy_sender(),
) )
.await; .await;
@@ -1658,6 +1750,7 @@ mod tests {
json!({ json!({
"command": format!("cat {}", allowed_file_path.to_str().unwrap()) "command": format!("cat {}", allowed_file_path.to_str().unwrap())
}), }),
dummy_sender(),
) )
.await; .await;

View File

@@ -4,7 +4,6 @@ use std::env;
pub struct ShellConfig { pub struct ShellConfig {
pub executable: String, pub executable: String,
pub arg: String, pub arg: String,
pub redirect_syntax: String,
} }
impl Default for ShellConfig { impl Default for ShellConfig {
@@ -14,13 +13,11 @@ impl Default for ShellConfig {
Self { Self {
executable: "powershell.exe".to_string(), executable: "powershell.exe".to_string(),
arg: "-NoProfile -NonInteractive -Command".to_string(), arg: "-NoProfile -NonInteractive -Command".to_string(),
redirect_syntax: "2>&1".to_string(),
} }
} else { } else {
Self { Self {
executable: "bash".to_string(), executable: "bash".to_string(),
arg: "-c".to_string(), arg: "-c".to_string(),
redirect_syntax: "2>&1".to_string(),
} }
} }
} }
@@ -31,13 +28,12 @@ pub fn get_shell_config() -> ShellConfig {
} }
pub fn format_command_for_platform(command: &str) -> String { pub fn format_command_for_platform(command: &str) -> String {
let config = get_shell_config();
if cfg!(windows) { if cfg!(windows) {
// For PowerShell, wrap the command in braces to handle special characters // For PowerShell, wrap the command in braces to handle special characters
format!("{{ {} }} {}", command, config.redirect_syntax) format!("{{ {} }}", command)
} else { } else {
// For other shells, no braces needed // For other shells, no braces needed
format!("{} {}", command, config.redirect_syntax) command.to_string()
} }
} }

View File

@@ -7,6 +7,7 @@ use base64::Engine;
use chrono::NaiveDate; use chrono::NaiveDate;
use indoc::indoc; use indoc::indoc;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use mcp_core::protocol::JsonRpcMessage;
use mcp_core::tool::ToolAnnotations; use mcp_core::tool::ToolAnnotations;
use oauth_pkce::PkceOAuth2Client; use oauth_pkce::PkceOAuth2Client;
use regex::Regex; use regex::Regex;
@@ -14,6 +15,7 @@ use serde_json::{json, Value};
use std::io::Cursor; use std::io::Cursor;
use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc}; use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc};
use storage::CredentialsManager; use storage::CredentialsManager;
use tokio::sync::mpsc;
use mcp_core::content::Content; use mcp_core::content::Content;
use mcp_core::{ use mcp_core::{
@@ -3281,6 +3283,7 @@ impl Router for GoogleDriveRouter {
&self, &self,
tool_name: &str, tool_name: &str,
arguments: Value, arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> { ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone(); let this = self.clone();
let tool_name = tool_name.to_string(); let tool_name = tool_name.to_string();

View File

@@ -5,7 +5,7 @@ use mcp_core::{
content::Content, content::Content,
handler::{PromptError, ResourceError, ToolError}, handler::{PromptError, ResourceError, ToolError},
prompt::Prompt, prompt::Prompt,
protocol::ServerCapabilities, protocol::{JsonRpcMessage, ServerCapabilities},
resource::Resource, resource::Resource,
role::Role, role::Role,
tool::Tool, tool::Tool,
@@ -16,7 +16,7 @@ use serde_json::Value;
use std::future::Future; use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::{mpsc, Mutex};
use tokio::time::{sleep, Duration}; use tokio::time::{sleep, Duration};
use tracing::error; use tracing::error;
@@ -158,6 +158,7 @@ impl Router for JetBrainsRouter {
&self, &self,
tool_name: &str, tool_name: &str,
arguments: Value, arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> { ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone(); let this = self.clone();
let tool_name = tool_name.to_string(); let tool_name = tool_name.to_string();

View File

@@ -10,11 +10,12 @@ use std::{
path::PathBuf, path::PathBuf,
pin::Pin, pin::Pin,
}; };
use tokio::sync::mpsc;
use mcp_core::{ use mcp_core::{
handler::{PromptError, ResourceError, ToolError}, handler::{PromptError, ResourceError, ToolError},
prompt::Prompt, prompt::Prompt,
protocol::ServerCapabilities, protocol::{JsonRpcMessage, ServerCapabilities},
resource::Resource, resource::Resource,
tool::{Tool, ToolAnnotations, ToolCall}, tool::{Tool, ToolAnnotations, ToolCall},
Content, Content,
@@ -520,6 +521,7 @@ impl Router for MemoryRouter {
&self, &self,
tool_name: &str, tool_name: &str,
arguments: Value, arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> { ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone(); let this = self.clone();
let tool_name = tool_name.to_string(); let tool_name = tool_name.to_string();

View File

@@ -3,11 +3,12 @@ use include_dir::{include_dir, Dir};
use indoc::formatdoc; use indoc::formatdoc;
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::{future::Future, pin::Pin}; use std::{future::Future, pin::Pin};
use tokio::sync::mpsc;
use mcp_core::{ use mcp_core::{
handler::{PromptError, ResourceError, ToolError}, handler::{PromptError, ResourceError, ToolError},
prompt::Prompt, prompt::Prompt,
protocol::ServerCapabilities, protocol::{JsonRpcMessage, ServerCapabilities},
resource::Resource, resource::Resource,
role::Role, role::Role,
tool::{Tool, ToolAnnotations}, tool::{Tool, ToolAnnotations},
@@ -130,6 +131,7 @@ impl Router for TutorialRouter {
&self, &self,
tool_name: &str, tool_name: &str,
arguments: Value, arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> { ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone(); let this = self.clone();
let tool_name = tool_name.to_string(); let tool_name = tool_name.to_string();

View File

@@ -10,7 +10,7 @@ use axum::{
use bytes::Bytes; use bytes::Bytes;
use futures::{stream::StreamExt, Stream}; use futures::{stream::StreamExt, Stream};
use goose::{ use goose::{
agents::SessionConfig, agents::{AgentEvent, SessionConfig},
message::{Message, MessageContent}, message::{Message, MessageContent},
permission::permission_confirmation::PrincipalType, permission::permission_confirmation::PrincipalType,
}; };
@@ -18,7 +18,7 @@ use goose::{
permission::{Permission, PermissionConfirmation}, permission::{Permission, PermissionConfirmation},
session, session,
}; };
use mcp_core::{role::Role, Content, ToolResult}; use mcp_core::{protocol::JsonRpcMessage, role::Role, Content, ToolResult};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use serde_json::Value; use serde_json::Value;
@@ -79,9 +79,19 @@ impl IntoResponse for SseResponse {
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
enum MessageEvent { enum MessageEvent {
Message { message: Message }, Message {
Error { error: String }, message: Message,
Finish { reason: String }, },
Error {
error: String,
},
Finish {
reason: String,
},
Notification {
request_id: String,
message: JsonRpcMessage,
},
} }
async fn stream_event( async fn stream_event(
@@ -200,7 +210,7 @@ async fn handler(
tokio::select! { tokio::select! {
response = timeout(Duration::from_millis(500), stream.next()) => { response = timeout(Duration::from_millis(500), stream.next()) => {
match response { match response {
Ok(Some(Ok(message))) => { Ok(Some(Ok(AgentEvent::Message(message)))) => {
all_messages.push(message.clone()); all_messages.push(message.clone());
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await { if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
tracing::error!("Error sending message through channel: {}", e); tracing::error!("Error sending message through channel: {}", e);
@@ -223,6 +233,20 @@ async fn handler(
} }
}); });
} }
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
if let Err(e) = stream_event(MessageEvent::Notification{
request_id: request_id.clone(),
message: n,
}, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
let _ = stream_event(
MessageEvent::Error {
error: e.to_string(),
},
&tx,
).await;
}
}
Ok(Some(Err(e))) => { Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e); tracing::error!("Error processing message: {}", e);
let _ = stream_event( let _ = stream_event(
@@ -317,7 +341,7 @@ async fn ask_handler(
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
match response { match response {
Ok(message) => { Ok(AgentEvent::Message(message)) => {
if message.role == Role::Assistant { if message.role == Role::Assistant {
for content in &message.content { for content in &message.content {
if let MessageContent::Text(text) = content { if let MessageContent::Text(text) = content {
@@ -328,6 +352,10 @@ async fn ask_handler(
} }
} }
} }
Ok(AgentEvent::McpNotification(n)) => {
// Handle notifications if needed
tracing::info!("Received notification: {:?}", n);
}
Err(e) => { Err(e) => {
tracing::error!("Error processing as_ai message: {}", e); tracing::error!("Error processing as_ai message: {}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR); return Err(StatusCode::INTERNAL_SERVER_ERROR);

View File

@@ -71,10 +71,10 @@ aws-sdk-bedrockruntime = "1.74.0"
# For GCP Vertex AI provider auth # For GCP Vertex AI provider auth
jsonwebtoken = "9.3.1" jsonwebtoken = "9.3.1"
# Added blake3 hashing library as a dependency
blake3 = "1.5" blake3 = "1.5"
fs2 = "0.4.3" fs2 = "0.4.3"
futures-util = "0.3.31" futures-util = "0.3.31"
tokio-stream = "0.1.17"
# Vector database for tool selection # Vector database for tool selection
lancedb = "0.13" lancedb = "0.13"

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use dotenv::dotenv; use dotenv::dotenv;
use futures::StreamExt; use futures::StreamExt;
use goose::agents::{Agent, ExtensionConfig}; use goose::agents::{Agent, AgentEvent, ExtensionConfig};
use goose::config::{DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT}; use goose::config::{DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT};
use goose::message::Message; use goose::message::Message;
use goose::providers::databricks::DatabricksProvider; use goose::providers::databricks::DatabricksProvider;
@@ -20,10 +20,11 @@ async fn main() {
let config = ExtensionConfig::stdio( let config = ExtensionConfig::stdio(
"developer", "developer",
"./target/debug/developer", "./target/debug/goose",
DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_DESCRIPTION,
DEFAULT_EXTENSION_TIMEOUT, DEFAULT_EXTENSION_TIMEOUT,
); )
.with_args(vec!["mcp", "developer"]);
agent.add_extension(config).await.unwrap(); agent.add_extension(config).await.unwrap();
println!("Extensions:"); println!("Extensions:");
@@ -35,11 +36,8 @@ async fn main() {
.with_text("can you summarize the readme.md in this dir using just a haiku?")]; .with_text("can you summarize the readme.md in this dir using just a haiku?")];
let mut stream = agent.reply(&messages, None).await.unwrap(); let mut stream = agent.reply(&messages, None).await.unwrap();
while let Some(message) = stream.next().await { while let Some(Ok(AgentEvent::Message(message))) = stream.next().await {
println!( println!("{}", serde_json::to_string_pretty(&message).unwrap());
"{}",
serde_json::to_string_pretty(&message.unwrap()).unwrap()
);
println!("\n"); println!("\n");
} }
} }

View File

@@ -1,9 +1,14 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use futures::stream::BoxStream; use futures::stream::BoxStream;
use futures::TryStreamExt; use futures::{FutureExt, Stream, TryStreamExt};
use futures_util::stream;
use futures_util::stream::StreamExt;
use mcp_core::protocol::JsonRpcMessage;
use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::config::{Config, ExtensionConfigManager, PermissionManager};
use crate::message::Message; use crate::message::Message;
@@ -39,7 +44,7 @@ use mcp_core::{
use super::platform_tools; use super::platform_tools;
use super::router_tools; use super::router_tools;
use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
/// The main goose Agent /// The main goose Agent
pub struct Agent { pub struct Agent {
@@ -56,6 +61,12 @@ pub struct Agent {
pub(super) router_tool_selector: Mutex<Option<Arc<Box<dyn RouterToolSelector>>>>, pub(super) router_tool_selector: Mutex<Option<Arc<Box<dyn RouterToolSelector>>>>,
} }
#[derive(Clone, Debug)]
pub enum AgentEvent {
Message(Message),
McpNotification((String, JsonRpcMessage)),
}
impl Agent { impl Agent {
pub fn new() -> Self { pub fn new() -> Self {
// Create channels with buffer size 32 (adjust if needed) // Create channels with buffer size 32 (adjust if needed)
@@ -100,6 +111,40 @@ impl Default for Agent {
} }
} }
pub enum ToolStreamItem<T> {
Message(JsonRpcMessage),
Result(T),
}
pub type ToolStream = Pin<Box<dyn Stream<Item = ToolStreamItem<ToolResult<Vec<Content>>>> + Send>>;
// tool_stream combines a stream of JsonRpcMessages with a future representing the
// final result of the tool call. MCP notifications are not request-scoped, but
// this lets us capture all notifications emitted during the tool call for
// simpler consumption
pub fn tool_stream<S, F>(rx: S, done: F) -> ToolStream
where
S: Stream<Item = JsonRpcMessage> + Send + Unpin + 'static,
F: Future<Output = ToolResult<Vec<Content>>> + Send + 'static,
{
Box::pin(async_stream::stream! {
tokio::pin!(done);
let mut rx = rx;
loop {
tokio::select! {
Some(msg) = rx.next() => {
yield ToolStreamItem::Message(msg);
}
r = &mut done => {
yield ToolStreamItem::Result(r);
break;
}
}
}
})
}
impl Agent { impl Agent {
/// Get a reference count clone to the provider /// Get a reference count clone to the provider
pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> { pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> {
@@ -143,7 +188,7 @@ impl Agent {
&self, &self,
tool_call: mcp_core::tool::ToolCall, tool_call: mcp_core::tool::ToolCall,
request_id: String, request_id: String,
) -> (String, Result<Vec<Content>, ToolError>) { ) -> (String, Result<ToolCallResult, ToolError>) {
// Check if this tool call should be allowed based on repetition monitoring // Check if this tool call should be allowed based on repetition monitoring
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() { if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
let tool_call_info = ToolCall::new(tool_call.name.clone(), tool_call.arguments.clone()); let tool_call_info = ToolCall::new(tool_call.name.clone(), tool_call.arguments.clone());
@@ -171,52 +216,65 @@ impl Agent {
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or("") .unwrap_or("")
.to_string(); .to_string();
return self let (request_id, result) = self
.manage_extensions(action, extension_name, request_id) .manage_extensions(action, extension_name, request_id)
.await; .await;
return (request_id, Ok(ToolCallResult::from(result)));
} }
let extension_manager = self.extension_manager.lock().await; let extension_manager = self.extension_manager.lock().await;
let result = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { let result: ToolCallResult = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
// Check if the tool is read_resource and handle it separately // Check if the tool is read_resource and handle it separately
extension_manager ToolCallResult::from(
.read_resource(tool_call.arguments.clone()) extension_manager
.await .read_resource(tool_call.arguments.clone())
.await,
)
} else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME { } else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME {
extension_manager ToolCallResult::from(
.list_resources(tool_call.arguments.clone()) extension_manager
.await .list_resources(tool_call.arguments.clone())
.await,
)
} else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME { } else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME {
extension_manager.search_available_extensions().await ToolCallResult::from(extension_manager.search_available_extensions().await)
} else if self.is_frontend_tool(&tool_call.name).await { } else if self.is_frontend_tool(&tool_call.name).await {
// For frontend tools, return an error indicating we need frontend execution // For frontend tools, return an error indicating we need frontend execution
Err(ToolError::ExecutionError( ToolCallResult::from(Err(ToolError::ExecutionError(
"Frontend tool execution required".to_string(), "Frontend tool execution required".to_string(),
)) )))
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME { } else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME {
let selector = self.router_tool_selector.lock().await.clone(); let selector = self.router_tool_selector.lock().await.clone();
if let Some(selector) = selector { ToolCallResult::from(if let Some(selector) = selector {
selector.select_tools(tool_call.arguments.clone()).await selector.select_tools(tool_call.arguments.clone()).await
} else { } else {
Err(ToolError::ExecutionError( Err(ToolError::ExecutionError(
"Encountered vector search error.".to_string(), "Encountered vector search error.".to_string(),
)) ))
} })
} else { } else {
extension_manager // Clone the result to ensure no references to extension_manager are returned
let result = extension_manager
.dispatch_tool_call(tool_call.clone()) .dispatch_tool_call(tool_call.clone())
.await .await;
match result {
Ok(call_result) => call_result,
Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))),
}
}; };
debug!( (
"input" = serde_json::to_string(&tool_call).unwrap(), request_id,
"output" = serde_json::to_string(&result).unwrap(), Ok(ToolCallResult {
); notification_stream: result.notification_stream,
result: Box::new(
// Process the response to handle large text content result
let processed_result = super::large_response_handler::process_tool_response(result); .result
.map(super::large_response_handler::process_tool_response),
(request_id, processed_result) ),
}),
)
} }
pub(super) async fn manage_extensions( pub(super) async fn manage_extensions(
@@ -466,7 +524,7 @@ impl Agent {
&self, &self,
messages: &[Message], messages: &[Message],
session: Option<SessionConfig>, session: Option<SessionConfig>,
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> { ) -> anyhow::Result<BoxStream<'_, anyhow::Result<AgentEvent>>> {
let mut messages = messages.to_vec(); let mut messages = messages.to_vec();
let reply_span = tracing::Span::current(); let reply_span = tracing::Span::current();
@@ -532,9 +590,8 @@ impl Agent {
} }
} }
} }
// Yield the assistant's response with frontend tool requests filtered out // Yield the assistant's response with frontend tool requests filtered out
yield filtered_response.clone(); yield AgentEvent::Message(filtered_response.clone());
tokio::task::yield_now().await; tokio::task::yield_now().await;
@@ -556,7 +613,7 @@ impl Agent {
// execution is yeield back to this reply loop, and is of the same Message // execution is yeield back to this reply loop, and is of the same Message
// type, so we can yield that back up to be handled // type, so we can yield that back up to be handled
while let Some(msg) = frontend_tool_stream.try_next().await? { while let Some(msg) = frontend_tool_stream.try_next().await? {
yield msg; yield AgentEvent::Message(msg);
} }
// Clone goose_mode once before the match to avoid move issues // Clone goose_mode once before the match to avoid move issues
@@ -584,13 +641,23 @@ impl Agent {
self.provider().await?).await; self.provider().await?).await;
// Handle pre-approved and read-only tools in parallel // Handle pre-approved and read-only tools in parallel
let mut tool_futures: Vec<ToolFuture> = Vec::new(); let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();
// Skip the confirmation for approved tools // Skip the confirmation for approved tools
for request in &permission_check_result.approved { for request in &permission_check_result.approved {
if let Ok(tool_call) = request.tool_call.clone() { if let Ok(tool_call) = request.tool_call.clone() {
let tool_future = self.dispatch_tool_call(tool_call, request.id.clone()); let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await;
tool_futures.push(Box::pin(tool_future));
tool_futures.push((req_id, match tool_result {
Ok(result) => tool_stream(
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
result.result,
),
Err(e) => tool_stream(
Box::new(stream::empty()),
futures::future::ready(Err(e)),
),
}));
} }
} }
@@ -618,7 +685,7 @@ impl Agent {
// type, so we can yield the Message back up to be handled and grab any // type, so we can yield the Message back up to be handled and grab any
// confirmations or denials // confirmations or denials
while let Some(msg) = tool_approval_stream.try_next().await? { while let Some(msg) = tool_approval_stream.try_next().await? {
yield msg; yield AgentEvent::Message(msg);
} }
tool_futures = { tool_futures = {
@@ -628,16 +695,30 @@ impl Agent {
futures_lock.drain(..).collect::<Vec<_>>() futures_lock.drain(..).collect::<Vec<_>>()
}; };
// Wait for all tool calls to complete let with_id = tool_futures
let results = futures::future::join_all(tool_futures).await; .into_iter()
.map(|(request_id, stream)| {
stream.map(move |item| (request_id.clone(), item))
})
.collect::<Vec<_>>();
let mut combined = stream::select_all(with_id);
let mut all_install_successful = true; let mut all_install_successful = true;
for (request_id, output) in results.into_iter() { while let Some((request_id, item)) = combined.next().await {
if enable_extension_request_ids.contains(&request_id) && output.is_err(){ match item {
all_install_successful = false; ToolStreamItem::Result(output) => {
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
all_install_successful = false;
}
let mut response = message_tool_response.lock().await;
*response = response.clone().with_tool_response(request_id, output);
},
ToolStreamItem::Message(msg) => {
yield AgentEvent::McpNotification((request_id, msg))
}
} }
let mut response = message_tool_response.lock().await;
*response = response.clone().with_tool_response(request_id, output);
} }
// Update system prompt and tools if installations were successful // Update system prompt and tools if installations were successful
@@ -647,7 +728,7 @@ impl Agent {
} }
let final_message_tool_resp = message_tool_response.lock().await.clone(); let final_message_tool_resp = message_tool_response.lock().await.clone();
yield final_message_tool_resp.clone(); yield AgentEvent::Message(final_message_tool_resp.clone());
messages.push(response); messages.push(response);
messages.push(final_message_tool_resp); messages.push(final_message_tool_resp);
@@ -656,15 +737,15 @@ impl Agent {
// At this point, the last message should be a user message // At this point, the last message should be a user message
// because call to provider led to context length exceeded error // because call to provider led to context length exceeded error
// Immediately yield a special message and break // Immediately yield a special message and break
yield Message::assistant().with_context_length_exceeded( yield AgentEvent::Message(Message::assistant().with_context_length_exceeded(
"The context length of the model has been exceeded. Please start a new session and try again.", "The context length of the model has been exceeded. Please start a new session and try again.",
); ));
break; break;
}, },
Err(e) => { Err(e) => {
// Create an error message & terminate the stream // Create an error message & terminate the stream
error!("Error: {}", e); error!("Error: {}", e);
yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")); yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")));
break; break;
} }
} }

View File

@@ -1,8 +1,7 @@
use anyhow::Result; use anyhow::Result;
use chrono::{DateTime, TimeZone, Utc}; use chrono::{DateTime, TimeZone, Utc};
use futures::future;
use futures::stream::{FuturesUnordered, StreamExt}; use futures::stream::{FuturesUnordered, StreamExt};
use mcp_client::McpService; use futures::{future, FutureExt};
use mcp_core::protocol::GetPromptResult; use mcp_core::protocol::GetPromptResult;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
@@ -10,15 +9,17 @@ use std::sync::LazyLock;
use std::time::Duration; use std::time::Duration;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::task; use tokio::task;
use tracing::{debug, error, warn}; use tokio_stream::wrappers::ReceiverStream;
use tracing::{error, warn};
use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo}; use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo};
use super::tool_execution::ToolCallResult;
use crate::agents::extension::Envs; use crate::agents::extension::Envs;
use crate::config::{Config, ExtensionConfigManager}; use crate::config::{Config, ExtensionConfigManager};
use crate::prompt_template; use crate::prompt_template;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, StdioTransport, Transport}; use mcp_client::transport::{SseTransport, StdioTransport, Transport};
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult}; use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError};
use serde_json::Value; use serde_json::Value;
// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp // By default, we set it to Jan 1, 2020 if the resource does not have a timestamp
@@ -113,7 +114,8 @@ impl ExtensionManager {
/// Add a new MCP extension based on the provided client type /// Add a new MCP extension based on the provided client type
// TODO IMPORTANT need to ensure this times out if the extension command is broken! // TODO IMPORTANT need to ensure this times out if the extension command is broken!
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> { pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
let sanitized_name = normalize(config.key().to_string()); let config_name = config.key().to_string();
let sanitized_name = normalize(config_name.clone());
/// Helper function to merge environment variables from direct envs and keychain-stored env_keys /// Helper function to merge environment variables from direct envs and keychain-stored env_keys
async fn merge_environments( async fn merge_environments(
@@ -183,13 +185,15 @@ impl ExtensionManager {
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
let transport = SseTransport::new(uri, all_envs); let transport = SseTransport::new(uri, all_envs);
let handle = transport.start().await?; let handle = transport.start().await?;
let service = McpService::with_timeout( Box::new(
handle, McpClient::connect(
Duration::from_secs( handle,
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), Duration::from_secs(
), timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
); ),
Box::new(McpClient::new(service)) )
.await?,
)
} }
ExtensionConfig::Stdio { ExtensionConfig::Stdio {
cmd, cmd,
@@ -202,13 +206,15 @@ impl ExtensionManager {
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
let transport = StdioTransport::new(cmd, args.to_vec(), all_envs); let transport = StdioTransport::new(cmd, args.to_vec(), all_envs);
let handle = transport.start().await?; let handle = transport.start().await?;
let service = McpService::with_timeout( Box::new(
handle, McpClient::connect(
Duration::from_secs( handle,
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), Duration::from_secs(
), timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
); ),
Box::new(McpClient::new(service)) )
.await?,
)
} }
ExtensionConfig::Builtin { ExtensionConfig::Builtin {
name, name,
@@ -227,13 +233,15 @@ impl ExtensionManager {
HashMap::new(), HashMap::new(),
); );
let handle = transport.start().await?; let handle = transport.start().await?;
let service = McpService::with_timeout( Box::new(
handle, McpClient::connect(
Duration::from_secs( handle,
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), Duration::from_secs(
), timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
); ),
Box::new(McpClient::new(service)) )
.await?,
)
} }
_ => unreachable!(), _ => unreachable!(),
}; };
@@ -609,7 +617,7 @@ impl ExtensionManager {
} }
} }
pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> ToolResult<Vec<Content>> { pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> Result<ToolCallResult> {
// Dispatch tool call based on the prefix naming convention // Dispatch tool call based on the prefix naming convention
let (client_name, client) = self let (client_name, client) = self
.get_client_for_tool(&tool_call.name) .get_client_for_tool(&tool_call.name)
@@ -620,22 +628,26 @@ impl ExtensionManager {
.name .name
.strip_prefix(client_name) .strip_prefix(client_name)
.and_then(|s| s.strip_prefix("__")) .and_then(|s| s.strip_prefix("__"))
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?; .ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?
.to_string();
let client_guard = client.lock().await; let arguments = tool_call.arguments.clone();
let client = client.clone();
let notifications_receiver = client.lock().await.subscribe().await;
let result = client_guard let fut = async move {
.call_tool(tool_name, tool_call.clone().arguments) let client_guard = client.lock().await;
.await client_guard
.map(|result| result.content) .call_tool(&tool_name, arguments)
.map_err(|e| ToolError::ExecutionError(e.to_string())); .await
.map(|call| call.content)
.map_err(|e| ToolError::ExecutionError(e.to_string()))
};
debug!( Ok(ToolCallResult {
"input" = serde_json::to_string(&tool_call).unwrap(), result: Box::new(fut.boxed()),
"output" = serde_json::to_string(&result).unwrap(), notification_stream: Some(Box::new(ReceiverStream::new(notifications_receiver))),
); })
result
} }
pub async fn list_prompts_from_extension( pub async fn list_prompts_from_extension(
@@ -793,10 +805,11 @@ mod tests {
use mcp_client::client::Error; use mcp_client::client::Error;
use mcp_client::client::McpClientTrait; use mcp_client::client::McpClientTrait;
use mcp_core::protocol::{ use mcp_core::protocol::{
CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult, CallToolResult, GetPromptResult, InitializeResult, JsonRpcMessage, ListPromptsResult,
ListToolsResult, ReadResourceResult, ListResourcesResult, ListToolsResult, ReadResourceResult,
}; };
use serde_json::json; use serde_json::json;
use tokio::sync::mpsc;
struct MockClient {} struct MockClient {}
@@ -849,6 +862,10 @@ mod tests {
) -> Result<GetPromptResult, Error> { ) -> Result<GetPromptResult, Error> {
Err(Error::NotInitialized) Err(Error::NotInitialized)
} }
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
mpsc::channel(1).1
}
} }
#[test] #[test]
@@ -970,6 +987,9 @@ mod tests {
let result = extension_manager let result = extension_manager
.dispatch_tool_call(invalid_tool_call) .dispatch_tool_call(invalid_tool_call)
.await
.unwrap()
.result
.await; .await;
assert!(matches!( assert!(matches!(
result.err().unwrap(), result.err().unwrap(),
@@ -986,6 +1006,11 @@ mod tests {
let result = extension_manager let result = extension_manager
.dispatch_tool_call(invalid_tool_call) .dispatch_tool_call(invalid_tool_call)
.await; .await;
assert!(matches!(result.err().unwrap(), ToolError::NotFound(_))); if let Err(err) = result {
let tool_err = err.downcast_ref::<ToolError>().expect("Expected ToolError");
assert!(matches!(tool_err, ToolError::NotFound(_)));
} else {
panic!("Expected ToolError::NotFound");
}
} }
} }

View File

@@ -13,7 +13,7 @@ mod tool_router_index_manager;
pub(crate) mod tool_vectordb; pub(crate) mod tool_vectordb;
mod types; mod types;
pub use agent::Agent; pub use agent::{Agent, AgentEvent};
pub use extension::ExtensionConfig; pub use extension::ExtensionConfig;
pub use extension_manager::ExtensionManager; pub use extension_manager::ExtensionManager;
pub use prompt_manager::PromptManager; pub use prompt_manager::PromptManager;

View File

@@ -1,23 +1,35 @@
use std::future::Future; use std::future::Future;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use async_stream::try_stream; use async_stream::try_stream;
use futures::stream::BoxStream; use futures::stream::{self, BoxStream};
use futures::StreamExt; use futures::{Stream, StreamExt};
use mcp_core::protocol::JsonRpcMessage;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::config::permission::PermissionLevel; use crate::config::permission::PermissionLevel;
use crate::config::PermissionManager; use crate::config::PermissionManager;
use crate::message::{Message, ToolRequest}; use crate::message::{Message, ToolRequest};
use crate::permission::Permission; use crate::permission::Permission;
use mcp_core::{Content, ToolError}; use mcp_core::{Content, ToolResult};
// Type alias for ToolFutures - used in the agent loop to join all futures together // ToolCallResult combines the result of a tool call with an optional notification stream that
pub(crate) type ToolFuture<'a> = // can be used to receive notifications from the tool.
Pin<Box<dyn Future<Output = (String, Result<Vec<Content>, ToolError>)> + Send + 'a>>; pub struct ToolCallResult {
pub(crate) type ToolFuturesVec<'a> = Arc<Mutex<Vec<ToolFuture<'a>>>>; pub result: Box<dyn Future<Output = ToolResult<Vec<Content>>> + Send + Unpin>,
pub notification_stream: Option<Box<dyn Stream<Item = JsonRpcMessage> + Send + Unpin>>,
}
impl From<ToolResult<Vec<Content>>> for ToolCallResult {
fn from(result: ToolResult<Vec<Content>>) -> Self {
Self {
result: Box::new(futures::future::ready(result)),
notification_stream: None,
}
}
}
use super::agent::{tool_stream, ToolStream};
use crate::agents::Agent; use crate::agents::Agent;
pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \ pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \
@@ -37,7 +49,7 @@ impl Agent {
pub(crate) fn handle_approval_tool_requests<'a>( pub(crate) fn handle_approval_tool_requests<'a>(
&'a self, &'a self,
tool_requests: &'a [ToolRequest], tool_requests: &'a [ToolRequest],
tool_futures: ToolFuturesVec<'a>, tool_futures: Arc<Mutex<Vec<(String, ToolStream)>>>,
permission_manager: &'a mut PermissionManager, permission_manager: &'a mut PermissionManager,
message_tool_response: Arc<Mutex<Message>>, message_tool_response: Arc<Mutex<Message>>,
) -> BoxStream<'a, anyhow::Result<Message>> { ) -> BoxStream<'a, anyhow::Result<Message>> {
@@ -56,9 +68,19 @@ impl Agent {
while let Some((req_id, confirmation)) = rx.recv().await { while let Some((req_id, confirmation)) = rx.recv().await {
if req_id == request.id { if req_id == request.id {
if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow { if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone()); let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone()).await;
let mut futures = tool_futures.lock().await; let mut futures = tool_futures.lock().await;
futures.push(Box::pin(tool_future));
futures.push((req_id, match tool_result {
Ok(result) => tool_stream(
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
result.result,
),
Err(e) => tool_stream(
Box::new(stream::empty()),
futures::future::ready(Err(e)),
),
}));
if confirmation.permission == Permission::AlwaysAllow { if confirmation.permission == Permission::AlwaysAllow {
permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow); permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow);

View File

@@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio_cron_scheduler::{job::JobId, Job, JobScheduler as TokioJobScheduler}; use tokio_cron_scheduler::{job::JobId, Job, JobScheduler as TokioJobScheduler};
use crate::agents::AgentEvent;
use crate::agents::{Agent, SessionConfig}; use crate::agents::{Agent, SessionConfig};
use crate::config::{self, Config}; use crate::config::{self, Config};
use crate::message::Message; use crate::message::Message;
@@ -1102,12 +1103,15 @@ async fn run_scheduled_job_internal(
tokio::task::yield_now().await; tokio::task::yield_now().await;
match message_result { match message_result {
Ok(msg) => { Ok(AgentEvent::Message(msg)) => {
if msg.role == mcp_core::role::Role::Assistant { if msg.role == mcp_core::role::Role::Assistant {
tracing::info!("[Job {}] Assistant: {:?}", job.id, msg.content); tracing::info!("[Job {}] Assistant: {:?}", job.id, msg.content);
} }
all_session_messages.push(msg); all_session_messages.push(msg);
} }
Ok(AgentEvent::McpNotification(_)) => {
// Handle notifications if needed
}
Err(e) => { Err(e) => {
tracing::error!( tracing::error!(
"[Job {}] Error receiving message from agent: {}", "[Job {}] Error receiving message from agent: {}",

View File

@@ -4,7 +4,7 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use futures::StreamExt; use futures::StreamExt;
use goose::agents::Agent; use goose::agents::{Agent, AgentEvent};
use goose::message::Message; use goose::message::Message;
use goose::model::ModelConfig; use goose::model::ModelConfig;
use goose::providers::base::Provider; use goose::providers::base::Provider;
@@ -132,7 +132,10 @@ async fn run_truncate_test(
let mut responses = Vec::new(); let mut responses = Vec::new();
while let Some(response_result) = reply_stream.next().await { while let Some(response_result) = reply_stream.next().await {
match response_result { match response_result {
Ok(response) => responses.push(response), Ok(AgentEvent::Message(response)) => responses.push(response),
Ok(AgentEvent::McpNotification(n)) => {
println!("MCP Notification: {n:?}");
}
Err(e) => { Err(e) => {
println!("Error: {:?}", e); println!("Error: {:?}", e);
return Err(e); return Err(e);

View File

@@ -1,7 +1,6 @@
use mcp_client::{ use mcp_client::{
client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}, client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait},
transport::{SseTransport, StdioTransport, Transport}, transport::{SseTransport, StdioTransport, Transport},
McpService,
}; };
use rand::Rng; use rand::Rng;
use rand::SeedableRng; use rand::SeedableRng;
@@ -20,18 +19,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
let handle1 = transport1.start().await?; let handle1 = transport1.start().await?;
let service1 = McpService::with_timeout(handle1, Duration::from_secs(30)); let client1 = McpClient::connect(handle1, Duration::from_secs(30)).await?;
let client1 = McpClient::new(service1);
let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
let handle2 = transport2.start().await?; let handle2 = transport2.start().await?;
let service2 = McpService::with_timeout(handle2, Duration::from_secs(30)); let client2 = McpClient::connect(handle2, Duration::from_secs(30)).await?;
let client2 = McpClient::new(service2);
let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new()); let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new());
let handle3 = transport3.start().await?; let handle3 = transport3.start().await?;
let service3 = McpService::with_timeout(handle3, Duration::from_secs(10)); let client3 = McpClient::connect(handle3, Duration::from_secs(10)).await?;
let client3 = McpClient::new(service3);
// Initialize both clients // Initialize both clients
let mut clients: Vec<Box<dyn McpClientTrait>> = let mut clients: Vec<Box<dyn McpClientTrait>> =

View File

@@ -0,0 +1,122 @@
use anyhow::Result;
use futures::lock::Mutex;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, Transport};
use mcp_client::StdioTransport;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tracing_subscriber::EnvFilter;
#[tokio::main]
async fn main() -> Result<()> {
// Initialize logging
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::from_default_env()
.add_directive("mcp_client=debug".parse().unwrap())
.add_directive("eventsource_client=info".parse().unwrap()),
)
.init();
test_transport(sse_transport().await?).await?;
test_transport(stdio_transport().await?).await?;
Ok(())
}
async fn sse_transport() -> Result<SseTransport> {
let port = "60053";
tokio::process::Command::new("npx")
.env("PORT", port)
.arg("@modelcontextprotocol/server-everything")
.arg("sse")
.spawn()?;
tokio::time::sleep(Duration::from_secs(1)).await;
Ok(SseTransport::new(
format!("http://localhost:{}/sse", port),
HashMap::new(),
))
}
async fn stdio_transport() -> Result<StdioTransport> {
Ok(StdioTransport::new(
"npx",
vec!["@modelcontextprotocol/server-everything"]
.into_iter()
.map(|s| s.to_string())
.collect(),
HashMap::new(),
))
}
async fn test_transport<T>(transport: T) -> Result<()>
where
T: Transport + Send + 'static,
{
// Start transport
let handle = transport.start().await?;
// Create client
let mut client = McpClient::connect(handle, Duration::from_secs(10)).await?;
println!("Client created\n");
let mut receiver = client.subscribe().await;
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
tokio::spawn(async move {
while let Some(event) = receiver.recv().await {
println!("Received event: {event:?}");
events_clone.lock().await.push(event);
}
});
// Initialize
let server_info = client
.initialize(
ClientInfo {
name: "test-client".into(),
version: "1.0.0".into(),
},
ClientCapabilities::default(),
)
.await?;
println!("Connected to server: {server_info:?}\n");
// Sleep for 100ms to allow the server to start - surprisingly this is required!
tokio::time::sleep(Duration::from_millis(500)).await;
// List tools
let tools = client.list_tools(None).await?;
println!("Available tools: {tools:#?}\n");
// Call tool
let tool_result = client
.call_tool("echo", serde_json::json!({ "message": "honk" }))
.await?;
println!("Tool result: {tool_result:#?}\n");
let collected_eventes_before = events.lock().await.len();
let n_steps = 5;
let long_op = client
.call_tool(
"longRunningOperation",
serde_json::json!({ "duration": 3, "steps": n_steps }),
)
.await?;
println!("Long op result: {long_op:#?}\n");
let collected_events_after = events.lock().await.len();
assert_eq!(collected_events_after - collected_eventes_before, n_steps);
// List resources
let resources = client.list_resources(None).await?;
println!("Resources: {resources:#?}\n");
// Read resource
let resource = client.read_resource("test://static/resource/1").await?;
println!("Resource: {resource:#?}\n");
Ok(())
}

View File

@@ -1,7 +1,6 @@
use anyhow::Result; use anyhow::Result;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, Transport}; use mcp_client::transport::{SseTransport, Transport};
use mcp_client::McpService;
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
@@ -23,11 +22,8 @@ async fn main() -> Result<()> {
// Start transport // Start transport
let handle = transport.start().await?; let handle = transport.start().await?;
// Create the service with timeout middleware
let service = McpService::with_timeout(handle, Duration::from_secs(3));
// Create client // Create client
let mut client = McpClient::new(service); let mut client = McpClient::connect(handle, Duration::from_secs(3)).await?;
println!("Client created\n"); println!("Client created\n");
// Initialize // Initialize

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use anyhow::Result; use anyhow::Result;
use mcp_client::{ use mcp_client::{
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, McpService, ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
StdioTransport, Transport, StdioTransport, Transport,
}; };
use std::time::Duration; use std::time::Duration;
@@ -25,11 +25,8 @@ async fn main() -> Result<(), ClientError> {
// 2) Start the transport to get a handle // 2) Start the transport to get a handle
let transport_handle = transport.start().await?; let transport_handle = transport.start().await?;
// 3) Create the service with timeout middleware // 3) Create the client with the middleware-wrapped service
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10)); let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?;
// 4) Create the client with the middleware-wrapped service
let mut client = McpClient::new(service);
// Initialize // Initialize
let server_info = client let server_info = client

View File

@@ -5,7 +5,6 @@ use mcp_client::client::{
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
}; };
use mcp_client::transport::{StdioTransport, Transport}; use mcp_client::transport::{StdioTransport, Transport};
use mcp_client::McpService;
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
@@ -34,11 +33,8 @@ async fn main() -> Result<(), ClientError> {
// Start the transport to get a handle // Start the transport to get a handle
let transport_handle = transport.start().await.unwrap(); let transport_handle = transport.start().await.unwrap();
// Create the service with timeout middleware
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));
// Create client // Create client
let mut client = McpClient::new(service); let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?;
// Initialize // Initialize
let server_info = client let server_info = client

View File

@@ -4,11 +4,16 @@ use mcp_core::protocol::{
ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND, ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::{json, Value};
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use thiserror::Error; use thiserror::Error;
use tokio::sync::Mutex; use tokio::sync::{mpsc, Mutex};
use tower::{Service, ServiceExt}; // for Service::ready() use tower::{timeout::TimeoutLayer, Layer, Service, ServiceExt};
use crate::{McpService, TransportHandle};
pub type BoxError = Box<dyn std::error::Error + Sync + Send>; pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
@@ -97,34 +102,67 @@ pub trait McpClientTrait: Send + Sync {
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>; async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>; async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage>;
} }
/// The MCP client is the interface for MCP operations. /// The MCP client is the interface for MCP operations.
pub struct McpClient<S> pub struct McpClient<T>
where where
S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static, T: TransportHandle + Send + Sync + 'static,
S::Error: Into<Error>,
S::Future: Send,
{ {
service: Mutex<S>, service: Mutex<tower::timeout::Timeout<McpService<T>>>,
next_id: AtomicU64, next_id: AtomicU64,
server_capabilities: Option<ServerCapabilities>, server_capabilities: Option<ServerCapabilities>,
server_info: Option<Implementation>, server_info: Option<Implementation>,
notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<JsonRpcMessage>>>>,
} }
impl<S> McpClient<S> impl<T> McpClient<T>
where where
S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static, T: TransportHandle + Send + Sync + 'static,
S::Error: Into<Error>,
S::Future: Send,
{ {
pub fn new(service: S) -> Self { pub async fn connect(transport: T, timeout: std::time::Duration) -> Result<Self, Error> {
Self { let service = McpService::new(transport.clone());
service: Mutex::new(service), let service_ptr = service.clone();
let notification_subscribers =
Arc::new(Mutex::new(Vec::<mpsc::Sender<JsonRpcMessage>>::new()));
let subscribers_ptr = notification_subscribers.clone();
tokio::spawn(async move {
loop {
match transport.receive().await {
Ok(message) => {
tracing::info!("Received message: {:?}", message);
match message {
JsonRpcMessage::Response(JsonRpcResponse { id: Some(id), .. }) => {
service_ptr.respond(&id.to_string(), Ok(message)).await;
}
_ => {
let mut subs = subscribers_ptr.lock().await;
subs.retain(|sub| sub.try_send(message.clone()).is_ok());
}
}
}
Err(e) => {
tracing::error!("transport error: {:?}", e);
service_ptr.hangup().await;
subscribers_ptr.lock().await.clear();
break;
}
}
}
});
let middleware = TimeoutLayer::new(timeout);
Ok(Self {
service: Mutex::new(middleware.layer(service)),
next_id: AtomicU64::new(1), next_id: AtomicU64::new(1),
server_capabilities: None, server_capabilities: None,
server_info: None, server_info: None,
} notification_subscribers,
})
} }
/// Send a JSON-RPC request and check we don't get an error response. /// Send a JSON-RPC request and check we don't get an error response.
@@ -134,13 +172,18 @@ where
{ {
let mut service = self.service.lock().await; let mut service = self.service.lock().await;
service.ready().await.map_err(|_| Error::NotReady)?; service.ready().await.map_err(|_| Error::NotReady)?;
let id = self.next_id.fetch_add(1, Ordering::SeqCst); let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let mut params = params.clone();
params["_meta"] = json!({
"progressToken": format!("prog-{}", id),
});
let request = JsonRpcMessage::Request(JsonRpcRequest { let request = JsonRpcMessage::Request(JsonRpcRequest {
jsonrpc: "2.0".to_string(), jsonrpc: "2.0".to_string(),
id: Some(id), id: Some(id),
method: method.to_string(), method: method.to_string(),
params: Some(params.clone()), params: Some(params),
}); });
let response_msg = service let response_msg = service
@@ -154,7 +197,7 @@ where
.unwrap_or("".to_string()), .unwrap_or("".to_string()),
method: method.to_string(), method: method.to_string(),
// we don't need include params because it can be really large // we don't need include params because it can be really large
source: Box::new(e.into()), source: Box::<Error>::new(e.into()),
})?; })?;
match response_msg { match response_msg {
@@ -220,7 +263,7 @@ where
.unwrap_or("".to_string()), .unwrap_or("".to_string()),
method: method.to_string(), method: method.to_string(),
// we don't need include params because it can be really large // we don't need include params because it can be really large
source: Box::new(e.into()), source: Box::<Error>::new(e.into()),
})?; })?;
Ok(()) Ok(())
@@ -233,11 +276,9 @@ where
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl<S> McpClientTrait for McpClient<S> impl<T> McpClientTrait for McpClient<T>
where where
S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static, T: TransportHandle + Send + Sync + 'static,
S::Error: Into<Error>,
S::Future: Send,
{ {
async fn initialize( async fn initialize(
&mut self, &mut self,
@@ -388,4 +429,10 @@ where
self.send_request("prompts/get", params).await self.send_request("prompts/get", params).await
} }
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
let (tx, rx) = mpsc::channel(16);
self.notification_subscribers.lock().await.push(tx);
rx
}
} }

View File

@@ -1,7 +1,9 @@
use futures::future::BoxFuture; use futures::future::BoxFuture;
use mcp_core::protocol::JsonRpcMessage; use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest};
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tokio::sync::{oneshot, RwLock};
use tower::{timeout::Timeout, Service, ServiceBuilder}; use tower::{timeout::Timeout, Service, ServiceBuilder};
use crate::transport::{Error, TransportHandle}; use crate::transport::{Error, TransportHandle};
@@ -10,14 +12,24 @@ use crate::transport::{Error, TransportHandle};
#[derive(Clone)] #[derive(Clone)]
pub struct McpService<T: TransportHandle> { pub struct McpService<T: TransportHandle> {
inner: Arc<T>, inner: Arc<T>,
pending_requests: Arc<PendingRequests>,
} }
impl<T: TransportHandle> McpService<T> { impl<T: TransportHandle> McpService<T> {
pub fn new(transport: T) -> Self { pub fn new(transport: T) -> Self {
Self { Self {
inner: Arc::new(transport), inner: Arc::new(transport),
pending_requests: Arc::new(PendingRequests::default()),
} }
} }
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
self.pending_requests.respond(id, response).await
}
pub async fn hangup(&self) {
self.pending_requests.broadcast_close().await
}
} }
impl<T> Service<JsonRpcMessage> for McpService<T> impl<T> Service<JsonRpcMessage> for McpService<T>
@@ -35,7 +47,31 @@ where
fn call(&mut self, request: JsonRpcMessage) -> Self::Future { fn call(&mut self, request: JsonRpcMessage) -> Self::Future {
let transport = self.inner.clone(); let transport = self.inner.clone();
Box::pin(async move { transport.send(request).await }) let pending_requests = self.pending_requests.clone();
Box::pin(async move {
match request {
JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) => {
// Create a channel to receive the response
let (sender, receiver) = oneshot::channel();
pending_requests.insert(id.to_string(), sender).await;
transport.send(request).await?;
receiver.await.map_err(|_| Error::ChannelClosed)?
}
JsonRpcMessage::Request(_) => {
// Handle notifications without waiting for a response
transport.send(request).await?;
Ok(JsonRpcMessage::Nil)
}
JsonRpcMessage::Notification(_) => {
// Handle notifications without waiting for a response
transport.send(request).await?;
Ok(JsonRpcMessage::Nil)
}
_ => Err(Error::UnsupportedMessage),
}
})
} }
} }
@@ -50,3 +86,50 @@ where
.service(McpService::new(transport)) .service(McpService::new(transport))
} }
} }
// A data structure to store pending requests and their response channels
pub struct PendingRequests {
requests: RwLock<HashMap<String, oneshot::Sender<Result<JsonRpcMessage, Error>>>>,
}
impl Default for PendingRequests {
fn default() -> Self {
Self::new()
}
}
impl PendingRequests {
pub fn new() -> Self {
Self {
requests: RwLock::new(HashMap::new()),
}
}
pub async fn insert(&self, id: String, sender: oneshot::Sender<Result<JsonRpcMessage, Error>>) {
self.requests.write().await.insert(id, sender);
}
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
if let Some(tx) = self.requests.write().await.remove(id) {
let _ = tx.send(response);
}
}
pub async fn broadcast_close(&self) {
for (_, tx) in self.requests.write().await.drain() {
let _ = tx.send(Err(Error::ChannelClosed));
}
}
pub async fn clear(&self) {
self.requests.write().await.clear();
}
pub async fn len(&self) -> usize {
self.requests.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
}

View File

@@ -1,8 +1,7 @@
use async_trait::async_trait; use async_trait::async_trait;
use mcp_core::protocol::JsonRpcMessage; use mcp_core::protocol::JsonRpcMessage;
use std::collections::HashMap;
use thiserror::Error; use thiserror::Error;
use tokio::sync::{mpsc, oneshot, RwLock}; use tokio::sync::{mpsc, oneshot};
pub type BoxError = Box<dyn std::error::Error + Sync + Send>; pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
/// A generic error type for transport operations. /// A generic error type for transport operations.
@@ -57,74 +56,20 @@ pub trait Transport {
#[async_trait] #[async_trait]
pub trait TransportHandle: Send + Sync + Clone + 'static { pub trait TransportHandle: Send + Sync + Clone + 'static {
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error>; async fn send(&self, message: JsonRpcMessage) -> Result<(), Error>;
async fn receive(&self) -> Result<JsonRpcMessage, Error>;
} }
// Helper function that contains the common send implementation pub async fn serialize_and_send(
pub async fn send_message( sender: &mpsc::Sender<String>,
sender: &mpsc::Sender<TransportMessage>,
message: JsonRpcMessage, message: JsonRpcMessage,
) -> Result<JsonRpcMessage, Error> { ) -> Result<(), Error> {
match message { match serde_json::to_string(&message).map_err(Error::Serialization) {
JsonRpcMessage::Request(request) => { Ok(msg) => sender.send(msg).await.map_err(|_| Error::ChannelClosed),
let (respond_to, response) = oneshot::channel(); Err(e) => {
let msg = TransportMessage { tracing::error!(error = ?e, "Error serializing message");
message: JsonRpcMessage::Request(request), Err(e)
response_tx: Some(respond_to),
};
sender.send(msg).await.map_err(|_| Error::ChannelClosed)?;
Ok(response.await.map_err(|_| Error::ChannelClosed)??)
} }
JsonRpcMessage::Notification(notification) => {
let msg = TransportMessage {
message: JsonRpcMessage::Notification(notification),
response_tx: None,
};
sender.send(msg).await.map_err(|_| Error::ChannelClosed)?;
Ok(JsonRpcMessage::Nil)
}
_ => Err(Error::UnsupportedMessage),
}
}
// A data structure to store pending requests and their response channels
pub struct PendingRequests {
requests: RwLock<HashMap<String, oneshot::Sender<Result<JsonRpcMessage, Error>>>>,
}
impl Default for PendingRequests {
fn default() -> Self {
Self::new()
}
}
impl PendingRequests {
pub fn new() -> Self {
Self {
requests: RwLock::new(HashMap::new()),
}
}
pub async fn insert(&self, id: String, sender: oneshot::Sender<Result<JsonRpcMessage, Error>>) {
self.requests.write().await.insert(id, sender);
}
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
if let Some(tx) = self.requests.write().await.remove(id) {
let _ = tx.send(response);
}
}
pub async fn clear(&self) {
self.requests.write().await.clear();
}
pub async fn len(&self) -> usize {
self.requests.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
} }
} }

View File

@@ -1,17 +1,17 @@
use crate::transport::{Error, PendingRequests, TransportMessage}; use crate::transport::Error;
use async_trait::async_trait; use async_trait::async_trait;
use eventsource_client::{Client, SSE}; use eventsource_client::{Client, SSE};
use futures::TryStreamExt; use futures::TryStreamExt;
use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest}; use mcp_core::protocol::JsonRpcMessage;
use reqwest::Client as HttpClient; use reqwest::Client as HttpClient;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{mpsc, RwLock}; use tokio::sync::{mpsc, Mutex, RwLock};
use tokio::time::{timeout, Duration}; use tokio::time::{timeout, Duration};
use tracing::warn; use tracing::warn;
use url::Url; use url::Url;
use super::{send_message, Transport, TransportHandle}; use super::{serialize_and_send, Transport, TransportHandle};
// Timeout for the endpoint discovery // Timeout for the endpoint discovery
const ENDPOINT_TIMEOUT_SECS: u64 = 5; const ENDPOINT_TIMEOUT_SECS: u64 = 5;
@@ -21,9 +21,9 @@ const ENDPOINT_TIMEOUT_SECS: u64 = 5;
/// - Sends outgoing messages via HTTP POST (once the post endpoint is known). /// - Sends outgoing messages via HTTP POST (once the post endpoint is known).
pub struct SseActor { pub struct SseActor {
/// Receives messages (requests/notifications) from the handle /// Receives messages (requests/notifications) from the handle
receiver: mpsc::Receiver<TransportMessage>, receiver: mpsc::Receiver<String>,
/// Map of request-id -> oneshot sender /// Sends messages (responses) back to the handle
pending_requests: Arc<PendingRequests>, sender: mpsc::Sender<JsonRpcMessage>,
/// Base SSE URL /// Base SSE URL
sse_url: String, sse_url: String,
/// For sending HTTP POST requests /// For sending HTTP POST requests
@@ -34,14 +34,14 @@ pub struct SseActor {
impl SseActor { impl SseActor {
pub fn new( pub fn new(
receiver: mpsc::Receiver<TransportMessage>, receiver: mpsc::Receiver<String>,
pending_requests: Arc<PendingRequests>, sender: mpsc::Sender<JsonRpcMessage>,
sse_url: String, sse_url: String,
post_endpoint: Arc<RwLock<Option<String>>>, post_endpoint: Arc<RwLock<Option<String>>>,
) -> Self { ) -> Self {
Self { Self {
receiver, receiver,
pending_requests, sender,
sse_url, sse_url,
post_endpoint, post_endpoint,
http_client: HttpClient::new(), http_client: HttpClient::new(),
@@ -54,15 +54,14 @@ impl SseActor {
pub async fn run(self) { pub async fn run(self) {
tokio::join!( tokio::join!(
Self::handle_incoming_messages( Self::handle_incoming_messages(
self.sender,
self.sse_url.clone(), self.sse_url.clone(),
Arc::clone(&self.pending_requests),
Arc::clone(&self.post_endpoint) Arc::clone(&self.post_endpoint)
), ),
Self::handle_outgoing_messages( Self::handle_outgoing_messages(
self.receiver, self.receiver,
self.http_client.clone(), self.http_client.clone(),
Arc::clone(&self.post_endpoint), Arc::clone(&self.post_endpoint),
Arc::clone(&self.pending_requests),
) )
); );
} }
@@ -72,14 +71,13 @@ impl SseActor {
/// - If a `message` event is received, parse it as `JsonRpcMessage` /// - If a `message` event is received, parse it as `JsonRpcMessage`
/// and respond to pending requests if it's a `Response`. /// and respond to pending requests if it's a `Response`.
async fn handle_incoming_messages( async fn handle_incoming_messages(
sender: mpsc::Sender<JsonRpcMessage>,
sse_url: String, sse_url: String,
pending_requests: Arc<PendingRequests>,
post_endpoint: Arc<RwLock<Option<String>>>, post_endpoint: Arc<RwLock<Option<String>>>,
) { ) {
let client = match eventsource_client::ClientBuilder::for_url(&sse_url) { let client = match eventsource_client::ClientBuilder::for_url(&sse_url) {
Ok(builder) => builder.build(), Ok(builder) => builder.build(),
Err(e) => { Err(e) => {
pending_requests.clear().await;
warn!("Failed to connect SSE client: {}", e); warn!("Failed to connect SSE client: {}", e);
return; return;
} }
@@ -105,84 +103,54 @@ impl SseActor {
} }
// Now handle subsequent events // Now handle subsequent events
while let Ok(Some(event)) = stream.try_next().await { loop {
match event { match stream.try_next().await {
SSE::Event(e) if e.event_type == "message" => { Ok(Some(event)) => {
// Attempt to parse the SSE data as a JsonRpcMessage match event {
match serde_json::from_str::<JsonRpcMessage>(&e.data) { SSE::Event(e) if e.event_type == "message" => {
Ok(message) => { // Attempt to parse the SSE data as a JsonRpcMessage
match &message { match serde_json::from_str::<JsonRpcMessage>(&e.data) {
JsonRpcMessage::Response(response) => { Ok(message) => {
if let Some(id) = &response.id { let _ = sender.send(message).await;
pending_requests
.respond(&id.to_string(), Ok(message))
.await;
}
} }
JsonRpcMessage::Error(error) => { Err(err) => {
if let Some(id) = &error.id { warn!("Failed to parse SSE message: {err}");
pending_requests
.respond(&id.to_string(), Ok(message))
.await;
}
} }
_ => {} // TODO: Handle other variants (Request, etc.)
} }
} }
Err(err) => { _ => { /* ignore other events */ }
warn!("Failed to parse SSE message: {err}");
}
} }
} }
_ => { /* ignore other events */ } Ok(None) => {
// Stream ended
tracing::info!("SSE stream ended.");
break;
}
Err(e) => {
warn!("Error reading SSE stream: {e}");
break;
}
} }
} }
// SSE stream ended or errored; signal any pending requests tracing::error!("SSE stream ended or encountered an error.");
tracing::error!("SSE stream ended or encountered an error; clearing pending requests.");
pending_requests.clear().await;
} }
/// Continuously receives messages from the `mpsc::Receiver`.
/// - If it's a request, store the oneshot in `pending_requests`.
/// - POST the message to the discovered endpoint (once known).
async fn handle_outgoing_messages( async fn handle_outgoing_messages(
mut receiver: mpsc::Receiver<TransportMessage>, mut receiver: mpsc::Receiver<String>,
http_client: HttpClient, http_client: HttpClient,
post_endpoint: Arc<RwLock<Option<String>>>, post_endpoint: Arc<RwLock<Option<String>>>,
pending_requests: Arc<PendingRequests>,
) { ) {
while let Some(transport_msg) = receiver.recv().await { while let Some(message_str) = receiver.recv().await {
let post_url = match post_endpoint.read().await.as_ref() { let post_url = match post_endpoint.read().await.as_ref() {
Some(url) => url.clone(), Some(url) => url.clone(),
None => { None => {
if let Some(response_tx) = transport_msg.response_tx { // TODO: the endpoint isn't discovered yet. This shouldn't happen -- we only return the handle
let _ = response_tx.send(Err(Error::NotConnected)); // after the endpoint is set.
}
continue; continue;
} }
}; };
// Serialize the JSON-RPC message
let message_str = match serde_json::to_string(&transport_msg.message) {
Ok(s) => s,
Err(e) => {
if let Some(tx) = transport_msg.response_tx {
let _ = tx.send(Err(Error::Serialization(e)));
}
continue;
}
};
// If it's a request, store the channel so we can respond later
if let Some(response_tx) = transport_msg.response_tx {
if let JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) =
&transport_msg.message
{
pending_requests.insert(id.to_string(), response_tx).await;
}
}
// Perform the HTTP POST // Perform the HTTP POST
match http_client match http_client
.post(&post_url) .post(&post_url)
@@ -209,26 +177,25 @@ impl SseActor {
} }
} }
// mpsc channel closed => no more outgoing messages tracing::info!("SseActor shut down.");
let pending = pending_requests.len().await;
if pending > 0 {
tracing::error!("SSE stream ended or encountered an error with {pending} unfulfilled pending requests.");
pending_requests.clear().await;
} else {
tracing::info!("SseActor shutdown cleanly. No pending requests.");
}
} }
} }
#[derive(Clone)] #[derive(Clone)]
pub struct SseTransportHandle { pub struct SseTransportHandle {
sender: mpsc::Sender<TransportMessage>, sender: mpsc::Sender<String>,
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>,
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl TransportHandle for SseTransportHandle { impl TransportHandle for SseTransportHandle {
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> { async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> {
send_message(&self.sender, message).await serialize_and_send(&self.sender, message).await
}
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
let mut receiver = self.receiver.lock().await;
receiver.recv().await.ok_or(Error::ChannelClosed)
} }
} }
@@ -279,17 +246,13 @@ impl Transport for SseTransport {
// Create a channel for outgoing TransportMessages // Create a channel for outgoing TransportMessages
let (tx, rx) = mpsc::channel(32); let (tx, rx) = mpsc::channel(32);
let (otx, orx) = mpsc::channel(32);
let post_endpoint: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None)); let post_endpoint: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
let post_endpoint_clone = Arc::clone(&post_endpoint); let post_endpoint_clone = Arc::clone(&post_endpoint);
// Build the actor // Build the actor
let actor = SseActor::new( let actor = SseActor::new(rx, otx, self.sse_url.clone(), post_endpoint);
rx,
Arc::new(PendingRequests::new()),
self.sse_url.clone(),
post_endpoint,
);
// Spawn the actor task // Spawn the actor task
tokio::spawn(actor.run()); tokio::spawn(actor.run());
@@ -301,7 +264,10 @@ impl Transport for SseTransport {
) )
.await .await
{ {
Ok(_) => Ok(SseTransportHandle { sender: tx }), Ok(_) => Ok(SseTransportHandle {
sender: tx,
receiver: Arc::new(Mutex::new(orx)),
}),
Err(e) => Err(Error::SseConnection(e.to_string())), Err(e) => Err(Error::SseConnection(e.to_string())),
} }
} }

View File

@@ -14,7 +14,7 @@ use nix::sys::signal::{kill, Signal};
#[cfg(unix)] #[cfg(unix)]
use nix::unistd::{getpgid, Pid}; use nix::unistd::{getpgid, Pid};
use super::{send_message, Error, PendingRequests, Transport, TransportHandle, TransportMessage}; use super::{serialize_and_send, Error, Transport, TransportHandle};
// Global to track process groups we've created // Global to track process groups we've created
static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1); static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
@@ -23,8 +23,8 @@ static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
/// ///
/// It uses channels for message passing and handles responses asynchronously through a background task. /// It uses channels for message passing and handles responses asynchronously through a background task.
pub struct StdioActor { pub struct StdioActor {
receiver: Option<mpsc::Receiver<TransportMessage>>, receiver: Option<mpsc::Receiver<String>>,
pending_requests: Arc<PendingRequests>, sender: Option<mpsc::Sender<JsonRpcMessage>>,
process: Child, // we store the process to keep it alive process: Child, // we store the process to keep it alive
error_sender: mpsc::Sender<Error>, error_sender: mpsc::Sender<Error>,
stdin: Option<ChildStdin>, stdin: Option<ChildStdin>,
@@ -55,11 +55,11 @@ impl StdioActor {
let stdout = self.stdout.take().expect("stdout should be available"); let stdout = self.stdout.take().expect("stdout should be available");
let stdin = self.stdin.take().expect("stdin should be available"); let stdin = self.stdin.take().expect("stdin should be available");
let receiver = self.receiver.take().expect("receiver should be available"); let msg_inbox = self.receiver.take().expect("receiver should be available");
let msg_outbox = self.sender.take().expect("sender should be available");
let incoming = Self::handle_incoming_messages(stdout, self.pending_requests.clone()); let incoming = Self::handle_proc_output(stdout, msg_outbox);
let outgoing = let outgoing = Self::handle_proc_input(stdin, msg_inbox);
Self::handle_outgoing_messages(receiver, stdin, self.pending_requests.clone());
// take ownership of futures for tokio::select // take ownership of futures for tokio::select
pin!(incoming); pin!(incoming);
@@ -96,12 +96,9 @@ impl StdioActor {
.await; .await;
} }
} }
// Clean up regardless of which path we took
self.pending_requests.clear().await;
} }
async fn handle_incoming_messages(stdout: ChildStdout, pending_requests: Arc<PendingRequests>) { async fn handle_proc_output(stdout: ChildStdout, sender: mpsc::Sender<JsonRpcMessage>) {
let mut reader = BufReader::new(stdout); let mut reader = BufReader::new(stdout);
let mut line = String::new(); let mut line = String::new();
loop { loop {
@@ -116,20 +113,12 @@ impl StdioActor {
message = ?message, message = ?message,
"Received incoming message" "Received incoming message"
); );
let _ = sender.send(message).await;
match &message { } else {
JsonRpcMessage::Response(response) => { tracing::warn!(
if let Some(id) = &response.id { message = ?line,
pending_requests.respond(&id.to_string(), Ok(message)).await; "Failed to parse incoming message"
} );
}
JsonRpcMessage::Error(error) => {
if let Some(id) = &error.id {
pending_requests.respond(&id.to_string(), Ok(message)).await;
}
}
_ => {} // TODO: Handle other variants (Request, etc.)
}
} }
line.clear(); line.clear();
} }
@@ -141,44 +130,20 @@ impl StdioActor {
} }
} }
async fn handle_outgoing_messages( async fn handle_proc_input(mut stdin: ChildStdin, mut receiver: mpsc::Receiver<String>) {
mut receiver: mpsc::Receiver<TransportMessage>, while let Some(message_str) = receiver.recv().await {
mut stdin: ChildStdin, tracing::debug!(message = ?message_str, "Sending outgoing message");
pending_requests: Arc<PendingRequests>,
) {
while let Some(mut transport_msg) = receiver.recv().await {
let message_str = match serde_json::to_string(&transport_msg.message) {
Ok(s) => s,
Err(e) => {
if let Some(tx) = transport_msg.response_tx.take() {
let _ = tx.send(Err(Error::Serialization(e)));
}
continue;
}
};
tracing::debug!(message = ?transport_msg.message, "Sending outgoing message");
if let Some(response_tx) = transport_msg.response_tx.take() {
if let JsonRpcMessage::Request(request) = &transport_msg.message {
if let Some(id) = &request.id {
pending_requests.insert(id.to_string(), response_tx).await;
}
}
}
if let Err(e) = stdin if let Err(e) = stdin
.write_all(format!("{}\n", message_str).as_bytes()) .write_all(format!("{}\n", message_str).as_bytes())
.await .await
{ {
tracing::error!(error = ?e, "Error writing message to child process"); tracing::error!(error = ?e, "Error writing message to child process");
pending_requests.clear().await;
break; break;
} }
if let Err(e) = stdin.flush().await { if let Err(e) = stdin.flush().await {
tracing::error!(error = ?e, "Error flushing message to child process"); tracing::error!(error = ?e, "Error flushing message to child process");
pending_requests.clear().await;
break; break;
} }
} }
@@ -187,18 +152,24 @@ impl StdioActor {
#[derive(Clone)] #[derive(Clone)]
pub struct StdioTransportHandle { pub struct StdioTransportHandle {
sender: mpsc::Sender<TransportMessage>, sender: mpsc::Sender<String>, // to process
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>, // from process
error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>, error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>,
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl TransportHandle for StdioTransportHandle { impl TransportHandle for StdioTransportHandle {
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> { async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> {
let result = send_message(&self.sender, message).await; let result = serialize_and_send(&self.sender, message).await;
// Check for any pending errors even if send is successful // Check for any pending errors even if send is successful
self.check_for_errors().await?; self.check_for_errors().await?;
result result
} }
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
let mut receiver = self.receiver.lock().await;
receiver.recv().await.ok_or(Error::ChannelClosed)
}
} }
impl StdioTransportHandle { impl StdioTransportHandle {
@@ -289,12 +260,13 @@ impl Transport for StdioTransport {
async fn start(&self) -> Result<Self::Handle, Error> { async fn start(&self) -> Result<Self::Handle, Error> {
let (process, stdin, stdout, stderr) = self.spawn_process().await?; let (process, stdin, stdout, stderr) = self.spawn_process().await?;
let (message_tx, message_rx) = mpsc::channel(32); let (outbox_tx, outbox_rx) = mpsc::channel(32);
let (inbox_tx, inbox_rx) = mpsc::channel(32);
let (error_tx, error_rx) = mpsc::channel(1); let (error_tx, error_rx) = mpsc::channel(1);
let actor = StdioActor { let actor = StdioActor {
receiver: Some(message_rx), receiver: Some(outbox_rx), // client to process
pending_requests: Arc::new(PendingRequests::new()), sender: Some(inbox_tx), // process to client
process, process,
error_sender: error_tx, error_sender: error_tx,
stdin: Some(stdin), stdin: Some(stdin),
@@ -305,7 +277,8 @@ impl Transport for StdioTransport {
tokio::spawn(actor.run()); tokio::spawn(actor.run());
let handle = StdioTransportHandle { let handle = StdioTransportHandle {
sender: message_tx, sender: outbox_tx, // client to process
receiver: Arc::new(Mutex::new(inbox_rx)), // process to client
error_receiver: Arc::new(Mutex::new(error_rx)), error_receiver: Arc::new(Mutex::new(error_rx)),
}; };
Ok(handle) Ok(handle)

View File

@@ -4,9 +4,13 @@ use std::{
}; };
use futures::{Future, Stream}; use futures::{Future, Stream};
use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse}; use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcResponse};
use pin_project::pin_project; use pin_project::pin_project;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; use router::McpRequest;
use tokio::{
io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader},
sync::mpsc,
};
use tower_service::Service; use tower_service::Service;
mod errors; mod errors;
@@ -123,7 +127,7 @@ pub struct Server<S> {
impl<S> Server<S> impl<S> Server<S>
where where
S: Service<JsonRpcRequest, Response = JsonRpcResponse> + Send, S: Service<McpRequest, Response = JsonRpcResponse> + Send,
S::Error: Into<BoxError>, S::Error: Into<BoxError>,
S::Future: Send, S::Future: Send,
{ {
@@ -134,8 +138,8 @@ where
// TODO transport trait instead of byte transport if we implement others // TODO transport trait instead of byte transport if we implement others
pub async fn run<R, W>(self, mut transport: ByteTransport<R, W>) -> Result<(), ServerError> pub async fn run<R, W>(self, mut transport: ByteTransport<R, W>) -> Result<(), ServerError>
where where
R: AsyncRead + Unpin, R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin, W: AsyncWrite + Unpin + Send + 'static,
{ {
use futures::StreamExt; use futures::StreamExt;
let mut service = self.service; let mut service = self.service;
@@ -160,7 +164,22 @@ where
); );
// Process the request using our service // Process the request using our service
let response = match service.call(request).await { let (notify_tx, mut notify_rx) = mpsc::channel(256);
let mcp_request = McpRequest {
request,
notifier: notify_tx,
};
let transport_fut = tokio::spawn(async move {
while let Some(notification) = notify_rx.recv().await {
if transport.write_message(notification).await.is_err() {
break;
}
}
transport
});
let response = match service.call(mcp_request).await {
Ok(resp) => resp, Ok(resp) => resp,
Err(e) => { Err(e) => {
let error_msg = e.into().to_string(); let error_msg = e.into().to_string();
@@ -178,6 +197,16 @@ where
} }
}; };
transport = match transport_fut.await {
Ok(transport) => transport,
Err(e) => {
tracing::error!(error = %e, "Failed to spawn transport task");
return Err(ServerError::Transport(TransportError::Io(
e.into(),
)));
}
};
// Serialize response for logging // Serialize response for logging
let response_json = serde_json::to_string(&response) let response_json = serde_json::to_string(&response)
.unwrap_or_else(|_| "Failed to serialize response".to_string()); .unwrap_or_else(|_| "Failed to serialize response".to_string());
@@ -247,7 +276,7 @@ where
// Any router implements this // Any router implements this
pub trait BoundedService: pub trait BoundedService:
Service< Service<
JsonRpcRequest, McpRequest,
Response = JsonRpcResponse, Response = JsonRpcResponse,
Error = BoxError, Error = BoxError,
Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>, Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>,
@@ -259,7 +288,7 @@ pub trait BoundedService:
// Implement it for any type that meets the bounds // Implement it for any type that meets the bounds
impl<T> BoundedService for T where impl<T> BoundedService for T where
T: Service< T: Service<
JsonRpcRequest, McpRequest,
Response = JsonRpcResponse, Response = JsonRpcResponse,
Error = BoxError, Error = BoxError,
Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>, Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>,

View File

@@ -2,12 +2,14 @@ use anyhow::Result;
use mcp_core::content::Content; use mcp_core::content::Content;
use mcp_core::handler::{PromptError, ResourceError}; use mcp_core::handler::{PromptError, ResourceError};
use mcp_core::prompt::{Prompt, PromptArgument}; use mcp_core::prompt::{Prompt, PromptArgument};
use mcp_core::protocol::JsonRpcMessage;
use mcp_core::tool::ToolAnnotations; use mcp_core::tool::ToolAnnotations;
use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool}; use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool};
use mcp_server::router::{CapabilitiesBuilder, RouterService}; use mcp_server::router::{CapabilitiesBuilder, RouterService};
use mcp_server::{ByteTransport, Router, Server}; use mcp_server::{ByteTransport, Router, Server};
use serde_json::Value; use serde_json::Value;
use std::{future::Future, pin::Pin, sync::Arc}; use std::{future::Future, pin::Pin, sync::Arc};
use tokio::sync::mpsc;
use tokio::{ use tokio::{
io::{stdin, stdout}, io::{stdin, stdout},
sync::Mutex, sync::Mutex,
@@ -124,6 +126,7 @@ impl Router for CounterRouter {
&self, &self,
tool_name: &str, tool_name: &str,
_arguments: Value, _arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> { ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone(); let this = self.clone();
let tool_name = tool_name.to_string(); let tool_name = tool_name.to_string();

View File

@@ -11,14 +11,15 @@ use mcp_core::{
handler::{PromptError, ResourceError, ToolError}, handler::{PromptError, ResourceError, ToolError},
prompt::{Prompt, PromptMessage, PromptMessageRole}, prompt::{Prompt, PromptMessage, PromptMessageRole},
protocol::{ protocol::{
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcRequest, CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcMessage,
JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult, JsonRpcRequest, JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult,
PromptsCapability, ReadResourceResult, ResourcesCapability, ServerCapabilities, PromptsCapability, ReadResourceResult, ResourcesCapability, ServerCapabilities,
ToolsCapability, ToolsCapability,
}, },
ResourceContents, ResourceContents,
}; };
use serde_json::Value; use serde_json::Value;
use tokio::sync::mpsc;
use tower_service::Service; use tower_service::Service;
use crate::{BoxError, RouterError}; use crate::{BoxError, RouterError};
@@ -91,6 +92,7 @@ pub trait Router: Send + Sync + 'static {
&self, &self,
tool_name: &str, tool_name: &str,
arguments: Value, arguments: Value,
notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>>; ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>>;
fn list_resources(&self) -> Vec<mcp_core::resource::Resource>; fn list_resources(&self) -> Vec<mcp_core::resource::Resource>;
fn read_resource( fn read_resource(
@@ -159,6 +161,7 @@ pub trait Router: Send + Sync + 'static {
fn handle_tools_call( fn handle_tools_call(
&self, &self,
req: JsonRpcRequest, req: JsonRpcRequest,
notifier: mpsc::Sender<JsonRpcMessage>,
) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send { ) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
async move { async move {
let params = req let params = req
@@ -172,7 +175,7 @@ pub trait Router: Send + Sync + 'static {
let arguments = params.get("arguments").cloned().unwrap_or(Value::Null); let arguments = params.get("arguments").cloned().unwrap_or(Value::Null);
let result = match self.call_tool(name, arguments).await { let result = match self.call_tool(name, arguments, notifier).await {
Ok(result) => CallToolResult { Ok(result) => CallToolResult {
content: result, content: result,
is_error: None, is_error: None,
@@ -394,7 +397,12 @@ pub trait Router: Send + Sync + 'static {
pub struct RouterService<T>(pub T); pub struct RouterService<T>(pub T);
impl<T> Service<JsonRpcRequest> for RouterService<T> pub struct McpRequest {
pub request: JsonRpcRequest,
pub notifier: mpsc::Sender<JsonRpcMessage>,
}
impl<T> Service<McpRequest> for RouterService<T>
where where
T: Router + Clone + Send + Sync + 'static, T: Router + Clone + Send + Sync + 'static,
{ {
@@ -406,21 +414,21 @@ where
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, req: JsonRpcRequest) -> Self::Future { fn call(&mut self, req: McpRequest) -> Self::Future {
let this = self.0.clone(); let this = self.0.clone();
Box::pin(async move { Box::pin(async move {
let result = match req.method.as_str() { let result = match req.request.method.as_str() {
"initialize" => this.handle_initialize(req).await, "initialize" => this.handle_initialize(req.request).await,
"tools/list" => this.handle_tools_list(req).await, "tools/list" => this.handle_tools_list(req.request).await,
"tools/call" => this.handle_tools_call(req).await, "tools/call" => this.handle_tools_call(req.request, req.notifier).await,
"resources/list" => this.handle_resources_list(req).await, "resources/list" => this.handle_resources_list(req.request).await,
"resources/read" => this.handle_resources_read(req).await, "resources/read" => this.handle_resources_read(req.request).await,
"prompts/list" => this.handle_prompts_list(req).await, "prompts/list" => this.handle_prompts_list(req.request).await,
"prompts/get" => this.handle_prompts_get(req).await, "prompts/get" => this.handle_prompts_get(req.request).await,
_ => { _ => {
let mut response = this.create_response(req.id); let mut response = this.create_response(req.request.id);
response.error = Some(RouterError::MethodNotFound(req.method).into()); response.error = Some(RouterError::MethodNotFound(req.request.method).into());
Ok(response) Ok(response)
} }
}; };

View File

@@ -148,6 +148,7 @@ function ChatContent({
handleInputChange: _handleInputChange, handleInputChange: _handleInputChange,
handleSubmit: _submitMessage, handleSubmit: _submitMessage,
updateMessageStreamBody, updateMessageStreamBody,
notifications,
} = useMessageStream({ } = useMessageStream({
api: getApiUrl('/reply'), api: getApiUrl('/reply'),
initialMessages: chat.messages, initialMessages: chat.messages,
@@ -492,6 +493,16 @@ function ChatContent({
const handleDragOver = (e: React.DragEvent<HTMLDivElement>) => { const handleDragOver = (e: React.DragEvent<HTMLDivElement>) => {
e.preventDefault(); e.preventDefault();
}; };
const toolCallNotifications = notifications.reduce((map, item) => {
const key = item.request_id;
if (!map.has(key)) {
map.set(key, []);
}
map.get(key).push(item);
return map;
}, new Map());
return ( return (
<div className="flex flex-col w-full h-screen items-center justify-center"> <div className="flex flex-col w-full h-screen items-center justify-center">
{/* Loader when generating recipe */} {/* Loader when generating recipe */}
@@ -571,6 +582,7 @@ function ChatContent({
const updatedMessages = [...messages, newMessage]; const updatedMessages = [...messages, newMessage];
setMessages(updatedMessages); setMessages(updatedMessages);
}} }}
toolCallNotifications={toolCallNotifications}
/> />
)} )}
</> </>
@@ -578,6 +590,7 @@ function ChatContent({
</div> </div>
))} ))}
</SearchView> </SearchView>
{error && ( {error && (
<div className="flex flex-col items-center justify-center p-4"> <div className="flex flex-col items-center justify-center p-4">
<div className="text-red-700 dark:text-red-300 bg-red-400/50 p-3 rounded-lg mb-2"> <div className="text-red-700 dark:text-red-300 bg-red-400/50 p-3 rounded-lg mb-2">

View File

@@ -17,6 +17,7 @@ import {
} from '../types/message'; } from '../types/message';
import ToolCallConfirmation from './ToolCallConfirmation'; import ToolCallConfirmation from './ToolCallConfirmation';
import MessageCopyLink from './MessageCopyLink'; import MessageCopyLink from './MessageCopyLink';
import { NotificationEvent } from '../hooks/useMessageStream';
interface GooseMessageProps { interface GooseMessageProps {
// messages up to this index are presumed to be "history" from a resumed session, this is used to track older tool confirmation requests // messages up to this index are presumed to be "history" from a resumed session, this is used to track older tool confirmation requests
@@ -25,6 +26,7 @@ interface GooseMessageProps {
message: Message; message: Message;
messages: Message[]; messages: Message[];
metadata?: string[]; metadata?: string[];
toolCallNotifications: Map<string, NotificationEvent[]>;
append: (value: string) => void; append: (value: string) => void;
appendMessage: (message: Message) => void; appendMessage: (message: Message) => void;
} }
@@ -34,6 +36,7 @@ export default function GooseMessage({
message, message,
metadata, metadata,
messages, messages,
toolCallNotifications,
append, append,
appendMessage, appendMessage,
}: GooseMessageProps) { }: GooseMessageProps) {
@@ -158,6 +161,7 @@ export default function GooseMessage({
} }
toolRequest={toolRequest} toolRequest={toolRequest}
toolResponse={toolResponsesMap.get(toolRequest.id)} toolResponse={toolResponsesMap.get(toolRequest.id)}
notifications={toolCallNotifications.get(toolRequest.id)}
/> />
</div> </div>
))} ))}

View File

@@ -1,4 +1,4 @@
import React from 'react'; import React, { useEffect, useRef } from 'react';
import { Card } from './ui/card'; import { Card } from './ui/card';
import { ToolCallArguments, ToolCallArgumentValue } from './ToolCallArguments'; import { ToolCallArguments, ToolCallArgumentValue } from './ToolCallArguments';
import MarkdownContent from './MarkdownContent'; import MarkdownContent from './MarkdownContent';
@@ -6,17 +6,20 @@ import { Content, ToolRequestMessageContent, ToolResponseMessageContent } from '
import { snakeToTitleCase } from '../utils'; import { snakeToTitleCase } from '../utils';
import Dot, { LoadingStatus } from './ui/Dot'; import Dot, { LoadingStatus } from './ui/Dot';
import Expand from './ui/Expand'; import Expand from './ui/Expand';
import { NotificationEvent } from '../hooks/useMessageStream';
interface ToolCallWithResponseProps { interface ToolCallWithResponseProps {
isCancelledMessage: boolean; isCancelledMessage: boolean;
toolRequest: ToolRequestMessageContent; toolRequest: ToolRequestMessageContent;
toolResponse?: ToolResponseMessageContent; toolResponse?: ToolResponseMessageContent;
notifications?: NotificationEvent[];
} }
export default function ToolCallWithResponse({ export default function ToolCallWithResponse({
isCancelledMessage, isCancelledMessage,
toolRequest, toolRequest,
toolResponse, toolResponse,
notifications,
}: ToolCallWithResponseProps) { }: ToolCallWithResponseProps) {
const toolCall = toolRequest.toolCall.status === 'success' ? toolRequest.toolCall.value : null; const toolCall = toolRequest.toolCall.status === 'success' ? toolRequest.toolCall.value : null;
if (!toolCall) { if (!toolCall) {
@@ -26,7 +29,7 @@ export default function ToolCallWithResponse({
return ( return (
<div className={'w-full text-textSubtle text-sm'}> <div className={'w-full text-textSubtle text-sm'}>
<Card className=""> <Card className="">
<ToolCallView {...{ isCancelledMessage, toolCall, toolResponse }} /> <ToolCallView {...{ isCancelledMessage, toolCall, toolResponse, notifications }} />
</Card> </Card>
</div> </div>
); );
@@ -47,8 +50,9 @@ function ToolCallExpandable({
children, children,
className = '', className = '',
}: ToolCallExpandableProps) { }: ToolCallExpandableProps) {
const [isExpanded, setIsExpanded] = React.useState(isStartExpanded); const [isExpandedState, setIsExpanded] = React.useState<boolean | null>(null);
const toggleExpand = () => setIsExpanded((prev) => !prev); const isExpanded = isExpandedState === null ? isStartExpanded : isExpandedState;
const toggleExpand = () => setIsExpanded(!isExpanded);
React.useEffect(() => { React.useEffect(() => {
if (isForceExpand) setIsExpanded(true); if (isForceExpand) setIsExpanded(true);
}, [isForceExpand]); }, [isForceExpand]);
@@ -71,9 +75,42 @@ interface ToolCallViewProps {
arguments: Record<string, unknown>; arguments: Record<string, unknown>;
}; };
toolResponse?: ToolResponseMessageContent; toolResponse?: ToolResponseMessageContent;
notifications?: NotificationEvent[];
} }
function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallViewProps) { interface Progress {
progress: number;
progressToken: string;
total?: number;
message?: string;
}
const logToString = (logMessage: NotificationEvent) => {
const params = logMessage.message.params;
// Special case for the developer system shell logs
if (
params &&
params.data &&
typeof params.data === 'object' &&
'output' in params.data &&
'stream' in params.data
) {
return `[${params.data.stream}] ${params.data.output}`;
}
return typeof params.data === 'string' ? params.data : JSON.stringify(params.data);
};
const notificationToProgress = (notification: NotificationEvent): Progress =>
notification.message.params as unknown as Progress;
function ToolCallView({
isCancelledMessage,
toolCall,
toolResponse,
notifications,
}: ToolCallViewProps) {
const responseStyle = localStorage.getItem('response_style'); const responseStyle = localStorage.getItem('response_style');
const isExpandToolDetails = (() => { const isExpandToolDetails = (() => {
switch (responseStyle) { switch (responseStyle) {
@@ -103,6 +140,29 @@ function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallVi
})) }))
: []; : [];
const logs = notifications
?.filter((notification) => notification.message.method === 'notifications/message')
.map(logToString);
const progress = notifications
?.filter((notification) => notification.message.method === 'notifications/progress')
.map(notificationToProgress)
.reduce((map, item) => {
const key = item.progressToken;
if (!map.has(key)) {
map.set(key, []);
}
map.get(key)!.push(item);
return map;
}, new Map<string, Progress[]>());
const progressEntries = [...(progress?.values() || [])].map(
(entries) => entries.sort((a, b) => b.progress - a.progress)[0]
);
const isRenderingProgress =
loadingStatus === 'loading' && (progressEntries.length > 0 || (logs || []).length > 0);
const isShouldExpand = isExpandToolDetails || toolResults.some((v) => v.isExpandToolResults); const isShouldExpand = isExpandToolDetails || toolResults.some((v) => v.isExpandToolResults);
// Function to create a compact representation of arguments // Function to create a compact representation of arguments
@@ -136,7 +196,7 @@ function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallVi
return ( return (
<ToolCallExpandable <ToolCallExpandable
isStartExpanded={isShouldExpand} isStartExpanded={isShouldExpand || isRenderingProgress}
isForceExpand={isShouldExpand} isForceExpand={isShouldExpand}
label={ label={
<> <>
@@ -156,6 +216,24 @@ function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallVi
</div> </div>
)} )}
{logs && logs.length > 0 && (
<div className="bg-bgStandard mt-1">
<ToolLogsView
logs={logs}
working={toolResults.length === 0}
isStartExpanded={toolResults.length === 0}
/>
</div>
)}
{toolResults.length === 0 &&
progressEntries.length > 0 &&
progressEntries.map((entry, index) => (
<div className="p-2" key={index}>
<ProgressBar progress={entry.progress} total={entry.total} message={entry.message} />
</div>
))}
{/* Tool Output */} {/* Tool Output */}
{!isCancelledMessage && ( {!isCancelledMessage && (
<> <>
@@ -234,3 +312,76 @@ function ToolResultView({ result, isStartExpanded }: ToolResultViewProps) {
</ToolCallExpandable> </ToolCallExpandable>
); );
} }
function ToolLogsView({
logs,
working,
isStartExpanded,
}: {
logs: string[];
working: boolean;
isStartExpanded?: boolean;
}) {
const boxRef = useRef(null);
// Whenever logs update, jump to the newest entry
useEffect(() => {
if (boxRef.current) {
boxRef.current.scrollTop = boxRef.current.scrollHeight;
}
}, [logs]);
return (
<ToolCallExpandable
label={
<span className="pl-[19px] py-1">
<span>Logs</span>
{working && (
<div className="mx-2 inline-block">
<span
className="inline-block animate-spin rounded-full border-2 border-t-transparent border-current"
style={{ width: 8, height: 8 }}
role="status"
aria-label="Loading spinner"
/>
</div>
)}
</span>
}
isStartExpanded={isStartExpanded}
>
<div
ref={boxRef}
className={`flex flex-col items-start space-y-2 overflow-y-auto ${working ? 'max-h-[4rem]' : 'max-h-[20rem]'} bg-bgApp`}
>
{logs.map((log, i) => (
<span key={i} className="font-mono text-sm text-textSubtle">
{log}
</span>
))}
</div>
</ToolCallExpandable>
);
}
const ProgressBar = ({ progress, total, message }: Omit<Progress, 'progressToken'>) => {
const isDeterminate = typeof total === 'number';
const percent = isDeterminate ? Math.min((progress / total!) * 100, 100) : 0;
return (
<div className="w-full space-y-2">
{message && <div className="text-sm text-gray-700">{message}</div>}
<div className="w-full bg-gray-200 rounded-full h-4 overflow-hidden relative">
{isDeterminate ? (
<div
className="bg-blue-500 h-full transition-all duration-300"
style={{ width: `${percent}%` }}
/>
) : (
<div className="absolute inset-0 animate-indeterminate bg-blue-500" />
)}
</div>
</div>
);
};

View File

@@ -6,11 +6,25 @@ import { Message, createUserMessage, hasCompletedToolCalls } from '../types/mess
// Ensure TextDecoder is available in the global scope // Ensure TextDecoder is available in the global scope
const TextDecoder = globalThis.TextDecoder; const TextDecoder = globalThis.TextDecoder;
type JsonValue = string | number | boolean | null | JsonValue[] | { [key: string]: JsonValue };
export interface NotificationEvent {
type: 'Notification';
request_id: string;
message: {
method: string;
params: {
[key: string]: JsonValue;
};
};
}
// Event types for SSE stream // Event types for SSE stream
type MessageEvent = type MessageEvent =
| { type: 'Message'; message: Message } | { type: 'Message'; message: Message }
| { type: 'Error'; error: string } | { type: 'Error'; error: string }
| { type: 'Finish'; reason: string }; | { type: 'Finish'; reason: string }
| NotificationEvent;
export interface UseMessageStreamOptions { export interface UseMessageStreamOptions {
/** /**
@@ -124,6 +138,8 @@ export interface UseMessageStreamHelpers {
/** Modify body (session id and/or work dir mid-stream) **/ /** Modify body (session id and/or work dir mid-stream) **/
updateMessageStreamBody?: (newBody: object) => void; updateMessageStreamBody?: (newBody: object) => void;
notifications: NotificationEvent[];
} }
/** /**
@@ -151,6 +167,8 @@ export function useMessageStream({
fallbackData: initialMessages, fallbackData: initialMessages,
}); });
const [notifications, setNotifications] = useState<NotificationEvent[]>([]);
// expose a way to update the body so we can update the session id when CLE occurs // expose a way to update the body so we can update the session id when CLE occurs
const updateMessageStreamBody = useCallback((newBody: object) => { const updateMessageStreamBody = useCallback((newBody: object) => {
extraMetadataRef.current.body = { extraMetadataRef.current.body = {
@@ -247,6 +265,14 @@ export function useMessageStream({
break; break;
} }
case 'Notification': {
const newNotification = {
...parsedEvent,
};
setNotifications((prev) => [...prev, newNotification]);
break;
}
case 'Error': case 'Error':
throw new Error(parsedEvent.error); throw new Error(parsedEvent.error);
@@ -516,5 +542,6 @@ export function useMessageStream({
isLoading: isLoading || false, isLoading: isLoading || false,
addToolResult, addToolResult,
updateMessageStreamBody, updateMessageStreamBody,
notifications,
}; };
} }

View File

@@ -44,10 +44,16 @@ export default {
'0%': { transform: 'rotate(0deg)' }, '0%': { transform: 'rotate(0deg)' },
'100%': { transform: 'rotate(360deg)' }, '100%': { transform: 'rotate(360deg)' },
}, },
indeterminate: {
'0%': { left: '-40%', width: '40%' },
'50%': { left: '20%', width: '60%' },
'100%': { left: '100%', width: '80%' },
},
}, },
animation: { animation: {
'shimmer-pulse': 'shimmer 4s ease-in-out infinite', 'shimmer-pulse': 'shimmer 4s ease-in-out infinite',
'gradient-loader': 'loader 750ms ease-in-out infinite', 'gradient-loader': 'loader 750ms ease-in-out infinite',
indeterminate: 'indeterminate 1.5s infinite linear',
}, },
colors: { colors: {
bgApp: 'var(--background-app)', bgApp: 'var(--background-app)',