feat: Handle MCP server notification messages (#2613)

Co-authored-by: Michael Neale <michael.neale@gmail.com>
This commit is contained in:
Jack Amadeo
2025-05-30 11:50:14 -04:00
committed by GitHub
parent eeb61ace22
commit 03e5549b54
40 changed files with 1186 additions and 443 deletions

View File

@@ -6,7 +6,7 @@ use serde_json::{json, Value};
use std::{
collections::HashMap, fs, future::Future, path::PathBuf, pin::Pin, sync::Arc, sync::Mutex,
};
use tokio::process::Command;
use tokio::{process::Command, sync::mpsc};
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
@@ -14,7 +14,7 @@ use std::os::unix::fs::PermissionsExt;
use mcp_core::{
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
protocol::{JsonRpcMessage, ServerCapabilities},
resource::Resource,
tool::{Tool, ToolAnnotations},
Content,
@@ -1155,6 +1155,7 @@ impl Router for ComputerControllerRouter {
&self,
tool_name: &str,
arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();

View File

@@ -13,13 +13,17 @@ use std::{
path::{Path, PathBuf},
pin::Pin,
};
use tokio::process::Command;
use tokio::{
io::{AsyncBufReadExt, BufReader},
process::Command,
sync::mpsc,
};
use url::Url;
use include_dir::{include_dir, Dir};
use mcp_core::{
handler::{PromptError, ResourceError, ToolError},
protocol::ServerCapabilities,
protocol::{JsonRpcMessage, JsonRpcNotification, ServerCapabilities},
resource::Resource,
tool::Tool,
Content,
@@ -456,7 +460,11 @@ impl DeveloperRouter {
}
// Shell command execution with platform-specific handling
async fn bash(&self, params: Value) -> Result<Vec<Content>, ToolError> {
async fn bash(
&self,
params: Value,
notifier: mpsc::Sender<JsonRpcMessage>,
) -> Result<Vec<Content>, ToolError> {
let command =
params
.get("command")
@@ -488,27 +496,92 @@ impl DeveloperRouter {
// Get platform-specific shell configuration
let shell_config = get_shell_config();
let cmd_with_redirect = format_command_for_platform(command);
let cmd_str = format_command_for_platform(command);
// Execute the command using platform-specific shell
let child = Command::new(&shell_config.executable)
let mut child = Command::new(&shell_config.executable)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.stdin(Stdio::null())
.kill_on_drop(true)
.arg(&shell_config.arg)
.arg(cmd_with_redirect)
.arg(cmd_str)
.spawn()
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
let stdout = child.stdout.take().unwrap();
let stderr = child.stderr.take().unwrap();
let mut stdout_reader = BufReader::new(stdout);
let mut stderr_reader = BufReader::new(stderr);
let output_task = tokio::spawn(async move {
let mut combined_output = String::new();
let mut stdout_buf = Vec::new();
let mut stderr_buf = Vec::new();
loop {
tokio::select! {
n = stdout_reader.read_until(b'\n', &mut stdout_buf) => {
if n? == 0 {
break;
}
let line = String::from_utf8_lossy(&stdout_buf);
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "notifications/message".to_string(),
params: Some(json!({
"data": {
"type": "shell",
"stream": "stdout",
"output": line.to_string(),
}
})),
}))
.ok();
combined_output.push_str(&line);
stdout_buf.clear();
}
n = stderr_reader.read_until(b'\n', &mut stderr_buf) => {
if n? == 0 {
break;
}
let line = String::from_utf8_lossy(&stderr_buf);
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "notifications/message".to_string(),
params: Some(json!({
"data": {
"type": "shell",
"stream": "stderr",
"output": line.to_string(),
}
})),
}))
.ok();
combined_output.push_str(&line);
stderr_buf.clear();
}
}
}
Ok::<_, std::io::Error>(combined_output)
});
// Wait for the command to complete and get output
let output = child
.wait_with_output()
child
.wait()
.await
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
let stdout_str = String::from_utf8_lossy(&output.stdout);
let output_str = stdout_str;
let output_str = match output_task.await {
Ok(result) => result.map_err(|e| ToolError::ExecutionError(e.to_string()))?,
Err(e) => return Err(ToolError::ExecutionError(e.to_string())),
};
// Check the character count of the output
const MAX_CHAR_COUNT: usize = 400_000; // 409600 chars = 400KB
@@ -1048,12 +1121,13 @@ impl Router for DeveloperRouter {
&self,
tool_name: &str,
arguments: Value,
notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();
Box::pin(async move {
match tool_name.as_str() {
"shell" => this.bash(arguments).await,
"shell" => this.bash(arguments, notifier).await,
"text_editor" => this.text_editor(arguments).await,
"list_windows" => this.list_windows(arguments).await,
"screen_capture" => this.screen_capture(arguments).await,
@@ -1195,6 +1269,10 @@ mod tests {
.await
}
fn dummy_sender() -> mpsc::Sender<JsonRpcMessage> {
mpsc::channel(1).0
}
#[tokio::test]
#[serial]
async fn test_shell_missing_parameters() {
@@ -1202,7 +1280,7 @@ mod tests {
std::env::set_current_dir(&temp_dir).unwrap();
let router = get_router().await;
let result = router.call_tool("shell", json!({})).await;
let result = router.call_tool("shell", json!({}), dummy_sender()).await;
assert!(result.is_err());
let err = result.err().unwrap();
@@ -1263,6 +1341,7 @@ mod tests {
"command": "view",
"path": large_file_str
}),
dummy_sender(),
)
.await;
@@ -1288,6 +1367,7 @@ mod tests {
"command": "view",
"path": many_chars_str
}),
dummy_sender(),
)
.await;
@@ -1319,6 +1399,7 @@ mod tests {
"path": file_path_str,
"file_text": "Hello, world!"
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1331,6 +1412,7 @@ mod tests {
"command": "view",
"path": file_path_str
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1369,6 +1451,7 @@ mod tests {
"path": file_path_str,
"file_text": "Hello, world!"
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1383,6 +1466,7 @@ mod tests {
"old_str": "world",
"new_str": "Rust"
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1407,6 +1491,7 @@ mod tests {
"command": "view",
"path": file_path_str
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1444,6 +1529,7 @@ mod tests {
"path": file_path_str,
"file_text": "First line"
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1458,6 +1544,7 @@ mod tests {
"old_str": "First line",
"new_str": "Second line"
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1470,6 +1557,7 @@ mod tests {
"command": "undo_edit",
"path": file_path_str
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1485,6 +1573,7 @@ mod tests {
"command": "view",
"path": file_path_str
}),
dummy_sender(),
)
.await
.unwrap();
@@ -1583,6 +1672,7 @@ mod tests {
"path": temp_dir.path().join("secret.txt").to_str().unwrap(),
"file_text": "test content"
}),
dummy_sender(),
)
.await;
@@ -1601,6 +1691,7 @@ mod tests {
"path": temp_dir.path().join("allowed.txt").to_str().unwrap(),
"file_text": "test content"
}),
dummy_sender(),
)
.await;
@@ -1642,6 +1733,7 @@ mod tests {
json!({
"command": format!("cat {}", secret_file_path.to_str().unwrap())
}),
dummy_sender(),
)
.await;
@@ -1658,6 +1750,7 @@ mod tests {
json!({
"command": format!("cat {}", allowed_file_path.to_str().unwrap())
}),
dummy_sender(),
)
.await;

View File

@@ -4,7 +4,6 @@ use std::env;
pub struct ShellConfig {
pub executable: String,
pub arg: String,
pub redirect_syntax: String,
}
impl Default for ShellConfig {
@@ -14,13 +13,11 @@ impl Default for ShellConfig {
Self {
executable: "powershell.exe".to_string(),
arg: "-NoProfile -NonInteractive -Command".to_string(),
redirect_syntax: "2>&1".to_string(),
}
} else {
Self {
executable: "bash".to_string(),
arg: "-c".to_string(),
redirect_syntax: "2>&1".to_string(),
}
}
}
@@ -31,13 +28,12 @@ pub fn get_shell_config() -> ShellConfig {
}
pub fn format_command_for_platform(command: &str) -> String {
let config = get_shell_config();
if cfg!(windows) {
// For PowerShell, wrap the command in braces to handle special characters
format!("{{ {} }} {}", command, config.redirect_syntax)
format!("{{ {} }}", command)
} else {
// For other shells, no braces needed
format!("{} {}", command, config.redirect_syntax)
command.to_string()
}
}

View File

@@ -7,6 +7,7 @@ use base64::Engine;
use chrono::NaiveDate;
use indoc::indoc;
use lazy_static::lazy_static;
use mcp_core::protocol::JsonRpcMessage;
use mcp_core::tool::ToolAnnotations;
use oauth_pkce::PkceOAuth2Client;
use regex::Regex;
@@ -14,6 +15,7 @@ use serde_json::{json, Value};
use std::io::Cursor;
use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc};
use storage::CredentialsManager;
use tokio::sync::mpsc;
use mcp_core::content::Content;
use mcp_core::{
@@ -3281,6 +3283,7 @@ impl Router for GoogleDriveRouter {
&self,
tool_name: &str,
arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();

View File

@@ -5,7 +5,7 @@ use mcp_core::{
content::Content,
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
protocol::{JsonRpcMessage, ServerCapabilities},
resource::Resource,
role::Role,
tool::Tool,
@@ -16,7 +16,7 @@ use serde_json::Value;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::{mpsc, Mutex};
use tokio::time::{sleep, Duration};
use tracing::error;
@@ -158,6 +158,7 @@ impl Router for JetBrainsRouter {
&self,
tool_name: &str,
arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();

View File

@@ -10,11 +10,12 @@ use std::{
path::PathBuf,
pin::Pin,
};
use tokio::sync::mpsc;
use mcp_core::{
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
protocol::{JsonRpcMessage, ServerCapabilities},
resource::Resource,
tool::{Tool, ToolAnnotations, ToolCall},
Content,
@@ -520,6 +521,7 @@ impl Router for MemoryRouter {
&self,
tool_name: &str,
arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();

View File

@@ -3,11 +3,12 @@ use include_dir::{include_dir, Dir};
use indoc::formatdoc;
use serde_json::{json, Value};
use std::{future::Future, pin::Pin};
use tokio::sync::mpsc;
use mcp_core::{
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
protocol::{JsonRpcMessage, ServerCapabilities},
resource::Resource,
role::Role,
tool::{Tool, ToolAnnotations},
@@ -130,6 +131,7 @@ impl Router for TutorialRouter {
&self,
tool_name: &str,
arguments: Value,
_notifier: mpsc::Sender<JsonRpcMessage>,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();