Files
goose/crates/mcp-client/src/client.rs
2025-06-13 14:43:45 -04:00

439 lines
14 KiB
Rust

use mcp_core::protocol::{
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError,
JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult,
ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use thiserror::Error;
use tokio::sync::{mpsc, Mutex};
use tower::{timeout::TimeoutLayer, Layer, Service, ServiceExt};
use crate::{McpService, TransportHandle};
pub type BoxError = Box<dyn std::error::Error + Sync + Send>;
/// Error type for MCP client operations.
#[derive(Debug, Error)]
pub enum Error {
#[error("Transport error: {0}")]
Transport(#[from] super::transport::Error),
#[error("RPC error: code={code}, message={message}")]
RpcError { code: i32, message: String },
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("Unexpected response from server: {0}")]
UnexpectedResponse(String),
#[error("Not initialized")]
NotInitialized,
#[error("Timeout or service not ready")]
NotReady,
#[error("Request timed out")]
Timeout(#[from] tower::timeout::error::Elapsed),
#[error("Error from mcp-server: {0}")]
ServerBoxError(BoxError),
#[error("Call to '{server}' failed for '{method}'. {source}")]
McpServerError {
method: String,
server: String,
#[source]
source: BoxError,
},
}
// BoxError from mcp-server gets converted to our Error type
impl From<BoxError> for Error {
fn from(err: BoxError) -> Self {
Error::ServerBoxError(err)
}
}
#[derive(Serialize, Deserialize)]
pub struct ClientInfo {
pub name: String,
pub version: String,
}
#[derive(Serialize, Deserialize, Default)]
pub struct ClientCapabilities {
// Add fields as needed. For now, empty capabilities are fine.
}
#[derive(Serialize, Deserialize)]
pub struct InitializeParams {
#[serde(rename = "protocolVersion")]
pub protocol_version: String,
pub capabilities: ClientCapabilities,
#[serde(rename = "clientInfo")]
pub client_info: ClientInfo,
}
#[async_trait::async_trait]
pub trait McpClientTrait: Send + Sync {
async fn initialize(
&mut self,
info: ClientInfo,
capabilities: ClientCapabilities,
) -> Result<InitializeResult, Error>;
async fn list_resources(
&self,
next_cursor: Option<String>,
) -> Result<ListResourcesResult, Error>;
async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, Error>;
async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error>;
async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error>;
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage>;
}
/// The MCP client is the interface for MCP operations.
pub struct McpClient<T>
where
T: TransportHandle + Send + Sync + 'static,
{
service: Mutex<tower::timeout::Timeout<McpService<T>>>,
next_id: AtomicU64,
server_capabilities: Option<ServerCapabilities>,
server_info: Option<Implementation>,
notification_subscribers: Arc<Mutex<Vec<mpsc::Sender<JsonRpcMessage>>>>,
}
impl<T> McpClient<T>
where
T: TransportHandle + Send + Sync + 'static,
{
pub async fn connect(transport: T, timeout: std::time::Duration) -> Result<Self, Error> {
let service = McpService::new(transport.clone());
let service_ptr = service.clone();
let notification_subscribers =
Arc::new(Mutex::new(Vec::<mpsc::Sender<JsonRpcMessage>>::new()));
let subscribers_ptr = notification_subscribers.clone();
tokio::spawn(async move {
loop {
match transport.receive().await {
Ok(message) => {
tracing::info!("Received message: {:?}", message);
match message {
JsonRpcMessage::Response(JsonRpcResponse { id: Some(id), .. })
| JsonRpcMessage::Error(JsonRpcError { id: Some(id), .. }) => {
service_ptr.respond(&id.to_string(), Ok(message)).await;
}
_ => {
let mut subs = subscribers_ptr.lock().await;
subs.retain(|sub| sub.try_send(message.clone()).is_ok());
}
}
}
Err(e) => {
service_ptr.hangup(e).await;
subscribers_ptr.lock().await.clear();
break;
}
}
}
});
let middleware = TimeoutLayer::new(timeout);
Ok(Self {
service: Mutex::new(middleware.layer(service)),
next_id: AtomicU64::new(1),
server_capabilities: None,
server_info: None,
notification_subscribers,
})
}
/// Send a JSON-RPC request and check we don't get an error response.
async fn send_request<R>(&self, method: &str, params: Value) -> Result<R, Error>
where
R: for<'de> Deserialize<'de>,
{
let mut service = self.service.lock().await;
service.ready().await.map_err(|_| Error::NotReady)?;
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let mut params = params.clone();
params["_meta"] = json!({
"progressToken": format!("prog-{}", id),
});
let request = JsonRpcMessage::Request(JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(id),
method: method.to_string(),
params: Some(params),
});
let response_msg = service
.call(request)
.await
.map_err(|e| Error::McpServerError {
server: self
.server_info
.as_ref()
.map(|s| s.name.clone())
.unwrap_or("".to_string()),
method: method.to_string(),
// we don't need include params because it can be really large
source: Box::<Error>::new(e.into()),
})?;
match response_msg {
JsonRpcMessage::Response(JsonRpcResponse {
id, result, error, ..
}) => {
// Verify id matches
if id != Some(self.next_id.load(Ordering::SeqCst) - 1) {
return Err(Error::UnexpectedResponse(
"id mismatch for JsonRpcResponse".to_string(),
));
}
if let Some(err) = error {
Err(Error::RpcError {
code: err.code,
message: err.message,
})
} else if let Some(r) = result {
Ok(serde_json::from_value(r)?)
} else {
Err(Error::UnexpectedResponse("missing result".to_string()))
}
}
JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => {
if id != Some(self.next_id.load(Ordering::SeqCst) - 1) {
return Err(Error::UnexpectedResponse(
"id mismatch for JsonRpcError".to_string(),
));
}
Err(Error::RpcError {
code: error.code,
message: error.message,
})
}
_ => {
// Requests/notifications not expected as a response
Err(Error::UnexpectedResponse(
"unexpected message type".to_string(),
))
}
}
}
/// Send a JSON-RPC notification.
async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> {
let mut service = self.service.lock().await;
service.ready().await.map_err(|_| Error::NotReady)?;
let notification = JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: method.to_string(),
params: Some(params.clone()),
});
service
.call(notification)
.await
.map_err(|e| Error::McpServerError {
server: self
.server_info
.as_ref()
.map(|s| s.name.clone())
.unwrap_or("".to_string()),
method: method.to_string(),
// we don't need include params because it can be really large
source: Box::<Error>::new(e.into()),
})?;
Ok(())
}
// Check if the client has completed initialization
fn completed_initialization(&self) -> bool {
self.server_capabilities.is_some()
}
}
#[async_trait::async_trait]
impl<T> McpClientTrait for McpClient<T>
where
T: TransportHandle + Send + Sync + 'static,
{
async fn initialize(
&mut self,
info: ClientInfo,
capabilities: ClientCapabilities,
) -> Result<InitializeResult, Error> {
let params = InitializeParams {
protocol_version: "2025-03-26".to_string(),
client_info: info,
capabilities,
};
let result: InitializeResult = self
.send_request("initialize", serde_json::to_value(params)?)
.await?;
self.send_notification("notifications/initialized", serde_json::json!({}))
.await?;
self.server_capabilities = Some(result.capabilities.clone());
self.server_info = Some(result.server_info.clone());
Ok(result)
}
async fn list_resources(
&self,
next_cursor: Option<String>,
) -> Result<ListResourcesResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If resources is not supported, return an empty list
if self
.server_capabilities
.as_ref()
.unwrap()
.resources
.is_none()
{
return Ok(ListResourcesResult {
resources: vec![],
next_cursor: None,
});
}
let payload = next_cursor
.map(|cursor| serde_json::json!({"cursor": cursor}))
.unwrap_or_else(|| serde_json::json!({}));
self.send_request("resources/list", payload).await
}
async fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If resources is not supported, return an error
if self
.server_capabilities
.as_ref()
.unwrap()
.resources
.is_none()
{
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'resources' capability".to_string(),
});
}
let params = serde_json::json!({ "uri": uri });
self.send_request("resources/read", params).await
}
async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If tools is not supported, return an empty list
if self.server_capabilities.as_ref().unwrap().tools.is_none() {
return Ok(ListToolsResult {
tools: vec![],
next_cursor: None,
});
}
let payload = next_cursor
.map(|cursor| serde_json::json!({"cursor": cursor}))
.unwrap_or_else(|| serde_json::json!({}));
self.send_request("tools/list", payload).await
}
async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If tools is not supported, return an error
if self.server_capabilities.as_ref().unwrap().tools.is_none() {
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'tools' capability".to_string(),
});
}
let params = serde_json::json!({ "name": name, "arguments": arguments });
// TODO ERROR: check that if there is an error, we send back is_error: true with msg
// https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2
self.send_request("tools/call", params).await
}
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If prompts is not supported, return an error
if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'prompts' capability".to_string(),
});
}
let payload = next_cursor
.map(|cursor| serde_json::json!({"cursor": cursor}))
.unwrap_or_else(|| serde_json::json!({}));
self.send_request("prompts/list", payload).await
}
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If prompts is not supported, return an error
if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'prompts' capability".to_string(),
});
}
let params = serde_json::json!({ "name": name, "arguments": arguments });
self.send_request("prompts/get", params).await
}
async fn subscribe(&self) -> mpsc::Receiver<JsonRpcMessage> {
let (tx, rx) = mpsc::channel(16);
self.notification_subscribers.lock().await.push(tx);
rx
}
}