fix: update index when tool selection strategy changes (#2991)

This commit is contained in:
Wendy Tang
2025-06-23 11:45:54 -07:00
committed by GitHub
parent ce7eabf20d
commit cebdbdb3d2
5 changed files with 213 additions and 24 deletions

View File

@@ -67,6 +67,11 @@ pub struct GetToolsQuery {
extension_name: Option<String>,
}
#[derive(Serialize)]
struct ErrorResponse {
error: String,
}
async fn get_versions() -> Json<VersionsResponse> {
let versions = ["goose".to_string()];
let default_version = "goose".to_string();
@@ -217,6 +222,46 @@ async fn update_agent_provider(
Ok(StatusCode::OK)
}
#[utoipa::path(
post,
path = "/agent/update_router_tool_selector",
responses(
(status = 200, description = "Tool selection strategy updated successfully", body = String),
(status = 500, description = "Internal server error")
)
)]
async fn update_router_tool_selector(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> Result<Json<String>, Json<ErrorResponse>> {
verify_secret_key(&headers, &state).map_err(|_| {
Json(ErrorResponse {
error: "Unauthorized - Invalid or missing API key".to_string(),
})
})?;
let agent = state.get_agent().await.map_err(|e| {
tracing::error!("Failed to get agent: {}", e);
Json(ErrorResponse {
error: format!("Failed to get agent: {}", e),
})
})?;
agent
.update_router_tool_selector(None, Some(true))
.await
.map_err(|e| {
tracing::error!("Failed to update tool selection strategy: {}", e);
Json(ErrorResponse {
error: format!("Failed to update tool selection strategy: {}", e),
})
})?;
Ok(Json(
"Tool selection strategy updated successfully".to_string(),
))
}
pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/agent/versions", get(get_versions))
@@ -224,5 +269,9 @@ pub fn routes(state: Arc<AppState>) -> Router {
.route("/agent/prompt", post(extend_prompt))
.route("/agent/tools", get(get_tools))
.route("/agent/update_provider", post(update_agent_provider))
.route(
"/agent/update_router_tool_selector",
post(update_router_tool_selector),
)
.with_state(state)
}

View File

@@ -842,12 +842,23 @@ impl Agent {
/// Update the provider used by this agent
pub async fn update_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
*self.provider.lock().await = Some(provider.clone());
self.update_router_tool_selector(provider).await?;
self.update_router_tool_selector(Some(provider), None)
.await?;
Ok(())
}
async fn update_router_tool_selector(&self, provider: Arc<dyn Provider>) -> Result<()> {
pub async fn update_router_tool_selector(
&self,
provider: Option<Arc<dyn Provider>>,
reindex_all: Option<bool>,
) -> Result<()> {
let config = Config::global();
let extension_manager = self.extension_manager.lock().await;
let provider = match provider {
Some(p) => p,
None => self.provider().await?,
};
let router_tool_selection_strategy = config
.get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
.unwrap_or_else(|_| "default".to_string());
@@ -861,21 +872,44 @@ impl Agent {
let selector = match strategy {
Some(RouterToolSelectionStrategy::Vector) => {
let table_name = generate_table_id();
let selector = create_tool_selector(strategy, provider, Some(table_name))
let selector = create_tool_selector(strategy, provider.clone(), Some(table_name))
.await
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
Arc::new(selector)
}
Some(RouterToolSelectionStrategy::Llm) => {
let selector = create_tool_selector(strategy, provider, None)
let selector = create_tool_selector(strategy, provider.clone(), None)
.await
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
Arc::new(selector)
}
None => return Ok(()),
};
let extension_manager = self.extension_manager.lock().await;
// First index platform tools
ToolRouterIndexManager::index_platform_tools(&selector, &extension_manager).await?;
if reindex_all.unwrap_or(false) {
let enabled_extensions = extension_manager.list_extensions().await?;
for extension_name in enabled_extensions {
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
&selector,
&extension_manager,
&extension_name,
"add",
)
.await
{
tracing::error!(
"Failed to index tools for extension {}: {}",
extension_name,
e
);
}
}
}
// Update the selector
*self.router_tool_selector.lock().await = Some(selector.clone());
Ok(())
}

View File

@@ -166,12 +166,36 @@ impl RouterToolSelector for VectorToolSelector {
})
.collect();
// Index all tools at once
// Get vector_db lock
let vector_db = self.vector_db.read().await;
vector_db
.index_tools(tool_records)
.await
.map_err(|e| ToolError::ExecutionError(format!("Failed to index tools: {}", e)))?;
// Filter out tools that already exist in the database
let mut new_tool_records = Vec::new();
for record in tool_records {
// Check if tool exists by searching for it
let existing_tools = vector_db
.search_tools(record.vector.clone(), 1, Some(&record.extension_name))
.await
.map_err(|e| {
ToolError::ExecutionError(format!("Failed to search for existing tools: {}", e))
})?;
// Only add if no exact match found
if !existing_tools
.iter()
.any(|t| t.tool_name == record.tool_name)
{
new_tool_records.push(record);
}
}
// Only index if there are new tools to add
if !new_tool_records.is_empty() {
vector_db
.index_tools(new_tool_records)
.await
.map_err(|e| ToolError::ExecutionError(format!("Failed to index tools: {}", e)))?;
}
Ok(())
}
@@ -282,7 +306,7 @@ impl RouterToolSelector for LLMToolSelector {
}
}
async fn index_tools(&self, tools: &[Tool], _extension_name: &str) -> Result<(), ToolError> {
async fn index_tools(&self, tools: &[Tool], extension_name: &str) -> Result<(), ToolError> {
let mut tool_strings = self.tool_strings.write().await;
for tool in tools {
@@ -294,8 +318,11 @@ impl RouterToolSelector for LLMToolSelector {
.unwrap_or_else(|_| "{}".to_string())
);
if let Some(extension_name) = tool.name.split("__").next() {
let entry = tool_strings.entry(extension_name.to_string()).or_default();
// Use the provided extension_name instead of parsing from tool name
let entry = tool_strings.entry(extension_name.to_string()).or_default();
// Check if this tool already exists in the entry
if !entry.contains(&format!("Tool: {}", tool.name)) {
if !entry.is_empty() {
entry.push_str("\n\n");
}
@@ -305,7 +332,6 @@ impl RouterToolSelector for LLMToolSelector {
Ok(())
}
async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolError> {
let mut tool_strings = self.tool_strings.write().await;
if let Some(extension_name) = tool_name.split("__").next() {

View File

@@ -44,13 +44,26 @@ pub fn vector_search_tool() -> Tool {
}
pub fn vector_search_tool_prompt() -> String {
r#"# Tool Selection Instructions
format!(
r#"# Tool Selection Instructions
Important: the user has opted to dynamically enable tools, so although an extension could be enabled, \
please invoke the vector search tool to actually retrieve the most relevant tools to use according to the user's messages.
For example, if the user has 3 extensions enabled, but they are asking for a tool to read a pdf file, \
you would invoke the vector_search tool to find the most relevant read pdf tool.
By dynamically enabling tools, you (Goose) as the agent save context window space and allow the user to dynamically retrieve the most relevant tools.
Be sure to format the query to search rather than pass in the user's messages directly."#.to_string()
Be sure to format the query to search rather than pass in the user's messages directly.
In addition to the extension names available to you, you also have platform extension tools available to you.
The platform extension contains the following tools:
- {}
- {}
- {}
- {}
"#,
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME,
PLATFORM_READ_RESOURCE_TOOL_NAME,
PLATFORM_LIST_RESOURCES_TOOL_NAME
)
}
pub fn llm_search_tool() -> Tool {

View File

@@ -1,6 +1,7 @@
import { useEffect, useState, useCallback } from 'react';
import { View, ViewOptions } from '../../../App';
import { useConfig } from '../../ConfigContext';
import { getApiUrl, getSecretKey } from '../../../config';
interface ToolSelectionStrategySectionProps {
setView: (view: View, viewOptions?: ViewOptions) => void;
@@ -29,15 +30,65 @@ export const ToolSelectionStrategySection = ({
setView: _setView,
}: ToolSelectionStrategySectionProps) => {
const [currentStrategy, setCurrentStrategy] = useState('default');
const [error, setError] = useState<string | null>(null);
const [isLoading, setIsLoading] = useState(false);
const { read, upsert } = useConfig();
const handleStrategyChange = async (newStrategy: string) => {
if (isLoading) return; // Prevent multiple simultaneous requests
setError(null); // Clear any previous errors
setIsLoading(true);
try {
await upsert('GOOSE_ROUTER_TOOL_SELECTION_STRATEGY', newStrategy, false);
// First update the configuration
try {
await upsert('GOOSE_ROUTER_TOOL_SELECTION_STRATEGY', newStrategy, false);
} catch (error) {
console.error('Error updating configuration:', error);
setError(`Failed to update configuration: ${error}`);
setIsLoading(false);
return;
}
// Then update the backend
try {
const response = await fetch(getApiUrl('/agent/update_router_tool_selector'), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-Secret-Key': getSecretKey(),
},
});
if (!response.ok) {
const errorData = await response
.json()
.catch(() => ({ error: 'Unknown error from backend' }));
throw new Error(errorData.error || 'Unknown error from backend');
}
// Parse the success response
const data = await response
.json()
.catch(() => ({ message: 'Tool selection strategy updated successfully' }));
if (data.error) {
throw new Error(data.error);
}
} catch (error) {
console.error('Error updating backend:', error);
setError(`Failed to update backend: ${error}`);
setIsLoading(false);
return;
}
// If both succeeded, update the UI
setCurrentStrategy(newStrategy);
} catch (error) {
console.error('Error updating tool selection strategy:', error);
throw new Error(`Failed to store new tool selection strategy: ${newStrategy}`);
setError(`Failed to update tool selection strategy: ${error}`);
} finally {
setIsLoading(false);
}
};
@@ -49,6 +100,7 @@ export const ToolSelectionStrategySection = ({
}
} catch (error) {
console.error('Error fetching current tool selection strategy:', error);
setError(`Failed to fetch current strategy: ${error}`);
}
}, [read]);
@@ -66,12 +118,26 @@ export const ToolSelectionStrategySection = ({
Configure how Goose selects tools for your requests. Recommended when many extensions are
enabled. Available only with Claude models served on Databricks for now.
</p>
{error && (
<div className="mb-4 p-3 bg-red-100 border border-red-400 text-red-700 rounded">
{error}
</div>
)}
{isLoading && (
<div className="mb-4 p-3 bg-blue-100 border border-blue-400 text-blue-700 rounded flex items-center gap-2">
<div className="animate-spin rounded-full h-4 w-4 border-b-2 border-blue-700"></div>
Updating tool selection strategy...
</div>
)}
<div>
{all_tool_selection_strategies.map((strategy) => (
<div className="group hover:cursor-pointer" key={strategy.key}>
<div
className={`group ${isLoading ? 'opacity-50' : 'hover:cursor-pointer'}`}
key={strategy.key}
>
<div
className="flex items-center justify-between text-textStandard py-2 px-4 hover:bg-bgSubtle"
onClick={() => handleStrategyChange(strategy.key)}
className={`flex items-center justify-between text-textStandard py-2 px-4 ${!isLoading ? 'hover:bg-bgSubtle' : ''}`}
onClick={() => !isLoading && handleStrategyChange(strategy.key)}
>
<div className="flex">
<div>
@@ -86,14 +152,15 @@ export const ToolSelectionStrategySection = ({
name="tool-selection-strategy"
value={strategy.key}
checked={currentStrategy === strategy.key}
onChange={() => handleStrategyChange(strategy.key)}
onChange={() => !isLoading && handleStrategyChange(strategy.key)}
disabled={isLoading}
className="peer sr-only"
/>
<div
className="h-4 w-4 rounded-full border border-borderStandard
className={`h-4 w-4 rounded-full border border-borderStandard
peer-checked:border-[6px] peer-checked:border-black dark:peer-checked:border-white
peer-checked:bg-white dark:peer-checked:bg-black
transition-all duration-200 ease-in-out group-hover:border-borderProminent"
transition-all duration-200 ease-in-out ${!isLoading ? 'group-hover:border-borderProminent' : ''}`}
></div>
</div>
</div>