mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-19 05:24:28 +01:00
feat: (tool router) adds extension name in vector db & search tool (#2855)
This commit is contained in:
@@ -246,13 +246,29 @@ impl Agent {
|
||||
)))
|
||||
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME {
|
||||
let selector = self.router_tool_selector.lock().await.clone();
|
||||
ToolCallResult::from(if let Some(selector) = selector {
|
||||
selector.select_tools(tool_call.arguments.clone()).await
|
||||
} else {
|
||||
Err(ToolError::ExecutionError(
|
||||
"Encountered vector search error.".to_string(),
|
||||
))
|
||||
})
|
||||
let selected_tools = match selector.as_ref() {
|
||||
Some(selector) => match selector.select_tools(tool_call.arguments.clone()).await {
|
||||
Ok(tools) => tools,
|
||||
Err(e) => {
|
||||
return (
|
||||
request_id,
|
||||
Err(ToolError::ExecutionError(format!(
|
||||
"Failed to select tools: {}",
|
||||
e
|
||||
))),
|
||||
)
|
||||
}
|
||||
},
|
||||
None => {
|
||||
return (
|
||||
request_id,
|
||||
Err(ToolError::ExecutionError(
|
||||
"No tool selector available".to_string(),
|
||||
)),
|
||||
)
|
||||
}
|
||||
};
|
||||
ToolCallResult::from(Ok(selected_tools))
|
||||
} else {
|
||||
// Clone the result to ensure no references to extension_manager are returned
|
||||
let result = extension_manager
|
||||
|
||||
@@ -22,7 +22,7 @@ pub enum RouterToolSelectionStrategy {
|
||||
#[async_trait]
|
||||
pub trait RouterToolSelector: Send + Sync {
|
||||
async fn select_tools(&self, params: Value) -> Result<Vec<Content>, ToolError>;
|
||||
async fn index_tools(&self, tools: &[Tool]) -> Result<(), ToolError>;
|
||||
async fn index_tools(&self, tools: &[Tool], extension_name: &str) -> Result<(), ToolError>;
|
||||
async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolError>;
|
||||
async fn record_tool_call(&self, tool_name: &str) -> Result<(), ToolError>;
|
||||
async fn get_recent_tool_calls(&self, limit: usize) -> Result<Vec<String>, ToolError>;
|
||||
@@ -76,6 +76,9 @@ impl RouterToolSelector for VectorToolSelector {
|
||||
|
||||
let k = params.get("k").and_then(|v| v.as_u64()).unwrap_or(5) as usize;
|
||||
|
||||
// Extract extension_name from params if present
|
||||
let extension_name = params.get("extension_name").and_then(|v| v.as_str());
|
||||
|
||||
// Check if provider supports embeddings
|
||||
if !self.embedding_provider.supports_embeddings() {
|
||||
return Err(ToolError::ExecutionError(
|
||||
@@ -98,7 +101,7 @@ impl RouterToolSelector for VectorToolSelector {
|
||||
|
||||
let vector_db = self.vector_db.read().await;
|
||||
let tools = vector_db
|
||||
.search_tools(query_embedding, k)
|
||||
.search_tools(query_embedding, k, extension_name)
|
||||
.await
|
||||
.map_err(|e| ToolError::ExecutionError(format!("Failed to search tools: {}", e)))?;
|
||||
|
||||
@@ -119,7 +122,7 @@ impl RouterToolSelector for VectorToolSelector {
|
||||
Ok(selected_tools)
|
||||
}
|
||||
|
||||
async fn index_tools(&self, tools: &[Tool]) -> Result<(), ToolError> {
|
||||
async fn index_tools(&self, tools: &[Tool], extension_name: &str) -> Result<(), ToolError> {
|
||||
let texts_to_embed: Vec<String> = tools
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
@@ -155,6 +158,7 @@ impl RouterToolSelector for VectorToolSelector {
|
||||
description: tool.description.clone(),
|
||||
schema: schema_str,
|
||||
vector,
|
||||
extension_name: extension_name.to_string(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -12,18 +12,20 @@ pub fn vector_search_tool() -> Tool {
|
||||
Format a query to search for the most relevant tools based on the user's messages.
|
||||
Pay attention to the keywords in the user's messages, especially the last message and potential tools they are asking for.
|
||||
This tool should be invoked when the user's messages suggest they are asking for a tool to be run.
|
||||
Examples:
|
||||
- {"User": "what is the weather in Tokyo?", "Query": "weather in Tokyo"}
|
||||
- {"User": "read this pdf file for me", "Query": "read pdf file"}
|
||||
- {"User": "run this command ls -l in the terminal", "Query": "run command in terminal ls -l"}
|
||||
You have the list of extension names available to you in your system prompt.
|
||||
Use the extension_name parameter to filter tools by the appropriate extension.
|
||||
For example, if the user is asking to list the files in the current directory, you filter for the "developer" extension.
|
||||
Example: {"User": "list the files in the current directory", "Query": "list files in current directory", "Extension Name": "developer", "k": 5}
|
||||
Extension name is not optional, it is required.
|
||||
"#}
|
||||
.to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["query"],
|
||||
"required": ["query", "extension_name"],
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "The query to search for the most relevant tools based on the user's messages"},
|
||||
"k": {"type": "integer", "description": "The number of tools to retrieve (defaults to 5)", "default": 5}
|
||||
"k": {"type": "integer", "description": "The number of tools to retrieve (defaults to 5)", "default": 5},
|
||||
"extension_name": {"type": "string", "description": "Name of the extension to filter tools by"}
|
||||
}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
|
||||
@@ -26,13 +26,16 @@ impl ToolRouterIndexManager {
|
||||
|
||||
if !tools.is_empty() {
|
||||
// Index all tools at once
|
||||
selector.index_tools(&tools).await.map_err(|e| {
|
||||
anyhow!(
|
||||
"Failed to index tools for extension {}: {}",
|
||||
extension_name,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
selector
|
||||
.index_tools(&tools, extension_name)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
anyhow!(
|
||||
"Failed to index tools for extension {}: {}",
|
||||
extension_name,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(
|
||||
"Indexed {} tools for extension {}",
|
||||
@@ -42,16 +45,20 @@ impl ToolRouterIndexManager {
|
||||
}
|
||||
}
|
||||
"remove" => {
|
||||
// Get tool names for the extension to remove them
|
||||
// Remove all tools for this extension
|
||||
let tools = extension_manager
|
||||
.get_prefixed_tools(Some(extension_name.to_string()))
|
||||
.await?;
|
||||
|
||||
for tool in &tools {
|
||||
selector
|
||||
.remove_tool(&tool.name)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to remove tool {}: {}", tool.name, e))?;
|
||||
selector.remove_tool(&tool.name).await.map_err(|e| {
|
||||
anyhow!(
|
||||
"Failed to remove tool {} for extension {}: {}",
|
||||
tool.name,
|
||||
extension_name,
|
||||
e
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
@@ -61,7 +68,7 @@ impl ToolRouterIndexManager {
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
anyhow::bail!("Invalid action '{}' for tool indexing", action);
|
||||
return Err(anyhow!("Invalid action: {}", action));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,7 +94,7 @@ impl ToolRouterIndexManager {
|
||||
|
||||
// Index all platform tools at once
|
||||
selector
|
||||
.index_tools(&tools)
|
||||
.index_tools(&tools, "platform")
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to index platform tools: {}", e))?;
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ pub struct ToolRecord {
|
||||
pub description: String,
|
||||
pub schema: String,
|
||||
pub vector: Vec<f32>,
|
||||
pub extension_name: String,
|
||||
}
|
||||
|
||||
pub struct ToolVectorDB {
|
||||
@@ -84,12 +85,14 @@ impl ToolVectorDB {
|
||||
),
|
||||
false,
|
||||
),
|
||||
Field::new("extension_name", DataType::Utf8, false),
|
||||
]));
|
||||
|
||||
// Create empty table
|
||||
let tool_names = StringArray::from(vec![] as Vec<&str>);
|
||||
let descriptions = StringArray::from(vec![] as Vec<&str>);
|
||||
let schemas = StringArray::from(vec![] as Vec<&str>);
|
||||
let extension_names = StringArray::from(vec![] as Vec<&str>);
|
||||
|
||||
// Create empty fixed size list array for vectors
|
||||
let mut vectors_builder =
|
||||
@@ -103,6 +106,7 @@ impl ToolVectorDB {
|
||||
Arc::new(descriptions),
|
||||
Arc::new(schemas),
|
||||
Arc::new(vectors),
|
||||
Arc::new(extension_names),
|
||||
],
|
||||
)
|
||||
.context("Failed to create record batch")?;
|
||||
@@ -163,6 +167,7 @@ impl ToolVectorDB {
|
||||
let tool_names: Vec<&str> = tools.iter().map(|t| t.tool_name.as_str()).collect();
|
||||
let descriptions: Vec<&str> = tools.iter().map(|t| t.description.as_str()).collect();
|
||||
let schemas: Vec<&str> = tools.iter().map(|t| t.schema.as_str()).collect();
|
||||
let extension_names: Vec<&str> = tools.iter().map(|t| t.extension_name.as_str()).collect();
|
||||
|
||||
let vectors_data: Vec<Option<Vec<Option<f32>>>> = tools
|
||||
.iter()
|
||||
@@ -181,11 +186,13 @@ impl ToolVectorDB {
|
||||
),
|
||||
false,
|
||||
),
|
||||
Field::new("extension_name", DataType::Utf8, false),
|
||||
]));
|
||||
|
||||
let tool_names_array = StringArray::from(tool_names);
|
||||
let descriptions_array = StringArray::from(descriptions);
|
||||
let schemas_array = StringArray::from(schemas);
|
||||
let extension_names_array = StringArray::from(extension_names);
|
||||
// Build vectors array
|
||||
let mut vectors_builder =
|
||||
FixedSizeListBuilder::new(arrow::array::Float32Builder::new(), 1536);
|
||||
@@ -213,6 +220,7 @@ impl ToolVectorDB {
|
||||
Arc::new(descriptions_array),
|
||||
Arc::new(schemas_array),
|
||||
Arc::new(vectors_array),
|
||||
Arc::new(extension_names_array),
|
||||
],
|
||||
)
|
||||
.context("Failed to create record batch")?;
|
||||
@@ -239,7 +247,12 @@ impl ToolVectorDB {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn search_tools(&self, query_vector: Vec<f32>, k: usize) -> Result<Vec<ToolRecord>> {
|
||||
pub async fn search_tools(
|
||||
&self,
|
||||
query_vector: Vec<f32>,
|
||||
k: usize,
|
||||
extension_name: Option<&str>,
|
||||
) -> Result<Vec<ToolRecord>> {
|
||||
let connection = self.connection.read().await;
|
||||
|
||||
let table = connection
|
||||
@@ -248,9 +261,11 @@ impl ToolVectorDB {
|
||||
.await
|
||||
.context("Failed to open tools table")?;
|
||||
|
||||
let results = table
|
||||
let search = table
|
||||
.vector_search(query_vector)
|
||||
.context("Failed to create vector search")?
|
||||
.context("Failed to create vector search")?;
|
||||
|
||||
let results = search
|
||||
.limit(k)
|
||||
.execute()
|
||||
.await
|
||||
@@ -281,6 +296,13 @@ impl ToolVectorDB {
|
||||
.downcast_ref::<StringArray>()
|
||||
.context("Invalid schema column type")?;
|
||||
|
||||
let extension_names = batch
|
||||
.column_by_name("extension_name")
|
||||
.context("Missing extension_name column")?
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.context("Invalid extension_name column type")?;
|
||||
|
||||
// Get the distance scores
|
||||
let distances = batch
|
||||
.column_by_name("_distance")
|
||||
@@ -292,12 +314,21 @@ impl ToolVectorDB {
|
||||
for i in 0..batch.num_rows() {
|
||||
let tool_name = tool_names.value(i).to_string();
|
||||
let _distance = distances.value(i);
|
||||
let ext_name = extension_names.value(i).to_string();
|
||||
|
||||
// Filter by extension name if provided
|
||||
if let Some(filter_ext) = extension_name {
|
||||
if ext_name != filter_ext {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
tools.push(ToolRecord {
|
||||
tool_name,
|
||||
description: descriptions.value(i).to_string(),
|
||||
schema: schemas.value(i).to_string(),
|
||||
vector: vec![], // We don't need to return the vector
|
||||
extension_name: ext_name,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -356,6 +387,7 @@ mod tests {
|
||||
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#
|
||||
.to_string(),
|
||||
vector: vec![0.1; 1536], // Mock embedding vector
|
||||
extension_name: "test_extension".to_string(),
|
||||
},
|
||||
ToolRecord {
|
||||
tool_name: "test_tool_2".to_string(),
|
||||
@@ -363,6 +395,7 @@ mod tests {
|
||||
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#
|
||||
.to_string(),
|
||||
vector: vec![0.2; 1536], // Different mock embedding vector
|
||||
extension_name: "test_extension".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
@@ -371,7 +404,7 @@ mod tests {
|
||||
|
||||
// Search for tools using a query vector similar to test_tool_1
|
||||
let query_vector = vec![0.1; 1536];
|
||||
let results = db.search_tools(query_vector, 2).await?;
|
||||
let results = db.search_tools(query_vector.clone(), 2, None).await?;
|
||||
|
||||
// Verify results
|
||||
assert_eq!(results.len(), 2, "Should find both tools");
|
||||
@@ -384,6 +417,25 @@ mod tests {
|
||||
"Second result should be test_tool_2"
|
||||
);
|
||||
|
||||
// Test filtering by extension name
|
||||
let results = db
|
||||
.search_tools(query_vector.clone(), 2, Some("test_extension"))
|
||||
.await?;
|
||||
assert_eq!(
|
||||
results.len(),
|
||||
2,
|
||||
"Should find both tools with test_extension"
|
||||
);
|
||||
|
||||
let results = db
|
||||
.search_tools(query_vector.clone(), 2, Some("nonexistent_extension"))
|
||||
.await?;
|
||||
assert_eq!(
|
||||
results.len(),
|
||||
0,
|
||||
"Should find no tools with nonexistent_extension"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -397,7 +449,7 @@ mod tests {
|
||||
|
||||
// Search in empty database
|
||||
let query_vector = vec![0.1; 1536];
|
||||
let results = db.search_tools(query_vector, 2).await?;
|
||||
let results = db.search_tools(query_vector, 2, None).await?;
|
||||
|
||||
// Verify no results returned
|
||||
assert_eq!(results.len(), 0, "Empty database should return no results");
|
||||
@@ -419,20 +471,21 @@ mod tests {
|
||||
description: "A test tool that will be deleted".to_string(),
|
||||
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#.to_string(),
|
||||
vector: vec![0.1; 1536],
|
||||
extension_name: "test_extension".to_string(),
|
||||
};
|
||||
|
||||
db.index_tools(vec![test_tool]).await?;
|
||||
|
||||
// Verify tool exists
|
||||
let query_vector = vec![0.1; 1536];
|
||||
let results = db.search_tools(query_vector.clone(), 1).await?;
|
||||
let results = db.search_tools(query_vector.clone(), 1, None).await?;
|
||||
assert_eq!(results.len(), 1, "Tool should exist before deletion");
|
||||
|
||||
// Delete the tool
|
||||
db.remove_tool("test_tool_to_delete").await?;
|
||||
|
||||
// Verify tool is gone
|
||||
let results = db.search_tools(query_vector.clone(), 1).await?;
|
||||
let results = db.search_tools(query_vector.clone(), 1, None).await?;
|
||||
assert_eq!(results.len(), 0, "Tool should be deleted");
|
||||
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user