mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-02 21:24:34 +01:00
fix: update index when tool selection strategy changes (#2991)
This commit is contained in:
@@ -67,6 +67,11 @@ pub struct GetToolsQuery {
|
||||
extension_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ErrorResponse {
|
||||
error: String,
|
||||
}
|
||||
|
||||
async fn get_versions() -> Json<VersionsResponse> {
|
||||
let versions = ["goose".to_string()];
|
||||
let default_version = "goose".to_string();
|
||||
@@ -217,6 +222,46 @@ async fn update_agent_provider(
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/agent/update_router_tool_selector",
|
||||
responses(
|
||||
(status = 200, description = "Tool selection strategy updated successfully", body = String),
|
||||
(status = 500, description = "Internal server error")
|
||||
)
|
||||
)]
|
||||
async fn update_router_tool_selector(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<String>, Json<ErrorResponse>> {
|
||||
verify_secret_key(&headers, &state).map_err(|_| {
|
||||
Json(ErrorResponse {
|
||||
error: "Unauthorized - Invalid or missing API key".to_string(),
|
||||
})
|
||||
})?;
|
||||
|
||||
let agent = state.get_agent().await.map_err(|e| {
|
||||
tracing::error!("Failed to get agent: {}", e);
|
||||
Json(ErrorResponse {
|
||||
error: format!("Failed to get agent: {}", e),
|
||||
})
|
||||
})?;
|
||||
|
||||
agent
|
||||
.update_router_tool_selector(None, Some(true))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to update tool selection strategy: {}", e);
|
||||
Json(ErrorResponse {
|
||||
error: format!("Failed to update tool selection strategy: {}", e),
|
||||
})
|
||||
})?;
|
||||
|
||||
Ok(Json(
|
||||
"Tool selection strategy updated successfully".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn routes(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/agent/versions", get(get_versions))
|
||||
@@ -224,5 +269,9 @@ pub fn routes(state: Arc<AppState>) -> Router {
|
||||
.route("/agent/prompt", post(extend_prompt))
|
||||
.route("/agent/tools", get(get_tools))
|
||||
.route("/agent/update_provider", post(update_agent_provider))
|
||||
.route(
|
||||
"/agent/update_router_tool_selector",
|
||||
post(update_router_tool_selector),
|
||||
)
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
@@ -842,12 +842,23 @@ impl Agent {
|
||||
/// Update the provider used by this agent
|
||||
pub async fn update_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
|
||||
*self.provider.lock().await = Some(provider.clone());
|
||||
self.update_router_tool_selector(provider).await?;
|
||||
self.update_router_tool_selector(Some(provider), None)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_router_tool_selector(&self, provider: Arc<dyn Provider>) -> Result<()> {
|
||||
pub async fn update_router_tool_selector(
|
||||
&self,
|
||||
provider: Option<Arc<dyn Provider>>,
|
||||
reindex_all: Option<bool>,
|
||||
) -> Result<()> {
|
||||
let config = Config::global();
|
||||
let extension_manager = self.extension_manager.lock().await;
|
||||
let provider = match provider {
|
||||
Some(p) => p,
|
||||
None => self.provider().await?,
|
||||
};
|
||||
|
||||
let router_tool_selection_strategy = config
|
||||
.get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
|
||||
.unwrap_or_else(|_| "default".to_string());
|
||||
@@ -861,21 +872,44 @@ impl Agent {
|
||||
let selector = match strategy {
|
||||
Some(RouterToolSelectionStrategy::Vector) => {
|
||||
let table_name = generate_table_id();
|
||||
let selector = create_tool_selector(strategy, provider, Some(table_name))
|
||||
let selector = create_tool_selector(strategy, provider.clone(), Some(table_name))
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
|
||||
Arc::new(selector)
|
||||
}
|
||||
Some(RouterToolSelectionStrategy::Llm) => {
|
||||
let selector = create_tool_selector(strategy, provider, None)
|
||||
let selector = create_tool_selector(strategy, provider.clone(), None)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
|
||||
Arc::new(selector)
|
||||
}
|
||||
None => return Ok(()),
|
||||
};
|
||||
let extension_manager = self.extension_manager.lock().await;
|
||||
|
||||
// First index platform tools
|
||||
ToolRouterIndexManager::index_platform_tools(&selector, &extension_manager).await?;
|
||||
|
||||
if reindex_all.unwrap_or(false) {
|
||||
let enabled_extensions = extension_manager.list_extensions().await?;
|
||||
for extension_name in enabled_extensions {
|
||||
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
|
||||
&selector,
|
||||
&extension_manager,
|
||||
&extension_name,
|
||||
"add",
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
"Failed to index tools for extension {}: {}",
|
||||
extension_name,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update the selector
|
||||
*self.router_tool_selector.lock().await = Some(selector.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -166,12 +166,36 @@ impl RouterToolSelector for VectorToolSelector {
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Index all tools at once
|
||||
// Get vector_db lock
|
||||
let vector_db = self.vector_db.read().await;
|
||||
vector_db
|
||||
.index_tools(tool_records)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(format!("Failed to index tools: {}", e)))?;
|
||||
|
||||
// Filter out tools that already exist in the database
|
||||
let mut new_tool_records = Vec::new();
|
||||
for record in tool_records {
|
||||
// Check if tool exists by searching for it
|
||||
let existing_tools = vector_db
|
||||
.search_tools(record.vector.clone(), 1, Some(&record.extension_name))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
ToolError::ExecutionError(format!("Failed to search for existing tools: {}", e))
|
||||
})?;
|
||||
|
||||
// Only add if no exact match found
|
||||
if !existing_tools
|
||||
.iter()
|
||||
.any(|t| t.tool_name == record.tool_name)
|
||||
{
|
||||
new_tool_records.push(record);
|
||||
}
|
||||
}
|
||||
|
||||
// Only index if there are new tools to add
|
||||
if !new_tool_records.is_empty() {
|
||||
vector_db
|
||||
.index_tools(new_tool_records)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(format!("Failed to index tools: {}", e)))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -282,7 +306,7 @@ impl RouterToolSelector for LLMToolSelector {
|
||||
}
|
||||
}
|
||||
|
||||
async fn index_tools(&self, tools: &[Tool], _extension_name: &str) -> Result<(), ToolError> {
|
||||
async fn index_tools(&self, tools: &[Tool], extension_name: &str) -> Result<(), ToolError> {
|
||||
let mut tool_strings = self.tool_strings.write().await;
|
||||
|
||||
for tool in tools {
|
||||
@@ -294,8 +318,11 @@ impl RouterToolSelector for LLMToolSelector {
|
||||
.unwrap_or_else(|_| "{}".to_string())
|
||||
);
|
||||
|
||||
if let Some(extension_name) = tool.name.split("__").next() {
|
||||
let entry = tool_strings.entry(extension_name.to_string()).or_default();
|
||||
// Use the provided extension_name instead of parsing from tool name
|
||||
let entry = tool_strings.entry(extension_name.to_string()).or_default();
|
||||
|
||||
// Check if this tool already exists in the entry
|
||||
if !entry.contains(&format!("Tool: {}", tool.name)) {
|
||||
if !entry.is_empty() {
|
||||
entry.push_str("\n\n");
|
||||
}
|
||||
@@ -305,7 +332,6 @@ impl RouterToolSelector for LLMToolSelector {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolError> {
|
||||
let mut tool_strings = self.tool_strings.write().await;
|
||||
if let Some(extension_name) = tool_name.split("__").next() {
|
||||
|
||||
@@ -44,13 +44,26 @@ pub fn vector_search_tool() -> Tool {
|
||||
}
|
||||
|
||||
pub fn vector_search_tool_prompt() -> String {
|
||||
r#"# Tool Selection Instructions
|
||||
format!(
|
||||
r#"# Tool Selection Instructions
|
||||
Important: the user has opted to dynamically enable tools, so although an extension could be enabled, \
|
||||
please invoke the vector search tool to actually retrieve the most relevant tools to use according to the user's messages.
|
||||
For example, if the user has 3 extensions enabled, but they are asking for a tool to read a pdf file, \
|
||||
you would invoke the vector_search tool to find the most relevant read pdf tool.
|
||||
By dynamically enabling tools, you (Goose) as the agent save context window space and allow the user to dynamically retrieve the most relevant tools.
|
||||
Be sure to format the query to search rather than pass in the user's messages directly."#.to_string()
|
||||
Be sure to format the query to search rather than pass in the user's messages directly.
|
||||
In addition to the extension names available to you, you also have platform extension tools available to you.
|
||||
The platform extension contains the following tools:
|
||||
- {}
|
||||
- {}
|
||||
- {}
|
||||
- {}
|
||||
"#,
|
||||
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
|
||||
PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME,
|
||||
PLATFORM_READ_RESOURCE_TOOL_NAME,
|
||||
PLATFORM_LIST_RESOURCES_TOOL_NAME
|
||||
)
|
||||
}
|
||||
|
||||
pub fn llm_search_tool() -> Tool {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useEffect, useState, useCallback } from 'react';
|
||||
import { View, ViewOptions } from '../../../App';
|
||||
import { useConfig } from '../../ConfigContext';
|
||||
import { getApiUrl, getSecretKey } from '../../../config';
|
||||
|
||||
interface ToolSelectionStrategySectionProps {
|
||||
setView: (view: View, viewOptions?: ViewOptions) => void;
|
||||
@@ -29,15 +30,65 @@ export const ToolSelectionStrategySection = ({
|
||||
setView: _setView,
|
||||
}: ToolSelectionStrategySectionProps) => {
|
||||
const [currentStrategy, setCurrentStrategy] = useState('default');
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const { read, upsert } = useConfig();
|
||||
|
||||
const handleStrategyChange = async (newStrategy: string) => {
|
||||
if (isLoading) return; // Prevent multiple simultaneous requests
|
||||
|
||||
setError(null); // Clear any previous errors
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
await upsert('GOOSE_ROUTER_TOOL_SELECTION_STRATEGY', newStrategy, false);
|
||||
// First update the configuration
|
||||
try {
|
||||
await upsert('GOOSE_ROUTER_TOOL_SELECTION_STRATEGY', newStrategy, false);
|
||||
} catch (error) {
|
||||
console.error('Error updating configuration:', error);
|
||||
setError(`Failed to update configuration: ${error}`);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Then update the backend
|
||||
try {
|
||||
const response = await fetch(getApiUrl('/agent/update_router_tool_selector'), {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Secret-Key': getSecretKey(),
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response
|
||||
.json()
|
||||
.catch(() => ({ error: 'Unknown error from backend' }));
|
||||
throw new Error(errorData.error || 'Unknown error from backend');
|
||||
}
|
||||
|
||||
// Parse the success response
|
||||
const data = await response
|
||||
.json()
|
||||
.catch(() => ({ message: 'Tool selection strategy updated successfully' }));
|
||||
if (data.error) {
|
||||
throw new Error(data.error);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error updating backend:', error);
|
||||
setError(`Failed to update backend: ${error}`);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// If both succeeded, update the UI
|
||||
setCurrentStrategy(newStrategy);
|
||||
} catch (error) {
|
||||
console.error('Error updating tool selection strategy:', error);
|
||||
throw new Error(`Failed to store new tool selection strategy: ${newStrategy}`);
|
||||
setError(`Failed to update tool selection strategy: ${error}`);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -49,6 +100,7 @@ export const ToolSelectionStrategySection = ({
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching current tool selection strategy:', error);
|
||||
setError(`Failed to fetch current strategy: ${error}`);
|
||||
}
|
||||
}, [read]);
|
||||
|
||||
@@ -66,12 +118,26 @@ export const ToolSelectionStrategySection = ({
|
||||
Configure how Goose selects tools for your requests. Recommended when many extensions are
|
||||
enabled. Available only with Claude models served on Databricks for now.
|
||||
</p>
|
||||
{error && (
|
||||
<div className="mb-4 p-3 bg-red-100 border border-red-400 text-red-700 rounded">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
{isLoading && (
|
||||
<div className="mb-4 p-3 bg-blue-100 border border-blue-400 text-blue-700 rounded flex items-center gap-2">
|
||||
<div className="animate-spin rounded-full h-4 w-4 border-b-2 border-blue-700"></div>
|
||||
Updating tool selection strategy...
|
||||
</div>
|
||||
)}
|
||||
<div>
|
||||
{all_tool_selection_strategies.map((strategy) => (
|
||||
<div className="group hover:cursor-pointer" key={strategy.key}>
|
||||
<div
|
||||
className={`group ${isLoading ? 'opacity-50' : 'hover:cursor-pointer'}`}
|
||||
key={strategy.key}
|
||||
>
|
||||
<div
|
||||
className="flex items-center justify-between text-textStandard py-2 px-4 hover:bg-bgSubtle"
|
||||
onClick={() => handleStrategyChange(strategy.key)}
|
||||
className={`flex items-center justify-between text-textStandard py-2 px-4 ${!isLoading ? 'hover:bg-bgSubtle' : ''}`}
|
||||
onClick={() => !isLoading && handleStrategyChange(strategy.key)}
|
||||
>
|
||||
<div className="flex">
|
||||
<div>
|
||||
@@ -86,14 +152,15 @@ export const ToolSelectionStrategySection = ({
|
||||
name="tool-selection-strategy"
|
||||
value={strategy.key}
|
||||
checked={currentStrategy === strategy.key}
|
||||
onChange={() => handleStrategyChange(strategy.key)}
|
||||
onChange={() => !isLoading && handleStrategyChange(strategy.key)}
|
||||
disabled={isLoading}
|
||||
className="peer sr-only"
|
||||
/>
|
||||
<div
|
||||
className="h-4 w-4 rounded-full border border-borderStandard
|
||||
className={`h-4 w-4 rounded-full border border-borderStandard
|
||||
peer-checked:border-[6px] peer-checked:border-black dark:peer-checked:border-white
|
||||
peer-checked:bg-white dark:peer-checked:bg-black
|
||||
transition-all duration-200 ease-in-out group-hover:border-borderProminent"
|
||||
transition-all duration-200 ease-in-out ${!isLoading ? 'group-hover:border-borderProminent' : ''}`}
|
||||
></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user