mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 22:24:21 +01:00
feat: Handle MCP server notification messages (#2613)
Co-authored-by: Michael Neale <michael.neale@gmail.com>
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -3435,6 +3435,7 @@ dependencies = [
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-cron-scheduler",
|
||||
"tokio-stream",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
@@ -3486,6 +3487,7 @@ dependencies = [
|
||||
"goose",
|
||||
"goose-bench",
|
||||
"goose-mcp",
|
||||
"indicatif",
|
||||
"mcp-client",
|
||||
"mcp-core",
|
||||
"mcp-server",
|
||||
|
||||
@@ -55,6 +55,7 @@ regex = "1.11.1"
|
||||
minijinja = "2.8.0"
|
||||
nix = { version = "0.30.1", features = ["process", "signal"] }
|
||||
tar = "0.4"
|
||||
indicatif = "0.17.11"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
winapi = { version = "0.3", features = ["wincred"] }
|
||||
|
||||
@@ -7,6 +7,7 @@ mod thinking;
|
||||
|
||||
pub use builder::{build_session, SessionBuilderConfig};
|
||||
use console::Color;
|
||||
use goose::agents::AgentEvent;
|
||||
use goose::permission::permission_confirmation::PrincipalType;
|
||||
use goose::permission::Permission;
|
||||
use goose::permission::PermissionConfirmation;
|
||||
@@ -26,6 +27,8 @@ use input::InputResult;
|
||||
use mcp_core::handler::ToolError;
|
||||
use mcp_core::prompt::PromptMessage;
|
||||
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use mcp_core::protocol::JsonRpcNotification;
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
@@ -713,12 +716,15 @@ impl Session {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut progress_bars = output::McpSpinners::new();
|
||||
|
||||
use futures::StreamExt;
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = stream.next() => {
|
||||
let _ = progress_bars.hide();
|
||||
match result {
|
||||
Some(Ok(message)) => {
|
||||
Some(Ok(AgentEvent::Message(message))) => {
|
||||
// If it's a confirmation request, get approval but otherwise do not render/persist
|
||||
if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() {
|
||||
output::hide_thinking();
|
||||
@@ -846,6 +852,51 @@ impl Session {
|
||||
if interactive {output::show_thinking()};
|
||||
}
|
||||
}
|
||||
Some(Ok(AgentEvent::McpNotification((_id, message)))) => {
|
||||
if let JsonRpcMessage::Notification(JsonRpcNotification{
|
||||
method,
|
||||
params: Some(Value::Object(o)),
|
||||
..
|
||||
}) = message {
|
||||
match method.as_str() {
|
||||
"notifications/message" => {
|
||||
let data = o.get("data").unwrap_or(&Value::Null);
|
||||
let message = match data {
|
||||
Value::String(s) => s.clone(),
|
||||
Value::Object(o) => {
|
||||
if let Some(Value::String(output)) = o.get("output") {
|
||||
output.to_owned()
|
||||
} else {
|
||||
data.to_string()
|
||||
}
|
||||
},
|
||||
v => {
|
||||
v.to_string()
|
||||
},
|
||||
};
|
||||
// output::render_text_no_newlines(&message, None, true);
|
||||
progress_bars.log(&message);
|
||||
},
|
||||
"notifications/progress" => {
|
||||
let progress = o.get("progress").and_then(|v| v.as_f64());
|
||||
let token = o.get("progressToken").map(|v| v.to_string());
|
||||
let message = o.get("message").and_then(|v| v.as_str());
|
||||
let total = o
|
||||
.get("total")
|
||||
.and_then(|v| v.as_f64());
|
||||
if let (Some(progress), Some(token)) = (progress, token) {
|
||||
progress_bars.update(
|
||||
token.as_str(),
|
||||
progress,
|
||||
total,
|
||||
message,
|
||||
);
|
||||
}
|
||||
},
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
eprintln!("Error: {}", e);
|
||||
drop(stream);
|
||||
@@ -872,6 +923,7 @@ impl Session {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -2,12 +2,15 @@ use bat::WrappingMode;
|
||||
use console::{style, Color};
|
||||
use goose::config::Config;
|
||||
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
|
||||
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
|
||||
use mcp_core::prompt::PromptArgument;
|
||||
use mcp_core::tool::ToolCall;
|
||||
use serde_json::Value;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Error;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
// Re-export theme for use in main
|
||||
#[derive(Clone, Copy)]
|
||||
@@ -144,6 +147,10 @@ pub fn render_message(message: &Message, debug: bool) {
|
||||
}
|
||||
|
||||
pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
|
||||
render_text_no_newlines(format!("\n{}\n\n", text).as_str(), color, dim);
|
||||
}
|
||||
|
||||
pub fn render_text_no_newlines(text: &str, color: Option<Color>, dim: bool) {
|
||||
let mut styled_text = style(text);
|
||||
if dim {
|
||||
styled_text = styled_text.dim();
|
||||
@@ -153,7 +160,7 @@ pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
|
||||
} else {
|
||||
styled_text = styled_text.green();
|
||||
}
|
||||
println!("\n{}\n", styled_text);
|
||||
print!("{}", styled_text);
|
||||
}
|
||||
|
||||
pub fn render_enter_plan_mode() {
|
||||
@@ -359,7 +366,6 @@ fn render_shell_request(call: &ToolCall, debug: bool) {
|
||||
}
|
||||
_ => print_params(&call.arguments, 0, debug),
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
fn render_default_request(call: &ToolCall, debug: bool) {
|
||||
@@ -568,6 +574,64 @@ pub fn display_greeting() {
|
||||
println!("\nGoose is running! Enter your instructions, or try asking what goose can do.\n");
|
||||
}
|
||||
|
||||
pub struct McpSpinners {
|
||||
bars: HashMap<String, ProgressBar>,
|
||||
log_spinner: Option<ProgressBar>,
|
||||
|
||||
multi_bar: MultiProgress,
|
||||
}
|
||||
|
||||
impl McpSpinners {
|
||||
pub fn new() -> Self {
|
||||
McpSpinners {
|
||||
bars: HashMap::new(),
|
||||
log_spinner: None,
|
||||
multi_bar: MultiProgress::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn log(&mut self, message: &str) {
|
||||
let spinner = self.log_spinner.get_or_insert_with(|| {
|
||||
let bar = self.multi_bar.add(
|
||||
ProgressBar::new_spinner()
|
||||
.with_style(
|
||||
ProgressStyle::with_template("{spinner:.green} {msg}")
|
||||
.unwrap()
|
||||
.tick_chars("⠋⠙⠚⠛⠓⠒⠊⠉"),
|
||||
)
|
||||
.with_message(message.to_string()),
|
||||
);
|
||||
bar.enable_steady_tick(Duration::from_millis(100));
|
||||
bar
|
||||
});
|
||||
|
||||
spinner.set_message(message.to_string());
|
||||
}
|
||||
|
||||
pub fn update(&mut self, token: &str, value: f64, total: Option<f64>, message: Option<&str>) {
|
||||
let bar = self.bars.entry(token.to_string()).or_insert_with(|| {
|
||||
if let Some(total) = total {
|
||||
self.multi_bar.add(
|
||||
ProgressBar::new((total * 100.0) as u64).with_style(
|
||||
ProgressStyle::with_template("[{elapsed}] {bar:40} {pos:>3}/{len:3} {msg}")
|
||||
.unwrap(),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
self.multi_bar.add(ProgressBar::new_spinner())
|
||||
}
|
||||
});
|
||||
bar.set_position((value * 100.0) as u64);
|
||||
if let Some(msg) = message {
|
||||
bar.set_message(msg.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn hide(&mut self) -> Result<(), Error> {
|
||||
self.multi_bar.clear()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::ptr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
use goose::agents::Agent;
|
||||
use goose::agents::{Agent, AgentEvent};
|
||||
use goose::message::Message;
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::databricks::DatabricksProvider;
|
||||
@@ -256,13 +256,16 @@ pub unsafe extern "C" fn goose_agent_send_message(
|
||||
|
||||
while let Some(message_result) = stream.next().await {
|
||||
match message_result {
|
||||
Ok(message) => {
|
||||
Ok(AgentEvent::Message(message)) => {
|
||||
// Get text or serialize to JSON
|
||||
// Note: Message doesn't have as_text method, we'll serialize to JSON
|
||||
if let Ok(json) = serde_json::to_string(&message) {
|
||||
full_response.push_str(&json);
|
||||
}
|
||||
}
|
||||
Ok(AgentEvent::McpNotification(_)) => {
|
||||
// TODO: Handle MCP notifications.
|
||||
}
|
||||
Err(e) => {
|
||||
full_response.push_str(&format!("\nError in message stream: {}", e));
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use serde_json::{json, Value};
|
||||
use std::{
|
||||
collections::HashMap, fs, future::Future, path::PathBuf, pin::Pin, sync::Arc, sync::Mutex,
|
||||
};
|
||||
use tokio::process::Command;
|
||||
use tokio::{process::Command, sync::mpsc};
|
||||
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
@@ -14,7 +14,7 @@ use std::os::unix::fs::PermissionsExt;
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||
resource::Resource,
|
||||
tool::{Tool, ToolAnnotations},
|
||||
Content,
|
||||
@@ -1155,6 +1155,7 @@ impl Router for ComputerControllerRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -13,13 +13,17 @@ use std::{
|
||||
path::{Path, PathBuf},
|
||||
pin::Pin,
|
||||
};
|
||||
use tokio::process::Command;
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, BufReader},
|
||||
process::Command,
|
||||
sync::mpsc,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use include_dir::{include_dir, Dir};
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, JsonRpcNotification, ServerCapabilities},
|
||||
resource::Resource,
|
||||
tool::Tool,
|
||||
Content,
|
||||
@@ -456,7 +460,11 @@ impl DeveloperRouter {
|
||||
}
|
||||
|
||||
// Shell command execution with platform-specific handling
|
||||
async fn bash(&self, params: Value) -> Result<Vec<Content>, ToolError> {
|
||||
async fn bash(
|
||||
&self,
|
||||
params: Value,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Result<Vec<Content>, ToolError> {
|
||||
let command =
|
||||
params
|
||||
.get("command")
|
||||
@@ -488,27 +496,92 @@ impl DeveloperRouter {
|
||||
|
||||
// Get platform-specific shell configuration
|
||||
let shell_config = get_shell_config();
|
||||
let cmd_with_redirect = format_command_for_platform(command);
|
||||
let cmd_str = format_command_for_platform(command);
|
||||
|
||||
// Execute the command using platform-specific shell
|
||||
let child = Command::new(&shell_config.executable)
|
||||
let mut child = Command::new(&shell_config.executable)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.stdin(Stdio::null())
|
||||
.kill_on_drop(true)
|
||||
.arg(&shell_config.arg)
|
||||
.arg(cmd_with_redirect)
|
||||
.arg(cmd_str)
|
||||
.spawn()
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
let stderr = child.stderr.take().unwrap();
|
||||
|
||||
let mut stdout_reader = BufReader::new(stdout);
|
||||
let mut stderr_reader = BufReader::new(stderr);
|
||||
|
||||
let output_task = tokio::spawn(async move {
|
||||
let mut combined_output = String::new();
|
||||
|
||||
let mut stdout_buf = Vec::new();
|
||||
let mut stderr_buf = Vec::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
n = stdout_reader.read_until(b'\n', &mut stdout_buf) => {
|
||||
if n? == 0 {
|
||||
break;
|
||||
}
|
||||
let line = String::from_utf8_lossy(&stdout_buf);
|
||||
|
||||
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: "notifications/message".to_string(),
|
||||
params: Some(json!({
|
||||
"data": {
|
||||
"type": "shell",
|
||||
"stream": "stdout",
|
||||
"output": line.to_string(),
|
||||
}
|
||||
})),
|
||||
}))
|
||||
.ok();
|
||||
|
||||
combined_output.push_str(&line);
|
||||
stdout_buf.clear();
|
||||
}
|
||||
n = stderr_reader.read_until(b'\n', &mut stderr_buf) => {
|
||||
if n? == 0 {
|
||||
break;
|
||||
}
|
||||
let line = String::from_utf8_lossy(&stderr_buf);
|
||||
|
||||
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: "notifications/message".to_string(),
|
||||
params: Some(json!({
|
||||
"data": {
|
||||
"type": "shell",
|
||||
"stream": "stderr",
|
||||
"output": line.to_string(),
|
||||
}
|
||||
})),
|
||||
}))
|
||||
.ok();
|
||||
|
||||
combined_output.push_str(&line);
|
||||
stderr_buf.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok::<_, std::io::Error>(combined_output)
|
||||
});
|
||||
|
||||
// Wait for the command to complete and get output
|
||||
let output = child
|
||||
.wait_with_output()
|
||||
child
|
||||
.wait()
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
|
||||
let stdout_str = String::from_utf8_lossy(&output.stdout);
|
||||
let output_str = stdout_str;
|
||||
let output_str = match output_task.await {
|
||||
Ok(result) => result.map_err(|e| ToolError::ExecutionError(e.to_string()))?,
|
||||
Err(e) => return Err(ToolError::ExecutionError(e.to_string())),
|
||||
};
|
||||
|
||||
// Check the character count of the output
|
||||
const MAX_CHAR_COUNT: usize = 400_000; // 409600 chars = 400KB
|
||||
@@ -1048,12 +1121,13 @@ impl Router for DeveloperRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
Box::pin(async move {
|
||||
match tool_name.as_str() {
|
||||
"shell" => this.bash(arguments).await,
|
||||
"shell" => this.bash(arguments, notifier).await,
|
||||
"text_editor" => this.text_editor(arguments).await,
|
||||
"list_windows" => this.list_windows(arguments).await,
|
||||
"screen_capture" => this.screen_capture(arguments).await,
|
||||
@@ -1195,6 +1269,10 @@ mod tests {
|
||||
.await
|
||||
}
|
||||
|
||||
fn dummy_sender() -> mpsc::Sender<JsonRpcMessage> {
|
||||
mpsc::channel(1).0
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_shell_missing_parameters() {
|
||||
@@ -1202,7 +1280,7 @@ mod tests {
|
||||
std::env::set_current_dir(&temp_dir).unwrap();
|
||||
|
||||
let router = get_router().await;
|
||||
let result = router.call_tool("shell", json!({})).await;
|
||||
let result = router.call_tool("shell", json!({}), dummy_sender()).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.err().unwrap();
|
||||
@@ -1263,6 +1341,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": large_file_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1288,6 +1367,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": many_chars_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1319,6 +1399,7 @@ mod tests {
|
||||
"path": file_path_str,
|
||||
"file_text": "Hello, world!"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1331,6 +1412,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": file_path_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1369,6 +1451,7 @@ mod tests {
|
||||
"path": file_path_str,
|
||||
"file_text": "Hello, world!"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1383,6 +1466,7 @@ mod tests {
|
||||
"old_str": "world",
|
||||
"new_str": "Rust"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1407,6 +1491,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": file_path_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1444,6 +1529,7 @@ mod tests {
|
||||
"path": file_path_str,
|
||||
"file_text": "First line"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1458,6 +1544,7 @@ mod tests {
|
||||
"old_str": "First line",
|
||||
"new_str": "Second line"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1470,6 +1557,7 @@ mod tests {
|
||||
"command": "undo_edit",
|
||||
"path": file_path_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1485,6 +1573,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": file_path_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1583,6 +1672,7 @@ mod tests {
|
||||
"path": temp_dir.path().join("secret.txt").to_str().unwrap(),
|
||||
"file_text": "test content"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1601,6 +1691,7 @@ mod tests {
|
||||
"path": temp_dir.path().join("allowed.txt").to_str().unwrap(),
|
||||
"file_text": "test content"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1642,6 +1733,7 @@ mod tests {
|
||||
json!({
|
||||
"command": format!("cat {}", secret_file_path.to_str().unwrap())
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1658,6 +1750,7 @@ mod tests {
|
||||
json!({
|
||||
"command": format!("cat {}", allowed_file_path.to_str().unwrap())
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ use std::env;
|
||||
pub struct ShellConfig {
|
||||
pub executable: String,
|
||||
pub arg: String,
|
||||
pub redirect_syntax: String,
|
||||
}
|
||||
|
||||
impl Default for ShellConfig {
|
||||
@@ -14,13 +13,11 @@ impl Default for ShellConfig {
|
||||
Self {
|
||||
executable: "powershell.exe".to_string(),
|
||||
arg: "-NoProfile -NonInteractive -Command".to_string(),
|
||||
redirect_syntax: "2>&1".to_string(),
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
executable: "bash".to_string(),
|
||||
arg: "-c".to_string(),
|
||||
redirect_syntax: "2>&1".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -31,13 +28,12 @@ pub fn get_shell_config() -> ShellConfig {
|
||||
}
|
||||
|
||||
pub fn format_command_for_platform(command: &str) -> String {
|
||||
let config = get_shell_config();
|
||||
if cfg!(windows) {
|
||||
// For PowerShell, wrap the command in braces to handle special characters
|
||||
format!("{{ {} }} {}", command, config.redirect_syntax)
|
||||
format!("{{ {} }}", command)
|
||||
} else {
|
||||
// For other shells, no braces needed
|
||||
format!("{} {}", command, config.redirect_syntax)
|
||||
command.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ use base64::Engine;
|
||||
use chrono::NaiveDate;
|
||||
use indoc::indoc;
|
||||
use lazy_static::lazy_static;
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use mcp_core::tool::ToolAnnotations;
|
||||
use oauth_pkce::PkceOAuth2Client;
|
||||
use regex::Regex;
|
||||
@@ -14,6 +15,7 @@ use serde_json::{json, Value};
|
||||
use std::io::Cursor;
|
||||
use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc};
|
||||
use storage::CredentialsManager;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use mcp_core::content::Content;
|
||||
use mcp_core::{
|
||||
@@ -3281,6 +3283,7 @@ impl Router for GoogleDriveRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -5,7 +5,7 @@ use mcp_core::{
|
||||
content::Content,
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||
resource::Resource,
|
||||
role::Role,
|
||||
tool::Tool,
|
||||
@@ -16,7 +16,7 @@ use serde_json::Value;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tracing::error;
|
||||
|
||||
@@ -158,6 +158,7 @@ impl Router for JetBrainsRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -10,11 +10,12 @@ use std::{
|
||||
path::PathBuf,
|
||||
pin::Pin,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||
resource::Resource,
|
||||
tool::{Tool, ToolAnnotations, ToolCall},
|
||||
Content,
|
||||
@@ -520,6 +521,7 @@ impl Router for MemoryRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -3,11 +3,12 @@ use include_dir::{include_dir, Dir};
|
||||
use indoc::formatdoc;
|
||||
use serde_json::{json, Value};
|
||||
use std::{future::Future, pin::Pin};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||
resource::Resource,
|
||||
role::Role,
|
||||
tool::{Tool, ToolAnnotations},
|
||||
@@ -130,6 +131,7 @@ impl Router for TutorialRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -10,7 +10,7 @@ use axum::{
|
||||
use bytes::Bytes;
|
||||
use futures::{stream::StreamExt, Stream};
|
||||
use goose::{
|
||||
agents::SessionConfig,
|
||||
agents::{AgentEvent, SessionConfig},
|
||||
message::{Message, MessageContent},
|
||||
permission::permission_confirmation::PrincipalType,
|
||||
};
|
||||
@@ -18,7 +18,7 @@ use goose::{
|
||||
permission::{Permission, PermissionConfirmation},
|
||||
session,
|
||||
};
|
||||
use mcp_core::{role::Role, Content, ToolResult};
|
||||
use mcp_core::{protocol::JsonRpcMessage, role::Role, Content, ToolResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
@@ -79,9 +79,19 @@ impl IntoResponse for SseResponse {
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum MessageEvent {
|
||||
Message { message: Message },
|
||||
Error { error: String },
|
||||
Finish { reason: String },
|
||||
Message {
|
||||
message: Message,
|
||||
},
|
||||
Error {
|
||||
error: String,
|
||||
},
|
||||
Finish {
|
||||
reason: String,
|
||||
},
|
||||
Notification {
|
||||
request_id: String,
|
||||
message: JsonRpcMessage,
|
||||
},
|
||||
}
|
||||
|
||||
async fn stream_event(
|
||||
@@ -200,7 +210,7 @@ async fn handler(
|
||||
tokio::select! {
|
||||
response = timeout(Duration::from_millis(500), stream.next()) => {
|
||||
match response {
|
||||
Ok(Some(Ok(message))) => {
|
||||
Ok(Some(Ok(AgentEvent::Message(message)))) => {
|
||||
all_messages.push(message.clone());
|
||||
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
|
||||
tracing::error!("Error sending message through channel: {}", e);
|
||||
@@ -223,6 +233,20 @@ async fn handler(
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
|
||||
if let Err(e) = stream_event(MessageEvent::Notification{
|
||||
request_id: request_id.clone(),
|
||||
message: n,
|
||||
}, &tx).await {
|
||||
tracing::error!("Error sending message through channel: {}", e);
|
||||
let _ = stream_event(
|
||||
MessageEvent::Error {
|
||||
error: e.to_string(),
|
||||
},
|
||||
&tx,
|
||||
).await;
|
||||
}
|
||||
}
|
||||
Ok(Some(Err(e))) => {
|
||||
tracing::error!("Error processing message: {}", e);
|
||||
let _ = stream_event(
|
||||
@@ -317,7 +341,7 @@ async fn ask_handler(
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
match response {
|
||||
Ok(message) => {
|
||||
Ok(AgentEvent::Message(message)) => {
|
||||
if message.role == Role::Assistant {
|
||||
for content in &message.content {
|
||||
if let MessageContent::Text(text) = content {
|
||||
@@ -328,6 +352,10 @@ async fn ask_handler(
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(AgentEvent::McpNotification(n)) => {
|
||||
// Handle notifications if needed
|
||||
tracing::info!("Received notification: {:?}", n);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error processing as_ai message: {}", e);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
@@ -71,10 +71,10 @@ aws-sdk-bedrockruntime = "1.74.0"
|
||||
# For GCP Vertex AI provider auth
|
||||
jsonwebtoken = "9.3.1"
|
||||
|
||||
# Added blake3 hashing library as a dependency
|
||||
blake3 = "1.5"
|
||||
fs2 = "0.4.3"
|
||||
futures-util = "0.3.31"
|
||||
tokio-stream = "0.1.17"
|
||||
|
||||
# Vector database for tool selection
|
||||
lancedb = "0.13"
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use dotenv::dotenv;
|
||||
use futures::StreamExt;
|
||||
use goose::agents::{Agent, ExtensionConfig};
|
||||
use goose::agents::{Agent, AgentEvent, ExtensionConfig};
|
||||
use goose::config::{DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT};
|
||||
use goose::message::Message;
|
||||
use goose::providers::databricks::DatabricksProvider;
|
||||
@@ -20,10 +20,11 @@ async fn main() {
|
||||
|
||||
let config = ExtensionConfig::stdio(
|
||||
"developer",
|
||||
"./target/debug/developer",
|
||||
"./target/debug/goose",
|
||||
DEFAULT_EXTENSION_DESCRIPTION,
|
||||
DEFAULT_EXTENSION_TIMEOUT,
|
||||
);
|
||||
)
|
||||
.with_args(vec!["mcp", "developer"]);
|
||||
agent.add_extension(config).await.unwrap();
|
||||
|
||||
println!("Extensions:");
|
||||
@@ -35,11 +36,8 @@ async fn main() {
|
||||
.with_text("can you summarize the readme.md in this dir using just a haiku?")];
|
||||
|
||||
let mut stream = agent.reply(&messages, None).await.unwrap();
|
||||
while let Some(message) = stream.next().await {
|
||||
println!(
|
||||
"{}",
|
||||
serde_json::to_string_pretty(&message.unwrap()).unwrap()
|
||||
);
|
||||
while let Some(Ok(AgentEvent::Message(message))) = stream.next().await {
|
||||
println!("{}", serde_json::to_string_pretty(&message).unwrap());
|
||||
println!("\n");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::stream::BoxStream;
|
||||
use futures::TryStreamExt;
|
||||
use futures::{FutureExt, Stream, TryStreamExt};
|
||||
use futures_util::stream;
|
||||
use futures_util::stream::StreamExt;
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
|
||||
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
||||
use crate::message::Message;
|
||||
@@ -39,7 +44,7 @@ use mcp_core::{
|
||||
|
||||
use super::platform_tools;
|
||||
use super::router_tools;
|
||||
use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
|
||||
use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
|
||||
|
||||
/// The main goose Agent
|
||||
pub struct Agent {
|
||||
@@ -56,6 +61,12 @@ pub struct Agent {
|
||||
pub(super) router_tool_selector: Mutex<Option<Arc<Box<dyn RouterToolSelector>>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum AgentEvent {
|
||||
Message(Message),
|
||||
McpNotification((String, JsonRpcMessage)),
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
pub fn new() -> Self {
|
||||
// Create channels with buffer size 32 (adjust if needed)
|
||||
@@ -100,6 +111,40 @@ impl Default for Agent {
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ToolStreamItem<T> {
|
||||
Message(JsonRpcMessage),
|
||||
Result(T),
|
||||
}
|
||||
|
||||
pub type ToolStream = Pin<Box<dyn Stream<Item = ToolStreamItem<ToolResult<Vec<Content>>>> + Send>>;
|
||||
|
||||
// tool_stream combines a stream of JsonRpcMessages with a future representing the
|
||||
// final result of the tool call. MCP notifications are not request-scoped, but
|
||||
// this lets us capture all notifications emitted during the tool call for
|
||||
// simpler consumption
|
||||
pub fn tool_stream<S, F>(rx: S, done: F) -> ToolStream
|
||||
where
|
||||
S: Stream<Item = JsonRpcMessage> + Send + Unpin + 'static,
|
||||
F: Future<Output = ToolResult<Vec<Content>>> + Send + 'static,
|
||||
{
|
||||
Box::pin(async_stream::stream! {
|
||||
tokio::pin!(done);
|
||||
let mut rx = rx;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(msg) = rx.next() => {
|
||||
yield ToolStreamItem::Message(msg);
|
||||
}
|
||||
r = &mut done => {
|
||||
yield ToolStreamItem::Result(r);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
/// Get a reference count clone to the provider
|
||||
pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> {
|
||||
@@ -143,7 +188,7 @@ impl Agent {
|
||||
&self,
|
||||
tool_call: mcp_core::tool::ToolCall,
|
||||
request_id: String,
|
||||
) -> (String, Result<Vec<Content>, ToolError>) {
|
||||
) -> (String, Result<ToolCallResult, ToolError>) {
|
||||
// Check if this tool call should be allowed based on repetition monitoring
|
||||
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
|
||||
let tool_call_info = ToolCall::new(tool_call.name.clone(), tool_call.arguments.clone());
|
||||
@@ -171,52 +216,65 @@ impl Agent {
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
return self
|
||||
let (request_id, result) = self
|
||||
.manage_extensions(action, extension_name, request_id)
|
||||
.await;
|
||||
|
||||
return (request_id, Ok(ToolCallResult::from(result)));
|
||||
}
|
||||
|
||||
let extension_manager = self.extension_manager.lock().await;
|
||||
let result = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
|
||||
let result: ToolCallResult = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
|
||||
// Check if the tool is read_resource and handle it separately
|
||||
extension_manager
|
||||
.read_resource(tool_call.arguments.clone())
|
||||
.await
|
||||
ToolCallResult::from(
|
||||
extension_manager
|
||||
.read_resource(tool_call.arguments.clone())
|
||||
.await,
|
||||
)
|
||||
} else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME {
|
||||
extension_manager
|
||||
.list_resources(tool_call.arguments.clone())
|
||||
.await
|
||||
ToolCallResult::from(
|
||||
extension_manager
|
||||
.list_resources(tool_call.arguments.clone())
|
||||
.await,
|
||||
)
|
||||
} else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME {
|
||||
extension_manager.search_available_extensions().await
|
||||
ToolCallResult::from(extension_manager.search_available_extensions().await)
|
||||
} else if self.is_frontend_tool(&tool_call.name).await {
|
||||
// For frontend tools, return an error indicating we need frontend execution
|
||||
Err(ToolError::ExecutionError(
|
||||
ToolCallResult::from(Err(ToolError::ExecutionError(
|
||||
"Frontend tool execution required".to_string(),
|
||||
))
|
||||
)))
|
||||
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME {
|
||||
let selector = self.router_tool_selector.lock().await.clone();
|
||||
if let Some(selector) = selector {
|
||||
ToolCallResult::from(if let Some(selector) = selector {
|
||||
selector.select_tools(tool_call.arguments.clone()).await
|
||||
} else {
|
||||
Err(ToolError::ExecutionError(
|
||||
"Encountered vector search error.".to_string(),
|
||||
))
|
||||
}
|
||||
})
|
||||
} else {
|
||||
extension_manager
|
||||
// Clone the result to ensure no references to extension_manager are returned
|
||||
let result = extension_manager
|
||||
.dispatch_tool_call(tool_call.clone())
|
||||
.await
|
||||
.await;
|
||||
match result {
|
||||
Ok(call_result) => call_result,
|
||||
Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))),
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"input" = serde_json::to_string(&tool_call).unwrap(),
|
||||
"output" = serde_json::to_string(&result).unwrap(),
|
||||
);
|
||||
|
||||
// Process the response to handle large text content
|
||||
let processed_result = super::large_response_handler::process_tool_response(result);
|
||||
|
||||
(request_id, processed_result)
|
||||
(
|
||||
request_id,
|
||||
Ok(ToolCallResult {
|
||||
notification_stream: result.notification_stream,
|
||||
result: Box::new(
|
||||
result
|
||||
.result
|
||||
.map(super::large_response_handler::process_tool_response),
|
||||
),
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub(super) async fn manage_extensions(
|
||||
@@ -466,7 +524,7 @@ impl Agent {
|
||||
&self,
|
||||
messages: &[Message],
|
||||
session: Option<SessionConfig>,
|
||||
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
|
||||
) -> anyhow::Result<BoxStream<'_, anyhow::Result<AgentEvent>>> {
|
||||
let mut messages = messages.to_vec();
|
||||
let reply_span = tracing::Span::current();
|
||||
|
||||
@@ -532,9 +590,8 @@ impl Agent {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Yield the assistant's response with frontend tool requests filtered out
|
||||
yield filtered_response.clone();
|
||||
yield AgentEvent::Message(filtered_response.clone());
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
@@ -556,7 +613,7 @@ impl Agent {
|
||||
// execution is yeield back to this reply loop, and is of the same Message
|
||||
// type, so we can yield that back up to be handled
|
||||
while let Some(msg) = frontend_tool_stream.try_next().await? {
|
||||
yield msg;
|
||||
yield AgentEvent::Message(msg);
|
||||
}
|
||||
|
||||
// Clone goose_mode once before the match to avoid move issues
|
||||
@@ -584,13 +641,23 @@ impl Agent {
|
||||
self.provider().await?).await;
|
||||
|
||||
// Handle pre-approved and read-only tools in parallel
|
||||
let mut tool_futures: Vec<ToolFuture> = Vec::new();
|
||||
let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();
|
||||
|
||||
// Skip the confirmation for approved tools
|
||||
for request in &permission_check_result.approved {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let tool_future = self.dispatch_tool_call(tool_call, request.id.clone());
|
||||
tool_futures.push(Box::pin(tool_future));
|
||||
let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await;
|
||||
|
||||
tool_futures.push((req_id, match tool_result {
|
||||
Ok(result) => tool_stream(
|
||||
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
|
||||
result.result,
|
||||
),
|
||||
Err(e) => tool_stream(
|
||||
Box::new(stream::empty()),
|
||||
futures::future::ready(Err(e)),
|
||||
),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -618,7 +685,7 @@ impl Agent {
|
||||
// type, so we can yield the Message back up to be handled and grab any
|
||||
// confirmations or denials
|
||||
while let Some(msg) = tool_approval_stream.try_next().await? {
|
||||
yield msg;
|
||||
yield AgentEvent::Message(msg);
|
||||
}
|
||||
|
||||
tool_futures = {
|
||||
@@ -628,16 +695,30 @@ impl Agent {
|
||||
futures_lock.drain(..).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
// Wait for all tool calls to complete
|
||||
let results = futures::future::join_all(tool_futures).await;
|
||||
let with_id = tool_futures
|
||||
.into_iter()
|
||||
.map(|(request_id, stream)| {
|
||||
stream.map(move |item| (request_id.clone(), item))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut combined = stream::select_all(with_id);
|
||||
|
||||
let mut all_install_successful = true;
|
||||
|
||||
for (request_id, output) in results.into_iter() {
|
||||
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
|
||||
all_install_successful = false;
|
||||
while let Some((request_id, item)) = combined.next().await {
|
||||
match item {
|
||||
ToolStreamItem::Result(output) => {
|
||||
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
|
||||
all_install_successful = false;
|
||||
}
|
||||
let mut response = message_tool_response.lock().await;
|
||||
*response = response.clone().with_tool_response(request_id, output);
|
||||
},
|
||||
ToolStreamItem::Message(msg) => {
|
||||
yield AgentEvent::McpNotification((request_id, msg))
|
||||
}
|
||||
}
|
||||
let mut response = message_tool_response.lock().await;
|
||||
*response = response.clone().with_tool_response(request_id, output);
|
||||
}
|
||||
|
||||
// Update system prompt and tools if installations were successful
|
||||
@@ -647,7 +728,7 @@ impl Agent {
|
||||
}
|
||||
|
||||
let final_message_tool_resp = message_tool_response.lock().await.clone();
|
||||
yield final_message_tool_resp.clone();
|
||||
yield AgentEvent::Message(final_message_tool_resp.clone());
|
||||
|
||||
messages.push(response);
|
||||
messages.push(final_message_tool_resp);
|
||||
@@ -656,15 +737,15 @@ impl Agent {
|
||||
// At this point, the last message should be a user message
|
||||
// because call to provider led to context length exceeded error
|
||||
// Immediately yield a special message and break
|
||||
yield Message::assistant().with_context_length_exceeded(
|
||||
yield AgentEvent::Message(Message::assistant().with_context_length_exceeded(
|
||||
"The context length of the model has been exceeded. Please start a new session and try again.",
|
||||
);
|
||||
));
|
||||
break;
|
||||
},
|
||||
Err(e) => {
|
||||
// Create an error message & terminate the stream
|
||||
error!("Error: {}", e);
|
||||
yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."));
|
||||
yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, TimeZone, Utc};
|
||||
use futures::future;
|
||||
use futures::stream::{FuturesUnordered, StreamExt};
|
||||
use mcp_client::McpService;
|
||||
use futures::{future, FutureExt};
|
||||
use mcp_core::protocol::GetPromptResult;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
@@ -10,15 +9,17 @@ use std::sync::LazyLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task;
|
||||
use tracing::{debug, error, warn};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tracing::{error, warn};
|
||||
|
||||
use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo};
|
||||
use super::tool_execution::ToolCallResult;
|
||||
use crate::agents::extension::Envs;
|
||||
use crate::config::{Config, ExtensionConfigManager};
|
||||
use crate::prompt_template;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
|
||||
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult};
|
||||
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError};
|
||||
use serde_json::Value;
|
||||
|
||||
// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp
|
||||
@@ -113,7 +114,8 @@ impl ExtensionManager {
|
||||
/// Add a new MCP extension based on the provided client type
|
||||
// TODO IMPORTANT need to ensure this times out if the extension command is broken!
|
||||
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
|
||||
let sanitized_name = normalize(config.key().to_string());
|
||||
let config_name = config.key().to_string();
|
||||
let sanitized_name = normalize(config_name.clone());
|
||||
|
||||
/// Helper function to merge environment variables from direct envs and keychain-stored env_keys
|
||||
async fn merge_environments(
|
||||
@@ -183,13 +185,15 @@ impl ExtensionManager {
|
||||
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
||||
let transport = SseTransport::new(uri, all_envs);
|
||||
let handle = transport.start().await?;
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
Box::new(McpClient::new(service))
|
||||
Box::new(
|
||||
McpClient::connect(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
ExtensionConfig::Stdio {
|
||||
cmd,
|
||||
@@ -202,13 +206,15 @@ impl ExtensionManager {
|
||||
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
||||
let transport = StdioTransport::new(cmd, args.to_vec(), all_envs);
|
||||
let handle = transport.start().await?;
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
Box::new(McpClient::new(service))
|
||||
Box::new(
|
||||
McpClient::connect(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
ExtensionConfig::Builtin {
|
||||
name,
|
||||
@@ -227,13 +233,15 @@ impl ExtensionManager {
|
||||
HashMap::new(),
|
||||
);
|
||||
let handle = transport.start().await?;
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
Box::new(McpClient::new(service))
|
||||
Box::new(
|
||||
McpClient::connect(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
@@ -609,7 +617,7 @@ impl ExtensionManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> ToolResult<Vec<Content>> {
|
||||
pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> Result<ToolCallResult> {
|
||||
// Dispatch tool call based on the prefix naming convention
|
||||
let (client_name, client) = self
|
||||
.get_client_for_tool(&tool_call.name)
|
||||
@@ -620,22 +628,26 @@ impl ExtensionManager {
|
||||
.name
|
||||
.strip_prefix(client_name)
|
||||
.and_then(|s| s.strip_prefix("__"))
|
||||
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?;
|
||||
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?
|
||||
.to_string();
|
||||
|
||||
let client_guard = client.lock().await;
|
||||
let arguments = tool_call.arguments.clone();
|
||||
let client = client.clone();
|
||||
let notifications_receiver = client.lock().await.subscribe().await;
|
||||
|
||||
let result = client_guard
|
||||
.call_tool(tool_name, tool_call.clone().arguments)
|
||||
.await
|
||||
.map(|result| result.content)
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()));
|
||||
let fut = async move {
|
||||
let client_guard = client.lock().await;
|
||||
client_guard
|
||||
.call_tool(&tool_name, arguments)
|
||||
.await
|
||||
.map(|call| call.content)
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))
|
||||
};
|
||||
|
||||
debug!(
|
||||
"input" = serde_json::to_string(&tool_call).unwrap(),
|
||||
"output" = serde_json::to_string(&result).unwrap(),
|
||||
);
|
||||
|
||||
result
|
||||
Ok(ToolCallResult {
|
||||
result: Box::new(fut.boxed()),
|
||||
notification_stream: Some(Box::new(ReceiverStream::new(notifications_receiver))),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn list_prompts_from_extension(
|
||||
@@ -793,10 +805,11 @@ mod tests {
|
||||
use mcp_client::client::Error;
|
||||
use mcp_client::client::McpClientTrait;
|
||||
use mcp_core::protocol::{
|
||||
CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult,
|
||||
ListToolsResult, ReadResourceResult,
|
||||
CallToolResult, GetPromptResult, InitializeResult, JsonRpcMessage, ListPromptsResult,
|
||||
ListResourcesResult, ListToolsResult, ReadResourceResult,
|
||||
};
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
struct MockClient {}
|
||||
|
||||
@@ -849,6 +862,10 @@ mod tests {
|
||||
) -> Result<GetPromptResult, Error> {
|
||||
Err(Error::NotInitialized)
|
||||
}
|
||||
|
||||
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
|
||||
mpsc::channel(1).1
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -970,6 +987,9 @@ mod tests {
|
||||
|
||||
let result = extension_manager
|
||||
.dispatch_tool_call(invalid_tool_call)
|
||||
.await
|
||||
.unwrap()
|
||||
.result
|
||||
.await;
|
||||
assert!(matches!(
|
||||
result.err().unwrap(),
|
||||
@@ -986,6 +1006,11 @@ mod tests {
|
||||
let result = extension_manager
|
||||
.dispatch_tool_call(invalid_tool_call)
|
||||
.await;
|
||||
assert!(matches!(result.err().unwrap(), ToolError::NotFound(_)));
|
||||
if let Err(err) = result {
|
||||
let tool_err = err.downcast_ref::<ToolError>().expect("Expected ToolError");
|
||||
assert!(matches!(tool_err, ToolError::NotFound(_)));
|
||||
} else {
|
||||
panic!("Expected ToolError::NotFound");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ mod tool_router_index_manager;
|
||||
pub(crate) mod tool_vectordb;
|
||||
mod types;
|
||||
|
||||
pub use agent::Agent;
|
||||
pub use agent::{Agent, AgentEvent};
|
||||
pub use extension::ExtensionConfig;
|
||||
pub use extension_manager::ExtensionManager;
|
||||
pub use prompt_manager::PromptManager;
|
||||
|
||||
@@ -1,23 +1,35 @@
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_stream::try_stream;
|
||||
use futures::stream::BoxStream;
|
||||
use futures::StreamExt;
|
||||
use futures::stream::{self, BoxStream};
|
||||
use futures::{Stream, StreamExt};
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::config::permission::PermissionLevel;
|
||||
use crate::config::PermissionManager;
|
||||
use crate::message::{Message, ToolRequest};
|
||||
use crate::permission::Permission;
|
||||
use mcp_core::{Content, ToolError};
|
||||
use mcp_core::{Content, ToolResult};
|
||||
|
||||
// Type alias for ToolFutures - used in the agent loop to join all futures together
|
||||
pub(crate) type ToolFuture<'a> =
|
||||
Pin<Box<dyn Future<Output = (String, Result<Vec<Content>, ToolError>)> + Send + 'a>>;
|
||||
pub(crate) type ToolFuturesVec<'a> = Arc<Mutex<Vec<ToolFuture<'a>>>>;
|
||||
// ToolCallResult combines the result of a tool call with an optional notification stream that
|
||||
// can be used to receive notifications from the tool.
|
||||
pub struct ToolCallResult {
|
||||
pub result: Box<dyn Future<Output = ToolResult<Vec<Content>>> + Send + Unpin>,
|
||||
pub notification_stream: Option<Box<dyn Stream<Item = JsonRpcMessage> + Send + Unpin>>,
|
||||
}
|
||||
|
||||
impl From<ToolResult<Vec<Content>>> for ToolCallResult {
|
||||
fn from(result: ToolResult<Vec<Content>>) -> Self {
|
||||
Self {
|
||||
result: Box::new(futures::future::ready(result)),
|
||||
notification_stream: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use super::agent::{tool_stream, ToolStream};
|
||||
use crate::agents::Agent;
|
||||
|
||||
pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \
|
||||
@@ -37,7 +49,7 @@ impl Agent {
|
||||
pub(crate) fn handle_approval_tool_requests<'a>(
|
||||
&'a self,
|
||||
tool_requests: &'a [ToolRequest],
|
||||
tool_futures: ToolFuturesVec<'a>,
|
||||
tool_futures: Arc<Mutex<Vec<(String, ToolStream)>>>,
|
||||
permission_manager: &'a mut PermissionManager,
|
||||
message_tool_response: Arc<Mutex<Message>>,
|
||||
) -> BoxStream<'a, anyhow::Result<Message>> {
|
||||
@@ -56,9 +68,19 @@ impl Agent {
|
||||
while let Some((req_id, confirmation)) = rx.recv().await {
|
||||
if req_id == request.id {
|
||||
if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
|
||||
let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone());
|
||||
let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone()).await;
|
||||
let mut futures = tool_futures.lock().await;
|
||||
futures.push(Box::pin(tool_future));
|
||||
|
||||
futures.push((req_id, match tool_result {
|
||||
Ok(result) => tool_stream(
|
||||
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
|
||||
result.result,
|
||||
),
|
||||
Err(e) => tool_stream(
|
||||
Box::new(stream::empty()),
|
||||
futures::future::ready(Err(e)),
|
||||
),
|
||||
}));
|
||||
|
||||
if confirmation.permission == Permission::AlwaysAllow {
|
||||
permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow);
|
||||
|
||||
@@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_cron_scheduler::{job::JobId, Job, JobScheduler as TokioJobScheduler};
|
||||
|
||||
use crate::agents::AgentEvent;
|
||||
use crate::agents::{Agent, SessionConfig};
|
||||
use crate::config::{self, Config};
|
||||
use crate::message::Message;
|
||||
@@ -1102,12 +1103,15 @@ async fn run_scheduled_job_internal(
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
match message_result {
|
||||
Ok(msg) => {
|
||||
Ok(AgentEvent::Message(msg)) => {
|
||||
if msg.role == mcp_core::role::Role::Assistant {
|
||||
tracing::info!("[Job {}] Assistant: {:?}", job.id, msg.content);
|
||||
}
|
||||
all_session_messages.push(msg);
|
||||
}
|
||||
Ok(AgentEvent::McpNotification(_)) => {
|
||||
// Handle notifications if needed
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"[Job {}] Error receiving message from agent: {}",
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use futures::StreamExt;
|
||||
use goose::agents::Agent;
|
||||
use goose::agents::{Agent, AgentEvent};
|
||||
use goose::message::Message;
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::Provider;
|
||||
@@ -132,7 +132,10 @@ async fn run_truncate_test(
|
||||
let mut responses = Vec::new();
|
||||
while let Some(response_result) = reply_stream.next().await {
|
||||
match response_result {
|
||||
Ok(response) => responses.push(response),
|
||||
Ok(AgentEvent::Message(response)) => responses.push(response),
|
||||
Ok(AgentEvent::McpNotification(n)) => {
|
||||
println!("MCP Notification: {n:?}");
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Error: {:?}", e);
|
||||
return Err(e);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use mcp_client::{
|
||||
client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait},
|
||||
transport::{SseTransport, StdioTransport, Transport},
|
||||
McpService,
|
||||
};
|
||||
use rand::Rng;
|
||||
use rand::SeedableRng;
|
||||
@@ -20,18 +19,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
|
||||
let handle1 = transport1.start().await?;
|
||||
let service1 = McpService::with_timeout(handle1, Duration::from_secs(30));
|
||||
let client1 = McpClient::new(service1);
|
||||
let client1 = McpClient::connect(handle1, Duration::from_secs(30)).await?;
|
||||
|
||||
let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
|
||||
let handle2 = transport2.start().await?;
|
||||
let service2 = McpService::with_timeout(handle2, Duration::from_secs(30));
|
||||
let client2 = McpClient::new(service2);
|
||||
let client2 = McpClient::connect(handle2, Duration::from_secs(30)).await?;
|
||||
|
||||
let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new());
|
||||
let handle3 = transport3.start().await?;
|
||||
let service3 = McpService::with_timeout(handle3, Duration::from_secs(10));
|
||||
let client3 = McpClient::new(service3);
|
||||
let client3 = McpClient::connect(handle3, Duration::from_secs(10)).await?;
|
||||
|
||||
// Initialize both clients
|
||||
let mut clients: Vec<Box<dyn McpClientTrait>> =
|
||||
|
||||
122
crates/mcp-client/examples/integration_test.rs
Normal file
122
crates/mcp-client/examples/integration_test.rs
Normal file
@@ -0,0 +1,122 @@
|
||||
use anyhow::Result;
|
||||
use futures::lock::Mutex;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
use mcp_client::transport::{SseTransport, Transport};
|
||||
use mcp_client::StdioTransport;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize logging
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
EnvFilter::from_default_env()
|
||||
.add_directive("mcp_client=debug".parse().unwrap())
|
||||
.add_directive("eventsource_client=info".parse().unwrap()),
|
||||
)
|
||||
.init();
|
||||
|
||||
test_transport(sse_transport().await?).await?;
|
||||
test_transport(stdio_transport().await?).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn sse_transport() -> Result<SseTransport> {
|
||||
let port = "60053";
|
||||
|
||||
tokio::process::Command::new("npx")
|
||||
.env("PORT", port)
|
||||
.arg("@modelcontextprotocol/server-everything")
|
||||
.arg("sse")
|
||||
.spawn()?;
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
|
||||
Ok(SseTransport::new(
|
||||
format!("http://localhost:{}/sse", port),
|
||||
HashMap::new(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn stdio_transport() -> Result<StdioTransport> {
|
||||
Ok(StdioTransport::new(
|
||||
"npx",
|
||||
vec!["@modelcontextprotocol/server-everything"]
|
||||
.into_iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
HashMap::new(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn test_transport<T>(transport: T) -> Result<()>
|
||||
where
|
||||
T: Transport + Send + 'static,
|
||||
{
|
||||
// Start transport
|
||||
let handle = transport.start().await?;
|
||||
|
||||
// Create client
|
||||
let mut client = McpClient::connect(handle, Duration::from_secs(10)).await?;
|
||||
println!("Client created\n");
|
||||
|
||||
let mut receiver = client.subscribe().await;
|
||||
let events = Arc::new(Mutex::new(Vec::new()));
|
||||
let events_clone = events.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(event) = receiver.recv().await {
|
||||
println!("Received event: {event:?}");
|
||||
events_clone.lock().await.push(event);
|
||||
}
|
||||
});
|
||||
|
||||
// Initialize
|
||||
let server_info = client
|
||||
.initialize(
|
||||
ClientInfo {
|
||||
name: "test-client".into(),
|
||||
version: "1.0.0".into(),
|
||||
},
|
||||
ClientCapabilities::default(),
|
||||
)
|
||||
.await?;
|
||||
println!("Connected to server: {server_info:?}\n");
|
||||
|
||||
// Sleep for 100ms to allow the server to start - surprisingly this is required!
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// List tools
|
||||
let tools = client.list_tools(None).await?;
|
||||
println!("Available tools: {tools:#?}\n");
|
||||
|
||||
// Call tool
|
||||
let tool_result = client
|
||||
.call_tool("echo", serde_json::json!({ "message": "honk" }))
|
||||
.await?;
|
||||
println!("Tool result: {tool_result:#?}\n");
|
||||
|
||||
let collected_eventes_before = events.lock().await.len();
|
||||
let n_steps = 5;
|
||||
let long_op = client
|
||||
.call_tool(
|
||||
"longRunningOperation",
|
||||
serde_json::json!({ "duration": 3, "steps": n_steps }),
|
||||
)
|
||||
.await?;
|
||||
println!("Long op result: {long_op:#?}\n");
|
||||
let collected_events_after = events.lock().await.len();
|
||||
assert_eq!(collected_events_after - collected_eventes_before, n_steps);
|
||||
|
||||
// List resources
|
||||
let resources = client.list_resources(None).await?;
|
||||
println!("Resources: {resources:#?}\n");
|
||||
|
||||
// Read resource
|
||||
let resource = client.read_resource("test://static/resource/1").await?;
|
||||
println!("Resource: {resource:#?}\n");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
use mcp_client::transport::{SseTransport, Transport};
|
||||
use mcp_client::McpService;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
@@ -23,11 +22,8 @@ async fn main() -> Result<()> {
|
||||
// Start transport
|
||||
let handle = transport.start().await?;
|
||||
|
||||
// Create the service with timeout middleware
|
||||
let service = McpService::with_timeout(handle, Duration::from_secs(3));
|
||||
|
||||
// Create client
|
||||
let mut client = McpClient::new(service);
|
||||
let mut client = McpClient::connect(handle, Duration::from_secs(3)).await?;
|
||||
println!("Client created\n");
|
||||
|
||||
// Initialize
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
use mcp_client::{
|
||||
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, McpService,
|
||||
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
|
||||
StdioTransport, Transport,
|
||||
};
|
||||
use std::time::Duration;
|
||||
@@ -25,11 +25,8 @@ async fn main() -> Result<(), ClientError> {
|
||||
// 2) Start the transport to get a handle
|
||||
let transport_handle = transport.start().await?;
|
||||
|
||||
// 3) Create the service with timeout middleware
|
||||
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));
|
||||
|
||||
// 4) Create the client with the middleware-wrapped service
|
||||
let mut client = McpClient::new(service);
|
||||
// 3) Create the client with the middleware-wrapped service
|
||||
let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?;
|
||||
|
||||
// Initialize
|
||||
let server_info = client
|
||||
|
||||
@@ -5,7 +5,6 @@ use mcp_client::client::{
|
||||
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
|
||||
};
|
||||
use mcp_client::transport::{StdioTransport, Transport};
|
||||
use mcp_client::McpService;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
@@ -34,11 +33,8 @@ async fn main() -> Result<(), ClientError> {
|
||||
// Start the transport to get a handle
|
||||
let transport_handle = transport.start().await.unwrap();
|
||||
|
||||
// Create the service with timeout middleware
|
||||
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));
|
||||
|
||||
// Create client
|
||||
let mut client = McpClient::new(service);
|
||||
let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?;
|
||||
|
||||
// Initialize
|
||||
let server_info = client
|
||||
|
||||
@@ -4,11 +4,16 @@ use mcp_core::protocol::{
|
||||
ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::Mutex;
|
||||
use tower::{Service, ServiceExt}; // for Service::ready()
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tower::{timeout::TimeoutLayer, Layer, Service, ServiceExt};
|
||||
|
||||
use crate::{McpService, TransportHandle};
|
||||
|
||||
pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
|
||||
|
||||
@@ -97,34 +102,67 @@ pub trait McpClientTrait: Send + Sync {
|
||||
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
|
||||
|
||||
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
|
||||
|
||||
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage>;
|
||||
}
|
||||
|
||||
/// The MCP client is the interface for MCP operations.
|
||||
pub struct McpClient<S>
|
||||
pub struct McpClient<T>
|
||||
where
|
||||
S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
|
||||
S::Error: Into<Error>,
|
||||
S::Future: Send,
|
||||
T: TransportHandle + Send + Sync + 'static,
|
||||
{
|
||||
service: Mutex<S>,
|
||||
service: Mutex<tower::timeout::Timeout<McpService<T>>>,
|
||||
next_id: AtomicU64,
|
||||
server_capabilities: Option<ServerCapabilities>,
|
||||
server_info: Option<Implementation>,
|
||||
notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<JsonRpcMessage>>>>,
|
||||
}
|
||||
|
||||
impl<S> McpClient<S>
|
||||
impl<T> McpClient<T>
|
||||
where
|
||||
S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
|
||||
S::Error: Into<Error>,
|
||||
S::Future: Send,
|
||||
T: TransportHandle + Send + Sync + 'static,
|
||||
{
|
||||
pub fn new(service: S) -> Self {
|
||||
Self {
|
||||
service: Mutex::new(service),
|
||||
pub async fn connect(transport: T, timeout: std::time::Duration) -> Result<Self, Error> {
|
||||
let service = McpService::new(transport.clone());
|
||||
let service_ptr = service.clone();
|
||||
let notification_subscribers =
|
||||
Arc::new(Mutex::new(Vec::<mpsc::Sender<JsonRpcMessage>>::new()));
|
||||
let subscribers_ptr = notification_subscribers.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match transport.receive().await {
|
||||
Ok(message) => {
|
||||
tracing::info!("Received message: {:?}", message);
|
||||
match message {
|
||||
JsonRpcMessage::Response(JsonRpcResponse { id: Some(id), .. }) => {
|
||||
service_ptr.respond(&id.to_string(), Ok(message)).await;
|
||||
}
|
||||
_ => {
|
||||
let mut subs = subscribers_ptr.lock().await;
|
||||
subs.retain(|sub| sub.try_send(message.clone()).is_ok());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("transport error: {:?}", e);
|
||||
service_ptr.hangup().await;
|
||||
subscribers_ptr.lock().await.clear();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let middleware = TimeoutLayer::new(timeout);
|
||||
|
||||
Ok(Self {
|
||||
service: Mutex::new(middleware.layer(service)),
|
||||
next_id: AtomicU64::new(1),
|
||||
server_capabilities: None,
|
||||
server_info: None,
|
||||
}
|
||||
notification_subscribers,
|
||||
})
|
||||
}
|
||||
|
||||
/// Send a JSON-RPC request and check we don't get an error response.
|
||||
@@ -134,13 +172,18 @@ where
|
||||
{
|
||||
let mut service = self.service.lock().await;
|
||||
service.ready().await.map_err(|_| Error::NotReady)?;
|
||||
|
||||
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
let mut params = params.clone();
|
||||
params["_meta"] = json!({
|
||||
"progressToken": format!("prog-{}", id),
|
||||
});
|
||||
|
||||
let request = JsonRpcMessage::Request(JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: Some(id),
|
||||
method: method.to_string(),
|
||||
params: Some(params.clone()),
|
||||
params: Some(params),
|
||||
});
|
||||
|
||||
let response_msg = service
|
||||
@@ -154,7 +197,7 @@ where
|
||||
.unwrap_or("".to_string()),
|
||||
method: method.to_string(),
|
||||
// we don't need include params because it can be really large
|
||||
source: Box::new(e.into()),
|
||||
source: Box::<Error>::new(e.into()),
|
||||
})?;
|
||||
|
||||
match response_msg {
|
||||
@@ -220,7 +263,7 @@ where
|
||||
.unwrap_or("".to_string()),
|
||||
method: method.to_string(),
|
||||
// we don't need include params because it can be really large
|
||||
source: Box::new(e.into()),
|
||||
source: Box::<Error>::new(e.into()),
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
@@ -233,11 +276,9 @@ where
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<S> McpClientTrait for McpClient<S>
|
||||
impl<T> McpClientTrait for McpClient<T>
|
||||
where
|
||||
S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
|
||||
S::Error: Into<Error>,
|
||||
S::Future: Send,
|
||||
T: TransportHandle + Send + Sync + 'static,
|
||||
{
|
||||
async fn initialize(
|
||||
&mut self,
|
||||
@@ -388,4 +429,10 @@ where
|
||||
|
||||
self.send_request("prompts/get", params).await
|
||||
}
|
||||
|
||||
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
|
||||
let (tx, rx) = mpsc::channel(16);
|
||||
self.notification_subscribers.lock().await.push(tx);
|
||||
rx
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use futures::future::BoxFuture;
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::sync::{oneshot, RwLock};
|
||||
use tower::{timeout::Timeout, Service, ServiceBuilder};
|
||||
|
||||
use crate::transport::{Error, TransportHandle};
|
||||
@@ -10,14 +12,24 @@ use crate::transport::{Error, TransportHandle};
|
||||
#[derive(Clone)]
|
||||
pub struct McpService<T: TransportHandle> {
|
||||
inner: Arc<T>,
|
||||
pending_requests: Arc<PendingRequests>,
|
||||
}
|
||||
|
||||
impl<T: TransportHandle> McpService<T> {
|
||||
pub fn new(transport: T) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(transport),
|
||||
pending_requests: Arc::new(PendingRequests::default()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
|
||||
self.pending_requests.respond(id, response).await
|
||||
}
|
||||
|
||||
pub async fn hangup(&self) {
|
||||
self.pending_requests.broadcast_close().await
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Service<JsonRpcMessage> for McpService<T>
|
||||
@@ -35,7 +47,31 @@ where
|
||||
|
||||
fn call(&mut self, request: JsonRpcMessage) -> Self::Future {
|
||||
let transport = self.inner.clone();
|
||||
Box::pin(async move { transport.send(request).await })
|
||||
let pending_requests = self.pending_requests.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
match request {
|
||||
JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) => {
|
||||
// Create a channel to receive the response
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
pending_requests.insert(id.to_string(), sender).await;
|
||||
|
||||
transport.send(request).await?;
|
||||
receiver.await.map_err(|_| Error::ChannelClosed)?
|
||||
}
|
||||
JsonRpcMessage::Request(_) => {
|
||||
// Handle notifications without waiting for a response
|
||||
transport.send(request).await?;
|
||||
Ok(JsonRpcMessage::Nil)
|
||||
}
|
||||
JsonRpcMessage::Notification(_) => {
|
||||
// Handle notifications without waiting for a response
|
||||
transport.send(request).await?;
|
||||
Ok(JsonRpcMessage::Nil)
|
||||
}
|
||||
_ => Err(Error::UnsupportedMessage),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,3 +86,50 @@ where
|
||||
.service(McpService::new(transport))
|
||||
}
|
||||
}
|
||||
|
||||
// A data structure to store pending requests and their response channels
|
||||
pub struct PendingRequests {
|
||||
requests: RwLock<HashMap<String, oneshot::Sender<Result<JsonRpcMessage, Error>>>>,
|
||||
}
|
||||
|
||||
impl Default for PendingRequests {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PendingRequests {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
requests: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn insert(&self, id: String, sender: oneshot::Sender<Result<JsonRpcMessage, Error>>) {
|
||||
self.requests.write().await.insert(id, sender);
|
||||
}
|
||||
|
||||
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
|
||||
if let Some(tx) = self.requests.write().await.remove(id) {
|
||||
let _ = tx.send(response);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn broadcast_close(&self) {
|
||||
for (_, tx) in self.requests.write().await.drain() {
|
||||
let _ = tx.send(Err(Error::ChannelClosed));
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn clear(&self) {
|
||||
self.requests.write().await.clear();
|
||||
}
|
||||
|
||||
pub async fn len(&self) -> usize {
|
||||
self.requests.read().await.len()
|
||||
}
|
||||
|
||||
pub async fn is_empty(&self) -> bool {
|
||||
self.len().await == 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use async_trait::async_trait;
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{mpsc, oneshot, RwLock};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
|
||||
/// A generic error type for transport operations.
|
||||
@@ -57,74 +56,20 @@ pub trait Transport {
|
||||
|
||||
#[async_trait]
|
||||
pub trait TransportHandle: Send + Sync + Clone + 'static {
|
||||
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error>;
|
||||
async fn send(&self, message: JsonRpcMessage) -> Result<(), Error>;
|
||||
async fn receive(&self) -> Result<JsonRpcMessage, Error>;
|
||||
}
|
||||
|
||||
// Helper function that contains the common send implementation
|
||||
pub async fn send_message(
|
||||
sender: &mpsc::Sender<TransportMessage>,
|
||||
pub async fn serialize_and_send(
|
||||
sender: &mpsc::Sender<String>,
|
||||
message: JsonRpcMessage,
|
||||
) -> Result<JsonRpcMessage, Error> {
|
||||
match message {
|
||||
JsonRpcMessage::Request(request) => {
|
||||
let (respond_to, response) = oneshot::channel();
|
||||
let msg = TransportMessage {
|
||||
message: JsonRpcMessage::Request(request),
|
||||
response_tx: Some(respond_to),
|
||||
};
|
||||
sender.send(msg).await.map_err(|_| Error::ChannelClosed)?;
|
||||
Ok(response.await.map_err(|_| Error::ChannelClosed)??)
|
||||
) -> Result<(), Error> {
|
||||
match serde_json::to_string(&message).map_err(Error::Serialization) {
|
||||
Ok(msg) => sender.send(msg).await.map_err(|_| Error::ChannelClosed),
|
||||
Err(e) => {
|
||||
tracing::error!(error = ?e, "Error serializing message");
|
||||
Err(e)
|
||||
}
|
||||
JsonRpcMessage::Notification(notification) => {
|
||||
let msg = TransportMessage {
|
||||
message: JsonRpcMessage::Notification(notification),
|
||||
response_tx: None,
|
||||
};
|
||||
sender.send(msg).await.map_err(|_| Error::ChannelClosed)?;
|
||||
Ok(JsonRpcMessage::Nil)
|
||||
}
|
||||
_ => Err(Error::UnsupportedMessage),
|
||||
}
|
||||
}
|
||||
|
||||
// A data structure to store pending requests and their response channels
|
||||
pub struct PendingRequests {
|
||||
requests: RwLock<HashMap<String, oneshot::Sender<Result<JsonRpcMessage, Error>>>>,
|
||||
}
|
||||
|
||||
impl Default for PendingRequests {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PendingRequests {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
requests: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn insert(&self, id: String, sender: oneshot::Sender<Result<JsonRpcMessage, Error>>) {
|
||||
self.requests.write().await.insert(id, sender);
|
||||
}
|
||||
|
||||
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
|
||||
if let Some(tx) = self.requests.write().await.remove(id) {
|
||||
let _ = tx.send(response);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn clear(&self) {
|
||||
self.requests.write().await.clear();
|
||||
}
|
||||
|
||||
pub async fn len(&self) -> usize {
|
||||
self.requests.read().await.len()
|
||||
}
|
||||
|
||||
pub async fn is_empty(&self) -> bool {
|
||||
self.len().await == 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
use crate::transport::{Error, PendingRequests, TransportMessage};
|
||||
use crate::transport::Error;
|
||||
use async_trait::async_trait;
|
||||
use eventsource_client::{Client, SSE};
|
||||
use futures::TryStreamExt;
|
||||
use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest};
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use reqwest::Client as HttpClient;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tracing::warn;
|
||||
use url::Url;
|
||||
|
||||
use super::{send_message, Transport, TransportHandle};
|
||||
use super::{serialize_and_send, Transport, TransportHandle};
|
||||
|
||||
// Timeout for the endpoint discovery
|
||||
const ENDPOINT_TIMEOUT_SECS: u64 = 5;
|
||||
@@ -21,9 +21,9 @@ const ENDPOINT_TIMEOUT_SECS: u64 = 5;
|
||||
/// - Sends outgoing messages via HTTP POST (once the post endpoint is known).
|
||||
pub struct SseActor {
|
||||
/// Receives messages (requests/notifications) from the handle
|
||||
receiver: mpsc::Receiver<TransportMessage>,
|
||||
/// Map of request-id -> oneshot sender
|
||||
pending_requests: Arc<PendingRequests>,
|
||||
receiver: mpsc::Receiver<String>,
|
||||
/// Sends messages (responses) back to the handle
|
||||
sender: mpsc::Sender<JsonRpcMessage>,
|
||||
/// Base SSE URL
|
||||
sse_url: String,
|
||||
/// For sending HTTP POST requests
|
||||
@@ -34,14 +34,14 @@ pub struct SseActor {
|
||||
|
||||
impl SseActor {
|
||||
pub fn new(
|
||||
receiver: mpsc::Receiver<TransportMessage>,
|
||||
pending_requests: Arc<PendingRequests>,
|
||||
receiver: mpsc::Receiver<String>,
|
||||
sender: mpsc::Sender<JsonRpcMessage>,
|
||||
sse_url: String,
|
||||
post_endpoint: Arc<RwLock<Option<String>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
receiver,
|
||||
pending_requests,
|
||||
sender,
|
||||
sse_url,
|
||||
post_endpoint,
|
||||
http_client: HttpClient::new(),
|
||||
@@ -54,15 +54,14 @@ impl SseActor {
|
||||
pub async fn run(self) {
|
||||
tokio::join!(
|
||||
Self::handle_incoming_messages(
|
||||
self.sender,
|
||||
self.sse_url.clone(),
|
||||
Arc::clone(&self.pending_requests),
|
||||
Arc::clone(&self.post_endpoint)
|
||||
),
|
||||
Self::handle_outgoing_messages(
|
||||
self.receiver,
|
||||
self.http_client.clone(),
|
||||
Arc::clone(&self.post_endpoint),
|
||||
Arc::clone(&self.pending_requests),
|
||||
)
|
||||
);
|
||||
}
|
||||
@@ -72,14 +71,13 @@ impl SseActor {
|
||||
/// - If a `message` event is received, parse it as `JsonRpcMessage`
|
||||
/// and respond to pending requests if it's a `Response`.
|
||||
async fn handle_incoming_messages(
|
||||
sender: mpsc::Sender<JsonRpcMessage>,
|
||||
sse_url: String,
|
||||
pending_requests: Arc<PendingRequests>,
|
||||
post_endpoint: Arc<RwLock<Option<String>>>,
|
||||
) {
|
||||
let client = match eventsource_client::ClientBuilder::for_url(&sse_url) {
|
||||
Ok(builder) => builder.build(),
|
||||
Err(e) => {
|
||||
pending_requests.clear().await;
|
||||
warn!("Failed to connect SSE client: {}", e);
|
||||
return;
|
||||
}
|
||||
@@ -105,84 +103,54 @@ impl SseActor {
|
||||
}
|
||||
|
||||
// Now handle subsequent events
|
||||
while let Ok(Some(event)) = stream.try_next().await {
|
||||
match event {
|
||||
SSE::Event(e) if e.event_type == "message" => {
|
||||
// Attempt to parse the SSE data as a JsonRpcMessage
|
||||
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
|
||||
Ok(message) => {
|
||||
match &message {
|
||||
JsonRpcMessage::Response(response) => {
|
||||
if let Some(id) = &response.id {
|
||||
pending_requests
|
||||
.respond(&id.to_string(), Ok(message))
|
||||
.await;
|
||||
}
|
||||
loop {
|
||||
match stream.try_next().await {
|
||||
Ok(Some(event)) => {
|
||||
match event {
|
||||
SSE::Event(e) if e.event_type == "message" => {
|
||||
// Attempt to parse the SSE data as a JsonRpcMessage
|
||||
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
|
||||
Ok(message) => {
|
||||
let _ = sender.send(message).await;
|
||||
}
|
||||
JsonRpcMessage::Error(error) => {
|
||||
if let Some(id) = &error.id {
|
||||
pending_requests
|
||||
.respond(&id.to_string(), Ok(message))
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to parse SSE message: {err}");
|
||||
}
|
||||
_ => {} // TODO: Handle other variants (Request, etc.)
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to parse SSE message: {err}");
|
||||
}
|
||||
_ => { /* ignore other events */ }
|
||||
}
|
||||
}
|
||||
_ => { /* ignore other events */ }
|
||||
Ok(None) => {
|
||||
// Stream ended
|
||||
tracing::info!("SSE stream ended.");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error reading SSE stream: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SSE stream ended or errored; signal any pending requests
|
||||
tracing::error!("SSE stream ended or encountered an error; clearing pending requests.");
|
||||
pending_requests.clear().await;
|
||||
tracing::error!("SSE stream ended or encountered an error.");
|
||||
}
|
||||
|
||||
/// Continuously receives messages from the `mpsc::Receiver`.
|
||||
/// - If it's a request, store the oneshot in `pending_requests`.
|
||||
/// - POST the message to the discovered endpoint (once known).
|
||||
async fn handle_outgoing_messages(
|
||||
mut receiver: mpsc::Receiver<TransportMessage>,
|
||||
mut receiver: mpsc::Receiver<String>,
|
||||
http_client: HttpClient,
|
||||
post_endpoint: Arc<RwLock<Option<String>>>,
|
||||
pending_requests: Arc<PendingRequests>,
|
||||
) {
|
||||
while let Some(transport_msg) = receiver.recv().await {
|
||||
while let Some(message_str) = receiver.recv().await {
|
||||
let post_url = match post_endpoint.read().await.as_ref() {
|
||||
Some(url) => url.clone(),
|
||||
None => {
|
||||
if let Some(response_tx) = transport_msg.response_tx {
|
||||
let _ = response_tx.send(Err(Error::NotConnected));
|
||||
}
|
||||
// TODO: the endpoint isn't discovered yet. This shouldn't happen -- we only return the handle
|
||||
// after the endpoint is set.
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Serialize the JSON-RPC message
|
||||
let message_str = match serde_json::to_string(&transport_msg.message) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
if let Some(tx) = transport_msg.response_tx {
|
||||
let _ = tx.send(Err(Error::Serialization(e)));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// If it's a request, store the channel so we can respond later
|
||||
if let Some(response_tx) = transport_msg.response_tx {
|
||||
if let JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) =
|
||||
&transport_msg.message
|
||||
{
|
||||
pending_requests.insert(id.to_string(), response_tx).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Perform the HTTP POST
|
||||
match http_client
|
||||
.post(&post_url)
|
||||
@@ -209,26 +177,25 @@ impl SseActor {
|
||||
}
|
||||
}
|
||||
|
||||
// mpsc channel closed => no more outgoing messages
|
||||
let pending = pending_requests.len().await;
|
||||
if pending > 0 {
|
||||
tracing::error!("SSE stream ended or encountered an error with {pending} unfulfilled pending requests.");
|
||||
pending_requests.clear().await;
|
||||
} else {
|
||||
tracing::info!("SseActor shutdown cleanly. No pending requests.");
|
||||
}
|
||||
tracing::info!("SseActor shut down.");
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SseTransportHandle {
|
||||
sender: mpsc::Sender<TransportMessage>,
|
||||
sender: mpsc::Sender<String>,
|
||||
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransportHandle for SseTransportHandle {
|
||||
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> {
|
||||
send_message(&self.sender, message).await
|
||||
async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> {
|
||||
serialize_and_send(&self.sender, message).await
|
||||
}
|
||||
|
||||
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
|
||||
let mut receiver = self.receiver.lock().await;
|
||||
receiver.recv().await.ok_or(Error::ChannelClosed)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,17 +246,13 @@ impl Transport for SseTransport {
|
||||
|
||||
// Create a channel for outgoing TransportMessages
|
||||
let (tx, rx) = mpsc::channel(32);
|
||||
let (otx, orx) = mpsc::channel(32);
|
||||
|
||||
let post_endpoint: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
|
||||
let post_endpoint_clone = Arc::clone(&post_endpoint);
|
||||
|
||||
// Build the actor
|
||||
let actor = SseActor::new(
|
||||
rx,
|
||||
Arc::new(PendingRequests::new()),
|
||||
self.sse_url.clone(),
|
||||
post_endpoint,
|
||||
);
|
||||
let actor = SseActor::new(rx, otx, self.sse_url.clone(), post_endpoint);
|
||||
|
||||
// Spawn the actor task
|
||||
tokio::spawn(actor.run());
|
||||
@@ -301,7 +264,10 @@ impl Transport for SseTransport {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(SseTransportHandle { sender: tx }),
|
||||
Ok(_) => Ok(SseTransportHandle {
|
||||
sender: tx,
|
||||
receiver: Arc::new(Mutex::new(orx)),
|
||||
}),
|
||||
Err(e) => Err(Error::SseConnection(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ use nix::sys::signal::{kill, Signal};
|
||||
#[cfg(unix)]
|
||||
use nix::unistd::{getpgid, Pid};
|
||||
|
||||
use super::{send_message, Error, PendingRequests, Transport, TransportHandle, TransportMessage};
|
||||
use super::{serialize_and_send, Error, Transport, TransportHandle};
|
||||
|
||||
// Global to track process groups we've created
|
||||
static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
|
||||
@@ -23,8 +23,8 @@ static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
|
||||
///
|
||||
/// It uses channels for message passing and handles responses asynchronously through a background task.
|
||||
pub struct StdioActor {
|
||||
receiver: Option<mpsc::Receiver<TransportMessage>>,
|
||||
pending_requests: Arc<PendingRequests>,
|
||||
receiver: Option<mpsc::Receiver<String>>,
|
||||
sender: Option<mpsc::Sender<JsonRpcMessage>>,
|
||||
process: Child, // we store the process to keep it alive
|
||||
error_sender: mpsc::Sender<Error>,
|
||||
stdin: Option<ChildStdin>,
|
||||
@@ -55,11 +55,11 @@ impl StdioActor {
|
||||
|
||||
let stdout = self.stdout.take().expect("stdout should be available");
|
||||
let stdin = self.stdin.take().expect("stdin should be available");
|
||||
let receiver = self.receiver.take().expect("receiver should be available");
|
||||
let msg_inbox = self.receiver.take().expect("receiver should be available");
|
||||
let msg_outbox = self.sender.take().expect("sender should be available");
|
||||
|
||||
let incoming = Self::handle_incoming_messages(stdout, self.pending_requests.clone());
|
||||
let outgoing =
|
||||
Self::handle_outgoing_messages(receiver, stdin, self.pending_requests.clone());
|
||||
let incoming = Self::handle_proc_output(stdout, msg_outbox);
|
||||
let outgoing = Self::handle_proc_input(stdin, msg_inbox);
|
||||
|
||||
// take ownership of futures for tokio::select
|
||||
pin!(incoming);
|
||||
@@ -96,12 +96,9 @@ impl StdioActor {
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up regardless of which path we took
|
||||
self.pending_requests.clear().await;
|
||||
}
|
||||
|
||||
async fn handle_incoming_messages(stdout: ChildStdout, pending_requests: Arc<PendingRequests>) {
|
||||
async fn handle_proc_output(stdout: ChildStdout, sender: mpsc::Sender<JsonRpcMessage>) {
|
||||
let mut reader = BufReader::new(stdout);
|
||||
let mut line = String::new();
|
||||
loop {
|
||||
@@ -116,20 +113,12 @@ impl StdioActor {
|
||||
message = ?message,
|
||||
"Received incoming message"
|
||||
);
|
||||
|
||||
match &message {
|
||||
JsonRpcMessage::Response(response) => {
|
||||
if let Some(id) = &response.id {
|
||||
pending_requests.respond(&id.to_string(), Ok(message)).await;
|
||||
}
|
||||
}
|
||||
JsonRpcMessage::Error(error) => {
|
||||
if let Some(id) = &error.id {
|
||||
pending_requests.respond(&id.to_string(), Ok(message)).await;
|
||||
}
|
||||
}
|
||||
_ => {} // TODO: Handle other variants (Request, etc.)
|
||||
}
|
||||
let _ = sender.send(message).await;
|
||||
} else {
|
||||
tracing::warn!(
|
||||
message = ?line,
|
||||
"Failed to parse incoming message"
|
||||
);
|
||||
}
|
||||
line.clear();
|
||||
}
|
||||
@@ -141,44 +130,20 @@ impl StdioActor {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_outgoing_messages(
|
||||
mut receiver: mpsc::Receiver<TransportMessage>,
|
||||
mut stdin: ChildStdin,
|
||||
pending_requests: Arc<PendingRequests>,
|
||||
) {
|
||||
while let Some(mut transport_msg) = receiver.recv().await {
|
||||
let message_str = match serde_json::to_string(&transport_msg.message) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
if let Some(tx) = transport_msg.response_tx.take() {
|
||||
let _ = tx.send(Err(Error::Serialization(e)));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
tracing::debug!(message = ?transport_msg.message, "Sending outgoing message");
|
||||
|
||||
if let Some(response_tx) = transport_msg.response_tx.take() {
|
||||
if let JsonRpcMessage::Request(request) = &transport_msg.message {
|
||||
if let Some(id) = &request.id {
|
||||
pending_requests.insert(id.to_string(), response_tx).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
async fn handle_proc_input(mut stdin: ChildStdin, mut receiver: mpsc::Receiver<String>) {
|
||||
while let Some(message_str) = receiver.recv().await {
|
||||
tracing::debug!(message = ?message_str, "Sending outgoing message");
|
||||
|
||||
if let Err(e) = stdin
|
||||
.write_all(format!("{}\n", message_str).as_bytes())
|
||||
.await
|
||||
{
|
||||
tracing::error!(error = ?e, "Error writing message to child process");
|
||||
pending_requests.clear().await;
|
||||
break;
|
||||
}
|
||||
|
||||
if let Err(e) = stdin.flush().await {
|
||||
tracing::error!(error = ?e, "Error flushing message to child process");
|
||||
pending_requests.clear().await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -187,18 +152,24 @@ impl StdioActor {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StdioTransportHandle {
|
||||
sender: mpsc::Sender<TransportMessage>,
|
||||
sender: mpsc::Sender<String>, // to process
|
||||
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>, // from process
|
||||
error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TransportHandle for StdioTransportHandle {
|
||||
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> {
|
||||
let result = send_message(&self.sender, message).await;
|
||||
async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> {
|
||||
let result = serialize_and_send(&self.sender, message).await;
|
||||
// Check for any pending errors even if send is successful
|
||||
self.check_for_errors().await?;
|
||||
result
|
||||
}
|
||||
|
||||
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
|
||||
let mut receiver = self.receiver.lock().await;
|
||||
receiver.recv().await.ok_or(Error::ChannelClosed)
|
||||
}
|
||||
}
|
||||
|
||||
impl StdioTransportHandle {
|
||||
@@ -289,12 +260,13 @@ impl Transport for StdioTransport {
|
||||
|
||||
async fn start(&self) -> Result<Self::Handle, Error> {
|
||||
let (process, stdin, stdout, stderr) = self.spawn_process().await?;
|
||||
let (message_tx, message_rx) = mpsc::channel(32);
|
||||
let (outbox_tx, outbox_rx) = mpsc::channel(32);
|
||||
let (inbox_tx, inbox_rx) = mpsc::channel(32);
|
||||
let (error_tx, error_rx) = mpsc::channel(1);
|
||||
|
||||
let actor = StdioActor {
|
||||
receiver: Some(message_rx),
|
||||
pending_requests: Arc::new(PendingRequests::new()),
|
||||
receiver: Some(outbox_rx), // client to process
|
||||
sender: Some(inbox_tx), // process to client
|
||||
process,
|
||||
error_sender: error_tx,
|
||||
stdin: Some(stdin),
|
||||
@@ -305,7 +277,8 @@ impl Transport for StdioTransport {
|
||||
tokio::spawn(actor.run());
|
||||
|
||||
let handle = StdioTransportHandle {
|
||||
sender: message_tx,
|
||||
sender: outbox_tx, // client to process
|
||||
receiver: Arc::new(Mutex::new(inbox_rx)), // process to client
|
||||
error_receiver: Arc::new(Mutex::new(error_rx)),
|
||||
};
|
||||
Ok(handle)
|
||||
|
||||
@@ -4,9 +4,13 @@ use std::{
|
||||
};
|
||||
|
||||
use futures::{Future, Stream};
|
||||
use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
|
||||
use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcResponse};
|
||||
use pin_project::pin_project;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
use router::McpRequest;
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader},
|
||||
sync::mpsc,
|
||||
};
|
||||
use tower_service::Service;
|
||||
|
||||
mod errors;
|
||||
@@ -123,7 +127,7 @@ pub struct Server<S> {
|
||||
|
||||
impl<S> Server<S>
|
||||
where
|
||||
S: Service<JsonRpcRequest, Response = JsonRpcResponse> + Send,
|
||||
S: Service<McpRequest, Response = JsonRpcResponse> + Send,
|
||||
S::Error: Into<BoxError>,
|
||||
S::Future: Send,
|
||||
{
|
||||
@@ -134,8 +138,8 @@ where
|
||||
// TODO transport trait instead of byte transport if we implement others
|
||||
pub async fn run<R, W>(self, mut transport: ByteTransport<R, W>) -> Result<(), ServerError>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
W: AsyncWrite + Unpin,
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
use futures::StreamExt;
|
||||
let mut service = self.service;
|
||||
@@ -160,7 +164,22 @@ where
|
||||
);
|
||||
|
||||
// Process the request using our service
|
||||
let response = match service.call(request).await {
|
||||
let (notify_tx, mut notify_rx) = mpsc::channel(256);
|
||||
let mcp_request = McpRequest {
|
||||
request,
|
||||
notifier: notify_tx,
|
||||
};
|
||||
|
||||
let transport_fut = tokio::spawn(async move {
|
||||
while let Some(notification) = notify_rx.recv().await {
|
||||
if transport.write_message(notification).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
transport
|
||||
});
|
||||
|
||||
let response = match service.call(mcp_request).await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
let error_msg = e.into().to_string();
|
||||
@@ -178,6 +197,16 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
transport = match transport_fut.await {
|
||||
Ok(transport) => transport,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to spawn transport task");
|
||||
return Err(ServerError::Transport(TransportError::Io(
|
||||
e.into(),
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Serialize response for logging
|
||||
let response_json = serde_json::to_string(&response)
|
||||
.unwrap_or_else(|_| "Failed to serialize response".to_string());
|
||||
@@ -247,7 +276,7 @@ where
|
||||
// Any router implements this
|
||||
pub trait BoundedService:
|
||||
Service<
|
||||
JsonRpcRequest,
|
||||
McpRequest,
|
||||
Response = JsonRpcResponse,
|
||||
Error = BoxError,
|
||||
Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>,
|
||||
@@ -259,7 +288,7 @@ pub trait BoundedService:
|
||||
// Implement it for any type that meets the bounds
|
||||
impl<T> BoundedService for T where
|
||||
T: Service<
|
||||
JsonRpcRequest,
|
||||
McpRequest,
|
||||
Response = JsonRpcResponse,
|
||||
Error = BoxError,
|
||||
Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>,
|
||||
|
||||
@@ -2,12 +2,14 @@ use anyhow::Result;
|
||||
use mcp_core::content::Content;
|
||||
use mcp_core::handler::{PromptError, ResourceError};
|
||||
use mcp_core::prompt::{Prompt, PromptArgument};
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use mcp_core::tool::ToolAnnotations;
|
||||
use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool};
|
||||
use mcp_server::router::{CapabilitiesBuilder, RouterService};
|
||||
use mcp_server::{ByteTransport, Router, Server};
|
||||
use serde_json::Value;
|
||||
use std::{future::Future, pin::Pin, sync::Arc};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::{
|
||||
io::{stdin, stdout},
|
||||
sync::Mutex,
|
||||
@@ -124,6 +126,7 @@ impl Router for CounterRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
_arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -11,14 +11,15 @@ use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::{Prompt, PromptMessage, PromptMessageRole},
|
||||
protocol::{
|
||||
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcRequest,
|
||||
JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult,
|
||||
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcMessage,
|
||||
JsonRpcRequest, JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult,
|
||||
PromptsCapability, ReadResourceResult, ResourcesCapability, ServerCapabilities,
|
||||
ToolsCapability,
|
||||
},
|
||||
ResourceContents,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
use tower_service::Service;
|
||||
|
||||
use crate::{BoxError, RouterError};
|
||||
@@ -91,6 +92,7 @@ pub trait Router: Send + Sync + 'static {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>>;
|
||||
fn list_resources(&self) -> Vec<mcp_core::resource::Resource>;
|
||||
fn read_resource(
|
||||
@@ -159,6 +161,7 @@ pub trait Router: Send + Sync + 'static {
|
||||
fn handle_tools_call(
|
||||
&self,
|
||||
req: JsonRpcRequest,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
|
||||
async move {
|
||||
let params = req
|
||||
@@ -172,7 +175,7 @@ pub trait Router: Send + Sync + 'static {
|
||||
|
||||
let arguments = params.get("arguments").cloned().unwrap_or(Value::Null);
|
||||
|
||||
let result = match self.call_tool(name, arguments).await {
|
||||
let result = match self.call_tool(name, arguments, notifier).await {
|
||||
Ok(result) => CallToolResult {
|
||||
content: result,
|
||||
is_error: None,
|
||||
@@ -394,7 +397,12 @@ pub trait Router: Send + Sync + 'static {
|
||||
|
||||
pub struct RouterService<T>(pub T);
|
||||
|
||||
impl<T> Service<JsonRpcRequest> for RouterService<T>
|
||||
pub struct McpRequest {
|
||||
pub request: JsonRpcRequest,
|
||||
pub notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
}
|
||||
|
||||
impl<T> Service<McpRequest> for RouterService<T>
|
||||
where
|
||||
T: Router + Clone + Send + Sync + 'static,
|
||||
{
|
||||
@@ -406,21 +414,21 @@ where
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
|
||||
fn call(&mut self, req: McpRequest) -> Self::Future {
|
||||
let this = self.0.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let result = match req.method.as_str() {
|
||||
"initialize" => this.handle_initialize(req).await,
|
||||
"tools/list" => this.handle_tools_list(req).await,
|
||||
"tools/call" => this.handle_tools_call(req).await,
|
||||
"resources/list" => this.handle_resources_list(req).await,
|
||||
"resources/read" => this.handle_resources_read(req).await,
|
||||
"prompts/list" => this.handle_prompts_list(req).await,
|
||||
"prompts/get" => this.handle_prompts_get(req).await,
|
||||
let result = match req.request.method.as_str() {
|
||||
"initialize" => this.handle_initialize(req.request).await,
|
||||
"tools/list" => this.handle_tools_list(req.request).await,
|
||||
"tools/call" => this.handle_tools_call(req.request, req.notifier).await,
|
||||
"resources/list" => this.handle_resources_list(req.request).await,
|
||||
"resources/read" => this.handle_resources_read(req.request).await,
|
||||
"prompts/list" => this.handle_prompts_list(req.request).await,
|
||||
"prompts/get" => this.handle_prompts_get(req.request).await,
|
||||
_ => {
|
||||
let mut response = this.create_response(req.id);
|
||||
response.error = Some(RouterError::MethodNotFound(req.method).into());
|
||||
let mut response = this.create_response(req.request.id);
|
||||
response.error = Some(RouterError::MethodNotFound(req.request.method).into());
|
||||
Ok(response)
|
||||
}
|
||||
};
|
||||
|
||||
@@ -148,6 +148,7 @@ function ChatContent({
|
||||
handleInputChange: _handleInputChange,
|
||||
handleSubmit: _submitMessage,
|
||||
updateMessageStreamBody,
|
||||
notifications,
|
||||
} = useMessageStream({
|
||||
api: getApiUrl('/reply'),
|
||||
initialMessages: chat.messages,
|
||||
@@ -492,6 +493,16 @@ function ChatContent({
|
||||
const handleDragOver = (e: React.DragEvent<HTMLDivElement>) => {
|
||||
e.preventDefault();
|
||||
};
|
||||
|
||||
const toolCallNotifications = notifications.reduce((map, item) => {
|
||||
const key = item.request_id;
|
||||
if (!map.has(key)) {
|
||||
map.set(key, []);
|
||||
}
|
||||
map.get(key).push(item);
|
||||
return map;
|
||||
}, new Map());
|
||||
|
||||
return (
|
||||
<div className="flex flex-col w-full h-screen items-center justify-center">
|
||||
{/* Loader when generating recipe */}
|
||||
@@ -571,6 +582,7 @@ function ChatContent({
|
||||
const updatedMessages = [...messages, newMessage];
|
||||
setMessages(updatedMessages);
|
||||
}}
|
||||
toolCallNotifications={toolCallNotifications}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
@@ -578,6 +590,7 @@ function ChatContent({
|
||||
</div>
|
||||
))}
|
||||
</SearchView>
|
||||
|
||||
{error && (
|
||||
<div className="flex flex-col items-center justify-center p-4">
|
||||
<div className="text-red-700 dark:text-red-300 bg-red-400/50 p-3 rounded-lg mb-2">
|
||||
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
} from '../types/message';
|
||||
import ToolCallConfirmation from './ToolCallConfirmation';
|
||||
import MessageCopyLink from './MessageCopyLink';
|
||||
import { NotificationEvent } from '../hooks/useMessageStream';
|
||||
|
||||
interface GooseMessageProps {
|
||||
// messages up to this index are presumed to be "history" from a resumed session, this is used to track older tool confirmation requests
|
||||
@@ -25,6 +26,7 @@ interface GooseMessageProps {
|
||||
message: Message;
|
||||
messages: Message[];
|
||||
metadata?: string[];
|
||||
toolCallNotifications: Map<string, NotificationEvent[]>;
|
||||
append: (value: string) => void;
|
||||
appendMessage: (message: Message) => void;
|
||||
}
|
||||
@@ -34,6 +36,7 @@ export default function GooseMessage({
|
||||
message,
|
||||
metadata,
|
||||
messages,
|
||||
toolCallNotifications,
|
||||
append,
|
||||
appendMessage,
|
||||
}: GooseMessageProps) {
|
||||
@@ -158,6 +161,7 @@ export default function GooseMessage({
|
||||
}
|
||||
toolRequest={toolRequest}
|
||||
toolResponse={toolResponsesMap.get(toolRequest.id)}
|
||||
notifications={toolCallNotifications.get(toolRequest.id)}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React from 'react';
|
||||
import React, { useEffect, useRef } from 'react';
|
||||
import { Card } from './ui/card';
|
||||
import { ToolCallArguments, ToolCallArgumentValue } from './ToolCallArguments';
|
||||
import MarkdownContent from './MarkdownContent';
|
||||
@@ -6,17 +6,20 @@ import { Content, ToolRequestMessageContent, ToolResponseMessageContent } from '
|
||||
import { snakeToTitleCase } from '../utils';
|
||||
import Dot, { LoadingStatus } from './ui/Dot';
|
||||
import Expand from './ui/Expand';
|
||||
import { NotificationEvent } from '../hooks/useMessageStream';
|
||||
|
||||
interface ToolCallWithResponseProps {
|
||||
isCancelledMessage: boolean;
|
||||
toolRequest: ToolRequestMessageContent;
|
||||
toolResponse?: ToolResponseMessageContent;
|
||||
notifications?: NotificationEvent[];
|
||||
}
|
||||
|
||||
export default function ToolCallWithResponse({
|
||||
isCancelledMessage,
|
||||
toolRequest,
|
||||
toolResponse,
|
||||
notifications,
|
||||
}: ToolCallWithResponseProps) {
|
||||
const toolCall = toolRequest.toolCall.status === 'success' ? toolRequest.toolCall.value : null;
|
||||
if (!toolCall) {
|
||||
@@ -26,7 +29,7 @@ export default function ToolCallWithResponse({
|
||||
return (
|
||||
<div className={'w-full text-textSubtle text-sm'}>
|
||||
<Card className="">
|
||||
<ToolCallView {...{ isCancelledMessage, toolCall, toolResponse }} />
|
||||
<ToolCallView {...{ isCancelledMessage, toolCall, toolResponse, notifications }} />
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
@@ -47,8 +50,9 @@ function ToolCallExpandable({
|
||||
children,
|
||||
className = '',
|
||||
}: ToolCallExpandableProps) {
|
||||
const [isExpanded, setIsExpanded] = React.useState(isStartExpanded);
|
||||
const toggleExpand = () => setIsExpanded((prev) => !prev);
|
||||
const [isExpandedState, setIsExpanded] = React.useState<boolean | null>(null);
|
||||
const isExpanded = isExpandedState === null ? isStartExpanded : isExpandedState;
|
||||
const toggleExpand = () => setIsExpanded(!isExpanded);
|
||||
React.useEffect(() => {
|
||||
if (isForceExpand) setIsExpanded(true);
|
||||
}, [isForceExpand]);
|
||||
@@ -71,9 +75,42 @@ interface ToolCallViewProps {
|
||||
arguments: Record<string, unknown>;
|
||||
};
|
||||
toolResponse?: ToolResponseMessageContent;
|
||||
notifications?: NotificationEvent[];
|
||||
}
|
||||
|
||||
function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallViewProps) {
|
||||
interface Progress {
|
||||
progress: number;
|
||||
progressToken: string;
|
||||
total?: number;
|
||||
message?: string;
|
||||
}
|
||||
|
||||
const logToString = (logMessage: NotificationEvent) => {
|
||||
const params = logMessage.message.params;
|
||||
|
||||
// Special case for the developer system shell logs
|
||||
if (
|
||||
params &&
|
||||
params.data &&
|
||||
typeof params.data === 'object' &&
|
||||
'output' in params.data &&
|
||||
'stream' in params.data
|
||||
) {
|
||||
return `[${params.data.stream}] ${params.data.output}`;
|
||||
}
|
||||
|
||||
return typeof params.data === 'string' ? params.data : JSON.stringify(params.data);
|
||||
};
|
||||
|
||||
const notificationToProgress = (notification: NotificationEvent): Progress =>
|
||||
notification.message.params as unknown as Progress;
|
||||
|
||||
function ToolCallView({
|
||||
isCancelledMessage,
|
||||
toolCall,
|
||||
toolResponse,
|
||||
notifications,
|
||||
}: ToolCallViewProps) {
|
||||
const responseStyle = localStorage.getItem('response_style');
|
||||
const isExpandToolDetails = (() => {
|
||||
switch (responseStyle) {
|
||||
@@ -103,6 +140,29 @@ function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallVi
|
||||
}))
|
||||
: [];
|
||||
|
||||
const logs = notifications
|
||||
?.filter((notification) => notification.message.method === 'notifications/message')
|
||||
.map(logToString);
|
||||
|
||||
const progress = notifications
|
||||
?.filter((notification) => notification.message.method === 'notifications/progress')
|
||||
.map(notificationToProgress)
|
||||
.reduce((map, item) => {
|
||||
const key = item.progressToken;
|
||||
if (!map.has(key)) {
|
||||
map.set(key, []);
|
||||
}
|
||||
map.get(key)!.push(item);
|
||||
return map;
|
||||
}, new Map<string, Progress[]>());
|
||||
|
||||
const progressEntries = [...(progress?.values() || [])].map(
|
||||
(entries) => entries.sort((a, b) => b.progress - a.progress)[0]
|
||||
);
|
||||
|
||||
const isRenderingProgress =
|
||||
loadingStatus === 'loading' && (progressEntries.length > 0 || (logs || []).length > 0);
|
||||
|
||||
const isShouldExpand = isExpandToolDetails || toolResults.some((v) => v.isExpandToolResults);
|
||||
|
||||
// Function to create a compact representation of arguments
|
||||
@@ -136,7 +196,7 @@ function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallVi
|
||||
|
||||
return (
|
||||
<ToolCallExpandable
|
||||
isStartExpanded={isShouldExpand}
|
||||
isStartExpanded={isShouldExpand || isRenderingProgress}
|
||||
isForceExpand={isShouldExpand}
|
||||
label={
|
||||
<>
|
||||
@@ -156,6 +216,24 @@ function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallVi
|
||||
</div>
|
||||
)}
|
||||
|
||||
{logs && logs.length > 0 && (
|
||||
<div className="bg-bgStandard mt-1">
|
||||
<ToolLogsView
|
||||
logs={logs}
|
||||
working={toolResults.length === 0}
|
||||
isStartExpanded={toolResults.length === 0}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{toolResults.length === 0 &&
|
||||
progressEntries.length > 0 &&
|
||||
progressEntries.map((entry, index) => (
|
||||
<div className="p-2" key={index}>
|
||||
<ProgressBar progress={entry.progress} total={entry.total} message={entry.message} />
|
||||
</div>
|
||||
))}
|
||||
|
||||
{/* Tool Output */}
|
||||
{!isCancelledMessage && (
|
||||
<>
|
||||
@@ -234,3 +312,76 @@ function ToolResultView({ result, isStartExpanded }: ToolResultViewProps) {
|
||||
</ToolCallExpandable>
|
||||
);
|
||||
}
|
||||
|
||||
function ToolLogsView({
|
||||
logs,
|
||||
working,
|
||||
isStartExpanded,
|
||||
}: {
|
||||
logs: string[];
|
||||
working: boolean;
|
||||
isStartExpanded?: boolean;
|
||||
}) {
|
||||
const boxRef = useRef(null);
|
||||
|
||||
// Whenever logs update, jump to the newest entry
|
||||
useEffect(() => {
|
||||
if (boxRef.current) {
|
||||
boxRef.current.scrollTop = boxRef.current.scrollHeight;
|
||||
}
|
||||
}, [logs]);
|
||||
|
||||
return (
|
||||
<ToolCallExpandable
|
||||
label={
|
||||
<span className="pl-[19px] py-1">
|
||||
<span>Logs</span>
|
||||
{working && (
|
||||
<div className="mx-2 inline-block">
|
||||
<span
|
||||
className="inline-block animate-spin rounded-full border-2 border-t-transparent border-current"
|
||||
style={{ width: 8, height: 8 }}
|
||||
role="status"
|
||||
aria-label="Loading spinner"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</span>
|
||||
}
|
||||
isStartExpanded={isStartExpanded}
|
||||
>
|
||||
<div
|
||||
ref={boxRef}
|
||||
className={`flex flex-col items-start space-y-2 overflow-y-auto ${working ? 'max-h-[4rem]' : 'max-h-[20rem]'} bg-bgApp`}
|
||||
>
|
||||
{logs.map((log, i) => (
|
||||
<span key={i} className="font-mono text-sm text-textSubtle">
|
||||
{log}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</ToolCallExpandable>
|
||||
);
|
||||
}
|
||||
|
||||
const ProgressBar = ({ progress, total, message }: Omit<Progress, 'progressToken'>) => {
|
||||
const isDeterminate = typeof total === 'number';
|
||||
const percent = isDeterminate ? Math.min((progress / total!) * 100, 100) : 0;
|
||||
|
||||
return (
|
||||
<div className="w-full space-y-2">
|
||||
{message && <div className="text-sm text-gray-700">{message}</div>}
|
||||
|
||||
<div className="w-full bg-gray-200 rounded-full h-4 overflow-hidden relative">
|
||||
{isDeterminate ? (
|
||||
<div
|
||||
className="bg-blue-500 h-full transition-all duration-300"
|
||||
style={{ width: `${percent}%` }}
|
||||
/>
|
||||
) : (
|
||||
<div className="absolute inset-0 animate-indeterminate bg-blue-500" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -6,11 +6,25 @@ import { Message, createUserMessage, hasCompletedToolCalls } from '../types/mess
|
||||
// Ensure TextDecoder is available in the global scope
|
||||
const TextDecoder = globalThis.TextDecoder;
|
||||
|
||||
type JsonValue = string | number | boolean | null | JsonValue[] | { [key: string]: JsonValue };
|
||||
|
||||
export interface NotificationEvent {
|
||||
type: 'Notification';
|
||||
request_id: string;
|
||||
message: {
|
||||
method: string;
|
||||
params: {
|
||||
[key: string]: JsonValue;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
// Event types for SSE stream
|
||||
type MessageEvent =
|
||||
| { type: 'Message'; message: Message }
|
||||
| { type: 'Error'; error: string }
|
||||
| { type: 'Finish'; reason: string };
|
||||
| { type: 'Finish'; reason: string }
|
||||
| NotificationEvent;
|
||||
|
||||
export interface UseMessageStreamOptions {
|
||||
/**
|
||||
@@ -124,6 +138,8 @@ export interface UseMessageStreamHelpers {
|
||||
|
||||
/** Modify body (session id and/or work dir mid-stream) **/
|
||||
updateMessageStreamBody?: (newBody: object) => void;
|
||||
|
||||
notifications: NotificationEvent[];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -151,6 +167,8 @@ export function useMessageStream({
|
||||
fallbackData: initialMessages,
|
||||
});
|
||||
|
||||
const [notifications, setNotifications] = useState<NotificationEvent[]>([]);
|
||||
|
||||
// expose a way to update the body so we can update the session id when CLE occurs
|
||||
const updateMessageStreamBody = useCallback((newBody: object) => {
|
||||
extraMetadataRef.current.body = {
|
||||
@@ -247,6 +265,14 @@ export function useMessageStream({
|
||||
break;
|
||||
}
|
||||
|
||||
case 'Notification': {
|
||||
const newNotification = {
|
||||
...parsedEvent,
|
||||
};
|
||||
setNotifications((prev) => [...prev, newNotification]);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'Error':
|
||||
throw new Error(parsedEvent.error);
|
||||
|
||||
@@ -516,5 +542,6 @@ export function useMessageStream({
|
||||
isLoading: isLoading || false,
|
||||
addToolResult,
|
||||
updateMessageStreamBody,
|
||||
notifications,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -44,10 +44,16 @@ export default {
|
||||
'0%': { transform: 'rotate(0deg)' },
|
||||
'100%': { transform: 'rotate(360deg)' },
|
||||
},
|
||||
indeterminate: {
|
||||
'0%': { left: '-40%', width: '40%' },
|
||||
'50%': { left: '20%', width: '60%' },
|
||||
'100%': { left: '100%', width: '80%' },
|
||||
},
|
||||
},
|
||||
animation: {
|
||||
'shimmer-pulse': 'shimmer 4s ease-in-out infinite',
|
||||
'gradient-loader': 'loader 750ms ease-in-out infinite',
|
||||
indeterminate: 'indeterminate 1.5s infinite linear',
|
||||
},
|
||||
colors: {
|
||||
bgApp: 'var(--background-app)',
|
||||
|
||||
Reference in New Issue
Block a user