mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 06:34:26 +01:00
feat: Handle MCP server notification messages (#2613)
Co-authored-by: Michael Neale <michael.neale@gmail.com>
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -3435,6 +3435,7 @@ dependencies = [
|
|||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-cron-scheduler",
|
"tokio-cron-scheduler",
|
||||||
|
"tokio-stream",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"url",
|
"url",
|
||||||
@@ -3486,6 +3487,7 @@ dependencies = [
|
|||||||
"goose",
|
"goose",
|
||||||
"goose-bench",
|
"goose-bench",
|
||||||
"goose-mcp",
|
"goose-mcp",
|
||||||
|
"indicatif",
|
||||||
"mcp-client",
|
"mcp-client",
|
||||||
"mcp-core",
|
"mcp-core",
|
||||||
"mcp-server",
|
"mcp-server",
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ regex = "1.11.1"
|
|||||||
minijinja = "2.8.0"
|
minijinja = "2.8.0"
|
||||||
nix = { version = "0.30.1", features = ["process", "signal"] }
|
nix = { version = "0.30.1", features = ["process", "signal"] }
|
||||||
tar = "0.4"
|
tar = "0.4"
|
||||||
|
indicatif = "0.17.11"
|
||||||
|
|
||||||
[target.'cfg(target_os = "windows")'.dependencies]
|
[target.'cfg(target_os = "windows")'.dependencies]
|
||||||
winapi = { version = "0.3", features = ["wincred"] }
|
winapi = { version = "0.3", features = ["wincred"] }
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ mod thinking;
|
|||||||
|
|
||||||
pub use builder::{build_session, SessionBuilderConfig};
|
pub use builder::{build_session, SessionBuilderConfig};
|
||||||
use console::Color;
|
use console::Color;
|
||||||
|
use goose::agents::AgentEvent;
|
||||||
use goose::permission::permission_confirmation::PrincipalType;
|
use goose::permission::permission_confirmation::PrincipalType;
|
||||||
use goose::permission::Permission;
|
use goose::permission::Permission;
|
||||||
use goose::permission::PermissionConfirmation;
|
use goose::permission::PermissionConfirmation;
|
||||||
@@ -26,6 +27,8 @@ use input::InputResult;
|
|||||||
use mcp_core::handler::ToolError;
|
use mcp_core::handler::ToolError;
|
||||||
use mcp_core::prompt::PromptMessage;
|
use mcp_core::prompt::PromptMessage;
|
||||||
|
|
||||||
|
use mcp_core::protocol::JsonRpcMessage;
|
||||||
|
use mcp_core::protocol::JsonRpcNotification;
|
||||||
use rand::{distributions::Alphanumeric, Rng};
|
use rand::{distributions::Alphanumeric, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@@ -713,12 +716,15 @@ impl Session {
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let mut progress_bars = output::McpSpinners::new();
|
||||||
|
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
result = stream.next() => {
|
result = stream.next() => {
|
||||||
|
let _ = progress_bars.hide();
|
||||||
match result {
|
match result {
|
||||||
Some(Ok(message)) => {
|
Some(Ok(AgentEvent::Message(message))) => {
|
||||||
// If it's a confirmation request, get approval but otherwise do not render/persist
|
// If it's a confirmation request, get approval but otherwise do not render/persist
|
||||||
if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() {
|
if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() {
|
||||||
output::hide_thinking();
|
output::hide_thinking();
|
||||||
@@ -846,6 +852,51 @@ impl Session {
|
|||||||
if interactive {output::show_thinking()};
|
if interactive {output::show_thinking()};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Some(Ok(AgentEvent::McpNotification((_id, message)))) => {
|
||||||
|
if let JsonRpcMessage::Notification(JsonRpcNotification{
|
||||||
|
method,
|
||||||
|
params: Some(Value::Object(o)),
|
||||||
|
..
|
||||||
|
}) = message {
|
||||||
|
match method.as_str() {
|
||||||
|
"notifications/message" => {
|
||||||
|
let data = o.get("data").unwrap_or(&Value::Null);
|
||||||
|
let message = match data {
|
||||||
|
Value::String(s) => s.clone(),
|
||||||
|
Value::Object(o) => {
|
||||||
|
if let Some(Value::String(output)) = o.get("output") {
|
||||||
|
output.to_owned()
|
||||||
|
} else {
|
||||||
|
data.to_string()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
v => {
|
||||||
|
v.to_string()
|
||||||
|
},
|
||||||
|
};
|
||||||
|
// output::render_text_no_newlines(&message, None, true);
|
||||||
|
progress_bars.log(&message);
|
||||||
|
},
|
||||||
|
"notifications/progress" => {
|
||||||
|
let progress = o.get("progress").and_then(|v| v.as_f64());
|
||||||
|
let token = o.get("progressToken").map(|v| v.to_string());
|
||||||
|
let message = o.get("message").and_then(|v| v.as_str());
|
||||||
|
let total = o
|
||||||
|
.get("total")
|
||||||
|
.and_then(|v| v.as_f64());
|
||||||
|
if let (Some(progress), Some(token)) = (progress, token) {
|
||||||
|
progress_bars.update(
|
||||||
|
token.as_str(),
|
||||||
|
progress,
|
||||||
|
total,
|
||||||
|
message,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Some(Err(e)) => {
|
Some(Err(e)) => {
|
||||||
eprintln!("Error: {}", e);
|
eprintln!("Error: {}", e);
|
||||||
drop(stream);
|
drop(stream);
|
||||||
@@ -872,6 +923,7 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,12 +2,15 @@ use bat::WrappingMode;
|
|||||||
use console::{style, Color};
|
use console::{style, Color};
|
||||||
use goose::config::Config;
|
use goose::config::Config;
|
||||||
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
|
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
|
||||||
|
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
|
||||||
use mcp_core::prompt::PromptArgument;
|
use mcp_core::prompt::PromptArgument;
|
||||||
use mcp_core::tool::ToolCall;
|
use mcp_core::tool::ToolCall;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::io::Error;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
// Re-export theme for use in main
|
// Re-export theme for use in main
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
@@ -144,6 +147,10 @@ pub fn render_message(message: &Message, debug: bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
|
pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
|
||||||
|
render_text_no_newlines(format!("\n{}\n\n", text).as_str(), color, dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn render_text_no_newlines(text: &str, color: Option<Color>, dim: bool) {
|
||||||
let mut styled_text = style(text);
|
let mut styled_text = style(text);
|
||||||
if dim {
|
if dim {
|
||||||
styled_text = styled_text.dim();
|
styled_text = styled_text.dim();
|
||||||
@@ -153,7 +160,7 @@ pub fn render_text(text: &str, color: Option<Color>, dim: bool) {
|
|||||||
} else {
|
} else {
|
||||||
styled_text = styled_text.green();
|
styled_text = styled_text.green();
|
||||||
}
|
}
|
||||||
println!("\n{}\n", styled_text);
|
print!("{}", styled_text);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn render_enter_plan_mode() {
|
pub fn render_enter_plan_mode() {
|
||||||
@@ -359,7 +366,6 @@ fn render_shell_request(call: &ToolCall, debug: bool) {
|
|||||||
}
|
}
|
||||||
_ => print_params(&call.arguments, 0, debug),
|
_ => print_params(&call.arguments, 0, debug),
|
||||||
}
|
}
|
||||||
println!();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn render_default_request(call: &ToolCall, debug: bool) {
|
fn render_default_request(call: &ToolCall, debug: bool) {
|
||||||
@@ -568,6 +574,64 @@ pub fn display_greeting() {
|
|||||||
println!("\nGoose is running! Enter your instructions, or try asking what goose can do.\n");
|
println!("\nGoose is running! Enter your instructions, or try asking what goose can do.\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct McpSpinners {
|
||||||
|
bars: HashMap<String, ProgressBar>,
|
||||||
|
log_spinner: Option<ProgressBar>,
|
||||||
|
|
||||||
|
multi_bar: MultiProgress,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpSpinners {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
McpSpinners {
|
||||||
|
bars: HashMap::new(),
|
||||||
|
log_spinner: None,
|
||||||
|
multi_bar: MultiProgress::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn log(&mut self, message: &str) {
|
||||||
|
let spinner = self.log_spinner.get_or_insert_with(|| {
|
||||||
|
let bar = self.multi_bar.add(
|
||||||
|
ProgressBar::new_spinner()
|
||||||
|
.with_style(
|
||||||
|
ProgressStyle::with_template("{spinner:.green} {msg}")
|
||||||
|
.unwrap()
|
||||||
|
.tick_chars("⠋⠙⠚⠛⠓⠒⠊⠉"),
|
||||||
|
)
|
||||||
|
.with_message(message.to_string()),
|
||||||
|
);
|
||||||
|
bar.enable_steady_tick(Duration::from_millis(100));
|
||||||
|
bar
|
||||||
|
});
|
||||||
|
|
||||||
|
spinner.set_message(message.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update(&mut self, token: &str, value: f64, total: Option<f64>, message: Option<&str>) {
|
||||||
|
let bar = self.bars.entry(token.to_string()).or_insert_with(|| {
|
||||||
|
if let Some(total) = total {
|
||||||
|
self.multi_bar.add(
|
||||||
|
ProgressBar::new((total * 100.0) as u64).with_style(
|
||||||
|
ProgressStyle::with_template("[{elapsed}] {bar:40} {pos:>3}/{len:3} {msg}")
|
||||||
|
.unwrap(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
self.multi_bar.add(ProgressBar::new_spinner())
|
||||||
|
}
|
||||||
|
});
|
||||||
|
bar.set_position((value * 100.0) as u64);
|
||||||
|
if let Some(msg) = message {
|
||||||
|
bar.set_message(msg.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn hide(&mut self) -> Result<(), Error> {
|
||||||
|
self.multi_bar.clear()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use std::ptr;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use goose::agents::Agent;
|
use goose::agents::{Agent, AgentEvent};
|
||||||
use goose::message::Message;
|
use goose::message::Message;
|
||||||
use goose::model::ModelConfig;
|
use goose::model::ModelConfig;
|
||||||
use goose::providers::databricks::DatabricksProvider;
|
use goose::providers::databricks::DatabricksProvider;
|
||||||
@@ -256,13 +256,16 @@ pub unsafe extern "C" fn goose_agent_send_message(
|
|||||||
|
|
||||||
while let Some(message_result) = stream.next().await {
|
while let Some(message_result) = stream.next().await {
|
||||||
match message_result {
|
match message_result {
|
||||||
Ok(message) => {
|
Ok(AgentEvent::Message(message)) => {
|
||||||
// Get text or serialize to JSON
|
// Get text or serialize to JSON
|
||||||
// Note: Message doesn't have as_text method, we'll serialize to JSON
|
// Note: Message doesn't have as_text method, we'll serialize to JSON
|
||||||
if let Ok(json) = serde_json::to_string(&message) {
|
if let Ok(json) = serde_json::to_string(&message) {
|
||||||
full_response.push_str(&json);
|
full_response.push_str(&json);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Ok(AgentEvent::McpNotification(_)) => {
|
||||||
|
// TODO: Handle MCP notifications.
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
full_response.push_str(&format!("\nError in message stream: {}", e));
|
full_response.push_str(&format!("\nError in message stream: {}", e));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use serde_json::{json, Value};
|
|||||||
use std::{
|
use std::{
|
||||||
collections::HashMap, fs, future::Future, path::PathBuf, pin::Pin, sync::Arc, sync::Mutex,
|
collections::HashMap, fs, future::Future, path::PathBuf, pin::Pin, sync::Arc, sync::Mutex,
|
||||||
};
|
};
|
||||||
use tokio::process::Command;
|
use tokio::{process::Command, sync::mpsc};
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use std::os::unix::fs::PermissionsExt;
|
use std::os::unix::fs::PermissionsExt;
|
||||||
@@ -14,7 +14,7 @@ use std::os::unix::fs::PermissionsExt;
|
|||||||
use mcp_core::{
|
use mcp_core::{
|
||||||
handler::{PromptError, ResourceError, ToolError},
|
handler::{PromptError, ResourceError, ToolError},
|
||||||
prompt::Prompt,
|
prompt::Prompt,
|
||||||
protocol::ServerCapabilities,
|
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||||
resource::Resource,
|
resource::Resource,
|
||||||
tool::{Tool, ToolAnnotations},
|
tool::{Tool, ToolAnnotations},
|
||||||
Content,
|
Content,
|
||||||
@@ -1155,6 +1155,7 @@ impl Router for ComputerControllerRouter {
|
|||||||
&self,
|
&self,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
arguments: Value,
|
arguments: Value,
|
||||||
|
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
let tool_name = tool_name.to_string();
|
let tool_name = tool_name.to_string();
|
||||||
|
|||||||
@@ -13,13 +13,17 @@ use std::{
|
|||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
};
|
};
|
||||||
use tokio::process::Command;
|
use tokio::{
|
||||||
|
io::{AsyncBufReadExt, BufReader},
|
||||||
|
process::Command,
|
||||||
|
sync::mpsc,
|
||||||
|
};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
use include_dir::{include_dir, Dir};
|
use include_dir::{include_dir, Dir};
|
||||||
use mcp_core::{
|
use mcp_core::{
|
||||||
handler::{PromptError, ResourceError, ToolError},
|
handler::{PromptError, ResourceError, ToolError},
|
||||||
protocol::ServerCapabilities,
|
protocol::{JsonRpcMessage, JsonRpcNotification, ServerCapabilities},
|
||||||
resource::Resource,
|
resource::Resource,
|
||||||
tool::Tool,
|
tool::Tool,
|
||||||
Content,
|
Content,
|
||||||
@@ -456,7 +460,11 @@ impl DeveloperRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Shell command execution with platform-specific handling
|
// Shell command execution with platform-specific handling
|
||||||
async fn bash(&self, params: Value) -> Result<Vec<Content>, ToolError> {
|
async fn bash(
|
||||||
|
&self,
|
||||||
|
params: Value,
|
||||||
|
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||||
|
) -> Result<Vec<Content>, ToolError> {
|
||||||
let command =
|
let command =
|
||||||
params
|
params
|
||||||
.get("command")
|
.get("command")
|
||||||
@@ -488,27 +496,92 @@ impl DeveloperRouter {
|
|||||||
|
|
||||||
// Get platform-specific shell configuration
|
// Get platform-specific shell configuration
|
||||||
let shell_config = get_shell_config();
|
let shell_config = get_shell_config();
|
||||||
let cmd_with_redirect = format_command_for_platform(command);
|
let cmd_str = format_command_for_platform(command);
|
||||||
|
|
||||||
// Execute the command using platform-specific shell
|
// Execute the command using platform-specific shell
|
||||||
let child = Command::new(&shell_config.executable)
|
let mut child = Command::new(&shell_config.executable)
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped())
|
.stderr(Stdio::piped())
|
||||||
.stdin(Stdio::null())
|
.stdin(Stdio::null())
|
||||||
.kill_on_drop(true)
|
.kill_on_drop(true)
|
||||||
.arg(&shell_config.arg)
|
.arg(&shell_config.arg)
|
||||||
.arg(cmd_with_redirect)
|
.arg(cmd_str)
|
||||||
.spawn()
|
.spawn()
|
||||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||||
|
|
||||||
|
let stdout = child.stdout.take().unwrap();
|
||||||
|
let stderr = child.stderr.take().unwrap();
|
||||||
|
|
||||||
|
let mut stdout_reader = BufReader::new(stdout);
|
||||||
|
let mut stderr_reader = BufReader::new(stderr);
|
||||||
|
|
||||||
|
let output_task = tokio::spawn(async move {
|
||||||
|
let mut combined_output = String::new();
|
||||||
|
|
||||||
|
let mut stdout_buf = Vec::new();
|
||||||
|
let mut stderr_buf = Vec::new();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
n = stdout_reader.read_until(b'\n', &mut stdout_buf) => {
|
||||||
|
if n? == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let line = String::from_utf8_lossy(&stdout_buf);
|
||||||
|
|
||||||
|
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
|
||||||
|
jsonrpc: "2.0".to_string(),
|
||||||
|
method: "notifications/message".to_string(),
|
||||||
|
params: Some(json!({
|
||||||
|
"data": {
|
||||||
|
"type": "shell",
|
||||||
|
"stream": "stdout",
|
||||||
|
"output": line.to_string(),
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
}))
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
combined_output.push_str(&line);
|
||||||
|
stdout_buf.clear();
|
||||||
|
}
|
||||||
|
n = stderr_reader.read_until(b'\n', &mut stderr_buf) => {
|
||||||
|
if n? == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let line = String::from_utf8_lossy(&stderr_buf);
|
||||||
|
|
||||||
|
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
|
||||||
|
jsonrpc: "2.0".to_string(),
|
||||||
|
method: "notifications/message".to_string(),
|
||||||
|
params: Some(json!({
|
||||||
|
"data": {
|
||||||
|
"type": "shell",
|
||||||
|
"stream": "stderr",
|
||||||
|
"output": line.to_string(),
|
||||||
|
}
|
||||||
|
})),
|
||||||
|
}))
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
combined_output.push_str(&line);
|
||||||
|
stderr_buf.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok::<_, std::io::Error>(combined_output)
|
||||||
|
});
|
||||||
|
|
||||||
// Wait for the command to complete and get output
|
// Wait for the command to complete and get output
|
||||||
let output = child
|
child
|
||||||
.wait_with_output()
|
.wait()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||||
|
|
||||||
let stdout_str = String::from_utf8_lossy(&output.stdout);
|
let output_str = match output_task.await {
|
||||||
let output_str = stdout_str;
|
Ok(result) => result.map_err(|e| ToolError::ExecutionError(e.to_string()))?,
|
||||||
|
Err(e) => return Err(ToolError::ExecutionError(e.to_string())),
|
||||||
|
};
|
||||||
|
|
||||||
// Check the character count of the output
|
// Check the character count of the output
|
||||||
const MAX_CHAR_COUNT: usize = 400_000; // 409600 chars = 400KB
|
const MAX_CHAR_COUNT: usize = 400_000; // 409600 chars = 400KB
|
||||||
@@ -1048,12 +1121,13 @@ impl Router for DeveloperRouter {
|
|||||||
&self,
|
&self,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
arguments: Value,
|
arguments: Value,
|
||||||
|
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
let tool_name = tool_name.to_string();
|
let tool_name = tool_name.to_string();
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
match tool_name.as_str() {
|
match tool_name.as_str() {
|
||||||
"shell" => this.bash(arguments).await,
|
"shell" => this.bash(arguments, notifier).await,
|
||||||
"text_editor" => this.text_editor(arguments).await,
|
"text_editor" => this.text_editor(arguments).await,
|
||||||
"list_windows" => this.list_windows(arguments).await,
|
"list_windows" => this.list_windows(arguments).await,
|
||||||
"screen_capture" => this.screen_capture(arguments).await,
|
"screen_capture" => this.screen_capture(arguments).await,
|
||||||
@@ -1195,6 +1269,10 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn dummy_sender() -> mpsc::Sender<JsonRpcMessage> {
|
||||||
|
mpsc::channel(1).0
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
#[serial]
|
#[serial]
|
||||||
async fn test_shell_missing_parameters() {
|
async fn test_shell_missing_parameters() {
|
||||||
@@ -1202,7 +1280,7 @@ mod tests {
|
|||||||
std::env::set_current_dir(&temp_dir).unwrap();
|
std::env::set_current_dir(&temp_dir).unwrap();
|
||||||
|
|
||||||
let router = get_router().await;
|
let router = get_router().await;
|
||||||
let result = router.call_tool("shell", json!({})).await;
|
let result = router.call_tool("shell", json!({}), dummy_sender()).await;
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
let err = result.err().unwrap();
|
let err = result.err().unwrap();
|
||||||
@@ -1263,6 +1341,7 @@ mod tests {
|
|||||||
"command": "view",
|
"command": "view",
|
||||||
"path": large_file_str
|
"path": large_file_str
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -1288,6 +1367,7 @@ mod tests {
|
|||||||
"command": "view",
|
"command": "view",
|
||||||
"path": many_chars_str
|
"path": many_chars_str
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -1319,6 +1399,7 @@ mod tests {
|
|||||||
"path": file_path_str,
|
"path": file_path_str,
|
||||||
"file_text": "Hello, world!"
|
"file_text": "Hello, world!"
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1331,6 +1412,7 @@ mod tests {
|
|||||||
"command": "view",
|
"command": "view",
|
||||||
"path": file_path_str
|
"path": file_path_str
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1369,6 +1451,7 @@ mod tests {
|
|||||||
"path": file_path_str,
|
"path": file_path_str,
|
||||||
"file_text": "Hello, world!"
|
"file_text": "Hello, world!"
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1383,6 +1466,7 @@ mod tests {
|
|||||||
"old_str": "world",
|
"old_str": "world",
|
||||||
"new_str": "Rust"
|
"new_str": "Rust"
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1407,6 +1491,7 @@ mod tests {
|
|||||||
"command": "view",
|
"command": "view",
|
||||||
"path": file_path_str
|
"path": file_path_str
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1444,6 +1529,7 @@ mod tests {
|
|||||||
"path": file_path_str,
|
"path": file_path_str,
|
||||||
"file_text": "First line"
|
"file_text": "First line"
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1458,6 +1544,7 @@ mod tests {
|
|||||||
"old_str": "First line",
|
"old_str": "First line",
|
||||||
"new_str": "Second line"
|
"new_str": "Second line"
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1470,6 +1557,7 @@ mod tests {
|
|||||||
"command": "undo_edit",
|
"command": "undo_edit",
|
||||||
"path": file_path_str
|
"path": file_path_str
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1485,6 +1573,7 @@ mod tests {
|
|||||||
"command": "view",
|
"command": "view",
|
||||||
"path": file_path_str
|
"path": file_path_str
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1583,6 +1672,7 @@ mod tests {
|
|||||||
"path": temp_dir.path().join("secret.txt").to_str().unwrap(),
|
"path": temp_dir.path().join("secret.txt").to_str().unwrap(),
|
||||||
"file_text": "test content"
|
"file_text": "test content"
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -1601,6 +1691,7 @@ mod tests {
|
|||||||
"path": temp_dir.path().join("allowed.txt").to_str().unwrap(),
|
"path": temp_dir.path().join("allowed.txt").to_str().unwrap(),
|
||||||
"file_text": "test content"
|
"file_text": "test content"
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -1642,6 +1733,7 @@ mod tests {
|
|||||||
json!({
|
json!({
|
||||||
"command": format!("cat {}", secret_file_path.to_str().unwrap())
|
"command": format!("cat {}", secret_file_path.to_str().unwrap())
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -1658,6 +1750,7 @@ mod tests {
|
|||||||
json!({
|
json!({
|
||||||
"command": format!("cat {}", allowed_file_path.to_str().unwrap())
|
"command": format!("cat {}", allowed_file_path.to_str().unwrap())
|
||||||
}),
|
}),
|
||||||
|
dummy_sender(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ use std::env;
|
|||||||
pub struct ShellConfig {
|
pub struct ShellConfig {
|
||||||
pub executable: String,
|
pub executable: String,
|
||||||
pub arg: String,
|
pub arg: String,
|
||||||
pub redirect_syntax: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ShellConfig {
|
impl Default for ShellConfig {
|
||||||
@@ -14,13 +13,11 @@ impl Default for ShellConfig {
|
|||||||
Self {
|
Self {
|
||||||
executable: "powershell.exe".to_string(),
|
executable: "powershell.exe".to_string(),
|
||||||
arg: "-NoProfile -NonInteractive -Command".to_string(),
|
arg: "-NoProfile -NonInteractive -Command".to_string(),
|
||||||
redirect_syntax: "2>&1".to_string(),
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Self {
|
Self {
|
||||||
executable: "bash".to_string(),
|
executable: "bash".to_string(),
|
||||||
arg: "-c".to_string(),
|
arg: "-c".to_string(),
|
||||||
redirect_syntax: "2>&1".to_string(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -31,13 +28,12 @@ pub fn get_shell_config() -> ShellConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn format_command_for_platform(command: &str) -> String {
|
pub fn format_command_for_platform(command: &str) -> String {
|
||||||
let config = get_shell_config();
|
|
||||||
if cfg!(windows) {
|
if cfg!(windows) {
|
||||||
// For PowerShell, wrap the command in braces to handle special characters
|
// For PowerShell, wrap the command in braces to handle special characters
|
||||||
format!("{{ {} }} {}", command, config.redirect_syntax)
|
format!("{{ {} }}", command)
|
||||||
} else {
|
} else {
|
||||||
// For other shells, no braces needed
|
// For other shells, no braces needed
|
||||||
format!("{} {}", command, config.redirect_syntax)
|
command.to_string()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ use base64::Engine;
|
|||||||
use chrono::NaiveDate;
|
use chrono::NaiveDate;
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
|
use mcp_core::protocol::JsonRpcMessage;
|
||||||
use mcp_core::tool::ToolAnnotations;
|
use mcp_core::tool::ToolAnnotations;
|
||||||
use oauth_pkce::PkceOAuth2Client;
|
use oauth_pkce::PkceOAuth2Client;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
@@ -14,6 +15,7 @@ use serde_json::{json, Value};
|
|||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc};
|
use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc};
|
||||||
use storage::CredentialsManager;
|
use storage::CredentialsManager;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
use mcp_core::content::Content;
|
use mcp_core::content::Content;
|
||||||
use mcp_core::{
|
use mcp_core::{
|
||||||
@@ -3281,6 +3283,7 @@ impl Router for GoogleDriveRouter {
|
|||||||
&self,
|
&self,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
arguments: Value,
|
arguments: Value,
|
||||||
|
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
let tool_name = tool_name.to_string();
|
let tool_name = tool_name.to_string();
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ use mcp_core::{
|
|||||||
content::Content,
|
content::Content,
|
||||||
handler::{PromptError, ResourceError, ToolError},
|
handler::{PromptError, ResourceError, ToolError},
|
||||||
prompt::Prompt,
|
prompt::Prompt,
|
||||||
protocol::ServerCapabilities,
|
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||||
resource::Resource,
|
resource::Resource,
|
||||||
role::Role,
|
role::Role,
|
||||||
tool::Tool,
|
tool::Tool,
|
||||||
@@ -16,7 +16,7 @@ use serde_json::Value;
|
|||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::{mpsc, Mutex};
|
||||||
use tokio::time::{sleep, Duration};
|
use tokio::time::{sleep, Duration};
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
@@ -158,6 +158,7 @@ impl Router for JetBrainsRouter {
|
|||||||
&self,
|
&self,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
arguments: Value,
|
arguments: Value,
|
||||||
|
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
let tool_name = tool_name.to_string();
|
let tool_name = tool_name.to_string();
|
||||||
|
|||||||
@@ -10,11 +10,12 @@ use std::{
|
|||||||
path::PathBuf,
|
path::PathBuf,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
};
|
};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
use mcp_core::{
|
use mcp_core::{
|
||||||
handler::{PromptError, ResourceError, ToolError},
|
handler::{PromptError, ResourceError, ToolError},
|
||||||
prompt::Prompt,
|
prompt::Prompt,
|
||||||
protocol::ServerCapabilities,
|
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||||
resource::Resource,
|
resource::Resource,
|
||||||
tool::{Tool, ToolAnnotations, ToolCall},
|
tool::{Tool, ToolAnnotations, ToolCall},
|
||||||
Content,
|
Content,
|
||||||
@@ -520,6 +521,7 @@ impl Router for MemoryRouter {
|
|||||||
&self,
|
&self,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
arguments: Value,
|
arguments: Value,
|
||||||
|
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
let tool_name = tool_name.to_string();
|
let tool_name = tool_name.to_string();
|
||||||
|
|||||||
@@ -3,11 +3,12 @@ use include_dir::{include_dir, Dir};
|
|||||||
use indoc::formatdoc;
|
use indoc::formatdoc;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use std::{future::Future, pin::Pin};
|
use std::{future::Future, pin::Pin};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
use mcp_core::{
|
use mcp_core::{
|
||||||
handler::{PromptError, ResourceError, ToolError},
|
handler::{PromptError, ResourceError, ToolError},
|
||||||
prompt::Prompt,
|
prompt::Prompt,
|
||||||
protocol::ServerCapabilities,
|
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||||
resource::Resource,
|
resource::Resource,
|
||||||
role::Role,
|
role::Role,
|
||||||
tool::{Tool, ToolAnnotations},
|
tool::{Tool, ToolAnnotations},
|
||||||
@@ -130,6 +131,7 @@ impl Router for TutorialRouter {
|
|||||||
&self,
|
&self,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
arguments: Value,
|
arguments: Value,
|
||||||
|
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
let tool_name = tool_name.to_string();
|
let tool_name = tool_name.to_string();
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ use axum::{
|
|||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures::{stream::StreamExt, Stream};
|
use futures::{stream::StreamExt, Stream};
|
||||||
use goose::{
|
use goose::{
|
||||||
agents::SessionConfig,
|
agents::{AgentEvent, SessionConfig},
|
||||||
message::{Message, MessageContent},
|
message::{Message, MessageContent},
|
||||||
permission::permission_confirmation::PrincipalType,
|
permission::permission_confirmation::PrincipalType,
|
||||||
};
|
};
|
||||||
@@ -18,7 +18,7 @@ use goose::{
|
|||||||
permission::{Permission, PermissionConfirmation},
|
permission::{Permission, PermissionConfirmation},
|
||||||
session,
|
session,
|
||||||
};
|
};
|
||||||
use mcp_core::{role::Role, Content, ToolResult};
|
use mcp_core::{protocol::JsonRpcMessage, role::Role, Content, ToolResult};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
@@ -79,9 +79,19 @@ impl IntoResponse for SseResponse {
|
|||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
enum MessageEvent {
|
enum MessageEvent {
|
||||||
Message { message: Message },
|
Message {
|
||||||
Error { error: String },
|
message: Message,
|
||||||
Finish { reason: String },
|
},
|
||||||
|
Error {
|
||||||
|
error: String,
|
||||||
|
},
|
||||||
|
Finish {
|
||||||
|
reason: String,
|
||||||
|
},
|
||||||
|
Notification {
|
||||||
|
request_id: String,
|
||||||
|
message: JsonRpcMessage,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn stream_event(
|
async fn stream_event(
|
||||||
@@ -200,7 +210,7 @@ async fn handler(
|
|||||||
tokio::select! {
|
tokio::select! {
|
||||||
response = timeout(Duration::from_millis(500), stream.next()) => {
|
response = timeout(Duration::from_millis(500), stream.next()) => {
|
||||||
match response {
|
match response {
|
||||||
Ok(Some(Ok(message))) => {
|
Ok(Some(Ok(AgentEvent::Message(message)))) => {
|
||||||
all_messages.push(message.clone());
|
all_messages.push(message.clone());
|
||||||
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
|
if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await {
|
||||||
tracing::error!("Error sending message through channel: {}", e);
|
tracing::error!("Error sending message through channel: {}", e);
|
||||||
@@ -223,6 +233,20 @@ async fn handler(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => {
|
||||||
|
if let Err(e) = stream_event(MessageEvent::Notification{
|
||||||
|
request_id: request_id.clone(),
|
||||||
|
message: n,
|
||||||
|
}, &tx).await {
|
||||||
|
tracing::error!("Error sending message through channel: {}", e);
|
||||||
|
let _ = stream_event(
|
||||||
|
MessageEvent::Error {
|
||||||
|
error: e.to_string(),
|
||||||
|
},
|
||||||
|
&tx,
|
||||||
|
).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(Some(Err(e))) => {
|
Ok(Some(Err(e))) => {
|
||||||
tracing::error!("Error processing message: {}", e);
|
tracing::error!("Error processing message: {}", e);
|
||||||
let _ = stream_event(
|
let _ = stream_event(
|
||||||
@@ -317,7 +341,7 @@ async fn ask_handler(
|
|||||||
|
|
||||||
while let Some(response) = stream.next().await {
|
while let Some(response) = stream.next().await {
|
||||||
match response {
|
match response {
|
||||||
Ok(message) => {
|
Ok(AgentEvent::Message(message)) => {
|
||||||
if message.role == Role::Assistant {
|
if message.role == Role::Assistant {
|
||||||
for content in &message.content {
|
for content in &message.content {
|
||||||
if let MessageContent::Text(text) = content {
|
if let MessageContent::Text(text) = content {
|
||||||
@@ -328,6 +352,10 @@ async fn ask_handler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Ok(AgentEvent::McpNotification(n)) => {
|
||||||
|
// Handle notifications if needed
|
||||||
|
tracing::info!("Received notification: {:?}", n);
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Error processing as_ai message: {}", e);
|
tracing::error!("Error processing as_ai message: {}", e);
|
||||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|||||||
@@ -71,10 +71,10 @@ aws-sdk-bedrockruntime = "1.74.0"
|
|||||||
# For GCP Vertex AI provider auth
|
# For GCP Vertex AI provider auth
|
||||||
jsonwebtoken = "9.3.1"
|
jsonwebtoken = "9.3.1"
|
||||||
|
|
||||||
# Added blake3 hashing library as a dependency
|
|
||||||
blake3 = "1.5"
|
blake3 = "1.5"
|
||||||
fs2 = "0.4.3"
|
fs2 = "0.4.3"
|
||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
|
tokio-stream = "0.1.17"
|
||||||
|
|
||||||
# Vector database for tool selection
|
# Vector database for tool selection
|
||||||
lancedb = "0.13"
|
lancedb = "0.13"
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use dotenv::dotenv;
|
use dotenv::dotenv;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use goose::agents::{Agent, ExtensionConfig};
|
use goose::agents::{Agent, AgentEvent, ExtensionConfig};
|
||||||
use goose::config::{DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT};
|
use goose::config::{DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT};
|
||||||
use goose::message::Message;
|
use goose::message::Message;
|
||||||
use goose::providers::databricks::DatabricksProvider;
|
use goose::providers::databricks::DatabricksProvider;
|
||||||
@@ -20,10 +20,11 @@ async fn main() {
|
|||||||
|
|
||||||
let config = ExtensionConfig::stdio(
|
let config = ExtensionConfig::stdio(
|
||||||
"developer",
|
"developer",
|
||||||
"./target/debug/developer",
|
"./target/debug/goose",
|
||||||
DEFAULT_EXTENSION_DESCRIPTION,
|
DEFAULT_EXTENSION_DESCRIPTION,
|
||||||
DEFAULT_EXTENSION_TIMEOUT,
|
DEFAULT_EXTENSION_TIMEOUT,
|
||||||
);
|
)
|
||||||
|
.with_args(vec!["mcp", "developer"]);
|
||||||
agent.add_extension(config).await.unwrap();
|
agent.add_extension(config).await.unwrap();
|
||||||
|
|
||||||
println!("Extensions:");
|
println!("Extensions:");
|
||||||
@@ -35,11 +36,8 @@ async fn main() {
|
|||||||
.with_text("can you summarize the readme.md in this dir using just a haiku?")];
|
.with_text("can you summarize the readme.md in this dir using just a haiku?")];
|
||||||
|
|
||||||
let mut stream = agent.reply(&messages, None).await.unwrap();
|
let mut stream = agent.reply(&messages, None).await.unwrap();
|
||||||
while let Some(message) = stream.next().await {
|
while let Some(Ok(AgentEvent::Message(message))) = stream.next().await {
|
||||||
println!(
|
println!("{}", serde_json::to_string_pretty(&message).unwrap());
|
||||||
"{}",
|
|
||||||
serde_json::to_string_pretty(&message.unwrap()).unwrap()
|
|
||||||
);
|
|
||||||
println!("\n");
|
println!("\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::BoxStream;
|
||||||
use futures::TryStreamExt;
|
use futures::{FutureExt, Stream, TryStreamExt};
|
||||||
|
use futures_util::stream;
|
||||||
|
use futures_util::stream::StreamExt;
|
||||||
|
use mcp_core::protocol::JsonRpcMessage;
|
||||||
|
|
||||||
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
||||||
use crate::message::Message;
|
use crate::message::Message;
|
||||||
@@ -39,7 +44,7 @@ use mcp_core::{
|
|||||||
|
|
||||||
use super::platform_tools;
|
use super::platform_tools;
|
||||||
use super::router_tools;
|
use super::router_tools;
|
||||||
use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
|
use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
|
||||||
|
|
||||||
/// The main goose Agent
|
/// The main goose Agent
|
||||||
pub struct Agent {
|
pub struct Agent {
|
||||||
@@ -56,6 +61,12 @@ pub struct Agent {
|
|||||||
pub(super) router_tool_selector: Mutex<Option<Arc<Box<dyn RouterToolSelector>>>>,
|
pub(super) router_tool_selector: Mutex<Option<Arc<Box<dyn RouterToolSelector>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum AgentEvent {
|
||||||
|
Message(Message),
|
||||||
|
McpNotification((String, JsonRpcMessage)),
|
||||||
|
}
|
||||||
|
|
||||||
impl Agent {
|
impl Agent {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
// Create channels with buffer size 32 (adjust if needed)
|
// Create channels with buffer size 32 (adjust if needed)
|
||||||
@@ -100,6 +111,40 @@ impl Default for Agent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub enum ToolStreamItem<T> {
|
||||||
|
Message(JsonRpcMessage),
|
||||||
|
Result(T),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type ToolStream = Pin<Box<dyn Stream<Item = ToolStreamItem<ToolResult<Vec<Content>>>> + Send>>;
|
||||||
|
|
||||||
|
// tool_stream combines a stream of JsonRpcMessages with a future representing the
|
||||||
|
// final result of the tool call. MCP notifications are not request-scoped, but
|
||||||
|
// this lets us capture all notifications emitted during the tool call for
|
||||||
|
// simpler consumption
|
||||||
|
pub fn tool_stream<S, F>(rx: S, done: F) -> ToolStream
|
||||||
|
where
|
||||||
|
S: Stream<Item = JsonRpcMessage> + Send + Unpin + 'static,
|
||||||
|
F: Future<Output = ToolResult<Vec<Content>>> + Send + 'static,
|
||||||
|
{
|
||||||
|
Box::pin(async_stream::stream! {
|
||||||
|
tokio::pin!(done);
|
||||||
|
let mut rx = rx;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
Some(msg) = rx.next() => {
|
||||||
|
yield ToolStreamItem::Message(msg);
|
||||||
|
}
|
||||||
|
r = &mut done => {
|
||||||
|
yield ToolStreamItem::Result(r);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
impl Agent {
|
impl Agent {
|
||||||
/// Get a reference count clone to the provider
|
/// Get a reference count clone to the provider
|
||||||
pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> {
|
pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> {
|
||||||
@@ -143,7 +188,7 @@ impl Agent {
|
|||||||
&self,
|
&self,
|
||||||
tool_call: mcp_core::tool::ToolCall,
|
tool_call: mcp_core::tool::ToolCall,
|
||||||
request_id: String,
|
request_id: String,
|
||||||
) -> (String, Result<Vec<Content>, ToolError>) {
|
) -> (String, Result<ToolCallResult, ToolError>) {
|
||||||
// Check if this tool call should be allowed based on repetition monitoring
|
// Check if this tool call should be allowed based on repetition monitoring
|
||||||
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
|
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
|
||||||
let tool_call_info = ToolCall::new(tool_call.name.clone(), tool_call.arguments.clone());
|
let tool_call_info = ToolCall::new(tool_call.name.clone(), tool_call.arguments.clone());
|
||||||
@@ -171,52 +216,65 @@ impl Agent {
|
|||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.unwrap_or("")
|
.unwrap_or("")
|
||||||
.to_string();
|
.to_string();
|
||||||
return self
|
let (request_id, result) = self
|
||||||
.manage_extensions(action, extension_name, request_id)
|
.manage_extensions(action, extension_name, request_id)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
return (request_id, Ok(ToolCallResult::from(result)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let extension_manager = self.extension_manager.lock().await;
|
let extension_manager = self.extension_manager.lock().await;
|
||||||
let result = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
|
let result: ToolCallResult = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
|
||||||
// Check if the tool is read_resource and handle it separately
|
// Check if the tool is read_resource and handle it separately
|
||||||
extension_manager
|
ToolCallResult::from(
|
||||||
.read_resource(tool_call.arguments.clone())
|
extension_manager
|
||||||
.await
|
.read_resource(tool_call.arguments.clone())
|
||||||
|
.await,
|
||||||
|
)
|
||||||
} else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME {
|
} else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME {
|
||||||
extension_manager
|
ToolCallResult::from(
|
||||||
.list_resources(tool_call.arguments.clone())
|
extension_manager
|
||||||
.await
|
.list_resources(tool_call.arguments.clone())
|
||||||
|
.await,
|
||||||
|
)
|
||||||
} else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME {
|
} else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME {
|
||||||
extension_manager.search_available_extensions().await
|
ToolCallResult::from(extension_manager.search_available_extensions().await)
|
||||||
} else if self.is_frontend_tool(&tool_call.name).await {
|
} else if self.is_frontend_tool(&tool_call.name).await {
|
||||||
// For frontend tools, return an error indicating we need frontend execution
|
// For frontend tools, return an error indicating we need frontend execution
|
||||||
Err(ToolError::ExecutionError(
|
ToolCallResult::from(Err(ToolError::ExecutionError(
|
||||||
"Frontend tool execution required".to_string(),
|
"Frontend tool execution required".to_string(),
|
||||||
))
|
)))
|
||||||
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME {
|
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME {
|
||||||
let selector = self.router_tool_selector.lock().await.clone();
|
let selector = self.router_tool_selector.lock().await.clone();
|
||||||
if let Some(selector) = selector {
|
ToolCallResult::from(if let Some(selector) = selector {
|
||||||
selector.select_tools(tool_call.arguments.clone()).await
|
selector.select_tools(tool_call.arguments.clone()).await
|
||||||
} else {
|
} else {
|
||||||
Err(ToolError::ExecutionError(
|
Err(ToolError::ExecutionError(
|
||||||
"Encountered vector search error.".to_string(),
|
"Encountered vector search error.".to_string(),
|
||||||
))
|
))
|
||||||
}
|
})
|
||||||
} else {
|
} else {
|
||||||
extension_manager
|
// Clone the result to ensure no references to extension_manager are returned
|
||||||
|
let result = extension_manager
|
||||||
.dispatch_tool_call(tool_call.clone())
|
.dispatch_tool_call(tool_call.clone())
|
||||||
.await
|
.await;
|
||||||
|
match result {
|
||||||
|
Ok(call_result) => call_result,
|
||||||
|
Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))),
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
debug!(
|
(
|
||||||
"input" = serde_json::to_string(&tool_call).unwrap(),
|
request_id,
|
||||||
"output" = serde_json::to_string(&result).unwrap(),
|
Ok(ToolCallResult {
|
||||||
);
|
notification_stream: result.notification_stream,
|
||||||
|
result: Box::new(
|
||||||
// Process the response to handle large text content
|
result
|
||||||
let processed_result = super::large_response_handler::process_tool_response(result);
|
.result
|
||||||
|
.map(super::large_response_handler::process_tool_response),
|
||||||
(request_id, processed_result)
|
),
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) async fn manage_extensions(
|
pub(super) async fn manage_extensions(
|
||||||
@@ -466,7 +524,7 @@ impl Agent {
|
|||||||
&self,
|
&self,
|
||||||
messages: &[Message],
|
messages: &[Message],
|
||||||
session: Option<SessionConfig>,
|
session: Option<SessionConfig>,
|
||||||
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
|
) -> anyhow::Result<BoxStream<'_, anyhow::Result<AgentEvent>>> {
|
||||||
let mut messages = messages.to_vec();
|
let mut messages = messages.to_vec();
|
||||||
let reply_span = tracing::Span::current();
|
let reply_span = tracing::Span::current();
|
||||||
|
|
||||||
@@ -532,9 +590,8 @@ impl Agent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Yield the assistant's response with frontend tool requests filtered out
|
// Yield the assistant's response with frontend tool requests filtered out
|
||||||
yield filtered_response.clone();
|
yield AgentEvent::Message(filtered_response.clone());
|
||||||
|
|
||||||
tokio::task::yield_now().await;
|
tokio::task::yield_now().await;
|
||||||
|
|
||||||
@@ -556,7 +613,7 @@ impl Agent {
|
|||||||
// execution is yeield back to this reply loop, and is of the same Message
|
// execution is yeield back to this reply loop, and is of the same Message
|
||||||
// type, so we can yield that back up to be handled
|
// type, so we can yield that back up to be handled
|
||||||
while let Some(msg) = frontend_tool_stream.try_next().await? {
|
while let Some(msg) = frontend_tool_stream.try_next().await? {
|
||||||
yield msg;
|
yield AgentEvent::Message(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clone goose_mode once before the match to avoid move issues
|
// Clone goose_mode once before the match to avoid move issues
|
||||||
@@ -584,13 +641,23 @@ impl Agent {
|
|||||||
self.provider().await?).await;
|
self.provider().await?).await;
|
||||||
|
|
||||||
// Handle pre-approved and read-only tools in parallel
|
// Handle pre-approved and read-only tools in parallel
|
||||||
let mut tool_futures: Vec<ToolFuture> = Vec::new();
|
let mut tool_futures: Vec<(String, ToolStream)> = Vec::new();
|
||||||
|
|
||||||
// Skip the confirmation for approved tools
|
// Skip the confirmation for approved tools
|
||||||
for request in &permission_check_result.approved {
|
for request in &permission_check_result.approved {
|
||||||
if let Ok(tool_call) = request.tool_call.clone() {
|
if let Ok(tool_call) = request.tool_call.clone() {
|
||||||
let tool_future = self.dispatch_tool_call(tool_call, request.id.clone());
|
let (req_id, tool_result) = self.dispatch_tool_call(tool_call, request.id.clone()).await;
|
||||||
tool_futures.push(Box::pin(tool_future));
|
|
||||||
|
tool_futures.push((req_id, match tool_result {
|
||||||
|
Ok(result) => tool_stream(
|
||||||
|
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
|
||||||
|
result.result,
|
||||||
|
),
|
||||||
|
Err(e) => tool_stream(
|
||||||
|
Box::new(stream::empty()),
|
||||||
|
futures::future::ready(Err(e)),
|
||||||
|
),
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -618,7 +685,7 @@ impl Agent {
|
|||||||
// type, so we can yield the Message back up to be handled and grab any
|
// type, so we can yield the Message back up to be handled and grab any
|
||||||
// confirmations or denials
|
// confirmations or denials
|
||||||
while let Some(msg) = tool_approval_stream.try_next().await? {
|
while let Some(msg) = tool_approval_stream.try_next().await? {
|
||||||
yield msg;
|
yield AgentEvent::Message(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
tool_futures = {
|
tool_futures = {
|
||||||
@@ -628,16 +695,30 @@ impl Agent {
|
|||||||
futures_lock.drain(..).collect::<Vec<_>>()
|
futures_lock.drain(..).collect::<Vec<_>>()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wait for all tool calls to complete
|
let with_id = tool_futures
|
||||||
let results = futures::future::join_all(tool_futures).await;
|
.into_iter()
|
||||||
|
.map(|(request_id, stream)| {
|
||||||
|
stream.map(move |item| (request_id.clone(), item))
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let mut combined = stream::select_all(with_id);
|
||||||
|
|
||||||
let mut all_install_successful = true;
|
let mut all_install_successful = true;
|
||||||
|
|
||||||
for (request_id, output) in results.into_iter() {
|
while let Some((request_id, item)) = combined.next().await {
|
||||||
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
|
match item {
|
||||||
all_install_successful = false;
|
ToolStreamItem::Result(output) => {
|
||||||
|
if enable_extension_request_ids.contains(&request_id) && output.is_err(){
|
||||||
|
all_install_successful = false;
|
||||||
|
}
|
||||||
|
let mut response = message_tool_response.lock().await;
|
||||||
|
*response = response.clone().with_tool_response(request_id, output);
|
||||||
|
},
|
||||||
|
ToolStreamItem::Message(msg) => {
|
||||||
|
yield AgentEvent::McpNotification((request_id, msg))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let mut response = message_tool_response.lock().await;
|
|
||||||
*response = response.clone().with_tool_response(request_id, output);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update system prompt and tools if installations were successful
|
// Update system prompt and tools if installations were successful
|
||||||
@@ -647,7 +728,7 @@ impl Agent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let final_message_tool_resp = message_tool_response.lock().await.clone();
|
let final_message_tool_resp = message_tool_response.lock().await.clone();
|
||||||
yield final_message_tool_resp.clone();
|
yield AgentEvent::Message(final_message_tool_resp.clone());
|
||||||
|
|
||||||
messages.push(response);
|
messages.push(response);
|
||||||
messages.push(final_message_tool_resp);
|
messages.push(final_message_tool_resp);
|
||||||
@@ -656,15 +737,15 @@ impl Agent {
|
|||||||
// At this point, the last message should be a user message
|
// At this point, the last message should be a user message
|
||||||
// because call to provider led to context length exceeded error
|
// because call to provider led to context length exceeded error
|
||||||
// Immediately yield a special message and break
|
// Immediately yield a special message and break
|
||||||
yield Message::assistant().with_context_length_exceeded(
|
yield AgentEvent::Message(Message::assistant().with_context_length_exceeded(
|
||||||
"The context length of the model has been exceeded. Please start a new session and try again.",
|
"The context length of the model has been exceeded. Please start a new session and try again.",
|
||||||
);
|
));
|
||||||
break;
|
break;
|
||||||
},
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Create an error message & terminate the stream
|
// Create an error message & terminate the stream
|
||||||
error!("Error: {}", e);
|
error!("Error: {}", e);
|
||||||
yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."));
|
yield AgentEvent::Message(Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error.")));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use chrono::{DateTime, TimeZone, Utc};
|
use chrono::{DateTime, TimeZone, Utc};
|
||||||
use futures::future;
|
|
||||||
use futures::stream::{FuturesUnordered, StreamExt};
|
use futures::stream::{FuturesUnordered, StreamExt};
|
||||||
use mcp_client::McpService;
|
use futures::{future, FutureExt};
|
||||||
use mcp_core::protocol::GetPromptResult;
|
use mcp_core::protocol::GetPromptResult;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -10,15 +9,17 @@ use std::sync::LazyLock;
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tokio::task;
|
use tokio::task;
|
||||||
use tracing::{debug, error, warn};
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
use tracing::{error, warn};
|
||||||
|
|
||||||
use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo};
|
use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo};
|
||||||
|
use super::tool_execution::ToolCallResult;
|
||||||
use crate::agents::extension::Envs;
|
use crate::agents::extension::Envs;
|
||||||
use crate::config::{Config, ExtensionConfigManager};
|
use crate::config::{Config, ExtensionConfigManager};
|
||||||
use crate::prompt_template;
|
use crate::prompt_template;
|
||||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||||
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
|
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
|
||||||
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult};
|
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp
|
// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp
|
||||||
@@ -113,7 +114,8 @@ impl ExtensionManager {
|
|||||||
/// Add a new MCP extension based on the provided client type
|
/// Add a new MCP extension based on the provided client type
|
||||||
// TODO IMPORTANT need to ensure this times out if the extension command is broken!
|
// TODO IMPORTANT need to ensure this times out if the extension command is broken!
|
||||||
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
|
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
|
||||||
let sanitized_name = normalize(config.key().to_string());
|
let config_name = config.key().to_string();
|
||||||
|
let sanitized_name = normalize(config_name.clone());
|
||||||
|
|
||||||
/// Helper function to merge environment variables from direct envs and keychain-stored env_keys
|
/// Helper function to merge environment variables from direct envs and keychain-stored env_keys
|
||||||
async fn merge_environments(
|
async fn merge_environments(
|
||||||
@@ -183,13 +185,15 @@ impl ExtensionManager {
|
|||||||
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
||||||
let transport = SseTransport::new(uri, all_envs);
|
let transport = SseTransport::new(uri, all_envs);
|
||||||
let handle = transport.start().await?;
|
let handle = transport.start().await?;
|
||||||
let service = McpService::with_timeout(
|
Box::new(
|
||||||
handle,
|
McpClient::connect(
|
||||||
Duration::from_secs(
|
handle,
|
||||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
Duration::from_secs(
|
||||||
),
|
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||||
);
|
),
|
||||||
Box::new(McpClient::new(service))
|
)
|
||||||
|
.await?,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
ExtensionConfig::Stdio {
|
ExtensionConfig::Stdio {
|
||||||
cmd,
|
cmd,
|
||||||
@@ -202,13 +206,15 @@ impl ExtensionManager {
|
|||||||
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
||||||
let transport = StdioTransport::new(cmd, args.to_vec(), all_envs);
|
let transport = StdioTransport::new(cmd, args.to_vec(), all_envs);
|
||||||
let handle = transport.start().await?;
|
let handle = transport.start().await?;
|
||||||
let service = McpService::with_timeout(
|
Box::new(
|
||||||
handle,
|
McpClient::connect(
|
||||||
Duration::from_secs(
|
handle,
|
||||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
Duration::from_secs(
|
||||||
),
|
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||||
);
|
),
|
||||||
Box::new(McpClient::new(service))
|
)
|
||||||
|
.await?,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
ExtensionConfig::Builtin {
|
ExtensionConfig::Builtin {
|
||||||
name,
|
name,
|
||||||
@@ -227,13 +233,15 @@ impl ExtensionManager {
|
|||||||
HashMap::new(),
|
HashMap::new(),
|
||||||
);
|
);
|
||||||
let handle = transport.start().await?;
|
let handle = transport.start().await?;
|
||||||
let service = McpService::with_timeout(
|
Box::new(
|
||||||
handle,
|
McpClient::connect(
|
||||||
Duration::from_secs(
|
handle,
|
||||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
Duration::from_secs(
|
||||||
),
|
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||||
);
|
),
|
||||||
Box::new(McpClient::new(service))
|
)
|
||||||
|
.await?,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
@@ -609,7 +617,7 @@ impl ExtensionManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> ToolResult<Vec<Content>> {
|
pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> Result<ToolCallResult> {
|
||||||
// Dispatch tool call based on the prefix naming convention
|
// Dispatch tool call based on the prefix naming convention
|
||||||
let (client_name, client) = self
|
let (client_name, client) = self
|
||||||
.get_client_for_tool(&tool_call.name)
|
.get_client_for_tool(&tool_call.name)
|
||||||
@@ -620,22 +628,26 @@ impl ExtensionManager {
|
|||||||
.name
|
.name
|
||||||
.strip_prefix(client_name)
|
.strip_prefix(client_name)
|
||||||
.and_then(|s| s.strip_prefix("__"))
|
.and_then(|s| s.strip_prefix("__"))
|
||||||
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?;
|
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
let client_guard = client.lock().await;
|
let arguments = tool_call.arguments.clone();
|
||||||
|
let client = client.clone();
|
||||||
|
let notifications_receiver = client.lock().await.subscribe().await;
|
||||||
|
|
||||||
let result = client_guard
|
let fut = async move {
|
||||||
.call_tool(tool_name, tool_call.clone().arguments)
|
let client_guard = client.lock().await;
|
||||||
.await
|
client_guard
|
||||||
.map(|result| result.content)
|
.call_tool(&tool_name, arguments)
|
||||||
.map_err(|e| ToolError::ExecutionError(e.to_string()));
|
.await
|
||||||
|
.map(|call| call.content)
|
||||||
|
.map_err(|e| ToolError::ExecutionError(e.to_string()))
|
||||||
|
};
|
||||||
|
|
||||||
debug!(
|
Ok(ToolCallResult {
|
||||||
"input" = serde_json::to_string(&tool_call).unwrap(),
|
result: Box::new(fut.boxed()),
|
||||||
"output" = serde_json::to_string(&result).unwrap(),
|
notification_stream: Some(Box::new(ReceiverStream::new(notifications_receiver))),
|
||||||
);
|
})
|
||||||
|
|
||||||
result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn list_prompts_from_extension(
|
pub async fn list_prompts_from_extension(
|
||||||
@@ -793,10 +805,11 @@ mod tests {
|
|||||||
use mcp_client::client::Error;
|
use mcp_client::client::Error;
|
||||||
use mcp_client::client::McpClientTrait;
|
use mcp_client::client::McpClientTrait;
|
||||||
use mcp_core::protocol::{
|
use mcp_core::protocol::{
|
||||||
CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult,
|
CallToolResult, GetPromptResult, InitializeResult, JsonRpcMessage, ListPromptsResult,
|
||||||
ListToolsResult, ReadResourceResult,
|
ListResourcesResult, ListToolsResult, ReadResourceResult,
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
struct MockClient {}
|
struct MockClient {}
|
||||||
|
|
||||||
@@ -849,6 +862,10 @@ mod tests {
|
|||||||
) -> Result<GetPromptResult, Error> {
|
) -> Result<GetPromptResult, Error> {
|
||||||
Err(Error::NotInitialized)
|
Err(Error::NotInitialized)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
|
||||||
|
mpsc::channel(1).1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -970,6 +987,9 @@ mod tests {
|
|||||||
|
|
||||||
let result = extension_manager
|
let result = extension_manager
|
||||||
.dispatch_tool_call(invalid_tool_call)
|
.dispatch_tool_call(invalid_tool_call)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.result
|
||||||
.await;
|
.await;
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
result.err().unwrap(),
|
result.err().unwrap(),
|
||||||
@@ -986,6 +1006,11 @@ mod tests {
|
|||||||
let result = extension_manager
|
let result = extension_manager
|
||||||
.dispatch_tool_call(invalid_tool_call)
|
.dispatch_tool_call(invalid_tool_call)
|
||||||
.await;
|
.await;
|
||||||
assert!(matches!(result.err().unwrap(), ToolError::NotFound(_)));
|
if let Err(err) = result {
|
||||||
|
let tool_err = err.downcast_ref::<ToolError>().expect("Expected ToolError");
|
||||||
|
assert!(matches!(tool_err, ToolError::NotFound(_)));
|
||||||
|
} else {
|
||||||
|
panic!("Expected ToolError::NotFound");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ mod tool_router_index_manager;
|
|||||||
pub(crate) mod tool_vectordb;
|
pub(crate) mod tool_vectordb;
|
||||||
mod types;
|
mod types;
|
||||||
|
|
||||||
pub use agent::Agent;
|
pub use agent::{Agent, AgentEvent};
|
||||||
pub use extension::ExtensionConfig;
|
pub use extension::ExtensionConfig;
|
||||||
pub use extension_manager::ExtensionManager;
|
pub use extension_manager::ExtensionManager;
|
||||||
pub use prompt_manager::PromptManager;
|
pub use prompt_manager::PromptManager;
|
||||||
|
|||||||
@@ -1,23 +1,35 @@
|
|||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::pin::Pin;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_stream::try_stream;
|
use async_stream::try_stream;
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::{self, BoxStream};
|
||||||
use futures::StreamExt;
|
use futures::{Stream, StreamExt};
|
||||||
|
use mcp_core::protocol::JsonRpcMessage;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
use crate::config::permission::PermissionLevel;
|
use crate::config::permission::PermissionLevel;
|
||||||
use crate::config::PermissionManager;
|
use crate::config::PermissionManager;
|
||||||
use crate::message::{Message, ToolRequest};
|
use crate::message::{Message, ToolRequest};
|
||||||
use crate::permission::Permission;
|
use crate::permission::Permission;
|
||||||
use mcp_core::{Content, ToolError};
|
use mcp_core::{Content, ToolResult};
|
||||||
|
|
||||||
// Type alias for ToolFutures - used in the agent loop to join all futures together
|
// ToolCallResult combines the result of a tool call with an optional notification stream that
|
||||||
pub(crate) type ToolFuture<'a> =
|
// can be used to receive notifications from the tool.
|
||||||
Pin<Box<dyn Future<Output = (String, Result<Vec<Content>, ToolError>)> + Send + 'a>>;
|
pub struct ToolCallResult {
|
||||||
pub(crate) type ToolFuturesVec<'a> = Arc<Mutex<Vec<ToolFuture<'a>>>>;
|
pub result: Box<dyn Future<Output = ToolResult<Vec<Content>>> + Send + Unpin>,
|
||||||
|
pub notification_stream: Option<Box<dyn Stream<Item = JsonRpcMessage> + Send + Unpin>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ToolResult<Vec<Content>>> for ToolCallResult {
|
||||||
|
fn from(result: ToolResult<Vec<Content>>) -> Self {
|
||||||
|
Self {
|
||||||
|
result: Box::new(futures::future::ready(result)),
|
||||||
|
notification_stream: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use super::agent::{tool_stream, ToolStream};
|
||||||
use crate::agents::Agent;
|
use crate::agents::Agent;
|
||||||
|
|
||||||
pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \
|
pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \
|
||||||
@@ -37,7 +49,7 @@ impl Agent {
|
|||||||
pub(crate) fn handle_approval_tool_requests<'a>(
|
pub(crate) fn handle_approval_tool_requests<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
tool_requests: &'a [ToolRequest],
|
tool_requests: &'a [ToolRequest],
|
||||||
tool_futures: ToolFuturesVec<'a>,
|
tool_futures: Arc<Mutex<Vec<(String, ToolStream)>>>,
|
||||||
permission_manager: &'a mut PermissionManager,
|
permission_manager: &'a mut PermissionManager,
|
||||||
message_tool_response: Arc<Mutex<Message>>,
|
message_tool_response: Arc<Mutex<Message>>,
|
||||||
) -> BoxStream<'a, anyhow::Result<Message>> {
|
) -> BoxStream<'a, anyhow::Result<Message>> {
|
||||||
@@ -56,9 +68,19 @@ impl Agent {
|
|||||||
while let Some((req_id, confirmation)) = rx.recv().await {
|
while let Some((req_id, confirmation)) = rx.recv().await {
|
||||||
if req_id == request.id {
|
if req_id == request.id {
|
||||||
if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
|
if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
|
||||||
let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone());
|
let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone()).await;
|
||||||
let mut futures = tool_futures.lock().await;
|
let mut futures = tool_futures.lock().await;
|
||||||
futures.push(Box::pin(tool_future));
|
|
||||||
|
futures.push((req_id, match tool_result {
|
||||||
|
Ok(result) => tool_stream(
|
||||||
|
result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
|
||||||
|
result.result,
|
||||||
|
),
|
||||||
|
Err(e) => tool_stream(
|
||||||
|
Box::new(stream::empty()),
|
||||||
|
futures::future::ready(Err(e)),
|
||||||
|
),
|
||||||
|
}));
|
||||||
|
|
||||||
if confirmation.permission == Permission::AlwaysAllow {
|
if confirmation.permission == Permission::AlwaysAllow {
|
||||||
permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow);
|
permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow);
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tokio_cron_scheduler::{job::JobId, Job, JobScheduler as TokioJobScheduler};
|
use tokio_cron_scheduler::{job::JobId, Job, JobScheduler as TokioJobScheduler};
|
||||||
|
|
||||||
|
use crate::agents::AgentEvent;
|
||||||
use crate::agents::{Agent, SessionConfig};
|
use crate::agents::{Agent, SessionConfig};
|
||||||
use crate::config::{self, Config};
|
use crate::config::{self, Config};
|
||||||
use crate::message::Message;
|
use crate::message::Message;
|
||||||
@@ -1102,12 +1103,15 @@ async fn run_scheduled_job_internal(
|
|||||||
tokio::task::yield_now().await;
|
tokio::task::yield_now().await;
|
||||||
|
|
||||||
match message_result {
|
match message_result {
|
||||||
Ok(msg) => {
|
Ok(AgentEvent::Message(msg)) => {
|
||||||
if msg.role == mcp_core::role::Role::Assistant {
|
if msg.role == mcp_core::role::Role::Assistant {
|
||||||
tracing::info!("[Job {}] Assistant: {:?}", job.id, msg.content);
|
tracing::info!("[Job {}] Assistant: {:?}", job.id, msg.content);
|
||||||
}
|
}
|
||||||
all_session_messages.push(msg);
|
all_session_messages.push(msg);
|
||||||
}
|
}
|
||||||
|
Ok(AgentEvent::McpNotification(_)) => {
|
||||||
|
// Handle notifications if needed
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
"[Job {}] Error receiving message from agent: {}",
|
"[Job {}] Error receiving message from agent: {}",
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use goose::agents::Agent;
|
use goose::agents::{Agent, AgentEvent};
|
||||||
use goose::message::Message;
|
use goose::message::Message;
|
||||||
use goose::model::ModelConfig;
|
use goose::model::ModelConfig;
|
||||||
use goose::providers::base::Provider;
|
use goose::providers::base::Provider;
|
||||||
@@ -132,7 +132,10 @@ async fn run_truncate_test(
|
|||||||
let mut responses = Vec::new();
|
let mut responses = Vec::new();
|
||||||
while let Some(response_result) = reply_stream.next().await {
|
while let Some(response_result) = reply_stream.next().await {
|
||||||
match response_result {
|
match response_result {
|
||||||
Ok(response) => responses.push(response),
|
Ok(AgentEvent::Message(response)) => responses.push(response),
|
||||||
|
Ok(AgentEvent::McpNotification(n)) => {
|
||||||
|
println!("MCP Notification: {n:?}");
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
println!("Error: {:?}", e);
|
println!("Error: {:?}", e);
|
||||||
return Err(e);
|
return Err(e);
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
use mcp_client::{
|
use mcp_client::{
|
||||||
client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait},
|
client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait},
|
||||||
transport::{SseTransport, StdioTransport, Transport},
|
transport::{SseTransport, StdioTransport, Transport},
|
||||||
McpService,
|
|
||||||
};
|
};
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
@@ -20,18 +19,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
|
let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
|
||||||
let handle1 = transport1.start().await?;
|
let handle1 = transport1.start().await?;
|
||||||
let service1 = McpService::with_timeout(handle1, Duration::from_secs(30));
|
let client1 = McpClient::connect(handle1, Duration::from_secs(30)).await?;
|
||||||
let client1 = McpClient::new(service1);
|
|
||||||
|
|
||||||
let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
|
let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
|
||||||
let handle2 = transport2.start().await?;
|
let handle2 = transport2.start().await?;
|
||||||
let service2 = McpService::with_timeout(handle2, Duration::from_secs(30));
|
let client2 = McpClient::connect(handle2, Duration::from_secs(30)).await?;
|
||||||
let client2 = McpClient::new(service2);
|
|
||||||
|
|
||||||
let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new());
|
let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new());
|
||||||
let handle3 = transport3.start().await?;
|
let handle3 = transport3.start().await?;
|
||||||
let service3 = McpService::with_timeout(handle3, Duration::from_secs(10));
|
let client3 = McpClient::connect(handle3, Duration::from_secs(10)).await?;
|
||||||
let client3 = McpClient::new(service3);
|
|
||||||
|
|
||||||
// Initialize both clients
|
// Initialize both clients
|
||||||
let mut clients: Vec<Box<dyn McpClientTrait>> =
|
let mut clients: Vec<Box<dyn McpClientTrait>> =
|
||||||
|
|||||||
122
crates/mcp-client/examples/integration_test.rs
Normal file
122
crates/mcp-client/examples/integration_test.rs
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use futures::lock::Mutex;
|
||||||
|
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||||
|
use mcp_client::transport::{SseTransport, Transport};
|
||||||
|
use mcp_client::StdioTransport;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
// Initialize logging
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(
|
||||||
|
EnvFilter::from_default_env()
|
||||||
|
.add_directive("mcp_client=debug".parse().unwrap())
|
||||||
|
.add_directive("eventsource_client=info".parse().unwrap()),
|
||||||
|
)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
test_transport(sse_transport().await?).await?;
|
||||||
|
test_transport(stdio_transport().await?).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn sse_transport() -> Result<SseTransport> {
|
||||||
|
let port = "60053";
|
||||||
|
|
||||||
|
tokio::process::Command::new("npx")
|
||||||
|
.env("PORT", port)
|
||||||
|
.arg("@modelcontextprotocol/server-everything")
|
||||||
|
.arg("sse")
|
||||||
|
.spawn()?;
|
||||||
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||||
|
|
||||||
|
Ok(SseTransport::new(
|
||||||
|
format!("http://localhost:{}/sse", port),
|
||||||
|
HashMap::new(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stdio_transport() -> Result<StdioTransport> {
|
||||||
|
Ok(StdioTransport::new(
|
||||||
|
"npx",
|
||||||
|
vec!["@modelcontextprotocol/server-everything"]
|
||||||
|
.into_iter()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect(),
|
||||||
|
HashMap::new(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn test_transport<T>(transport: T) -> Result<()>
|
||||||
|
where
|
||||||
|
T: Transport + Send + 'static,
|
||||||
|
{
|
||||||
|
// Start transport
|
||||||
|
let handle = transport.start().await?;
|
||||||
|
|
||||||
|
// Create client
|
||||||
|
let mut client = McpClient::connect(handle, Duration::from_secs(10)).await?;
|
||||||
|
println!("Client created\n");
|
||||||
|
|
||||||
|
let mut receiver = client.subscribe().await;
|
||||||
|
let events = Arc::new(Mutex::new(Vec::new()));
|
||||||
|
let events_clone = events.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(event) = receiver.recv().await {
|
||||||
|
println!("Received event: {event:?}");
|
||||||
|
events_clone.lock().await.push(event);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
let server_info = client
|
||||||
|
.initialize(
|
||||||
|
ClientInfo {
|
||||||
|
name: "test-client".into(),
|
||||||
|
version: "1.0.0".into(),
|
||||||
|
},
|
||||||
|
ClientCapabilities::default(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
println!("Connected to server: {server_info:?}\n");
|
||||||
|
|
||||||
|
// Sleep for 100ms to allow the server to start - surprisingly this is required!
|
||||||
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||||
|
|
||||||
|
// List tools
|
||||||
|
let tools = client.list_tools(None).await?;
|
||||||
|
println!("Available tools: {tools:#?}\n");
|
||||||
|
|
||||||
|
// Call tool
|
||||||
|
let tool_result = client
|
||||||
|
.call_tool("echo", serde_json::json!({ "message": "honk" }))
|
||||||
|
.await?;
|
||||||
|
println!("Tool result: {tool_result:#?}\n");
|
||||||
|
|
||||||
|
let collected_eventes_before = events.lock().await.len();
|
||||||
|
let n_steps = 5;
|
||||||
|
let long_op = client
|
||||||
|
.call_tool(
|
||||||
|
"longRunningOperation",
|
||||||
|
serde_json::json!({ "duration": 3, "steps": n_steps }),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
println!("Long op result: {long_op:#?}\n");
|
||||||
|
let collected_events_after = events.lock().await.len();
|
||||||
|
assert_eq!(collected_events_after - collected_eventes_before, n_steps);
|
||||||
|
|
||||||
|
// List resources
|
||||||
|
let resources = client.list_resources(None).await?;
|
||||||
|
println!("Resources: {resources:#?}\n");
|
||||||
|
|
||||||
|
// Read resource
|
||||||
|
let resource = client.read_resource("test://static/resource/1").await?;
|
||||||
|
println!("Resource: {resource:#?}\n");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||||
use mcp_client::transport::{SseTransport, Transport};
|
use mcp_client::transport::{SseTransport, Transport};
|
||||||
use mcp_client::McpService;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
@@ -23,11 +22,8 @@ async fn main() -> Result<()> {
|
|||||||
// Start transport
|
// Start transport
|
||||||
let handle = transport.start().await?;
|
let handle = transport.start().await?;
|
||||||
|
|
||||||
// Create the service with timeout middleware
|
|
||||||
let service = McpService::with_timeout(handle, Duration::from_secs(3));
|
|
||||||
|
|
||||||
// Create client
|
// Create client
|
||||||
let mut client = McpClient::new(service);
|
let mut client = McpClient::connect(handle, Duration::from_secs(3)).await?;
|
||||||
println!("Client created\n");
|
println!("Client created\n");
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use mcp_client::{
|
use mcp_client::{
|
||||||
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, McpService,
|
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
|
||||||
StdioTransport, Transport,
|
StdioTransport, Transport,
|
||||||
};
|
};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
@@ -25,11 +25,8 @@ async fn main() -> Result<(), ClientError> {
|
|||||||
// 2) Start the transport to get a handle
|
// 2) Start the transport to get a handle
|
||||||
let transport_handle = transport.start().await?;
|
let transport_handle = transport.start().await?;
|
||||||
|
|
||||||
// 3) Create the service with timeout middleware
|
// 3) Create the client with the middleware-wrapped service
|
||||||
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));
|
let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?;
|
||||||
|
|
||||||
// 4) Create the client with the middleware-wrapped service
|
|
||||||
let mut client = McpClient::new(service);
|
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
let server_info = client
|
let server_info = client
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ use mcp_client::client::{
|
|||||||
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
|
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
|
||||||
};
|
};
|
||||||
use mcp_client::transport::{StdioTransport, Transport};
|
use mcp_client::transport::{StdioTransport, Transport};
|
||||||
use mcp_client::McpService;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
@@ -34,11 +33,8 @@ async fn main() -> Result<(), ClientError> {
|
|||||||
// Start the transport to get a handle
|
// Start the transport to get a handle
|
||||||
let transport_handle = transport.start().await.unwrap();
|
let transport_handle = transport.start().await.unwrap();
|
||||||
|
|
||||||
// Create the service with timeout middleware
|
|
||||||
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));
|
|
||||||
|
|
||||||
// Create client
|
// Create client
|
||||||
let mut client = McpClient::new(service);
|
let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?;
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
let server_info = client
|
let server_info = client
|
||||||
|
|||||||
@@ -4,11 +4,16 @@ use mcp_core::protocol::{
|
|||||||
ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
|
ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::{json, Value};
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::{
|
||||||
|
atomic::{AtomicU64, Ordering},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::{mpsc, Mutex};
|
||||||
use tower::{Service, ServiceExt}; // for Service::ready()
|
use tower::{timeout::TimeoutLayer, Layer, Service, ServiceExt};
|
||||||
|
|
||||||
|
use crate::{McpService, TransportHandle};
|
||||||
|
|
||||||
pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
|
pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
|
||||||
|
|
||||||
@@ -97,34 +102,67 @@ pub trait McpClientTrait: Send + Sync {
|
|||||||
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
|
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
|
||||||
|
|
||||||
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
|
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
|
||||||
|
|
||||||
|
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The MCP client is the interface for MCP operations.
|
/// The MCP client is the interface for MCP operations.
|
||||||
pub struct McpClient<S>
|
pub struct McpClient<T>
|
||||||
where
|
where
|
||||||
S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
|
T: TransportHandle + Send + Sync + 'static,
|
||||||
S::Error: Into<Error>,
|
|
||||||
S::Future: Send,
|
|
||||||
{
|
{
|
||||||
service: Mutex<S>,
|
service: Mutex<tower::timeout::Timeout<McpService<T>>>,
|
||||||
next_id: AtomicU64,
|
next_id: AtomicU64,
|
||||||
server_capabilities: Option<ServerCapabilities>,
|
server_capabilities: Option<ServerCapabilities>,
|
||||||
server_info: Option<Implementation>,
|
server_info: Option<Implementation>,
|
||||||
|
notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<JsonRpcMessage>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> McpClient<S>
|
impl<T> McpClient<T>
|
||||||
where
|
where
|
||||||
S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
|
T: TransportHandle + Send + Sync + 'static,
|
||||||
S::Error: Into<Error>,
|
|
||||||
S::Future: Send,
|
|
||||||
{
|
{
|
||||||
pub fn new(service: S) -> Self {
|
pub async fn connect(transport: T, timeout: std::time::Duration) -> Result<Self, Error> {
|
||||||
Self {
|
let service = McpService::new(transport.clone());
|
||||||
service: Mutex::new(service),
|
let service_ptr = service.clone();
|
||||||
|
let notification_subscribers =
|
||||||
|
Arc::new(Mutex::new(Vec::<mpsc::Sender<JsonRpcMessage>>::new()));
|
||||||
|
let subscribers_ptr = notification_subscribers.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
match transport.receive().await {
|
||||||
|
Ok(message) => {
|
||||||
|
tracing::info!("Received message: {:?}", message);
|
||||||
|
match message {
|
||||||
|
JsonRpcMessage::Response(JsonRpcResponse { id: Some(id), .. }) => {
|
||||||
|
service_ptr.respond(&id.to_string(), Ok(message)).await;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let mut subs = subscribers_ptr.lock().await;
|
||||||
|
subs.retain(|sub| sub.try_send(message.clone()).is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("transport error: {:?}", e);
|
||||||
|
service_ptr.hangup().await;
|
||||||
|
subscribers_ptr.lock().await.clear();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let middleware = TimeoutLayer::new(timeout);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
service: Mutex::new(middleware.layer(service)),
|
||||||
next_id: AtomicU64::new(1),
|
next_id: AtomicU64::new(1),
|
||||||
server_capabilities: None,
|
server_capabilities: None,
|
||||||
server_info: None,
|
server_info: None,
|
||||||
}
|
notification_subscribers,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send a JSON-RPC request and check we don't get an error response.
|
/// Send a JSON-RPC request and check we don't get an error response.
|
||||||
@@ -134,13 +172,18 @@ where
|
|||||||
{
|
{
|
||||||
let mut service = self.service.lock().await;
|
let mut service = self.service.lock().await;
|
||||||
service.ready().await.map_err(|_| Error::NotReady)?;
|
service.ready().await.map_err(|_| Error::NotReady)?;
|
||||||
|
|
||||||
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
|
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
|
||||||
|
|
||||||
|
let mut params = params.clone();
|
||||||
|
params["_meta"] = json!({
|
||||||
|
"progressToken": format!("prog-{}", id),
|
||||||
|
});
|
||||||
|
|
||||||
let request = JsonRpcMessage::Request(JsonRpcRequest {
|
let request = JsonRpcMessage::Request(JsonRpcRequest {
|
||||||
jsonrpc: "2.0".to_string(),
|
jsonrpc: "2.0".to_string(),
|
||||||
id: Some(id),
|
id: Some(id),
|
||||||
method: method.to_string(),
|
method: method.to_string(),
|
||||||
params: Some(params.clone()),
|
params: Some(params),
|
||||||
});
|
});
|
||||||
|
|
||||||
let response_msg = service
|
let response_msg = service
|
||||||
@@ -154,7 +197,7 @@ where
|
|||||||
.unwrap_or("".to_string()),
|
.unwrap_or("".to_string()),
|
||||||
method: method.to_string(),
|
method: method.to_string(),
|
||||||
// we don't need include params because it can be really large
|
// we don't need include params because it can be really large
|
||||||
source: Box::new(e.into()),
|
source: Box::<Error>::new(e.into()),
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
match response_msg {
|
match response_msg {
|
||||||
@@ -220,7 +263,7 @@ where
|
|||||||
.unwrap_or("".to_string()),
|
.unwrap_or("".to_string()),
|
||||||
method: method.to_string(),
|
method: method.to_string(),
|
||||||
// we don't need include params because it can be really large
|
// we don't need include params because it can be really large
|
||||||
source: Box::new(e.into()),
|
source: Box::<Error>::new(e.into()),
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -233,11 +276,9 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl<S> McpClientTrait for McpClient<S>
|
impl<T> McpClientTrait for McpClient<T>
|
||||||
where
|
where
|
||||||
S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
|
T: TransportHandle + Send + Sync + 'static,
|
||||||
S::Error: Into<Error>,
|
|
||||||
S::Future: Send,
|
|
||||||
{
|
{
|
||||||
async fn initialize(
|
async fn initialize(
|
||||||
&mut self,
|
&mut self,
|
||||||
@@ -388,4 +429,10 @@ where
|
|||||||
|
|
||||||
self.send_request("prompts/get", params).await
|
self.send_request("prompts/get", params).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
|
||||||
|
let (tx, rx) = mpsc::channel(16);
|
||||||
|
self.notification_subscribers.lock().await.push(tx);
|
||||||
|
rx
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
use futures::future::BoxFuture;
|
use futures::future::BoxFuture;
|
||||||
use mcp_core::protocol::JsonRpcMessage;
|
use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest};
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
|
use tokio::sync::{oneshot, RwLock};
|
||||||
use tower::{timeout::Timeout, Service, ServiceBuilder};
|
use tower::{timeout::Timeout, Service, ServiceBuilder};
|
||||||
|
|
||||||
use crate::transport::{Error, TransportHandle};
|
use crate::transport::{Error, TransportHandle};
|
||||||
@@ -10,14 +12,24 @@ use crate::transport::{Error, TransportHandle};
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct McpService<T: TransportHandle> {
|
pub struct McpService<T: TransportHandle> {
|
||||||
inner: Arc<T>,
|
inner: Arc<T>,
|
||||||
|
pending_requests: Arc<PendingRequests>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: TransportHandle> McpService<T> {
|
impl<T: TransportHandle> McpService<T> {
|
||||||
pub fn new(transport: T) -> Self {
|
pub fn new(transport: T) -> Self {
|
||||||
Self {
|
Self {
|
||||||
inner: Arc::new(transport),
|
inner: Arc::new(transport),
|
||||||
|
pending_requests: Arc::new(PendingRequests::default()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
|
||||||
|
self.pending_requests.respond(id, response).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn hangup(&self) {
|
||||||
|
self.pending_requests.broadcast_close().await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Service<JsonRpcMessage> for McpService<T>
|
impl<T> Service<JsonRpcMessage> for McpService<T>
|
||||||
@@ -35,7 +47,31 @@ where
|
|||||||
|
|
||||||
fn call(&mut self, request: JsonRpcMessage) -> Self::Future {
|
fn call(&mut self, request: JsonRpcMessage) -> Self::Future {
|
||||||
let transport = self.inner.clone();
|
let transport = self.inner.clone();
|
||||||
Box::pin(async move { transport.send(request).await })
|
let pending_requests = self.pending_requests.clone();
|
||||||
|
|
||||||
|
Box::pin(async move {
|
||||||
|
match request {
|
||||||
|
JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) => {
|
||||||
|
// Create a channel to receive the response
|
||||||
|
let (sender, receiver) = oneshot::channel();
|
||||||
|
pending_requests.insert(id.to_string(), sender).await;
|
||||||
|
|
||||||
|
transport.send(request).await?;
|
||||||
|
receiver.await.map_err(|_| Error::ChannelClosed)?
|
||||||
|
}
|
||||||
|
JsonRpcMessage::Request(_) => {
|
||||||
|
// Handle notifications without waiting for a response
|
||||||
|
transport.send(request).await?;
|
||||||
|
Ok(JsonRpcMessage::Nil)
|
||||||
|
}
|
||||||
|
JsonRpcMessage::Notification(_) => {
|
||||||
|
// Handle notifications without waiting for a response
|
||||||
|
transport.send(request).await?;
|
||||||
|
Ok(JsonRpcMessage::Nil)
|
||||||
|
}
|
||||||
|
_ => Err(Error::UnsupportedMessage),
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,3 +86,50 @@ where
|
|||||||
.service(McpService::new(transport))
|
.service(McpService::new(transport))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A data structure to store pending requests and their response channels
|
||||||
|
pub struct PendingRequests {
|
||||||
|
requests: RwLock<HashMap<String, oneshot::Sender<Result<JsonRpcMessage, Error>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PendingRequests {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PendingRequests {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
requests: RwLock::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn insert(&self, id: String, sender: oneshot::Sender<Result<JsonRpcMessage, Error>>) {
|
||||||
|
self.requests.write().await.insert(id, sender);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
|
||||||
|
if let Some(tx) = self.requests.write().await.remove(id) {
|
||||||
|
let _ = tx.send(response);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn broadcast_close(&self) {
|
||||||
|
for (_, tx) in self.requests.write().await.drain() {
|
||||||
|
let _ = tx.send(Err(Error::ChannelClosed));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn clear(&self) {
|
||||||
|
self.requests.write().await.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn len(&self) -> usize {
|
||||||
|
self.requests.read().await.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn is_empty(&self) -> bool {
|
||||||
|
self.len().await == 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use mcp_core::protocol::JsonRpcMessage;
|
use mcp_core::protocol::JsonRpcMessage;
|
||||||
use std::collections::HashMap;
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{mpsc, oneshot, RwLock};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
|
pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
|
||||||
/// A generic error type for transport operations.
|
/// A generic error type for transport operations.
|
||||||
@@ -57,74 +56,20 @@ pub trait Transport {
|
|||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait TransportHandle: Send + Sync + Clone + 'static {
|
pub trait TransportHandle: Send + Sync + Clone + 'static {
|
||||||
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error>;
|
async fn send(&self, message: JsonRpcMessage) -> Result<(), Error>;
|
||||||
|
async fn receive(&self) -> Result<JsonRpcMessage, Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function that contains the common send implementation
|
pub async fn serialize_and_send(
|
||||||
pub async fn send_message(
|
sender: &mpsc::Sender<String>,
|
||||||
sender: &mpsc::Sender<TransportMessage>,
|
|
||||||
message: JsonRpcMessage,
|
message: JsonRpcMessage,
|
||||||
) -> Result<JsonRpcMessage, Error> {
|
) -> Result<(), Error> {
|
||||||
match message {
|
match serde_json::to_string(&message).map_err(Error::Serialization) {
|
||||||
JsonRpcMessage::Request(request) => {
|
Ok(msg) => sender.send(msg).await.map_err(|_| Error::ChannelClosed),
|
||||||
let (respond_to, response) = oneshot::channel();
|
Err(e) => {
|
||||||
let msg = TransportMessage {
|
tracing::error!(error = ?e, "Error serializing message");
|
||||||
message: JsonRpcMessage::Request(request),
|
Err(e)
|
||||||
response_tx: Some(respond_to),
|
|
||||||
};
|
|
||||||
sender.send(msg).await.map_err(|_| Error::ChannelClosed)?;
|
|
||||||
Ok(response.await.map_err(|_| Error::ChannelClosed)??)
|
|
||||||
}
|
}
|
||||||
JsonRpcMessage::Notification(notification) => {
|
|
||||||
let msg = TransportMessage {
|
|
||||||
message: JsonRpcMessage::Notification(notification),
|
|
||||||
response_tx: None,
|
|
||||||
};
|
|
||||||
sender.send(msg).await.map_err(|_| Error::ChannelClosed)?;
|
|
||||||
Ok(JsonRpcMessage::Nil)
|
|
||||||
}
|
|
||||||
_ => Err(Error::UnsupportedMessage),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A data structure to store pending requests and their response channels
|
|
||||||
pub struct PendingRequests {
|
|
||||||
requests: RwLock<HashMap<String, oneshot::Sender<Result<JsonRpcMessage, Error>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for PendingRequests {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PendingRequests {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
requests: RwLock::new(HashMap::new()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn insert(&self, id: String, sender: oneshot::Sender<Result<JsonRpcMessage, Error>>) {
|
|
||||||
self.requests.write().await.insert(id, sender);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
|
|
||||||
if let Some(tx) = self.requests.write().await.remove(id) {
|
|
||||||
let _ = tx.send(response);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn clear(&self) {
|
|
||||||
self.requests.write().await.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn len(&self) -> usize {
|
|
||||||
self.requests.read().await.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn is_empty(&self) -> bool {
|
|
||||||
self.len().await == 0
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
use crate::transport::{Error, PendingRequests, TransportMessage};
|
use crate::transport::Error;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use eventsource_client::{Client, SSE};
|
use eventsource_client::{Client, SSE};
|
||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest};
|
use mcp_core::protocol::JsonRpcMessage;
|
||||||
use reqwest::Client as HttpClient;
|
use reqwest::Client as HttpClient;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{mpsc, RwLock};
|
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||||
use tokio::time::{timeout, Duration};
|
use tokio::time::{timeout, Duration};
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
use super::{send_message, Transport, TransportHandle};
|
use super::{serialize_and_send, Transport, TransportHandle};
|
||||||
|
|
||||||
// Timeout for the endpoint discovery
|
// Timeout for the endpoint discovery
|
||||||
const ENDPOINT_TIMEOUT_SECS: u64 = 5;
|
const ENDPOINT_TIMEOUT_SECS: u64 = 5;
|
||||||
@@ -21,9 +21,9 @@ const ENDPOINT_TIMEOUT_SECS: u64 = 5;
|
|||||||
/// - Sends outgoing messages via HTTP POST (once the post endpoint is known).
|
/// - Sends outgoing messages via HTTP POST (once the post endpoint is known).
|
||||||
pub struct SseActor {
|
pub struct SseActor {
|
||||||
/// Receives messages (requests/notifications) from the handle
|
/// Receives messages (requests/notifications) from the handle
|
||||||
receiver: mpsc::Receiver<TransportMessage>,
|
receiver: mpsc::Receiver<String>,
|
||||||
/// Map of request-id -> oneshot sender
|
/// Sends messages (responses) back to the handle
|
||||||
pending_requests: Arc<PendingRequests>,
|
sender: mpsc::Sender<JsonRpcMessage>,
|
||||||
/// Base SSE URL
|
/// Base SSE URL
|
||||||
sse_url: String,
|
sse_url: String,
|
||||||
/// For sending HTTP POST requests
|
/// For sending HTTP POST requests
|
||||||
@@ -34,14 +34,14 @@ pub struct SseActor {
|
|||||||
|
|
||||||
impl SseActor {
|
impl SseActor {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
receiver: mpsc::Receiver<TransportMessage>,
|
receiver: mpsc::Receiver<String>,
|
||||||
pending_requests: Arc<PendingRequests>,
|
sender: mpsc::Sender<JsonRpcMessage>,
|
||||||
sse_url: String,
|
sse_url: String,
|
||||||
post_endpoint: Arc<RwLock<Option<String>>>,
|
post_endpoint: Arc<RwLock<Option<String>>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
receiver,
|
receiver,
|
||||||
pending_requests,
|
sender,
|
||||||
sse_url,
|
sse_url,
|
||||||
post_endpoint,
|
post_endpoint,
|
||||||
http_client: HttpClient::new(),
|
http_client: HttpClient::new(),
|
||||||
@@ -54,15 +54,14 @@ impl SseActor {
|
|||||||
pub async fn run(self) {
|
pub async fn run(self) {
|
||||||
tokio::join!(
|
tokio::join!(
|
||||||
Self::handle_incoming_messages(
|
Self::handle_incoming_messages(
|
||||||
|
self.sender,
|
||||||
self.sse_url.clone(),
|
self.sse_url.clone(),
|
||||||
Arc::clone(&self.pending_requests),
|
|
||||||
Arc::clone(&self.post_endpoint)
|
Arc::clone(&self.post_endpoint)
|
||||||
),
|
),
|
||||||
Self::handle_outgoing_messages(
|
Self::handle_outgoing_messages(
|
||||||
self.receiver,
|
self.receiver,
|
||||||
self.http_client.clone(),
|
self.http_client.clone(),
|
||||||
Arc::clone(&self.post_endpoint),
|
Arc::clone(&self.post_endpoint),
|
||||||
Arc::clone(&self.pending_requests),
|
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -72,14 +71,13 @@ impl SseActor {
|
|||||||
/// - If a `message` event is received, parse it as `JsonRpcMessage`
|
/// - If a `message` event is received, parse it as `JsonRpcMessage`
|
||||||
/// and respond to pending requests if it's a `Response`.
|
/// and respond to pending requests if it's a `Response`.
|
||||||
async fn handle_incoming_messages(
|
async fn handle_incoming_messages(
|
||||||
|
sender: mpsc::Sender<JsonRpcMessage>,
|
||||||
sse_url: String,
|
sse_url: String,
|
||||||
pending_requests: Arc<PendingRequests>,
|
|
||||||
post_endpoint: Arc<RwLock<Option<String>>>,
|
post_endpoint: Arc<RwLock<Option<String>>>,
|
||||||
) {
|
) {
|
||||||
let client = match eventsource_client::ClientBuilder::for_url(&sse_url) {
|
let client = match eventsource_client::ClientBuilder::for_url(&sse_url) {
|
||||||
Ok(builder) => builder.build(),
|
Ok(builder) => builder.build(),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
pending_requests.clear().await;
|
|
||||||
warn!("Failed to connect SSE client: {}", e);
|
warn!("Failed to connect SSE client: {}", e);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -105,84 +103,54 @@ impl SseActor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now handle subsequent events
|
// Now handle subsequent events
|
||||||
while let Ok(Some(event)) = stream.try_next().await {
|
loop {
|
||||||
match event {
|
match stream.try_next().await {
|
||||||
SSE::Event(e) if e.event_type == "message" => {
|
Ok(Some(event)) => {
|
||||||
// Attempt to parse the SSE data as a JsonRpcMessage
|
match event {
|
||||||
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
|
SSE::Event(e) if e.event_type == "message" => {
|
||||||
Ok(message) => {
|
// Attempt to parse the SSE data as a JsonRpcMessage
|
||||||
match &message {
|
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
|
||||||
JsonRpcMessage::Response(response) => {
|
Ok(message) => {
|
||||||
if let Some(id) = &response.id {
|
let _ = sender.send(message).await;
|
||||||
pending_requests
|
|
||||||
.respond(&id.to_string(), Ok(message))
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
JsonRpcMessage::Error(error) => {
|
Err(err) => {
|
||||||
if let Some(id) = &error.id {
|
warn!("Failed to parse SSE message: {err}");
|
||||||
pending_requests
|
|
||||||
.respond(&id.to_string(), Ok(message))
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
_ => {} // TODO: Handle other variants (Request, etc.)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(err) => {
|
_ => { /* ignore other events */ }
|
||||||
warn!("Failed to parse SSE message: {err}");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => { /* ignore other events */ }
|
Ok(None) => {
|
||||||
|
// Stream ended
|
||||||
|
tracing::info!("SSE stream ended.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Error reading SSE stream: {e}");
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSE stream ended or errored; signal any pending requests
|
tracing::error!("SSE stream ended or encountered an error.");
|
||||||
tracing::error!("SSE stream ended or encountered an error; clearing pending requests.");
|
|
||||||
pending_requests.clear().await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Continuously receives messages from the `mpsc::Receiver`.
|
|
||||||
/// - If it's a request, store the oneshot in `pending_requests`.
|
|
||||||
/// - POST the message to the discovered endpoint (once known).
|
|
||||||
async fn handle_outgoing_messages(
|
async fn handle_outgoing_messages(
|
||||||
mut receiver: mpsc::Receiver<TransportMessage>,
|
mut receiver: mpsc::Receiver<String>,
|
||||||
http_client: HttpClient,
|
http_client: HttpClient,
|
||||||
post_endpoint: Arc<RwLock<Option<String>>>,
|
post_endpoint: Arc<RwLock<Option<String>>>,
|
||||||
pending_requests: Arc<PendingRequests>,
|
|
||||||
) {
|
) {
|
||||||
while let Some(transport_msg) = receiver.recv().await {
|
while let Some(message_str) = receiver.recv().await {
|
||||||
let post_url = match post_endpoint.read().await.as_ref() {
|
let post_url = match post_endpoint.read().await.as_ref() {
|
||||||
Some(url) => url.clone(),
|
Some(url) => url.clone(),
|
||||||
None => {
|
None => {
|
||||||
if let Some(response_tx) = transport_msg.response_tx {
|
// TODO: the endpoint isn't discovered yet. This shouldn't happen -- we only return the handle
|
||||||
let _ = response_tx.send(Err(Error::NotConnected));
|
// after the endpoint is set.
|
||||||
}
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Serialize the JSON-RPC message
|
|
||||||
let message_str = match serde_json::to_string(&transport_msg.message) {
|
|
||||||
Ok(s) => s,
|
|
||||||
Err(e) => {
|
|
||||||
if let Some(tx) = transport_msg.response_tx {
|
|
||||||
let _ = tx.send(Err(Error::Serialization(e)));
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// If it's a request, store the channel so we can respond later
|
|
||||||
if let Some(response_tx) = transport_msg.response_tx {
|
|
||||||
if let JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) =
|
|
||||||
&transport_msg.message
|
|
||||||
{
|
|
||||||
pending_requests.insert(id.to_string(), response_tx).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform the HTTP POST
|
// Perform the HTTP POST
|
||||||
match http_client
|
match http_client
|
||||||
.post(&post_url)
|
.post(&post_url)
|
||||||
@@ -209,26 +177,25 @@ impl SseActor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mpsc channel closed => no more outgoing messages
|
tracing::info!("SseActor shut down.");
|
||||||
let pending = pending_requests.len().await;
|
|
||||||
if pending > 0 {
|
|
||||||
tracing::error!("SSE stream ended or encountered an error with {pending} unfulfilled pending requests.");
|
|
||||||
pending_requests.clear().await;
|
|
||||||
} else {
|
|
||||||
tracing::info!("SseActor shutdown cleanly. No pending requests.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct SseTransportHandle {
|
pub struct SseTransportHandle {
|
||||||
sender: mpsc::Sender<TransportMessage>,
|
sender: mpsc::Sender<String>,
|
||||||
|
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl TransportHandle for SseTransportHandle {
|
impl TransportHandle for SseTransportHandle {
|
||||||
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> {
|
async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> {
|
||||||
send_message(&self.sender, message).await
|
serialize_and_send(&self.sender, message).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
|
||||||
|
let mut receiver = self.receiver.lock().await;
|
||||||
|
receiver.recv().await.ok_or(Error::ChannelClosed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,17 +246,13 @@ impl Transport for SseTransport {
|
|||||||
|
|
||||||
// Create a channel for outgoing TransportMessages
|
// Create a channel for outgoing TransportMessages
|
||||||
let (tx, rx) = mpsc::channel(32);
|
let (tx, rx) = mpsc::channel(32);
|
||||||
|
let (otx, orx) = mpsc::channel(32);
|
||||||
|
|
||||||
let post_endpoint: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
|
let post_endpoint: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
|
||||||
let post_endpoint_clone = Arc::clone(&post_endpoint);
|
let post_endpoint_clone = Arc::clone(&post_endpoint);
|
||||||
|
|
||||||
// Build the actor
|
// Build the actor
|
||||||
let actor = SseActor::new(
|
let actor = SseActor::new(rx, otx, self.sse_url.clone(), post_endpoint);
|
||||||
rx,
|
|
||||||
Arc::new(PendingRequests::new()),
|
|
||||||
self.sse_url.clone(),
|
|
||||||
post_endpoint,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Spawn the actor task
|
// Spawn the actor task
|
||||||
tokio::spawn(actor.run());
|
tokio::spawn(actor.run());
|
||||||
@@ -301,7 +264,10 @@ impl Transport for SseTransport {
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(_) => Ok(SseTransportHandle { sender: tx }),
|
Ok(_) => Ok(SseTransportHandle {
|
||||||
|
sender: tx,
|
||||||
|
receiver: Arc::new(Mutex::new(orx)),
|
||||||
|
}),
|
||||||
Err(e) => Err(Error::SseConnection(e.to_string())),
|
Err(e) => Err(Error::SseConnection(e.to_string())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ use nix::sys::signal::{kill, Signal};
|
|||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use nix::unistd::{getpgid, Pid};
|
use nix::unistd::{getpgid, Pid};
|
||||||
|
|
||||||
use super::{send_message, Error, PendingRequests, Transport, TransportHandle, TransportMessage};
|
use super::{serialize_and_send, Error, Transport, TransportHandle};
|
||||||
|
|
||||||
// Global to track process groups we've created
|
// Global to track process groups we've created
|
||||||
static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
|
static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
|
||||||
@@ -23,8 +23,8 @@ static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
|
|||||||
///
|
///
|
||||||
/// It uses channels for message passing and handles responses asynchronously through a background task.
|
/// It uses channels for message passing and handles responses asynchronously through a background task.
|
||||||
pub struct StdioActor {
|
pub struct StdioActor {
|
||||||
receiver: Option<mpsc::Receiver<TransportMessage>>,
|
receiver: Option<mpsc::Receiver<String>>,
|
||||||
pending_requests: Arc<PendingRequests>,
|
sender: Option<mpsc::Sender<JsonRpcMessage>>,
|
||||||
process: Child, // we store the process to keep it alive
|
process: Child, // we store the process to keep it alive
|
||||||
error_sender: mpsc::Sender<Error>,
|
error_sender: mpsc::Sender<Error>,
|
||||||
stdin: Option<ChildStdin>,
|
stdin: Option<ChildStdin>,
|
||||||
@@ -55,11 +55,11 @@ impl StdioActor {
|
|||||||
|
|
||||||
let stdout = self.stdout.take().expect("stdout should be available");
|
let stdout = self.stdout.take().expect("stdout should be available");
|
||||||
let stdin = self.stdin.take().expect("stdin should be available");
|
let stdin = self.stdin.take().expect("stdin should be available");
|
||||||
let receiver = self.receiver.take().expect("receiver should be available");
|
let msg_inbox = self.receiver.take().expect("receiver should be available");
|
||||||
|
let msg_outbox = self.sender.take().expect("sender should be available");
|
||||||
|
|
||||||
let incoming = Self::handle_incoming_messages(stdout, self.pending_requests.clone());
|
let incoming = Self::handle_proc_output(stdout, msg_outbox);
|
||||||
let outgoing =
|
let outgoing = Self::handle_proc_input(stdin, msg_inbox);
|
||||||
Self::handle_outgoing_messages(receiver, stdin, self.pending_requests.clone());
|
|
||||||
|
|
||||||
// take ownership of futures for tokio::select
|
// take ownership of futures for tokio::select
|
||||||
pin!(incoming);
|
pin!(incoming);
|
||||||
@@ -96,12 +96,9 @@ impl StdioActor {
|
|||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up regardless of which path we took
|
|
||||||
self.pending_requests.clear().await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_incoming_messages(stdout: ChildStdout, pending_requests: Arc<PendingRequests>) {
|
async fn handle_proc_output(stdout: ChildStdout, sender: mpsc::Sender<JsonRpcMessage>) {
|
||||||
let mut reader = BufReader::new(stdout);
|
let mut reader = BufReader::new(stdout);
|
||||||
let mut line = String::new();
|
let mut line = String::new();
|
||||||
loop {
|
loop {
|
||||||
@@ -116,20 +113,12 @@ impl StdioActor {
|
|||||||
message = ?message,
|
message = ?message,
|
||||||
"Received incoming message"
|
"Received incoming message"
|
||||||
);
|
);
|
||||||
|
let _ = sender.send(message).await;
|
||||||
match &message {
|
} else {
|
||||||
JsonRpcMessage::Response(response) => {
|
tracing::warn!(
|
||||||
if let Some(id) = &response.id {
|
message = ?line,
|
||||||
pending_requests.respond(&id.to_string(), Ok(message)).await;
|
"Failed to parse incoming message"
|
||||||
}
|
);
|
||||||
}
|
|
||||||
JsonRpcMessage::Error(error) => {
|
|
||||||
if let Some(id) = &error.id {
|
|
||||||
pending_requests.respond(&id.to_string(), Ok(message)).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {} // TODO: Handle other variants (Request, etc.)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
line.clear();
|
line.clear();
|
||||||
}
|
}
|
||||||
@@ -141,44 +130,20 @@ impl StdioActor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_outgoing_messages(
|
async fn handle_proc_input(mut stdin: ChildStdin, mut receiver: mpsc::Receiver<String>) {
|
||||||
mut receiver: mpsc::Receiver<TransportMessage>,
|
while let Some(message_str) = receiver.recv().await {
|
||||||
mut stdin: ChildStdin,
|
tracing::debug!(message = ?message_str, "Sending outgoing message");
|
||||||
pending_requests: Arc<PendingRequests>,
|
|
||||||
) {
|
|
||||||
while let Some(mut transport_msg) = receiver.recv().await {
|
|
||||||
let message_str = match serde_json::to_string(&transport_msg.message) {
|
|
||||||
Ok(s) => s,
|
|
||||||
Err(e) => {
|
|
||||||
if let Some(tx) = transport_msg.response_tx.take() {
|
|
||||||
let _ = tx.send(Err(Error::Serialization(e)));
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
tracing::debug!(message = ?transport_msg.message, "Sending outgoing message");
|
|
||||||
|
|
||||||
if let Some(response_tx) = transport_msg.response_tx.take() {
|
|
||||||
if let JsonRpcMessage::Request(request) = &transport_msg.message {
|
|
||||||
if let Some(id) = &request.id {
|
|
||||||
pending_requests.insert(id.to_string(), response_tx).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Err(e) = stdin
|
if let Err(e) = stdin
|
||||||
.write_all(format!("{}\n", message_str).as_bytes())
|
.write_all(format!("{}\n", message_str).as_bytes())
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
tracing::error!(error = ?e, "Error writing message to child process");
|
tracing::error!(error = ?e, "Error writing message to child process");
|
||||||
pending_requests.clear().await;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(e) = stdin.flush().await {
|
if let Err(e) = stdin.flush().await {
|
||||||
tracing::error!(error = ?e, "Error flushing message to child process");
|
tracing::error!(error = ?e, "Error flushing message to child process");
|
||||||
pending_requests.clear().await;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -187,18 +152,24 @@ impl StdioActor {
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct StdioTransportHandle {
|
pub struct StdioTransportHandle {
|
||||||
sender: mpsc::Sender<TransportMessage>,
|
sender: mpsc::Sender<String>, // to process
|
||||||
|
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>, // from process
|
||||||
error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>,
|
error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl TransportHandle for StdioTransportHandle {
|
impl TransportHandle for StdioTransportHandle {
|
||||||
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> {
|
async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> {
|
||||||
let result = send_message(&self.sender, message).await;
|
let result = serialize_and_send(&self.sender, message).await;
|
||||||
// Check for any pending errors even if send is successful
|
// Check for any pending errors even if send is successful
|
||||||
self.check_for_errors().await?;
|
self.check_for_errors().await?;
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
|
||||||
|
let mut receiver = self.receiver.lock().await;
|
||||||
|
receiver.recv().await.ok_or(Error::ChannelClosed)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StdioTransportHandle {
|
impl StdioTransportHandle {
|
||||||
@@ -289,12 +260,13 @@ impl Transport for StdioTransport {
|
|||||||
|
|
||||||
async fn start(&self) -> Result<Self::Handle, Error> {
|
async fn start(&self) -> Result<Self::Handle, Error> {
|
||||||
let (process, stdin, stdout, stderr) = self.spawn_process().await?;
|
let (process, stdin, stdout, stderr) = self.spawn_process().await?;
|
||||||
let (message_tx, message_rx) = mpsc::channel(32);
|
let (outbox_tx, outbox_rx) = mpsc::channel(32);
|
||||||
|
let (inbox_tx, inbox_rx) = mpsc::channel(32);
|
||||||
let (error_tx, error_rx) = mpsc::channel(1);
|
let (error_tx, error_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
let actor = StdioActor {
|
let actor = StdioActor {
|
||||||
receiver: Some(message_rx),
|
receiver: Some(outbox_rx), // client to process
|
||||||
pending_requests: Arc::new(PendingRequests::new()),
|
sender: Some(inbox_tx), // process to client
|
||||||
process,
|
process,
|
||||||
error_sender: error_tx,
|
error_sender: error_tx,
|
||||||
stdin: Some(stdin),
|
stdin: Some(stdin),
|
||||||
@@ -305,7 +277,8 @@ impl Transport for StdioTransport {
|
|||||||
tokio::spawn(actor.run());
|
tokio::spawn(actor.run());
|
||||||
|
|
||||||
let handle = StdioTransportHandle {
|
let handle = StdioTransportHandle {
|
||||||
sender: message_tx,
|
sender: outbox_tx, // client to process
|
||||||
|
receiver: Arc::new(Mutex::new(inbox_rx)), // process to client
|
||||||
error_receiver: Arc::new(Mutex::new(error_rx)),
|
error_receiver: Arc::new(Mutex::new(error_rx)),
|
||||||
};
|
};
|
||||||
Ok(handle)
|
Ok(handle)
|
||||||
|
|||||||
@@ -4,9 +4,13 @@ use std::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use futures::{Future, Stream};
|
use futures::{Future, Stream};
|
||||||
use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
|
use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcResponse};
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
use router::McpRequest;
|
||||||
|
use tokio::{
|
||||||
|
io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader},
|
||||||
|
sync::mpsc,
|
||||||
|
};
|
||||||
use tower_service::Service;
|
use tower_service::Service;
|
||||||
|
|
||||||
mod errors;
|
mod errors;
|
||||||
@@ -123,7 +127,7 @@ pub struct Server<S> {
|
|||||||
|
|
||||||
impl<S> Server<S>
|
impl<S> Server<S>
|
||||||
where
|
where
|
||||||
S: Service<JsonRpcRequest, Response = JsonRpcResponse> + Send,
|
S: Service<McpRequest, Response = JsonRpcResponse> + Send,
|
||||||
S::Error: Into<BoxError>,
|
S::Error: Into<BoxError>,
|
||||||
S::Future: Send,
|
S::Future: Send,
|
||||||
{
|
{
|
||||||
@@ -134,8 +138,8 @@ where
|
|||||||
// TODO transport trait instead of byte transport if we implement others
|
// TODO transport trait instead of byte transport if we implement others
|
||||||
pub async fn run<R, W>(self, mut transport: ByteTransport<R, W>) -> Result<(), ServerError>
|
pub async fn run<R, W>(self, mut transport: ByteTransport<R, W>) -> Result<(), ServerError>
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin,
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
W: AsyncWrite + Unpin,
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
let mut service = self.service;
|
let mut service = self.service;
|
||||||
@@ -160,7 +164,22 @@ where
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Process the request using our service
|
// Process the request using our service
|
||||||
let response = match service.call(request).await {
|
let (notify_tx, mut notify_rx) = mpsc::channel(256);
|
||||||
|
let mcp_request = McpRequest {
|
||||||
|
request,
|
||||||
|
notifier: notify_tx,
|
||||||
|
};
|
||||||
|
|
||||||
|
let transport_fut = tokio::spawn(async move {
|
||||||
|
while let Some(notification) = notify_rx.recv().await {
|
||||||
|
if transport.write_message(notification).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
transport
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = match service.call(mcp_request).await {
|
||||||
Ok(resp) => resp,
|
Ok(resp) => resp,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let error_msg = e.into().to_string();
|
let error_msg = e.into().to_string();
|
||||||
@@ -178,6 +197,16 @@ where
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
transport = match transport_fut.await {
|
||||||
|
Ok(transport) => transport,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(error = %e, "Failed to spawn transport task");
|
||||||
|
return Err(ServerError::Transport(TransportError::Io(
|
||||||
|
e.into(),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Serialize response for logging
|
// Serialize response for logging
|
||||||
let response_json = serde_json::to_string(&response)
|
let response_json = serde_json::to_string(&response)
|
||||||
.unwrap_or_else(|_| "Failed to serialize response".to_string());
|
.unwrap_or_else(|_| "Failed to serialize response".to_string());
|
||||||
@@ -247,7 +276,7 @@ where
|
|||||||
// Any router implements this
|
// Any router implements this
|
||||||
pub trait BoundedService:
|
pub trait BoundedService:
|
||||||
Service<
|
Service<
|
||||||
JsonRpcRequest,
|
McpRequest,
|
||||||
Response = JsonRpcResponse,
|
Response = JsonRpcResponse,
|
||||||
Error = BoxError,
|
Error = BoxError,
|
||||||
Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>,
|
Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>,
|
||||||
@@ -259,7 +288,7 @@ pub trait BoundedService:
|
|||||||
// Implement it for any type that meets the bounds
|
// Implement it for any type that meets the bounds
|
||||||
impl<T> BoundedService for T where
|
impl<T> BoundedService for T where
|
||||||
T: Service<
|
T: Service<
|
||||||
JsonRpcRequest,
|
McpRequest,
|
||||||
Response = JsonRpcResponse,
|
Response = JsonRpcResponse,
|
||||||
Error = BoxError,
|
Error = BoxError,
|
||||||
Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>,
|
Future = Pin<Box<dyn Future<Output = Result<JsonRpcResponse, BoxError>> + Send>>,
|
||||||
|
|||||||
@@ -2,12 +2,14 @@ use anyhow::Result;
|
|||||||
use mcp_core::content::Content;
|
use mcp_core::content::Content;
|
||||||
use mcp_core::handler::{PromptError, ResourceError};
|
use mcp_core::handler::{PromptError, ResourceError};
|
||||||
use mcp_core::prompt::{Prompt, PromptArgument};
|
use mcp_core::prompt::{Prompt, PromptArgument};
|
||||||
|
use mcp_core::protocol::JsonRpcMessage;
|
||||||
use mcp_core::tool::ToolAnnotations;
|
use mcp_core::tool::ToolAnnotations;
|
||||||
use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool};
|
use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool};
|
||||||
use mcp_server::router::{CapabilitiesBuilder, RouterService};
|
use mcp_server::router::{CapabilitiesBuilder, RouterService};
|
||||||
use mcp_server::{ByteTransport, Router, Server};
|
use mcp_server::{ByteTransport, Router, Server};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::{future::Future, pin::Pin, sync::Arc};
|
use std::{future::Future, pin::Pin, sync::Arc};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{stdin, stdout},
|
io::{stdin, stdout},
|
||||||
sync::Mutex,
|
sync::Mutex,
|
||||||
@@ -124,6 +126,7 @@ impl Router for CounterRouter {
|
|||||||
&self,
|
&self,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
_arguments: Value,
|
_arguments: Value,
|
||||||
|
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||||
let this = self.clone();
|
let this = self.clone();
|
||||||
let tool_name = tool_name.to_string();
|
let tool_name = tool_name.to_string();
|
||||||
|
|||||||
@@ -11,14 +11,15 @@ use mcp_core::{
|
|||||||
handler::{PromptError, ResourceError, ToolError},
|
handler::{PromptError, ResourceError, ToolError},
|
||||||
prompt::{Prompt, PromptMessage, PromptMessageRole},
|
prompt::{Prompt, PromptMessage, PromptMessageRole},
|
||||||
protocol::{
|
protocol::{
|
||||||
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcRequest,
|
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcMessage,
|
||||||
JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult,
|
JsonRpcRequest, JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult,
|
||||||
PromptsCapability, ReadResourceResult, ResourcesCapability, ServerCapabilities,
|
PromptsCapability, ReadResourceResult, ResourcesCapability, ServerCapabilities,
|
||||||
ToolsCapability,
|
ToolsCapability,
|
||||||
},
|
},
|
||||||
ResourceContents,
|
ResourceContents,
|
||||||
};
|
};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
use tower_service::Service;
|
use tower_service::Service;
|
||||||
|
|
||||||
use crate::{BoxError, RouterError};
|
use crate::{BoxError, RouterError};
|
||||||
@@ -91,6 +92,7 @@ pub trait Router: Send + Sync + 'static {
|
|||||||
&self,
|
&self,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
arguments: Value,
|
arguments: Value,
|
||||||
|
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>>;
|
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>>;
|
||||||
fn list_resources(&self) -> Vec<mcp_core::resource::Resource>;
|
fn list_resources(&self) -> Vec<mcp_core::resource::Resource>;
|
||||||
fn read_resource(
|
fn read_resource(
|
||||||
@@ -159,6 +161,7 @@ pub trait Router: Send + Sync + 'static {
|
|||||||
fn handle_tools_call(
|
fn handle_tools_call(
|
||||||
&self,
|
&self,
|
||||||
req: JsonRpcRequest,
|
req: JsonRpcRequest,
|
||||||
|
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||||
) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
|
) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
|
||||||
async move {
|
async move {
|
||||||
let params = req
|
let params = req
|
||||||
@@ -172,7 +175,7 @@ pub trait Router: Send + Sync + 'static {
|
|||||||
|
|
||||||
let arguments = params.get("arguments").cloned().unwrap_or(Value::Null);
|
let arguments = params.get("arguments").cloned().unwrap_or(Value::Null);
|
||||||
|
|
||||||
let result = match self.call_tool(name, arguments).await {
|
let result = match self.call_tool(name, arguments, notifier).await {
|
||||||
Ok(result) => CallToolResult {
|
Ok(result) => CallToolResult {
|
||||||
content: result,
|
content: result,
|
||||||
is_error: None,
|
is_error: None,
|
||||||
@@ -394,7 +397,12 @@ pub trait Router: Send + Sync + 'static {
|
|||||||
|
|
||||||
pub struct RouterService<T>(pub T);
|
pub struct RouterService<T>(pub T);
|
||||||
|
|
||||||
impl<T> Service<JsonRpcRequest> for RouterService<T>
|
pub struct McpRequest {
|
||||||
|
pub request: JsonRpcRequest,
|
||||||
|
pub notifier: mpsc::Sender<JsonRpcMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Service<McpRequest> for RouterService<T>
|
||||||
where
|
where
|
||||||
T: Router + Clone + Send + Sync + 'static,
|
T: Router + Clone + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
@@ -406,21 +414,21 @@ where
|
|||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call(&mut self, req: JsonRpcRequest) -> Self::Future {
|
fn call(&mut self, req: McpRequest) -> Self::Future {
|
||||||
let this = self.0.clone();
|
let this = self.0.clone();
|
||||||
|
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let result = match req.method.as_str() {
|
let result = match req.request.method.as_str() {
|
||||||
"initialize" => this.handle_initialize(req).await,
|
"initialize" => this.handle_initialize(req.request).await,
|
||||||
"tools/list" => this.handle_tools_list(req).await,
|
"tools/list" => this.handle_tools_list(req.request).await,
|
||||||
"tools/call" => this.handle_tools_call(req).await,
|
"tools/call" => this.handle_tools_call(req.request, req.notifier).await,
|
||||||
"resources/list" => this.handle_resources_list(req).await,
|
"resources/list" => this.handle_resources_list(req.request).await,
|
||||||
"resources/read" => this.handle_resources_read(req).await,
|
"resources/read" => this.handle_resources_read(req.request).await,
|
||||||
"prompts/list" => this.handle_prompts_list(req).await,
|
"prompts/list" => this.handle_prompts_list(req.request).await,
|
||||||
"prompts/get" => this.handle_prompts_get(req).await,
|
"prompts/get" => this.handle_prompts_get(req.request).await,
|
||||||
_ => {
|
_ => {
|
||||||
let mut response = this.create_response(req.id);
|
let mut response = this.create_response(req.request.id);
|
||||||
response.error = Some(RouterError::MethodNotFound(req.method).into());
|
response.error = Some(RouterError::MethodNotFound(req.request.method).into());
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
VITE_START_EMBEDDED_SERVER=yes
|
VITE_START_EMBEDDED_SERVER=yes
|
||||||
GOOSE_PROVIDER__TYPE=openai
|
GOOSE_PROVIDER__TYPE=openai
|
||||||
GOOSE_PROVIDER__HOST=https://api.openai.com
|
GOOSE_PROVIDER__HOST=https://api.openai.com
|
||||||
GOOSE_PROVIDER__MODEL=gpt-4o
|
GOOSE_PROVIDER__MODEL=gpt-4o
|
||||||
|
|||||||
@@ -148,6 +148,7 @@ function ChatContent({
|
|||||||
handleInputChange: _handleInputChange,
|
handleInputChange: _handleInputChange,
|
||||||
handleSubmit: _submitMessage,
|
handleSubmit: _submitMessage,
|
||||||
updateMessageStreamBody,
|
updateMessageStreamBody,
|
||||||
|
notifications,
|
||||||
} = useMessageStream({
|
} = useMessageStream({
|
||||||
api: getApiUrl('/reply'),
|
api: getApiUrl('/reply'),
|
||||||
initialMessages: chat.messages,
|
initialMessages: chat.messages,
|
||||||
@@ -492,6 +493,16 @@ function ChatContent({
|
|||||||
const handleDragOver = (e: React.DragEvent<HTMLDivElement>) => {
|
const handleDragOver = (e: React.DragEvent<HTMLDivElement>) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const toolCallNotifications = notifications.reduce((map, item) => {
|
||||||
|
const key = item.request_id;
|
||||||
|
if (!map.has(key)) {
|
||||||
|
map.set(key, []);
|
||||||
|
}
|
||||||
|
map.get(key).push(item);
|
||||||
|
return map;
|
||||||
|
}, new Map());
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col w-full h-screen items-center justify-center">
|
<div className="flex flex-col w-full h-screen items-center justify-center">
|
||||||
{/* Loader when generating recipe */}
|
{/* Loader when generating recipe */}
|
||||||
@@ -571,6 +582,7 @@ function ChatContent({
|
|||||||
const updatedMessages = [...messages, newMessage];
|
const updatedMessages = [...messages, newMessage];
|
||||||
setMessages(updatedMessages);
|
setMessages(updatedMessages);
|
||||||
}}
|
}}
|
||||||
|
toolCallNotifications={toolCallNotifications}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</>
|
</>
|
||||||
@@ -578,6 +590,7 @@ function ChatContent({
|
|||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
</SearchView>
|
</SearchView>
|
||||||
|
|
||||||
{error && (
|
{error && (
|
||||||
<div className="flex flex-col items-center justify-center p-4">
|
<div className="flex flex-col items-center justify-center p-4">
|
||||||
<div className="text-red-700 dark:text-red-300 bg-red-400/50 p-3 rounded-lg mb-2">
|
<div className="text-red-700 dark:text-red-300 bg-red-400/50 p-3 rounded-lg mb-2">
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import {
|
|||||||
} from '../types/message';
|
} from '../types/message';
|
||||||
import ToolCallConfirmation from './ToolCallConfirmation';
|
import ToolCallConfirmation from './ToolCallConfirmation';
|
||||||
import MessageCopyLink from './MessageCopyLink';
|
import MessageCopyLink from './MessageCopyLink';
|
||||||
|
import { NotificationEvent } from '../hooks/useMessageStream';
|
||||||
|
|
||||||
interface GooseMessageProps {
|
interface GooseMessageProps {
|
||||||
// messages up to this index are presumed to be "history" from a resumed session, this is used to track older tool confirmation requests
|
// messages up to this index are presumed to be "history" from a resumed session, this is used to track older tool confirmation requests
|
||||||
@@ -25,6 +26,7 @@ interface GooseMessageProps {
|
|||||||
message: Message;
|
message: Message;
|
||||||
messages: Message[];
|
messages: Message[];
|
||||||
metadata?: string[];
|
metadata?: string[];
|
||||||
|
toolCallNotifications: Map<string, NotificationEvent[]>;
|
||||||
append: (value: string) => void;
|
append: (value: string) => void;
|
||||||
appendMessage: (message: Message) => void;
|
appendMessage: (message: Message) => void;
|
||||||
}
|
}
|
||||||
@@ -34,6 +36,7 @@ export default function GooseMessage({
|
|||||||
message,
|
message,
|
||||||
metadata,
|
metadata,
|
||||||
messages,
|
messages,
|
||||||
|
toolCallNotifications,
|
||||||
append,
|
append,
|
||||||
appendMessage,
|
appendMessage,
|
||||||
}: GooseMessageProps) {
|
}: GooseMessageProps) {
|
||||||
@@ -158,6 +161,7 @@ export default function GooseMessage({
|
|||||||
}
|
}
|
||||||
toolRequest={toolRequest}
|
toolRequest={toolRequest}
|
||||||
toolResponse={toolResponsesMap.get(toolRequest.id)}
|
toolResponse={toolResponsesMap.get(toolRequest.id)}
|
||||||
|
notifications={toolCallNotifications.get(toolRequest.id)}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import React from 'react';
|
import React, { useEffect, useRef } from 'react';
|
||||||
import { Card } from './ui/card';
|
import { Card } from './ui/card';
|
||||||
import { ToolCallArguments, ToolCallArgumentValue } from './ToolCallArguments';
|
import { ToolCallArguments, ToolCallArgumentValue } from './ToolCallArguments';
|
||||||
import MarkdownContent from './MarkdownContent';
|
import MarkdownContent from './MarkdownContent';
|
||||||
@@ -6,17 +6,20 @@ import { Content, ToolRequestMessageContent, ToolResponseMessageContent } from '
|
|||||||
import { snakeToTitleCase } from '../utils';
|
import { snakeToTitleCase } from '../utils';
|
||||||
import Dot, { LoadingStatus } from './ui/Dot';
|
import Dot, { LoadingStatus } from './ui/Dot';
|
||||||
import Expand from './ui/Expand';
|
import Expand from './ui/Expand';
|
||||||
|
import { NotificationEvent } from '../hooks/useMessageStream';
|
||||||
|
|
||||||
interface ToolCallWithResponseProps {
|
interface ToolCallWithResponseProps {
|
||||||
isCancelledMessage: boolean;
|
isCancelledMessage: boolean;
|
||||||
toolRequest: ToolRequestMessageContent;
|
toolRequest: ToolRequestMessageContent;
|
||||||
toolResponse?: ToolResponseMessageContent;
|
toolResponse?: ToolResponseMessageContent;
|
||||||
|
notifications?: NotificationEvent[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function ToolCallWithResponse({
|
export default function ToolCallWithResponse({
|
||||||
isCancelledMessage,
|
isCancelledMessage,
|
||||||
toolRequest,
|
toolRequest,
|
||||||
toolResponse,
|
toolResponse,
|
||||||
|
notifications,
|
||||||
}: ToolCallWithResponseProps) {
|
}: ToolCallWithResponseProps) {
|
||||||
const toolCall = toolRequest.toolCall.status === 'success' ? toolRequest.toolCall.value : null;
|
const toolCall = toolRequest.toolCall.status === 'success' ? toolRequest.toolCall.value : null;
|
||||||
if (!toolCall) {
|
if (!toolCall) {
|
||||||
@@ -26,7 +29,7 @@ export default function ToolCallWithResponse({
|
|||||||
return (
|
return (
|
||||||
<div className={'w-full text-textSubtle text-sm'}>
|
<div className={'w-full text-textSubtle text-sm'}>
|
||||||
<Card className="">
|
<Card className="">
|
||||||
<ToolCallView {...{ isCancelledMessage, toolCall, toolResponse }} />
|
<ToolCallView {...{ isCancelledMessage, toolCall, toolResponse, notifications }} />
|
||||||
</Card>
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
@@ -47,8 +50,9 @@ function ToolCallExpandable({
|
|||||||
children,
|
children,
|
||||||
className = '',
|
className = '',
|
||||||
}: ToolCallExpandableProps) {
|
}: ToolCallExpandableProps) {
|
||||||
const [isExpanded, setIsExpanded] = React.useState(isStartExpanded);
|
const [isExpandedState, setIsExpanded] = React.useState<boolean | null>(null);
|
||||||
const toggleExpand = () => setIsExpanded((prev) => !prev);
|
const isExpanded = isExpandedState === null ? isStartExpanded : isExpandedState;
|
||||||
|
const toggleExpand = () => setIsExpanded(!isExpanded);
|
||||||
React.useEffect(() => {
|
React.useEffect(() => {
|
||||||
if (isForceExpand) setIsExpanded(true);
|
if (isForceExpand) setIsExpanded(true);
|
||||||
}, [isForceExpand]);
|
}, [isForceExpand]);
|
||||||
@@ -71,9 +75,42 @@ interface ToolCallViewProps {
|
|||||||
arguments: Record<string, unknown>;
|
arguments: Record<string, unknown>;
|
||||||
};
|
};
|
||||||
toolResponse?: ToolResponseMessageContent;
|
toolResponse?: ToolResponseMessageContent;
|
||||||
|
notifications?: NotificationEvent[];
|
||||||
}
|
}
|
||||||
|
|
||||||
function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallViewProps) {
|
interface Progress {
|
||||||
|
progress: number;
|
||||||
|
progressToken: string;
|
||||||
|
total?: number;
|
||||||
|
message?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const logToString = (logMessage: NotificationEvent) => {
|
||||||
|
const params = logMessage.message.params;
|
||||||
|
|
||||||
|
// Special case for the developer system shell logs
|
||||||
|
if (
|
||||||
|
params &&
|
||||||
|
params.data &&
|
||||||
|
typeof params.data === 'object' &&
|
||||||
|
'output' in params.data &&
|
||||||
|
'stream' in params.data
|
||||||
|
) {
|
||||||
|
return `[${params.data.stream}] ${params.data.output}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return typeof params.data === 'string' ? params.data : JSON.stringify(params.data);
|
||||||
|
};
|
||||||
|
|
||||||
|
const notificationToProgress = (notification: NotificationEvent): Progress =>
|
||||||
|
notification.message.params as unknown as Progress;
|
||||||
|
|
||||||
|
function ToolCallView({
|
||||||
|
isCancelledMessage,
|
||||||
|
toolCall,
|
||||||
|
toolResponse,
|
||||||
|
notifications,
|
||||||
|
}: ToolCallViewProps) {
|
||||||
const responseStyle = localStorage.getItem('response_style');
|
const responseStyle = localStorage.getItem('response_style');
|
||||||
const isExpandToolDetails = (() => {
|
const isExpandToolDetails = (() => {
|
||||||
switch (responseStyle) {
|
switch (responseStyle) {
|
||||||
@@ -103,6 +140,29 @@ function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallVi
|
|||||||
}))
|
}))
|
||||||
: [];
|
: [];
|
||||||
|
|
||||||
|
const logs = notifications
|
||||||
|
?.filter((notification) => notification.message.method === 'notifications/message')
|
||||||
|
.map(logToString);
|
||||||
|
|
||||||
|
const progress = notifications
|
||||||
|
?.filter((notification) => notification.message.method === 'notifications/progress')
|
||||||
|
.map(notificationToProgress)
|
||||||
|
.reduce((map, item) => {
|
||||||
|
const key = item.progressToken;
|
||||||
|
if (!map.has(key)) {
|
||||||
|
map.set(key, []);
|
||||||
|
}
|
||||||
|
map.get(key)!.push(item);
|
||||||
|
return map;
|
||||||
|
}, new Map<string, Progress[]>());
|
||||||
|
|
||||||
|
const progressEntries = [...(progress?.values() || [])].map(
|
||||||
|
(entries) => entries.sort((a, b) => b.progress - a.progress)[0]
|
||||||
|
);
|
||||||
|
|
||||||
|
const isRenderingProgress =
|
||||||
|
loadingStatus === 'loading' && (progressEntries.length > 0 || (logs || []).length > 0);
|
||||||
|
|
||||||
const isShouldExpand = isExpandToolDetails || toolResults.some((v) => v.isExpandToolResults);
|
const isShouldExpand = isExpandToolDetails || toolResults.some((v) => v.isExpandToolResults);
|
||||||
|
|
||||||
// Function to create a compact representation of arguments
|
// Function to create a compact representation of arguments
|
||||||
@@ -136,7 +196,7 @@ function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallVi
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<ToolCallExpandable
|
<ToolCallExpandable
|
||||||
isStartExpanded={isShouldExpand}
|
isStartExpanded={isShouldExpand || isRenderingProgress}
|
||||||
isForceExpand={isShouldExpand}
|
isForceExpand={isShouldExpand}
|
||||||
label={
|
label={
|
||||||
<>
|
<>
|
||||||
@@ -156,6 +216,24 @@ function ToolCallView({ isCancelledMessage, toolCall, toolResponse }: ToolCallVi
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{logs && logs.length > 0 && (
|
||||||
|
<div className="bg-bgStandard mt-1">
|
||||||
|
<ToolLogsView
|
||||||
|
logs={logs}
|
||||||
|
working={toolResults.length === 0}
|
||||||
|
isStartExpanded={toolResults.length === 0}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{toolResults.length === 0 &&
|
||||||
|
progressEntries.length > 0 &&
|
||||||
|
progressEntries.map((entry, index) => (
|
||||||
|
<div className="p-2" key={index}>
|
||||||
|
<ProgressBar progress={entry.progress} total={entry.total} message={entry.message} />
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
|
||||||
{/* Tool Output */}
|
{/* Tool Output */}
|
||||||
{!isCancelledMessage && (
|
{!isCancelledMessage && (
|
||||||
<>
|
<>
|
||||||
@@ -234,3 +312,76 @@ function ToolResultView({ result, isStartExpanded }: ToolResultViewProps) {
|
|||||||
</ToolCallExpandable>
|
</ToolCallExpandable>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function ToolLogsView({
|
||||||
|
logs,
|
||||||
|
working,
|
||||||
|
isStartExpanded,
|
||||||
|
}: {
|
||||||
|
logs: string[];
|
||||||
|
working: boolean;
|
||||||
|
isStartExpanded?: boolean;
|
||||||
|
}) {
|
||||||
|
const boxRef = useRef(null);
|
||||||
|
|
||||||
|
// Whenever logs update, jump to the newest entry
|
||||||
|
useEffect(() => {
|
||||||
|
if (boxRef.current) {
|
||||||
|
boxRef.current.scrollTop = boxRef.current.scrollHeight;
|
||||||
|
}
|
||||||
|
}, [logs]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ToolCallExpandable
|
||||||
|
label={
|
||||||
|
<span className="pl-[19px] py-1">
|
||||||
|
<span>Logs</span>
|
||||||
|
{working && (
|
||||||
|
<div className="mx-2 inline-block">
|
||||||
|
<span
|
||||||
|
className="inline-block animate-spin rounded-full border-2 border-t-transparent border-current"
|
||||||
|
style={{ width: 8, height: 8 }}
|
||||||
|
role="status"
|
||||||
|
aria-label="Loading spinner"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</span>
|
||||||
|
}
|
||||||
|
isStartExpanded={isStartExpanded}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
ref={boxRef}
|
||||||
|
className={`flex flex-col items-start space-y-2 overflow-y-auto ${working ? 'max-h-[4rem]' : 'max-h-[20rem]'} bg-bgApp`}
|
||||||
|
>
|
||||||
|
{logs.map((log, i) => (
|
||||||
|
<span key={i} className="font-mono text-sm text-textSubtle">
|
||||||
|
{log}
|
||||||
|
</span>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</ToolCallExpandable>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const ProgressBar = ({ progress, total, message }: Omit<Progress, 'progressToken'>) => {
|
||||||
|
const isDeterminate = typeof total === 'number';
|
||||||
|
const percent = isDeterminate ? Math.min((progress / total!) * 100, 100) : 0;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="w-full space-y-2">
|
||||||
|
{message && <div className="text-sm text-gray-700">{message}</div>}
|
||||||
|
|
||||||
|
<div className="w-full bg-gray-200 rounded-full h-4 overflow-hidden relative">
|
||||||
|
{isDeterminate ? (
|
||||||
|
<div
|
||||||
|
className="bg-blue-500 h-full transition-all duration-300"
|
||||||
|
style={{ width: `${percent}%` }}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<div className="absolute inset-0 animate-indeterminate bg-blue-500" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|||||||
@@ -6,11 +6,25 @@ import { Message, createUserMessage, hasCompletedToolCalls } from '../types/mess
|
|||||||
// Ensure TextDecoder is available in the global scope
|
// Ensure TextDecoder is available in the global scope
|
||||||
const TextDecoder = globalThis.TextDecoder;
|
const TextDecoder = globalThis.TextDecoder;
|
||||||
|
|
||||||
|
type JsonValue = string | number | boolean | null | JsonValue[] | { [key: string]: JsonValue };
|
||||||
|
|
||||||
|
export interface NotificationEvent {
|
||||||
|
type: 'Notification';
|
||||||
|
request_id: string;
|
||||||
|
message: {
|
||||||
|
method: string;
|
||||||
|
params: {
|
||||||
|
[key: string]: JsonValue;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// Event types for SSE stream
|
// Event types for SSE stream
|
||||||
type MessageEvent =
|
type MessageEvent =
|
||||||
| { type: 'Message'; message: Message }
|
| { type: 'Message'; message: Message }
|
||||||
| { type: 'Error'; error: string }
|
| { type: 'Error'; error: string }
|
||||||
| { type: 'Finish'; reason: string };
|
| { type: 'Finish'; reason: string }
|
||||||
|
| NotificationEvent;
|
||||||
|
|
||||||
export interface UseMessageStreamOptions {
|
export interface UseMessageStreamOptions {
|
||||||
/**
|
/**
|
||||||
@@ -124,6 +138,8 @@ export interface UseMessageStreamHelpers {
|
|||||||
|
|
||||||
/** Modify body (session id and/or work dir mid-stream) **/
|
/** Modify body (session id and/or work dir mid-stream) **/
|
||||||
updateMessageStreamBody?: (newBody: object) => void;
|
updateMessageStreamBody?: (newBody: object) => void;
|
||||||
|
|
||||||
|
notifications: NotificationEvent[];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -151,6 +167,8 @@ export function useMessageStream({
|
|||||||
fallbackData: initialMessages,
|
fallbackData: initialMessages,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const [notifications, setNotifications] = useState<NotificationEvent[]>([]);
|
||||||
|
|
||||||
// expose a way to update the body so we can update the session id when CLE occurs
|
// expose a way to update the body so we can update the session id when CLE occurs
|
||||||
const updateMessageStreamBody = useCallback((newBody: object) => {
|
const updateMessageStreamBody = useCallback((newBody: object) => {
|
||||||
extraMetadataRef.current.body = {
|
extraMetadataRef.current.body = {
|
||||||
@@ -247,6 +265,14 @@ export function useMessageStream({
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case 'Notification': {
|
||||||
|
const newNotification = {
|
||||||
|
...parsedEvent,
|
||||||
|
};
|
||||||
|
setNotifications((prev) => [...prev, newNotification]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case 'Error':
|
case 'Error':
|
||||||
throw new Error(parsedEvent.error);
|
throw new Error(parsedEvent.error);
|
||||||
|
|
||||||
@@ -516,5 +542,6 @@ export function useMessageStream({
|
|||||||
isLoading: isLoading || false,
|
isLoading: isLoading || false,
|
||||||
addToolResult,
|
addToolResult,
|
||||||
updateMessageStreamBody,
|
updateMessageStreamBody,
|
||||||
|
notifications,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,10 +44,16 @@ export default {
|
|||||||
'0%': { transform: 'rotate(0deg)' },
|
'0%': { transform: 'rotate(0deg)' },
|
||||||
'100%': { transform: 'rotate(360deg)' },
|
'100%': { transform: 'rotate(360deg)' },
|
||||||
},
|
},
|
||||||
|
indeterminate: {
|
||||||
|
'0%': { left: '-40%', width: '40%' },
|
||||||
|
'50%': { left: '20%', width: '60%' },
|
||||||
|
'100%': { left: '100%', width: '80%' },
|
||||||
|
},
|
||||||
},
|
},
|
||||||
animation: {
|
animation: {
|
||||||
'shimmer-pulse': 'shimmer 4s ease-in-out infinite',
|
'shimmer-pulse': 'shimmer 4s ease-in-out infinite',
|
||||||
'gradient-loader': 'loader 750ms ease-in-out infinite',
|
'gradient-loader': 'loader 750ms ease-in-out infinite',
|
||||||
|
indeterminate: 'indeterminate 1.5s infinite linear',
|
||||||
},
|
},
|
||||||
colors: {
|
colors: {
|
||||||
bgApp: 'var(--background-app)',
|
bgApp: 'var(--background-app)',
|
||||||
|
|||||||
Reference in New Issue
Block a user