feat: Enable frontend tools (#1778)

This commit is contained in:
Bradley Axen
2025-04-03 14:12:42 -07:00
committed by GitHub
parent 51edc0c7af
commit c8f8963545
16 changed files with 684 additions and 132 deletions

View File

@@ -8,6 +8,7 @@ use goose::{
}; };
use http::{HeaderMap, StatusCode}; use http::{HeaderMap, StatusCode};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing;
/// Enum representing the different types of extension configuration requests. /// Enum representing the different types of extension configuration requests.
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -48,6 +49,16 @@ enum ExtensionConfigRequest {
display_name: Option<String>, display_name: Option<String>,
timeout: Option<u64>, timeout: Option<u64>,
}, },
/// Frontend extension that provides tools to be executed by the frontend.
#[serde(rename = "frontend")]
Frontend {
/// The name to identify this extension
name: String,
/// The tools provided by this extension
tools: Vec<mcp_core::tool::Tool>,
/// Optional instructions for using the tools
instructions: Option<String>,
},
} }
/// Response structure for adding an extension. /// Response structure for adding an extension.
@@ -64,8 +75,26 @@ struct ExtensionResponse {
async fn add_extension( async fn add_extension(
State(state): State<AppState>, State(state): State<AppState>,
headers: HeaderMap, headers: HeaderMap,
Json(request): Json<ExtensionConfigRequest>, raw: axum::extract::Json<serde_json::Value>,
) -> Result<Json<ExtensionResponse>, StatusCode> { ) -> Result<Json<ExtensionResponse>, StatusCode> {
// Log the raw request for debugging
tracing::info!(
"Received extension request: {}",
serde_json::to_string_pretty(&raw.0).unwrap()
);
// Try to parse into our enum
let request: ExtensionConfigRequest = match serde_json::from_value(raw.0.clone()) {
Ok(req) => req,
Err(e) => {
tracing::error!("Failed to parse extension request: {}", e);
tracing::error!(
"Raw request was: {}",
serde_json::to_string_pretty(&raw.0).unwrap()
);
return Err(StatusCode::UNPROCESSABLE_ENTITY);
}
};
// Verify the presence and validity of the secret key. // Verify the presence and validity of the secret key.
let secret_key = headers let secret_key = headers
.get("X-Secret-Key") .get("X-Secret-Key")
@@ -167,6 +196,15 @@ async fn add_extension(
display_name, display_name,
timeout, timeout,
}, },
ExtensionConfigRequest::Frontend {
name,
tools,
instructions,
} => ExtensionConfig::Frontend {
name,
tools,
instructions,
},
}; };
// Acquire a lock on the agent and attempt to add the extension. // Acquire a lock on the agent and attempt to add the extension.

View File

@@ -13,9 +13,9 @@ use goose::{
agents::SessionConfig, agents::SessionConfig,
message::{Message, MessageContent}, message::{Message, MessageContent},
}; };
use mcp_core::{role::Role, Content, ToolResult};
use mcp_core::role::Role;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Value; use serde_json::Value;
use std::{ use std::{
convert::Infallible, convert::Infallible,
@@ -391,12 +391,59 @@ async fn confirm_handler(
Ok(Json(Value::Object(serde_json::Map::new()))) Ok(Json(Value::Object(serde_json::Map::new())))
} }
#[derive(Debug, Deserialize)]
struct ToolResultRequest {
id: String,
result: ToolResult<Vec<Content>>,
}
async fn submit_tool_result(
State(state): State<AppState>,
headers: HeaderMap,
raw: axum::extract::Json<serde_json::Value>,
) -> Result<Json<Value>, StatusCode> {
// Log the raw request for debugging
tracing::info!(
"Received tool result request: {}",
serde_json::to_string_pretty(&raw.0).unwrap()
);
// Try to parse into our struct
let payload: ToolResultRequest = match serde_json::from_value(raw.0.clone()) {
Ok(req) => req,
Err(e) => {
tracing::error!("Failed to parse tool result request: {}", e);
tracing::error!(
"Raw request was: {}",
serde_json::to_string_pretty(&raw.0).unwrap()
);
return Err(StatusCode::UNPROCESSABLE_ENTITY);
}
};
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}
let agent = state.agent.read().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
agent.handle_tool_result(payload.id, payload.result).await;
Ok(Json(json!({"status": "ok"})))
}
// Configure routes for this module // Configure routes for this module
pub fn routes(state: AppState) -> Router { pub fn routes(state: AppState) -> Router {
Router::new() Router::new()
.route("/reply", post(handler)) .route("/reply", post(handler))
.route("/ask", post(ask_handler)) .route("/ask", post(ask_handler))
.route("/confirm", post(confirm_handler)) .route("/confirm", post(confirm_handler))
.route("/tool_result", post(submit_tool_result))
.with_state(state) .with_state(state)
} }

View File

@@ -12,8 +12,7 @@ use super::extension::{ExtensionConfig, ExtensionResult};
use crate::message::Message; use crate::message::Message;
use crate::providers::base::Provider; use crate::providers::base::Provider;
use crate::session; use crate::session;
use mcp_core::prompt::Prompt; use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolResult};
use mcp_core::protocol::GetPromptResult;
/// Session configuration for an agent /// Session configuration for an agent
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -68,4 +67,7 @@ pub trait Agent: Send + Sync {
/// Get a reference to the provider used by this agent /// Get a reference to the provider used by this agent
async fn provider(&self) -> Arc<Box<dyn Provider>>; async fn provider(&self) -> Arc<Box<dyn Provider>>;
/// Handle a tool result from the frontend
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>);
} }

View File

@@ -26,9 +26,17 @@ static DEFAULT_TIMESTAMP: LazyLock<DateTime<Utc>> =
type McpClientBox = Arc<Mutex<Box<dyn McpClientTrait>>>; type McpClientBox = Arc<Mutex<Box<dyn McpClientTrait>>>;
/// A frontend tool that will be executed by the frontend rather than an extension
#[derive(Clone)]
pub struct FrontendTool {
pub name: String,
pub tool: Tool,
}
/// Manages MCP clients and their interactions /// Manages MCP clients and their interactions
pub struct Capabilities { pub struct Capabilities {
clients: HashMap<String, McpClientBox>, clients: HashMap<String, McpClientBox>,
frontend_tools: HashMap<String, FrontendTool>,
instructions: HashMap<String, String>, instructions: HashMap<String, String>,
resource_capable_extensions: HashSet<String>, resource_capable_extensions: HashSet<String>,
provider: Arc<Box<dyn Provider>>, provider: Arc<Box<dyn Provider>>,
@@ -96,6 +104,7 @@ impl Capabilities {
pub fn new(provider: Box<dyn Provider>) -> Self { pub fn new(provider: Box<dyn Provider>) -> Self {
Self { Self {
clients: HashMap::new(), clients: HashMap::new(),
frontend_tools: HashMap::new(),
instructions: HashMap::new(), instructions: HashMap::new(),
resource_capable_extensions: HashSet::new(), resource_capable_extensions: HashSet::new(),
provider: Arc::new(provider), provider: Arc::new(provider),
@@ -111,6 +120,30 @@ impl Capabilities {
/// Add a new MCP extension based on the provided client type /// Add a new MCP extension based on the provided client type
// TODO IMPORTANT need to ensure this times out if the extension command is broken! // TODO IMPORTANT need to ensure this times out if the extension command is broken!
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> { pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
let sanitized_name = normalize(config.key().to_string());
match &config {
ExtensionConfig::Frontend {
name: _,
tools,
instructions,
} => {
// For frontend tools, just store them in the frontend_tools map
for tool in tools {
let frontend_tool = FrontendTool {
name: tool.name.clone(),
tool: tool.clone(),
};
self.frontend_tools.insert(tool.name.clone(), frontend_tool);
}
// Store instructions if provided, using "frontend" as the key
if let Some(instructions) = instructions {
self.instructions
.insert("frontend".to_string(), instructions.clone());
}
Ok(())
}
_ => {
let mut client: Box<dyn McpClientTrait> = match &config { let mut client: Box<dyn McpClientTrait> = match &config {
ExtensionConfig::Sse { ExtensionConfig::Sse {
uri, envs, timeout, .. uri, envs, timeout, ..
@@ -142,10 +175,9 @@ impl Capabilities {
); );
Box::new(McpClient::new(service)) Box::new(McpClient::new(service))
} }
#[allow(unused_variables)]
ExtensionConfig::Builtin { ExtensionConfig::Builtin {
name, name,
display_name, display_name: _,
timeout, timeout,
} => { } => {
// For builtin extensions, we run the current executable with mcp and extension name // For builtin extensions, we run the current executable with mcp and extension name
@@ -168,6 +200,7 @@ impl Capabilities {
); );
Box::new(McpClient::new(service)) Box::new(McpClient::new(service))
} }
_ => unreachable!(),
}; };
// Initialize the client with default capabilities // Initialize the client with default capabilities
@@ -182,8 +215,6 @@ impl Capabilities {
.await .await
.map_err(|e| ExtensionError::Initialization(config.clone(), e))?; .map_err(|e| ExtensionError::Initialization(config.clone(), e))?;
let sanitized_name = normalize(config.key().to_string());
// Store instructions if provided // Store instructions if provided
if let Some(instructions) = init_result.instructions { if let Some(instructions) = init_result.instructions {
self.instructions self.instructions
@@ -202,6 +233,8 @@ impl Capabilities {
Ok(()) Ok(())
} }
}
}
/// Add a system prompt extension /// Add a system prompt extension
pub fn add_system_prompt_extension(&mut self, extension: String) { pub fn add_system_prompt_extension(&mut self, extension: String) {
@@ -235,6 +268,13 @@ impl Capabilities {
/// Get all tools from all clients with proper prefixing /// Get all tools from all clients with proper prefixing
pub async fn get_prefixed_tools(&mut self) -> ExtensionResult<Vec<Tool>> { pub async fn get_prefixed_tools(&mut self) -> ExtensionResult<Vec<Tool>> {
let mut tools = Vec::new(); let mut tools = Vec::new();
// Add frontend tools directly - they don't need prefixing since they're already uniquely named
for frontend_tool in self.frontend_tools.values() {
tools.push(frontend_tool.tool.clone());
}
// Add tools from MCP extensions with prefixing
for (name, client) in &self.clients { for (name, client) in &self.clients {
let client_guard = client.lock().await; let client_guard = client.lock().await;
let mut client_tools = client_guard.list_tools(None).await?; let mut client_tools = client_guard.list_tools(None).await?;
@@ -317,7 +357,7 @@ impl Capabilities {
pub async fn get_system_prompt(&self) -> String { pub async fn get_system_prompt(&self) -> String {
let mut context: HashMap<&str, Value> = HashMap::new(); let mut context: HashMap<&str, Value> = HashMap::new();
let extensions_info: Vec<ExtensionInfo> = self let mut extensions_info: Vec<ExtensionInfo> = self
.clients .clients
.keys() .keys()
.map(|name| { .map(|name| {
@@ -326,6 +366,15 @@ impl Capabilities {
ExtensionInfo::new(name, &instructions, has_resources) ExtensionInfo::new(name, &instructions, has_resources)
}) })
.collect(); .collect();
// Add frontend tools as a special extension if any exist
if !self.frontend_tools.is_empty() {
let name = "frontend";
let instructions = self.instructions.get(name).cloned().unwrap_or_else(||
"The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string()
);
extensions_info.push(ExtensionInfo::new(name, &instructions, false));
}
context.insert("extensions", serde_json::to_value(extensions_info).unwrap()); context.insert("extensions", serde_json::to_value(extensions_info).unwrap());
let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string(); let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
@@ -364,6 +413,16 @@ impl Capabilities {
} }
} }
/// Check if a tool is a frontend tool
pub fn is_frontend_tool(&self, name: &str) -> bool {
self.frontend_tools.contains_key(name)
}
/// Get a reference to a frontend tool
pub fn get_frontend_tool(&self, name: &str) -> Option<&FrontendTool> {
self.frontend_tools.get(name)
}
/// Find and return a reference to the appropriate client for a tool call /// Find and return a reference to the appropriate client for a tool call
fn get_client_for_tool(&self, prefixed_name: &str) -> Option<(&str, McpClientBox)> { fn get_client_for_tool(&self, prefixed_name: &str) -> Option<(&str, McpClientBox)> {
self.clients self.clients
@@ -543,6 +602,11 @@ impl Capabilities {
self.read_resource(tool_call.arguments.clone()).await self.read_resource(tool_call.arguments.clone()).await
} else if tool_call.name == "platform__list_resources" { } else if tool_call.name == "platform__list_resources" {
self.list_resources(tool_call.arguments.clone()).await self.list_resources(tool_call.arguments.clone()).await
} else if self.is_frontend_tool(&tool_call.name) {
// For frontend tools, return an error indicating we need frontend execution
Err(ToolError::ExecutionError(
"Frontend tool execution required".to_string(),
))
} else { } else {
// Else, dispatch tool call based on the prefix naming convention // Else, dispatch tool call based on the prefix naming convention
let (client_name, client) = self let (client_name, client) = self

View File

@@ -149,6 +149,16 @@ pub enum ExtensionConfig {
display_name: Option<String>, // needed for the UI display_name: Option<String>, // needed for the UI
timeout: Option<u64>, timeout: Option<u64>,
}, },
/// Frontend-provided tools that will be called through the frontend
#[serde(rename = "frontend")]
Frontend {
/// The name used to identify this extension
name: String,
/// The tools provided by the frontend
tools: Vec<mcp_core::tool::Tool>,
/// Instructions for how to use these tools
instructions: Option<String>,
},
} }
impl Default for ExtensionConfig { impl Default for ExtensionConfig {
@@ -224,6 +234,7 @@ impl ExtensionConfig {
Self::Sse { name, .. } => name, Self::Sse { name, .. } => name,
Self::Stdio { name, .. } => name, Self::Stdio { name, .. } => name,
Self::Builtin { name, .. } => name, Self::Builtin { name, .. } => name,
Self::Frontend { name, .. } => name,
} }
.to_string() .to_string()
} }
@@ -239,6 +250,9 @@ impl std::fmt::Display for ExtensionConfig {
write!(f, "Stdio({}: {} {})", name, cmd, args.join(" ")) write!(f, "Stdio({}: {} {})", name, cmd, args.join(" "))
} }
ExtensionConfig::Builtin { name, .. } => write!(f, "Builtin({})", name), ExtensionConfig::Builtin { name, .. } => write!(f, "Builtin({})", name),
ExtensionConfig::Frontend { name, tools, .. } => {
write!(f, "Frontend({}: {} tools)", name, tools.len())
}
} }
} }
} }

View File

@@ -7,6 +7,7 @@ mod permission_store;
mod reference; mod reference;
mod summarize; mod summarize;
mod truncate; mod truncate;
mod types;
pub use agent::{Agent, SessionConfig}; pub use agent::{Agent, SessionConfig};
pub use capabilities::Capabilities; pub use capabilities::Capabilities;

View File

@@ -4,12 +4,13 @@ use async_trait::async_trait;
use futures::stream::BoxStream; use futures::stream::BoxStream;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::{mpsc, Mutex};
use tracing::{debug, instrument}; use tracing::{debug, instrument};
use super::agent::SessionConfig; use super::agent::SessionConfig;
use super::capabilities::get_parameter_names; use super::capabilities::get_parameter_names;
use super::extension::ToolInfo; use super::extension::ToolInfo;
use super::types::ToolResultReceiver;
use super::Agent; use super::Agent;
use crate::agents::capabilities::Capabilities; use crate::agents::capabilities::Capabilities;
use crate::agents::extension::{ExtensionConfig, ExtensionResult}; use crate::agents::extension::{ExtensionConfig, ExtensionResult};
@@ -19,23 +20,27 @@ use crate::token_counter::TokenCounter;
use crate::{register_agent, session}; use crate::{register_agent, session};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use indoc::indoc; use indoc::indoc;
use mcp_core::prompt::Prompt;
use mcp_core::protocol::GetPromptResult;
use mcp_core::tool::{Tool, ToolAnnotations}; use mcp_core::tool::{Tool, ToolAnnotations};
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolResult};
use serde_json::{json, Value}; use serde_json::{json, Value};
/// Reference implementation of an Agent /// Reference implementation of an Agent
pub struct ReferenceAgent { pub struct ReferenceAgent {
capabilities: Mutex<Capabilities>, capabilities: Mutex<Capabilities>,
_token_counter: TokenCounter, _token_counter: TokenCounter,
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
tool_result_rx: ToolResultReceiver,
} }
impl ReferenceAgent { impl ReferenceAgent {
pub fn new(provider: Box<dyn Provider>) -> Self { pub fn new(provider: Box<dyn Provider>) -> Self {
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name()); let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
let (tx, rx) = mpsc::channel(32);
Self { Self {
capabilities: Mutex::new(Capabilities::new(provider)), capabilities: Mutex::new(Capabilities::new(provider)),
_token_counter: token_counter, _token_counter: token_counter,
tool_result_tx: tx,
tool_result_rx: Arc::new(Mutex::new(rx)),
} }
} }
} }
@@ -193,24 +198,32 @@ impl Agent for ReferenceAgent {
} }
// Then dispatch each in parallel // Then dispatch each in parallel
let futures: Vec<_> = tool_requests
.iter()
.filter_map(|request| request.tool_call.clone().ok())
.map(|tool_call| capabilities.dispatch_tool_call(tool_call))
.collect();
// Process all the futures in parallel but wait until all are finished
let outputs = futures::future::join_all(futures).await;
// Create a message with the responses
let mut message_tool_response = Message::user(); let mut message_tool_response = Message::user();
// Now combine these into MessageContent::ToolResponse using the original ID for request in tool_requests {
for (request, output) in tool_requests.iter().zip(outputs.into_iter()) { if let Ok(tool_call) = &request.tool_call {
// Check if it's a frontend tool
if capabilities.is_frontend_tool(&tool_call.name) {
// Send frontend tool request and wait for response
yield Message::assistant().with_frontend_tool_request(
request.id.clone(),
request.tool_call.clone()
);
// Wait for the result using our channel
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
message_tool_response = message_tool_response.with_tool_response(id, result);
}
continue;
}
// Handle regular tool calls
let result = capabilities.dispatch_tool_call(tool_call.clone()).await;
message_tool_response = message_tool_response.with_tool_response( message_tool_response = message_tool_response.with_tool_response(
request.id.clone(), request.id.clone(),
output, result,
); );
} }
}
yield message_tool_response.clone(); yield message_tool_response.clone();
@@ -278,6 +291,12 @@ impl Agent for ReferenceAgent {
let capabilities = self.capabilities.lock().await; let capabilities = self.capabilities.lock().await;
capabilities.provider() capabilities.provider()
} }
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
if let Err(e) = self.tool_result_tx.send((id, result)).await {
tracing::error!("Failed to send tool result: {}", e);
}
}
} }
register_agent!("reference", ReferenceAgent); register_agent!("reference", ReferenceAgent);

View File

@@ -28,9 +28,7 @@ use crate::token_counter::TokenCounter;
use crate::truncate::{truncate_messages, OldestFirstTruncation}; use crate::truncate::{truncate_messages, OldestFirstTruncation};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use indoc::indoc; use indoc::indoc;
use mcp_core::prompt::Prompt; use mcp_core::{prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolResult};
use mcp_core::protocol::GetPromptResult;
use mcp_core::{tool::Tool, Content};
use serde_json::{json, Value}; use serde_json::{json, Value};
const MAX_TRUNCATION_ATTEMPTS: usize = 3; const MAX_TRUNCATION_ATTEMPTS: usize = 3;
@@ -42,19 +40,22 @@ pub struct SummarizeAgent {
token_counter: TokenCounter, token_counter: TokenCounter,
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed) confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed)
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>, confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>,
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
} }
impl SummarizeAgent { impl SummarizeAgent {
pub fn new(provider: Box<dyn Provider>) -> Self { pub fn new(provider: Box<dyn Provider>) -> Self {
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name()); let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
// Create channel with buffer size 32 (adjust if needed) // Create channels with buffer size 32 (adjust if needed)
let (tx, rx) = mpsc::channel(32); let (confirm_tx, confirm_rx) = mpsc::channel(32);
let (tool_tx, _tool_rx) = mpsc::channel(32);
Self { Self {
capabilities: Mutex::new(Capabilities::new(provider)), capabilities: Mutex::new(Capabilities::new(provider)),
token_counter, token_counter,
confirmation_tx: tx, confirmation_tx: confirm_tx,
confirmation_rx: Mutex::new(rx), confirmation_rx: Mutex::new(confirm_rx),
tool_result_tx: tool_tx,
} }
} }
@@ -493,6 +494,12 @@ impl Agent for SummarizeAgent {
let capabilities = self.capabilities.lock().await; let capabilities = self.capabilities.lock().await;
capabilities.provider() capabilities.provider()
} }
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
if let Err(e) = self.tool_result_tx.send((id, result)).await {
tracing::error!("Failed to send tool result: {}", e);
}
}
} }
register_agent!("summarize", SummarizeAgent); register_agent!("summarize", SummarizeAgent);

View File

@@ -12,12 +12,13 @@ use tracing::{debug, error, instrument, warn};
use super::agent::SessionConfig; use super::agent::SessionConfig;
use super::detect_read_only_tools; use super::detect_read_only_tools;
use super::extension::ToolInfo; use super::extension::ToolInfo;
use super::types::ToolResultReceiver;
use super::Agent; use super::Agent;
use crate::agents::capabilities::{get_parameter_names, Capabilities}; use crate::agents::capabilities::{get_parameter_names, Capabilities};
use crate::agents::extension::{ExtensionConfig, ExtensionResult}; use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::agents::ToolPermissionStore; use crate::agents::ToolPermissionStore;
use crate::config::Config; use crate::config::Config;
use crate::message::{Message, ToolRequest}; use crate::message::{Message, MessageContent, ToolRequest};
use crate::providers::base::Provider; use crate::providers::base::Provider;
use crate::providers::errors::ProviderError; use crate::providers::errors::ProviderError;
use crate::providers::toolshim::{ use crate::providers::toolshim::{
@@ -29,9 +30,9 @@ use crate::token_counter::TokenCounter;
use crate::truncate::{truncate_messages, OldestFirstTruncation}; use crate::truncate::{truncate_messages, OldestFirstTruncation};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use indoc::indoc; use indoc::indoc;
use mcp_core::prompt::Prompt; use mcp_core::{
use mcp_core::protocol::GetPromptResult; prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult,
use mcp_core::{tool::Tool, Content, ToolError}; };
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::time::Duration; use std::time::Duration;
@@ -44,19 +45,24 @@ pub struct TruncateAgent {
token_counter: TokenCounter, token_counter: TokenCounter,
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed) confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed)
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>, confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>,
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
tool_result_rx: ToolResultReceiver,
} }
impl TruncateAgent { impl TruncateAgent {
pub fn new(provider: Box<dyn Provider>) -> Self { pub fn new(provider: Box<dyn Provider>) -> Self {
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name()); let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
// Create channel with buffer size 32 (adjust if needed) // Create channels with buffer size 32 (adjust if needed)
let (tx, rx) = mpsc::channel(32); let (confirm_tx, confirm_rx) = mpsc::channel(32);
let (tool_tx, tool_rx) = mpsc::channel(32);
Self { Self {
capabilities: Mutex::new(Capabilities::new(provider)), capabilities: Mutex::new(Capabilities::new(provider)),
token_counter, token_counter,
confirmation_tx: tx, confirmation_tx: confirm_tx,
confirmation_rx: Mutex::new(rx), confirmation_rx: Mutex::new(confirm_rx),
tool_result_tx: tool_tx,
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
} }
} }
@@ -306,8 +312,21 @@ impl Agent for TruncateAgent {
// Reset truncation attempt // Reset truncation attempt
truncation_attempt = 0; truncation_attempt = 0;
// Yield the assistant's response // Yield the assistant's response, but filter out frontend tool requests that we'll process separately
yield response.clone(); let filtered_response = Message {
role: response.role.clone(),
created: response.created,
content: response.content.iter().filter(|c| {
if let MessageContent::ToolRequest(req) = c {
// Only filter out frontend tool requests
if let Ok(tool_call) = &req.tool_call {
return !capabilities.is_frontend_tool(&tool_call.name);
}
}
true
}).cloned().collect(),
};
yield filtered_response.clone();
tokio::task::yield_now().await; tokio::task::yield_now().await;
@@ -323,6 +342,29 @@ impl Agent for TruncateAgent {
// Process tool requests depending on goose_mode // Process tool requests depending on goose_mode
let mut message_tool_response = Message::user(); let mut message_tool_response = Message::user();
// First handle any frontend tool requests
let mut remaining_requests = Vec::new();
for request in &tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
if capabilities.is_frontend_tool(&tool_call.name) {
// Send frontend tool request and wait for response
yield Message::assistant().with_frontend_tool_request(
request.id.clone(),
Ok(tool_call.clone())
);
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
message_tool_response = message_tool_response.with_tool_response(id, result);
}
} else {
remaining_requests.push(request);
}
} else {
remaining_requests.push(request);
}
}
// Clone goose_mode once before the match to avoid move issues // Clone goose_mode once before the match to avoid move issues
let mode = goose_mode.clone(); let mode = goose_mode.clone();
match mode.as_str() { match mode.as_str() {
@@ -334,7 +376,7 @@ impl Agent for TruncateAgent {
// First check permissions for all tools // First check permissions for all tools
let store = ToolPermissionStore::load()?; let store = ToolPermissionStore::load()?;
for request in tool_requests.iter() { for request in remaining_requests.iter() {
if let Ok(tool_call) = request.tool_call.clone() { if let Ok(tool_call) = request.tool_call.clone() {
if tools_with_readonly_annotation.contains(&tool_call.name) { if tools_with_readonly_annotation.contains(&tool_call.name) {
approved_tools.push((request.id.clone(), tool_call)); approved_tools.push((request.id.clone(), tool_call));
@@ -427,7 +469,7 @@ impl Agent for TruncateAgent {
}, },
"chat" => { "chat" => {
// Skip all tool calls in chat mode // Skip all tool calls in chat mode
for request in &tool_requests { for request in &remaining_requests {
message_tool_response = message_tool_response.with_tool_response( message_tool_response = message_tool_response.with_tool_response(
request.id.clone(), request.id.clone(),
Ok(vec![Content::text( Ok(vec![Content::text(
@@ -449,7 +491,7 @@ impl Agent for TruncateAgent {
} }
// Process tool requests in parallel // Process tool requests in parallel
let mut tool_futures = Vec::new(); let mut tool_futures = Vec::new();
for request in &tool_requests { for request in &remaining_requests {
if let Ok(tool_call) = request.tool_call.clone() { if let Ok(tool_call) = request.tool_call.clone() {
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone()); let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
tool_futures.push(tool_future); tool_futures.push(tool_future);
@@ -574,6 +616,12 @@ impl Agent for TruncateAgent {
let capabilities = self.capabilities.lock().await; let capabilities = self.capabilities.lock().await;
capabilities.provider() capabilities.provider()
} }
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
if let Err(e) = self.tool_result_tx.send((id, result)).await {
tracing::error!("Failed to send tool result: {}", e);
}
}
} }
register_agent!("truncate", TruncateAgent); register_agent!("truncate", TruncateAgent);

View File

@@ -0,0 +1,6 @@
use mcp_core::{Content, ToolResult};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
/// Type alias for the tool result channel receiver
pub type ToolResultReceiver = Arc<Mutex<mpsc::Receiver<(String, ToolResult<Vec<Content>>)>>>;

View File

@@ -70,6 +70,14 @@ pub struct RedactedThinkingContent {
pub data: String, pub data: String,
} }
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FrontendToolRequest {
pub id: String,
#[serde(with = "tool_result_serde")]
pub tool_call: ToolResult<ToolCall>,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// Content passed inside a message, which can be both simple content and tool content /// Content passed inside a message, which can be both simple content and tool content
#[serde(tag = "type", rename_all = "camelCase")] #[serde(tag = "type", rename_all = "camelCase")]
@@ -79,6 +87,7 @@ pub enum MessageContent {
ToolRequest(ToolRequest), ToolRequest(ToolRequest),
ToolResponse(ToolResponse), ToolResponse(ToolResponse),
ToolConfirmationRequest(ToolConfirmationRequest), ToolConfirmationRequest(ToolConfirmationRequest),
FrontendToolRequest(FrontendToolRequest),
Thinking(ThinkingContent), Thinking(ThinkingContent),
RedactedThinking(RedactedThinkingContent), RedactedThinking(RedactedThinkingContent),
} }
@@ -137,6 +146,13 @@ impl MessageContent {
pub fn redacted_thinking<S: Into<String>>(data: S) -> Self { pub fn redacted_thinking<S: Into<String>>(data: S) -> Self {
MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() }) MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() })
} }
pub fn frontend_tool_request<S: Into<String>>(id: S, tool_call: ToolResult<ToolCall>) -> Self {
MessageContent::FrontendToolRequest(FrontendToolRequest {
id: id.into(),
tool_call,
})
}
pub fn as_tool_request(&self) -> Option<&ToolRequest> { pub fn as_tool_request(&self) -> Option<&ToolRequest> {
if let MessageContent::ToolRequest(ref tool_request) = self { if let MessageContent::ToolRequest(ref tool_request) = self {
Some(tool_request) Some(tool_request)
@@ -320,6 +336,14 @@ impl Message {
)) ))
} }
pub fn with_frontend_tool_request<S: Into<String>>(
self,
id: S,
tool_call: ToolResult<ToolCall>,
) -> Self {
self.with_content(MessageContent::frontend_tool_request(id, tool_call))
}
/// Add thinking content to the message /// Add thinking content to the message
pub fn with_thinking<S1: Into<String>, S2: Into<String>>( pub fn with_thinking<S1: Into<String>, S2: Into<String>>(
self, self,

View File

@@ -74,6 +74,16 @@ pub fn format_messages(messages: &[Message]) -> Vec<Value> {
})); }));
} }
MessageContent::Image(_) => continue, // Anthropic doesn't support image content yet MessageContent::Image(_) => continue, // Anthropic doesn't support image content yet
MessageContent::FrontendToolRequest(tool_request) => {
if let Ok(tool_call) = &tool_request.tool_call {
content.push(json!({
"type": "tool_use",
"id": tool_request.id,
"name": tool_call.name,
"input": tool_call.arguments
}));
}
}
} }
} }

View File

@@ -57,6 +57,21 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::C
}?; }?;
bedrock::ContentBlock::ToolUse(tool_use) bedrock::ContentBlock::ToolUse(tool_use)
} }
MessageContent::FrontendToolRequest(tool_req) => {
let tool_use_id = tool_req.id.to_string();
let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() {
bedrock::ToolUseBlock::builder()
.tool_use_id(tool_use_id)
.name(call.name.to_string())
.input(to_bedrock_json(&call.arguments))
.build()
} else {
bedrock::ToolUseBlock::builder()
.tool_use_id(tool_use_id)
.build()
}?;
bedrock::ContentBlock::ToolUse(tool_use)
}
MessageContent::ToolResponse(tool_res) => { MessageContent::ToolResponse(tool_res) => {
let content = match &tool_res.tool_result { let content = match &tool_res.tool_result {
Ok(content) => Some( Ok(content) => Some(

View File

@@ -188,6 +188,27 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
} }
})); }));
} }
MessageContent::FrontendToolRequest(req) => {
// Frontend tool requests are converted to text messages
if let Ok(tool_call) = &req.tool_call {
content_array.push(json!({
"type": "text",
"text": format!(
"Frontend tool request: {} ({})",
tool_call.name,
serde_json::to_string_pretty(&tool_call.arguments).unwrap()
)
}));
} else {
content_array.push(json!({
"type": "text",
"text": format!(
"Frontend tool request error: {}",
req.tool_call.as_ref().unwrap_err()
)
}));
}
}
} }
} }

View File

@@ -151,6 +151,32 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
// Handle direct image content // Handle direct image content
converted["content"] = json!([convert_image(image, image_format)]); converted["content"] = json!([convert_image(image, image_format)]);
} }
MessageContent::FrontendToolRequest(request) => match &request.tool_call {
Ok(tool_call) => {
let sanitized_name = sanitize_function_name(&tool_call.name);
let tool_calls = converted
.as_object_mut()
.unwrap()
.entry("tool_calls")
.or_insert(json!([]));
tool_calls.as_array_mut().unwrap().push(json!({
"id": request.id,
"type": "function",
"function": {
"name": sanitized_name,
"arguments": tool_call.arguments.to_string(),
}
}));
}
Err(e) => {
output.push(json!({
"role": "tool",
"content": format!("Error: {}", e),
"tool_call_id": request.id
}));
}
},
} }
} }

210
examples/frontend_tools.py Normal file
View File

@@ -0,0 +1,210 @@
import asyncio
import json
import os
from typing import List, Dict, Any
import httpx
from datetime import datetime
# Configuration
GOOSE_HOST = "127.0.0.1"
GOOSE_PORT = "3001"
GOOSE_URL = f"http://{GOOSE_HOST}:{GOOSE_PORT}"
SECRET_KEY = "test" # Default development secret key
# A simple calculator tool definition
CALCULATOR_TOOL = {
"name": "calculator",
"description": "Perform basic arithmetic calculations",
"inputSchema": {
"type": "object",
"required": ["operation", "numbers"],
"properties": {
"operation": {
"type": "string",
"enum": ["add", "subtract", "multiply", "divide"],
"description": "The arithmetic operation to perform",
},
"numbers": {
"type": "array",
"items": {"type": "number"},
"description": "List of numbers to operate on in order",
},
},
},
}
# Frontend extension configuration
FRONTEND_CONFIG = {
"name": "pythonclient",
"type": "frontend",
"tools": [CALCULATOR_TOOL],
"instructions": "A calculator extension that can perform basic arithmetic operations.",
}
async def setup_agent() -> None:
"""Initialize the agent with our frontend tool."""
async with httpx.AsyncClient() as client:
# First create the agent
response = await client.post(
f"{GOOSE_URL}/agent",
json={"provider": "databricks", "model": "goose"},
headers={"X-Secret-Key": SECRET_KEY},
)
response.raise_for_status()
print("Successfully created agent")
# Then add our frontend extension
response = await client.post(
f"{GOOSE_URL}/extensions/add",
json=FRONTEND_CONFIG,
headers={"X-Secret-Key": SECRET_KEY},
)
response.raise_for_status()
print("Successfully added calculator extension")
def execute_calculator(args: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Execute the calculator tool with the given arguments."""
operation = args["operation"]
numbers = args["numbers"]
try:
result = None
if operation == "add":
result = sum(numbers)
elif operation == "subtract":
result = numbers[0] - sum(numbers[1:])
elif operation == "multiply":
result = 1
for n in numbers:
result *= n
elif operation == "divide":
result = numbers[0]
for n in numbers[1:]:
result /= n
# Return properly structured Content::Text variant
return [
{
"type": "text",
"text": str(result),
"annotations": None, # Required field in Rust struct
}
]
except Exception as e:
return [
{
"type": "text",
"text": f"Error: {str(e)}",
"annotations": None, # Required field in Rust struct
}
]
def submit_tool_result(tool_id: str, result: List[Dict[str, Any]]) -> None:
"""Submit the tool execution result back to Goose.
The result should be a list of Content variants (Text, Image, or Resource).
Each Content variant has a type tag and appropriate fields.
"""
payload = {
"id": tool_id,
"result": {
"Ok": result # Result enum variant with single key for success case
},
}
with httpx.Client(timeout=2.0) as client:
response = client.post(
f"{GOOSE_URL}/tool_result",
json=payload,
headers={"X-Secret-Key": SECRET_KEY},
)
response.raise_for_status()
async def chat_loop() -> None:
"""Main chat loop that handles the conversation with Goose."""
session_id = "test-session"
# Use a client with a longer timeout for streaming
async with httpx.AsyncClient(timeout=30.0) as client:
# Get user input
user_message = input("\nYou: ")
if user_message.lower() in ["exit", "quit"]:
return
# Create the message object
message = {
"role": "user",
"created": int(datetime.now().timestamp()),
"content": [{"type": "text", "text": user_message}],
}
# Send to /reply endpoint
payload = {
"messages": [message],
"session_id": session_id,
"session_working_dir": os.getcwd(),
}
# Process the stream of responses
async with client.stream(
"POST",
f"{GOOSE_URL}/reply",
json=payload,
headers={
"X-Secret-Key": SECRET_KEY,
"Accept": "text/event-stream",
"Content-Type": "application/json",
},
) as stream:
async for line in stream.aiter_lines():
if not line:
continue
# Handle SSE format
if line.startswith("data: "):
line = line[6:] # Remove "data: " prefix
try:
payload = json.loads(line)
except json.JSONDecodeError:
print(f"Failed to parse line: {line}")
continue
if payload["type"] == "Finish":
break
message = payload["message"]
# Handle different message types
for content in message.get("content", []):
if content["type"] == "text":
print(f"\nGoose: {content['text']}")
elif content["type"] == "frontendToolRequest":
# Execute the tool and submit results
tool_call = content["toolCall"]["value"]
print(f"Calculator: {tool_call}")
# Execute the tool
result = execute_calculator(tool_call["arguments"])
# Submit the result
submit_tool_result(content["id"], result)
async def main():
try:
# Initialize the agent with our tool
await setup_agent()
# Start the chat loop
await chat_loop()
except Exception as e:
print(f"Error: {e}")
raise # Re-raise to see full traceback
if __name__ == "__main__":
asyncio.run(main())