feat: Handle MCP server notification messages (#2613)

Co-authored-by: Michael Neale <michael.neale@gmail.com>
This commit is contained in:
Jack Amadeo
2025-05-30 11:50:14 -04:00
committed by GitHub
parent eeb61ace22
commit 03e5549b54
40 changed files with 1186 additions and 443 deletions

View File

@@ -1,7 +1,6 @@
use mcp_client::{
client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait},
transport::{SseTransport, StdioTransport, Transport},
McpService,
};
use rand::Rng;
use rand::SeedableRng;
@@ -20,18 +19,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
let handle1 = transport1.start().await?;
let service1 = McpService::with_timeout(handle1, Duration::from_secs(30));
let client1 = McpClient::new(service1);
let client1 = McpClient::connect(handle1, Duration::from_secs(30)).await?;
let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
let handle2 = transport2.start().await?;
let service2 = McpService::with_timeout(handle2, Duration::from_secs(30));
let client2 = McpClient::new(service2);
let client2 = McpClient::connect(handle2, Duration::from_secs(30)).await?;
let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new());
let handle3 = transport3.start().await?;
let service3 = McpService::with_timeout(handle3, Duration::from_secs(10));
let client3 = McpClient::new(service3);
let client3 = McpClient::connect(handle3, Duration::from_secs(10)).await?;
// Initialize both clients
let mut clients: Vec<Box<dyn McpClientTrait>> =

View File

@@ -0,0 +1,122 @@
use anyhow::Result;
use futures::lock::Mutex;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, Transport};
use mcp_client::StdioTransport;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tracing_subscriber::EnvFilter;
#[tokio::main]
async fn main() -> Result<()> {
// Initialize logging
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::from_default_env()
.add_directive("mcp_client=debug".parse().unwrap())
.add_directive("eventsource_client=info".parse().unwrap()),
)
.init();
test_transport(sse_transport().await?).await?;
test_transport(stdio_transport().await?).await?;
Ok(())
}
async fn sse_transport() -> Result<SseTransport> {
let port = "60053";
tokio::process::Command::new("npx")
.env("PORT", port)
.arg("@modelcontextprotocol/server-everything")
.arg("sse")
.spawn()?;
tokio::time::sleep(Duration::from_secs(1)).await;
Ok(SseTransport::new(
format!("http://localhost:{}/sse", port),
HashMap::new(),
))
}
async fn stdio_transport() -> Result<StdioTransport> {
Ok(StdioTransport::new(
"npx",
vec!["@modelcontextprotocol/server-everything"]
.into_iter()
.map(|s| s.to_string())
.collect(),
HashMap::new(),
))
}
async fn test_transport<T>(transport: T) -> Result<()>
where
T: Transport + Send + 'static,
{
// Start transport
let handle = transport.start().await?;
// Create client
let mut client = McpClient::connect(handle, Duration::from_secs(10)).await?;
println!("Client created\n");
let mut receiver = client.subscribe().await;
let events = Arc::new(Mutex::new(Vec::new()));
let events_clone = events.clone();
tokio::spawn(async move {
while let Some(event) = receiver.recv().await {
println!("Received event: {event:?}");
events_clone.lock().await.push(event);
}
});
// Initialize
let server_info = client
.initialize(
ClientInfo {
name: "test-client".into(),
version: "1.0.0".into(),
},
ClientCapabilities::default(),
)
.await?;
println!("Connected to server: {server_info:?}\n");
// Sleep for 100ms to allow the server to start - surprisingly this is required!
tokio::time::sleep(Duration::from_millis(500)).await;
// List tools
let tools = client.list_tools(None).await?;
println!("Available tools: {tools:#?}\n");
// Call tool
let tool_result = client
.call_tool("echo", serde_json::json!({ "message": "honk" }))
.await?;
println!("Tool result: {tool_result:#?}\n");
let collected_eventes_before = events.lock().await.len();
let n_steps = 5;
let long_op = client
.call_tool(
"longRunningOperation",
serde_json::json!({ "duration": 3, "steps": n_steps }),
)
.await?;
println!("Long op result: {long_op:#?}\n");
let collected_events_after = events.lock().await.len();
assert_eq!(collected_events_after - collected_eventes_before, n_steps);
// List resources
let resources = client.list_resources(None).await?;
println!("Resources: {resources:#?}\n");
// Read resource
let resource = client.read_resource("test://static/resource/1").await?;
println!("Resource: {resource:#?}\n");
Ok(())
}

View File

@@ -1,7 +1,6 @@
use anyhow::Result;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, Transport};
use mcp_client::McpService;
use std::collections::HashMap;
use std::time::Duration;
use tracing_subscriber::EnvFilter;
@@ -23,11 +22,8 @@ async fn main() -> Result<()> {
// Start transport
let handle = transport.start().await?;
// Create the service with timeout middleware
let service = McpService::with_timeout(handle, Duration::from_secs(3));
// Create client
let mut client = McpClient::new(service);
let mut client = McpClient::connect(handle, Duration::from_secs(3)).await?;
println!("Client created\n");
// Initialize

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use anyhow::Result;
use mcp_client::{
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, McpService,
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
StdioTransport, Transport,
};
use std::time::Duration;
@@ -25,11 +25,8 @@ async fn main() -> Result<(), ClientError> {
// 2) Start the transport to get a handle
let transport_handle = transport.start().await?;
// 3) Create the service with timeout middleware
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));
// 4) Create the client with the middleware-wrapped service
let mut client = McpClient::new(service);
// 3) Create the client with the middleware-wrapped service
let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?;
// Initialize
let server_info = client

View File

@@ -5,7 +5,6 @@ use mcp_client::client::{
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
};
use mcp_client::transport::{StdioTransport, Transport};
use mcp_client::McpService;
use std::collections::HashMap;
use std::time::Duration;
use tracing_subscriber::EnvFilter;
@@ -34,11 +33,8 @@ async fn main() -> Result<(), ClientError> {
// Start the transport to get a handle
let transport_handle = transport.start().await.unwrap();
// Create the service with timeout middleware
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));
// Create client
let mut client = McpClient::new(service);
let mut client = McpClient::connect(transport_handle, Duration::from_secs(10)).await?;
// Initialize
let server_info = client

View File

@@ -4,11 +4,16 @@ use mcp_core::protocol::{
ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::atomic::{AtomicU64, Ordering};
use serde_json::{json, Value};
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use thiserror::Error;
use tokio::sync::Mutex;
use tower::{Service, ServiceExt}; // for Service::ready()
use tokio::sync::{mpsc, Mutex};
use tower::{timeout::TimeoutLayer, Layer, Service, ServiceExt};
use crate::{McpService, TransportHandle};
pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
@@ -97,34 +102,67 @@ pub trait McpClientTrait: Send + Sync {
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage>;
}
/// The MCP client is the interface for MCP operations.
pub struct McpClient<S>
pub struct McpClient<T>
where
S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
S::Error: Into<Error>,
S::Future: Send,
T: TransportHandle + Send + Sync + 'static,
{
service: Mutex<S>,
service: Mutex<tower::timeout::Timeout<McpService<T>>>,
next_id: AtomicU64,
server_capabilities: Option<ServerCapabilities>,
server_info: Option<Implementation>,
notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<JsonRpcMessage>>>>,
}
impl<S> McpClient<S>
impl<T> McpClient<T>
where
S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
S::Error: Into<Error>,
S::Future: Send,
T: TransportHandle + Send + Sync + 'static,
{
pub fn new(service: S) -> Self {
Self {
service: Mutex::new(service),
pub async fn connect(transport: T, timeout: std::time::Duration) -> Result<Self, Error> {
let service = McpService::new(transport.clone());
let service_ptr = service.clone();
let notification_subscribers =
Arc::new(Mutex::new(Vec::<mpsc::Sender<JsonRpcMessage>>::new()));
let subscribers_ptr = notification_subscribers.clone();
tokio::spawn(async move {
loop {
match transport.receive().await {
Ok(message) => {
tracing::info!("Received message: {:?}", message);
match message {
JsonRpcMessage::Response(JsonRpcResponse { id: Some(id), .. }) => {
service_ptr.respond(&id.to_string(), Ok(message)).await;
}
_ => {
let mut subs = subscribers_ptr.lock().await;
subs.retain(|sub| sub.try_send(message.clone()).is_ok());
}
}
}
Err(e) => {
tracing::error!("transport error: {:?}", e);
service_ptr.hangup().await;
subscribers_ptr.lock().await.clear();
break;
}
}
}
});
let middleware = TimeoutLayer::new(timeout);
Ok(Self {
service: Mutex::new(middleware.layer(service)),
next_id: AtomicU64::new(1),
server_capabilities: None,
server_info: None,
}
notification_subscribers,
})
}
/// Send a JSON-RPC request and check we don't get an error response.
@@ -134,13 +172,18 @@ where
{
let mut service = self.service.lock().await;
service.ready().await.map_err(|_| Error::NotReady)?;
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let mut params = params.clone();
params["_meta"] = json!({
"progressToken": format!("prog-{}", id),
});
let request = JsonRpcMessage::Request(JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(id),
method: method.to_string(),
params: Some(params.clone()),
params: Some(params),
});
let response_msg = service
@@ -154,7 +197,7 @@ where
.unwrap_or("".to_string()),
method: method.to_string(),
// we don't need include params because it can be really large
source: Box::new(e.into()),
source: Box::<Error>::new(e.into()),
})?;
match response_msg {
@@ -220,7 +263,7 @@ where
.unwrap_or("".to_string()),
method: method.to_string(),
// we don't need include params because it can be really large
source: Box::new(e.into()),
source: Box::<Error>::new(e.into()),
})?;
Ok(())
@@ -233,11 +276,9 @@ where
}
#[async_trait::async_trait]
impl<S> McpClientTrait for McpClient<S>
impl<T> McpClientTrait for McpClient<T>
where
S: Service<JsonRpcMessage, Response = JsonRpcMessage> + Clone + Send + Sync + 'static,
S::Error: Into<Error>,
S::Future: Send,
T: TransportHandle + Send + Sync + 'static,
{
async fn initialize(
&mut self,
@@ -388,4 +429,10 @@ where
self.send_request("prompts/get", params).await
}
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
let (tx, rx) = mpsc::channel(16);
self.notification_subscribers.lock().await.push(tx);
rx
}
}

View File

@@ -1,7 +1,9 @@
use futures::future::BoxFuture;
use mcp_core::protocol::JsonRpcMessage;
use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest};
use std::collections::HashMap;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::{oneshot, RwLock};
use tower::{timeout::Timeout, Service, ServiceBuilder};
use crate::transport::{Error, TransportHandle};
@@ -10,14 +12,24 @@ use crate::transport::{Error, TransportHandle};
#[derive(Clone)]
pub struct McpService<T: TransportHandle> {
inner: Arc<T>,
pending_requests: Arc<PendingRequests>,
}
impl<T: TransportHandle> McpService<T> {
pub fn new(transport: T) -> Self {
Self {
inner: Arc::new(transport),
pending_requests: Arc::new(PendingRequests::default()),
}
}
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
self.pending_requests.respond(id, response).await
}
pub async fn hangup(&self) {
self.pending_requests.broadcast_close().await
}
}
impl<T> Service<JsonRpcMessage> for McpService<T>
@@ -35,7 +47,31 @@ where
fn call(&mut self, request: JsonRpcMessage) -> Self::Future {
let transport = self.inner.clone();
Box::pin(async move { transport.send(request).await })
let pending_requests = self.pending_requests.clone();
Box::pin(async move {
match request {
JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) => {
// Create a channel to receive the response
let (sender, receiver) = oneshot::channel();
pending_requests.insert(id.to_string(), sender).await;
transport.send(request).await?;
receiver.await.map_err(|_| Error::ChannelClosed)?
}
JsonRpcMessage::Request(_) => {
// Handle notifications without waiting for a response
transport.send(request).await?;
Ok(JsonRpcMessage::Nil)
}
JsonRpcMessage::Notification(_) => {
// Handle notifications without waiting for a response
transport.send(request).await?;
Ok(JsonRpcMessage::Nil)
}
_ => Err(Error::UnsupportedMessage),
}
})
}
}
@@ -50,3 +86,50 @@ where
.service(McpService::new(transport))
}
}
// A data structure to store pending requests and their response channels
pub struct PendingRequests {
requests: RwLock<HashMap<String, oneshot::Sender<Result<JsonRpcMessage, Error>>>>,
}
impl Default for PendingRequests {
fn default() -> Self {
Self::new()
}
}
impl PendingRequests {
pub fn new() -> Self {
Self {
requests: RwLock::new(HashMap::new()),
}
}
pub async fn insert(&self, id: String, sender: oneshot::Sender<Result<JsonRpcMessage, Error>>) {
self.requests.write().await.insert(id, sender);
}
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
if let Some(tx) = self.requests.write().await.remove(id) {
let _ = tx.send(response);
}
}
pub async fn broadcast_close(&self) {
for (_, tx) in self.requests.write().await.drain() {
let _ = tx.send(Err(Error::ChannelClosed));
}
}
pub async fn clear(&self) {
self.requests.write().await.clear();
}
pub async fn len(&self) -> usize {
self.requests.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
}

View File

@@ -1,8 +1,7 @@
use async_trait::async_trait;
use mcp_core::protocol::JsonRpcMessage;
use std::collections::HashMap;
use thiserror::Error;
use tokio::sync::{mpsc, oneshot, RwLock};
use tokio::sync::{mpsc, oneshot};
pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
/// A generic error type for transport operations.
@@ -57,74 +56,20 @@ pub trait Transport {
#[async_trait]
pub trait TransportHandle: Send + Sync + Clone + 'static {
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error>;
async fn send(&self, message: JsonRpcMessage) -> Result<(), Error>;
async fn receive(&self) -> Result<JsonRpcMessage, Error>;
}
// Helper function that contains the common send implementation
pub async fn send_message(
sender: &mpsc::Sender<TransportMessage>,
pub async fn serialize_and_send(
sender: &mpsc::Sender<String>,
message: JsonRpcMessage,
) -> Result<JsonRpcMessage, Error> {
match message {
JsonRpcMessage::Request(request) => {
let (respond_to, response) = oneshot::channel();
let msg = TransportMessage {
message: JsonRpcMessage::Request(request),
response_tx: Some(respond_to),
};
sender.send(msg).await.map_err(|_| Error::ChannelClosed)?;
Ok(response.await.map_err(|_| Error::ChannelClosed)??)
) -> Result<(), Error> {
match serde_json::to_string(&message).map_err(Error::Serialization) {
Ok(msg) => sender.send(msg).await.map_err(|_| Error::ChannelClosed),
Err(e) => {
tracing::error!(error = ?e, "Error serializing message");
Err(e)
}
JsonRpcMessage::Notification(notification) => {
let msg = TransportMessage {
message: JsonRpcMessage::Notification(notification),
response_tx: None,
};
sender.send(msg).await.map_err(|_| Error::ChannelClosed)?;
Ok(JsonRpcMessage::Nil)
}
_ => Err(Error::UnsupportedMessage),
}
}
// A data structure to store pending requests and their response channels
pub struct PendingRequests {
requests: RwLock<HashMap<String, oneshot::Sender<Result<JsonRpcMessage, Error>>>>,
}
impl Default for PendingRequests {
fn default() -> Self {
Self::new()
}
}
impl PendingRequests {
pub fn new() -> Self {
Self {
requests: RwLock::new(HashMap::new()),
}
}
pub async fn insert(&self, id: String, sender: oneshot::Sender<Result<JsonRpcMessage, Error>>) {
self.requests.write().await.insert(id, sender);
}
pub async fn respond(&self, id: &str, response: Result<JsonRpcMessage, Error>) {
if let Some(tx) = self.requests.write().await.remove(id) {
let _ = tx.send(response);
}
}
pub async fn clear(&self) {
self.requests.write().await.clear();
}
pub async fn len(&self) -> usize {
self.requests.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
}

View File

@@ -1,17 +1,17 @@
use crate::transport::{Error, PendingRequests, TransportMessage};
use crate::transport::Error;
use async_trait::async_trait;
use eventsource_client::{Client, SSE};
use futures::TryStreamExt;
use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest};
use mcp_core::protocol::JsonRpcMessage;
use reqwest::Client as HttpClient;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio::time::{timeout, Duration};
use tracing::warn;
use url::Url;
use super::{send_message, Transport, TransportHandle};
use super::{serialize_and_send, Transport, TransportHandle};
// Timeout for the endpoint discovery
const ENDPOINT_TIMEOUT_SECS: u64 = 5;
@@ -21,9 +21,9 @@ const ENDPOINT_TIMEOUT_SECS: u64 = 5;
/// - Sends outgoing messages via HTTP POST (once the post endpoint is known).
pub struct SseActor {
/// Receives messages (requests/notifications) from the handle
receiver: mpsc::Receiver<TransportMessage>,
/// Map of request-id -> oneshot sender
pending_requests: Arc<PendingRequests>,
receiver: mpsc::Receiver<String>,
/// Sends messages (responses) back to the handle
sender: mpsc::Sender<JsonRpcMessage>,
/// Base SSE URL
sse_url: String,
/// For sending HTTP POST requests
@@ -34,14 +34,14 @@ pub struct SseActor {
impl SseActor {
pub fn new(
receiver: mpsc::Receiver<TransportMessage>,
pending_requests: Arc<PendingRequests>,
receiver: mpsc::Receiver<String>,
sender: mpsc::Sender<JsonRpcMessage>,
sse_url: String,
post_endpoint: Arc<RwLock<Option<String>>>,
) -> Self {
Self {
receiver,
pending_requests,
sender,
sse_url,
post_endpoint,
http_client: HttpClient::new(),
@@ -54,15 +54,14 @@ impl SseActor {
pub async fn run(self) {
tokio::join!(
Self::handle_incoming_messages(
self.sender,
self.sse_url.clone(),
Arc::clone(&self.pending_requests),
Arc::clone(&self.post_endpoint)
),
Self::handle_outgoing_messages(
self.receiver,
self.http_client.clone(),
Arc::clone(&self.post_endpoint),
Arc::clone(&self.pending_requests),
)
);
}
@@ -72,14 +71,13 @@ impl SseActor {
/// - If a `message` event is received, parse it as `JsonRpcMessage`
/// and respond to pending requests if it's a `Response`.
async fn handle_incoming_messages(
sender: mpsc::Sender<JsonRpcMessage>,
sse_url: String,
pending_requests: Arc<PendingRequests>,
post_endpoint: Arc<RwLock<Option<String>>>,
) {
let client = match eventsource_client::ClientBuilder::for_url(&sse_url) {
Ok(builder) => builder.build(),
Err(e) => {
pending_requests.clear().await;
warn!("Failed to connect SSE client: {}", e);
return;
}
@@ -105,84 +103,54 @@ impl SseActor {
}
// Now handle subsequent events
while let Ok(Some(event)) = stream.try_next().await {
match event {
SSE::Event(e) if e.event_type == "message" => {
// Attempt to parse the SSE data as a JsonRpcMessage
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
Ok(message) => {
match &message {
JsonRpcMessage::Response(response) => {
if let Some(id) = &response.id {
pending_requests
.respond(&id.to_string(), Ok(message))
.await;
}
loop {
match stream.try_next().await {
Ok(Some(event)) => {
match event {
SSE::Event(e) if e.event_type == "message" => {
// Attempt to parse the SSE data as a JsonRpcMessage
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
Ok(message) => {
let _ = sender.send(message).await;
}
JsonRpcMessage::Error(error) => {
if let Some(id) = &error.id {
pending_requests
.respond(&id.to_string(), Ok(message))
.await;
}
Err(err) => {
warn!("Failed to parse SSE message: {err}");
}
_ => {} // TODO: Handle other variants (Request, etc.)
}
}
Err(err) => {
warn!("Failed to parse SSE message: {err}");
}
_ => { /* ignore other events */ }
}
}
_ => { /* ignore other events */ }
Ok(None) => {
// Stream ended
tracing::info!("SSE stream ended.");
break;
}
Err(e) => {
warn!("Error reading SSE stream: {e}");
break;
}
}
}
// SSE stream ended or errored; signal any pending requests
tracing::error!("SSE stream ended or encountered an error; clearing pending requests.");
pending_requests.clear().await;
tracing::error!("SSE stream ended or encountered an error.");
}
/// Continuously receives messages from the `mpsc::Receiver`.
/// - If it's a request, store the oneshot in `pending_requests`.
/// - POST the message to the discovered endpoint (once known).
async fn handle_outgoing_messages(
mut receiver: mpsc::Receiver<TransportMessage>,
mut receiver: mpsc::Receiver<String>,
http_client: HttpClient,
post_endpoint: Arc<RwLock<Option<String>>>,
pending_requests: Arc<PendingRequests>,
) {
while let Some(transport_msg) = receiver.recv().await {
while let Some(message_str) = receiver.recv().await {
let post_url = match post_endpoint.read().await.as_ref() {
Some(url) => url.clone(),
None => {
if let Some(response_tx) = transport_msg.response_tx {
let _ = response_tx.send(Err(Error::NotConnected));
}
// TODO: the endpoint isn't discovered yet. This shouldn't happen -- we only return the handle
// after the endpoint is set.
continue;
}
};
// Serialize the JSON-RPC message
let message_str = match serde_json::to_string(&transport_msg.message) {
Ok(s) => s,
Err(e) => {
if let Some(tx) = transport_msg.response_tx {
let _ = tx.send(Err(Error::Serialization(e)));
}
continue;
}
};
// If it's a request, store the channel so we can respond later
if let Some(response_tx) = transport_msg.response_tx {
if let JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) =
&transport_msg.message
{
pending_requests.insert(id.to_string(), response_tx).await;
}
}
// Perform the HTTP POST
match http_client
.post(&post_url)
@@ -209,26 +177,25 @@ impl SseActor {
}
}
// mpsc channel closed => no more outgoing messages
let pending = pending_requests.len().await;
if pending > 0 {
tracing::error!("SSE stream ended or encountered an error with {pending} unfulfilled pending requests.");
pending_requests.clear().await;
} else {
tracing::info!("SseActor shutdown cleanly. No pending requests.");
}
tracing::info!("SseActor shut down.");
}
}
#[derive(Clone)]
pub struct SseTransportHandle {
sender: mpsc::Sender<TransportMessage>,
sender: mpsc::Sender<String>,
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>,
}
#[async_trait::async_trait]
impl TransportHandle for SseTransportHandle {
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> {
send_message(&self.sender, message).await
async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> {
serialize_and_send(&self.sender, message).await
}
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
let mut receiver = self.receiver.lock().await;
receiver.recv().await.ok_or(Error::ChannelClosed)
}
}
@@ -279,17 +246,13 @@ impl Transport for SseTransport {
// Create a channel for outgoing TransportMessages
let (tx, rx) = mpsc::channel(32);
let (otx, orx) = mpsc::channel(32);
let post_endpoint: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
let post_endpoint_clone = Arc::clone(&post_endpoint);
// Build the actor
let actor = SseActor::new(
rx,
Arc::new(PendingRequests::new()),
self.sse_url.clone(),
post_endpoint,
);
let actor = SseActor::new(rx, otx, self.sse_url.clone(), post_endpoint);
// Spawn the actor task
tokio::spawn(actor.run());
@@ -301,7 +264,10 @@ impl Transport for SseTransport {
)
.await
{
Ok(_) => Ok(SseTransportHandle { sender: tx }),
Ok(_) => Ok(SseTransportHandle {
sender: tx,
receiver: Arc::new(Mutex::new(orx)),
}),
Err(e) => Err(Error::SseConnection(e.to_string())),
}
}

View File

@@ -14,7 +14,7 @@ use nix::sys::signal::{kill, Signal};
#[cfg(unix)]
use nix::unistd::{getpgid, Pid};
use super::{send_message, Error, PendingRequests, Transport, TransportHandle, TransportMessage};
use super::{serialize_and_send, Error, Transport, TransportHandle};
// Global to track process groups we've created
static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
@@ -23,8 +23,8 @@ static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
///
/// It uses channels for message passing and handles responses asynchronously through a background task.
pub struct StdioActor {
receiver: Option<mpsc::Receiver<TransportMessage>>,
pending_requests: Arc<PendingRequests>,
receiver: Option<mpsc::Receiver<String>>,
sender: Option<mpsc::Sender<JsonRpcMessage>>,
process: Child, // we store the process to keep it alive
error_sender: mpsc::Sender<Error>,
stdin: Option<ChildStdin>,
@@ -55,11 +55,11 @@ impl StdioActor {
let stdout = self.stdout.take().expect("stdout should be available");
let stdin = self.stdin.take().expect("stdin should be available");
let receiver = self.receiver.take().expect("receiver should be available");
let msg_inbox = self.receiver.take().expect("receiver should be available");
let msg_outbox = self.sender.take().expect("sender should be available");
let incoming = Self::handle_incoming_messages(stdout, self.pending_requests.clone());
let outgoing =
Self::handle_outgoing_messages(receiver, stdin, self.pending_requests.clone());
let incoming = Self::handle_proc_output(stdout, msg_outbox);
let outgoing = Self::handle_proc_input(stdin, msg_inbox);
// take ownership of futures for tokio::select
pin!(incoming);
@@ -96,12 +96,9 @@ impl StdioActor {
.await;
}
}
// Clean up regardless of which path we took
self.pending_requests.clear().await;
}
async fn handle_incoming_messages(stdout: ChildStdout, pending_requests: Arc<PendingRequests>) {
async fn handle_proc_output(stdout: ChildStdout, sender: mpsc::Sender<JsonRpcMessage>) {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
loop {
@@ -116,20 +113,12 @@ impl StdioActor {
message = ?message,
"Received incoming message"
);
match &message {
JsonRpcMessage::Response(response) => {
if let Some(id) = &response.id {
pending_requests.respond(&id.to_string(), Ok(message)).await;
}
}
JsonRpcMessage::Error(error) => {
if let Some(id) = &error.id {
pending_requests.respond(&id.to_string(), Ok(message)).await;
}
}
_ => {} // TODO: Handle other variants (Request, etc.)
}
let _ = sender.send(message).await;
} else {
tracing::warn!(
message = ?line,
"Failed to parse incoming message"
);
}
line.clear();
}
@@ -141,44 +130,20 @@ impl StdioActor {
}
}
async fn handle_outgoing_messages(
mut receiver: mpsc::Receiver<TransportMessage>,
mut stdin: ChildStdin,
pending_requests: Arc<PendingRequests>,
) {
while let Some(mut transport_msg) = receiver.recv().await {
let message_str = match serde_json::to_string(&transport_msg.message) {
Ok(s) => s,
Err(e) => {
if let Some(tx) = transport_msg.response_tx.take() {
let _ = tx.send(Err(Error::Serialization(e)));
}
continue;
}
};
tracing::debug!(message = ?transport_msg.message, "Sending outgoing message");
if let Some(response_tx) = transport_msg.response_tx.take() {
if let JsonRpcMessage::Request(request) = &transport_msg.message {
if let Some(id) = &request.id {
pending_requests.insert(id.to_string(), response_tx).await;
}
}
}
async fn handle_proc_input(mut stdin: ChildStdin, mut receiver: mpsc::Receiver<String>) {
while let Some(message_str) = receiver.recv().await {
tracing::debug!(message = ?message_str, "Sending outgoing message");
if let Err(e) = stdin
.write_all(format!("{}\n", message_str).as_bytes())
.await
{
tracing::error!(error = ?e, "Error writing message to child process");
pending_requests.clear().await;
break;
}
if let Err(e) = stdin.flush().await {
tracing::error!(error = ?e, "Error flushing message to child process");
pending_requests.clear().await;
break;
}
}
@@ -187,18 +152,24 @@ impl StdioActor {
#[derive(Clone)]
pub struct StdioTransportHandle {
sender: mpsc::Sender<TransportMessage>,
sender: mpsc::Sender<String>, // to process
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>, // from process
error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>,
}
#[async_trait::async_trait]
impl TransportHandle for StdioTransportHandle {
async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> {
let result = send_message(&self.sender, message).await;
async fn send(&self, message: JsonRpcMessage) -> Result<(), Error> {
let result = serialize_and_send(&self.sender, message).await;
// Check for any pending errors even if send is successful
self.check_for_errors().await?;
result
}
async fn receive(&self) -> Result<JsonRpcMessage, Error> {
let mut receiver = self.receiver.lock().await;
receiver.recv().await.ok_or(Error::ChannelClosed)
}
}
impl StdioTransportHandle {
@@ -289,12 +260,13 @@ impl Transport for StdioTransport {
async fn start(&self) -> Result<Self::Handle, Error> {
let (process, stdin, stdout, stderr) = self.spawn_process().await?;
let (message_tx, message_rx) = mpsc::channel(32);
let (outbox_tx, outbox_rx) = mpsc::channel(32);
let (inbox_tx, inbox_rx) = mpsc::channel(32);
let (error_tx, error_rx) = mpsc::channel(1);
let actor = StdioActor {
receiver: Some(message_rx),
pending_requests: Arc::new(PendingRequests::new()),
receiver: Some(outbox_rx), // client to process
sender: Some(inbox_tx), // process to client
process,
error_sender: error_tx,
stdin: Some(stdin),
@@ -305,7 +277,8 @@ impl Transport for StdioTransport {
tokio::spawn(actor.run());
let handle = StdioTransportHandle {
sender: message_tx,
sender: outbox_tx, // client to process
receiver: Arc::new(Mutex::new(inbox_rx)), // process to client
error_receiver: Arc::new(Mutex::new(error_rx)),
};
Ok(handle)