mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 06:34:26 +01:00
feat(cli): add mcp prompt support via slash commands (#1323)
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2203,6 +2203,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
"shlex",
|
||||
"temp-env",
|
||||
"tempfile",
|
||||
"test-case",
|
||||
|
||||
@@ -47,6 +47,7 @@ chrono = "0.4"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json", "time"] }
|
||||
tracing-appender = "0.2"
|
||||
once_cell = "1.20.2"
|
||||
shlex = "1.3.0"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
winapi = { version = "0.3", features = ["wincred"] }
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use rustyline::Editor;
|
||||
use shlex;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum InputResult {
|
||||
@@ -9,6 +11,15 @@ pub enum InputResult {
|
||||
AddBuiltin(String),
|
||||
ToggleTheme,
|
||||
Retry,
|
||||
ListPrompts,
|
||||
PromptCommand(PromptCommandOptions),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PromptCommandOptions {
|
||||
pub name: String,
|
||||
pub info: bool,
|
||||
pub arguments: HashMap<String, String>,
|
||||
}
|
||||
|
||||
pub fn get_input(
|
||||
@@ -59,12 +70,67 @@ fn handle_slash_command(input: &str) -> Option<InputResult> {
|
||||
Some(InputResult::Retry)
|
||||
}
|
||||
"/t" => Some(InputResult::ToggleTheme),
|
||||
"/prompts" => Some(InputResult::ListPrompts),
|
||||
s if s.starts_with("/prompt") => {
|
||||
if s == "/prompt" {
|
||||
// No arguments case
|
||||
Some(InputResult::PromptCommand(PromptCommandOptions {
|
||||
name: String::new(), // Empty name will trigger the error message in the rendering
|
||||
info: false,
|
||||
arguments: HashMap::new(),
|
||||
}))
|
||||
} else if let Some(stripped) = s.strip_prefix("/prompt ") {
|
||||
// Has arguments case
|
||||
parse_prompt_command(stripped)
|
||||
} else {
|
||||
// Handle invalid cases like "/promptxyz"
|
||||
None
|
||||
}
|
||||
}
|
||||
s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())),
|
||||
s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_prompt_command(args: &str) -> Option<InputResult> {
|
||||
let parts: Vec<String> = shlex::split(args).unwrap_or_default();
|
||||
|
||||
// set name to empty and error out in the rendering
|
||||
let mut options = PromptCommandOptions {
|
||||
name: parts.first().cloned().unwrap_or_default(),
|
||||
info: false,
|
||||
arguments: HashMap::new(),
|
||||
};
|
||||
|
||||
// handle info at any point in the command
|
||||
if parts.iter().any(|part| part == "--info") {
|
||||
options.info = true;
|
||||
}
|
||||
|
||||
// Parse remaining arguments
|
||||
let mut i = 1;
|
||||
|
||||
while i < parts.len() {
|
||||
let part = &parts[i];
|
||||
|
||||
// Skip flag arguments
|
||||
if part == "--info" {
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Process key=value pairs - removed redundant contains check
|
||||
if let Some((key, value)) = part.split_once('=') {
|
||||
options.arguments.insert(key.to_string(), value.to_string());
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
Some(InputResult::PromptCommand(options))
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
println!(
|
||||
"Available commands:
|
||||
@@ -72,6 +138,8 @@ fn print_help() {
|
||||
/t - Toggle Light/Dark/Ansi theme
|
||||
/extension <command> - Add a stdio extension (format: ENV1=val1 command args...)
|
||||
/builtin <names> - Add builtin extensions by name (comma-separated)
|
||||
/prompts - List all available prompts by name
|
||||
/prompt <name> [--info] [key=value...] - Get prompt info or execute a prompt
|
||||
/? or /help - Display this help message
|
||||
|
||||
Navigation:
|
||||
@@ -131,6 +199,33 @@ mod tests {
|
||||
assert!(handle_slash_command("/unknown").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prompt_command() {
|
||||
// Test basic prompt info command
|
||||
if let Some(InputResult::PromptCommand(opts)) =
|
||||
handle_slash_command("/prompt test-prompt --info")
|
||||
{
|
||||
assert_eq!(opts.name, "test-prompt");
|
||||
assert!(opts.info);
|
||||
assert!(opts.arguments.is_empty());
|
||||
} else {
|
||||
panic!("Expected PromptCommand");
|
||||
}
|
||||
|
||||
// Test prompt with arguments
|
||||
if let Some(InputResult::PromptCommand(opts)) =
|
||||
handle_slash_command("/prompt test-prompt arg1=val1 arg2=val2")
|
||||
{
|
||||
assert_eq!(opts.name, "test-prompt");
|
||||
assert!(!opts.info);
|
||||
assert_eq!(opts.arguments.len(), 2);
|
||||
assert_eq!(opts.arguments.get("arg1"), Some(&"val1".to_string()));
|
||||
assert_eq!(opts.arguments.get("arg2"), Some(&"val2".to_string()));
|
||||
} else {
|
||||
panic!("Expected PromptCommand");
|
||||
}
|
||||
}
|
||||
|
||||
// Test whitespace handling
|
||||
#[test]
|
||||
fn test_whitespace_handling() {
|
||||
@@ -149,4 +244,73 @@ mod tests {
|
||||
panic!("Expected AddBuiltin");
|
||||
}
|
||||
}
|
||||
|
||||
// Test prompt with no arguments
|
||||
#[test]
|
||||
fn test_prompt_no_args() {
|
||||
// Test just "/prompt" with no arguments
|
||||
if let Some(InputResult::PromptCommand(opts)) = handle_slash_command("/prompt") {
|
||||
assert_eq!(opts.name, "");
|
||||
assert!(!opts.info);
|
||||
assert!(opts.arguments.is_empty());
|
||||
} else {
|
||||
panic!("Expected PromptCommand");
|
||||
}
|
||||
|
||||
// Test invalid prompt command
|
||||
assert!(handle_slash_command("/promptxyz").is_none());
|
||||
}
|
||||
|
||||
// Test quoted arguments
|
||||
#[test]
|
||||
fn test_quoted_arguments() {
|
||||
// Test prompt with quoted arguments
|
||||
if let Some(InputResult::PromptCommand(opts)) = handle_slash_command(
|
||||
r#"/prompt test-prompt arg1="value with spaces" arg2="another value""#,
|
||||
) {
|
||||
assert_eq!(opts.name, "test-prompt");
|
||||
assert_eq!(opts.arguments.len(), 2);
|
||||
assert_eq!(
|
||||
opts.arguments.get("arg1"),
|
||||
Some(&"value with spaces".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
opts.arguments.get("arg2"),
|
||||
Some(&"another value".to_string())
|
||||
);
|
||||
} else {
|
||||
panic!("Expected PromptCommand");
|
||||
}
|
||||
|
||||
// Test prompt with mixed quoted and unquoted arguments
|
||||
if let Some(InputResult::PromptCommand(opts)) = handle_slash_command(
|
||||
r#"/prompt test-prompt simple=value quoted="value with \"nested\" quotes""#,
|
||||
) {
|
||||
assert_eq!(opts.name, "test-prompt");
|
||||
assert_eq!(opts.arguments.len(), 2);
|
||||
assert_eq!(opts.arguments.get("simple"), Some(&"value".to_string()));
|
||||
assert_eq!(
|
||||
opts.arguments.get("quoted"),
|
||||
Some(&r#"value with "nested" quotes"#.to_string())
|
||||
);
|
||||
} else {
|
||||
panic!("Expected PromptCommand");
|
||||
}
|
||||
}
|
||||
|
||||
// Test invalid arguments
|
||||
#[test]
|
||||
fn test_invalid_arguments() {
|
||||
// Test prompt with invalid arguments
|
||||
if let Some(InputResult::PromptCommand(opts)) =
|
||||
handle_slash_command(r#"/prompt test-prompt valid=value invalid_arg another_invalid"#)
|
||||
{
|
||||
assert_eq!(opts.name, "test-prompt");
|
||||
assert_eq!(opts.arguments.len(), 1);
|
||||
assert_eq!(opts.arguments.get("valid"), Some(&"value".to_string()));
|
||||
// Invalid arguments are ignored but logged
|
||||
} else {
|
||||
panic!("Expected PromptCommand");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,11 @@ use goose::agents::extension::{Envs, ExtensionConfig};
|
||||
use goose::agents::Agent;
|
||||
use goose::message::{Message, MessageContent};
|
||||
use mcp_core::handler::ToolError;
|
||||
use mcp_core::prompt::PromptMessage;
|
||||
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use tokio;
|
||||
|
||||
@@ -104,6 +108,40 @@ impl Session {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_prompts(&mut self) -> HashMap<String, Vec<String>> {
|
||||
let prompts = self.agent.list_extension_prompts().await;
|
||||
prompts
|
||||
.into_iter()
|
||||
.map(|(extension, prompt_list)| {
|
||||
let names = prompt_list.into_iter().map(|p| p.name).collect();
|
||||
(extension, names)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn get_prompt_info(&mut self, name: &str) -> Result<Option<output::PromptInfo>> {
|
||||
let prompts = self.agent.list_extension_prompts().await;
|
||||
|
||||
// Find which extension has this prompt
|
||||
for (extension, prompt_list) in prompts {
|
||||
if let Some(prompt) = prompt_list.iter().find(|p| p.name == name) {
|
||||
return Ok(Some(output::PromptInfo {
|
||||
name: prompt.name.clone(),
|
||||
description: prompt.description.clone(),
|
||||
arguments: prompt.arguments.clone(),
|
||||
extension: Some(extension),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn get_prompt(&mut self, name: &str, arguments: Value) -> Result<Vec<PromptMessage>> {
|
||||
let result = self.agent.get_prompt(name, arguments).await?;
|
||||
Ok(result.messages)
|
||||
}
|
||||
|
||||
/// Process a single message and get the response
|
||||
async fn process_message(&mut self, message: String) -> Result<()> {
|
||||
self.messages.push(Message::user().with_text(&message));
|
||||
@@ -179,6 +217,68 @@ impl Session {
|
||||
continue;
|
||||
}
|
||||
input::InputResult::Retry => continue,
|
||||
input::InputResult::ListPrompts => {
|
||||
output::render_prompts(&self.list_prompts().await)
|
||||
}
|
||||
input::InputResult::PromptCommand(opts) => {
|
||||
// name is required
|
||||
if opts.name.is_empty() {
|
||||
output::render_error("Prompt name argument is required");
|
||||
continue;
|
||||
}
|
||||
|
||||
if opts.info {
|
||||
match self.get_prompt_info(&opts.name).await? {
|
||||
Some(info) => output::render_prompt_info(&info),
|
||||
None => {
|
||||
output::render_error(&format!("Prompt '{}' not found", opts.name))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Convert the arguments HashMap to a Value
|
||||
let arguments = serde_json::to_value(opts.arguments)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to serialize arguments: {}", e))?;
|
||||
|
||||
match self.get_prompt(&opts.name, arguments).await {
|
||||
Ok(messages) => {
|
||||
let start_len = self.messages.len();
|
||||
let mut valid = true;
|
||||
for (i, prompt_message) in messages.into_iter().enumerate() {
|
||||
let msg = Message::from(prompt_message);
|
||||
// ensure we get a User - Assistant - User type pattern
|
||||
let expected_role = if i % 2 == 0 {
|
||||
mcp_core::Role::User
|
||||
} else {
|
||||
mcp_core::Role::Assistant
|
||||
};
|
||||
|
||||
if msg.role != expected_role {
|
||||
output::render_error(&format!(
|
||||
"Expected {:?} message at position {}, but found {:?}",
|
||||
expected_role, i, msg.role
|
||||
));
|
||||
valid = false;
|
||||
// get rid of everything we added to messages
|
||||
self.messages.truncate(start_len);
|
||||
break;
|
||||
}
|
||||
|
||||
if msg.role == mcp_core::Role::User {
|
||||
output::render_message(&msg);
|
||||
}
|
||||
self.messages.push(msg);
|
||||
}
|
||||
|
||||
if valid {
|
||||
output::show_thinking();
|
||||
self.process_agent_response(true).await?;
|
||||
output::hide_thinking();
|
||||
}
|
||||
}
|
||||
Err(e) => output::render_error(&e.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,11 @@ use bat::WrappingMode;
|
||||
use console::style;
|
||||
use goose::config::Config;
|
||||
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
|
||||
use mcp_core::prompt::PromptArgument;
|
||||
use mcp_core::tool::ToolCall;
|
||||
use serde_json::Value;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
// Re-export theme for use in main
|
||||
@@ -73,6 +75,14 @@ impl ThinkingIndicator {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PromptInfo {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub arguments: Option<Vec<PromptArgument>>,
|
||||
pub extension: Option<String>,
|
||||
}
|
||||
|
||||
// Global thinking indicator
|
||||
thread_local! {
|
||||
static THINKING: RefCell<ThinkingIndicator> = RefCell::new(ThinkingIndicator::default());
|
||||
@@ -154,6 +164,51 @@ pub fn render_error(message: &str) {
|
||||
println!("\n {} {}\n", style("error:").red().bold(), message);
|
||||
}
|
||||
|
||||
pub fn render_prompts(prompts: &HashMap<String, Vec<String>>) {
|
||||
println!();
|
||||
for (extension, prompts) in prompts {
|
||||
println!(" {}", style(extension).green());
|
||||
for prompt in prompts {
|
||||
println!(" - {}", style(prompt).cyan());
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
pub fn render_prompt_info(info: &PromptInfo) {
|
||||
println!();
|
||||
|
||||
if let Some(ext) = &info.extension {
|
||||
println!(" {}: {}", style("Extension").green(), ext);
|
||||
}
|
||||
|
||||
println!(" Prompt: {}", style(&info.name).cyan().bold());
|
||||
|
||||
if let Some(desc) = &info.description {
|
||||
println!("\n {}", desc);
|
||||
}
|
||||
|
||||
if let Some(args) = &info.arguments {
|
||||
println!("\n Arguments:");
|
||||
for arg in args {
|
||||
let required = arg.required.unwrap_or(false);
|
||||
let req_str = if required {
|
||||
style("(required)").red()
|
||||
} else {
|
||||
style("(optional)").dim()
|
||||
};
|
||||
|
||||
println!(
|
||||
" {} {} {}",
|
||||
style(&arg.name).yellow(),
|
||||
req_str,
|
||||
arg.description.as_deref().unwrap_or("")
|
||||
);
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
pub fn render_extension_success(name: &str) {
|
||||
println!();
|
||||
println!(
|
||||
|
||||
@@ -9,7 +9,8 @@ use std::{
|
||||
use tokio::process::Command;
|
||||
|
||||
use mcp_core::{
|
||||
handler::{ResourceError, ToolError},
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
resource::Resource,
|
||||
tool::Tool,
|
||||
@@ -819,4 +820,21 @@ impl Router for ComputerControllerRouter {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn list_prompts(&self) -> Vec<Prompt> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn get_prompt(
|
||||
&self,
|
||||
prompt_name: &str,
|
||||
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
|
||||
let prompt_name = prompt_name.to_string();
|
||||
Box::pin(async move {
|
||||
Err(PromptError::NotFound(format!(
|
||||
"Prompt {} not found",
|
||||
prompt_name
|
||||
)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,9 +72,9 @@ pub fn load_prompt_files() -> HashMap<String, Prompt> {
|
||||
description: arg.description,
|
||||
required: arg.required,
|
||||
})
|
||||
.collect();
|
||||
.collect::<Vec<PromptArgument>>();
|
||||
|
||||
let prompt = Prompt::new(&template.id, &template.template, arguments);
|
||||
let prompt = Prompt::new(&template.id, Some(&template.template), Some(arguments));
|
||||
|
||||
if prompts.contains_key(&prompt.name) {
|
||||
eprintln!("Duplicate prompt name '{}' found. Skipping.", prompt.name);
|
||||
@@ -905,47 +905,35 @@ impl Router for DeveloperRouter {
|
||||
Box::pin(async move { Ok("".to_string()) })
|
||||
}
|
||||
|
||||
fn list_prompts(&self) -> Option<Vec<Prompt>> {
|
||||
if self.prompts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.prompts.values().cloned().collect())
|
||||
}
|
||||
fn list_prompts(&self) -> Vec<Prompt> {
|
||||
self.prompts.values().cloned().collect()
|
||||
}
|
||||
|
||||
fn get_prompt(
|
||||
&self,
|
||||
prompt_name: &str,
|
||||
) -> Option<Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>>> {
|
||||
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
|
||||
let prompt_name = prompt_name.trim().to_owned();
|
||||
|
||||
// Validate prompt name is not empty
|
||||
if prompt_name.is_empty() {
|
||||
return Some(Box::pin(async move {
|
||||
return Box::pin(async move {
|
||||
Err(PromptError::InvalidParameters(
|
||||
"Prompt name cannot be empty".to_string(),
|
||||
))
|
||||
}));
|
||||
});
|
||||
}
|
||||
|
||||
let prompts = Arc::clone(&self.prompts);
|
||||
|
||||
Some(Box::pin(async move {
|
||||
Box::pin(async move {
|
||||
match prompts.get(&prompt_name) {
|
||||
Some(prompt) => {
|
||||
if prompt.description.trim().is_empty() {
|
||||
Err(PromptError::InternalError(format!(
|
||||
"Prompt '{prompt_name}' has an empty description"
|
||||
)))
|
||||
} else {
|
||||
Ok(prompt.description.clone())
|
||||
}
|
||||
}
|
||||
Some(prompt) => Ok(prompt.description.clone().unwrap_or_default()),
|
||||
None => Err(PromptError::NotFound(format!(
|
||||
"Prompt '{prompt_name}' not found"
|
||||
))),
|
||||
}
|
||||
}))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ use serde_json::{json, Value};
|
||||
use std::{env, fs, future::Future, io::Write, path::Path, pin::Pin};
|
||||
|
||||
use mcp_core::{
|
||||
handler::{ResourceError, ToolError},
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
resource::Resource,
|
||||
tool::Tool,
|
||||
@@ -618,6 +619,23 @@ impl Router for GoogleDriveRouter {
|
||||
let uri_clone = uri.to_string();
|
||||
Box::pin(async move { this.read_google_resource(uri_clone).await })
|
||||
}
|
||||
|
||||
fn list_prompts(&self) -> Vec<Prompt> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn get_prompt(
|
||||
&self,
|
||||
prompt_name: &str,
|
||||
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
|
||||
let prompt_name = prompt_name.to_string();
|
||||
Box::pin(async move {
|
||||
Err(PromptError::NotFound(format!(
|
||||
"Prompt {} not found",
|
||||
prompt_name
|
||||
)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for GoogleDriveRouter {
|
||||
|
||||
@@ -3,7 +3,8 @@ mod proxy;
|
||||
use anyhow::Result;
|
||||
use mcp_core::{
|
||||
content::Content,
|
||||
handler::{ResourceError, ToolError},
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
resource::Resource,
|
||||
role::Role,
|
||||
@@ -176,6 +177,23 @@ impl Router for JetBrainsRouter {
|
||||
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> {
|
||||
Box::pin(async { Err(ResourceError::NotFound("Resource not found".into())) })
|
||||
}
|
||||
|
||||
fn list_prompts(&self) -> Vec<Prompt> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn get_prompt(
|
||||
&self,
|
||||
prompt_name: &str,
|
||||
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
|
||||
let prompt_name = prompt_name.to_string();
|
||||
Box::pin(async move {
|
||||
Err(PromptError::NotFound(format!(
|
||||
"Prompt {} not found",
|
||||
prompt_name
|
||||
)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for JetBrainsRouter {
|
||||
|
||||
@@ -12,7 +12,8 @@ use std::{
|
||||
};
|
||||
|
||||
use mcp_core::{
|
||||
handler::{ResourceError, ToolError},
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
resource::Resource,
|
||||
tool::{Tool, ToolCall},
|
||||
@@ -493,6 +494,22 @@ impl Router for MemoryRouter {
|
||||
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> {
|
||||
Box::pin(async move { Ok("".to_string()) })
|
||||
}
|
||||
fn list_prompts(&self) -> Vec<Prompt> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn get_prompt(
|
||||
&self,
|
||||
prompt_name: &str,
|
||||
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
|
||||
let prompt_name = prompt_name.to_string();
|
||||
Box::pin(async move {
|
||||
Err(PromptError::NotFound(format!(
|
||||
"Prompt {} not found",
|
||||
prompt_name
|
||||
)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -5,7 +5,8 @@ use serde_json::{json, Value};
|
||||
use std::{future::Future, pin::Pin};
|
||||
|
||||
use mcp_core::{
|
||||
handler::{ResourceError, ToolError},
|
||||
handler::{PromptError, ResourceError, ToolError},
|
||||
prompt::Prompt,
|
||||
protocol::ServerCapabilities,
|
||||
resource::Resource,
|
||||
role::Role,
|
||||
@@ -156,6 +157,23 @@ impl Router for TutorialRouter {
|
||||
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> {
|
||||
Box::pin(async move { Ok("".to_string()) })
|
||||
}
|
||||
|
||||
fn list_prompts(&self) -> Vec<Prompt> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn get_prompt(
|
||||
&self,
|
||||
prompt_name: &str,
|
||||
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
|
||||
let prompt_name = prompt_name.to_string();
|
||||
Box::pin(async move {
|
||||
Err(PromptError::NotFound(format!(
|
||||
"Prompt {} not found",
|
||||
prompt_name
|
||||
)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for TutorialRouter {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
@@ -6,6 +8,8 @@ use serde_json::Value;
|
||||
use super::extension::{ExtensionConfig, ExtensionResult};
|
||||
use crate::message::Message;
|
||||
use crate::providers::base::ProviderUsage;
|
||||
use mcp_core::prompt::Prompt;
|
||||
use mcp_core::protocol::GetPromptResult;
|
||||
|
||||
/// Core trait defining the behavior of an Agent
|
||||
#[async_trait]
|
||||
@@ -37,4 +41,11 @@ pub trait Agent: Send + Sync {
|
||||
|
||||
/// Override the system prompt with custom text
|
||||
async fn override_system_prompt(&mut self, template: String);
|
||||
|
||||
/// Lists all prompts from all extensions
|
||||
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>>;
|
||||
|
||||
/// Get a prompt result with the given name and arguments
|
||||
/// Returns the prompt text that would be used as user input
|
||||
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult>;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, TimeZone, Utc};
|
||||
use futures::stream::{FuturesUnordered, StreamExt};
|
||||
use mcp_client::McpService;
|
||||
use mcp_core::protocol::GetPromptResult;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use std::sync::LazyLock;
|
||||
@@ -13,7 +15,7 @@ use crate::prompt_template::{load_prompt, load_prompt_file};
|
||||
use crate::providers::base::{Provider, ProviderUsage};
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
|
||||
use mcp_core::{Content, Tool, ToolCall, ToolError, ToolResult};
|
||||
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult};
|
||||
use serde_json::Value;
|
||||
|
||||
// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp
|
||||
@@ -544,6 +546,87 @@ impl Capabilities {
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn list_prompts_from_extension(
|
||||
&self,
|
||||
extension_name: &str,
|
||||
) -> Result<Vec<Prompt>, ToolError> {
|
||||
let client = self.clients.get(extension_name).ok_or_else(|| {
|
||||
ToolError::InvalidParameters(format!("Extension {} is not valid", extension_name))
|
||||
})?;
|
||||
|
||||
let client_guard = client.lock().await;
|
||||
client_guard
|
||||
.list_prompts(None)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
ToolError::ExecutionError(format!(
|
||||
"Unable to list prompts for {}, {:?}",
|
||||
extension_name, e
|
||||
))
|
||||
})
|
||||
.map(|lp| lp.prompts)
|
||||
}
|
||||
|
||||
pub async fn list_prompts(&self) -> Result<HashMap<String, Vec<Prompt>>, ToolError> {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
|
||||
for extension_name in self.clients.keys() {
|
||||
futures.push(async move {
|
||||
(
|
||||
extension_name,
|
||||
self.list_prompts_from_extension(extension_name).await,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
let mut all_prompts = HashMap::new();
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Process results as they complete
|
||||
while let Some(result) = futures.next().await {
|
||||
let (name, prompts) = result;
|
||||
match prompts {
|
||||
Ok(content) => {
|
||||
all_prompts.insert(name.to_string(), content);
|
||||
}
|
||||
Err(tool_error) => {
|
||||
errors.push(tool_error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Log any errors that occurred
|
||||
if !errors.is_empty() {
|
||||
tracing::error!(
|
||||
errors = ?errors
|
||||
.into_iter()
|
||||
.map(|e| format!("{:?}", e))
|
||||
.collect::<Vec<_>>(),
|
||||
"errors from listing prompts"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(all_prompts)
|
||||
}
|
||||
|
||||
pub async fn get_prompt(
|
||||
&self,
|
||||
extension_name: &str,
|
||||
name: &str,
|
||||
arguments: Value,
|
||||
) -> Result<GetPromptResult> {
|
||||
let client = self
|
||||
.clients
|
||||
.get(extension_name)
|
||||
.ok_or_else(|| anyhow::anyhow!("Extension {} not found", extension_name))?;
|
||||
|
||||
let client_guard = client.lock().await;
|
||||
client_guard
|
||||
.get_prompt(name, arguments)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -556,7 +639,8 @@ mod tests {
|
||||
use mcp_client::client::Error;
|
||||
use mcp_client::client::McpClientTrait;
|
||||
use mcp_core::protocol::{
|
||||
CallToolResult, InitializeResult, ListResourcesResult, ListToolsResult, ReadResourceResult,
|
||||
CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult,
|
||||
ListToolsResult, ReadResourceResult,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
@@ -625,6 +709,21 @@ mod tests {
|
||||
_ => Err(Error::NotInitialized),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_prompts(
|
||||
&self,
|
||||
_next_cursor: Option<String>,
|
||||
) -> Result<ListPromptsResult, Error> {
|
||||
Err(Error::NotInitialized)
|
||||
}
|
||||
|
||||
async fn get_prompt(
|
||||
&self,
|
||||
_name: &str,
|
||||
_arguments: Value,
|
||||
) -> Result<GetPromptResult, Error> {
|
||||
Err(Error::NotInitialized)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
/// It makes no attempt to handle context limits, and cannot read resources
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
@@ -13,7 +14,10 @@ use crate::providers::base::Provider;
|
||||
use crate::providers::base::ProviderUsage;
|
||||
use crate::register_agent;
|
||||
use crate::token_counter::TokenCounter;
|
||||
use anyhow::{anyhow, Result};
|
||||
use indoc::indoc;
|
||||
use mcp_core::prompt::Prompt;
|
||||
use mcp_core::protocol::GetPromptResult;
|
||||
use mcp_core::tool::Tool;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
@@ -198,6 +202,37 @@ impl Agent for ReferenceAgent {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.set_system_prompt_override(template);
|
||||
}
|
||||
|
||||
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities
|
||||
.list_prompts()
|
||||
.await
|
||||
.expect("Failed to list prompts")
|
||||
}
|
||||
|
||||
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
|
||||
// First find which extension has this prompt
|
||||
let prompts = capabilities
|
||||
.list_prompts()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
|
||||
|
||||
if let Some(extension) = prompts
|
||||
.iter()
|
||||
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
|
||||
.map(|(extension, _)| extension)
|
||||
{
|
||||
return capabilities
|
||||
.get_prompt(extension, name, arguments)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
|
||||
}
|
||||
|
||||
Err(anyhow!("Prompt '{}' not found", name))
|
||||
}
|
||||
}
|
||||
|
||||
register_agent!("reference", ReferenceAgent);
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
/// It makes no attempt to handle context limits, and cannot read resources
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, error, instrument, warn};
|
||||
@@ -19,7 +20,10 @@ use crate::providers::errors::ProviderError;
|
||||
use crate::register_agent;
|
||||
use crate::token_counter::TokenCounter;
|
||||
use crate::truncate::{truncate_messages, OldestFirstTruncation};
|
||||
use anyhow::{anyhow, Result};
|
||||
use indoc::indoc;
|
||||
use mcp_core::prompt::Prompt;
|
||||
use mcp_core::protocol::GetPromptResult;
|
||||
use mcp_core::{tool::Tool, Content};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
@@ -398,6 +402,37 @@ impl Agent for TruncateAgent {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.set_system_prompt_override(template);
|
||||
}
|
||||
|
||||
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities
|
||||
.list_prompts()
|
||||
.await
|
||||
.expect("Failed to list prompts")
|
||||
}
|
||||
|
||||
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
|
||||
// First find which extension has this prompt
|
||||
let prompts = capabilities
|
||||
.list_prompts()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
|
||||
|
||||
if let Some(extension) = prompts
|
||||
.iter()
|
||||
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
|
||||
.map(|(extension, _)| extension)
|
||||
{
|
||||
return capabilities
|
||||
.get_prompt(extension, name, arguments)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
|
||||
}
|
||||
|
||||
Err(anyhow!("Prompt '{}' not found", name))
|
||||
}
|
||||
}
|
||||
|
||||
register_agent!("truncate", TruncateAgent);
|
||||
|
||||
@@ -10,6 +10,8 @@ use std::collections::HashSet;
|
||||
use chrono::Utc;
|
||||
use mcp_core::content::{Content, ImageContent, TextContent};
|
||||
use mcp_core::handler::ToolResult;
|
||||
use mcp_core::prompt::{PromptMessage, PromptMessageContent, PromptMessageRole};
|
||||
use mcp_core::resource::ResourceContents;
|
||||
use mcp_core::role::Role;
|
||||
use mcp_core::tool::ToolCall;
|
||||
use serde_json::Value;
|
||||
@@ -156,6 +158,37 @@ impl From<Content> for MessageContent {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PromptMessage> for Message {
|
||||
fn from(prompt_message: PromptMessage) -> Self {
|
||||
// Create a new message with the appropriate role
|
||||
let message = match prompt_message.role {
|
||||
PromptMessageRole::User => Message::user(),
|
||||
PromptMessageRole::Assistant => Message::assistant(),
|
||||
};
|
||||
|
||||
// Convert and add the content
|
||||
let content = match prompt_message.content {
|
||||
PromptMessageContent::Text { text } => MessageContent::text(text),
|
||||
PromptMessageContent::Image { image } => {
|
||||
MessageContent::image(image.data, image.mime_type)
|
||||
}
|
||||
PromptMessageContent::Resource { resource } => {
|
||||
// For resources, convert to text content with the resource text
|
||||
match resource.resource {
|
||||
ResourceContents::TextResourceContents { text, .. } => {
|
||||
MessageContent::text(text)
|
||||
}
|
||||
ResourceContents::BlobResourceContents { blob, .. } => {
|
||||
MessageContent::text(format!("[Binary content: {}]", blob))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
message.with_content(content)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
/// A message to or from an LLM
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -305,7 +338,10 @@ impl Message {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mcp_core::content::EmbeddedResource;
|
||||
use mcp_core::handler::ToolError;
|
||||
use mcp_core::prompt::PromptMessageContent;
|
||||
use mcp_core::resource::ResourceContents;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[test]
|
||||
@@ -420,4 +456,158 @@ mod tests {
|
||||
panic!("Expected ToolRequest content");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_prompt_message_text() {
|
||||
let prompt_content = PromptMessageContent::Text {
|
||||
text: "Hello, world!".to_string(),
|
||||
};
|
||||
|
||||
let prompt_message = PromptMessage {
|
||||
role: PromptMessageRole::User,
|
||||
content: prompt_content,
|
||||
};
|
||||
|
||||
let message = Message::from(prompt_message);
|
||||
|
||||
if let MessageContent::Text(text_content) = &message.content[0] {
|
||||
assert_eq!(text_content.text, "Hello, world!");
|
||||
} else {
|
||||
panic!("Expected MessageContent::Text");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_prompt_message_image() {
|
||||
let prompt_content = PromptMessageContent::Image {
|
||||
image: ImageContent {
|
||||
data: "base64data".to_string(),
|
||||
mime_type: "image/jpeg".to_string(),
|
||||
annotations: None,
|
||||
},
|
||||
};
|
||||
|
||||
let prompt_message = PromptMessage {
|
||||
role: PromptMessageRole::User,
|
||||
content: prompt_content,
|
||||
};
|
||||
|
||||
let message = Message::from(prompt_message);
|
||||
|
||||
if let MessageContent::Image(image_content) = &message.content[0] {
|
||||
assert_eq!(image_content.data, "base64data");
|
||||
assert_eq!(image_content.mime_type, "image/jpeg");
|
||||
} else {
|
||||
panic!("Expected MessageContent::Image");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_prompt_message_text_resource() {
|
||||
let resource = ResourceContents::TextResourceContents {
|
||||
uri: "file:///test.txt".to_string(),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
text: "Resource content".to_string(),
|
||||
};
|
||||
|
||||
let prompt_content = PromptMessageContent::Resource {
|
||||
resource: EmbeddedResource {
|
||||
resource,
|
||||
annotations: None,
|
||||
},
|
||||
};
|
||||
|
||||
let prompt_message = PromptMessage {
|
||||
role: PromptMessageRole::User,
|
||||
content: prompt_content,
|
||||
};
|
||||
|
||||
let message = Message::from(prompt_message);
|
||||
|
||||
if let MessageContent::Text(text_content) = &message.content[0] {
|
||||
assert_eq!(text_content.text, "Resource content");
|
||||
} else {
|
||||
panic!("Expected MessageContent::Text");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_prompt_message_blob_resource() {
|
||||
let resource = ResourceContents::BlobResourceContents {
|
||||
uri: "file:///test.bin".to_string(),
|
||||
mime_type: Some("application/octet-stream".to_string()),
|
||||
blob: "binary_data".to_string(),
|
||||
};
|
||||
|
||||
let prompt_content = PromptMessageContent::Resource {
|
||||
resource: EmbeddedResource {
|
||||
resource,
|
||||
annotations: None,
|
||||
},
|
||||
};
|
||||
|
||||
let prompt_message = PromptMessage {
|
||||
role: PromptMessageRole::User,
|
||||
content: prompt_content,
|
||||
};
|
||||
|
||||
let message = Message::from(prompt_message);
|
||||
|
||||
if let MessageContent::Text(text_content) = &message.content[0] {
|
||||
assert_eq!(text_content.text, "[Binary content: binary_data]");
|
||||
} else {
|
||||
panic!("Expected MessageContent::Text");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_prompt_message() {
|
||||
// Test user message conversion
|
||||
let prompt_message = PromptMessage {
|
||||
role: PromptMessageRole::User,
|
||||
content: PromptMessageContent::Text {
|
||||
text: "Hello, world!".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let message = Message::from(prompt_message);
|
||||
assert_eq!(message.role, Role::User);
|
||||
assert_eq!(message.content.len(), 1);
|
||||
assert_eq!(message.as_concat_text(), "Hello, world!");
|
||||
|
||||
// Test assistant message conversion
|
||||
let prompt_message = PromptMessage {
|
||||
role: PromptMessageRole::Assistant,
|
||||
content: PromptMessageContent::Text {
|
||||
text: "I can help with that.".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let message = Message::from(prompt_message);
|
||||
assert_eq!(message.role, Role::Assistant);
|
||||
assert_eq!(message.content.len(), 1);
|
||||
assert_eq!(message.as_concat_text(), "I can help with that.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_with_text() {
|
||||
let message = Message::user().with_text("Hello");
|
||||
assert_eq!(message.as_concat_text(), "Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_with_tool_request() {
|
||||
let tool_call = Ok(ToolCall {
|
||||
name: "test_tool".to_string(),
|
||||
arguments: serde_json::json!({}),
|
||||
});
|
||||
|
||||
let message = Message::assistant().with_tool_request("req1", tool_call);
|
||||
assert!(message.is_tool_call());
|
||||
assert!(!message.is_tool_response());
|
||||
|
||||
let ids = message.get_tool_ids();
|
||||
assert_eq!(ids.len(), 1);
|
||||
assert!(ids.contains("req1"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,5 +82,16 @@ async fn main() -> Result<(), ClientError> {
|
||||
let resource = client.read_resource("memo://insights").await?;
|
||||
println!("Resource: {resource:?}\n");
|
||||
|
||||
let prompts = client.list_prompts(None).await?;
|
||||
println!("Prompts: {prompts:?}\n");
|
||||
|
||||
let prompt = client
|
||||
.get_prompt(
|
||||
"example_prompt",
|
||||
serde_json::json!({"message": "hello there!"}),
|
||||
)
|
||||
.await?;
|
||||
println!("Prompt: {prompt:?}\n");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use mcp_core::protocol::{
|
||||
CallToolResult, Implementation, InitializeResult, JsonRpcError, JsonRpcMessage,
|
||||
JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult,
|
||||
ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
|
||||
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError,
|
||||
JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult,
|
||||
ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
@@ -93,6 +93,10 @@ pub trait McpClientTrait: Send + Sync {
|
||||
async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error>;
|
||||
|
||||
async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error>;
|
||||
|
||||
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error>;
|
||||
|
||||
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;
|
||||
}
|
||||
|
||||
/// The MCP client is the interface for MCP operations.
|
||||
@@ -346,4 +350,42 @@ where
|
||||
// https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2
|
||||
self.send_request("tools/call", params).await
|
||||
}
|
||||
|
||||
async fn list_prompts(&self, next_cursor: Option<String>) -> Result<ListPromptsResult, Error> {
|
||||
if !self.completed_initialization() {
|
||||
return Err(Error::NotInitialized);
|
||||
}
|
||||
|
||||
// If prompts is not supported, return an error
|
||||
if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
|
||||
return Err(Error::RpcError {
|
||||
code: METHOD_NOT_FOUND,
|
||||
message: "Server does not support 'prompts' capability".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let payload = next_cursor
|
||||
.map(|cursor| serde_json::json!({"cursor": cursor}))
|
||||
.unwrap_or_else(|| serde_json::json!({}));
|
||||
|
||||
self.send_request("prompts/list", payload).await
|
||||
}
|
||||
|
||||
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error> {
|
||||
if !self.completed_initialization() {
|
||||
return Err(Error::NotInitialized);
|
||||
}
|
||||
|
||||
// If prompts is not supported, return an error
|
||||
if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
|
||||
return Err(Error::RpcError {
|
||||
code: METHOD_NOT_FOUND,
|
||||
message: "Server does not support 'prompts' capability".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let params = serde_json::json!({ "name": name, "arguments": arguments });
|
||||
|
||||
self.send_request("prompts/get", params).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,13 +111,23 @@ impl SseActor {
|
||||
// Attempt to parse the SSE data as a JsonRpcMessage
|
||||
match serde_json::from_str::<JsonRpcMessage>(&e.data) {
|
||||
Ok(message) => {
|
||||
// If it's a response, complete the pending request
|
||||
if let JsonRpcMessage::Response(resp) = &message {
|
||||
if let Some(id) = &resp.id {
|
||||
pending_requests.respond(&id.to_string(), Ok(message)).await;
|
||||
match &message {
|
||||
JsonRpcMessage::Response(response) => {
|
||||
if let Some(id) = &response.id {
|
||||
pending_requests
|
||||
.respond(&id.to_string(), Ok(message))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
JsonRpcMessage::Error(error) => {
|
||||
if let Some(id) = &error.id {
|
||||
pending_requests
|
||||
.respond(&id.to_string(), Ok(message))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
_ => {} // TODO: Handle other variants (Request, etc.)
|
||||
}
|
||||
// If it's something else (notification, etc.), handle as needed
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to parse SSE message: {err}");
|
||||
|
||||
@@ -87,10 +87,18 @@ impl StdioActor {
|
||||
"Received incoming message"
|
||||
);
|
||||
|
||||
if let JsonRpcMessage::Response(response) = &message {
|
||||
if let Some(id) = &response.id {
|
||||
pending_requests.respond(&id.to_string(), Ok(message)).await;
|
||||
match &message {
|
||||
JsonRpcMessage::Response(response) => {
|
||||
if let Some(id) = &response.id {
|
||||
pending_requests.respond(&id.to_string(), Ok(message)).await;
|
||||
}
|
||||
}
|
||||
JsonRpcMessage::Error(error) => {
|
||||
if let Some(id) = &error.id {
|
||||
pending_requests.respond(&id.to_string(), Ok(message)).await;
|
||||
}
|
||||
}
|
||||
_ => {} // TODO: Handle other variants (Request, etc.)
|
||||
}
|
||||
}
|
||||
line.clear();
|
||||
|
||||
@@ -10,22 +10,28 @@ use serde::{Deserialize, Serialize};
|
||||
pub struct Prompt {
|
||||
/// The name of the prompt
|
||||
pub name: String,
|
||||
/// A description of what the prompt does
|
||||
pub description: String,
|
||||
/// The arguments that can be passed to customize the prompt
|
||||
pub arguments: Vec<PromptArgument>,
|
||||
/// Optional description of what the prompt does
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
/// Optional arguments that can be passed to customize the prompt
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<Vec<PromptArgument>>,
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
/// Create a new prompt with the given name, description and arguments
|
||||
pub fn new<N, D>(name: N, description: D, arguments: Vec<PromptArgument>) -> Self
|
||||
pub fn new<N, D>(
|
||||
name: N,
|
||||
description: Option<D>,
|
||||
arguments: Option<Vec<PromptArgument>>,
|
||||
) -> Self
|
||||
where
|
||||
N: Into<String>,
|
||||
D: Into<String>,
|
||||
{
|
||||
Prompt {
|
||||
name: name.into(),
|
||||
description: description.into(),
|
||||
description: description.map(Into::into),
|
||||
arguments,
|
||||
}
|
||||
}
|
||||
@@ -37,9 +43,11 @@ pub struct PromptArgument {
|
||||
/// The name of the argument
|
||||
pub name: String,
|
||||
/// A description of what the argument is used for
|
||||
pub description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
/// Whether this argument is required
|
||||
pub required: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
}
|
||||
|
||||
/// Represents the role of a message sender in a prompt conversation
|
||||
@@ -151,6 +159,6 @@ pub struct PromptTemplate {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PromptArgumentTemplate {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub required: bool,
|
||||
pub description: Option<String>,
|
||||
pub required: Option<bool>,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use mcp_core::content::Content;
|
||||
use mcp_core::handler::ResourceError;
|
||||
use mcp_core::handler::{PromptError, ResourceError};
|
||||
use mcp_core::prompt::{Prompt, PromptArgument};
|
||||
use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool};
|
||||
use mcp_server::router::{CapabilitiesBuilder, RouterService};
|
||||
use mcp_server::{ByteTransport, Router, Server};
|
||||
@@ -61,6 +62,7 @@ impl Router for CounterRouter {
|
||||
CapabilitiesBuilder::new()
|
||||
.with_tools(false)
|
||||
.with_resources(false, false)
|
||||
.with_prompts(false)
|
||||
.build()
|
||||
}
|
||||
|
||||
@@ -153,6 +155,37 @@ impl Router for CounterRouter {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn list_prompts(&self) -> Vec<Prompt> {
|
||||
vec![Prompt::new(
|
||||
"example_prompt",
|
||||
Some("This is an example prompt that takes one required agrument, message"),
|
||||
Some(vec![PromptArgument {
|
||||
name: "message".to_string(),
|
||||
description: Some("A message to put in the prompt".to_string()),
|
||||
required: Some(true),
|
||||
}]),
|
||||
)]
|
||||
}
|
||||
|
||||
fn get_prompt(
|
||||
&self,
|
||||
prompt_name: &str,
|
||||
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
|
||||
let prompt_name = prompt_name.to_string();
|
||||
Box::pin(async move {
|
||||
match prompt_name.as_str() {
|
||||
"example_prompt" => {
|
||||
let prompt = "This is an example prompt with your message here: '{message}'";
|
||||
Ok(prompt.to_string())
|
||||
}
|
||||
_ => Err(PromptError::NotFound(format!(
|
||||
"Prompt {} not found",
|
||||
prompt_name
|
||||
))),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
|
||||
@@ -97,12 +97,8 @@ pub trait Router: Send + Sync + 'static {
|
||||
&self,
|
||||
uri: &str,
|
||||
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>>;
|
||||
fn list_prompts(&self) -> Option<Vec<Prompt>> {
|
||||
None
|
||||
}
|
||||
fn get_prompt(&self, _prompt_name: &str) -> Option<PromptFuture> {
|
||||
None
|
||||
}
|
||||
fn list_prompts(&self) -> Vec<Prompt>;
|
||||
fn get_prompt(&self, prompt_name: &str) -> PromptFuture;
|
||||
|
||||
// Helper method to create base response
|
||||
fn create_response(&self, id: Option<u64>) -> JsonRpcResponse {
|
||||
@@ -257,7 +253,7 @@ pub trait Router: Send + Sync + 'static {
|
||||
req: JsonRpcRequest,
|
||||
) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
|
||||
async move {
|
||||
let prompts = self.list_prompts().unwrap_or_default();
|
||||
let prompts = self.list_prompts();
|
||||
|
||||
let result = ListPromptsResult { prompts };
|
||||
|
||||
@@ -294,36 +290,36 @@ pub trait Router: Send + Sync + 'static {
|
||||
.ok_or_else(|| RouterError::InvalidParams("Missing arguments object".into()))?;
|
||||
|
||||
// Fetch the prompt definition first
|
||||
let prompt = match self.list_prompts() {
|
||||
Some(prompts) => prompts
|
||||
.into_iter()
|
||||
.find(|p| p.name == prompt_name)
|
||||
.ok_or_else(|| {
|
||||
RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name))
|
||||
})?,
|
||||
None => return Err(RouterError::PromptNotFound("No prompts available".into())),
|
||||
};
|
||||
let prompt = self
|
||||
.list_prompts()
|
||||
.into_iter()
|
||||
.find(|p| p.name == prompt_name)
|
||||
.ok_or_else(|| {
|
||||
RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name))
|
||||
})?;
|
||||
|
||||
// Validate required arguments
|
||||
for arg in &prompt.arguments {
|
||||
if arg.required
|
||||
&& (!arguments.contains_key(&arg.name)
|
||||
|| arguments
|
||||
.get(&arg.name)
|
||||
.and_then(Value::as_str)
|
||||
.is_none_or(str::is_empty))
|
||||
{
|
||||
return Err(RouterError::InvalidParams(format!(
|
||||
"Missing required argument: '{}'",
|
||||
arg.name
|
||||
)));
|
||||
if let Some(args) = &prompt.arguments {
|
||||
for arg in args {
|
||||
if arg.required.is_some()
|
||||
&& arg.required.unwrap()
|
||||
&& (!arguments.contains_key(&arg.name)
|
||||
|| arguments
|
||||
.get(&arg.name)
|
||||
.and_then(Value::as_str)
|
||||
.is_none_or(str::is_empty))
|
||||
{
|
||||
return Err(RouterError::InvalidParams(format!(
|
||||
"Missing required argument: '{}'",
|
||||
arg.name
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now get the prompt content
|
||||
let description = self
|
||||
.get_prompt(prompt_name)
|
||||
.ok_or_else(|| RouterError::PromptNotFound("Prompt not found".into()))?
|
||||
.await
|
||||
.map_err(|e| RouterError::Internal(e.to_string()))?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user