From d0ca46983ef5e67a61620053047df7bbf5e9c0a0 Mon Sep 17 00:00:00 2001 From: Kalvin C Date: Thu, 27 Feb 2025 15:47:29 -0800 Subject: [PATCH] feat(cli): add mcp prompt support via slash commands (#1323) --- Cargo.lock | 1 + crates/goose-cli/Cargo.toml | 1 + crates/goose-cli/src/session/input.rs | 164 +++++++++++++++ crates/goose-cli/src/session/mod.rs | 100 +++++++++ crates/goose-cli/src/session/output.rs | 55 +++++ .../goose-mcp/src/computercontroller/mod.rs | 20 +- crates/goose-mcp/src/developer/mod.rs | 32 +-- crates/goose-mcp/src/google_drive/mod.rs | 20 +- crates/goose-mcp/src/jetbrains/mod.rs | 20 +- crates/goose-mcp/src/memory/mod.rs | 19 +- crates/goose-mcp/src/tutorial/mod.rs | 20 +- crates/goose/src/agents/agent.rs | 11 + crates/goose/src/agents/capabilities.rs | 103 +++++++++- crates/goose/src/agents/reference.rs | 35 ++++ crates/goose/src/agents/truncate.rs | 35 ++++ crates/goose/src/message.rs | 190 ++++++++++++++++++ .../mcp-client/examples/stdio_integration.rs | 11 + crates/mcp-client/src/client.rs | 48 ++++- crates/mcp-client/src/transport/sse.rs | 20 +- crates/mcp-client/src/transport/stdio.rs | 14 +- crates/mcp-core/src/prompt.rs | 28 ++- crates/mcp-server/src/main.rs | 35 +++- crates/mcp-server/src/router.rs | 54 +++-- ui/desktop/src/components/UserMessage.tsx | 4 +- 24 files changed, 958 insertions(+), 82 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6732b2cc..4734d782 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2203,6 +2203,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "shlex", "temp-env", "tempfile", "test-case", diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index 41cb85a6..7addc964 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -47,6 +47,7 @@ chrono = "0.4" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json", "time"] } tracing-appender = "0.2" once_cell = "1.20.2" +shlex = "1.3.0" [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose-cli/src/session/input.rs b/crates/goose-cli/src/session/input.rs index 7cfa94d3..38fd2ea5 100644 --- a/crates/goose-cli/src/session/input.rs +++ b/crates/goose-cli/src/session/input.rs @@ -1,5 +1,7 @@ use anyhow::Result; use rustyline::Editor; +use shlex; +use std::collections::HashMap; #[derive(Debug)] pub enum InputResult { @@ -9,6 +11,15 @@ pub enum InputResult { AddBuiltin(String), ToggleTheme, Retry, + ListPrompts, + PromptCommand(PromptCommandOptions), +} + +#[derive(Debug)] +pub struct PromptCommandOptions { + pub name: String, + pub info: bool, + pub arguments: HashMap, } pub fn get_input( @@ -59,12 +70,67 @@ fn handle_slash_command(input: &str) -> Option { Some(InputResult::Retry) } "/t" => Some(InputResult::ToggleTheme), + "/prompts" => Some(InputResult::ListPrompts), + s if s.starts_with("/prompt") => { + if s == "/prompt" { + // No arguments case + Some(InputResult::PromptCommand(PromptCommandOptions { + name: String::new(), // Empty name will trigger the error message in the rendering + info: false, + arguments: HashMap::new(), + })) + } else if let Some(stripped) = s.strip_prefix("/prompt ") { + // Has arguments case + parse_prompt_command(stripped) + } else { + // Handle invalid cases like "/promptxyz" + None + } + } s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())), s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), _ => None, } } +fn parse_prompt_command(args: &str) -> Option { + let parts: Vec = shlex::split(args).unwrap_or_default(); + + // set name to empty and error out in the rendering + let mut options = PromptCommandOptions { + name: parts.first().cloned().unwrap_or_default(), + info: false, + arguments: HashMap::new(), + }; + + // handle info at any point in the command + if parts.iter().any(|part| part == "--info") { + options.info = true; + } + + // Parse remaining arguments + let mut i = 1; + + while i < parts.len() { + let part = &parts[i]; + + // Skip flag arguments + if part == "--info" { + i += 1; + continue; + } + + // Process key=value pairs - removed redundant contains check + if let Some((key, value)) = part.split_once('=') { + options.arguments.insert(key.to_string(), value.to_string()); + } + + i += 1; + } + + Some(InputResult::PromptCommand(options)) +} + fn print_help() { println!( "Available commands: @@ -72,6 +138,8 @@ fn print_help() { /t - Toggle Light/Dark/Ansi theme /extension - Add a stdio extension (format: ENV1=val1 command args...) /builtin - Add builtin extensions by name (comma-separated) +/prompts - List all available prompts by name +/prompt [--info] [key=value...] - Get prompt info or execute a prompt /? or /help - Display this help message Navigation: @@ -131,6 +199,33 @@ mod tests { assert!(handle_slash_command("/unknown").is_none()); } + #[test] + fn test_prompt_command() { + // Test basic prompt info command + if let Some(InputResult::PromptCommand(opts)) = + handle_slash_command("/prompt test-prompt --info") + { + assert_eq!(opts.name, "test-prompt"); + assert!(opts.info); + assert!(opts.arguments.is_empty()); + } else { + panic!("Expected PromptCommand"); + } + + // Test prompt with arguments + if let Some(InputResult::PromptCommand(opts)) = + handle_slash_command("/prompt test-prompt arg1=val1 arg2=val2") + { + assert_eq!(opts.name, "test-prompt"); + assert!(!opts.info); + assert_eq!(opts.arguments.len(), 2); + assert_eq!(opts.arguments.get("arg1"), Some(&"val1".to_string())); + assert_eq!(opts.arguments.get("arg2"), Some(&"val2".to_string())); + } else { + panic!("Expected PromptCommand"); + } + } + // Test whitespace handling #[test] fn test_whitespace_handling() { @@ -149,4 +244,73 @@ mod tests { panic!("Expected AddBuiltin"); } } + + // Test prompt with no arguments + #[test] + fn test_prompt_no_args() { + // Test just "/prompt" with no arguments + if let Some(InputResult::PromptCommand(opts)) = handle_slash_command("/prompt") { + assert_eq!(opts.name, ""); + assert!(!opts.info); + assert!(opts.arguments.is_empty()); + } else { + panic!("Expected PromptCommand"); + } + + // Test invalid prompt command + assert!(handle_slash_command("/promptxyz").is_none()); + } + + // Test quoted arguments + #[test] + fn test_quoted_arguments() { + // Test prompt with quoted arguments + if let Some(InputResult::PromptCommand(opts)) = handle_slash_command( + r#"/prompt test-prompt arg1="value with spaces" arg2="another value""#, + ) { + assert_eq!(opts.name, "test-prompt"); + assert_eq!(opts.arguments.len(), 2); + assert_eq!( + opts.arguments.get("arg1"), + Some(&"value with spaces".to_string()) + ); + assert_eq!( + opts.arguments.get("arg2"), + Some(&"another value".to_string()) + ); + } else { + panic!("Expected PromptCommand"); + } + + // Test prompt with mixed quoted and unquoted arguments + if let Some(InputResult::PromptCommand(opts)) = handle_slash_command( + r#"/prompt test-prompt simple=value quoted="value with \"nested\" quotes""#, + ) { + assert_eq!(opts.name, "test-prompt"); + assert_eq!(opts.arguments.len(), 2); + assert_eq!(opts.arguments.get("simple"), Some(&"value".to_string())); + assert_eq!( + opts.arguments.get("quoted"), + Some(&r#"value with "nested" quotes"#.to_string()) + ); + } else { + panic!("Expected PromptCommand"); + } + } + + // Test invalid arguments + #[test] + fn test_invalid_arguments() { + // Test prompt with invalid arguments + if let Some(InputResult::PromptCommand(opts)) = + handle_slash_command(r#"/prompt test-prompt valid=value invalid_arg another_invalid"#) + { + assert_eq!(opts.name, "test-prompt"); + assert_eq!(opts.arguments.len(), 1); + assert_eq!(opts.arguments.get("valid"), Some(&"value".to_string())); + // Invalid arguments are ignored but logged + } else { + panic!("Expected PromptCommand"); + } + } } diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 88b68d0c..b3826d2b 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -14,7 +14,11 @@ use goose::agents::extension::{Envs, ExtensionConfig}; use goose::agents::Agent; use goose::message::{Message, MessageContent}; use mcp_core::handler::ToolError; +use mcp_core::prompt::PromptMessage; + use rand::{distributions::Alphanumeric, Rng}; +use serde_json::Value; +use std::collections::HashMap; use std::path::PathBuf; use tokio; @@ -104,6 +108,40 @@ impl Session { Ok(()) } + pub async fn list_prompts(&mut self) -> HashMap> { + let prompts = self.agent.list_extension_prompts().await; + prompts + .into_iter() + .map(|(extension, prompt_list)| { + let names = prompt_list.into_iter().map(|p| p.name).collect(); + (extension, names) + }) + .collect() + } + + pub async fn get_prompt_info(&mut self, name: &str) -> Result> { + let prompts = self.agent.list_extension_prompts().await; + + // Find which extension has this prompt + for (extension, prompt_list) in prompts { + if let Some(prompt) = prompt_list.iter().find(|p| p.name == name) { + return Ok(Some(output::PromptInfo { + name: prompt.name.clone(), + description: prompt.description.clone(), + arguments: prompt.arguments.clone(), + extension: Some(extension), + })); + } + } + + Ok(None) + } + + pub async fn get_prompt(&mut self, name: &str, arguments: Value) -> Result> { + let result = self.agent.get_prompt(name, arguments).await?; + Ok(result.messages) + } + /// Process a single message and get the response async fn process_message(&mut self, message: String) -> Result<()> { self.messages.push(Message::user().with_text(&message)); @@ -179,6 +217,68 @@ impl Session { continue; } input::InputResult::Retry => continue, + input::InputResult::ListPrompts => { + output::render_prompts(&self.list_prompts().await) + } + input::InputResult::PromptCommand(opts) => { + // name is required + if opts.name.is_empty() { + output::render_error("Prompt name argument is required"); + continue; + } + + if opts.info { + match self.get_prompt_info(&opts.name).await? { + Some(info) => output::render_prompt_info(&info), + None => { + output::render_error(&format!("Prompt '{}' not found", opts.name)) + } + } + } else { + // Convert the arguments HashMap to a Value + let arguments = serde_json::to_value(opts.arguments) + .map_err(|e| anyhow::anyhow!("Failed to serialize arguments: {}", e))?; + + match self.get_prompt(&opts.name, arguments).await { + Ok(messages) => { + let start_len = self.messages.len(); + let mut valid = true; + for (i, prompt_message) in messages.into_iter().enumerate() { + let msg = Message::from(prompt_message); + // ensure we get a User - Assistant - User type pattern + let expected_role = if i % 2 == 0 { + mcp_core::Role::User + } else { + mcp_core::Role::Assistant + }; + + if msg.role != expected_role { + output::render_error(&format!( + "Expected {:?} message at position {}, but found {:?}", + expected_role, i, msg.role + )); + valid = false; + // get rid of everything we added to messages + self.messages.truncate(start_len); + break; + } + + if msg.role == mcp_core::Role::User { + output::render_message(&msg); + } + self.messages.push(msg); + } + + if valid { + output::show_thinking(); + self.process_agent_response(true).await?; + output::hide_thinking(); + } + } + Err(e) => output::render_error(&e.to_string()), + } + } + } } } diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index 06d38343..d1377eb6 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -2,9 +2,11 @@ use bat::WrappingMode; use console::style; use goose::config::Config; use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; +use mcp_core::prompt::PromptArgument; use mcp_core::tool::ToolCall; use serde_json::Value; use std::cell::RefCell; +use std::collections::HashMap; use std::path::Path; // Re-export theme for use in main @@ -73,6 +75,14 @@ impl ThinkingIndicator { } } +#[derive(Debug)] +pub struct PromptInfo { + pub name: String, + pub description: Option, + pub arguments: Option>, + pub extension: Option, +} + // Global thinking indicator thread_local! { static THINKING: RefCell = RefCell::new(ThinkingIndicator::default()); @@ -154,6 +164,51 @@ pub fn render_error(message: &str) { println!("\n {} {}\n", style("error:").red().bold(), message); } +pub fn render_prompts(prompts: &HashMap>) { + println!(); + for (extension, prompts) in prompts { + println!(" {}", style(extension).green()); + for prompt in prompts { + println!(" - {}", style(prompt).cyan()); + } + } + println!(); +} + +pub fn render_prompt_info(info: &PromptInfo) { + println!(); + + if let Some(ext) = &info.extension { + println!(" {}: {}", style("Extension").green(), ext); + } + + println!(" Prompt: {}", style(&info.name).cyan().bold()); + + if let Some(desc) = &info.description { + println!("\n {}", desc); + } + + if let Some(args) = &info.arguments { + println!("\n Arguments:"); + for arg in args { + let required = arg.required.unwrap_or(false); + let req_str = if required { + style("(required)").red() + } else { + style("(optional)").dim() + }; + + println!( + " {} {} {}", + style(&arg.name).yellow(), + req_str, + arg.description.as_deref().unwrap_or("") + ); + } + } + println!(); +} + pub fn render_extension_success(name: &str) { println!(); println!( diff --git a/crates/goose-mcp/src/computercontroller/mod.rs b/crates/goose-mcp/src/computercontroller/mod.rs index 9dd2a5e1..832814b6 100644 --- a/crates/goose-mcp/src/computercontroller/mod.rs +++ b/crates/goose-mcp/src/computercontroller/mod.rs @@ -9,7 +9,8 @@ use std::{ use tokio::process::Command; use mcp_core::{ - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, tool::Tool, @@ -819,4 +820,21 @@ impl Router for ComputerControllerRouter { } }) } + + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index 28d0f0db..028e0453 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -72,9 +72,9 @@ pub fn load_prompt_files() -> HashMap { description: arg.description, required: arg.required, }) - .collect(); + .collect::>(); - let prompt = Prompt::new(&template.id, &template.template, arguments); + let prompt = Prompt::new(&template.id, Some(&template.template), Some(arguments)); if prompts.contains_key(&prompt.name) { eprintln!("Duplicate prompt name '{}' found. Skipping.", prompt.name); @@ -905,47 +905,35 @@ impl Router for DeveloperRouter { Box::pin(async move { Ok("".to_string()) }) } - fn list_prompts(&self) -> Option> { - if self.prompts.is_empty() { - None - } else { - Some(self.prompts.values().cloned().collect()) - } + fn list_prompts(&self) -> Vec { + self.prompts.values().cloned().collect() } fn get_prompt( &self, prompt_name: &str, - ) -> Option> + Send + 'static>>> { + ) -> Pin> + Send + 'static>> { let prompt_name = prompt_name.trim().to_owned(); // Validate prompt name is not empty if prompt_name.is_empty() { - return Some(Box::pin(async move { + return Box::pin(async move { Err(PromptError::InvalidParameters( "Prompt name cannot be empty".to_string(), )) - })); + }); } let prompts = Arc::clone(&self.prompts); - Some(Box::pin(async move { + Box::pin(async move { match prompts.get(&prompt_name) { - Some(prompt) => { - if prompt.description.trim().is_empty() { - Err(PromptError::InternalError(format!( - "Prompt '{prompt_name}' has an empty description" - ))) - } else { - Ok(prompt.description.clone()) - } - } + Some(prompt) => Ok(prompt.description.clone().unwrap_or_default()), None => Err(PromptError::NotFound(format!( "Prompt '{prompt_name}' not found" ))), } - })) + }) } } diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index 2ba36a57..2ed1e7f1 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -5,7 +5,8 @@ use serde_json::{json, Value}; use std::{env, fs, future::Future, io::Write, path::Path, pin::Pin}; use mcp_core::{ - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, tool::Tool, @@ -618,6 +619,23 @@ impl Router for GoogleDriveRouter { let uri_clone = uri.to_string(); Box::pin(async move { this.read_google_resource(uri_clone).await }) } + + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } impl Clone for GoogleDriveRouter { diff --git a/crates/goose-mcp/src/jetbrains/mod.rs b/crates/goose-mcp/src/jetbrains/mod.rs index 319cdcd3..0cdf8018 100644 --- a/crates/goose-mcp/src/jetbrains/mod.rs +++ b/crates/goose-mcp/src/jetbrains/mod.rs @@ -3,7 +3,8 @@ mod proxy; use anyhow::Result; use mcp_core::{ content::Content, - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, role::Role, @@ -176,6 +177,23 @@ impl Router for JetBrainsRouter { ) -> Pin> + Send + 'static>> { Box::pin(async { Err(ResourceError::NotFound("Resource not found".into())) }) } + + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } impl Clone for JetBrainsRouter { diff --git a/crates/goose-mcp/src/memory/mod.rs b/crates/goose-mcp/src/memory/mod.rs index 4a7411a5..a9fd1fa3 100644 --- a/crates/goose-mcp/src/memory/mod.rs +++ b/crates/goose-mcp/src/memory/mod.rs @@ -12,7 +12,8 @@ use std::{ }; use mcp_core::{ - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, tool::{Tool, ToolCall}, @@ -493,6 +494,22 @@ impl Router for MemoryRouter { ) -> Pin> + Send + 'static>> { Box::pin(async move { Ok("".to_string()) }) } + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } #[derive(Debug)] diff --git a/crates/goose-mcp/src/tutorial/mod.rs b/crates/goose-mcp/src/tutorial/mod.rs index 9d6ba3d7..2f32b03a 100644 --- a/crates/goose-mcp/src/tutorial/mod.rs +++ b/crates/goose-mcp/src/tutorial/mod.rs @@ -5,7 +5,8 @@ use serde_json::{json, Value}; use std::{future::Future, pin::Pin}; use mcp_core::{ - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, role::Role, @@ -156,6 +157,23 @@ impl Router for TutorialRouter { ) -> Pin> + Send + 'static>> { Box::pin(async move { Ok("".to_string()) }) } + + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } impl Clone for TutorialRouter { diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 469418b2..007848a4 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; @@ -6,6 +8,8 @@ use serde_json::Value; use super::extension::{ExtensionConfig, ExtensionResult}; use crate::message::Message; use crate::providers::base::ProviderUsage; +use mcp_core::prompt::Prompt; +use mcp_core::protocol::GetPromptResult; /// Core trait defining the behavior of an Agent #[async_trait] @@ -37,4 +41,11 @@ pub trait Agent: Send + Sync { /// Override the system prompt with custom text async fn override_system_prompt(&mut self, template: String); + + /// Lists all prompts from all extensions + async fn list_extension_prompts(&self) -> HashMap>; + + /// Get a prompt result with the given name and arguments + /// Returns the prompt text that would be used as user input + async fn get_prompt(&self, name: &str, arguments: Value) -> Result; } diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index 783f15de..2e95ec70 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -1,6 +1,8 @@ +use anyhow::Result; use chrono::{DateTime, TimeZone, Utc}; use futures::stream::{FuturesUnordered, StreamExt}; use mcp_client::McpService; +use mcp_core::protocol::GetPromptResult; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::sync::LazyLock; @@ -13,7 +15,7 @@ use crate::prompt_template::{load_prompt, load_prompt_file}; use crate::providers::base::{Provider, ProviderUsage}; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; use mcp_client::transport::{SseTransport, StdioTransport, Transport}; -use mcp_core::{Content, Tool, ToolCall, ToolError, ToolResult}; +use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult}; use serde_json::Value; // By default, we set it to Jan 1, 2020 if the resource does not have a timestamp @@ -544,6 +546,87 @@ impl Capabilities { result } + + pub async fn list_prompts_from_extension( + &self, + extension_name: &str, + ) -> Result, ToolError> { + let client = self.clients.get(extension_name).ok_or_else(|| { + ToolError::InvalidParameters(format!("Extension {} is not valid", extension_name)) + })?; + + let client_guard = client.lock().await; + client_guard + .list_prompts(None) + .await + .map_err(|e| { + ToolError::ExecutionError(format!( + "Unable to list prompts for {}, {:?}", + extension_name, e + )) + }) + .map(|lp| lp.prompts) + } + + pub async fn list_prompts(&self) -> Result>, ToolError> { + let mut futures = FuturesUnordered::new(); + + for extension_name in self.clients.keys() { + futures.push(async move { + ( + extension_name, + self.list_prompts_from_extension(extension_name).await, + ) + }); + } + + let mut all_prompts = HashMap::new(); + let mut errors = Vec::new(); + + // Process results as they complete + while let Some(result) = futures.next().await { + let (name, prompts) = result; + match prompts { + Ok(content) => { + all_prompts.insert(name.to_string(), content); + } + Err(tool_error) => { + errors.push(tool_error); + } + } + } + + // Log any errors that occurred + if !errors.is_empty() { + tracing::error!( + errors = ?errors + .into_iter() + .map(|e| format!("{:?}", e)) + .collect::>(), + "errors from listing prompts" + ); + } + + Ok(all_prompts) + } + + pub async fn get_prompt( + &self, + extension_name: &str, + name: &str, + arguments: Value, + ) -> Result { + let client = self + .clients + .get(extension_name) + .ok_or_else(|| anyhow::anyhow!("Extension {} not found", extension_name))?; + + let client_guard = client.lock().await; + client_guard + .get_prompt(name, arguments) + .await + .map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e)) + } } #[cfg(test)] @@ -556,7 +639,8 @@ mod tests { use mcp_client::client::Error; use mcp_client::client::McpClientTrait; use mcp_core::protocol::{ - CallToolResult, InitializeResult, ListResourcesResult, ListToolsResult, ReadResourceResult, + CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult, + ListToolsResult, ReadResourceResult, }; use serde_json::json; @@ -625,6 +709,21 @@ mod tests { _ => Err(Error::NotInitialized), } } + + async fn list_prompts( + &self, + _next_cursor: Option, + ) -> Result { + Err(Error::NotInitialized) + } + + async fn get_prompt( + &self, + _name: &str, + _arguments: Value, + ) -> Result { + Err(Error::NotInitialized) + } } #[test] diff --git a/crates/goose/src/agents/reference.rs b/crates/goose/src/agents/reference.rs index bda3acce..9eccd23c 100644 --- a/crates/goose/src/agents/reference.rs +++ b/crates/goose/src/agents/reference.rs @@ -2,6 +2,7 @@ /// It makes no attempt to handle context limits, and cannot read resources use async_trait::async_trait; use futures::stream::BoxStream; +use std::collections::HashMap; use tokio::sync::Mutex; use tracing::{debug, instrument}; @@ -13,7 +14,10 @@ use crate::providers::base::Provider; use crate::providers::base::ProviderUsage; use crate::register_agent; use crate::token_counter::TokenCounter; +use anyhow::{anyhow, Result}; use indoc::indoc; +use mcp_core::prompt::Prompt; +use mcp_core::protocol::GetPromptResult; use mcp_core::tool::Tool; use serde_json::{json, Value}; @@ -198,6 +202,37 @@ impl Agent for ReferenceAgent { let mut capabilities = self.capabilities.lock().await; capabilities.set_system_prompt_override(template); } + + async fn list_extension_prompts(&self) -> HashMap> { + let capabilities = self.capabilities.lock().await; + capabilities + .list_prompts() + .await + .expect("Failed to list prompts") + } + + async fn get_prompt(&self, name: &str, arguments: Value) -> Result { + let capabilities = self.capabilities.lock().await; + + // First find which extension has this prompt + let prompts = capabilities + .list_prompts() + .await + .map_err(|e| anyhow!("Failed to list prompts: {}", e))?; + + if let Some(extension) = prompts + .iter() + .find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name)) + .map(|(extension, _)| extension) + { + return capabilities + .get_prompt(extension, name, arguments) + .await + .map_err(|e| anyhow!("Failed to get prompt: {}", e)); + } + + Err(anyhow!("Prompt '{}' not found", name)) + } } register_agent!("reference", ReferenceAgent); diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index 8bf581d6..c795f93b 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -2,6 +2,7 @@ /// It makes no attempt to handle context limits, and cannot read resources use async_trait::async_trait; use futures::stream::BoxStream; +use std::collections::HashMap; use tokio::sync::mpsc; use tokio::sync::Mutex; use tracing::{debug, error, instrument, warn}; @@ -19,7 +20,10 @@ use crate::providers::errors::ProviderError; use crate::register_agent; use crate::token_counter::TokenCounter; use crate::truncate::{truncate_messages, OldestFirstTruncation}; +use anyhow::{anyhow, Result}; use indoc::indoc; +use mcp_core::prompt::Prompt; +use mcp_core::protocol::GetPromptResult; use mcp_core::{tool::Tool, Content}; use serde_json::{json, Value}; @@ -398,6 +402,37 @@ impl Agent for TruncateAgent { let mut capabilities = self.capabilities.lock().await; capabilities.set_system_prompt_override(template); } + + async fn list_extension_prompts(&self) -> HashMap> { + let capabilities = self.capabilities.lock().await; + capabilities + .list_prompts() + .await + .expect("Failed to list prompts") + } + + async fn get_prompt(&self, name: &str, arguments: Value) -> Result { + let capabilities = self.capabilities.lock().await; + + // First find which extension has this prompt + let prompts = capabilities + .list_prompts() + .await + .map_err(|e| anyhow!("Failed to list prompts: {}", e))?; + + if let Some(extension) = prompts + .iter() + .find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name)) + .map(|(extension, _)| extension) + { + return capabilities + .get_prompt(extension, name, arguments) + .await + .map_err(|e| anyhow!("Failed to get prompt: {}", e)); + } + + Err(anyhow!("Prompt '{}' not found", name)) + } } register_agent!("truncate", TruncateAgent); diff --git a/crates/goose/src/message.rs b/crates/goose/src/message.rs index e5771a23..41193a53 100644 --- a/crates/goose/src/message.rs +++ b/crates/goose/src/message.rs @@ -10,6 +10,8 @@ use std::collections::HashSet; use chrono::Utc; use mcp_core::content::{Content, ImageContent, TextContent}; use mcp_core::handler::ToolResult; +use mcp_core::prompt::{PromptMessage, PromptMessageContent, PromptMessageRole}; +use mcp_core::resource::ResourceContents; use mcp_core::role::Role; use mcp_core::tool::ToolCall; use serde_json::Value; @@ -156,6 +158,37 @@ impl From for MessageContent { } } +impl From for Message { + fn from(prompt_message: PromptMessage) -> Self { + // Create a new message with the appropriate role + let message = match prompt_message.role { + PromptMessageRole::User => Message::user(), + PromptMessageRole::Assistant => Message::assistant(), + }; + + // Convert and add the content + let content = match prompt_message.content { + PromptMessageContent::Text { text } => MessageContent::text(text), + PromptMessageContent::Image { image } => { + MessageContent::image(image.data, image.mime_type) + } + PromptMessageContent::Resource { resource } => { + // For resources, convert to text content with the resource text + match resource.resource { + ResourceContents::TextResourceContents { text, .. } => { + MessageContent::text(text) + } + ResourceContents::BlobResourceContents { blob, .. } => { + MessageContent::text(format!("[Binary content: {}]", blob)) + } + } + } + }; + + message.with_content(content) + } +} + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] /// A message to or from an LLM #[serde(rename_all = "camelCase")] @@ -305,7 +338,10 @@ impl Message { #[cfg(test)] mod tests { use super::*; + use mcp_core::content::EmbeddedResource; use mcp_core::handler::ToolError; + use mcp_core::prompt::PromptMessageContent; + use mcp_core::resource::ResourceContents; use serde_json::{json, Value}; #[test] @@ -420,4 +456,158 @@ mod tests { panic!("Expected ToolRequest content"); } } + + #[test] + fn test_from_prompt_message_text() { + let prompt_content = PromptMessageContent::Text { + text: "Hello, world!".to_string(), + }; + + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: prompt_content, + }; + + let message = Message::from(prompt_message); + + if let MessageContent::Text(text_content) = &message.content[0] { + assert_eq!(text_content.text, "Hello, world!"); + } else { + panic!("Expected MessageContent::Text"); + } + } + + #[test] + fn test_from_prompt_message_image() { + let prompt_content = PromptMessageContent::Image { + image: ImageContent { + data: "base64data".to_string(), + mime_type: "image/jpeg".to_string(), + annotations: None, + }, + }; + + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: prompt_content, + }; + + let message = Message::from(prompt_message); + + if let MessageContent::Image(image_content) = &message.content[0] { + assert_eq!(image_content.data, "base64data"); + assert_eq!(image_content.mime_type, "image/jpeg"); + } else { + panic!("Expected MessageContent::Image"); + } + } + + #[test] + fn test_from_prompt_message_text_resource() { + let resource = ResourceContents::TextResourceContents { + uri: "file:///test.txt".to_string(), + mime_type: Some("text/plain".to_string()), + text: "Resource content".to_string(), + }; + + let prompt_content = PromptMessageContent::Resource { + resource: EmbeddedResource { + resource, + annotations: None, + }, + }; + + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: prompt_content, + }; + + let message = Message::from(prompt_message); + + if let MessageContent::Text(text_content) = &message.content[0] { + assert_eq!(text_content.text, "Resource content"); + } else { + panic!("Expected MessageContent::Text"); + } + } + + #[test] + fn test_from_prompt_message_blob_resource() { + let resource = ResourceContents::BlobResourceContents { + uri: "file:///test.bin".to_string(), + mime_type: Some("application/octet-stream".to_string()), + blob: "binary_data".to_string(), + }; + + let prompt_content = PromptMessageContent::Resource { + resource: EmbeddedResource { + resource, + annotations: None, + }, + }; + + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: prompt_content, + }; + + let message = Message::from(prompt_message); + + if let MessageContent::Text(text_content) = &message.content[0] { + assert_eq!(text_content.text, "[Binary content: binary_data]"); + } else { + panic!("Expected MessageContent::Text"); + } + } + + #[test] + fn test_from_prompt_message() { + // Test user message conversion + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: PromptMessageContent::Text { + text: "Hello, world!".to_string(), + }, + }; + + let message = Message::from(prompt_message); + assert_eq!(message.role, Role::User); + assert_eq!(message.content.len(), 1); + assert_eq!(message.as_concat_text(), "Hello, world!"); + + // Test assistant message conversion + let prompt_message = PromptMessage { + role: PromptMessageRole::Assistant, + content: PromptMessageContent::Text { + text: "I can help with that.".to_string(), + }, + }; + + let message = Message::from(prompt_message); + assert_eq!(message.role, Role::Assistant); + assert_eq!(message.content.len(), 1); + assert_eq!(message.as_concat_text(), "I can help with that."); + } + + #[test] + fn test_message_with_text() { + let message = Message::user().with_text("Hello"); + assert_eq!(message.as_concat_text(), "Hello"); + } + + #[test] + fn test_message_with_tool_request() { + let tool_call = Ok(ToolCall { + name: "test_tool".to_string(), + arguments: serde_json::json!({}), + }); + + let message = Message::assistant().with_tool_request("req1", tool_call); + assert!(message.is_tool_call()); + assert!(!message.is_tool_response()); + + let ids = message.get_tool_ids(); + assert_eq!(ids.len(), 1); + assert!(ids.contains("req1")); + } } diff --git a/crates/mcp-client/examples/stdio_integration.rs b/crates/mcp-client/examples/stdio_integration.rs index 9acd2086..ffdcc10c 100644 --- a/crates/mcp-client/examples/stdio_integration.rs +++ b/crates/mcp-client/examples/stdio_integration.rs @@ -82,5 +82,16 @@ async fn main() -> Result<(), ClientError> { let resource = client.read_resource("memo://insights").await?; println!("Resource: {resource:?}\n"); + let prompts = client.list_prompts(None).await?; + println!("Prompts: {prompts:?}\n"); + + let prompt = client + .get_prompt( + "example_prompt", + serde_json::json!({"message": "hello there!"}), + ) + .await?; + println!("Prompt: {prompt:?}\n"); + Ok(()) } diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 0a00e8c7..0d722e55 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -1,7 +1,7 @@ use mcp_core::protocol::{ - CallToolResult, Implementation, InitializeResult, JsonRpcError, JsonRpcMessage, - JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, - ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND, + CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError, + JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult, + ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND, }; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -93,6 +93,10 @@ pub trait McpClientTrait: Send + Sync { async fn list_tools(&self, next_cursor: Option) -> Result; async fn call_tool(&self, name: &str, arguments: Value) -> Result; + + async fn list_prompts(&self, next_cursor: Option) -> Result; + + async fn get_prompt(&self, name: &str, arguments: Value) -> Result; } /// The MCP client is the interface for MCP operations. @@ -346,4 +350,42 @@ where // https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2 self.send_request("tools/call", params).await } + + async fn list_prompts(&self, next_cursor: Option) -> Result { + if !self.completed_initialization() { + return Err(Error::NotInitialized); + } + + // If prompts is not supported, return an error + if self.server_capabilities.as_ref().unwrap().prompts.is_none() { + return Err(Error::RpcError { + code: METHOD_NOT_FOUND, + message: "Server does not support 'prompts' capability".to_string(), + }); + } + + let payload = next_cursor + .map(|cursor| serde_json::json!({"cursor": cursor})) + .unwrap_or_else(|| serde_json::json!({})); + + self.send_request("prompts/list", payload).await + } + + async fn get_prompt(&self, name: &str, arguments: Value) -> Result { + if !self.completed_initialization() { + return Err(Error::NotInitialized); + } + + // If prompts is not supported, return an error + if self.server_capabilities.as_ref().unwrap().prompts.is_none() { + return Err(Error::RpcError { + code: METHOD_NOT_FOUND, + message: "Server does not support 'prompts' capability".to_string(), + }); + } + + let params = serde_json::json!({ "name": name, "arguments": arguments }); + + self.send_request("prompts/get", params).await + } } diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index ed08e480..90dc5f2f 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -111,13 +111,23 @@ impl SseActor { // Attempt to parse the SSE data as a JsonRpcMessage match serde_json::from_str::(&e.data) { Ok(message) => { - // If it's a response, complete the pending request - if let JsonRpcMessage::Response(resp) = &message { - if let Some(id) = &resp.id { - pending_requests.respond(&id.to_string(), Ok(message)).await; + match &message { + JsonRpcMessage::Response(response) => { + if let Some(id) = &response.id { + pending_requests + .respond(&id.to_string(), Ok(message)) + .await; + } } + JsonRpcMessage::Error(error) => { + if let Some(id) = &error.id { + pending_requests + .respond(&id.to_string(), Ok(message)) + .await; + } + } + _ => {} // TODO: Handle other variants (Request, etc.) } - // If it's something else (notification, etc.), handle as needed } Err(err) => { warn!("Failed to parse SSE message: {err}"); diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index 59d90054..7980816b 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -87,10 +87,18 @@ impl StdioActor { "Received incoming message" ); - if let JsonRpcMessage::Response(response) = &message { - if let Some(id) = &response.id { - pending_requests.respond(&id.to_string(), Ok(message)).await; + match &message { + JsonRpcMessage::Response(response) => { + if let Some(id) = &response.id { + pending_requests.respond(&id.to_string(), Ok(message)).await; + } } + JsonRpcMessage::Error(error) => { + if let Some(id) = &error.id { + pending_requests.respond(&id.to_string(), Ok(message)).await; + } + } + _ => {} // TODO: Handle other variants (Request, etc.) } } line.clear(); diff --git a/crates/mcp-core/src/prompt.rs b/crates/mcp-core/src/prompt.rs index 7b814fd4..4a0106e3 100644 --- a/crates/mcp-core/src/prompt.rs +++ b/crates/mcp-core/src/prompt.rs @@ -10,22 +10,28 @@ use serde::{Deserialize, Serialize}; pub struct Prompt { /// The name of the prompt pub name: String, - /// A description of what the prompt does - pub description: String, - /// The arguments that can be passed to customize the prompt - pub arguments: Vec, + /// Optional description of what the prompt does + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Optional arguments that can be passed to customize the prompt + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, } impl Prompt { /// Create a new prompt with the given name, description and arguments - pub fn new(name: N, description: D, arguments: Vec) -> Self + pub fn new( + name: N, + description: Option, + arguments: Option>, + ) -> Self where N: Into, D: Into, { Prompt { name: name.into(), - description: description.into(), + description: description.map(Into::into), arguments, } } @@ -37,9 +43,11 @@ pub struct PromptArgument { /// The name of the argument pub name: String, /// A description of what the argument is used for - pub description: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, /// Whether this argument is required - pub required: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option, } /// Represents the role of a message sender in a prompt conversation @@ -151,6 +159,6 @@ pub struct PromptTemplate { #[derive(Debug, Serialize, Deserialize)] pub struct PromptArgumentTemplate { pub name: String, - pub description: String, - pub required: bool, + pub description: Option, + pub required: Option, } diff --git a/crates/mcp-server/src/main.rs b/crates/mcp-server/src/main.rs index eee25002..907cc1b1 100644 --- a/crates/mcp-server/src/main.rs +++ b/crates/mcp-server/src/main.rs @@ -1,6 +1,7 @@ use anyhow::Result; use mcp_core::content::Content; -use mcp_core::handler::ResourceError; +use mcp_core::handler::{PromptError, ResourceError}; +use mcp_core::prompt::{Prompt, PromptArgument}; use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool}; use mcp_server::router::{CapabilitiesBuilder, RouterService}; use mcp_server::{ByteTransport, Router, Server}; @@ -61,6 +62,7 @@ impl Router for CounterRouter { CapabilitiesBuilder::new() .with_tools(false) .with_resources(false, false) + .with_prompts(false) .build() } @@ -153,6 +155,37 @@ impl Router for CounterRouter { } }) } + + fn list_prompts(&self) -> Vec { + vec![Prompt::new( + "example_prompt", + Some("This is an example prompt that takes one required agrument, message"), + Some(vec![PromptArgument { + name: "message".to_string(), + description: Some("A message to put in the prompt".to_string()), + required: Some(true), + }]), + )] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + match prompt_name.as_str() { + "example_prompt" => { + let prompt = "This is an example prompt with your message here: '{message}'"; + Ok(prompt.to_string()) + } + _ => Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))), + } + }) + } } #[tokio::main] diff --git a/crates/mcp-server/src/router.rs b/crates/mcp-server/src/router.rs index d2918311..2c277d1c 100644 --- a/crates/mcp-server/src/router.rs +++ b/crates/mcp-server/src/router.rs @@ -97,12 +97,8 @@ pub trait Router: Send + Sync + 'static { &self, uri: &str, ) -> Pin> + Send + 'static>>; - fn list_prompts(&self) -> Option> { - None - } - fn get_prompt(&self, _prompt_name: &str) -> Option { - None - } + fn list_prompts(&self) -> Vec; + fn get_prompt(&self, prompt_name: &str) -> PromptFuture; // Helper method to create base response fn create_response(&self, id: Option) -> JsonRpcResponse { @@ -257,7 +253,7 @@ pub trait Router: Send + Sync + 'static { req: JsonRpcRequest, ) -> impl Future> + Send { async move { - let prompts = self.list_prompts().unwrap_or_default(); + let prompts = self.list_prompts(); let result = ListPromptsResult { prompts }; @@ -294,36 +290,36 @@ pub trait Router: Send + Sync + 'static { .ok_or_else(|| RouterError::InvalidParams("Missing arguments object".into()))?; // Fetch the prompt definition first - let prompt = match self.list_prompts() { - Some(prompts) => prompts - .into_iter() - .find(|p| p.name == prompt_name) - .ok_or_else(|| { - RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name)) - })?, - None => return Err(RouterError::PromptNotFound("No prompts available".into())), - }; + let prompt = self + .list_prompts() + .into_iter() + .find(|p| p.name == prompt_name) + .ok_or_else(|| { + RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name)) + })?; // Validate required arguments - for arg in &prompt.arguments { - if arg.required - && (!arguments.contains_key(&arg.name) - || arguments - .get(&arg.name) - .and_then(Value::as_str) - .is_none_or(str::is_empty)) - { - return Err(RouterError::InvalidParams(format!( - "Missing required argument: '{}'", - arg.name - ))); + if let Some(args) = &prompt.arguments { + for arg in args { + if arg.required.is_some() + && arg.required.unwrap() + && (!arguments.contains_key(&arg.name) + || arguments + .get(&arg.name) + .and_then(Value::as_str) + .is_none_or(str::is_empty)) + { + return Err(RouterError::InvalidParams(format!( + "Missing required argument: '{}'", + arg.name + ))); + } } } // Now get the prompt content let description = self .get_prompt(prompt_name) - .ok_or_else(|| RouterError::PromptNotFound("Prompt not found".into()))? .await .map_err(|e| RouterError::Internal(e.to_string()))?; diff --git a/ui/desktop/src/components/UserMessage.tsx b/ui/desktop/src/components/UserMessage.tsx index 42bb072c..28d49cb0 100644 --- a/ui/desktop/src/components/UserMessage.tsx +++ b/ui/desktop/src/components/UserMessage.tsx @@ -11,7 +11,7 @@ interface UserMessageProps { export default function UserMessage({ message }: UserMessageProps) { // Extract text content from the message const textContent = getTextContent(message); - + // Extract URLs which explicitly contain the http:// or https:// protocol const urls = extractUrls(textContent, []); @@ -33,4 +33,4 @@ export default function UserMessage({ message }: UserMessageProps) { ); -} \ No newline at end of file +}