feat: parallel processing in approve mode (#1575)

This commit is contained in:
Wendy Tang
2025-03-07 15:25:00 -08:00
committed by GitHub
parent 00fc3a5de8
commit 3faff03f7c

View File

@@ -28,7 +28,7 @@ use anyhow::{anyhow, Result};
use indoc::indoc;
use mcp_core::prompt::Prompt;
use mcp_core::protocol::GetPromptResult;
use mcp_core::{tool::Tool, Content};
use mcp_core::{tool::Tool, Content, ToolError};
use serde_json::{json, Value};
use std::time::Duration;
@@ -111,6 +111,15 @@ impl TruncateAgent {
&OldestFirstTruncation,
)
}
async fn create_tool_future(
capabilities: &Capabilities,
tool_call: mcp_core::tool::ToolCall,
request_id: String,
) -> (String, Result<Vec<Content>, ToolError>) {
let output = capabilities.dispatch_tool_call(tool_call).await;
(request_id, output)
}
}
#[async_trait]
@@ -270,6 +279,7 @@ impl Agent for TruncateAgent {
"approve" => {
let mut read_only_tools = Vec::new();
let mut needs_confirmation = Vec::<&ToolRequest>::new();
let mut approved_tools = Vec::new();
// First check permissions for all tools
let store = ToolPermissionStore::load()?;
@@ -277,11 +287,8 @@ impl Agent for TruncateAgent {
if let Ok(tool_call) = request.tool_call.clone() {
if let Some(allowed) = store.check_permission(request) {
if allowed {
let output = capabilities.dispatch_tool_call(tool_call).await;
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
// Instead of executing immediately, collect approved tools
approved_tools.push((request.id.clone(), tool_call));
} else {
needs_confirmation.push(request);
}
@@ -296,16 +303,22 @@ impl Agent for TruncateAgent {
read_only_tools = detect_read_only_tools(&capabilities, needs_confirmation.clone()).await;
}
// Process remaining tools that need confirmation
// Handle pre-approved and read-only tools in parallel
let mut tool_futures = Vec::new();
// Add pre-approved tools
for (request_id, tool_call) in approved_tools {
let tool_future = Self::create_tool_future(&capabilities, tool_call, request_id.clone());
tool_futures.push(tool_future);
}
// Process read-only tools
for request in &needs_confirmation {
if let Ok(tool_call) = request.tool_call.clone() {
// Skip confirmation if the tool_call.name is in the read_only_tools list
if read_only_tools.contains(&tool_call.name) {
let output = capabilities.dispatch_tool_call(tool_call).await;
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
tool_futures.push(tool_future);
} else {
let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
@@ -324,12 +337,9 @@ impl Agent for TruncateAgent {
store.record_permission(request, confirmed, Some(Duration::from_secs(30 * 24 * 60 * 60)))?;
if confirmed {
// User approved - dispatch the tool call
let output = capabilities.dispatch_tool_call(tool_call).await;
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
// Add this tool call to the futures collection
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
tool_futures.push(tool_future);
} else {
// User declined - add declined response
message_tool_response = message_tool_response.with_tool_response(
@@ -343,6 +353,14 @@ impl Agent for TruncateAgent {
}
}
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
},
"chat" => {
// Skip all tool calls in chat mode
@@ -370,10 +388,8 @@ impl Agent for TruncateAgent {
let mut tool_futures = Vec::new();
for request in &tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
tool_futures.push(async {
let output = capabilities.dispatch_tool_call(tool_call).await;
(request.id.clone(), output)
});
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
tool_futures.push(tool_future);
}
}
// Wait for all tool calls to complete