From b805e8bd7a1d2f5a5ac879c4f201c1715da6211c Mon Sep 17 00:00:00 2001 From: Yingjie He Date: Thu, 10 Apr 2025 15:41:37 -0700 Subject: [PATCH] fix: fix the mismatched tool result/call when using enable/search extension tools (#2138) Co-authored-by: Alice Hau --- crates/goose/src/agents/agent.rs | 129 ++++++----------- .../goose/src/permission/permission_judge.rs | 136 ++++++++++++------ crates/goose/src/prompts/system.md | 2 +- 3 files changed, 139 insertions(+), 128 deletions(-) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index fe0a4896..3bee1e59 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -475,50 +475,50 @@ impl Agent { } } - // Split tool requests into enable_extension and others - let (enable_extension_requests, non_enable_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone() - .into_iter() - .partition(|req| { - req.tool_call.as_ref() - .map(|call| call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME) - .unwrap_or(false) - }); - - let (search_extension_requests, _non_search_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone() - .into_iter() - .partition(|req| { - req.tool_call.as_ref() - .map(|call| call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME) - .unwrap_or(false) - }); - // Clone goose_mode once before the match to avoid move issues let mode = goose_mode.clone(); - - // If there are install extension requests, always require confirmation - // or if goose_mode is approve or smart_approve, check permissions for all tools - if !enable_extension_requests.is_empty() || mode.as_str() == "approve" || mode.as_str() == "smart_approve" { + if mode.as_str() == "chat" { + // Skip all tool calls in chat mode + for request in remaining_requests { + message_tool_response = message_tool_response.with_tool_response( + request.id.clone(), + Ok(vec![Content::text( + "Let the user know the tool call was skipped in Goose chat mode. \ + DO NOT apologize for skipping the tool call. DO NOT say sorry. \ + Provide an explanation of what the tool call would do, structured as a \ + plan for the user. Again, DO NOT apologize. \ + **Example Plan:**\n \ + 1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \ + 2. **Outline Steps** - Break down the steps.\n \ + If needed, adjust the explanation based on user preferences or questions." + )]), + ); + } + } else { + // Split tool requests into enable_extension and others + let (enable_extension_requests, non_enable_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone() + .into_iter() + .partition(|req| { + req.tool_call.as_ref() + .map(|call| call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME) + .unwrap_or(false) + }); let mut permission_manager = PermissionManager::default(); - // Skip the platform tools - remaining_requests.retain(|req| { - if let Ok(tool_call) = &req.tool_call { - !tool_call.name.starts_with("platform__") - } else { - true // If there's an error (Err), don't skip the request - } - }); - let permission_check_result = check_tool_permissions(remaining_requests, + let permission_check_result = check_tool_permissions(non_enable_extension_requests, &mode, tools_with_readonly_annotation.clone(), tools_without_annotation.clone(), &mut permission_manager, self.provider()).await; - // Handle pre-approved and read-only tools in parallel let mut tool_futures = Vec::new(); let mut install_results = Vec::new(); + let denied_content_text = Content::text( + "The user has declined to run this tool. \ + DO NOT attempt to call this tool again. \ + If there are no alternative methods to proceed, clearly explain the situation and STOP."); // Handle install extension requests for request in &enable_extension_requests { if let Ok(tool_call) = request.tool_call.clone() { @@ -541,6 +541,12 @@ impl Agent { .to_string(); let install_result = Self::enable_extension(&mut extension_manager, extension_name, request.id.clone()).await; install_results.push(install_result); + } else { + // User declined - add declined response + message_tool_response = message_tool_response.with_tool_response( + request.id.clone(), + Ok(vec![denied_content_text.clone()]), + ); } break; } @@ -557,10 +563,6 @@ impl Agent { } } - let denied_content_text = Content::text( - "The user has declined to run this tool. \ - DO NOT attempt to call this tool again. \ - If there are no alternative methods to proceed, clearly explain the situation and STOP."); for request in &permission_check_result.denied { message_tool_response = message_tool_response.with_tool_response( request.id.clone(), @@ -629,59 +631,12 @@ impl Agent { let extensions_info = extension_manager.get_extensions_info().await; system_prompt = self.prompt_manager.build_system_prompt(extensions_info, self.frontend_instructions.clone()); tools = extension_manager.get_prefixed_tools().await?; - } - } - - if mode.as_str() == "auto" || !search_extension_requests.is_empty() { - let mut tool_futures = Vec::new(); - // Process non_enable_extension_requests and search_extension_requests without duplicates - let mut processed_ids = HashSet::new(); - - for request in non_enable_extension_requests.iter().chain(search_extension_requests.iter()) { - if processed_ids.insert(request.id.clone()) { - if let Ok(tool_call) = request.tool_call.clone() { - let is_frontend_tool = self.is_frontend_tool(&tool_call.name); - let tool_future = Self::create_tool_future(&extension_manager, tool_call, is_frontend_tool, request.id.clone()); - tool_futures.push(tool_future); - } + if extension_manager.supports_resources() { + tools.push(platform_tools::read_resource_tool()); + tools.push(platform_tools::list_resources_tool()); } - } - - // Wait for all tool calls to complete - let results = futures::future::join_all(tool_futures).await; - for (request_id, output) in results { - message_tool_response = message_tool_response.with_tool_response( - request_id, - output, - ); - } - } - - if mode.as_str() == "chat" { - // Skip all tool calls in chat mode - // Skip search extension requests since they were already processed - let non_search_non_enable_extension_requests = non_enable_extension_requests.iter() - .filter(|req| { - if let Ok(tool_call) = &req.tool_call { - tool_call.name != PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME - } else { - true - } - }); - for request in non_search_non_enable_extension_requests { - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - Ok(vec![Content::text( - "Let the user know the tool call was skipped in Goose chat mode. \ - DO NOT apologize for skipping the tool call. DO NOT say sorry. \ - Provide an explanation of what the tool call would do, structured as a \ - plan for the user. Again, DO NOT apologize. \ - **Example Plan:**\n \ - 1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \ - 2. **Outline Steps** - Break down the steps.\n \ - If needed, adjust the explanation based on user preferences or questions." - )]), - ); + tools.push(platform_tools::search_available_extensions_tool()); + tools.push(platform_tools::enable_extension_tool()); } } diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index 658ec52f..b80c3023 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -158,7 +158,7 @@ pub struct PermissionCheckResult { } pub async fn check_tool_permissions( - remaining_requests: Vec<&&ToolRequest>, + candidate_requests: Vec<&ToolRequest>, mode: &str, tools_with_readonly_annotation: HashSet, tools_without_annotation: HashSet, @@ -170,45 +170,51 @@ pub async fn check_tool_permissions( let mut denied = vec![]; let mut llm_detect_candidates = vec![]; - for &&request in &remaining_requests { + for request in candidate_requests { if let Ok(tool_call) = request.tool_call.clone() { - // 1. Check user-defined permission - if let Some(level) = permission_manager.get_user_permission(&tool_call.name) { - match level { - PermissionLevel::AlwaysAllow => approved.push(request.clone()), - PermissionLevel::AskBefore => needs_approval.push(request.clone()), - PermissionLevel::NeverAllow => denied.push(request.clone()), - } + if mode == "chat" { continue; - } - - // 2. Fallback based on mode - match mode { - "approve" => { - needs_approval.push(request.clone()); - } - "smart_approve" => { - if let Some(level) = - permission_manager.get_smart_approve_permission(&tool_call.name) - { - match level { - PermissionLevel::AlwaysAllow => approved.push(request.clone()), - PermissionLevel::AskBefore => needs_approval.push(request.clone()), - PermissionLevel::NeverAllow => denied.push(request.clone()), - } - continue; + } else if mode == "auto" { + approved.push(request.clone()); + } else { + // 1. Check user-defined permission + if let Some(level) = permission_manager.get_user_permission(&tool_call.name) { + match level { + PermissionLevel::AlwaysAllow => approved.push(request.clone()), + PermissionLevel::AskBefore => needs_approval.push(request.clone()), + PermissionLevel::NeverAllow => denied.push(request.clone()), } + continue; + } - if tools_with_readonly_annotation.contains(&tool_call.name) { - approved.push(request.clone()); - } else if tools_without_annotation.contains(&tool_call.name) { - llm_detect_candidates.push(request.clone()); - } else { + // 2. Fallback based on mode + match mode { + "approve" => { + needs_approval.push(request.clone()); + } + "smart_approve" => { + if let Some(level) = + permission_manager.get_smart_approve_permission(&tool_call.name) + { + match level { + PermissionLevel::AlwaysAllow => approved.push(request.clone()), + PermissionLevel::AskBefore => needs_approval.push(request.clone()), + PermissionLevel::NeverAllow => denied.push(request.clone()), + } + continue; + } + + if tools_with_readonly_annotation.contains(&tool_call.name) { + approved.push(request.clone()); + } else if tools_without_annotation.contains(&tool_call.name) { + llm_detect_candidates.push(request.clone()); + } else { + needs_approval.push(request.clone()); + } + } + _ => { needs_approval.push(request.clone()); } - } - _ => { - needs_approval.push(request.clone()); } } } @@ -380,7 +386,7 @@ mod tests { } #[tokio::test] - async fn test_check_tool_permissions() { + async fn test_check_tool_permissions_smart_approve() { // Setup mocks let temp_file = NamedTempFile::new().unwrap(); let temp_path = temp_file.path(); @@ -412,15 +418,12 @@ mod tests { }), }; - // Store ToolRequests in a Vec - let tool_requests = vec![&tool_request_1, &tool_request_2]; - // Create a Vec of references to ToolRequests - let remaining_requests: Vec<&&ToolRequest> = tool_requests.iter().collect(); + let candidate_requests: Vec<&ToolRequest> = vec![&tool_request_1, &tool_request_2]; // Call the function under test let result = check_tool_permissions( - remaining_requests, + candidate_requests, "smart_approve", tools_with_readonly_annotation, tools_without_annotation, @@ -438,4 +441,57 @@ mod tests { assert!(result.approved.iter().any(|req| req.id == "tool_1")); assert!(result.needs_approval.iter().any(|req| req.id == "tool_2")); } + + #[tokio::test] + async fn test_check_tool_permissions_auto() { + // Setup mocks + let temp_file = NamedTempFile::new().unwrap(); + let temp_path = temp_file.path(); + let mut permission_manager = PermissionManager::new(temp_path); + let provider = create_mock_provider(); + + let tools_with_readonly_annotation: HashSet = + vec!["file_reader".to_string()].into_iter().collect(); + let tools_without_annotation: HashSet = + vec!["data_fetcher".to_string()].into_iter().collect(); + + permission_manager.update_user_permission("file_reader", PermissionLevel::AlwaysAllow); + permission_manager + .update_smart_approve_permission("data_fetcher", PermissionLevel::AskBefore); + + let tool_request_1 = ToolRequest { + id: "tool_1".to_string(), + tool_call: ToolResult::Ok(ToolCall { + name: "file_reader".to_string(), + arguments: serde_json::json!({"path": "/path/to/file"}), + }), + }; + + let tool_request_2 = ToolRequest { + id: "tool_2".to_string(), + tool_call: ToolResult::Ok(ToolCall { + name: "data_fetcher".to_string(), + arguments: serde_json::json!({"url": "http://example.com"}), + }), + }; + + // Create a Vec of references to ToolRequests + let candidate_requests: Vec<&ToolRequest> = vec![&tool_request_1, &tool_request_2]; + + // Call the function under test + let result = check_tool_permissions( + candidate_requests, + "auto", + tools_with_readonly_annotation, + tools_without_annotation, + &mut permission_manager, + provider, + ) + .await; + + // Validate the result + assert_eq!(result.approved.len(), 2); // file_reader should be approved + assert_eq!(result.needs_approval.len(), 0); // data_fetcher should need approval + assert_eq!(result.denied.len(), 0); // No tool should be denied in this test + } } diff --git a/crates/goose/src/prompts/system.md b/crates/goose/src/prompts/system.md index 34b7b66d..65781dee 100644 --- a/crates/goose/src/prompts/system.md +++ b/crates/goose/src/prompts/system.md @@ -9,7 +9,7 @@ These models have varying knowledge cut-off dates depending on when they were tr Extensions allow other applications to provide context to Goose. Extensions connect Goose to different data sources and tools. You are capable of dynamically plugging into new extensions and learning how to use them. You solve higher level problems using the tools in these extensions, and can interact with multiple at once. -Use the search_available_extensions tool to find additional extensions to enable to help with your task. To enable extensions, use the enable_extensions tool. You should only enable extensions found from the search_available_extensions tool. +Use the search_available_extensions tool to find additional extensions to enable to help with your task. To enable extensions, use the enable_extension tool. You should only enable extensions found from the search_available_extensions tool. {% if (extensions is defined) and extensions %} Because you dynamically load extensions, your conversation history may refer