feat: (tool router) adds extension name in vector db & search tool (#2855)

This commit is contained in:
Wendy Tang
2025-06-13 10:28:02 -07:00
committed by GitHub
parent 9368a242ce
commit de5fc9b450
5 changed files with 119 additions and 37 deletions

View File

@@ -246,13 +246,29 @@ impl Agent {
)))
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME {
let selector = self.router_tool_selector.lock().await.clone();
ToolCallResult::from(if let Some(selector) = selector {
selector.select_tools(tool_call.arguments.clone()).await
} else {
Err(ToolError::ExecutionError(
"Encountered vector search error.".to_string(),
))
})
let selected_tools = match selector.as_ref() {
Some(selector) => match selector.select_tools(tool_call.arguments.clone()).await {
Ok(tools) => tools,
Err(e) => {
return (
request_id,
Err(ToolError::ExecutionError(format!(
"Failed to select tools: {}",
e
))),
)
}
},
None => {
return (
request_id,
Err(ToolError::ExecutionError(
"No tool selector available".to_string(),
)),
)
}
};
ToolCallResult::from(Ok(selected_tools))
} else {
// Clone the result to ensure no references to extension_manager are returned
let result = extension_manager

View File

@@ -22,7 +22,7 @@ pub enum RouterToolSelectionStrategy {
#[async_trait]
pub trait RouterToolSelector: Send + Sync {
async fn select_tools(&self, params: Value) -> Result<Vec<Content>, ToolError>;
async fn index_tools(&self, tools: &[Tool]) -> Result<(), ToolError>;
async fn index_tools(&self, tools: &[Tool], extension_name: &str) -> Result<(), ToolError>;
async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolError>;
async fn record_tool_call(&self, tool_name: &str) -> Result<(), ToolError>;
async fn get_recent_tool_calls(&self, limit: usize) -> Result<Vec<String>, ToolError>;
@@ -76,6 +76,9 @@ impl RouterToolSelector for VectorToolSelector {
let k = params.get("k").and_then(|v| v.as_u64()).unwrap_or(5) as usize;
// Extract extension_name from params if present
let extension_name = params.get("extension_name").and_then(|v| v.as_str());
// Check if provider supports embeddings
if !self.embedding_provider.supports_embeddings() {
return Err(ToolError::ExecutionError(
@@ -98,7 +101,7 @@ impl RouterToolSelector for VectorToolSelector {
let vector_db = self.vector_db.read().await;
let tools = vector_db
.search_tools(query_embedding, k)
.search_tools(query_embedding, k, extension_name)
.await
.map_err(|e| ToolError::ExecutionError(format!("Failed to search tools: {}", e)))?;
@@ -119,7 +122,7 @@ impl RouterToolSelector for VectorToolSelector {
Ok(selected_tools)
}
async fn index_tools(&self, tools: &[Tool]) -> Result<(), ToolError> {
async fn index_tools(&self, tools: &[Tool], extension_name: &str) -> Result<(), ToolError> {
let texts_to_embed: Vec<String> = tools
.iter()
.map(|tool| {
@@ -155,6 +158,7 @@ impl RouterToolSelector for VectorToolSelector {
description: tool.description.clone(),
schema: schema_str,
vector,
extension_name: extension_name.to_string(),
}
})
.collect();

View File

@@ -12,18 +12,20 @@ pub fn vector_search_tool() -> Tool {
Format a query to search for the most relevant tools based on the user's messages.
Pay attention to the keywords in the user's messages, especially the last message and potential tools they are asking for.
This tool should be invoked when the user's messages suggest they are asking for a tool to be run.
Examples:
- {"User": "what is the weather in Tokyo?", "Query": "weather in Tokyo"}
- {"User": "read this pdf file for me", "Query": "read pdf file"}
- {"User": "run this command ls -l in the terminal", "Query": "run command in terminal ls -l"}
You have the list of extension names available to you in your system prompt.
Use the extension_name parameter to filter tools by the appropriate extension.
For example, if the user is asking to list the files in the current directory, you filter for the "developer" extension.
Example: {"User": "list the files in the current directory", "Query": "list files in current directory", "Extension Name": "developer", "k": 5}
Extension name is not optional, it is required.
"#}
.to_string(),
json!({
"type": "object",
"required": ["query"],
"required": ["query", "extension_name"],
"properties": {
"query": {"type": "string", "description": "The query to search for the most relevant tools based on the user's messages"},
"k": {"type": "integer", "description": "The number of tools to retrieve (defaults to 5)", "default": 5}
"k": {"type": "integer", "description": "The number of tools to retrieve (defaults to 5)", "default": 5},
"extension_name": {"type": "string", "description": "Name of the extension to filter tools by"}
}
}),
Some(ToolAnnotations {

View File

@@ -26,13 +26,16 @@ impl ToolRouterIndexManager {
if !tools.is_empty() {
// Index all tools at once
selector.index_tools(&tools).await.map_err(|e| {
anyhow!(
"Failed to index tools for extension {}: {}",
extension_name,
e
)
})?;
selector
.index_tools(&tools, extension_name)
.await
.map_err(|e| {
anyhow!(
"Failed to index tools for extension {}: {}",
extension_name,
e
)
})?;
tracing::info!(
"Indexed {} tools for extension {}",
@@ -42,16 +45,20 @@ impl ToolRouterIndexManager {
}
}
"remove" => {
// Get tool names for the extension to remove them
// Remove all tools for this extension
let tools = extension_manager
.get_prefixed_tools(Some(extension_name.to_string()))
.await?;
for tool in &tools {
selector
.remove_tool(&tool.name)
.await
.map_err(|e| anyhow!("Failed to remove tool {}: {}", tool.name, e))?;
selector.remove_tool(&tool.name).await.map_err(|e| {
anyhow!(
"Failed to remove tool {} for extension {}: {}",
tool.name,
extension_name,
e
)
})?;
}
tracing::info!(
@@ -61,7 +68,7 @@ impl ToolRouterIndexManager {
);
}
_ => {
anyhow::bail!("Invalid action '{}' for tool indexing", action);
return Err(anyhow!("Invalid action: {}", action));
}
}
@@ -87,7 +94,7 @@ impl ToolRouterIndexManager {
// Index all platform tools at once
selector
.index_tools(&tools)
.index_tools(&tools, "platform")
.await
.map_err(|e| anyhow!("Failed to index platform tools: {}", e))?;

View File

@@ -18,6 +18,7 @@ pub struct ToolRecord {
pub description: String,
pub schema: String,
pub vector: Vec<f32>,
pub extension_name: String,
}
pub struct ToolVectorDB {
@@ -84,12 +85,14 @@ impl ToolVectorDB {
),
false,
),
Field::new("extension_name", DataType::Utf8, false),
]));
// Create empty table
let tool_names = StringArray::from(vec![] as Vec<&str>);
let descriptions = StringArray::from(vec![] as Vec<&str>);
let schemas = StringArray::from(vec![] as Vec<&str>);
let extension_names = StringArray::from(vec![] as Vec<&str>);
// Create empty fixed size list array for vectors
let mut vectors_builder =
@@ -103,6 +106,7 @@ impl ToolVectorDB {
Arc::new(descriptions),
Arc::new(schemas),
Arc::new(vectors),
Arc::new(extension_names),
],
)
.context("Failed to create record batch")?;
@@ -163,6 +167,7 @@ impl ToolVectorDB {
let tool_names: Vec<&str> = tools.iter().map(|t| t.tool_name.as_str()).collect();
let descriptions: Vec<&str> = tools.iter().map(|t| t.description.as_str()).collect();
let schemas: Vec<&str> = tools.iter().map(|t| t.schema.as_str()).collect();
let extension_names: Vec<&str> = tools.iter().map(|t| t.extension_name.as_str()).collect();
let vectors_data: Vec<Option<Vec<Option<f32>>>> = tools
.iter()
@@ -181,11 +186,13 @@ impl ToolVectorDB {
),
false,
),
Field::new("extension_name", DataType::Utf8, false),
]));
let tool_names_array = StringArray::from(tool_names);
let descriptions_array = StringArray::from(descriptions);
let schemas_array = StringArray::from(schemas);
let extension_names_array = StringArray::from(extension_names);
// Build vectors array
let mut vectors_builder =
FixedSizeListBuilder::new(arrow::array::Float32Builder::new(), 1536);
@@ -213,6 +220,7 @@ impl ToolVectorDB {
Arc::new(descriptions_array),
Arc::new(schemas_array),
Arc::new(vectors_array),
Arc::new(extension_names_array),
],
)
.context("Failed to create record batch")?;
@@ -239,7 +247,12 @@ impl ToolVectorDB {
Ok(())
}
pub async fn search_tools(&self, query_vector: Vec<f32>, k: usize) -> Result<Vec<ToolRecord>> {
pub async fn search_tools(
&self,
query_vector: Vec<f32>,
k: usize,
extension_name: Option<&str>,
) -> Result<Vec<ToolRecord>> {
let connection = self.connection.read().await;
let table = connection
@@ -248,9 +261,11 @@ impl ToolVectorDB {
.await
.context("Failed to open tools table")?;
let results = table
let search = table
.vector_search(query_vector)
.context("Failed to create vector search")?
.context("Failed to create vector search")?;
let results = search
.limit(k)
.execute()
.await
@@ -281,6 +296,13 @@ impl ToolVectorDB {
.downcast_ref::<StringArray>()
.context("Invalid schema column type")?;
let extension_names = batch
.column_by_name("extension_name")
.context("Missing extension_name column")?
.as_any()
.downcast_ref::<StringArray>()
.context("Invalid extension_name column type")?;
// Get the distance scores
let distances = batch
.column_by_name("_distance")
@@ -292,12 +314,21 @@ impl ToolVectorDB {
for i in 0..batch.num_rows() {
let tool_name = tool_names.value(i).to_string();
let _distance = distances.value(i);
let ext_name = extension_names.value(i).to_string();
// Filter by extension name if provided
if let Some(filter_ext) = extension_name {
if ext_name != filter_ext {
continue;
}
}
tools.push(ToolRecord {
tool_name,
description: descriptions.value(i).to_string(),
schema: schemas.value(i).to_string(),
vector: vec![], // We don't need to return the vector
extension_name: ext_name,
});
}
}
@@ -356,6 +387,7 @@ mod tests {
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#
.to_string(),
vector: vec![0.1; 1536], // Mock embedding vector
extension_name: "test_extension".to_string(),
},
ToolRecord {
tool_name: "test_tool_2".to_string(),
@@ -363,6 +395,7 @@ mod tests {
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#
.to_string(),
vector: vec![0.2; 1536], // Different mock embedding vector
extension_name: "test_extension".to_string(),
},
];
@@ -371,7 +404,7 @@ mod tests {
// Search for tools using a query vector similar to test_tool_1
let query_vector = vec![0.1; 1536];
let results = db.search_tools(query_vector, 2).await?;
let results = db.search_tools(query_vector.clone(), 2, None).await?;
// Verify results
assert_eq!(results.len(), 2, "Should find both tools");
@@ -384,6 +417,25 @@ mod tests {
"Second result should be test_tool_2"
);
// Test filtering by extension name
let results = db
.search_tools(query_vector.clone(), 2, Some("test_extension"))
.await?;
assert_eq!(
results.len(),
2,
"Should find both tools with test_extension"
);
let results = db
.search_tools(query_vector.clone(), 2, Some("nonexistent_extension"))
.await?;
assert_eq!(
results.len(),
0,
"Should find no tools with nonexistent_extension"
);
Ok(())
}
@@ -397,7 +449,7 @@ mod tests {
// Search in empty database
let query_vector = vec![0.1; 1536];
let results = db.search_tools(query_vector, 2).await?;
let results = db.search_tools(query_vector, 2, None).await?;
// Verify no results returned
assert_eq!(results.len(), 0, "Empty database should return no results");
@@ -419,20 +471,21 @@ mod tests {
description: "A test tool that will be deleted".to_string(),
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#.to_string(),
vector: vec![0.1; 1536],
extension_name: "test_extension".to_string(),
};
db.index_tools(vec![test_tool]).await?;
// Verify tool exists
let query_vector = vec![0.1; 1536];
let results = db.search_tools(query_vector.clone(), 1).await?;
let results = db.search_tools(query_vector.clone(), 1, None).await?;
assert_eq!(results.len(), 1, "Tool should exist before deletion");
// Delete the tool
db.remove_tool("test_tool_to_delete").await?;
// Verify tool is gone
let results = db.search_tools(query_vector.clone(), 1).await?;
let results = db.search_tools(query_vector.clone(), 1, None).await?;
assert_eq!(results.len(), 0, "Tool should be deleted");
Ok(())