feat(cli): add mcp prompt support via slash commands (#1323)

This commit is contained in:
Kalvin C
2025-02-27 15:47:29 -08:00
committed by GitHub
parent 5bf05d545e
commit d0ca46983e
24 changed files with 958 additions and 82 deletions

1
Cargo.lock generated
View File

@@ -2203,6 +2203,7 @@ dependencies = [
"serde",
"serde_json",
"serde_yaml",
"shlex",
"temp-env",
"tempfile",
"test-case",

View File

@@ -47,6 +47,7 @@ chrono = "0.4"
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json", "time"] }
tracing-appender = "0.2"
once_cell = "1.20.2"
shlex = "1.3.0"
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }

View File

@@ -1,5 +1,7 @@
use anyhow::Result;
use rustyline::Editor;
use shlex;
use std::collections::HashMap;
#[derive(Debug)]
pub enum InputResult {
@@ -9,6 +11,15 @@ pub enum InputResult {
AddBuiltin(String),
ToggleTheme,
Retry,
ListPrompts,
PromptCommand(PromptCommandOptions),
}
#[derive(Debug)]
pub struct PromptCommandOptions {
pub name: String,
pub info: bool,
pub arguments: HashMap<String, String>,
}
pub fn get_input(
@@ -59,12 +70,67 @@ fn handle_slash_command(input: &str) -> Option<InputResult> {
Some(InputResult::Retry)
}
"/t" => Some(InputResult::ToggleTheme),
"/prompts" => Some(InputResult::ListPrompts),
s if s.starts_with("/prompt") => {
if s == "/prompt" {
// No arguments case
Some(InputResult::PromptCommand(PromptCommandOptions {
name: String::new(), // Empty name will trigger the error message in the rendering
info: false,
arguments: HashMap::new(),
}))
} else if let Some(stripped) = s.strip_prefix("/prompt ") {
// Has arguments case
parse_prompt_command(stripped)
} else {
// Handle invalid cases like "/promptxyz"
None
}
}
s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())),
s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())),
_ => None,
}
}
fn parse_prompt_command(args: &str) -> Option<InputResult> {
let parts: Vec<String> = shlex::split(args).unwrap_or_default();
// set name to empty and error out in the rendering
let mut options = PromptCommandOptions {
name: parts.first().cloned().unwrap_or_default(),
info: false,
arguments: HashMap::new(),
};
// handle info at any point in the command
if parts.iter().any(|part| part == "--info") {
options.info = true;
}
// Parse remaining arguments
let mut i = 1;
while i < parts.len() {
let part = &parts[i];
// Skip flag arguments
if part == "--info" {
i += 1;
continue;
}
// Process key=value pairs - removed redundant contains check
if let Some((key, value)) = part.split_once('=') {
options.arguments.insert(key.to_string(), value.to_string());
}
i += 1;
}
Some(InputResult::PromptCommand(options))
}
fn print_help() {
println!(
"Available commands:
@@ -72,6 +138,8 @@ fn print_help() {
/t - Toggle Light/Dark/Ansi theme
/extension <command> - Add a stdio extension (format: ENV1=val1 command args...)
/builtin <names> - Add builtin extensions by name (comma-separated)
/prompts - List all available prompts by name
/prompt <name> [--info] [key=value...] - Get prompt info or execute a prompt
/? or /help - Display this help message
Navigation:
@@ -131,6 +199,33 @@ mod tests {
assert!(handle_slash_command("/unknown").is_none());
}
#[test]
fn test_prompt_command() {
// Test basic prompt info command
if let Some(InputResult::PromptCommand(opts)) =
handle_slash_command("/prompt test-prompt --info")
{
assert_eq!(opts.name, "test-prompt");
assert!(opts.info);
assert!(opts.arguments.is_empty());
} else {
panic!("Expected PromptCommand");
}
// Test prompt with arguments
if let Some(InputResult::PromptCommand(opts)) =
handle_slash_command("/prompt test-prompt arg1=val1 arg2=val2")
{
assert_eq!(opts.name, "test-prompt");
assert!(!opts.info);
assert_eq!(opts.arguments.len(), 2);
assert_eq!(opts.arguments.get("arg1"), Some(&"val1".to_string()));
assert_eq!(opts.arguments.get("arg2"), Some(&"val2".to_string()));
} else {
panic!("Expected PromptCommand");
}
}
// Test whitespace handling
#[test]
fn test_whitespace_handling() {
@@ -149,4 +244,73 @@ mod tests {
panic!("Expected AddBuiltin");
}
}
// Test prompt with no arguments
#[test]
fn test_prompt_no_args() {
// Test just "/prompt" with no arguments
if let Some(InputResult::PromptCommand(opts)) = handle_slash_command("/prompt") {
assert_eq!(opts.name, "");
assert!(!opts.info);
assert!(opts.arguments.is_empty());
} else {
panic!("Expected PromptCommand");
}
// Test invalid prompt command
assert!(handle_slash_command("/promptxyz").is_none());
}
// Test quoted arguments
#[test]
fn test_quoted_arguments() {
// Test prompt with quoted arguments
if let Some(InputResult::PromptCommand(opts)) = handle_slash_command(
r#"/prompt test-prompt arg1="value with spaces" arg2="another value""#,
) {
assert_eq!(opts.name, "test-prompt");
assert_eq!(opts.arguments.len(), 2);
assert_eq!(
opts.arguments.get("arg1"),
Some(&"value with spaces".to_string())
);
assert_eq!(
opts.arguments.get("arg2"),
Some(&"another value".to_string())
);
} else {
panic!("Expected PromptCommand");
}
// Test prompt with mixed quoted and unquoted arguments
if let Some(InputResult::PromptCommand(opts)) = handle_slash_command(
r#"/prompt test-prompt simple=value quoted="value with \"nested\" quotes""#,
) {
assert_eq!(opts.name, "test-prompt");
assert_eq!(opts.arguments.len(), 2);
assert_eq!(opts.arguments.get("simple"), Some(&"value".to_string()));
assert_eq!(
opts.arguments.get("quoted"),
Some(&r#"value with "nested" quotes"#.to_string())
);
} else {
panic!("Expected PromptCommand");
}
}
// Test invalid arguments
#[test]
fn test_invalid_arguments() {
// Test prompt with invalid arguments
if let Some(InputResult::PromptCommand(opts)) =
handle_slash_command(r#"/prompt test-prompt valid=value invalid_arg another_invalid"#)
{
assert_eq!(opts.name, "test-prompt");
assert_eq!(opts.arguments.len(), 1);
assert_eq!(opts.arguments.get("valid"), Some(&"value".to_string()));
// Invalid arguments are ignored but logged
} else {
panic!("Expected PromptCommand");
}
}
}

View File

@@ -14,7 +14,11 @@ use goose::agents::extension::{Envs, ExtensionConfig};
use goose::agents::Agent;
use goose::message::{Message, MessageContent};
use mcp_core::handler::ToolError;
use mcp_core::prompt::PromptMessage;
use rand::{distributions::Alphanumeric, Rng};
use serde_json::Value;
use std::collections::HashMap;
use std::path::PathBuf;
use tokio;
@@ -104,6 +108,40 @@ impl Session {
Ok(())
}
pub async fn list_prompts(&mut self) -> HashMap<String, Vec<String>> {
let prompts = self.agent.list_extension_prompts().await;
prompts
.into_iter()
.map(|(extension, prompt_list)| {
let names = prompt_list.into_iter().map(|p| p.name).collect();
(extension, names)
})
.collect()
}
pub async fn get_prompt_info(&mut self, name: &str) -> Result<Option<output::PromptInfo>> {
let prompts = self.agent.list_extension_prompts().await;
// Find which extension has this prompt
for (extension, prompt_list) in prompts {
if let Some(prompt) = prompt_list.iter().find(|p| p.name == name) {
return Ok(Some(output::PromptInfo {
name: prompt.name.clone(),
description: prompt.description.clone(),
arguments: prompt.arguments.clone(),
extension: Some(extension),
}));
}
}
Ok(None)
}
pub async fn get_prompt(&mut self, name: &str, arguments: Value) -> Result<Vec<PromptMessage>> {
let result = self.agent.get_prompt(name, arguments).await?;
Ok(result.messages)
}
/// Process a single message and get the response
async fn process_message(&mut self, message: String) -> Result<()> {
self.messages.push(Message::user().with_text(&message));
@@ -179,6 +217,68 @@ impl Session {
continue;
}
input::InputResult::Retry => continue,
input::InputResult::ListPrompts => {
output::render_prompts(&self.list_prompts().await)
}
input::InputResult::PromptCommand(opts) => {
// name is required
if opts.name.is_empty() {
output::render_error("Prompt name argument is required");
continue;
}
if opts.info {
match self.get_prompt_info(&opts.name).await? {
Some(info) => output::render_prompt_info(&info),
None => {
output::render_error(&format!("Prompt '{}' not found", opts.name))
}
}
} else {
// Convert the arguments HashMap to a Value
let arguments = serde_json::to_value(opts.arguments)
.map_err(|e| anyhow::anyhow!("Failed to serialize arguments: {}", e))?;
match self.get_prompt(&opts.name, arguments).await {
Ok(messages) => {
let start_len = self.messages.len();
let mut valid = true;
for (i, prompt_message) in messages.into_iter().enumerate() {
let msg = Message::from(prompt_message);
// ensure we get a User - Assistant - User type pattern
let expected_role = if i % 2 == 0 {
mcp_core::Role::User
} else {
mcp_core::Role::Assistant
};
if msg.role != expected_role {
output::render_error(&format!(
"Expected {:?} message at position {}, but found {:?}",
expected_role, i, msg.role
));
valid = false;
// get rid of everything we added to messages
self.messages.truncate(start_len);
break;
}
if msg.role == mcp_core::Role::User {
output::render_message(&msg);
}
self.messages.push(msg);
}
if valid {
output::show_thinking();
self.process_agent_response(true).await?;
output::hide_thinking();
}
}
Err(e) => output::render_error(&e.to_string()),
}
}
}
}
}

View File

@@ -2,9 +2,11 @@ use bat::WrappingMode;
use console::style;
use goose::config::Config;
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
use mcp_core::prompt::PromptArgument;
use mcp_core::tool::ToolCall;
use serde_json::Value;
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
// Re-export theme for use in main
@@ -73,6 +75,14 @@ impl ThinkingIndicator {
}
}
#[derive(Debug)]
pub struct PromptInfo {
pub name: String,
pub description: Option<String>,
pub arguments: Option<Vec<PromptArgument>>,
pub extension: Option<String>,
}
// Global thinking indicator
thread_local! {
static THINKING: RefCell<ThinkingIndicator> = RefCell::new(ThinkingIndicator::default());
@@ -154,6 +164,51 @@ pub fn render_error(message: &str) {
println!("\n {} {}\n", style("error:").red().bold(), message);
}
pub fn render_prompts(prompts: &HashMap<String, Vec<String>>) {
println!();
for (extension, prompts) in prompts {
println!(" {}", style(extension).green());
for prompt in prompts {
println!(" - {}", style(prompt).cyan());
}
}
println!();
}
pub fn render_prompt_info(info: &PromptInfo) {
println!();
if let Some(ext) = &info.extension {
println!(" {}: {}", style("Extension").green(), ext);
}
println!(" Prompt: {}", style(&info.name).cyan().bold());
if let Some(desc) = &info.description {
println!("\n {}", desc);
}
if let Some(args) = &info.arguments {
println!("\n Arguments:");
for arg in args {
let required = arg.required.unwrap_or(false);
let req_str = if required {
style("(required)").red()
} else {
style("(optional)").dim()
};
println!(
" {} {} {}",
style(&arg.name).yellow(),
req_str,
arg.description.as_deref().unwrap_or("")
);
}
}
println!();
}
pub fn render_extension_success(name: &str) {
println!();
println!(

View File

@@ -9,7 +9,8 @@ use std::{
use tokio::process::Command;
use mcp_core::{
handler::{ResourceError, ToolError},
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
resource::Resource,
tool::Tool,
@@ -819,4 +820,21 @@ impl Router for ComputerControllerRouter {
}
})
}
fn list_prompts(&self) -> Vec<Prompt> {
vec![]
}
fn get_prompt(
&self,
prompt_name: &str,
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
let prompt_name = prompt_name.to_string();
Box::pin(async move {
Err(PromptError::NotFound(format!(
"Prompt {} not found",
prompt_name
)))
})
}
}

View File

@@ -72,9 +72,9 @@ pub fn load_prompt_files() -> HashMap<String, Prompt> {
description: arg.description,
required: arg.required,
})
.collect();
.collect::<Vec<PromptArgument>>();
let prompt = Prompt::new(&template.id, &template.template, arguments);
let prompt = Prompt::new(&template.id, Some(&template.template), Some(arguments));
if prompts.contains_key(&prompt.name) {
eprintln!("Duplicate prompt name '{}' found. Skipping.", prompt.name);
@@ -905,47 +905,35 @@ impl Router for DeveloperRouter {
Box::pin(async move { Ok("".to_string()) })
}
fn list_prompts(&self) -> Option<Vec<Prompt>> {
if self.prompts.is_empty() {
None
} else {
Some(self.prompts.values().cloned().collect())
}
fn list_prompts(&self) -> Vec<Prompt> {
self.prompts.values().cloned().collect()
}
fn get_prompt(
&self,
prompt_name: &str,
) -> Option<Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>>> {
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
let prompt_name = prompt_name.trim().to_owned();
// Validate prompt name is not empty
if prompt_name.is_empty() {
return Some(Box::pin(async move {
return Box::pin(async move {
Err(PromptError::InvalidParameters(
"Prompt name cannot be empty".to_string(),
))
}));
});
}
let prompts = Arc::clone(&self.prompts);
Some(Box::pin(async move {
Box::pin(async move {
match prompts.get(&prompt_name) {
Some(prompt) => {
if prompt.description.trim().is_empty() {
Err(PromptError::InternalError(format!(
"Prompt '{prompt_name}' has an empty description"
)))
} else {
Ok(prompt.description.clone())
}
}
Some(prompt) => Ok(prompt.description.clone().unwrap_or_default()),
None => Err(PromptError::NotFound(format!(
"Prompt '{prompt_name}' not found"
))),
}
}))
})
}
}

View File

@@ -5,7 +5,8 @@ use serde_json::{json, Value};
use std::{env, fs, future::Future, io::Write, path::Path, pin::Pin};
use mcp_core::{
handler::{ResourceError, ToolError},
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
resource::Resource,
tool::Tool,
@@ -618,6 +619,23 @@ impl Router for GoogleDriveRouter {
let uri_clone = uri.to_string();
Box::pin(async move { this.read_google_resource(uri_clone).await })
}
fn list_prompts(&self) -> Vec<Prompt> {
vec![]
}
fn get_prompt(
&self,
prompt_name: &str,
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
let prompt_name = prompt_name.to_string();
Box::pin(async move {
Err(PromptError::NotFound(format!(
"Prompt {} not found",
prompt_name
)))
})
}
}
impl Clone for GoogleDriveRouter {

View File

@@ -3,7 +3,8 @@ mod proxy;
use anyhow::Result;
use mcp_core::{
content::Content,
handler::{ResourceError, ToolError},
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
resource::Resource,
role::Role,
@@ -176,6 +177,23 @@ impl Router for JetBrainsRouter {
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> {
Box::pin(async { Err(ResourceError::NotFound("Resource not found".into())) })
}
fn list_prompts(&self) -> Vec<Prompt> {
vec![]
}
fn get_prompt(
&self,
prompt_name: &str,
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
let prompt_name = prompt_name.to_string();
Box::pin(async move {
Err(PromptError::NotFound(format!(
"Prompt {} not found",
prompt_name
)))
})
}
}
impl Clone for JetBrainsRouter {

View File

@@ -12,7 +12,8 @@ use std::{
};
use mcp_core::{
handler::{ResourceError, ToolError},
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
resource::Resource,
tool::{Tool, ToolCall},
@@ -493,6 +494,22 @@ impl Router for MemoryRouter {
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> {
Box::pin(async move { Ok("".to_string()) })
}
fn list_prompts(&self) -> Vec<Prompt> {
vec![]
}
fn get_prompt(
&self,
prompt_name: &str,
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
let prompt_name = prompt_name.to_string();
Box::pin(async move {
Err(PromptError::NotFound(format!(
"Prompt {} not found",
prompt_name
)))
})
}
}
#[derive(Debug)]

View File

@@ -5,7 +5,8 @@ use serde_json::{json, Value};
use std::{future::Future, pin::Pin};
use mcp_core::{
handler::{ResourceError, ToolError},
handler::{PromptError, ResourceError, ToolError},
prompt::Prompt,
protocol::ServerCapabilities,
resource::Resource,
role::Role,
@@ -156,6 +157,23 @@ impl Router for TutorialRouter {
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> {
Box::pin(async move { Ok("".to_string()) })
}
fn list_prompts(&self) -> Vec<Prompt> {
vec![]
}
fn get_prompt(
&self,
prompt_name: &str,
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
let prompt_name = prompt_name.to_string();
Box::pin(async move {
Err(PromptError::NotFound(format!(
"Prompt {} not found",
prompt_name
)))
})
}
}
impl Clone for TutorialRouter {

View File

@@ -1,3 +1,5 @@
use std::collections::HashMap;
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
@@ -6,6 +8,8 @@ use serde_json::Value;
use super::extension::{ExtensionConfig, ExtensionResult};
use crate::message::Message;
use crate::providers::base::ProviderUsage;
use mcp_core::prompt::Prompt;
use mcp_core::protocol::GetPromptResult;
/// Core trait defining the behavior of an Agent
#[async_trait]
@@ -37,4 +41,11 @@ pub trait Agent: Send + Sync {
/// Override the system prompt with custom text
async fn override_system_prompt(&mut self, template: String);
/// Lists all prompts from all extensions
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>>;
/// Get a prompt result with the given name and arguments
/// Returns the prompt text that would be used as user input
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult>;
}

View File

@@ -1,6 +1,8 @@
use anyhow::Result;
use chrono::{DateTime, TimeZone, Utc};
use futures::stream::{FuturesUnordered, StreamExt};
use mcp_client::McpService;
use mcp_core::protocol::GetPromptResult;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::sync::LazyLock;
@@ -13,7 +15,7 @@ use crate::prompt_template::{load_prompt, load_prompt_file};
use crate::providers::base::{Provider, ProviderUsage};
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
use mcp_core::{Content, Tool, ToolCall, ToolError, ToolResult};
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult};
use serde_json::Value;
// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp
@@ -544,6 +546,87 @@ impl Capabilities {
result
}
pub async fn list_prompts_from_extension(
&self,
extension_name: &str,
) -> Result<Vec<Prompt>, ToolError> {
let client = self.clients.get(extension_name).ok_or_else(|| {
ToolError::InvalidParameters(format!("Extension {} is not valid", extension_name))
})?;
let client_guard = client.lock().await;
client_guard
.list_prompts(None)
.await
.map_err(|e| {
ToolError::ExecutionError(format!(
"Unable to list prompts for {}, {:?}",
extension_name, e
))
})
.map(|lp| lp.prompts)
}
pub async fn list_prompts(&self) -> Result<HashMap<String, Vec<Prompt>>, ToolError> {
let mut futures = FuturesUnordered::new();
for extension_name in self.clients.keys() {
futures.push(async move {
(
extension_name,
self.list_prompts_from_extension(extension_name).await,
)
});
}
let mut all_prompts = HashMap::new();
let mut errors = Vec::new();
// Process results as they complete
while let Some(result) = futures.next().await {
let (name, prompts) = result;
match prompts {
Ok(content) => {
all_prompts.insert(name.to_string(), content);
}
Err(tool_error) => {
errors.push(tool_error);
}
}
}
// Log any errors that occurred
if !errors.is_empty() {
tracing::error!(
errors = ?errors
.into_iter()
.map(|e| format!("{:?}", e))
.collect::<Vec<_>>(),
"errors from listing prompts"
);
}
Ok(all_prompts)
}
pub async fn get_prompt(
&self,
extension_name: &str,
name: &str,
arguments: Value,
) -> Result<GetPromptResult> {
let client = self
.clients
.get(extension_name)
.ok_or_else(|| anyhow::anyhow!("Extension {} not found", extension_name))?;
let client_guard = client.lock().await;
client_guard
.get_prompt(name, arguments)
.await
.map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e))
}
}
#[cfg(test)]
@@ -556,7 +639,8 @@ mod tests {
use mcp_client::client::Error;
use mcp_client::client::McpClientTrait;
use mcp_core::protocol::{
CallToolResult, InitializeResult, ListResourcesResult, ListToolsResult, ReadResourceResult,
CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult,
ListToolsResult, ReadResourceResult,
};
use serde_json::json;
@@ -625,6 +709,21 @@ mod tests {
_ => Err(Error::NotInitialized),
}
}
async fn list_prompts(
&self,
_next_cursor: Option<String>,
) -> Result<ListPromptsResult, Error> {
Err(Error::NotInitialized)
}
async fn get_prompt(
&self,
_name: &str,
_arguments: Value,
) -> Result<GetPromptResult, Error> {
Err(Error::NotInitialized)
}
}
#[test]

View File

@@ -2,6 +2,7 @@
/// It makes no attempt to handle context limits, and cannot read resources
use async_trait::async_trait;
use futures::stream::BoxStream;
use std::collections::HashMap;
use tokio::sync::Mutex;
use tracing::{debug, instrument};
@@ -13,7 +14,10 @@ use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
use crate::register_agent;
use crate::token_counter::TokenCounter;
use anyhow::{anyhow, Result};
use indoc::indoc;
use mcp_core::prompt::Prompt;
use mcp_core::protocol::GetPromptResult;
use mcp_core::tool::Tool;
use serde_json::{json, Value};
@@ -198,6 +202,37 @@ impl Agent for ReferenceAgent {
let mut capabilities = self.capabilities.lock().await;
capabilities.set_system_prompt_override(template);
}
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
let capabilities = self.capabilities.lock().await;
capabilities
.list_prompts()
.await
.expect("Failed to list prompts")
}
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
let capabilities = self.capabilities.lock().await;
// First find which extension has this prompt
let prompts = capabilities
.list_prompts()
.await
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
if let Some(extension) = prompts
.iter()
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
.map(|(extension, _)| extension)
{
return capabilities
.get_prompt(extension, name, arguments)
.await
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
}
Err(anyhow!("Prompt '{}' not found", name))
}
}
register_agent!("reference", ReferenceAgent);

View File

@@ -2,6 +2,7 @@
/// It makes no attempt to handle context limits, and cannot read resources
use async_trait::async_trait;
use futures::stream::BoxStream;
use std::collections::HashMap;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tracing::{debug, error, instrument, warn};
@@ -19,7 +20,10 @@ use crate::providers::errors::ProviderError;
use crate::register_agent;
use crate::token_counter::TokenCounter;
use crate::truncate::{truncate_messages, OldestFirstTruncation};
use anyhow::{anyhow, Result};
use indoc::indoc;
use mcp_core::prompt::Prompt;
use mcp_core::protocol::GetPromptResult;
use mcp_core::{tool::Tool, Content};
use serde_json::{json, Value};
@@ -398,6 +402,37 @@ impl Agent for TruncateAgent {
let mut capabilities = self.capabilities.lock().await;
capabilities.set_system_prompt_override(template);
}
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
let capabilities = self.capabilities.lock().await;
capabilities
.list_prompts()
.await
.expect("Failed to list prompts")
}
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
let capabilities = self.capabilities.lock().await;
// First find which extension has this prompt
let prompts = capabilities
.list_prompts()
.await
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
if let Some(extension) = prompts
.iter()
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
.map(|(extension, _)| extension)
{
return capabilities
.get_prompt(extension, name, arguments)
.await
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
}
Err(anyhow!("Prompt '{}' not found", name))
}
}
register_agent!("truncate", TruncateAgent);

View File

@@ -10,6 +10,8 @@ use std::collections::HashSet;
use chrono::Utc;
use mcp_core::content::{Content, ImageContent, TextContent};
use mcp_core::handler::ToolResult;
use mcp_core::prompt::{PromptMessage, PromptMessageContent, PromptMessageRole};
use mcp_core::resource::ResourceContents;
use mcp_core::role::Role;
use mcp_core::tool::ToolCall;
use serde_json::Value;
@@ -156,6 +158,37 @@ impl From<Content> for MessageContent {
}
}
impl From<PromptMessage> for Message {
fn from(prompt_message: PromptMessage) -> Self {
// Create a new message with the appropriate role
let message = match prompt_message.role {
PromptMessageRole::User => Message::user(),
PromptMessageRole::Assistant => Message::assistant(),
};
// Convert and add the content
let content = match prompt_message.content {
PromptMessageContent::Text { text } => MessageContent::text(text),
PromptMessageContent::Image { image } => {
MessageContent::image(image.data, image.mime_type)
}
PromptMessageContent::Resource { resource } => {
// For resources, convert to text content with the resource text
match resource.resource {
ResourceContents::TextResourceContents { text, .. } => {
MessageContent::text(text)
}
ResourceContents::BlobResourceContents { blob, .. } => {
MessageContent::text(format!("[Binary content: {}]", blob))
}
}
}
};
message.with_content(content)
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// A message to or from an LLM
#[serde(rename_all = "camelCase")]
@@ -305,7 +338,10 @@ impl Message {
#[cfg(test)]
mod tests {
use super::*;
use mcp_core::content::EmbeddedResource;
use mcp_core::handler::ToolError;
use mcp_core::prompt::PromptMessageContent;
use mcp_core::resource::ResourceContents;
use serde_json::{json, Value};
#[test]
@@ -420,4 +456,158 @@ mod tests {
panic!("Expected ToolRequest content");
}
}
#[test]
fn test_from_prompt_message_text() {
let prompt_content = PromptMessageContent::Text {
text: "Hello, world!".to_string(),
};
let prompt_message = PromptMessage {
role: PromptMessageRole::User,
content: prompt_content,
};
let message = Message::from(prompt_message);
if let MessageContent::Text(text_content) = &message.content[0] {
assert_eq!(text_content.text, "Hello, world!");
} else {
panic!("Expected MessageContent::Text");
}
}
#[test]
fn test_from_prompt_message_image() {
let prompt_content = PromptMessageContent::Image {
image: ImageContent {
data: "base64data".to_string(),
mime_type: "image/jpeg".to_string(),
annotations: None,
},
};
let prompt_message = PromptMessage {
role: PromptMessageRole::User,
content: prompt_content,
};
let message = Message::from(prompt_message);
if let MessageContent::Image(image_content) = &message.content[0] {
assert_eq!(image_content.data, "base64data");
assert_eq!(image_content.mime_type, "image/jpeg");
} else {
panic!("Expected MessageContent::Image");
}
}
#[test]
fn test_from_prompt_message_text_resource() {
let resource = ResourceContents::TextResourceContents {
uri: "file:///test.txt".to_string(),
mime_type: Some("text/plain".to_string()),
text: "Resource content".to_string(),
};
let prompt_content = PromptMessageContent::Resource {
resource: EmbeddedResource {
resource,
annotations: None,
},
};
let prompt_message = PromptMessage {
role: PromptMessageRole::User,
content: prompt_content,
};
let message = Message::from(prompt_message);
if let MessageContent::Text(text_content) = &message.content[0] {
assert_eq!(text_content.text, "Resource content");
} else {
panic!("Expected MessageContent::Text");
}
}
#[test]
fn test_from_prompt_message_blob_resource() {
let resource = ResourceContents::BlobResourceContents {
uri: "file:///test.bin".to_string(),
mime_type: Some("application/octet-stream".to_string()),
blob: "binary_data".to_string(),
};
let prompt_content = PromptMessageContent::Resource {
resource: EmbeddedResource {
resource,
annotations: None,
},
};
let prompt_message = PromptMessage {
role: PromptMessageRole::User,
content: prompt_content,
};
let message = Message::from(prompt_message);
if let MessageContent::Text(text_content) = &message.content[0] {
assert_eq!(text_content.text, "[Binary content: binary_data]");
} else {
panic!("Expected MessageContent::Text");
}
}
#[test]
fn test_from_prompt_message() {
// Test user message conversion
let prompt_message = PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::Text {
text: "Hello, world!".to_string(),
},
};
let message = Message::from(prompt_message);
assert_eq!(message.role, Role::User);
assert_eq!(message.content.len(), 1);
assert_eq!(message.as_concat_text(), "Hello, world!");
// Test assistant message conversion
let prompt_message = PromptMessage {
role: PromptMessageRole::Assistant,
content: PromptMessageContent::Text {
text: "I can help with that.".to_string(),
},
};
let message = Message::from(prompt_message);
assert_eq!(message.role, Role::Assistant);
assert_eq!(message.content.len(), 1);
assert_eq!(message.as_concat_text(), "I can help with that.");
}
#[test]
fn test_message_with_text() {
let message = Message::user().with_text("Hello");
assert_eq!(message.as_concat_text(), "Hello");
}
#[test]
fn test_message_with_tool_request() {
let tool_call = Ok(ToolCall {
name: "test_tool".to_string(),
arguments: serde_json::json!({}),
});
let message = Message::assistant().with_tool_request("req1", tool_call);
assert!(message.is_tool_call());
assert!(!message.is_tool_response());
let ids = message.get_tool_ids();
assert_eq!(ids.len(), 1);
assert!(ids.contains("req1"));
}
}

View File

@@ -82,5 +82,16 @@ async fn main() -> Result<(), ClientError> {
let resource = client.read_resource("memo://insights").await?;
println!("Resource: {resource:?}\n");
let prompts = client.list_prompts(None).await?;
println!("Prompts: {prompts:?}\n");
let prompt = client
.get_prompt(
"example_prompt",
serde_json::json!({"message": "hello there!"}),
)
.await?;
println!("Prompt: {prompt:?}\n");
Ok(())
}

View File

@@ -1,7 +1,7 @@
use mcp_core::protocol::{
CallToolResult, Implementation, InitializeResult, JsonRpcError, JsonRpcMessage,
JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult,
ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError,
JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult,
ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -93,6 +93,10 @@ pub trait McpClientTrait: Send + Sync {
async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error>;
async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error>;
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
}
/// The MCP client is the interface for MCP operations.
@@ -346,4 +350,42 @@ where
// https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2
self.send_request("tools/call", params).await
}
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If prompts is not supported, return an error
if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'prompts' capability".to_string(),
});
}
let payload = next_cursor
.map(|cursor| serde_json::json!({"cursor": cursor}))
.unwrap_or_else(|| serde_json::json!({}));
self.send_request("prompts/list", payload).await
}
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
// If prompts is not supported, return an error
if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'prompts' capability".to_string(),
});
}
let params = serde_json::json!({ "name": name, "arguments": arguments });
self.send_request("prompts/get", params).await
}
}

View File

@@ -111,13 +111,23 @@ impl SseActor {
// Attempt to parse the SSE data as a JsonRpcMessage
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
Ok(message) => {
// If it's a response, complete the pending request
if let JsonRpcMessage::Response(resp) = &message {
if let Some(id) = &resp.id {
pending_requests.respond(&id.to_string(), Ok(message)).await;
match &message {
JsonRpcMessage::Response(response) => {
if let Some(id) = &response.id {
pending_requests
.respond(&id.to_string(), Ok(message))
.await;
}
}
JsonRpcMessage::Error(error) => {
if let Some(id) = &error.id {
pending_requests
.respond(&id.to_string(), Ok(message))
.await;
}
}
_ => {} // TODO: Handle other variants (Request, etc.)
}
// If it's something else (notification, etc.), handle as needed
}
Err(err) => {
warn!("Failed to parse SSE message: {err}");

View File

@@ -87,10 +87,18 @@ impl StdioActor {
"Received incoming message"
);
if let JsonRpcMessage::Response(response) = &message {
if let Some(id) = &response.id {
pending_requests.respond(&id.to_string(), Ok(message)).await;
match &message {
JsonRpcMessage::Response(response) => {
if let Some(id) = &response.id {
pending_requests.respond(&id.to_string(), Ok(message)).await;
}
}
JsonRpcMessage::Error(error) => {
if let Some(id) = &error.id {
pending_requests.respond(&id.to_string(), Ok(message)).await;
}
}
_ => {} // TODO: Handle other variants (Request, etc.)
}
}
line.clear();

View File

@@ -10,22 +10,28 @@ use serde::{Deserialize, Serialize};
pub struct Prompt {
/// The name of the prompt
pub name: String,
/// A description of what the prompt does
pub description: String,
/// The arguments that can be passed to customize the prompt
pub arguments: Vec<PromptArgument>,
/// Optional description of what the prompt does
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
/// Optional arguments that can be passed to customize the prompt
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<Vec<PromptArgument>>,
}
impl Prompt {
/// Create a new prompt with the given name, description and arguments
pub fn new<N, D>(name: N, description: D, arguments: Vec<PromptArgument>) -> Self
pub fn new<N, D>(
name: N,
description: Option<D>,
arguments: Option<Vec<PromptArgument>>,
) -> Self
where
N: Into<String>,
D: Into<String>,
{
Prompt {
name: name.into(),
description: description.into(),
description: description.map(Into::into),
arguments,
}
}
@@ -37,9 +43,11 @@ pub struct PromptArgument {
/// The name of the argument
pub name: String,
/// A description of what the argument is used for
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
/// Whether this argument is required
pub required: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
}
/// Represents the role of a message sender in a prompt conversation
@@ -151,6 +159,6 @@ pub struct PromptTemplate {
#[derive(Debug, Serialize, Deserialize)]
pub struct PromptArgumentTemplate {
pub name: String,
pub description: String,
pub required: bool,
pub description: Option<String>,
pub required: Option<bool>,
}

View File

@@ -1,6 +1,7 @@
use anyhow::Result;
use mcp_core::content::Content;
use mcp_core::handler::ResourceError;
use mcp_core::handler::{PromptError, ResourceError};
use mcp_core::prompt::{Prompt, PromptArgument};
use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool};
use mcp_server::router::{CapabilitiesBuilder, RouterService};
use mcp_server::{ByteTransport, Router, Server};
@@ -61,6 +62,7 @@ impl Router for CounterRouter {
CapabilitiesBuilder::new()
.with_tools(false)
.with_resources(false, false)
.with_prompts(false)
.build()
}
@@ -153,6 +155,37 @@ impl Router for CounterRouter {
}
})
}
fn list_prompts(&self) -> Vec<Prompt> {
vec![Prompt::new(
"example_prompt",
Some("This is an example prompt that takes one required agrument, message"),
Some(vec![PromptArgument {
name: "message".to_string(),
description: Some("A message to put in the prompt".to_string()),
required: Some(true),
}]),
)]
}
fn get_prompt(
&self,
prompt_name: &str,
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
let prompt_name = prompt_name.to_string();
Box::pin(async move {
match prompt_name.as_str() {
"example_prompt" => {
let prompt = "This is an example prompt with your message here: '{message}'";
Ok(prompt.to_string())
}
_ => Err(PromptError::NotFound(format!(
"Prompt {} not found",
prompt_name
))),
}
})
}
}
#[tokio::main]

View File

@@ -97,12 +97,8 @@ pub trait Router: Send + Sync + 'static {
&self,
uri: &str,
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>>;
fn list_prompts(&self) -> Option<Vec<Prompt>> {
None
}
fn get_prompt(&self, _prompt_name: &str) -> Option<PromptFuture> {
None
}
fn list_prompts(&self) -> Vec<Prompt>;
fn get_prompt(&self, prompt_name: &str) -> PromptFuture;
// Helper method to create base response
fn create_response(&self, id: Option<u64>) -> JsonRpcResponse {
@@ -257,7 +253,7 @@ pub trait Router: Send + Sync + 'static {
req: JsonRpcRequest,
) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
async move {
let prompts = self.list_prompts().unwrap_or_default();
let prompts = self.list_prompts();
let result = ListPromptsResult { prompts };
@@ -294,36 +290,36 @@ pub trait Router: Send + Sync + 'static {
.ok_or_else(|| RouterError::InvalidParams("Missing arguments object".into()))?;
// Fetch the prompt definition first
let prompt = match self.list_prompts() {
Some(prompts) => prompts
.into_iter()
.find(|p| p.name == prompt_name)
.ok_or_else(|| {
RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name))
})?,
None => return Err(RouterError::PromptNotFound("No prompts available".into())),
};
let prompt = self
.list_prompts()
.into_iter()
.find(|p| p.name == prompt_name)
.ok_or_else(|| {
RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name))
})?;
// Validate required arguments
for arg in &prompt.arguments {
if arg.required
&& (!arguments.contains_key(&arg.name)
|| arguments
.get(&arg.name)
.and_then(Value::as_str)
.is_none_or(str::is_empty))
{
return Err(RouterError::InvalidParams(format!(
"Missing required argument: '{}'",
arg.name
)));
if let Some(args) = &prompt.arguments {
for arg in args {
if arg.required.is_some()
&& arg.required.unwrap()
&& (!arguments.contains_key(&arg.name)
|| arguments
.get(&arg.name)
.and_then(Value::as_str)
.is_none_or(str::is_empty))
{
return Err(RouterError::InvalidParams(format!(
"Missing required argument: '{}'",
arg.name
)));
}
}
}
// Now get the prompt content
let description = self
.get_prompt(prompt_name)
.ok_or_else(|| RouterError::PromptNotFound("Prompt not found".into()))?
.await
.map_err(|e| RouterError::Internal(e.to_string()))?;