mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-08 16:14:24 +01:00
chore: use typed notifications from rmcp (#3653)
This commit is contained in:
@@ -36,7 +36,8 @@ use goose::providers::pricing::initialize_pricing_cache;
|
||||
use goose::session;
|
||||
use input::InputResult;
|
||||
use mcp_core::handler::ToolError;
|
||||
use rmcp::model::{JsonRpcMessage, JsonRpcNotification, Notification, PromptMessage};
|
||||
use rmcp::model::PromptMessage;
|
||||
use rmcp::model::ServerNotification;
|
||||
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use rustyline::EditMode;
|
||||
@@ -1023,126 +1024,115 @@ impl Session {
|
||||
}
|
||||
}
|
||||
Some(Ok(AgentEvent::McpNotification((_id, message)))) => {
|
||||
if let JsonRpcMessage::Notification( JsonRpcNotification {
|
||||
notification: Notification {
|
||||
method,
|
||||
params: o,..
|
||||
},..
|
||||
}) = message {
|
||||
match method.as_str() {
|
||||
"notifications/message" => {
|
||||
let data = o.get("data").unwrap_or(&Value::Null);
|
||||
let (formatted_message, subagent_id, message_notification_type) = match data {
|
||||
Value::String(s) => (s.clone(), None, None),
|
||||
Value::Object(o) => {
|
||||
// Check for subagent notification structure first
|
||||
if let Some(Value::String(msg)) = o.get("message") {
|
||||
// Extract subagent info for better display
|
||||
let subagent_id = o.get("subagent_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
let notification_type = o.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
match &message {
|
||||
ServerNotification::LoggingMessageNotification(notification) => {
|
||||
let data = ¬ification.params.data;
|
||||
let (formatted_message, subagent_id, message_notification_type) = match data {
|
||||
Value::String(s) => (s.clone(), None, None),
|
||||
Value::Object(o) => {
|
||||
// Check for subagent notification structure first
|
||||
if let Some(Value::String(msg)) = o.get("message") {
|
||||
// Extract subagent info for better display
|
||||
let subagent_id = o.get("subagent_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
let notification_type = o.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let formatted = match notification_type {
|
||||
"subagent_created" | "completed" | "terminated" => {
|
||||
format!("🤖 {}", msg)
|
||||
}
|
||||
"tool_usage" | "tool_completed" | "tool_error" => {
|
||||
format!("🔧 {}", msg)
|
||||
}
|
||||
"message_processing" | "turn_progress" => {
|
||||
format!("💭 {}", msg)
|
||||
}
|
||||
"response_generated" => {
|
||||
// Check verbosity setting for subagent response content
|
||||
let config = Config::global();
|
||||
let min_priority = config
|
||||
.get_param::<f32>("GOOSE_CLI_MIN_PRIORITY")
|
||||
.ok()
|
||||
.unwrap_or(0.5);
|
||||
let formatted = match notification_type {
|
||||
"subagent_created" | "completed" | "terminated" => {
|
||||
format!("🤖 {}", msg)
|
||||
}
|
||||
"tool_usage" | "tool_completed" | "tool_error" => {
|
||||
format!("🔧 {}", msg)
|
||||
}
|
||||
"message_processing" | "turn_progress" => {
|
||||
format!("💭 {}", msg)
|
||||
}
|
||||
"response_generated" => {
|
||||
// Check verbosity setting for subagent response content
|
||||
let config = Config::global();
|
||||
let min_priority = config
|
||||
.get_param::<f32>("GOOSE_CLI_MIN_PRIORITY")
|
||||
.ok()
|
||||
.unwrap_or(0.5);
|
||||
|
||||
if min_priority > 0.1 && !self.debug {
|
||||
// High/Medium verbosity: show truncated response
|
||||
if let Some(response_content) = msg.strip_prefix("Responded: ") {
|
||||
format!("🤖 Responded: {}", safe_truncate(response_content, 100))
|
||||
} else {
|
||||
format!("🤖 {}", msg)
|
||||
}
|
||||
if min_priority > 0.1 && !self.debug {
|
||||
// High/Medium verbosity: show truncated response
|
||||
if let Some(response_content) = msg.strip_prefix("Responded: ") {
|
||||
format!("🤖 Responded: {}", safe_truncate(response_content, 100))
|
||||
} else {
|
||||
// All verbosity or debug: show full response
|
||||
format!("🤖 {}", msg)
|
||||
}
|
||||
} else {
|
||||
// All verbosity or debug: show full response
|
||||
format!("🤖 {}", msg)
|
||||
}
|
||||
_ => {
|
||||
msg.to_string()
|
||||
}
|
||||
};
|
||||
(formatted, Some(subagent_id.to_string()), Some(notification_type.to_string()))
|
||||
} else if let Some(Value::String(output)) = o.get("output") {
|
||||
// Fallback for other MCP notification types
|
||||
(output.to_owned(), None, None)
|
||||
} else if let Some(result) = format_task_execution_notification(data) {
|
||||
result
|
||||
} else {
|
||||
(data.to_string(), None, None)
|
||||
}
|
||||
},
|
||||
v => {
|
||||
(v.to_string(), None, None)
|
||||
},
|
||||
};
|
||||
}
|
||||
_ => {
|
||||
msg.to_string()
|
||||
}
|
||||
};
|
||||
(formatted, Some(subagent_id.to_string()), Some(notification_type.to_string()))
|
||||
} else if let Some(Value::String(output)) = o.get("output") {
|
||||
// Fallback for other MCP notification types
|
||||
(output.to_owned(), None, None)
|
||||
} else if let Some(result) = format_task_execution_notification(data) {
|
||||
result
|
||||
} else {
|
||||
(data.to_string(), None, None)
|
||||
}
|
||||
},
|
||||
v => {
|
||||
(v.to_string(), None, None)
|
||||
},
|
||||
};
|
||||
|
||||
// Handle subagent notifications - show immediately
|
||||
if let Some(_id) = subagent_id {
|
||||
// TODO: proper display for subagent notifications
|
||||
// Handle subagent notifications - show immediately
|
||||
if let Some(_id) = subagent_id {
|
||||
// TODO: proper display for subagent notifications
|
||||
if interactive {
|
||||
let _ = progress_bars.hide();
|
||||
println!("{}", console::style(&formatted_message).green().dim());
|
||||
} else {
|
||||
progress_bars.log(&formatted_message);
|
||||
}
|
||||
} else if let Some(ref notification_type) = message_notification_type {
|
||||
if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE {
|
||||
if interactive {
|
||||
let _ = progress_bars.hide();
|
||||
println!("{}", console::style(&formatted_message).green().dim());
|
||||
print!("{}", formatted_message);
|
||||
std::io::stdout().flush().unwrap();
|
||||
} else {
|
||||
progress_bars.log(&formatted_message);
|
||||
}
|
||||
} else if let Some(ref notification_type) = message_notification_type {
|
||||
if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE {
|
||||
if interactive {
|
||||
let _ = progress_bars.hide();
|
||||
print!("{}", formatted_message);
|
||||
std::io::stdout().flush().unwrap();
|
||||
} else {
|
||||
print!("{}", formatted_message);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
print!("{}", formatted_message);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Non-subagent notification, display immediately with compact spacing
|
||||
if interactive {
|
||||
let _ = progress_bars.hide();
|
||||
println!("{}", console::style(&formatted_message).green().dim());
|
||||
} else {
|
||||
progress_bars.log(&formatted_message);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Non-subagent notification, display immediately with compact spacing
|
||||
if interactive {
|
||||
let _ = progress_bars.hide();
|
||||
println!("{}", console::style(&formatted_message).green().dim());
|
||||
} else {
|
||||
progress_bars.log(&formatted_message);
|
||||
}
|
||||
},
|
||||
"notifications/progress" => {
|
||||
let progress = o.get("progress").and_then(|v| v.as_f64());
|
||||
let token = o.get("progressToken").map(|v| v.to_string());
|
||||
let message = o.get("message").and_then(|v| v.as_str());
|
||||
let total = o
|
||||
.get("total")
|
||||
.and_then(|v| v.as_f64());
|
||||
if let (Some(progress), Some(token)) = (progress, token) {
|
||||
progress_bars.update(
|
||||
token.as_str(),
|
||||
progress,
|
||||
total,
|
||||
message,
|
||||
);
|
||||
}
|
||||
},
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
},
|
||||
ServerNotification::ProgressNotification(notification) => {
|
||||
let progress = notification.params.progress;
|
||||
let text = notification.params.message.as_deref();
|
||||
let total = notification.params.total;
|
||||
let token = ¬ification.params.progress_token;
|
||||
progress_bars.update(
|
||||
&token.0.to_string(),
|
||||
progress,
|
||||
total,
|
||||
text,
|
||||
);
|
||||
},
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
Some(Ok(AgentEvent::ModelChange { model, mode })) => {
|
||||
|
||||
@@ -799,11 +799,11 @@ impl McpSpinners {
|
||||
spinner.set_message(message.to_string());
|
||||
}
|
||||
|
||||
pub fn update(&mut self, token: &str, value: f64, total: Option<f64>, message: Option<&str>) {
|
||||
pub fn update(&mut self, token: &str, value: u32, total: Option<u32>, message: Option<&str>) {
|
||||
let bar = self.bars.entry(token.to_string()).or_insert_with(|| {
|
||||
if let Some(total) = total {
|
||||
self.multi_bar.add(
|
||||
ProgressBar::new((total * 100.0) as u64).with_style(
|
||||
ProgressBar::new((total * 100) as u64).with_style(
|
||||
ProgressStyle::with_template("[{elapsed}] {bar:40} {pos:>3}/{len:3} {msg}")
|
||||
.unwrap(),
|
||||
),
|
||||
@@ -812,7 +812,7 @@ impl McpSpinners {
|
||||
self.multi_bar.add(ProgressBar::new_spinner())
|
||||
}
|
||||
});
|
||||
bar.set_position((value * 100.0) as u64);
|
||||
bar.set_position((value * 100) as u64);
|
||||
if let Some(msg) = message {
|
||||
bar.set_message(msg.to_string());
|
||||
}
|
||||
|
||||
@@ -672,6 +672,7 @@ impl DeveloperRouter {
|
||||
notification: Notification {
|
||||
method: "notifications/message".to_string(),
|
||||
params: object!({
|
||||
"level": "info",
|
||||
"data": {
|
||||
"type": "shell",
|
||||
"stream": "stdout",
|
||||
@@ -698,6 +699,7 @@ impl DeveloperRouter {
|
||||
notification: Notification {
|
||||
method: "notifications/message".to_string(),
|
||||
params: object!({
|
||||
"level": "info",
|
||||
"data": {
|
||||
"type": "shell",
|
||||
"stream": "stderr",
|
||||
|
||||
@@ -19,7 +19,7 @@ use goose::{
|
||||
session,
|
||||
};
|
||||
use mcp_core::ToolResult;
|
||||
use rmcp::model::{Content, JsonRpcMessage};
|
||||
use rmcp::model::{Content, ServerNotification};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
@@ -97,7 +97,7 @@ enum MessageEvent {
|
||||
},
|
||||
Notification {
|
||||
request_id: String,
|
||||
message: JsonRpcMessage,
|
||||
message: ServerNotification,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -45,8 +45,7 @@ use crate::tool_monitor::{ToolCall, ToolMonitor};
|
||||
use crate::utils::is_token_cancelled;
|
||||
use mcp_core::{ToolError, ToolResult};
|
||||
use regex::Regex;
|
||||
use rmcp::model::Tool;
|
||||
use rmcp::model::{Content, GetPromptResult, JsonRpcMessage, Prompt};
|
||||
use rmcp::model::{Content, GetPromptResult, Prompt, ServerNotification, Tool};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -83,7 +82,7 @@ pub struct Agent {
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum AgentEvent {
|
||||
Message(Message),
|
||||
McpNotification((String, JsonRpcMessage)),
|
||||
McpNotification((String, ServerNotification)),
|
||||
ModelChange { model: String, mode: String },
|
||||
}
|
||||
|
||||
@@ -94,19 +93,19 @@ impl Default for Agent {
|
||||
}
|
||||
|
||||
pub enum ToolStreamItem<T> {
|
||||
Message(JsonRpcMessage),
|
||||
Message(ServerNotification),
|
||||
Result(T),
|
||||
}
|
||||
|
||||
pub type ToolStream = Pin<Box<dyn Stream<Item = ToolStreamItem<ToolResult<Vec<Content>>>> + Send>>;
|
||||
|
||||
// tool_stream combines a stream of JsonRpcMessages with a future representing the
|
||||
// tool_stream combines a stream of ServerNotifications with a future representing the
|
||||
// final result of the tool call. MCP notifications are not request-scoped, but
|
||||
// this lets us capture all notifications emitted during the tool call for
|
||||
// simpler consumption
|
||||
pub fn tool_stream<S, F>(rx: S, done: F) -> ToolStream
|
||||
where
|
||||
S: Stream<Item = JsonRpcMessage> + Send + Unpin + 'static,
|
||||
S: Stream<Item = ServerNotification> + Send + Unpin + 'static,
|
||||
F: Future<Output = ToolResult<Vec<Content>>> + Send + 'static,
|
||||
{
|
||||
Box::pin(async_stream::stream! {
|
||||
|
||||
@@ -835,7 +835,7 @@ mod tests {
|
||||
CallToolResult, InitializeResult, ListPromptsResult, ListResourcesResult, ListToolsResult,
|
||||
ReadResourceResult,
|
||||
};
|
||||
use rmcp::model::{GetPromptResult, JsonRpcMessage};
|
||||
use rmcp::model::{GetPromptResult, ServerNotification};
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
@@ -891,7 +891,7 @@ mod tests {
|
||||
Err(Error::NotInitialized)
|
||||
}
|
||||
|
||||
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
|
||||
async fn subscribe(&self) -> mpsc::Receiver<ServerNotification> {
|
||||
mpsc::channel(1).1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::agents::subagent_execution_tool::task_execution_tracker::{
|
||||
use crate::agents::subagent_execution_tool::tasks::process_task;
|
||||
use crate::agents::subagent_execution_tool::workers::spawn_worker;
|
||||
use crate::agents::subagent_task_config::TaskConfig;
|
||||
use rmcp::model::JsonRpcMessage;
|
||||
use rmcp::model::ServerNotification;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -20,7 +20,7 @@ const DEFAULT_MAX_WORKERS: usize = 10;
|
||||
|
||||
pub async fn execute_single_task(
|
||||
task: &Task,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
notifier: mpsc::Sender<ServerNotification>,
|
||||
task_config: TaskConfig,
|
||||
cancellation_token: Option<CancellationToken>,
|
||||
) -> ExecutionResponse {
|
||||
@@ -56,7 +56,7 @@ pub async fn execute_single_task(
|
||||
|
||||
pub async fn execute_tasks_in_parallel(
|
||||
tasks: Vec<Task>,
|
||||
notifier: Sender<JsonRpcMessage>,
|
||||
notifier: Sender<ServerNotification>,
|
||||
task_config: TaskConfig,
|
||||
cancellation_token: Option<CancellationToken>,
|
||||
) -> ExecutionResponse {
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::agents::subagent_execution_tool::{
|
||||
tasks_manager::TasksManager,
|
||||
};
|
||||
use crate::agents::subagent_task_config::TaskConfig;
|
||||
use rmcp::model::JsonRpcMessage;
|
||||
use rmcp::model::ServerNotification;
|
||||
use serde_json::{json, Value};
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -14,7 +14,7 @@ use tokio_util::sync::CancellationToken;
|
||||
pub async fn execute_tasks(
|
||||
input: Value,
|
||||
execution_mode: ExecutionMode,
|
||||
notifier: Sender<JsonRpcMessage>,
|
||||
notifier: Sender<ServerNotification>,
|
||||
task_config: TaskConfig,
|
||||
tasks_manager: &TasksManager,
|
||||
cancellation_token: Option<CancellationToken>,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use mcp_core::ToolError;
|
||||
use rmcp::model::{Content, Tool, ToolAnnotations};
|
||||
use rmcp::model::{Content, ServerNotification, Tool, ToolAnnotations};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::agents::subagent_task_config::TaskConfig;
|
||||
@@ -8,7 +8,6 @@ use crate::agents::{
|
||||
subagent_execution_tool::task_types::ExecutionMode,
|
||||
subagent_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult,
|
||||
};
|
||||
use rmcp::model::JsonRpcMessage;
|
||||
use rmcp::object;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream;
|
||||
@@ -67,7 +66,7 @@ pub async fn run_tasks(
|
||||
tasks_manager: &TasksManager,
|
||||
cancellation_token: Option<CancellationToken>,
|
||||
) -> ToolCallResult {
|
||||
let (notification_tx, notification_rx) = mpsc::channel::<JsonRpcMessage>(100);
|
||||
let (notification_tx, notification_rx) = mpsc::channel::<ServerNotification>(100);
|
||||
|
||||
let tasks_manager_clone = tasks_manager.clone();
|
||||
let result_future = async move {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use rmcp::model::{JsonRpcMessage, JsonRpcNotification, JsonRpcVersion2_0, Notification};
|
||||
use rmcp::object;
|
||||
use rmcp::model::{
|
||||
LoggingLevel, LoggingMessageNotification, LoggingMessageNotificationMethod,
|
||||
LoggingMessageNotificationParam, ServerNotification,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
@@ -52,7 +54,7 @@ fn format_task_metadata(task_info: &TaskInfo) -> String {
|
||||
pub struct TaskExecutionTracker {
|
||||
tasks: Arc<RwLock<HashMap<String, TaskInfo>>>,
|
||||
last_refresh: Arc<RwLock<Instant>>,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
notifier: mpsc::Sender<ServerNotification>,
|
||||
display_mode: DisplayMode,
|
||||
cancellation_token: Option<CancellationToken>,
|
||||
}
|
||||
@@ -61,7 +63,7 @@ impl TaskExecutionTracker {
|
||||
pub fn new(
|
||||
tasks: Vec<Task>,
|
||||
display_mode: DisplayMode,
|
||||
notifier: Sender<JsonRpcMessage>,
|
||||
notifier: Sender<ServerNotification>,
|
||||
cancellation_token: Option<CancellationToken>,
|
||||
) -> Self {
|
||||
let task_map = tasks
|
||||
@@ -97,7 +99,7 @@ impl TaskExecutionTracker {
|
||||
|
||||
fn log_notification_error(
|
||||
&self,
|
||||
error: &mpsc::error::TrySendError<JsonRpcMessage>,
|
||||
error: &mpsc::error::TrySendError<ServerNotification>,
|
||||
context: &str,
|
||||
) {
|
||||
if !self.is_cancelled() {
|
||||
@@ -108,16 +110,17 @@ impl TaskExecutionTracker {
|
||||
fn try_send_notification(&self, event: TaskExecutionNotificationEvent, context: &str) {
|
||||
if let Err(e) = self
|
||||
.notifier
|
||||
.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
|
||||
jsonrpc: JsonRpcVersion2_0,
|
||||
notification: Notification {
|
||||
method: "notifications/message".to_string(),
|
||||
params: object!({
|
||||
"data": event.to_notification_data()
|
||||
}),
|
||||
.try_send(ServerNotification::LoggingMessageNotification(
|
||||
LoggingMessageNotification {
|
||||
method: LoggingMessageNotificationMethod,
|
||||
params: LoggingMessageNotificationParam {
|
||||
data: event.to_notification_data(),
|
||||
level: LoggingLevel::Info,
|
||||
logger: None,
|
||||
},
|
||||
extensions: Default::default(),
|
||||
},
|
||||
}))
|
||||
))
|
||||
{
|
||||
self.log_notification_error(&e, context);
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::sync::Arc;
|
||||
use async_stream::try_stream;
|
||||
use futures::stream::{self, BoxStream};
|
||||
use futures::{Stream, StreamExt};
|
||||
use rmcp::model::JsonRpcMessage;
|
||||
use rmcp::model::ServerNotification;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
@@ -19,7 +19,7 @@ use rmcp::model::Content;
|
||||
// can be used to receive notifications from the tool.
|
||||
pub struct ToolCallResult {
|
||||
pub result: Box<dyn Future<Output = ToolResult<Vec<Content>>> + Send + Unpin>,
|
||||
pub notification_stream: Option<Box<dyn Stream<Item = JsonRpcMessage> + Send + Unpin>>,
|
||||
pub notification_stream: Option<Box<dyn Stream<Item = ServerNotification> + Send + Unpin>>,
|
||||
}
|
||||
|
||||
impl From<ToolResult<Vec<Content>>> for ToolCallResult {
|
||||
|
||||
@@ -5,6 +5,7 @@ use mcp_core::protocol::{
|
||||
use rmcp::model::{
|
||||
GetPromptResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest,
|
||||
JsonRpcResponse, JsonRpcVersion2_0, Notification, NumberOrString, Request, RequestId,
|
||||
ServerNotification,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
@@ -106,7 +107,7 @@ pub trait McpClientTrait: Send + Sync {
|
||||
|
||||
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
|
||||
|
||||
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage>;
|
||||
async fn subscribe(&self) -> mpsc::Receiver<ServerNotification>;
|
||||
}
|
||||
|
||||
/// The MCP client is the interface for MCP operations.
|
||||
@@ -118,7 +119,7 @@ where
|
||||
next_id_counter: AtomicU64, // Added for atomic ID generation
|
||||
server_capabilities: Option<ServerCapabilities>,
|
||||
server_info: Option<Implementation>,
|
||||
notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<JsonRpcMessage>>>>,
|
||||
notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<ServerNotification>>>>,
|
||||
}
|
||||
|
||||
impl<T> McpClient<T>
|
||||
@@ -129,7 +130,7 @@ where
|
||||
let service = McpService::new(transport.clone());
|
||||
let service_ptr = service.clone();
|
||||
let notification_subscribers =
|
||||
Arc::new(Mutex::new(Vec::<mpsc::Sender<JsonRpcMessage>>::new()));
|
||||
Arc::new(Mutex::new(Vec::<mpsc::Sender<ServerNotification>>::new()));
|
||||
let subscribers_ptr = notification_subscribers.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
@@ -148,9 +149,22 @@ where
|
||||
}) => {
|
||||
service_ptr.respond(&id.to_string(), Ok(message)).await;
|
||||
}
|
||||
_ => {
|
||||
JsonRpcMessage::Notification(JsonRpcNotification {
|
||||
notification,
|
||||
..
|
||||
}) => {
|
||||
let mut subs = subscribers_ptr.lock().await;
|
||||
subs.retain(|sub| sub.try_send(message.clone()).is_ok());
|
||||
if let Some(server_notification) = notification.into() {
|
||||
subs.retain(|sub| {
|
||||
sub.try_send(server_notification.clone()).is_ok()
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!(
|
||||
"Received unexpected received message type: {:?}",
|
||||
message
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -437,7 +451,7 @@ where
|
||||
self.send_request("prompts/get", params).await
|
||||
}
|
||||
|
||||
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
|
||||
async fn subscribe(&self) -> mpsc::Receiver<ServerNotification> {
|
||||
let (tx, rx) = mpsc::channel(16);
|
||||
self.notification_subscribers.lock().await.push(tx);
|
||||
rx
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::task::{Context, Poll};
|
||||
use tokio::sync::{oneshot, RwLock};
|
||||
use tower::{timeout::Timeout, Service, ServiceBuilder};
|
||||
|
||||
use crate::transport::{Error, TransportHandle};
|
||||
use crate::transport::{Error, TransportHandle, TransportMessageRecv};
|
||||
|
||||
/// A wrapper service that implements Tower's Service trait for MCP transport
|
||||
#[derive(Clone)]
|
||||
@@ -23,7 +23,7 @@ impl<T: TransportHandle> McpService<T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
|
||||
pub async fn respond(&self, id: &str, response: Result<TransportMessageRecv, Error>) {
|
||||
self.pending_requests.respond(id, response).await
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ impl<T> Service<JsonRpcMessage> for McpService<T>
|
||||
where
|
||||
T: TransportHandle + Send + Sync + 'static,
|
||||
{
|
||||
type Response = JsonRpcMessage;
|
||||
type Response = TransportMessageRecv;
|
||||
type Error = Error;
|
||||
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
|
||||
@@ -63,7 +63,7 @@ where
|
||||
// Handle notifications without waiting for a response
|
||||
transport.send(request).await?;
|
||||
// Return a dummy response for notifications
|
||||
let dummy_response: JsonRpcMessage =
|
||||
let dummy_response: Self::Response =
|
||||
JsonRpcMessage::Response(rmcp::model::JsonRpcResponse {
|
||||
jsonrpc: rmcp::model::JsonRpcVersion2_0,
|
||||
id: rmcp::model::RequestId::Number(0),
|
||||
@@ -91,7 +91,7 @@ where
|
||||
|
||||
// A data structure to store pending requests and their response channels
|
||||
pub struct PendingRequests {
|
||||
requests: RwLock<HashMap<String, oneshot::Sender<Result<JsonRpcMessage, Error>>>>,
|
||||
requests: RwLock<HashMap<String, oneshot::Sender<Result<TransportMessageRecv, Error>>>>,
|
||||
}
|
||||
|
||||
impl Default for PendingRequests {
|
||||
@@ -107,11 +107,15 @@ impl PendingRequests {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn insert(&self, id: String, sender: oneshot::Sender<Result<JsonRpcMessage, Error>>) {
|
||||
pub async fn insert(
|
||||
&self,
|
||||
id: String,
|
||||
sender: oneshot::Sender<Result<TransportMessageRecv, Error>>,
|
||||
) {
|
||||
self.requests.write().await.insert(id, sender);
|
||||
}
|
||||
|
||||
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
|
||||
pub async fn respond(&self, id: &str, response: Result<TransportMessageRecv, Error>) {
|
||||
if let Some(tx) = self.requests.write().await.remove(id) {
|
||||
let _ = tx.send(response);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use async_trait::async_trait;
|
||||
use rmcp::model::JsonRpcMessage;
|
||||
use rmcp::model::{JsonObject, JsonRpcMessage, Request, ServerNotification};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
|
||||
/// A generic error type for transport operations.
|
||||
@@ -38,15 +38,6 @@ pub enum Error {
|
||||
SessionError(String),
|
||||
}
|
||||
|
||||
/// A message that can be sent through the transport
|
||||
#[derive(Debug)]
|
||||
pub struct TransportMessage {
|
||||
/// The JSON-RPC message to send
|
||||
pub message: JsonRpcMessage,
|
||||
/// Channel to receive the response on (None for notifications)
|
||||
pub response_tx: Option<oneshot::Sender<Result<JsonRpcMessage, Error>>>,
|
||||
}
|
||||
|
||||
/// A generic asynchronous transport trait with channel-based communication
|
||||
#[async_trait]
|
||||
pub trait Transport {
|
||||
@@ -60,10 +51,12 @@ pub trait Transport {
|
||||
async fn close(&self) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
pub type TransportMessageRecv = JsonRpcMessage<Request, JsonObject, ServerNotification>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait TransportHandle: Send + Sync + Clone + 'static {
|
||||
async fn send(&self, message: JsonRpcMessage) -> Result<(), Error>;
|
||||
async fn receive(&self) -> Result<JsonRpcMessage, Error>;
|
||||
async fn receive(&self) -> Result<TransportMessageRecv, Error>;
|
||||
}
|
||||
|
||||
pub async fn serialize_and_send(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::transport::Error;
|
||||
use crate::transport::{Error, TransportMessageRecv};
|
||||
use async_trait::async_trait;
|
||||
use eventsource_client::{Client, SSE};
|
||||
use futures::TryStreamExt;
|
||||
@@ -23,7 +23,7 @@ pub struct SseActor {
|
||||
/// Receives messages (requests/notifications) from the handle
|
||||
receiver: mpsc::Receiver<String>,
|
||||
/// Sends messages (responses) back to the handle
|
||||
sender: mpsc::Sender<JsonRpcMessage>,
|
||||
sender: mpsc::Sender<TransportMessageRecv>,
|
||||
/// Base SSE URL
|
||||
sse_url: String,
|
||||
/// For sending HTTP POST requests
|
||||
@@ -35,7 +35,7 @@ pub struct SseActor {
|
||||
impl SseActor {
|
||||
pub fn new(
|
||||
receiver: mpsc::Receiver<String>,
|
||||
sender: mpsc::Sender<JsonRpcMessage>,
|
||||
sender: mpsc::Sender<TransportMessageRecv>,
|
||||
sse_url: String,
|
||||
post_endpoint: Arc<RwLock<Option<String>>>,
|
||||
) -> Self {
|
||||
@@ -71,7 +71,7 @@ impl SseActor {
|
||||
/// - If a `message` event is received, parse it as `JsonRpcMessage`
|
||||
/// and respond to pending requests if it's a `Response`.
|
||||
async fn handle_incoming_messages(
|
||||
sender: mpsc::Sender<JsonRpcMessage>,
|
||||
sender: mpsc::Sender<TransportMessageRecv>,
|
||||
sse_url: String,
|
||||
post_endpoint: Arc<RwLock<Option<String>>>,
|
||||
) {
|
||||
@@ -109,7 +109,7 @@ impl SseActor {
|
||||
match event {
|
||||
SSE::Event(e) if e.event_type == "message" => {
|
||||
// Attempt to parse the SSE data as a JsonRpcMessage
|
||||
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
|
||||
match serde_json::from_str::<TransportMessageRecv>(&e.data) {
|
||||
Ok(message) => {
|
||||
let _ = sender.send(message).await;
|
||||
}
|
||||
@@ -184,7 +184,7 @@ impl SseActor {
|
||||
#[derive(Clone)]
|
||||
pub struct SseTransportHandle {
|
||||
sender: mpsc::Sender<String>,
|
||||
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>,
|
||||
receiver: Arc<Mutex<mpsc::Receiver<TransportMessageRecv>>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -193,7 +193,7 @@ impl TransportHandle for SseTransportHandle {
|
||||
serialize_and_send(&self.sender, message).await
|
||||
}
|
||||
|
||||
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
|
||||
async fn receive(&self) -> Result<TransportMessageRecv, Error> {
|
||||
let mut receiver = self.receiver.lock().await;
|
||||
receiver.recv().await.ok_or(Error::ChannelClosed)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ use nix::sys::signal::{kill, Signal};
|
||||
#[cfg(unix)]
|
||||
use nix::unistd::{getpgid, Pid};
|
||||
|
||||
use crate::transport::TransportMessageRecv;
|
||||
|
||||
use super::{serialize_and_send, Error, Transport, TransportHandle};
|
||||
|
||||
// Global to track process groups we've created
|
||||
@@ -24,7 +26,7 @@ static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
|
||||
/// It uses channels for message passing and handles responses asynchronously through a background task.
|
||||
pub struct StdioActor {
|
||||
receiver: Option<mpsc::Receiver<String>>,
|
||||
sender: Option<mpsc::Sender<JsonRpcMessage>>,
|
||||
sender: Option<mpsc::Sender<TransportMessageRecv>>,
|
||||
process: Child, // we store the process to keep it alive
|
||||
error_sender: mpsc::Sender<Error>,
|
||||
stdin: Option<ChildStdin>,
|
||||
@@ -98,7 +100,7 @@ impl StdioActor {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_proc_output(stdout: ChildStdout, sender: mpsc::Sender<JsonRpcMessage>) {
|
||||
async fn handle_proc_output(stdout: ChildStdout, sender: mpsc::Sender<TransportMessageRecv>) {
|
||||
let mut reader = BufReader::new(stdout);
|
||||
let mut line = String::new();
|
||||
loop {
|
||||
@@ -108,7 +110,7 @@ impl StdioActor {
|
||||
break;
|
||||
} // EOF
|
||||
Ok(_) => {
|
||||
if let Ok(message) = serde_json::from_str::<JsonRpcMessage>(&line) {
|
||||
if let Ok(message) = serde_json::from_str::<TransportMessageRecv>(&line) {
|
||||
tracing::debug!(
|
||||
message = ?message,
|
||||
"Received incoming message"
|
||||
@@ -149,8 +151,8 @@ impl StdioActor {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StdioTransportHandle {
|
||||
sender: mpsc::Sender<String>, // to process
|
||||
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>, // from process
|
||||
sender: mpsc::Sender<String>, // to process
|
||||
receiver: Arc<Mutex<mpsc::Receiver<TransportMessageRecv>>>, // from process
|
||||
error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>,
|
||||
}
|
||||
|
||||
@@ -163,7 +165,7 @@ impl TransportHandle for StdioTransportHandle {
|
||||
result
|
||||
}
|
||||
|
||||
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
|
||||
async fn receive(&self) -> Result<TransportMessageRecv, Error> {
|
||||
let mut receiver = self.receiver.lock().await;
|
||||
match receiver.recv().await {
|
||||
Some(message) => Ok(message),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::oauth::{authenticate_service, ServiceConfig};
|
||||
use crate::transport::Error;
|
||||
use crate::transport::{Error, TransportMessageRecv};
|
||||
use async_trait::async_trait;
|
||||
use eventsource_client::{Client, SSE};
|
||||
use futures::TryStreamExt;
|
||||
@@ -25,7 +25,7 @@ pub struct StreamableHttpActor {
|
||||
/// Receives messages (requests/notifications) from the handle
|
||||
receiver: mpsc::Receiver<String>,
|
||||
/// Sends messages (responses) back to the handle
|
||||
sender: mpsc::Sender<JsonRpcMessage>,
|
||||
sender: mpsc::Sender<TransportMessageRecv>,
|
||||
/// MCP endpoint URL
|
||||
mcp_endpoint: String,
|
||||
/// HTTP client for sending requests
|
||||
@@ -41,7 +41,7 @@ pub struct StreamableHttpActor {
|
||||
impl StreamableHttpActor {
|
||||
pub fn new(
|
||||
receiver: mpsc::Receiver<String>,
|
||||
sender: mpsc::Sender<JsonRpcMessage>,
|
||||
sender: mpsc::Sender<TransportMessageRecv>,
|
||||
mcp_endpoint: String,
|
||||
session_id: Arc<RwLock<Option<String>>>,
|
||||
env: HashMap<String, String>,
|
||||
@@ -84,8 +84,8 @@ impl StreamableHttpActor {
|
||||
debug!("Sending message to MCP endpoint: {}", message_str);
|
||||
|
||||
// Parse the message to determine if it's a request that expects a response
|
||||
let parsed_message: JsonRpcMessage =
|
||||
serde_json::from_str(&message_str).map_err(Error::Serialization)?;
|
||||
let parsed_message = serde_json::from_str::<TransportMessageRecv>(&message_str)
|
||||
.map_err(Error::Serialization)?;
|
||||
|
||||
let expects_response = matches!(
|
||||
parsed_message,
|
||||
@@ -196,8 +196,8 @@ impl StreamableHttpActor {
|
||||
})?;
|
||||
|
||||
if !response_text.is_empty() {
|
||||
let json_message: JsonRpcMessage =
|
||||
serde_json::from_str(&response_text).map_err(Error::Serialization)?;
|
||||
let json_message = serde_json::from_str::<TransportMessageRecv>(&response_text)
|
||||
.map_err(Error::Serialization)?;
|
||||
|
||||
let _ = self.sender.send(json_message).await;
|
||||
}
|
||||
@@ -267,7 +267,7 @@ impl StreamableHttpActor {
|
||||
// Empty line indicates end of event
|
||||
if !event_data.is_empty() {
|
||||
// Parse the streamed data as JSON-RPC message
|
||||
match serde_json::from_str::<JsonRpcMessage>(&event_data) {
|
||||
match serde_json::from_str::<TransportMessageRecv>(&event_data) {
|
||||
Ok(message) => {
|
||||
debug!("Received streaming HTTP response message: {:?}", message);
|
||||
let _ = self.sender.send(message).await;
|
||||
@@ -301,7 +301,7 @@ impl StreamableHttpActor {
|
||||
#[derive(Clone)]
|
||||
pub struct StreamableHttpTransportHandle {
|
||||
sender: mpsc::Sender<String>,
|
||||
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>,
|
||||
receiver: Arc<Mutex<mpsc::Receiver<TransportMessageRecv>>>,
|
||||
session_id: Arc<RwLock<Option<String>>>,
|
||||
mcp_endpoint: String,
|
||||
http_client: HttpClient,
|
||||
@@ -314,7 +314,7 @@ impl TransportHandle for StreamableHttpTransportHandle {
|
||||
serialize_and_send(&self.sender, message).await
|
||||
}
|
||||
|
||||
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
|
||||
async fn receive(&self) -> Result<TransportMessageRecv, Error> {
|
||||
let mut receiver = self.receiver.lock().await;
|
||||
receiver.recv().await.ok_or(Error::ChannelClosed)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user