diff --git a/crates/goose-server/src/routes/extension.rs b/crates/goose-server/src/routes/extension.rs index c362be69..6bff014f 100644 --- a/crates/goose-server/src/routes/extension.rs +++ b/crates/goose-server/src/routes/extension.rs @@ -8,6 +8,7 @@ use goose::{ }; use http::{HeaderMap, StatusCode}; use serde::{Deserialize, Serialize}; +use tracing; /// Enum representing the different types of extension configuration requests. #[derive(Deserialize)] @@ -48,6 +49,16 @@ enum ExtensionConfigRequest { display_name: Option, timeout: Option, }, + /// Frontend extension that provides tools to be executed by the frontend. + #[serde(rename = "frontend")] + Frontend { + /// The name to identify this extension + name: String, + /// The tools provided by this extension + tools: Vec, + /// Optional instructions for using the tools + instructions: Option, + }, } /// Response structure for adding an extension. @@ -64,8 +75,26 @@ struct ExtensionResponse { async fn add_extension( State(state): State, headers: HeaderMap, - Json(request): Json, + raw: axum::extract::Json, ) -> Result, StatusCode> { + // Log the raw request for debugging + tracing::info!( + "Received extension request: {}", + serde_json::to_string_pretty(&raw.0).unwrap() + ); + + // Try to parse into our enum + let request: ExtensionConfigRequest = match serde_json::from_value(raw.0.clone()) { + Ok(req) => req, + Err(e) => { + tracing::error!("Failed to parse extension request: {}", e); + tracing::error!( + "Raw request was: {}", + serde_json::to_string_pretty(&raw.0).unwrap() + ); + return Err(StatusCode::UNPROCESSABLE_ENTITY); + } + }; // Verify the presence and validity of the secret key. let secret_key = headers .get("X-Secret-Key") @@ -167,6 +196,15 @@ async fn add_extension( display_name, timeout, }, + ExtensionConfigRequest::Frontend { + name, + tools, + instructions, + } => ExtensionConfig::Frontend { + name, + tools, + instructions, + }, }; // Acquire a lock on the agent and attempt to add the extension. diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 9a2271d0..15d75eb5 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -13,9 +13,9 @@ use goose::{ agents::SessionConfig, message::{Message, MessageContent}, }; - -use mcp_core::role::Role; +use mcp_core::{role::Role, Content, ToolResult}; use serde::{Deserialize, Serialize}; +use serde_json::json; use serde_json::Value; use std::{ convert::Infallible, @@ -391,12 +391,59 @@ async fn confirm_handler( Ok(Json(Value::Object(serde_json::Map::new()))) } +#[derive(Debug, Deserialize)] +struct ToolResultRequest { + id: String, + result: ToolResult>, +} + +async fn submit_tool_result( + State(state): State, + headers: HeaderMap, + raw: axum::extract::Json, +) -> Result, StatusCode> { + // Log the raw request for debugging + tracing::info!( + "Received tool result request: {}", + serde_json::to_string_pretty(&raw.0).unwrap() + ); + + // Try to parse into our struct + let payload: ToolResultRequest = match serde_json::from_value(raw.0.clone()) { + Ok(req) => req, + Err(e) => { + tracing::error!("Failed to parse tool result request: {}", e); + tracing::error!( + "Raw request was: {}", + serde_json::to_string_pretty(&raw.0).unwrap() + ); + return Err(StatusCode::UNPROCESSABLE_ENTITY); + } + }; + + // Verify secret key + let secret_key = headers + .get("X-Secret-Key") + .and_then(|value| value.to_str().ok()) + .ok_or(StatusCode::UNAUTHORIZED)?; + + if secret_key != state.secret_key { + return Err(StatusCode::UNAUTHORIZED); + } + + let agent = state.agent.read().await; + let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?; + agent.handle_tool_result(payload.id, payload.result).await; + Ok(Json(json!({"status": "ok"}))) +} + // Configure routes for this module pub fn routes(state: AppState) -> Router { Router::new() .route("/reply", post(handler)) .route("/ask", post(ask_handler)) .route("/confirm", post(confirm_handler)) + .route("/tool_result", post(submit_tool_result)) .with_state(state) } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 76ea04cb..185f6b87 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -12,8 +12,7 @@ use super::extension::{ExtensionConfig, ExtensionResult}; use crate::message::Message; use crate::providers::base::Provider; use crate::session; -use mcp_core::prompt::Prompt; -use mcp_core::protocol::GetPromptResult; +use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolResult}; /// Session configuration for an agent #[derive(Debug, Clone, Serialize, Deserialize)] @@ -68,4 +67,7 @@ pub trait Agent: Send + Sync { /// Get a reference to the provider used by this agent async fn provider(&self) -> Arc>; + + /// Handle a tool result from the frontend + async fn handle_tool_result(&self, id: String, result: ToolResult>); } diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index 40ac4c4f..73ee67bf 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -26,9 +26,17 @@ static DEFAULT_TIMESTAMP: LazyLock> = type McpClientBox = Arc>>; +/// A frontend tool that will be executed by the frontend rather than an extension +#[derive(Clone)] +pub struct FrontendTool { + pub name: String, + pub tool: Tool, +} + /// Manages MCP clients and their interactions pub struct Capabilities { clients: HashMap, + frontend_tools: HashMap, instructions: HashMap, resource_capable_extensions: HashSet, provider: Arc>, @@ -96,6 +104,7 @@ impl Capabilities { pub fn new(provider: Box) -> Self { Self { clients: HashMap::new(), + frontend_tools: HashMap::new(), instructions: HashMap::new(), resource_capable_extensions: HashSet::new(), provider: Arc::new(provider), @@ -111,96 +120,120 @@ impl Capabilities { /// Add a new MCP extension based on the provided client type // TODO IMPORTANT need to ensure this times out if the extension command is broken! pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> { - let mut client: Box = match &config { - ExtensionConfig::Sse { - uri, envs, timeout, .. - } => { - let transport = SseTransport::new(uri, envs.get_env()); - let handle = transport.start().await?; - let service = McpService::with_timeout( - handle, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), - ); - Box::new(McpClient::new(service)) - } - ExtensionConfig::Stdio { - cmd, - args, - envs, - timeout, - .. - } => { - let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env()); - let handle = transport.start().await?; - let service = McpService::with_timeout( - handle, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), - ); - Box::new(McpClient::new(service)) - } - #[allow(unused_variables)] - ExtensionConfig::Builtin { - name, - display_name, - timeout, - } => { - // For builtin extensions, we run the current executable with mcp and extension name - let cmd = std::env::current_exe() - .expect("should find the current executable") - .to_str() - .expect("should resolve executable to string path") - .to_string(); - let transport = StdioTransport::new( - &cmd, - vec!["mcp".to_string(), name.clone()], - HashMap::new(), - ); - let handle = transport.start().await?; - let service = McpService::with_timeout( - handle, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), - ); - Box::new(McpClient::new(service)) - } - }; - - // Initialize the client with default capabilities - let info = ClientInfo { - name: "goose".to_string(), - version: env!("CARGO_PKG_VERSION").to_string(), - }; - let capabilities = ClientCapabilities::default(); - - let init_result = client - .initialize(info, capabilities) - .await - .map_err(|e| ExtensionError::Initialization(config.clone(), e))?; - let sanitized_name = normalize(config.key().to_string()); - // Store instructions if provided - if let Some(instructions) = init_result.instructions { - self.instructions - .insert(sanitized_name.clone(), instructions); + match &config { + ExtensionConfig::Frontend { + name: _, + tools, + instructions, + } => { + // For frontend tools, just store them in the frontend_tools map + for tool in tools { + let frontend_tool = FrontendTool { + name: tool.name.clone(), + tool: tool.clone(), + }; + self.frontend_tools.insert(tool.name.clone(), frontend_tool); + } + // Store instructions if provided, using "frontend" as the key + if let Some(instructions) = instructions { + self.instructions + .insert("frontend".to_string(), instructions.clone()); + } + Ok(()) + } + _ => { + let mut client: Box = match &config { + ExtensionConfig::Sse { + uri, envs, timeout, .. + } => { + let transport = SseTransport::new(uri, envs.get_env()); + let handle = transport.start().await?; + let service = McpService::with_timeout( + handle, + Duration::from_secs( + timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), + ), + ); + Box::new(McpClient::new(service)) + } + ExtensionConfig::Stdio { + cmd, + args, + envs, + timeout, + .. + } => { + let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env()); + let handle = transport.start().await?; + let service = McpService::with_timeout( + handle, + Duration::from_secs( + timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), + ), + ); + Box::new(McpClient::new(service)) + } + ExtensionConfig::Builtin { + name, + display_name: _, + timeout, + } => { + // For builtin extensions, we run the current executable with mcp and extension name + let cmd = std::env::current_exe() + .expect("should find the current executable") + .to_str() + .expect("should resolve executable to string path") + .to_string(); + let transport = StdioTransport::new( + &cmd, + vec!["mcp".to_string(), name.clone()], + HashMap::new(), + ); + let handle = transport.start().await?; + let service = McpService::with_timeout( + handle, + Duration::from_secs( + timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), + ), + ); + Box::new(McpClient::new(service)) + } + _ => unreachable!(), + }; + + // Initialize the client with default capabilities + let info = ClientInfo { + name: "goose".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }; + let capabilities = ClientCapabilities::default(); + + let init_result = client + .initialize(info, capabilities) + .await + .map_err(|e| ExtensionError::Initialization(config.clone(), e))?; + + // Store instructions if provided + if let Some(instructions) = init_result.instructions { + self.instructions + .insert(sanitized_name.clone(), instructions); + } + + // if the server is capable if resources we track it + if init_result.capabilities.resources.is_some() { + self.resource_capable_extensions + .insert(sanitized_name.clone()); + } + + // Store the client using the provided name + self.clients + .insert(sanitized_name.clone(), Arc::new(Mutex::new(client))); + + Ok(()) + } } - - // if the server is capable if resources we track it - if init_result.capabilities.resources.is_some() { - self.resource_capable_extensions - .insert(sanitized_name.clone()); - } - - // Store the client using the provided name - self.clients - .insert(sanitized_name.clone(), Arc::new(Mutex::new(client))); - - Ok(()) } /// Add a system prompt extension @@ -235,6 +268,13 @@ impl Capabilities { /// Get all tools from all clients with proper prefixing pub async fn get_prefixed_tools(&mut self) -> ExtensionResult> { let mut tools = Vec::new(); + + // Add frontend tools directly - they don't need prefixing since they're already uniquely named + for frontend_tool in self.frontend_tools.values() { + tools.push(frontend_tool.tool.clone()); + } + + // Add tools from MCP extensions with prefixing for (name, client) in &self.clients { let client_guard = client.lock().await; let mut client_tools = client_guard.list_tools(None).await?; @@ -317,7 +357,7 @@ impl Capabilities { pub async fn get_system_prompt(&self) -> String { let mut context: HashMap<&str, Value> = HashMap::new(); - let extensions_info: Vec = self + let mut extensions_info: Vec = self .clients .keys() .map(|name| { @@ -326,6 +366,15 @@ impl Capabilities { ExtensionInfo::new(name, &instructions, has_resources) }) .collect(); + + // Add frontend tools as a special extension if any exist + if !self.frontend_tools.is_empty() { + let name = "frontend"; + let instructions = self.instructions.get(name).cloned().unwrap_or_else(|| + "The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string() + ); + extensions_info.push(ExtensionInfo::new(name, &instructions, false)); + } context.insert("extensions", serde_json::to_value(extensions_info).unwrap()); let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string(); @@ -364,6 +413,16 @@ impl Capabilities { } } + /// Check if a tool is a frontend tool + pub fn is_frontend_tool(&self, name: &str) -> bool { + self.frontend_tools.contains_key(name) + } + + /// Get a reference to a frontend tool + pub fn get_frontend_tool(&self, name: &str) -> Option<&FrontendTool> { + self.frontend_tools.get(name) + } + /// Find and return a reference to the appropriate client for a tool call fn get_client_for_tool(&self, prefixed_name: &str) -> Option<(&str, McpClientBox)> { self.clients @@ -543,6 +602,11 @@ impl Capabilities { self.read_resource(tool_call.arguments.clone()).await } else if tool_call.name == "platform__list_resources" { self.list_resources(tool_call.arguments.clone()).await + } else if self.is_frontend_tool(&tool_call.name) { + // For frontend tools, return an error indicating we need frontend execution + Err(ToolError::ExecutionError( + "Frontend tool execution required".to_string(), + )) } else { // Else, dispatch tool call based on the prefix naming convention let (client_name, client) = self diff --git a/crates/goose/src/agents/extension.rs b/crates/goose/src/agents/extension.rs index d44de5a3..056786e2 100644 --- a/crates/goose/src/agents/extension.rs +++ b/crates/goose/src/agents/extension.rs @@ -149,6 +149,16 @@ pub enum ExtensionConfig { display_name: Option, // needed for the UI timeout: Option, }, + /// Frontend-provided tools that will be called through the frontend + #[serde(rename = "frontend")] + Frontend { + /// The name used to identify this extension + name: String, + /// The tools provided by the frontend + tools: Vec, + /// Instructions for how to use these tools + instructions: Option, + }, } impl Default for ExtensionConfig { @@ -224,6 +234,7 @@ impl ExtensionConfig { Self::Sse { name, .. } => name, Self::Stdio { name, .. } => name, Self::Builtin { name, .. } => name, + Self::Frontend { name, .. } => name, } .to_string() } @@ -239,6 +250,9 @@ impl std::fmt::Display for ExtensionConfig { write!(f, "Stdio({}: {} {})", name, cmd, args.join(" ")) } ExtensionConfig::Builtin { name, .. } => write!(f, "Builtin({})", name), + ExtensionConfig::Frontend { name, tools, .. } => { + write!(f, "Frontend({}: {} tools)", name, tools.len()) + } } } } diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index c1af723a..89c66da5 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -7,6 +7,7 @@ mod permission_store; mod reference; mod summarize; mod truncate; +mod types; pub use agent::{Agent, SessionConfig}; pub use capabilities::Capabilities; diff --git a/crates/goose/src/agents/reference.rs b/crates/goose/src/agents/reference.rs index 9989a80c..71ca7bfd 100644 --- a/crates/goose/src/agents/reference.rs +++ b/crates/goose/src/agents/reference.rs @@ -4,12 +4,13 @@ use async_trait::async_trait; use futures::stream::BoxStream; use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::Mutex; +use tokio::sync::{mpsc, Mutex}; use tracing::{debug, instrument}; use super::agent::SessionConfig; use super::capabilities::get_parameter_names; use super::extension::ToolInfo; +use super::types::ToolResultReceiver; use super::Agent; use crate::agents::capabilities::Capabilities; use crate::agents::extension::{ExtensionConfig, ExtensionResult}; @@ -19,23 +20,27 @@ use crate::token_counter::TokenCounter; use crate::{register_agent, session}; use anyhow::{anyhow, Result}; use indoc::indoc; -use mcp_core::prompt::Prompt; -use mcp_core::protocol::GetPromptResult; use mcp_core::tool::{Tool, ToolAnnotations}; +use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolResult}; use serde_json::{json, Value}; /// Reference implementation of an Agent pub struct ReferenceAgent { capabilities: Mutex, _token_counter: TokenCounter, + tool_result_tx: mpsc::Sender<(String, ToolResult>)>, + tool_result_rx: ToolResultReceiver, } impl ReferenceAgent { pub fn new(provider: Box) -> Self { let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name()); + let (tx, rx) = mpsc::channel(32); Self { capabilities: Mutex::new(Capabilities::new(provider)), _token_counter: token_counter, + tool_result_tx: tx, + tool_result_rx: Arc::new(Mutex::new(rx)), } } } @@ -193,23 +198,31 @@ impl Agent for ReferenceAgent { } // Then dispatch each in parallel - let futures: Vec<_> = tool_requests - .iter() - .filter_map(|request| request.tool_call.clone().ok()) - .map(|tool_call| capabilities.dispatch_tool_call(tool_call)) - .collect(); - - // Process all the futures in parallel but wait until all are finished - let outputs = futures::future::join_all(futures).await; - - // Create a message with the responses let mut message_tool_response = Message::user(); - // Now combine these into MessageContent::ToolResponse using the original ID - for (request, output) in tool_requests.iter().zip(outputs.into_iter()) { - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - output, - ); + for request in tool_requests { + if let Ok(tool_call) = &request.tool_call { + // Check if it's a frontend tool + if capabilities.is_frontend_tool(&tool_call.name) { + // Send frontend tool request and wait for response + yield Message::assistant().with_frontend_tool_request( + request.id.clone(), + request.tool_call.clone() + ); + + // Wait for the result using our channel + if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await { + message_tool_response = message_tool_response.with_tool_response(id, result); + } + continue; + } + + // Handle regular tool calls + let result = capabilities.dispatch_tool_call(tool_call.clone()).await; + message_tool_response = message_tool_response.with_tool_response( + request.id.clone(), + result, + ); + } } yield message_tool_response.clone(); @@ -278,6 +291,12 @@ impl Agent for ReferenceAgent { let capabilities = self.capabilities.lock().await; capabilities.provider() } + + async fn handle_tool_result(&self, id: String, result: ToolResult>) { + if let Err(e) = self.tool_result_tx.send((id, result)).await { + tracing::error!("Failed to send tool result: {}", e); + } + } } register_agent!("reference", ReferenceAgent); diff --git a/crates/goose/src/agents/summarize.rs b/crates/goose/src/agents/summarize.rs index a0daf777..486e1944 100644 --- a/crates/goose/src/agents/summarize.rs +++ b/crates/goose/src/agents/summarize.rs @@ -28,9 +28,7 @@ use crate::token_counter::TokenCounter; use crate::truncate::{truncate_messages, OldestFirstTruncation}; use anyhow::{anyhow, Result}; use indoc::indoc; -use mcp_core::prompt::Prompt; -use mcp_core::protocol::GetPromptResult; -use mcp_core::{tool::Tool, Content}; +use mcp_core::{prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolResult}; use serde_json::{json, Value}; const MAX_TRUNCATION_ATTEMPTS: usize = 3; @@ -42,19 +40,22 @@ pub struct SummarizeAgent { token_counter: TokenCounter, confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed) confirmation_rx: Mutex>, + tool_result_tx: mpsc::Sender<(String, ToolResult>)>, } impl SummarizeAgent { pub fn new(provider: Box) -> Self { let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name()); - // Create channel with buffer size 32 (adjust if needed) - let (tx, rx) = mpsc::channel(32); + // Create channels with buffer size 32 (adjust if needed) + let (confirm_tx, confirm_rx) = mpsc::channel(32); + let (tool_tx, _tool_rx) = mpsc::channel(32); Self { capabilities: Mutex::new(Capabilities::new(provider)), token_counter, - confirmation_tx: tx, - confirmation_rx: Mutex::new(rx), + confirmation_tx: confirm_tx, + confirmation_rx: Mutex::new(confirm_rx), + tool_result_tx: tool_tx, } } @@ -493,6 +494,12 @@ impl Agent for SummarizeAgent { let capabilities = self.capabilities.lock().await; capabilities.provider() } + + async fn handle_tool_result(&self, id: String, result: ToolResult>) { + if let Err(e) = self.tool_result_tx.send((id, result)).await { + tracing::error!("Failed to send tool result: {}", e); + } + } } register_agent!("summarize", SummarizeAgent); diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index 3eb3de5f..7c0d4ec0 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -12,12 +12,13 @@ use tracing::{debug, error, instrument, warn}; use super::agent::SessionConfig; use super::detect_read_only_tools; use super::extension::ToolInfo; +use super::types::ToolResultReceiver; use super::Agent; use crate::agents::capabilities::{get_parameter_names, Capabilities}; use crate::agents::extension::{ExtensionConfig, ExtensionResult}; use crate::agents::ToolPermissionStore; use crate::config::Config; -use crate::message::{Message, ToolRequest}; +use crate::message::{Message, MessageContent, ToolRequest}; use crate::providers::base::Provider; use crate::providers::errors::ProviderError; use crate::providers::toolshim::{ @@ -29,9 +30,9 @@ use crate::token_counter::TokenCounter; use crate::truncate::{truncate_messages, OldestFirstTruncation}; use anyhow::{anyhow, Result}; use indoc::indoc; -use mcp_core::prompt::Prompt; -use mcp_core::protocol::GetPromptResult; -use mcp_core::{tool::Tool, Content, ToolError}; +use mcp_core::{ + prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult, +}; use serde_json::{json, Value}; use std::time::Duration; @@ -44,19 +45,24 @@ pub struct TruncateAgent { token_counter: TokenCounter, confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed) confirmation_rx: Mutex>, + tool_result_tx: mpsc::Sender<(String, ToolResult>)>, + tool_result_rx: ToolResultReceiver, } impl TruncateAgent { pub fn new(provider: Box) -> Self { let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name()); - // Create channel with buffer size 32 (adjust if needed) - let (tx, rx) = mpsc::channel(32); + // Create channels with buffer size 32 (adjust if needed) + let (confirm_tx, confirm_rx) = mpsc::channel(32); + let (tool_tx, tool_rx) = mpsc::channel(32); Self { capabilities: Mutex::new(Capabilities::new(provider)), token_counter, - confirmation_tx: tx, - confirmation_rx: Mutex::new(rx), + confirmation_tx: confirm_tx, + confirmation_rx: Mutex::new(confirm_rx), + tool_result_tx: tool_tx, + tool_result_rx: Arc::new(Mutex::new(tool_rx)), } } @@ -306,8 +312,21 @@ impl Agent for TruncateAgent { // Reset truncation attempt truncation_attempt = 0; - // Yield the assistant's response - yield response.clone(); + // Yield the assistant's response, but filter out frontend tool requests that we'll process separately + let filtered_response = Message { + role: response.role.clone(), + created: response.created, + content: response.content.iter().filter(|c| { + if let MessageContent::ToolRequest(req) = c { + // Only filter out frontend tool requests + if let Ok(tool_call) = &req.tool_call { + return !capabilities.is_frontend_tool(&tool_call.name); + } + } + true + }).cloned().collect(), + }; + yield filtered_response.clone(); tokio::task::yield_now().await; @@ -323,6 +342,29 @@ impl Agent for TruncateAgent { // Process tool requests depending on goose_mode let mut message_tool_response = Message::user(); + + // First handle any frontend tool requests + let mut remaining_requests = Vec::new(); + for request in &tool_requests { + if let Ok(tool_call) = request.tool_call.clone() { + if capabilities.is_frontend_tool(&tool_call.name) { + // Send frontend tool request and wait for response + yield Message::assistant().with_frontend_tool_request( + request.id.clone(), + Ok(tool_call.clone()) + ); + + if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await { + message_tool_response = message_tool_response.with_tool_response(id, result); + } + } else { + remaining_requests.push(request); + } + } else { + remaining_requests.push(request); + } + } + // Clone goose_mode once before the match to avoid move issues let mode = goose_mode.clone(); match mode.as_str() { @@ -334,7 +376,7 @@ impl Agent for TruncateAgent { // First check permissions for all tools let store = ToolPermissionStore::load()?; - for request in tool_requests.iter() { + for request in remaining_requests.iter() { if let Ok(tool_call) = request.tool_call.clone() { if tools_with_readonly_annotation.contains(&tool_call.name) { approved_tools.push((request.id.clone(), tool_call)); @@ -427,7 +469,7 @@ impl Agent for TruncateAgent { }, "chat" => { // Skip all tool calls in chat mode - for request in &tool_requests { + for request in &remaining_requests { message_tool_response = message_tool_response.with_tool_response( request.id.clone(), Ok(vec![Content::text( @@ -449,7 +491,7 @@ impl Agent for TruncateAgent { } // Process tool requests in parallel let mut tool_futures = Vec::new(); - for request in &tool_requests { + for request in &remaining_requests { if let Ok(tool_call) = request.tool_call.clone() { let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone()); tool_futures.push(tool_future); @@ -574,6 +616,12 @@ impl Agent for TruncateAgent { let capabilities = self.capabilities.lock().await; capabilities.provider() } + + async fn handle_tool_result(&self, id: String, result: ToolResult>) { + if let Err(e) = self.tool_result_tx.send((id, result)).await { + tracing::error!("Failed to send tool result: {}", e); + } + } } register_agent!("truncate", TruncateAgent); diff --git a/crates/goose/src/agents/types.rs b/crates/goose/src/agents/types.rs new file mode 100644 index 00000000..9ae5653d --- /dev/null +++ b/crates/goose/src/agents/types.rs @@ -0,0 +1,6 @@ +use mcp_core::{Content, ToolResult}; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; + +/// Type alias for the tool result channel receiver +pub type ToolResultReceiver = Arc>)>>>; diff --git a/crates/goose/src/message.rs b/crates/goose/src/message.rs index c370c174..ecdf4157 100644 --- a/crates/goose/src/message.rs +++ b/crates/goose/src/message.rs @@ -70,6 +70,14 @@ pub struct RedactedThinkingContent { pub data: String, } +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FrontendToolRequest { + pub id: String, + #[serde(with = "tool_result_serde")] + pub tool_call: ToolResult, +} + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] /// Content passed inside a message, which can be both simple content and tool content #[serde(tag = "type", rename_all = "camelCase")] @@ -79,6 +87,7 @@ pub enum MessageContent { ToolRequest(ToolRequest), ToolResponse(ToolResponse), ToolConfirmationRequest(ToolConfirmationRequest), + FrontendToolRequest(FrontendToolRequest), Thinking(ThinkingContent), RedactedThinking(RedactedThinkingContent), } @@ -137,6 +146,13 @@ impl MessageContent { pub fn redacted_thinking>(data: S) -> Self { MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() }) } + + pub fn frontend_tool_request>(id: S, tool_call: ToolResult) -> Self { + MessageContent::FrontendToolRequest(FrontendToolRequest { + id: id.into(), + tool_call, + }) + } pub fn as_tool_request(&self) -> Option<&ToolRequest> { if let MessageContent::ToolRequest(ref tool_request) = self { Some(tool_request) @@ -320,6 +336,14 @@ impl Message { )) } + pub fn with_frontend_tool_request>( + self, + id: S, + tool_call: ToolResult, + ) -> Self { + self.with_content(MessageContent::frontend_tool_request(id, tool_call)) + } + /// Add thinking content to the message pub fn with_thinking, S2: Into>( self, diff --git a/crates/goose/src/providers/formats/anthropic.rs b/crates/goose/src/providers/formats/anthropic.rs index b8219f09..b0f05fae 100644 --- a/crates/goose/src/providers/formats/anthropic.rs +++ b/crates/goose/src/providers/formats/anthropic.rs @@ -74,6 +74,16 @@ pub fn format_messages(messages: &[Message]) -> Vec { })); } MessageContent::Image(_) => continue, // Anthropic doesn't support image content yet + MessageContent::FrontendToolRequest(tool_request) => { + if let Ok(tool_call) = &tool_request.tool_call { + content.push(json!({ + "type": "tool_use", + "id": tool_request.id, + "name": tool_call.name, + "input": tool_call.arguments + })); + } + } } } diff --git a/crates/goose/src/providers/formats/bedrock.rs b/crates/goose/src/providers/formats/bedrock.rs index 70fb7591..da2bb332 100644 --- a/crates/goose/src/providers/formats/bedrock.rs +++ b/crates/goose/src/providers/formats/bedrock.rs @@ -57,6 +57,21 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result { + let tool_use_id = tool_req.id.to_string(); + let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() { + bedrock::ToolUseBlock::builder() + .tool_use_id(tool_use_id) + .name(call.name.to_string()) + .input(to_bedrock_json(&call.arguments)) + .build() + } else { + bedrock::ToolUseBlock::builder() + .tool_use_id(tool_use_id) + .build() + }?; + bedrock::ContentBlock::ToolUse(tool_use) + } MessageContent::ToolResponse(tool_res) => { let content = match &tool_res.tool_result { Ok(content) => Some( diff --git a/crates/goose/src/providers/formats/databricks.rs b/crates/goose/src/providers/formats/databricks.rs index fff5f1ad..fce49f93 100644 --- a/crates/goose/src/providers/formats/databricks.rs +++ b/crates/goose/src/providers/formats/databricks.rs @@ -188,6 +188,27 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< } })); } + MessageContent::FrontendToolRequest(req) => { + // Frontend tool requests are converted to text messages + if let Ok(tool_call) = &req.tool_call { + content_array.push(json!({ + "type": "text", + "text": format!( + "Frontend tool request: {} ({})", + tool_call.name, + serde_json::to_string_pretty(&tool_call.arguments).unwrap() + ) + })); + } else { + content_array.push(json!({ + "type": "text", + "text": format!( + "Frontend tool request error: {}", + req.tool_call.as_ref().unwrap_err() + ) + })); + } + } } } diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index c3368f2b..63bb6e2e 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -151,6 +151,32 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< // Handle direct image content converted["content"] = json!([convert_image(image, image_format)]); } + MessageContent::FrontendToolRequest(request) => match &request.tool_call { + Ok(tool_call) => { + let sanitized_name = sanitize_function_name(&tool_call.name); + let tool_calls = converted + .as_object_mut() + .unwrap() + .entry("tool_calls") + .or_insert(json!([])); + + tool_calls.as_array_mut().unwrap().push(json!({ + "id": request.id, + "type": "function", + "function": { + "name": sanitized_name, + "arguments": tool_call.arguments.to_string(), + } + })); + } + Err(e) => { + output.push(json!({ + "role": "tool", + "content": format!("Error: {}", e), + "tool_call_id": request.id + })); + } + }, } } diff --git a/examples/frontend_tools.py b/examples/frontend_tools.py new file mode 100644 index 00000000..0e032dc2 --- /dev/null +++ b/examples/frontend_tools.py @@ -0,0 +1,210 @@ +import asyncio +import json +import os +from typing import List, Dict, Any +import httpx +from datetime import datetime + +# Configuration +GOOSE_HOST = "127.0.0.1" +GOOSE_PORT = "3001" +GOOSE_URL = f"http://{GOOSE_HOST}:{GOOSE_PORT}" +SECRET_KEY = "test" # Default development secret key + +# A simple calculator tool definition +CALCULATOR_TOOL = { + "name": "calculator", + "description": "Perform basic arithmetic calculations", + "inputSchema": { + "type": "object", + "required": ["operation", "numbers"], + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The arithmetic operation to perform", + }, + "numbers": { + "type": "array", + "items": {"type": "number"}, + "description": "List of numbers to operate on in order", + }, + }, + }, +} + +# Frontend extension configuration +FRONTEND_CONFIG = { + "name": "pythonclient", + "type": "frontend", + "tools": [CALCULATOR_TOOL], + "instructions": "A calculator extension that can perform basic arithmetic operations.", +} + + +async def setup_agent() -> None: + """Initialize the agent with our frontend tool.""" + async with httpx.AsyncClient() as client: + # First create the agent + response = await client.post( + f"{GOOSE_URL}/agent", + json={"provider": "databricks", "model": "goose"}, + headers={"X-Secret-Key": SECRET_KEY}, + ) + response.raise_for_status() + print("Successfully created agent") + + # Then add our frontend extension + response = await client.post( + f"{GOOSE_URL}/extensions/add", + json=FRONTEND_CONFIG, + headers={"X-Secret-Key": SECRET_KEY}, + ) + response.raise_for_status() + print("Successfully added calculator extension") + + +def execute_calculator(args: Dict[str, Any]) -> List[Dict[str, Any]]: + """Execute the calculator tool with the given arguments.""" + operation = args["operation"] + numbers = args["numbers"] + + try: + result = None + if operation == "add": + result = sum(numbers) + elif operation == "subtract": + result = numbers[0] - sum(numbers[1:]) + elif operation == "multiply": + result = 1 + for n in numbers: + result *= n + elif operation == "divide": + result = numbers[0] + for n in numbers[1:]: + result /= n + + # Return properly structured Content::Text variant + return [ + { + "type": "text", + "text": str(result), + "annotations": None, # Required field in Rust struct + } + ] + except Exception as e: + return [ + { + "type": "text", + "text": f"Error: {str(e)}", + "annotations": None, # Required field in Rust struct + } + ] + + +def submit_tool_result(tool_id: str, result: List[Dict[str, Any]]) -> None: + """Submit the tool execution result back to Goose. + + The result should be a list of Content variants (Text, Image, or Resource). + Each Content variant has a type tag and appropriate fields. + """ + payload = { + "id": tool_id, + "result": { + "Ok": result # Result enum variant with single key for success case + }, + } + + with httpx.Client(timeout=2.0) as client: + response = client.post( + f"{GOOSE_URL}/tool_result", + json=payload, + headers={"X-Secret-Key": SECRET_KEY}, + ) + response.raise_for_status() + + +async def chat_loop() -> None: + """Main chat loop that handles the conversation with Goose.""" + session_id = "test-session" + + # Use a client with a longer timeout for streaming + async with httpx.AsyncClient(timeout=30.0) as client: + # Get user input + user_message = input("\nYou: ") + if user_message.lower() in ["exit", "quit"]: + return + + # Create the message object + message = { + "role": "user", + "created": int(datetime.now().timestamp()), + "content": [{"type": "text", "text": user_message}], + } + + # Send to /reply endpoint + payload = { + "messages": [message], + "session_id": session_id, + "session_working_dir": os.getcwd(), + } + + # Process the stream of responses + async with client.stream( + "POST", + f"{GOOSE_URL}/reply", + json=payload, + headers={ + "X-Secret-Key": SECRET_KEY, + "Accept": "text/event-stream", + "Content-Type": "application/json", + }, + ) as stream: + async for line in stream.aiter_lines(): + if not line: + continue + + # Handle SSE format + if line.startswith("data: "): + line = line[6:] # Remove "data: " prefix + + try: + payload = json.loads(line) + except json.JSONDecodeError: + print(f"Failed to parse line: {line}") + continue + + if payload["type"] == "Finish": + break + + message = payload["message"] + # Handle different message types + for content in message.get("content", []): + if content["type"] == "text": + print(f"\nGoose: {content['text']}") + elif content["type"] == "frontendToolRequest": + # Execute the tool and submit results + tool_call = content["toolCall"]["value"] + print(f"Calculator: {tool_call}") + # Execute the tool + result = execute_calculator(tool_call["arguments"]) + + # Submit the result + submit_tool_result(content["id"], result) + + +async def main(): + try: + # Initialize the agent with our tool + await setup_agent() + + # Start the chat loop + await chat_loop() + + except Exception as e: + print(f"Error: {e}") + raise # Re-raise to see full traceback + + +if __name__ == "__main__": + asyncio.run(main())