chore: use typed notifications from rmcp (#3653)

This commit is contained in:
Jack Amadeo
2025-07-25 14:04:18 -04:00
committed by GitHub
parent 31a5f9cbbc
commit 0ef38c6658
17 changed files with 198 additions and 192 deletions

View File

@@ -36,7 +36,8 @@ use goose::providers::pricing::initialize_pricing_cache;
use goose::session;
use input::InputResult;
use mcp_core::handler::ToolError;
use rmcp::model::{JsonRpcMessage, JsonRpcNotification, Notification, PromptMessage};
use rmcp::model::PromptMessage;
use rmcp::model::ServerNotification;
use rand::{distributions::Alphanumeric, Rng};
use rustyline::EditMode;
@@ -1023,126 +1024,115 @@ impl Session {
}
}
Some(Ok(AgentEvent::McpNotification((_id, message)))) => {
if let JsonRpcMessage::Notification( JsonRpcNotification {
notification: Notification {
method,
params: o,..
},..
}) = message {
match method.as_str() {
"notifications/message" => {
let data = o.get("data").unwrap_or(&Value::Null);
let (formatted_message, subagent_id, message_notification_type) = match data {
Value::String(s) => (s.clone(), None, None),
Value::Object(o) => {
// Check for subagent notification structure first
if let Some(Value::String(msg)) = o.get("message") {
// Extract subagent info for better display
let subagent_id = o.get("subagent_id")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let notification_type = o.get("type")
.and_then(|v| v.as_str())
.unwrap_or("");
match &message {
ServerNotification::LoggingMessageNotification(notification) => {
let data = &notification.params.data;
let (formatted_message, subagent_id, message_notification_type) = match data {
Value::String(s) => (s.clone(), None, None),
Value::Object(o) => {
// Check for subagent notification structure first
if let Some(Value::String(msg)) = o.get("message") {
// Extract subagent info for better display
let subagent_id = o.get("subagent_id")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let notification_type = o.get("type")
.and_then(|v| v.as_str())
.unwrap_or("");
let formatted = match notification_type {
"subagent_created" | "completed" | "terminated" => {
format!("🤖 {}", msg)
}
"tool_usage" | "tool_completed" | "tool_error" => {
format!("🔧 {}", msg)
}
"message_processing" | "turn_progress" => {
format!("💭 {}", msg)
}
"response_generated" => {
// Check verbosity setting for subagent response content
let config = Config::global();
let min_priority = config
.get_param::<f32>("GOOSE_CLI_MIN_PRIORITY")
.ok()
.unwrap_or(0.5);
let formatted = match notification_type {
"subagent_created" | "completed" | "terminated" => {
format!("🤖 {}", msg)
}
"tool_usage" | "tool_completed" | "tool_error" => {
format!("🔧 {}", msg)
}
"message_processing" | "turn_progress" => {
format!("💭 {}", msg)
}
"response_generated" => {
// Check verbosity setting for subagent response content
let config = Config::global();
let min_priority = config
.get_param::<f32>("GOOSE_CLI_MIN_PRIORITY")
.ok()
.unwrap_or(0.5);
if min_priority > 0.1 && !self.debug {
// High/Medium verbosity: show truncated response
if let Some(response_content) = msg.strip_prefix("Responded: ") {
format!("🤖 Responded: {}", safe_truncate(response_content, 100))
} else {
format!("🤖 {}", msg)
}
if min_priority > 0.1 && !self.debug {
// High/Medium verbosity: show truncated response
if let Some(response_content) = msg.strip_prefix("Responded: ") {
format!("🤖 Responded: {}", safe_truncate(response_content, 100))
} else {
// All verbosity or debug: show full response
format!("🤖 {}", msg)
}
} else {
// All verbosity or debug: show full response
format!("🤖 {}", msg)
}
_ => {
msg.to_string()
}
};
(formatted, Some(subagent_id.to_string()), Some(notification_type.to_string()))
} else if let Some(Value::String(output)) = o.get("output") {
// Fallback for other MCP notification types
(output.to_owned(), None, None)
} else if let Some(result) = format_task_execution_notification(data) {
result
} else {
(data.to_string(), None, None)
}
},
v => {
(v.to_string(), None, None)
},
};
}
_ => {
msg.to_string()
}
};
(formatted, Some(subagent_id.to_string()), Some(notification_type.to_string()))
} else if let Some(Value::String(output)) = o.get("output") {
// Fallback for other MCP notification types
(output.to_owned(), None, None)
} else if let Some(result) = format_task_execution_notification(data) {
result
} else {
(data.to_string(), None, None)
}
},
v => {
(v.to_string(), None, None)
},
};
// Handle subagent notifications - show immediately
if let Some(_id) = subagent_id {
// TODO: proper display for subagent notifications
// Handle subagent notifications - show immediately
if let Some(_id) = subagent_id {
// TODO: proper display for subagent notifications
if interactive {
let _ = progress_bars.hide();
println!("{}", console::style(&formatted_message).green().dim());
} else {
progress_bars.log(&formatted_message);
}
} else if let Some(ref notification_type) = message_notification_type {
if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE {
if interactive {
let _ = progress_bars.hide();
println!("{}", console::style(&formatted_message).green().dim());
print!("{}", formatted_message);
std::io::stdout().flush().unwrap();
} else {
progress_bars.log(&formatted_message);
}
} else if let Some(ref notification_type) = message_notification_type {
if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE {
if interactive {
let _ = progress_bars.hide();
print!("{}", formatted_message);
std::io::stdout().flush().unwrap();
} else {
print!("{}", formatted_message);
std::io::stdout().flush().unwrap();
}
print!("{}", formatted_message);
std::io::stdout().flush().unwrap();
}
}
else {
// Non-subagent notification, display immediately with compact spacing
if interactive {
let _ = progress_bars.hide();
println!("{}", console::style(&formatted_message).green().dim());
} else {
progress_bars.log(&formatted_message);
}
}
else {
// Non-subagent notification, display immediately with compact spacing
if interactive {
let _ = progress_bars.hide();
println!("{}", console::style(&formatted_message).green().dim());
} else {
progress_bars.log(&formatted_message);
}
},
"notifications/progress" => {
let progress = o.get("progress").and_then(|v| v.as_f64());
let token = o.get("progressToken").map(|v| v.to_string());
let message = o.get("message").and_then(|v| v.as_str());
let total = o
.get("total")
.and_then(|v| v.as_f64());
if let (Some(progress), Some(token)) = (progress, token) {
progress_bars.update(
token.as_str(),
progress,
total,
message,
);
}
},
_ => (),
}
}
},
ServerNotification::ProgressNotification(notification) => {
let progress = notification.params.progress;
let text = notification.params.message.as_deref();
let total = notification.params.total;
let token = &notification.params.progress_token;
progress_bars.update(
&token.0.to_string(),
progress,
total,
text,
);
},
_ => (),
}
}
Some(Ok(AgentEvent::ModelChange { model, mode })) => {

View File

@@ -799,11 +799,11 @@ impl McpSpinners {
spinner.set_message(message.to_string());
}
pub fn update(&mut self, token: &str, value: f64, total: Option<f64>, message: Option<&str>) {
pub fn update(&mut self, token: &str, value: u32, total: Option<u32>, message: Option<&str>) {
let bar = self.bars.entry(token.to_string()).or_insert_with(|| {
if let Some(total) = total {
self.multi_bar.add(
ProgressBar::new((total * 100.0) as u64).with_style(
ProgressBar::new((total * 100) as u64).with_style(
ProgressStyle::with_template("[{elapsed}] {bar:40} {pos:>3}/{len:3} {msg}")
.unwrap(),
),
@@ -812,7 +812,7 @@ impl McpSpinners {
self.multi_bar.add(ProgressBar::new_spinner())
}
});
bar.set_position((value * 100.0) as u64);
bar.set_position((value * 100) as u64);
if let Some(msg) = message {
bar.set_message(msg.to_string());
}

View File

@@ -672,6 +672,7 @@ impl DeveloperRouter {
notification: Notification {
method: "notifications/message".to_string(),
params: object!({
"level": "info",
"data": {
"type": "shell",
"stream": "stdout",
@@ -698,6 +699,7 @@ impl DeveloperRouter {
notification: Notification {
method: "notifications/message".to_string(),
params: object!({
"level": "info",
"data": {
"type": "shell",
"stream": "stderr",

View File

@@ -19,7 +19,7 @@ use goose::{
session,
};
use mcp_core::ToolResult;
use rmcp::model::{Content, JsonRpcMessage};
use rmcp::model::{Content, ServerNotification};
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Value;
@@ -97,7 +97,7 @@ enum MessageEvent {
},
Notification {
request_id: String,
message: JsonRpcMessage,
message: ServerNotification,
},
}

View File

@@ -45,8 +45,7 @@ use crate::tool_monitor::{ToolCall, ToolMonitor};
use crate::utils::is_token_cancelled;
use mcp_core::{ToolError, ToolResult};
use regex::Regex;
use rmcp::model::Tool;
use rmcp::model::{Content, GetPromptResult, JsonRpcMessage, Prompt};
use rmcp::model::{Content, GetPromptResult, Prompt, ServerNotification, Tool};
use serde_json::Value;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_util::sync::CancellationToken;
@@ -83,7 +82,7 @@ pub struct Agent {
#[derive(Clone, Debug)]
pub enum AgentEvent {
Message(Message),
McpNotification((String, JsonRpcMessage)),
McpNotification((String, ServerNotification)),
ModelChange { model: String, mode: String },
}
@@ -94,19 +93,19 @@ impl Default for Agent {
}
pub enum ToolStreamItem<T> {
Message(JsonRpcMessage),
Message(ServerNotification),
Result(T),
}
pub type ToolStream = Pin<Box<dyn Stream<Item = ToolStreamItem<ToolResult<Vec<Content>>>> + Send>>;
// tool_stream combines a stream of JsonRpcMessages with a future representing the
// tool_stream combines a stream of ServerNotifications with a future representing the
// final result of the tool call. MCP notifications are not request-scoped, but
// this lets us capture all notifications emitted during the tool call for
// simpler consumption
pub fn tool_stream<S, F>(rx: S, done: F) -> ToolStream
where
S: Stream<Item = JsonRpcMessage> + Send + Unpin + 'static,
S: Stream<Item = ServerNotification> + Send + Unpin + 'static,
F: Future<Output = ToolResult<Vec<Content>>> + Send + 'static,
{
Box::pin(async_stream::stream! {

View File

@@ -835,7 +835,7 @@ mod tests {
CallToolResult, InitializeResult, ListPromptsResult, ListResourcesResult, ListToolsResult,
ReadResourceResult,
};
use rmcp::model::{GetPromptResult, JsonRpcMessage};
use rmcp::model::{GetPromptResult, ServerNotification};
use serde_json::json;
use tokio::sync::mpsc;
@@ -891,7 +891,7 @@ mod tests {
Err(Error::NotInitialized)
}
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
async fn subscribe(&self) -> mpsc::Receiver<ServerNotification> {
mpsc::channel(1).1
}
}

View File

@@ -7,7 +7,7 @@ use crate::agents::subagent_execution_tool::task_execution_tracker::{
use crate::agents::subagent_execution_tool::tasks::process_task;
use crate::agents::subagent_execution_tool::workers::spawn_worker;
use crate::agents::subagent_task_config::TaskConfig;
use rmcp::model::JsonRpcMessage;
use rmcp::model::ServerNotification;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use tokio::sync::mpsc;
@@ -20,7 +20,7 @@ const DEFAULT_MAX_WORKERS: usize = 10;
pub async fn execute_single_task(
task: &Task,
notifier: mpsc::Sender<JsonRpcMessage>,
notifier: mpsc::Sender<ServerNotification>,
task_config: TaskConfig,
cancellation_token: Option<CancellationToken>,
) -> ExecutionResponse {
@@ -56,7 +56,7 @@ pub async fn execute_single_task(
pub async fn execute_tasks_in_parallel(
tasks: Vec<Task>,
notifier: Sender<JsonRpcMessage>,
notifier: Sender<ServerNotification>,
task_config: TaskConfig,
cancellation_token: Option<CancellationToken>,
) -> ExecutionResponse {

View File

@@ -6,7 +6,7 @@ use crate::agents::subagent_execution_tool::{
tasks_manager::TasksManager,
};
use crate::agents::subagent_task_config::TaskConfig;
use rmcp::model::JsonRpcMessage;
use rmcp::model::ServerNotification;
use serde_json::{json, Value};
use tokio::sync::mpsc::Sender;
use tokio_util::sync::CancellationToken;
@@ -14,7 +14,7 @@ use tokio_util::sync::CancellationToken;
pub async fn execute_tasks(
input: Value,
execution_mode: ExecutionMode,
notifier: Sender<JsonRpcMessage>,
notifier: Sender<ServerNotification>,
task_config: TaskConfig,
tasks_manager: &TasksManager,
cancellation_token: Option<CancellationToken>,

View File

@@ -1,5 +1,5 @@
use mcp_core::ToolError;
use rmcp::model::{Content, Tool, ToolAnnotations};
use rmcp::model::{Content, ServerNotification, Tool, ToolAnnotations};
use serde_json::Value;
use crate::agents::subagent_task_config::TaskConfig;
@@ -8,7 +8,6 @@ use crate::agents::{
subagent_execution_tool::task_types::ExecutionMode,
subagent_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult,
};
use rmcp::model::JsonRpcMessage;
use rmcp::object;
use tokio::sync::mpsc;
use tokio_stream;
@@ -67,7 +66,7 @@ pub async fn run_tasks(
tasks_manager: &TasksManager,
cancellation_token: Option<CancellationToken>,
) -> ToolCallResult {
let (notification_tx, notification_rx) = mpsc::channel::<JsonRpcMessage>(100);
let (notification_tx, notification_rx) = mpsc::channel::<ServerNotification>(100);
let tasks_manager_clone = tasks_manager.clone();
let result_future = async move {

View File

@@ -1,5 +1,7 @@
use rmcp::model::{JsonRpcMessage, JsonRpcNotification, JsonRpcVersion2_0, Notification};
use rmcp::object;
use rmcp::model::{
LoggingLevel, LoggingMessageNotification, LoggingMessageNotificationMethod,
LoggingMessageNotificationParam, ServerNotification,
};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
@@ -52,7 +54,7 @@ fn format_task_metadata(task_info: &TaskInfo) -> String {
pub struct TaskExecutionTracker {
tasks: Arc<RwLock<HashMap<String, TaskInfo>>>,
last_refresh: Arc<RwLock<Instant>>,
notifier: mpsc::Sender<JsonRpcMessage>,
notifier: mpsc::Sender<ServerNotification>,
display_mode: DisplayMode,
cancellation_token: Option<CancellationToken>,
}
@@ -61,7 +63,7 @@ impl TaskExecutionTracker {
pub fn new(
tasks: Vec<Task>,
display_mode: DisplayMode,
notifier: Sender<JsonRpcMessage>,
notifier: Sender<ServerNotification>,
cancellation_token: Option<CancellationToken>,
) -> Self {
let task_map = tasks
@@ -97,7 +99,7 @@ impl TaskExecutionTracker {
fn log_notification_error(
&self,
error: &mpsc::error::TrySendError<JsonRpcMessage>,
error: &mpsc::error::TrySendError<ServerNotification>,
context: &str,
) {
if !self.is_cancelled() {
@@ -108,16 +110,17 @@ impl TaskExecutionTracker {
fn try_send_notification(&self, event: TaskExecutionNotificationEvent, context: &str) {
if let Err(e) = self
.notifier
.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: JsonRpcVersion2_0,
notification: Notification {
method: "notifications/message".to_string(),
params: object!({
"data": event.to_notification_data()
}),
.try_send(ServerNotification::LoggingMessageNotification(
LoggingMessageNotification {
method: LoggingMessageNotificationMethod,
params: LoggingMessageNotificationParam {
data: event.to_notification_data(),
level: LoggingLevel::Info,
logger: None,
},
extensions: Default::default(),
},
}))
))
{
self.log_notification_error(&e, context);
}

View File

@@ -4,7 +4,7 @@ use std::sync::Arc;
use async_stream::try_stream;
use futures::stream::{self, BoxStream};
use futures::{Stream, StreamExt};
use rmcp::model::JsonRpcMessage;
use rmcp::model::ServerNotification;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
@@ -19,7 +19,7 @@ use rmcp::model::Content;
// can be used to receive notifications from the tool.
pub struct ToolCallResult {
pub result: Box<dyn Future<Output = ToolResult<Vec<Content>>> + Send + Unpin>,
pub notification_stream: Option<Box<dyn Stream<Item = JsonRpcMessage> + Send + Unpin>>,
pub notification_stream: Option<Box<dyn Stream<Item = ServerNotification> + Send + Unpin>>,
}
impl From<ToolResult<Vec<Content>>> for ToolCallResult {

View File

@@ -5,6 +5,7 @@ use mcp_core::protocol::{
use rmcp::model::{
GetPromptResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest,
JsonRpcResponse, JsonRpcVersion2_0, Notification, NumberOrString, Request, RequestId,
ServerNotification,
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
@@ -106,7 +107,7 @@ pub trait McpClientTrait: Send + Sync {
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage>;
async fn subscribe(&self) -> mpsc::Receiver<ServerNotification>;
}
/// The MCP client is the interface for MCP operations.
@@ -118,7 +119,7 @@ where
next_id_counter: AtomicU64, // Added for atomic ID generation
server_capabilities: Option<ServerCapabilities>,
server_info: Option<Implementation>,
notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<JsonRpcMessage>>>>,
notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<ServerNotification>>>>,
}
impl<T> McpClient<T>
@@ -129,7 +130,7 @@ where
let service = McpService::new(transport.clone());
let service_ptr = service.clone();
let notification_subscribers =
Arc::new(Mutex::new(Vec::<mpsc::Sender<JsonRpcMessage>>::new()));
Arc::new(Mutex::new(Vec::<mpsc::Sender<ServerNotification>>::new()));
let subscribers_ptr = notification_subscribers.clone();
tokio::spawn(async move {
@@ -148,9 +149,22 @@ where
}) => {
service_ptr.respond(&id.to_string(), Ok(message)).await;
}
_ => {
JsonRpcMessage::Notification(JsonRpcNotification {
notification,
..
}) => {
let mut subs = subscribers_ptr.lock().await;
subs.retain(|sub| sub.try_send(message.clone()).is_ok());
if let Some(server_notification) = notification.into() {
subs.retain(|sub| {
sub.try_send(server_notification.clone()).is_ok()
});
}
}
_ => {
tracing::warn!(
"Received unexpected received message type: {:?}",
message
);
}
}
}
@@ -437,7 +451,7 @@ where
self.send_request("prompts/get", params).await
}
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
async fn subscribe(&self) -> mpsc::Receiver<ServerNotification> {
let (tx, rx) = mpsc::channel(16);
self.notification_subscribers.lock().await.push(tx);
rx

View File

@@ -6,7 +6,7 @@ use std::task::{Context, Poll};
use tokio::sync::{oneshot, RwLock};
use tower::{timeout::Timeout, Service, ServiceBuilder};
use crate::transport::{Error, TransportHandle};
use crate::transport::{Error, TransportHandle, TransportMessageRecv};
/// A wrapper service that implements Tower's Service trait for MCP transport
#[derive(Clone)]
@@ -23,7 +23,7 @@ impl<T: TransportHandle> McpService<T> {
}
}
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
pub async fn respond(&self, id: &str, response: Result<TransportMessageRecv, Error>) {
self.pending_requests.respond(id, response).await
}
@@ -36,7 +36,7 @@ impl<T> Service<JsonRpcMessage> for McpService<T>
where
T: TransportHandle + Send + Sync + 'static,
{
type Response = JsonRpcMessage;
type Response = TransportMessageRecv;
type Error = Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
@@ -63,7 +63,7 @@ where
// Handle notifications without waiting for a response
transport.send(request).await?;
// Return a dummy response for notifications
let dummy_response: JsonRpcMessage =
let dummy_response: Self::Response =
JsonRpcMessage::Response(rmcp::model::JsonRpcResponse {
jsonrpc: rmcp::model::JsonRpcVersion2_0,
id: rmcp::model::RequestId::Number(0),
@@ -91,7 +91,7 @@ where
// A data structure to store pending requests and their response channels
pub struct PendingRequests {
requests: RwLock<HashMap<String, oneshot::Sender<Result<JsonRpcMessage, Error>>>>,
requests: RwLock<HashMap<String, oneshot::Sender<Result<TransportMessageRecv, Error>>>>,
}
impl Default for PendingRequests {
@@ -107,11 +107,15 @@ impl PendingRequests {
}
}
pub async fn insert(&self, id: String, sender: oneshot::Sender<Result<JsonRpcMessage, Error>>) {
pub async fn insert(
&self,
id: String,
sender: oneshot::Sender<Result<TransportMessageRecv, Error>>,
) {
self.requests.write().await.insert(id, sender);
}
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
pub async fn respond(&self, id: &str, response: Result<TransportMessageRecv, Error>) {
if let Some(tx) = self.requests.write().await.remove(id) {
let _ = tx.send(response);
}

View File

@@ -1,7 +1,7 @@
use async_trait::async_trait;
use rmcp::model::JsonRpcMessage;
use rmcp::model::{JsonObject, JsonRpcMessage, Request, ServerNotification};
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::mpsc;
pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
/// A generic error type for transport operations.
@@ -38,15 +38,6 @@ pub enum Error {
SessionError(String),
}
/// A message that can be sent through the transport
#[derive(Debug)]
pub struct TransportMessage {
/// The JSON-RPC message to send
pub message: JsonRpcMessage,
/// Channel to receive the response on (None for notifications)
pub response_tx: Option<oneshot::Sender<Result<JsonRpcMessage, Error>>>,
}
/// A generic asynchronous transport trait with channel-based communication
#[async_trait]
pub trait Transport {
@@ -60,10 +51,12 @@ pub trait Transport {
async fn close(&self) -> Result<(), Error>;
}
pub type TransportMessageRecv = JsonRpcMessage<Request, JsonObject, ServerNotification>;
#[async_trait]
pub trait TransportHandle: Send + Sync + Clone + 'static {
async fn send(&self, message: JsonRpcMessage) -> Result<(), Error>;
async fn receive(&self) -> Result<JsonRpcMessage, Error>;
async fn receive(&self) -> Result<TransportMessageRecv, Error>;
}
pub async fn serialize_and_send(

View File

@@ -1,4 +1,4 @@
use crate::transport::Error;
use crate::transport::{Error, TransportMessageRecv};
use async_trait::async_trait;
use eventsource_client::{Client, SSE};
use futures::TryStreamExt;
@@ -23,7 +23,7 @@ pub struct SseActor {
/// Receives messages (requests/notifications) from the handle
receiver: mpsc::Receiver<String>,
/// Sends messages (responses) back to the handle
sender: mpsc::Sender<JsonRpcMessage>,
sender: mpsc::Sender<TransportMessageRecv>,
/// Base SSE URL
sse_url: String,
/// For sending HTTP POST requests
@@ -35,7 +35,7 @@ pub struct SseActor {
impl SseActor {
pub fn new(
receiver: mpsc::Receiver<String>,
sender: mpsc::Sender<JsonRpcMessage>,
sender: mpsc::Sender<TransportMessageRecv>,
sse_url: String,
post_endpoint: Arc<RwLock<Option<String>>>,
) -> Self {
@@ -71,7 +71,7 @@ impl SseActor {
/// - If a `message` event is received, parse it as `JsonRpcMessage`
/// and respond to pending requests if it's a `Response`.
async fn handle_incoming_messages(
sender: mpsc::Sender<JsonRpcMessage>,
sender: mpsc::Sender<TransportMessageRecv>,
sse_url: String,
post_endpoint: Arc<RwLock<Option<String>>>,
) {
@@ -109,7 +109,7 @@ impl SseActor {
match event {
SSE::Event(e) if e.event_type == "message" => {
// Attempt to parse the SSE data as a JsonRpcMessage
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
match serde_json::from_str::<TransportMessageRecv>(&e.data) {
Ok(message) => {
let _ = sender.send(message).await;
}
@@ -184,7 +184,7 @@ impl SseActor {
#[derive(Clone)]
pub struct SseTransportHandle {
sender: mpsc::Sender<String>,
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>,
receiver: Arc<Mutex<mpsc::Receiver<TransportMessageRecv>>>,
}
#[async_trait::async_trait]
@@ -193,7 +193,7 @@ impl TransportHandle for SseTransportHandle {
serialize_and_send(&self.sender, message).await
}
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
async fn receive(&self) -> Result<TransportMessageRecv, Error> {
let mut receiver = self.receiver.lock().await;
receiver.recv().await.ok_or(Error::ChannelClosed)
}

View File

@@ -14,6 +14,8 @@ use nix::sys::signal::{kill, Signal};
#[cfg(unix)]
use nix::unistd::{getpgid, Pid};
use crate::transport::TransportMessageRecv;
use super::{serialize_and_send, Error, Transport, TransportHandle};
// Global to track process groups we've created
@@ -24,7 +26,7 @@ static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
/// It uses channels for message passing and handles responses asynchronously through a background task.
pub struct StdioActor {
receiver: Option<mpsc::Receiver<String>>,
sender: Option<mpsc::Sender<JsonRpcMessage>>,
sender: Option<mpsc::Sender<TransportMessageRecv>>,
process: Child, // we store the process to keep it alive
error_sender: mpsc::Sender<Error>,
stdin: Option<ChildStdin>,
@@ -98,7 +100,7 @@ impl StdioActor {
}
}
async fn handle_proc_output(stdout: ChildStdout, sender: mpsc::Sender<JsonRpcMessage>) {
async fn handle_proc_output(stdout: ChildStdout, sender: mpsc::Sender<TransportMessageRecv>) {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
loop {
@@ -108,7 +110,7 @@ impl StdioActor {
break;
} // EOF
Ok(_) => {
if let Ok(message) = serde_json::from_str::<JsonRpcMessage>(&line) {
if let Ok(message) = serde_json::from_str::<TransportMessageRecv>(&line) {
tracing::debug!(
message = ?message,
"Received incoming message"
@@ -149,8 +151,8 @@ impl StdioActor {
#[derive(Clone)]
pub struct StdioTransportHandle {
sender: mpsc::Sender<String>, // to process
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>, // from process
sender: mpsc::Sender<String>, // to process
receiver: Arc<Mutex<mpsc::Receiver<TransportMessageRecv>>>, // from process
error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>,
}
@@ -163,7 +165,7 @@ impl TransportHandle for StdioTransportHandle {
result
}
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
async fn receive(&self) -> Result<TransportMessageRecv, Error> {
let mut receiver = self.receiver.lock().await;
match receiver.recv().await {
Some(message) => Ok(message),

View File

@@ -1,5 +1,5 @@
use crate::oauth::{authenticate_service, ServiceConfig};
use crate::transport::Error;
use crate::transport::{Error, TransportMessageRecv};
use async_trait::async_trait;
use eventsource_client::{Client, SSE};
use futures::TryStreamExt;
@@ -25,7 +25,7 @@ pub struct StreamableHttpActor {
/// Receives messages (requests/notifications) from the handle
receiver: mpsc::Receiver<String>,
/// Sends messages (responses) back to the handle
sender: mpsc::Sender<JsonRpcMessage>,
sender: mpsc::Sender<TransportMessageRecv>,
/// MCP endpoint URL
mcp_endpoint: String,
/// HTTP client for sending requests
@@ -41,7 +41,7 @@ pub struct StreamableHttpActor {
impl StreamableHttpActor {
pub fn new(
receiver: mpsc::Receiver<String>,
sender: mpsc::Sender<JsonRpcMessage>,
sender: mpsc::Sender<TransportMessageRecv>,
mcp_endpoint: String,
session_id: Arc<RwLock<Option<String>>>,
env: HashMap<String, String>,
@@ -84,8 +84,8 @@ impl StreamableHttpActor {
debug!("Sending message to MCP endpoint: {}", message_str);
// Parse the message to determine if it's a request that expects a response
let parsed_message: JsonRpcMessage =
serde_json::from_str(&message_str).map_err(Error::Serialization)?;
let parsed_message = serde_json::from_str::<TransportMessageRecv>(&message_str)
.map_err(Error::Serialization)?;
let expects_response = matches!(
parsed_message,
@@ -196,8 +196,8 @@ impl StreamableHttpActor {
})?;
if !response_text.is_empty() {
let json_message: JsonRpcMessage =
serde_json::from_str(&response_text).map_err(Error::Serialization)?;
let json_message = serde_json::from_str::<TransportMessageRecv>(&response_text)
.map_err(Error::Serialization)?;
let _ = self.sender.send(json_message).await;
}
@@ -267,7 +267,7 @@ impl StreamableHttpActor {
// Empty line indicates end of event
if !event_data.is_empty() {
// Parse the streamed data as JSON-RPC message
match serde_json::from_str::<JsonRpcMessage>(&event_data) {
match serde_json::from_str::<TransportMessageRecv>(&event_data) {
Ok(message) => {
debug!("Received streaming HTTP response message: {:?}", message);
let _ = self.sender.send(message).await;
@@ -301,7 +301,7 @@ impl StreamableHttpActor {
#[derive(Clone)]
pub struct StreamableHttpTransportHandle {
sender: mpsc::Sender<String>,
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>,
receiver: Arc<Mutex<mpsc::Receiver<TransportMessageRecv>>>,
session_id: Arc<RwLock<Option<String>>>,
mcp_endpoint: String,
http_client: HttpClient,
@@ -314,7 +314,7 @@ impl TransportHandle for StreamableHttpTransportHandle {
serialize_and_send(&self.sender, message).await
}
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
async fn receive(&self) -> Result<TransportMessageRecv, Error> {
let mut receiver = self.receiver.lock().await;
receiver.recv().await.ok_or(Error::ChannelClosed)
}