diff --git a/Cargo.lock b/Cargo.lock index d5142e42..399a1c43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2323,7 +2323,7 @@ dependencies = [ [[package]] name = "goose" -version = "1.0.18" +version = "1.0.17" dependencies = [ "anyhow", "async-stream", @@ -2378,7 +2378,7 @@ dependencies = [ [[package]] name = "goose-bench" -version = "1.0.18" +version = "1.0.17" dependencies = [ "anyhow", "async-trait", @@ -2401,7 +2401,7 @@ dependencies = [ [[package]] name = "goose-cli" -version = "1.0.18" +version = "1.0.17" dependencies = [ "anyhow", "async-trait", @@ -2439,7 +2439,7 @@ dependencies = [ [[package]] name = "goose-mcp" -version = "1.0.18" +version = "1.0.17" dependencies = [ "anyhow", "async-trait", @@ -2485,7 +2485,7 @@ dependencies = [ [[package]] name = "goose-server" -version = "1.0.18" +version = "1.0.17" dependencies = [ "anyhow", "async-trait", diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index f14c4434..cfa61ccb 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -3,13 +3,12 @@ use std::sync::Arc; use anyhow::{anyhow, Result}; use futures::stream::BoxStream; +use futures::TryStreamExt; -use crate::config::permission::PermissionLevel; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; -use crate::message::{Message, MessageContent, ToolRequest}; -use crate::permission::permission_confirmation::PrincipalType; -use crate::permission::permission_judge::{check_tool_permissions, get_confirmation_message}; -use crate::permission::{Permission, PermissionConfirmation}; +use crate::message::Message; +use crate::permission::permission_judge::check_tool_permissions; +use crate::permission::PermissionConfirmation; use crate::providers::base::Provider; use crate::providers::errors::ProviderError; use crate::recipe::{Author, Recipe}; @@ -33,6 +32,10 @@ use mcp_core::{ prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult, }; +use super::tool_execution::{ + ExtensionInstallResult, ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE, +}; + const MAX_TRUNCATION_ATTEMPTS: usize = 3; const ESTIMATE_FACTOR_DECAY: f32 = 0.9; @@ -105,7 +108,7 @@ impl Agent { /// Dispatch a single tool call to the appropriate client #[instrument(skip(self, tool_call, request_id), fields(input, output))] - async fn dispatch_tool_call( + pub(super) async fn dispatch_tool_call( &self, tool_call: mcp_core::tool::ToolCall, request_id: String, @@ -190,7 +193,7 @@ impl Agent { ) } - async fn enable_extension( + pub(super) async fn enable_extension( &self, extension_name: String, request_id: String, @@ -351,57 +354,37 @@ impl Agent { // Reset truncation attempt truncation_attempt = 0; - // Yield the assistant's response, but filter out frontend tool requests that we'll process separately - let filtered_response = Message { - role: response.role.clone(), - created: response.created, - content: response.content.iter().filter(|c| { - if let MessageContent::ToolRequest(req) = c { - // Only filter out frontend tool requests - if let Ok(tool_call) = &req.tool_call { - return !self.is_frontend_tool(&tool_call.name); - } - } - true - }).cloned().collect(), - }; + // categorize the type of requests we need to handle + let (frontend_requests, + remaining_requests, + filtered_response) = + self.categorize_tool_requests(&response); + + + // Yield the assistant's response with frontend tool requests filtered out yield filtered_response.clone(); tokio::task::yield_now().await; - // First collect any tool requests - let tool_requests: Vec<&ToolRequest> = response.content - .iter() - .filter_map(|content| content.as_tool_request()) - .collect(); - - if tool_requests.is_empty() { + let num_tool_requests = frontend_requests.len() + remaining_requests.len(); + if num_tool_requests == 0 { break; } - // Process tool requests depending on goose_mode - let mut message_tool_response = Message::user(); + // Process tool requests depending on frontend tools and then goose_mode + let message_tool_response = Arc::new(Mutex::new(Message::user())); // First handle any frontend tool requests - let mut remaining_requests = Vec::new(); - for request in &tool_requests { - if let Ok(tool_call) = request.tool_call.clone() { - if self.is_frontend_tool(&tool_call.name) { - // Send frontend tool request and wait for response - yield Message::assistant().with_frontend_tool_request( - request.id.clone(), - Ok(tool_call.clone()) - ); + let mut frontend_tool_stream = self.handle_frontend_tool_requests( + &frontend_requests, + message_tool_response.clone() + ); - if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await { - message_tool_response = message_tool_response.with_tool_response(id, result); - } - } else { - remaining_requests.push(request); - } - } else { - remaining_requests.push(request); - } + // we have a stream of frontend tools to handle, inside the stream + // execution is yeield back to this reply loop, and is of the same Message + // type, so we can yield that back up to be handled + while let Some(msg) = frontend_tool_stream.try_next().await? { + yield msg; } // Clone goose_mode once before the match to avoid move issues @@ -409,112 +392,110 @@ impl Agent { if mode.as_str() == "chat" { // Skip all tool calls in chat mode for request in remaining_requests { - message_tool_response = message_tool_response.with_tool_response( + let mut response = message_tool_response.lock().await; + *response = response.clone().with_tool_response( request.id.clone(), - Ok(vec![Content::text( - "Let the user know the tool call was skipped in Goose chat mode. \ - DO NOT apologize for skipping the tool call. DO NOT say sorry. \ - Provide an explanation of what the tool call would do, structured as a \ - plan for the user. Again, DO NOT apologize. \ - **Example Plan:**\n \ - 1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \ - 2. **Outline Steps** - Break down the steps.\n \ - If needed, adjust the explanation based on user preferences or questions." - )]), + Ok(vec![Content::text(CHAT_MODE_TOOL_SKIPPED_RESPONSE)]), ); } } else { - let mut permission_manager = PermissionManager::default(); - let permission_check_result = check_tool_permissions(remaining_requests.into_iter().copied().collect(), - &mode, - tools_with_readonly_annotation.clone(), - tools_without_annotation.clone(), - &mut permission_manager, - self.provider()).await; + // At this point, we have handled the frontend tool requests and know goose_mode != "chat" + // What remains is handling the remaining tool requests (enable extension, + // regular tool calls) in goose_mode == ["auto", "approve" or "smart_approve"] + let mut permission_manager = PermissionManager::default(); + let permission_check_result = check_tool_permissions(&remaining_requests, + &mode, + tools_with_readonly_annotation.clone(), + tools_without_annotation.clone(), + &mut permission_manager, + self.provider()).await; - // Handle pre-approved and read-only tools in parallel - let mut tool_futures = Vec::new(); - let mut install_results = Vec::new(); - let denied_content_text = Content::text( - "The user has declined to run this tool. \ - DO NOT attempt to call this tool again. \ - If there are no alternative methods to proceed, clearly explain the situation and STOP."); + // Handle pre-approved and read-only tools in parallel + let mut tool_futures: Vec = Vec::new(); + let mut install_results: Vec = Vec::new(); + let install_results_arc = Arc::new(Mutex::new(install_results)); - // Skip the confirmation for approved tools - for request in &permission_check_result.approved { - if let Ok(tool_call) = request.tool_call.clone() { - let tool_future = self.dispatch_tool_call(tool_call, request.id.clone()); - tool_futures.push(tool_future); - } - } - - for request in &permission_check_result.denied { - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - Ok(vec![denied_content_text.clone()]), - ); - } - - // Process tools requiring approval - for request in &permission_check_result.needs_approval { - if let Ok(tool_call) = request.tool_call.clone() { - let (principal_type, confirmation) = get_confirmation_message(&request.id.clone(), tool_call.clone()); - yield confirmation; - - // Wait for confirmation response through the channel - let mut rx = self.confirmation_rx.lock().await; - while let Some((req_id, confirmation)) = rx.recv().await { - if req_id == request.id { - if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow { - if principal_type == PrincipalType::Extension { - let extension_name = tool_call.arguments.get("extension_name") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let install_result = self.enable_extension(extension_name, request.id.clone()).await; - install_results.push(install_result); - } else { - // Add this tool call to the futures collection - let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone()); - tool_futures.push(tool_future); - if confirmation.permission == Permission::AlwaysAllow { - permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow); - } - } - } else { - // User declined - add declined response - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - Ok(vec![denied_content_text.clone()]), - ); - } - break; // Exit the loop once the matching `req_id` is found - } - } - } - } - - // Wait for all tool calls to complete - let results = futures::future::join_all(tool_futures).await; - - // Check if any install results had errors before processing them - let all_install_successful = !install_results.iter().any(|(_, result)| result.is_err()); - - for (request_id, output) in results.into_iter().chain(install_results.into_iter()) { - message_tool_response = message_tool_response.with_tool_response(request_id, output); - } - - // Update system prompt and tools if installations were successful - if all_install_successful { - (tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?; + // Skip the confirmation for approved tools + for request in &permission_check_result.approved { + if let Ok(tool_call) = request.tool_call.clone() { + let tool_future = self.dispatch_tool_call(tool_call, request.id.clone()); + tool_futures.push(Box::pin(tool_future)); } } + for request in &permission_check_result.denied { + let mut response = message_tool_response.lock().await; + *response = response.clone().with_tool_response( + request.id.clone(), + Ok(vec![Content::text(DECLINED_RESPONSE)]), + ); + } - yield message_tool_response.clone(); + // we need interior mutability in handle_approval_tool_requests + let tool_futures_arc = Arc::new(Mutex::new(tool_futures)); + // Process tools requiring approval (enable extension, regular tool calls) + let mut tool_approval_stream = self.handle_approval_tool_requests( + &permission_check_result.needs_approval, + install_results_arc.clone(), + tool_futures_arc.clone(), + &mut permission_manager, + message_tool_response.clone() + ); + + // we have a stream of tool_approval_requests to handle + // execution is yeield back to this reply loop, and is of the same Message + // type, so we can yield the Message back up to be handled and grab and + // confirmations or denials + while let Some(msg) = tool_approval_stream.try_next().await? { + yield msg; + } + + tool_futures = { + // Lock the mutex asynchronously. + let mut futures_lock = tool_futures_arc.lock().await; + // Drain the vector and collect into a new Vec. + futures_lock.drain(..).collect::>() + }; + + install_results = { + // Lock the mutex asynchronously. + let mut results_lock = install_results_arc.lock().await; + // Drain the vector and collect into a new Vec. + results_lock.drain(..).collect::>() + }; + + + // Wait for all tool calls to complete + let results = futures::future::join_all(tool_futures).await; + for (request_id, output) in results { + let mut response = message_tool_response.lock().await; + *response = response.clone().with_tool_response( + request_id, + output, + ); + } + + // Check if any install results had errors before processing them + let all_install_successful = !install_results.iter().any(|(_, result)| result.is_err()); + for (request_id, output) in install_results { + let mut response = message_tool_response.lock().await; + *response = response.clone().with_tool_response( + request_id, + output + ); + } + + // Update system prompt and tools if installations were successful + if all_install_successful { + (tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?; + } + } + + let final_message_tool_resp = message_tool_response.lock().await.clone(); + yield final_message_tool_resp.clone(); messages.push(response); - messages.push(message_tool_response); + messages.push(final_message_tool_resp); }, Err(ProviderError::ContextLengthExceeded(_)) => { if truncation_attempt >= MAX_TRUNCATION_ATTEMPTS { diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index d4c5ee6a..bd6de72e 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -4,6 +4,7 @@ pub mod extension_manager; pub mod platform_tools; pub mod prompt_manager; mod reply_parts; +mod tool_execution; mod types; pub use agent::Agent; diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index b0ab2ea6..543f7394 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use anyhow::Result; use crate::agents::platform_tools; -use crate::message::Message; +use crate::message::{Message, MessageContent, ToolRequest}; use crate::providers::base::{Provider, ProviderUsage}; use crate::providers::errors::ProviderError; use crate::providers::toolshim::{ @@ -111,6 +111,70 @@ impl Agent { Ok((response, usage)) } + /// Categorize tool requests from the response into different types + /// Returns: + /// - frontend_requests: Tool requests that should be handled by the frontend + /// - other_requests: All other tool requests (including requests to enable extensions) + /// - filtered_message: The original message with frontend tool requests removed + pub(crate) fn categorize_tool_requests( + &self, + response: &Message, + ) -> (Vec, Vec, Message) { + // First collect all tool requests + let tool_requests: Vec = response + .content + .iter() + .filter_map(|content| { + if let MessageContent::ToolRequest(req) = content { + Some(req.clone()) + } else { + None + } + }) + .collect(); + + // Create a filtered message with frontend tool requests removed + let filtered_content = response + .content + .iter() + .filter(|c| { + if let MessageContent::ToolRequest(req) = c { + // Only filter out frontend tool requests + if let Ok(tool_call) = &req.tool_call { + return !self.is_frontend_tool(&tool_call.name); + } + } + true + }) + .cloned() + .collect(); + + let filtered_message = Message { + role: response.role.clone(), + created: response.created, + content: filtered_content, + }; + + // Categorize tool requests + let mut frontend_requests = Vec::new(); + let mut other_requests = Vec::new(); + + for request in tool_requests { + if let Ok(tool_call) = &request.tool_call { + if self.is_frontend_tool(&tool_call.name) { + frontend_requests.push(request); + } else { + other_requests.push(request); + } + } else { + // If there's an error in the tool call, add it to other_requests + other_requests.push(request); + } + } + + (frontend_requests, other_requests, filtered_message) + } + /// Update session metrics after a response pub(crate) async fn update_session_metrics( session_config: crate::agents::types::SessionConfig, diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs new file mode 100644 index 00000000..e76f54c7 --- /dev/null +++ b/crates/goose/src/agents/tool_execution.rs @@ -0,0 +1,120 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use async_stream::try_stream; +use futures::stream::BoxStream; +use futures::StreamExt; +use tokio::sync::Mutex; + +use crate::config::permission::PermissionLevel; +use crate::config::PermissionManager; +use crate::message::{Message, ToolRequest}; +use crate::permission::permission_confirmation::PrincipalType; +use crate::permission::permission_judge::get_confirmation_message; +use crate::permission::Permission; +use mcp_core::{Content, ToolError}; + +// Type alias for ToolFutures - used in the agent loop to join all futures together +pub(crate) type ToolFuture<'a> = + Pin, ToolError>)> + Send + 'a>>; +pub(crate) type ToolFuturesVec<'a> = Arc>>>; +// Type alias for extension installation results +pub(crate) type ExtensionInstallResult = (String, Result, ToolError>); +pub(crate) type ExtensionInstallResults = Arc>>; + +use crate::agents::Agent; + +pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \ + DO NOT attempt to call this tool again. \ + If there are no alternative methods to proceed, clearly explain the situation and STOP."; + +pub const CHAT_MODE_TOOL_SKIPPED_RESPONSE: &str = "Let the user know the tool call was skipped in Goose chat mode. \ + DO NOT apologize for skipping the tool call. DO NOT say sorry. \ + Provide an explanation of what the tool call would do, structured as a \ + plan for the user. Again, DO NOT apologize. \ + **Example Plan:**\n \ + 1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \ + 2. **Outline Steps** - Break down the steps.\n \ + If needed, adjust the explanation based on user preferences or questions."; + +impl Agent { + pub(crate) fn handle_approval_tool_requests<'a>( + &'a self, + tool_requests: &'a [ToolRequest], + install_results: ExtensionInstallResults, + tool_futures: ToolFuturesVec<'a>, + permission_manager: &'a mut PermissionManager, + message_tool_response: Arc>, + ) -> BoxStream<'a, anyhow::Result> { + try_stream! { + for request in tool_requests { + if let Ok(tool_call) = request.tool_call.clone() { + let (principal_type, confirmation) = get_confirmation_message(&request.id.clone(), tool_call.clone()); + yield confirmation; + + let mut rx = self.confirmation_rx.lock().await; + while let Some((req_id, confirmation)) = rx.recv().await { + if req_id == request.id { + if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow { + if principal_type == PrincipalType::Extension { + let extension_name = tool_call.arguments.get("extension_name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let mut results = install_results.lock().await; + let install_result = self.enable_extension(extension_name, request.id.clone()).await; + results.push(install_result); + } else { + // Add this tool call to the futures collection + let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone()); + let mut futures = tool_futures.lock().await; + futures.push(Box::pin(tool_future)); + + if confirmation.permission == Permission::AlwaysAllow { + permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow); + } + } + } else { + // User declined - add declined response + let mut response = message_tool_response.lock().await; + *response = response.clone().with_tool_response( + request.id.clone(), + Ok(vec![Content::text(DECLINED_RESPONSE)]), + ); + } + break; // Exit the loop once the matching `req_id` is found + } + } + } + } + }.boxed() + } + + pub(crate) fn handle_frontend_tool_requests<'a>( + &'a self, + tool_requests: &'a [ToolRequest], + message_tool_response: Arc>, + ) -> BoxStream<'a, anyhow::Result> { + try_stream! { + for request in tool_requests { + if let Ok(tool_call) = request.tool_call.clone() { + if self.is_frontend_tool(&tool_call.name) { + // Send frontend tool request and wait for response + yield Message::assistant().with_frontend_tool_request( + request.id.clone(), + Ok(tool_call.clone()) + ); + + if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await { + let mut response = message_tool_response.lock().await; + *response = response.clone().with_tool_response(id, result); + } + } + } + } + } + .boxed() + } +} diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index bdc6339b..1664a5fd 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -191,7 +191,7 @@ pub struct PermissionCheckResult { } pub async fn check_tool_permissions( - candidate_requests: Vec<&ToolRequest>, + candidate_requests: &[ToolRequest], mode: &str, tools_with_readonly_annotation: HashSet, tools_without_annotation: HashSet, @@ -466,12 +466,12 @@ mod tests { }), }; - let candidate_requests: Vec<&ToolRequest> = - vec![&tool_request_1, &tool_request_2, &enable_extension]; + let candidate_requests: Vec = + vec![tool_request_1, tool_request_2, enable_extension]; // Call the function under test let result = check_tool_permissions( - candidate_requests, + &candidate_requests, "smart_approve", tools_with_readonly_annotation, tools_without_annotation, @@ -534,11 +534,11 @@ mod tests { }), }; - let candidate_requests: Vec<&ToolRequest> = vec![&tool_request_1, &tool_request_2]; + let candidate_requests: Vec = vec![tool_request_1, tool_request_2]; // Call the function under test let result = check_tool_permissions( - candidate_requests, + &candidate_requests, "auto", tools_with_readonly_annotation, tools_without_annotation,