feat: avoid duplicate confirmation handle code (#2165)

This commit is contained in:
Yingjie He
2025-04-11 17:01:37 -07:00
committed by GitHub
parent 33255cd554
commit 255c8dc0f6
2 changed files with 89 additions and 77 deletions

View File

@@ -7,7 +7,8 @@ use futures::stream::BoxStream;
use crate::config::permission::PermissionLevel;
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
use crate::message::{Message, MessageContent, ToolRequest};
use crate::permission::permission_judge::check_tool_permissions;
use crate::permission::permission_confirmation::PrincipalType;
use crate::permission::permission_judge::{check_tool_permissions, get_confirmation_message};
use crate::permission::{Permission, PermissionConfirmation};
use crate::providers::base::Provider;
use crate::providers::errors::ProviderError;
@@ -22,8 +23,8 @@ use tracing::{debug, error, instrument, warn};
use crate::agents::extension::{ExtensionConfig, ExtensionResult, ToolInfo};
use crate::agents::extension_manager::{get_parameter_names, ExtensionManager};
use crate::agents::platform_tools::{
PLATFORM_ENABLE_EXTENSION_TOOL_NAME, PLATFORM_LIST_RESOURCES_TOOL_NAME,
PLATFORM_READ_RESOURCE_TOOL_NAME, PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME,
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
};
use crate::agents::prompt_manager::PromptManager;
use crate::agents::types::SessionConfig;
@@ -423,16 +424,8 @@ impl Agent {
);
}
} else {
// Split tool requests into enable_extension and others
let (enable_extension_requests, non_enable_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone()
.into_iter()
.partition(|req| {
req.tool_call.as_ref()
.map(|call| call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME)
.unwrap_or(false)
});
let mut permission_manager = PermissionManager::default();
let permission_check_result = check_tool_permissions(non_enable_extension_requests,
let permission_check_result = check_tool_permissions(remaining_requests.into_iter().copied().collect(),
&mode,
tools_with_readonly_annotation.clone(),
tools_without_annotation.clone(),
@@ -447,40 +440,6 @@ impl Agent {
"The user has declined to run this tool. \
DO NOT attempt to call this tool again. \
If there are no alternative methods to proceed, clearly explain the situation and STOP.");
// Handle install extension requests
for request in &enable_extension_requests {
if let Ok(tool_call) = request.tool_call.clone() {
let confirmation = Message::user().with_enable_extension_request(
request.id.clone(),
tool_call.arguments.get("extension_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string()
);
yield confirmation;
let mut rx = self.confirmation_rx.lock().await;
while let Some((req_id, extension_confirmation)) = rx.recv().await {
let extension_name = tool_call.arguments.get("extension_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if req_id == request.id {
if extension_confirmation.permission == Permission::AllowOnce || extension_confirmation.permission == Permission::AlwaysAllow {
let install_result = self.enable_extension(extension_name, request.id.clone()).await;
install_results.push(install_result);
} else {
// User declined - add declined response
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![denied_content_text.clone()]),
);
}
break;
}
}
}
}
// Skip the confirmation for approved tools
for request in &permission_check_result.approved {
@@ -497,28 +456,31 @@ impl Agent {
);
}
// Process read-only tools
// Process tools requiring approval
for request in &permission_check_result.needs_approval {
if let Ok(tool_call) = request.tool_call.clone() {
let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
);
let (principal_type, confirmation) = get_confirmation_message(&request.id.clone(), tool_call.clone());
yield confirmation;
// Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await;
while let Some((req_id, tool_confirmation)) = rx.recv().await {
while let Some((req_id, confirmation)) = rx.recv().await {
if req_id == request.id {
let confirmed = tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow;
if confirmed {
// Add this tool call to the futures collection
let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone());
tool_futures.push(tool_future);
if tool_confirmation.permission == Permission::AlwaysAllow {
permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow);
if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
if principal_type == PrincipalType::Extension {
let extension_name = tool_call.arguments.get("extension_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let install_result = self.enable_extension(extension_name, request.id.clone()).await;
install_results.push(install_result);
} else {
// Add this tool call to the futures collection
let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone());
tool_futures.push(tool_future);
if confirmation.permission == Permission::AlwaysAllow {
permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow);
}
}
} else {
// User declined - add declined response
@@ -535,27 +497,19 @@ impl Agent {
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
// Check if any install results had errors before processing them
let all_install_successful = !install_results.iter().any(|(_, result)| result.is_err());
for (request_id, output) in install_results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output
);
for (request_id, output) in results.into_iter().chain(install_results.into_iter()) {
message_tool_response = message_tool_response.with_tool_response(request_id, output);
}
// Update system prompt and tools if installations were successful
if all_install_successful {
(tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?;
}
}
}
yield message_tool_response.clone();

View File

@@ -1,3 +1,4 @@
use crate::agents::platform_tools::PLATFORM_ENABLE_EXTENSION_TOOL_NAME;
use crate::config::permission::PermissionLevel;
use crate::config::PermissionManager;
use crate::message::{Message, MessageContent, ToolRequest};
@@ -5,11 +6,14 @@ use crate::providers::base::Provider;
use chrono::Utc;
use indoc::indoc;
use mcp_core::tool::ToolAnnotations;
use mcp_core::ToolCall;
use mcp_core::{tool::Tool, TextContent};
use serde_json::{json, Value};
use std::collections::HashSet;
use std::sync::Arc;
use super::permission_confirmation::PrincipalType;
/// Creates the tool definition for checking read-only permissions.
fn create_read_only_tool() -> Tool {
Tool::new(
@@ -150,6 +154,35 @@ pub async fn detect_read_only_tools(
}
}
/// Gets the boolean value whether the message is enable extension related and
/// the cconfirmation message based on the tool call
pub fn get_confirmation_message(request_id: &str, tool_call: ToolCall) -> (PrincipalType, Message) {
if tool_call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME {
(
PrincipalType::Extension,
Message::user().with_enable_extension_request(
request_id,
tool_call
.arguments
.get("extension_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
),
)
} else {
(
PrincipalType::Tool,
Message::user().with_tool_confirmation_request(
request_id,
tool_call.name.clone(),
tool_call.arguments.clone(),
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
),
)
}
}
// Define return structure
pub struct PermissionCheckResult {
pub approved: Vec<ToolRequest>,
@@ -172,6 +205,13 @@ pub async fn check_tool_permissions(
for request in candidate_requests {
if let Ok(tool_call) = request.tool_call.clone() {
// Always ask approval for enable extension tool.
if tool_call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME {
// Insert at the front of the list so that enable extension can be run before other tools.
needs_approval.insert(0, request.clone());
continue;
}
if mode == "chat" {
continue;
} else if mode == "auto" {
@@ -418,8 +458,16 @@ mod tests {
}),
};
// Create a Vec of references to ToolRequests
let candidate_requests: Vec<&ToolRequest> = vec![&tool_request_1, &tool_request_2];
let enable_extension = ToolRequest {
id: "tool_3".to_string(),
tool_call: ToolResult::Ok(ToolCall {
name: PLATFORM_ENABLE_EXTENSION_TOOL_NAME.to_string(),
arguments: serde_json::json!({"url": "http://example.com"}),
}),
};
let candidate_requests: Vec<&ToolRequest> =
vec![&tool_request_1, &tool_request_2, &enable_extension];
// Call the function under test
let result = check_tool_permissions(
@@ -434,12 +482,23 @@ mod tests {
// Validate the result
assert_eq!(result.approved.len(), 1); // file_reader should be approved
assert_eq!(result.needs_approval.len(), 1); // data_fetcher should need approval
assert_eq!(result.needs_approval.len(), 2); // data_fetcher should need approval
assert_eq!(result.denied.len(), 0); // No tool should be denied in this test
// Ensure the right tools are in the approved and needs_approval lists
assert!(result.approved.iter().any(|req| req.id == "tool_1"));
assert!(result.needs_approval.iter().any(|req| req.id == "tool_2"));
let tool_0 = result.needs_approval.get(0);
assert!(
tool_0.is_some(),
"Expected at least one tool in needs_approval"
);
assert_eq!(
tool_0.unwrap().id,
"tool_3",
"PLATFORM_ENABLE_EXTENSION_TOOL_NAME should be the first in needs_approval"
);
}
#[tokio::test]
@@ -475,7 +534,6 @@ mod tests {
}),
};
// Create a Vec of references to ToolRequests
let candidate_requests: Vec<&ToolRequest> = vec![&tool_request_1, &tool_request_2];
// Call the function under test