From 3faff03f7c767b1d100026ea45a73febbcf09bf1 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Fri, 7 Mar 2025 15:25:00 -0800 Subject: [PATCH] feat: parallel processing in approve mode (#1575) --- crates/goose/src/agents/truncate.rs | 60 ++++++++++++++++++----------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index 01f4cc0e..679a7c49 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -28,7 +28,7 @@ use anyhow::{anyhow, Result}; use indoc::indoc; use mcp_core::prompt::Prompt; use mcp_core::protocol::GetPromptResult; -use mcp_core::{tool::Tool, Content}; +use mcp_core::{tool::Tool, Content, ToolError}; use serde_json::{json, Value}; use std::time::Duration; @@ -111,6 +111,15 @@ impl TruncateAgent { &OldestFirstTruncation, ) } + + async fn create_tool_future( + capabilities: &Capabilities, + tool_call: mcp_core::tool::ToolCall, + request_id: String, + ) -> (String, Result, ToolError>) { + let output = capabilities.dispatch_tool_call(tool_call).await; + (request_id, output) + } } #[async_trait] @@ -270,6 +279,7 @@ impl Agent for TruncateAgent { "approve" => { let mut read_only_tools = Vec::new(); let mut needs_confirmation = Vec::<&ToolRequest>::new(); + let mut approved_tools = Vec::new(); // First check permissions for all tools let store = ToolPermissionStore::load()?; @@ -277,11 +287,8 @@ impl Agent for TruncateAgent { if let Ok(tool_call) = request.tool_call.clone() { if let Some(allowed) = store.check_permission(request) { if allowed { - let output = capabilities.dispatch_tool_call(tool_call).await; - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - output, - ); + // Instead of executing immediately, collect approved tools + approved_tools.push((request.id.clone(), tool_call)); } else { needs_confirmation.push(request); } @@ -296,16 +303,22 @@ impl Agent for TruncateAgent { read_only_tools = detect_read_only_tools(&capabilities, needs_confirmation.clone()).await; } - // Process remaining tools that need confirmation + // Handle pre-approved and read-only tools in parallel + let mut tool_futures = Vec::new(); + + // Add pre-approved tools + for (request_id, tool_call) in approved_tools { + let tool_future = Self::create_tool_future(&capabilities, tool_call, request_id.clone()); + tool_futures.push(tool_future); + } + + // Process read-only tools for request in &needs_confirmation { if let Ok(tool_call) = request.tool_call.clone() { // Skip confirmation if the tool_call.name is in the read_only_tools list if read_only_tools.contains(&tool_call.name) { - let output = capabilities.dispatch_tool_call(tool_call).await; - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - output, - ); + let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone()); + tool_futures.push(tool_future); } else { let confirmation = Message::user().with_tool_confirmation_request( request.id.clone(), @@ -324,12 +337,9 @@ impl Agent for TruncateAgent { store.record_permission(request, confirmed, Some(Duration::from_secs(30 * 24 * 60 * 60)))?; if confirmed { - // User approved - dispatch the tool call - let output = capabilities.dispatch_tool_call(tool_call).await; - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - output, - ); + // Add this tool call to the futures collection + let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone()); + tool_futures.push(tool_future); } else { // User declined - add declined response message_tool_response = message_tool_response.with_tool_response( @@ -343,6 +353,14 @@ impl Agent for TruncateAgent { } } } + // Wait for all tool calls to complete + let results = futures::future::join_all(tool_futures).await; + for (request_id, output) in results { + message_tool_response = message_tool_response.with_tool_response( + request_id, + output, + ); + } }, "chat" => { // Skip all tool calls in chat mode @@ -370,10 +388,8 @@ impl Agent for TruncateAgent { let mut tool_futures = Vec::new(); for request in &tool_requests { if let Ok(tool_call) = request.tool_call.clone() { - tool_futures.push(async { - let output = capabilities.dispatch_tool_call(tool_call).await; - (request.id.clone(), output) - }); + let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone()); + tool_futures.push(tool_future); } } // Wait for all tool calls to complete