mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 07:04:21 +01:00
feat: Handle MCP server notification messages (#2613)
Co-authored-by: Michael Neale <michael.neale@gmail.com>
This commit is contained in:
@@ -6,7 +6,7 @@ use serde_json::{json, Value};
|
||||
use std::{
|
||||
collections::HashMap, fs, future::Future, path::PathBuf, pin::Pin, sync::Arc, sync::Mutex,
|
||||
};
|
||||
use tokio::process::Command;
|
||||
use tokio::{process::Command, sync::mpsc};
|
||||
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
@@ -14,7 +14,7 @@ use std::os::unix::fs::PermissionsExt;
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||
resource::Resource,
|
||||
tool::{Tool, ToolAnnotations},
|
||||
Content,
|
||||
@@ -1155,6 +1155,7 @@ impl Router for ComputerControllerRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -13,13 +13,17 @@ use std::{
|
||||
path::{Path, PathBuf},
|
||||
pin::Pin,
|
||||
};
|
||||
use tokio::process::Command;
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, BufReader},
|
||||
process::Command,
|
||||
sync::mpsc,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use include_dir::{include_dir, Dir};
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, JsonRpcNotification, ServerCapabilities},
|
||||
resource::Resource,
|
||||
tool::Tool,
|
||||
Content,
|
||||
@@ -456,7 +460,11 @@ impl DeveloperRouter {
|
||||
}
|
||||
|
||||
// Shell command execution with platform-specific handling
|
||||
async fn bash(&self, params: Value) -> Result<Vec<Content>, ToolError> {
|
||||
async fn bash(
|
||||
&self,
|
||||
params: Value,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Result<Vec<Content>, ToolError> {
|
||||
let command =
|
||||
params
|
||||
.get("command")
|
||||
@@ -488,27 +496,92 @@ impl DeveloperRouter {
|
||||
|
||||
// Get platform-specific shell configuration
|
||||
let shell_config = get_shell_config();
|
||||
let cmd_with_redirect = format_command_for_platform(command);
|
||||
let cmd_str = format_command_for_platform(command);
|
||||
|
||||
// Execute the command using platform-specific shell
|
||||
let child = Command::new(&shell_config.executable)
|
||||
let mut child = Command::new(&shell_config.executable)
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.stdin(Stdio::null())
|
||||
.kill_on_drop(true)
|
||||
.arg(&shell_config.arg)
|
||||
.arg(cmd_with_redirect)
|
||||
.arg(cmd_str)
|
||||
.spawn()
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
let stderr = child.stderr.take().unwrap();
|
||||
|
||||
let mut stdout_reader = BufReader::new(stdout);
|
||||
let mut stderr_reader = BufReader::new(stderr);
|
||||
|
||||
let output_task = tokio::spawn(async move {
|
||||
let mut combined_output = String::new();
|
||||
|
||||
let mut stdout_buf = Vec::new();
|
||||
let mut stderr_buf = Vec::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
n = stdout_reader.read_until(b'\n', &mut stdout_buf) => {
|
||||
if n? == 0 {
|
||||
break;
|
||||
}
|
||||
let line = String::from_utf8_lossy(&stdout_buf);
|
||||
|
||||
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: "notifications/message".to_string(),
|
||||
params: Some(json!({
|
||||
"data": {
|
||||
"type": "shell",
|
||||
"stream": "stdout",
|
||||
"output": line.to_string(),
|
||||
}
|
||||
})),
|
||||
}))
|
||||
.ok();
|
||||
|
||||
combined_output.push_str(&line);
|
||||
stdout_buf.clear();
|
||||
}
|
||||
n = stderr_reader.read_until(b'\n', &mut stderr_buf) => {
|
||||
if n? == 0 {
|
||||
break;
|
||||
}
|
||||
let line = String::from_utf8_lossy(&stderr_buf);
|
||||
|
||||
notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: "notifications/message".to_string(),
|
||||
params: Some(json!({
|
||||
"data": {
|
||||
"type": "shell",
|
||||
"stream": "stderr",
|
||||
"output": line.to_string(),
|
||||
}
|
||||
})),
|
||||
}))
|
||||
.ok();
|
||||
|
||||
combined_output.push_str(&line);
|
||||
stderr_buf.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok::<_, std::io::Error>(combined_output)
|
||||
});
|
||||
|
||||
// Wait for the command to complete and get output
|
||||
let output = child
|
||||
.wait_with_output()
|
||||
child
|
||||
.wait()
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;
|
||||
|
||||
let stdout_str = String::from_utf8_lossy(&output.stdout);
|
||||
let output_str = stdout_str;
|
||||
let output_str = match output_task.await {
|
||||
Ok(result) => result.map_err(|e| ToolError::ExecutionError(e.to_string()))?,
|
||||
Err(e) => return Err(ToolError::ExecutionError(e.to_string())),
|
||||
};
|
||||
|
||||
// Check the character count of the output
|
||||
const MAX_CHAR_COUNT: usize = 400_000; // 409600 chars = 400KB
|
||||
@@ -1048,12 +1121,13 @@ impl Router for DeveloperRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
Box::pin(async move {
|
||||
match tool_name.as_str() {
|
||||
"shell" => this.bash(arguments).await,
|
||||
"shell" => this.bash(arguments, notifier).await,
|
||||
"text_editor" => this.text_editor(arguments).await,
|
||||
"list_windows" => this.list_windows(arguments).await,
|
||||
"screen_capture" => this.screen_capture(arguments).await,
|
||||
@@ -1195,6 +1269,10 @@ mod tests {
|
||||
.await
|
||||
}
|
||||
|
||||
fn dummy_sender() -> mpsc::Sender<JsonRpcMessage> {
|
||||
mpsc::channel(1).0
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_shell_missing_parameters() {
|
||||
@@ -1202,7 +1280,7 @@ mod tests {
|
||||
std::env::set_current_dir(&temp_dir).unwrap();
|
||||
|
||||
let router = get_router().await;
|
||||
let result = router.call_tool("shell", json!({})).await;
|
||||
let result = router.call_tool("shell", json!({}), dummy_sender()).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
let err = result.err().unwrap();
|
||||
@@ -1263,6 +1341,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": large_file_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1288,6 +1367,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": many_chars_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1319,6 +1399,7 @@ mod tests {
|
||||
"path": file_path_str,
|
||||
"file_text": "Hello, world!"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1331,6 +1412,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": file_path_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1369,6 +1451,7 @@ mod tests {
|
||||
"path": file_path_str,
|
||||
"file_text": "Hello, world!"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1383,6 +1466,7 @@ mod tests {
|
||||
"old_str": "world",
|
||||
"new_str": "Rust"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1407,6 +1491,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": file_path_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1444,6 +1529,7 @@ mod tests {
|
||||
"path": file_path_str,
|
||||
"file_text": "First line"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1458,6 +1544,7 @@ mod tests {
|
||||
"old_str": "First line",
|
||||
"new_str": "Second line"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1470,6 +1557,7 @@ mod tests {
|
||||
"command": "undo_edit",
|
||||
"path": file_path_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1485,6 +1573,7 @@ mod tests {
|
||||
"command": "view",
|
||||
"path": file_path_str
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1583,6 +1672,7 @@ mod tests {
|
||||
"path": temp_dir.path().join("secret.txt").to_str().unwrap(),
|
||||
"file_text": "test content"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1601,6 +1691,7 @@ mod tests {
|
||||
"path": temp_dir.path().join("allowed.txt").to_str().unwrap(),
|
||||
"file_text": "test content"
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1642,6 +1733,7 @@ mod tests {
|
||||
json!({
|
||||
"command": format!("cat {}", secret_file_path.to_str().unwrap())
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1658,6 +1750,7 @@ mod tests {
|
||||
json!({
|
||||
"command": format!("cat {}", allowed_file_path.to_str().unwrap())
|
||||
}),
|
||||
dummy_sender(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ use std::env;
|
||||
pub struct ShellConfig {
|
||||
pub executable: String,
|
||||
pub arg: String,
|
||||
pub redirect_syntax: String,
|
||||
}
|
||||
|
||||
impl Default for ShellConfig {
|
||||
@@ -14,13 +13,11 @@ impl Default for ShellConfig {
|
||||
Self {
|
||||
executable: "powershell.exe".to_string(),
|
||||
arg: "-NoProfile -NonInteractive -Command".to_string(),
|
||||
redirect_syntax: "2>&1".to_string(),
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
executable: "bash".to_string(),
|
||||
arg: "-c".to_string(),
|
||||
redirect_syntax: "2>&1".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -31,13 +28,12 @@ pub fn get_shell_config() -> ShellConfig {
|
||||
}
|
||||
|
||||
pub fn format_command_for_platform(command: &str) -> String {
|
||||
let config = get_shell_config();
|
||||
if cfg!(windows) {
|
||||
// For PowerShell, wrap the command in braces to handle special characters
|
||||
format!("{{ {} }} {}", command, config.redirect_syntax)
|
||||
format!("{{ {} }}", command)
|
||||
} else {
|
||||
// For other shells, no braces needed
|
||||
format!("{} {}", command, config.redirect_syntax)
|
||||
command.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ use base64::Engine;
|
||||
use chrono::NaiveDate;
|
||||
use indoc::indoc;
|
||||
use lazy_static::lazy_static;
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use mcp_core::tool::ToolAnnotations;
|
||||
use oauth_pkce::PkceOAuth2Client;
|
||||
use regex::Regex;
|
||||
@@ -14,6 +15,7 @@ use serde_json::{json, Value};
|
||||
use std::io::Cursor;
|
||||
use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc};
|
||||
use storage::CredentialsManager;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use mcp_core::content::Content;
|
||||
use mcp_core::{
|
||||
@@ -3281,6 +3283,7 @@ impl Router for GoogleDriveRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -5,7 +5,7 @@ use mcp_core::{
|
||||
content::Content,
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||
resource::Resource,
|
||||
role::Role,
|
||||
tool::Tool,
|
||||
@@ -16,7 +16,7 @@ use serde_json::Value;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tracing::error;
|
||||
|
||||
@@ -158,6 +158,7 @@ impl Router for JetBrainsRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -10,11 +10,12 @@ use std::{
|
||||
path::PathBuf,
|
||||
pin::Pin,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||
resource::Resource,
|
||||
tool::{Tool, ToolAnnotations, ToolCall},
|
||||
Content,
|
||||
@@ -520,6 +521,7 @@ impl Router for MemoryRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
@@ -3,11 +3,12 @@ use include_dir::{include_dir, Dir};
|
||||
use indoc::formatdoc;
|
||||
use serde_json::{json, Value};
|
||||
use std::{future::Future, pin::Pin};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use mcp_core::{
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
protocol::{JsonRpcMessage, ServerCapabilities},
|
||||
resource::Resource,
|
||||
role::Role,
|
||||
tool::{Tool, ToolAnnotations},
|
||||
@@ -130,6 +131,7 @@ impl Router for TutorialRouter {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: Value,
|
||||
_notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
|
||||
let this = self.clone();
|
||||
let tool_name = tool_name.to_string();
|
||||
|
||||
Reference in New Issue
Block a user