diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 98a461af..4ab527bb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -98,7 +98,6 @@ jobs: source ../bin/activate-hermit && cargo test working-directory: crates - # Add disk space cleanup before linting - name: Check disk space before cleanup run: df -h diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 5ddccc3b..01417ac7 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -33,7 +33,7 @@ use crate::agents::prompt_manager::PromptManager; use crate::agents::router_tool_selector::{ create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector, }; -use crate::agents::router_tools::ROUTER_VECTOR_SEARCH_TOOL_NAME; +use crate::agents::router_tools::{ROUTER_LLM_SEARCH_TOOL_NAME, ROUTER_VECTOR_SEARCH_TOOL_NAME}; use crate::agents::tool_router_index_manager::ToolRouterIndexManager; use crate::agents::tool_vectordb::generate_table_id; use crate::agents::types::SessionConfig; @@ -244,7 +244,9 @@ impl Agent { ToolCallResult::from(Err(ToolError::ExecutionError( "Frontend tool execution required".to_string(), ))) - } else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME { + } else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME + || tool_call.name == ROUTER_LLM_SEARCH_TOOL_NAME + { let selector = self.router_tool_selector.lock().await.clone(); let selected_tools = match selector.as_ref() { Some(selector) => match selector.select_tools(tool_call.arguments.clone()).await { @@ -351,10 +353,11 @@ impl Agent { // Update vector index if operation was successful and vector routing is enabled if result.is_ok() { let selector = self.router_tool_selector.lock().await.clone(); - if ToolRouterIndexManager::vector_tool_router_enabled(&selector) { + if ToolRouterIndexManager::is_tool_router_enabled(&selector) { if let Some(selector) = selector { let vector_action = if action == "disable" { "remove" } else { "add" }; let extension_manager = self.extension_manager.lock().await; + let selector = Arc::new(selector); if let Err(e) = ToolRouterIndexManager::update_extension_tools( &selector, &extension_manager, @@ -414,9 +417,10 @@ impl Agent { // If vector tool selection is enabled, index the tools let selector = self.router_tool_selector.lock().await.clone(); - if ToolRouterIndexManager::vector_tool_router_enabled(&selector) { + if ToolRouterIndexManager::is_tool_router_enabled(&selector) { if let Some(selector) = selector { let extension_manager = self.extension_manager.lock().await; + let selector = Arc::new(selector); if let Err(e) = ToolRouterIndexManager::update_extension_tools( &selector, &extension_manager, @@ -468,6 +472,9 @@ impl Agent { Some(RouterToolSelectionStrategy::Vector) => { prefixed_tools.push(router_tools::vector_search_tool()); } + Some(RouterToolSelectionStrategy::Llm) => { + prefixed_tools.push(router_tools::llm_search_tool()); + } None => {} } @@ -500,7 +507,7 @@ impl Agent { // If vector tool selection is enabled, remove tools from the index let selector = self.router_tool_selector.lock().await.clone(); - if ToolRouterIndexManager::vector_tool_router_enabled(&selector) { + if ToolRouterIndexManager::is_tool_router_enabled(&selector) { if let Some(selector) = selector { let extension_manager = self.extension_manager.lock().await; ToolRouterIndexManager::update_extension_tools( @@ -793,22 +800,29 @@ impl Agent { let strategy = match router_tool_selection_strategy.to_lowercase().as_str() { "vector" => Some(RouterToolSelectionStrategy::Vector), + "llm" => Some(RouterToolSelectionStrategy::Llm), _ => None, }; - if let Some(strategy) = strategy { - let table_name = generate_table_id(); - let selector = create_tool_selector(Some(strategy), provider, table_name) - .await - .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; - - let selector = Arc::new(selector); - *self.router_tool_selector.lock().await = Some(selector.clone()); - - let extension_manager = self.extension_manager.lock().await; - ToolRouterIndexManager::index_platform_tools(&selector, &extension_manager).await?; - } - + let selector = match strategy { + Some(RouterToolSelectionStrategy::Vector) => { + let table_name = generate_table_id(); + let selector = create_tool_selector(strategy, provider, Some(table_name)) + .await + .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; + Arc::new(selector) + } + Some(RouterToolSelectionStrategy::Llm) => { + let selector = create_tool_selector(strategy, provider, None) + .await + .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; + Arc::new(selector) + } + None => return Ok(()), + }; + let extension_manager = self.extension_manager.lock().await; + ToolRouterIndexManager::index_platform_tools(&selector, &extension_manager).await?; + *self.router_tool_selector.lock().await = Some(selector.clone()); Ok(()) } diff --git a/crates/goose/src/agents/prompt_manager.rs b/crates/goose/src/agents/prompt_manager.rs index a5a4e41f..191d604d 100644 --- a/crates/goose/src/agents/prompt_manager.rs +++ b/crates/goose/src/agents/prompt_manager.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use crate::agents::extension::ExtensionInfo; use crate::agents::router_tool_selector::RouterToolSelectionStrategy; -use crate::agents::router_tools::vector_search_tool_prompt; +use crate::agents::router_tools::{llm_search_tool_prompt, vector_search_tool_prompt}; use crate::providers::base::get_current_model; use crate::{config::Config, prompt_template}; @@ -92,6 +92,12 @@ impl PromptManager { Value::String(vector_search_tool_prompt()), ); } + Some(RouterToolSelectionStrategy::Llm) => { + context.insert( + "tool_selection_strategy", + Value::String(llm_search_tool_prompt()), + ); + } None => {} } diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 5b4b6d71..35aa3abe 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -29,6 +29,7 @@ impl Agent { let tool_selection_strategy = match router_tool_selection_strategy.to_lowercase().as_str() { "vector" => Some(RouterToolSelectionStrategy::Vector), + "llm" => Some(RouterToolSelectionStrategy::Llm), _ => None, }; @@ -38,6 +39,10 @@ impl Agent { self.list_tools_for_router(Some(RouterToolSelectionStrategy::Vector)) .await } + Some(RouterToolSelectionStrategy::Llm) => { + self.list_tools_for_router(Some(RouterToolSelectionStrategy::Llm)) + .await + } _ => self.list_tools(None).await, }; // Add frontend tools diff --git a/crates/goose/src/agents/router_tool_selector.rs b/crates/goose/src/agents/router_tool_selector.rs index 5ca570f8..f87005c2 100644 --- a/crates/goose/src/agents/router_tool_selector.rs +++ b/crates/goose/src/agents/router_tool_selector.rs @@ -5,18 +5,21 @@ use mcp_core::{Content, ToolError}; use anyhow::{Context, Result}; use async_trait::async_trait; use serde_json::Value; +use std::collections::HashMap; use std::collections::VecDeque; use std::env; use std::sync::Arc; use tokio::sync::RwLock; use crate::agents::tool_vectordb::ToolVectorDB; +use crate::message::Message; use crate::model::ModelConfig; use crate::providers::{self, base::Provider}; #[derive(Debug, Clone, PartialEq)] pub enum RouterToolSelectionStrategy { Vector, + Llm, } #[async_trait] @@ -200,19 +203,153 @@ impl RouterToolSelector for VectorToolSelector { } } +pub struct LLMToolSelector { + llm_provider: Arc, + tool_strings: Arc>>, // extension_name -> tool_string + recent_tool_calls: Arc>>, +} + +impl LLMToolSelector { + pub async fn new(provider: Arc) -> Result { + Ok(Self { + llm_provider: provider.clone(), + tool_strings: Arc::new(RwLock::new(HashMap::new())), + recent_tool_calls: Arc::new(RwLock::new(VecDeque::with_capacity(100))), + }) + } +} + +#[async_trait] +impl RouterToolSelector for LLMToolSelector { + async fn select_tools(&self, params: Value) -> Result, ToolError> { + let query = params + .get("query") + .and_then(|v| v.as_str()) + .ok_or_else(|| ToolError::InvalidParameters("Missing 'query' parameter".to_string()))?; + + let extension_name = params + .get("extension_name") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + // Get relevant tool strings based on extension_name + let tool_strings = self.tool_strings.read().await; + let relevant_tools = if let Some(ext) = &extension_name { + tool_strings.get(ext).cloned() + } else { + // If no extension specified, use all tools + Some( + tool_strings + .values() + .cloned() + .collect::>() + .join("\n"), + ) + }; + + if let Some(tools) = relevant_tools { + // Use LLM to search through tools + let prompt = format!( + "Given the following tools:\n{}\n\nFind the most relevant tools for the query: {}\n\nReturn the tools in this exact format for each tool:\nTool: \nDescription: \nSchema: ", + tools, query + ); + let system_message = Message::user().with_text("You are a tool selection assistant. Your task is to find the most relevant tools based on the user's query."); + let response = self + .llm_provider + .complete(&prompt, &[system_message], &[]) + .await + .map_err(|e| ToolError::ExecutionError(format!("Failed to search tools: {}", e)))?; + + // Extract just the message content from the response + let (message, _usage) = response; + let text = message.content[0].as_text().unwrap_or_default(); + + // Split the response into individual tool entries + let tool_entries: Vec = text + .split("\n\n") + .filter(|entry| entry.trim().starts_with("Tool:")) + .map(|entry| { + Content::Text(TextContent { + text: entry.trim().to_string(), + annotations: None, + }) + }) + .collect(); + + Ok(tool_entries) + } else { + Ok(vec![]) + } + } + + async fn index_tools(&self, tools: &[Tool]) -> Result<(), ToolError> { + let mut tool_strings = self.tool_strings.write().await; + + for tool in tools { + let tool_string = format!( + "Tool: {}\nDescription: {}\nSchema: {}", + tool.name, + tool.description, + serde_json::to_string_pretty(&tool.input_schema) + .unwrap_or_else(|_| "{}".to_string()) + ); + + if let Some(extension_name) = tool.name.split("__").next() { + let entry = tool_strings.entry(extension_name.to_string()).or_default(); + if !entry.is_empty() { + entry.push_str("\n\n"); + } + entry.push_str(&tool_string); + } + } + + Ok(()) + } + + async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolError> { + let mut tool_strings = self.tool_strings.write().await; + if let Some(extension_name) = tool_name.split("__").next() { + tool_strings.remove(extension_name); + } + Ok(()) + } + + async fn record_tool_call(&self, tool_name: &str) -> Result<(), ToolError> { + let mut recent_calls = self.recent_tool_calls.write().await; + if recent_calls.len() >= 100 { + recent_calls.pop_front(); + } + recent_calls.push_back(tool_name.to_string()); + Ok(()) + } + + async fn get_recent_tool_calls(&self, limit: usize) -> Result, ToolError> { + let recent_calls = self.recent_tool_calls.read().await; + Ok(recent_calls.iter().rev().take(limit).cloned().collect()) + } + + fn selector_type(&self) -> RouterToolSelectionStrategy { + RouterToolSelectionStrategy::Llm + } +} + // Helper function to create a boxed tool selector pub async fn create_tool_selector( strategy: Option, provider: Arc, - table_name: String, + table_name: Option, ) -> Result> { match strategy { Some(RouterToolSelectionStrategy::Vector) => { - let selector = VectorToolSelector::new(provider, table_name).await?; + let selector = VectorToolSelector::new(provider, table_name.unwrap()).await?; + Ok(Box::new(selector)) + } + Some(RouterToolSelectionStrategy::Llm) => { + let selector = LLMToolSelector::new(provider).await?; Ok(Box::new(selector)) } None => { - let selector = VectorToolSelector::new(provider, table_name).await?; + let selector = LLMToolSelector::new(provider).await?; Ok(Box::new(selector)) } } diff --git a/crates/goose/src/agents/router_tools.rs b/crates/goose/src/agents/router_tools.rs index 8be66800..6660955f 100644 --- a/crates/goose/src/agents/router_tools.rs +++ b/crates/goose/src/agents/router_tools.rs @@ -1,8 +1,13 @@ +use super::platform_tools::{ + PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME, + PLATFORM_READ_RESOURCE_TOOL_NAME, PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME, +}; use indoc::indoc; use mcp_core::tool::{Tool, ToolAnnotations}; use serde_json::json; pub const ROUTER_VECTOR_SEARCH_TOOL_NAME: &str = "router__vector_search"; +pub const ROUTER_LLM_SEARCH_TOOL_NAME: &str = "router__llm_search"; pub fn vector_search_tool() -> Tool { Tool::new( @@ -47,3 +52,60 @@ pub fn vector_search_tool_prompt() -> String { By dynamically enabling tools, you (Goose) as the agent save context window space and allow the user to dynamically retrieve the most relevant tools. Be sure to format the query to search rather than pass in the user's messages directly."#.to_string() } + +pub fn llm_search_tool() -> Tool { + Tool::new( + ROUTER_LLM_SEARCH_TOOL_NAME.to_string(), + indoc! {r#" + Searches for relevant tools based on the user's messages. + Format a query to search for the most relevant tools based on the user's messages. + Pay attention to the keywords in the user's messages, especially the last message and potential tools they are asking for. + This tool should be invoked when the user's messages suggest they are asking for a tool to be run. + Use the extension_name parameter to filter tools by the appropriate extension. + For example, if the user is asking to list the files in the current directory, you filter for the "developer" extension. + Example: {"User": "list the files in the current directory", "Query": "list files in current directory", "Extension Name": "developer", "k": 5} + Extension name is not optional, it is required. + The returned result will be a list of tool names, descriptions, and schemas from which you, the agent can select the most relevant tool to invoke. + "#} + .to_string(), + json!({ + "type": "object", + "required": ["query", "extension_name"], + "properties": { + "extension_name": {"type": "string", "description": "The name of the extension to filter tools by"}, + "query": {"type": "string", "description": "The query to search for the most relevant tools based on the user's messages"}, + "k": {"type": "integer", "description": "The number of tools to retrieve (defaults to 5)", "default": 5} + } + }), + Some(ToolAnnotations { + title: Some("LLM search for relevant tools".to_string()), + read_only_hint: true, + destructive_hint: false, + idempotent_hint: false, + open_world_hint: false, + }), + ) +} + +pub fn llm_search_tool_prompt() -> String { + format!( + r#"# LLM Tool Selection Instructions + Important: the user has opted to dynamically enable tools, so although an extension could be enabled, \ + please invoke the llm search tool to actually retrieve the most relevant tools to use according to the user's messages. + For example, if the user has 3 extensions enabled, but they are asking for a tool to read a pdf file, \ + you would invoke the llm_search tool to find the most relevant read pdf tool. + By dynamically enabling tools, you (Goose) as the agent save context window space and allow the user to dynamically retrieve the most relevant tools. + Be sure to format a query packed with relevant keywords to search for the most relevant tools. + In addition to the extension names available to you, you also have platform extension tools available to you. + The platform extension contains the following tools: + - {} + - {} + - {} + - {} + "#, + PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME, + PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME, + PLATFORM_READ_RESOURCE_TOOL_NAME, + PLATFORM_LIST_RESOURCES_TOOL_NAME + ) +} diff --git a/crates/goose/src/agents/tool_router_index_manager.rs b/crates/goose/src/agents/tool_router_index_manager.rs index a0fb0d63..ec1a3ac9 100644 --- a/crates/goose/src/agents/tool_router_index_manager.rs +++ b/crates/goose/src/agents/tool_router_index_manager.rs @@ -102,12 +102,10 @@ impl ToolRouterIndexManager { Ok(()) } - /// Helper to check if vector tool router is enabled - pub fn vector_tool_router_enabled(selector: &Option>>) -> bool { - if let Some(selector) = selector { - selector.selector_type() == RouterToolSelectionStrategy::Vector - } else { - false - } + /// Helper to check if vector or llm tool router is enabled + pub fn is_tool_router_enabled(selector: &Option>>) -> bool { + selector.is_some() + && (selector.as_ref().unwrap().selector_type() == RouterToolSelectionStrategy::Vector + || selector.as_ref().unwrap().selector_type() == RouterToolSelectionStrategy::Llm) } } diff --git a/ui/desktop/src/components/settings/tool_selection_strategy/ToolSelectionStrategySection.tsx b/ui/desktop/src/components/settings/tool_selection_strategy/ToolSelectionStrategySection.tsx index b6cb85f4..c3535dbf 100644 --- a/ui/desktop/src/components/settings/tool_selection_strategy/ToolSelectionStrategySection.tsx +++ b/ui/desktop/src/components/settings/tool_selection_strategy/ToolSelectionStrategySection.tsx @@ -16,7 +16,13 @@ export const all_tool_selection_strategies = [ key: 'vector', label: 'Vector', description: - 'Filter tools based on vector-based similarity. Recommended when many extensions are enabled.', + 'Filter tools based on vector similarity.', + }, + { + key: 'llm', + label: 'LLM-based', + description: + 'Uses LLM to intelligently select the most relevant tools based on the user query context.', }, ]; @@ -58,8 +64,8 @@ export const ToolSelectionStrategySection = ({

- Configure how Goose selects tools for your requests. Available only with Claude models - served on Databricks. + Configure how Goose selects tools for your requests. Recommended when many extensions are enabled. + Available only with Claude models served on Databricks for now.

{all_tool_selection_strategies.map((strategy) => (