diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 0becddcd..f14c4434 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -7,7 +7,8 @@ use futures::stream::BoxStream; use crate::config::permission::PermissionLevel; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::message::{Message, MessageContent, ToolRequest}; -use crate::permission::permission_judge::check_tool_permissions; +use crate::permission::permission_confirmation::PrincipalType; +use crate::permission::permission_judge::{check_tool_permissions, get_confirmation_message}; use crate::permission::{Permission, PermissionConfirmation}; use crate::providers::base::Provider; use crate::providers::errors::ProviderError; @@ -22,8 +23,8 @@ use tracing::{debug, error, instrument, warn}; use crate::agents::extension::{ExtensionConfig, ExtensionResult, ToolInfo}; use crate::agents::extension_manager::{get_parameter_names, ExtensionManager}; use crate::agents::platform_tools::{ - PLATFORM_ENABLE_EXTENSION_TOOL_NAME, PLATFORM_LIST_RESOURCES_TOOL_NAME, - PLATFORM_READ_RESOURCE_TOOL_NAME, PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME, + PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME, + PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME, }; use crate::agents::prompt_manager::PromptManager; use crate::agents::types::SessionConfig; @@ -423,16 +424,8 @@ impl Agent { ); } } else { - // Split tool requests into enable_extension and others - let (enable_extension_requests, non_enable_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone() - .into_iter() - .partition(|req| { - req.tool_call.as_ref() - .map(|call| call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME) - .unwrap_or(false) - }); let mut permission_manager = PermissionManager::default(); - let permission_check_result = check_tool_permissions(non_enable_extension_requests, + let permission_check_result = check_tool_permissions(remaining_requests.into_iter().copied().collect(), &mode, tools_with_readonly_annotation.clone(), tools_without_annotation.clone(), @@ -447,40 +440,6 @@ impl Agent { "The user has declined to run this tool. \ DO NOT attempt to call this tool again. \ If there are no alternative methods to proceed, clearly explain the situation and STOP."); - // Handle install extension requests - for request in &enable_extension_requests { - if let Ok(tool_call) = request.tool_call.clone() { - let confirmation = Message::user().with_enable_extension_request( - request.id.clone(), - tool_call.arguments.get("extension_name") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string() - ); - yield confirmation; - - let mut rx = self.confirmation_rx.lock().await; - while let Some((req_id, extension_confirmation)) = rx.recv().await { - let extension_name = tool_call.arguments.get("extension_name") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - if req_id == request.id { - if extension_confirmation.permission == Permission::AllowOnce || extension_confirmation.permission == Permission::AlwaysAllow { - let install_result = self.enable_extension(extension_name, request.id.clone()).await; - install_results.push(install_result); - } else { - // User declined - add declined response - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - Ok(vec![denied_content_text.clone()]), - ); - } - break; - } - } - } - } // Skip the confirmation for approved tools for request in &permission_check_result.approved { @@ -497,28 +456,31 @@ impl Agent { ); } - // Process read-only tools + // Process tools requiring approval for request in &permission_check_result.needs_approval { if let Ok(tool_call) = request.tool_call.clone() { - let confirmation = Message::user().with_tool_confirmation_request( - request.id.clone(), - tool_call.name.clone(), - tool_call.arguments.clone(), - Some("Goose would like to call the above tool. Allow? (y/n):".to_string()), - ); + let (principal_type, confirmation) = get_confirmation_message(&request.id.clone(), tool_call.clone()); yield confirmation; // Wait for confirmation response through the channel let mut rx = self.confirmation_rx.lock().await; - while let Some((req_id, tool_confirmation)) = rx.recv().await { + while let Some((req_id, confirmation)) = rx.recv().await { if req_id == request.id { - let confirmed = tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow; - if confirmed { - // Add this tool call to the futures collection - let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone()); - tool_futures.push(tool_future); - if tool_confirmation.permission == Permission::AlwaysAllow { - permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow); + if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow { + if principal_type == PrincipalType::Extension { + let extension_name = tool_call.arguments.get("extension_name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let install_result = self.enable_extension(extension_name, request.id.clone()).await; + install_results.push(install_result); + } else { + // Add this tool call to the futures collection + let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone()); + tool_futures.push(tool_future); + if confirmation.permission == Permission::AlwaysAllow { + permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow); + } } } else { // User declined - add declined response @@ -535,27 +497,19 @@ impl Agent { // Wait for all tool calls to complete let results = futures::future::join_all(tool_futures).await; - for (request_id, output) in results { - message_tool_response = message_tool_response.with_tool_response( - request_id, - output, - ); - } // Check if any install results had errors before processing them let all_install_successful = !install_results.iter().any(|(_, result)| result.is_err()); - for (request_id, output) in install_results { - message_tool_response = message_tool_response.with_tool_response( - request_id, - output - ); + + for (request_id, output) in results.into_iter().chain(install_results.into_iter()) { + message_tool_response = message_tool_response.with_tool_response(request_id, output); } // Update system prompt and tools if installations were successful if all_install_successful { (tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?; } - } + } yield message_tool_response.clone(); diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index b80c3023..bdc6339b 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -1,3 +1,4 @@ +use crate::agents::platform_tools::PLATFORM_ENABLE_EXTENSION_TOOL_NAME; use crate::config::permission::PermissionLevel; use crate::config::PermissionManager; use crate::message::{Message, MessageContent, ToolRequest}; @@ -5,11 +6,14 @@ use crate::providers::base::Provider; use chrono::Utc; use indoc::indoc; use mcp_core::tool::ToolAnnotations; +use mcp_core::ToolCall; use mcp_core::{tool::Tool, TextContent}; use serde_json::{json, Value}; use std::collections::HashSet; use std::sync::Arc; +use super::permission_confirmation::PrincipalType; + /// Creates the tool definition for checking read-only permissions. fn create_read_only_tool() -> Tool { Tool::new( @@ -150,6 +154,35 @@ pub async fn detect_read_only_tools( } } +/// Gets the boolean value whether the message is enable extension related and +/// the cconfirmation message based on the tool call +pub fn get_confirmation_message(request_id: &str, tool_call: ToolCall) -> (PrincipalType, Message) { + if tool_call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME { + ( + PrincipalType::Extension, + Message::user().with_enable_extension_request( + request_id, + tool_call + .arguments + .get("extension_name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + ), + ) + } else { + ( + PrincipalType::Tool, + Message::user().with_tool_confirmation_request( + request_id, + tool_call.name.clone(), + tool_call.arguments.clone(), + Some("Goose would like to call the above tool. Allow? (y/n):".to_string()), + ), + ) + } +} + // Define return structure pub struct PermissionCheckResult { pub approved: Vec, @@ -172,6 +205,13 @@ pub async fn check_tool_permissions( for request in candidate_requests { if let Ok(tool_call) = request.tool_call.clone() { + // Always ask approval for enable extension tool. + if tool_call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME { + // Insert at the front of the list so that enable extension can be run before other tools. + needs_approval.insert(0, request.clone()); + continue; + } + if mode == "chat" { continue; } else if mode == "auto" { @@ -418,8 +458,16 @@ mod tests { }), }; - // Create a Vec of references to ToolRequests - let candidate_requests: Vec<&ToolRequest> = vec![&tool_request_1, &tool_request_2]; + let enable_extension = ToolRequest { + id: "tool_3".to_string(), + tool_call: ToolResult::Ok(ToolCall { + name: PLATFORM_ENABLE_EXTENSION_TOOL_NAME.to_string(), + arguments: serde_json::json!({"url": "http://example.com"}), + }), + }; + + let candidate_requests: Vec<&ToolRequest> = + vec![&tool_request_1, &tool_request_2, &enable_extension]; // Call the function under test let result = check_tool_permissions( @@ -434,12 +482,23 @@ mod tests { // Validate the result assert_eq!(result.approved.len(), 1); // file_reader should be approved - assert_eq!(result.needs_approval.len(), 1); // data_fetcher should need approval + assert_eq!(result.needs_approval.len(), 2); // data_fetcher should need approval assert_eq!(result.denied.len(), 0); // No tool should be denied in this test // Ensure the right tools are in the approved and needs_approval lists assert!(result.approved.iter().any(|req| req.id == "tool_1")); assert!(result.needs_approval.iter().any(|req| req.id == "tool_2")); + + let tool_0 = result.needs_approval.get(0); + assert!( + tool_0.is_some(), + "Expected at least one tool in needs_approval" + ); + assert_eq!( + tool_0.unwrap().id, + "tool_3", + "PLATFORM_ENABLE_EXTENSION_TOOL_NAME should be the first in needs_approval" + ); } #[tokio::test] @@ -475,7 +534,6 @@ mod tests { }), }; - // Create a Vec of references to ToolRequests let candidate_requests: Vec<&ToolRequest> = vec![&tool_request_1, &tool_request_2]; // Call the function under test