diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 2099cd77..5564470c 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -25,6 +25,10 @@ use crate::agents::platform_tools::{ PLATFORM_READ_RESOURCE_TOOL_NAME, PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME, }; use crate::agents::prompt_manager::PromptManager; +use crate::agents::router_tool_selector::{ + create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector, +}; +use crate::agents::router_tools::ROUTER_VECTOR_SEARCH_TOOL_NAME; use crate::agents::types::SessionConfig; use crate::agents::types::{FrontendTool, ToolResultReceiver}; use mcp_core::{ @@ -32,6 +36,7 @@ use mcp_core::{ }; use super::platform_tools; +use super::router_tools; use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; /// The main goose Agent @@ -46,6 +51,7 @@ pub struct Agent { pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult>)>, pub(super) tool_result_rx: ToolResultReceiver, pub(super) tool_monitor: Mutex>, + pub(super) router_tool_selector: Mutex>>, } impl Agent { @@ -54,6 +60,16 @@ impl Agent { let (confirm_tx, confirm_rx) = mpsc::channel(32); let (tool_tx, tool_rx) = mpsc::channel(32); + let router_tool_selection_strategy = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") + .ok() + .and_then(|s| { + if s.eq_ignore_ascii_case("vector") { + Some(RouterToolSelectionStrategy::Vector) + } else { + None + } + }); + Self { provider: Mutex::new(None), extension_manager: Mutex::new(ExtensionManager::new()), @@ -65,6 +81,9 @@ impl Agent { tool_result_tx: tool_tx, tool_result_rx: Arc::new(Mutex::new(tool_rx)), tool_monitor: Mutex::new(None), + router_tool_selector: Mutex::new(Some(create_tool_selector( + router_tool_selection_strategy, + ))), } } @@ -184,6 +203,15 @@ impl Agent { Err(ToolError::ExecutionError( "Frontend tool execution required".to_string(), )) + } else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME { + let router_tool_selector = self.router_tool_selector.lock().await; + if let Some(selector) = router_tool_selector.as_ref() { + selector.select_tools(tool_call.arguments.clone()).await + } else { + Err(ToolError::ExecutionError( + "Encountered vector search error.".to_string(), + )) + } } else { extension_manager .dispatch_tool_call(tool_call.clone()) @@ -318,6 +346,29 @@ impl Agent { prefixed_tools } + pub async fn list_tools_for_router( + &self, + strategy: Option, + ) -> Vec { + let extension_manager = self.extension_manager.lock().await; + + let mut prefixed_tools = vec![]; + match strategy { + Some(RouterToolSelectionStrategy::Vector) => { + prefixed_tools.push(router_tools::vector_search_tool()); + } + None => {} + } + prefixed_tools.push(platform_tools::search_available_extensions_tool()); + prefixed_tools.push(platform_tools::manage_extensions_tool()); + + if extension_manager.supports_resources() { + prefixed_tools.push(platform_tools::read_resource_tool()); + prefixed_tools.push(platform_tools::list_resources_tool()); + } + prefixed_tools + } + pub async fn remove_extension(&self, name: &str) { let mut extension_manager = self.extension_manager.lock().await; extension_manager @@ -474,7 +525,7 @@ impl Agent { &permission_check_result.needs_approval, tool_futures_arc.clone(), &mut permission_manager, - message_tool_response.clone(), + message_tool_response.clone() ); // We have a stream of tool_approval_requests to handle @@ -629,6 +680,7 @@ impl Agent { self.frontend_instructions.lock().await.clone(), extension_manager.suggest_disable_extensions_prompt().await, Some(model_name), + None, ); let recipe_prompt = prompt_manager.get_recipe_prompt().await; diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 3b39a6f8..fc44bcab 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -6,6 +6,8 @@ mod large_response_handler; pub mod platform_tools; pub mod prompt_manager; mod reply_parts; +mod router_tool_selector; +mod router_tools; mod tool_execution; mod types; diff --git a/crates/goose/src/agents/prompt_manager.rs b/crates/goose/src/agents/prompt_manager.rs index f464af7d..cdd367c9 100644 --- a/crates/goose/src/agents/prompt_manager.rs +++ b/crates/goose/src/agents/prompt_manager.rs @@ -3,6 +3,8 @@ use serde_json::Value; use std::collections::HashMap; use crate::agents::extension::ExtensionInfo; +use crate::agents::router_tool_selector::RouterToolSelectionStrategy; +use crate::agents::router_tools::vector_search_tool_prompt; use crate::providers::base::get_current_model; use crate::{config::Config, prompt_template}; @@ -67,6 +69,7 @@ impl PromptManager { frontend_instructions: Option, suggest_disable_extensions_prompt: Value, model_name: Option<&str>, + tool_selection_strategy: Option, ) -> String { let mut context: HashMap<&str, Value> = HashMap::new(); let mut extensions_info = extensions_info.clone(); @@ -82,6 +85,16 @@ impl PromptManager { context.insert("extensions", serde_json::to_value(extensions_info).unwrap()); + match tool_selection_strategy { + Some(RouterToolSelectionStrategy::Vector) => { + context.insert( + "tool_selection_strategy", + Value::String(vector_search_tool_prompt()), + ); + } + None => {} + } + context.insert( "current_date_time", Value::String(self.current_date_timestamp.clone()), diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 5d07635d..72515c4a 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -2,6 +2,7 @@ use anyhow::Result; use std::collections::HashSet; use std::sync::Arc; +use crate::agents::router_tool_selector::RouterToolSelectionStrategy; use crate::message::{Message, MessageContent, ToolRequest}; use crate::providers::base::{Provider, ProviderUsage}; use crate::providers::errors::ProviderError; @@ -18,9 +19,24 @@ impl Agent { pub(crate) async fn prepare_tools_and_prompt( &self, ) -> anyhow::Result<(Vec, Vec, String)> { + // Get tool selection strategy + let tool_selection_strategy = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") + .ok() + .and_then(|s| { + if s.eq_ignore_ascii_case("vector") { + Some(RouterToolSelectionStrategy::Vector) + } else { + None + } + }); // Get tools from extension manager - let mut tools = self.list_tools(None).await; - + let mut tools = match tool_selection_strategy { + Some(RouterToolSelectionStrategy::Vector) => { + self.list_tools_for_router(Some(RouterToolSelectionStrategy::Vector)) + .await + } + _ => self.list_tools(None).await, + }; // Add frontend tools let frontend_tools = self.frontend_tools.lock().await; for frontend_tool in frontend_tools.values() { @@ -42,6 +58,7 @@ impl Agent { self.frontend_instructions.lock().await.clone(), extension_manager.suggest_disable_extensions_prompt().await, Some(model_name), + tool_selection_strategy, ); // Handle toolshim if enabled diff --git a/crates/goose/src/agents/router_tool_selector.rs b/crates/goose/src/agents/router_tool_selector.rs new file mode 100644 index 00000000..6427952c --- /dev/null +++ b/crates/goose/src/agents/router_tool_selector.rs @@ -0,0 +1,36 @@ +use mcp_core::{Content, ToolError}; + +use async_trait::async_trait; +use serde_json::Value; + +pub enum RouterToolSelectionStrategy { + Vector, +} + +#[async_trait] +pub trait RouterToolSelector: Send + Sync { + async fn select_tools(&self, params: Value) -> Result, ToolError>; +} + +pub struct VectorToolSelector; + +#[async_trait] +impl RouterToolSelector for VectorToolSelector { + async fn select_tools(&self, params: Value) -> Result, ToolError> { + let query = params.get("query").and_then(|v| v.as_str()); + println!("query: {:?}", query); + let selected_tools = Vec::new(); + // TODO: placeholder for vector tool selection + Ok(selected_tools) + } +} + +// Helper function to create a boxed tool selector +pub fn create_tool_selector( + strategy: Option, +) -> Box { + match strategy { + Some(RouterToolSelectionStrategy::Vector) => Box::new(VectorToolSelector), + _ => Box::new(VectorToolSelector), // Default to VectorToolSelector + } +} diff --git a/crates/goose/src/agents/router_tools.rs b/crates/goose/src/agents/router_tools.rs new file mode 100644 index 00000000..fe7aee59 --- /dev/null +++ b/crates/goose/src/agents/router_tools.rs @@ -0,0 +1,46 @@ +use indoc::indoc; +use mcp_core::tool::{Tool, ToolAnnotations}; +use serde_json::json; + +pub const ROUTER_VECTOR_SEARCH_TOOL_NAME: &str = "router__vector_search"; + +pub fn vector_search_tool() -> Tool { + Tool::new( + ROUTER_VECTOR_SEARCH_TOOL_NAME.to_string(), + indoc! {r#" + Searches for relevant tools based on the user's messages. + Format a query to search for the most relevant tools based on the user's messages. + Pay attention to the keywords in the user's messages, especially the last message and potential tools they are asking for. + This tool should be invoked when the user's messages suggest they are asking for a tool to be run. + Examples: + - {"User": "what is the weather in Tokyo?", "Query": "weather in Tokyo"} + - {"User": "read this pdf file for me", "Query": "read pdf file"} + - {"User": "run this command ls -l in the terminal", "Query": "run command in terminal ls -l"} + "#} + .to_string(), + json!({ + "type": "object", + "required": ["query"], + "properties": { + "query": {"type": "string", "description": "The query to search for the most relevant tools based on the user's messages"} + } + }), + Some(ToolAnnotations { + title: Some("Vector search for relevant tools".to_string()), + read_only_hint: true, + destructive_hint: false, + idempotent_hint: false, + open_world_hint: false, + }), + ) +} + +pub fn vector_search_tool_prompt() -> String { + r#"# Tool Selection Instructions + Imporant: the user has opted to dynamically enable tools, so although an extension could be enabled, \ + please invoke the vector search tool to actually retrieve the most relevant tools to use according to the user's messages. + For example, if the user has 3 extensions enabled, but they are asking for a tool to read a pdf file, \ + you would invoke the vector_search tool to find the most relevant read pdf tool. + By dynamically enabling tools, you (Goose) as the agent save context window space and allow the user to dynamically retrieve the most relevant tools. + Be sure to format the query to search rather than pass in the user's messages directly."#.to_string() +} diff --git a/crates/goose/src/prompts/system.md b/crates/goose/src/prompts/system.md index 5cf2a41b..2681696a 100644 --- a/crates/goose/src/prompts/system.md +++ b/crates/goose/src/prompts/system.md @@ -36,6 +36,8 @@ No extensions are defined. You should let the user know that they should add ext {{suggest_disable}} {% endif %} +{{tool_selection_strategy}} + # Response Guidelines - Use Markdown formatting for all responses.