mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-19 21:44:24 +01:00
feat: avoid duplicate confirmation handle code (#2165)
This commit is contained in:
@@ -7,7 +7,8 @@ use futures::stream::BoxStream;
|
||||
use crate::config::permission::PermissionLevel;
|
||||
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
||||
use crate::message::{Message, MessageContent, ToolRequest};
|
||||
use crate::permission::permission_judge::check_tool_permissions;
|
||||
use crate::permission::permission_confirmation::PrincipalType;
|
||||
use crate::permission::permission_judge::{check_tool_permissions, get_confirmation_message};
|
||||
use crate::permission::{Permission, PermissionConfirmation};
|
||||
use crate::providers::base::Provider;
|
||||
use crate::providers::errors::ProviderError;
|
||||
@@ -22,8 +23,8 @@ use tracing::{debug, error, instrument, warn};
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionResult, ToolInfo};
|
||||
use crate::agents::extension_manager::{get_parameter_names, ExtensionManager};
|
||||
use crate::agents::platform_tools::{
|
||||
PLATFORM_ENABLE_EXTENSION_TOOL_NAME, PLATFORM_LIST_RESOURCES_TOOL_NAME,
|
||||
PLATFORM_READ_RESOURCE_TOOL_NAME, PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
|
||||
PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME,
|
||||
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
|
||||
};
|
||||
use crate::agents::prompt_manager::PromptManager;
|
||||
use crate::agents::types::SessionConfig;
|
||||
@@ -423,16 +424,8 @@ impl Agent {
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Split tool requests into enable_extension and others
|
||||
let (enable_extension_requests, non_enable_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone()
|
||||
.into_iter()
|
||||
.partition(|req| {
|
||||
req.tool_call.as_ref()
|
||||
.map(|call| call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
let mut permission_manager = PermissionManager::default();
|
||||
let permission_check_result = check_tool_permissions(non_enable_extension_requests,
|
||||
let permission_check_result = check_tool_permissions(remaining_requests.into_iter().copied().collect(),
|
||||
&mode,
|
||||
tools_with_readonly_annotation.clone(),
|
||||
tools_without_annotation.clone(),
|
||||
@@ -447,40 +440,6 @@ impl Agent {
|
||||
"The user has declined to run this tool. \
|
||||
DO NOT attempt to call this tool again. \
|
||||
If there are no alternative methods to proceed, clearly explain the situation and STOP.");
|
||||
// Handle install extension requests
|
||||
for request in &enable_extension_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let confirmation = Message::user().with_enable_extension_request(
|
||||
request.id.clone(),
|
||||
tool_call.arguments.get("extension_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
);
|
||||
yield confirmation;
|
||||
|
||||
let mut rx = self.confirmation_rx.lock().await;
|
||||
while let Some((req_id, extension_confirmation)) = rx.recv().await {
|
||||
let extension_name = tool_call.arguments.get("extension_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
if req_id == request.id {
|
||||
if extension_confirmation.permission == Permission::AllowOnce || extension_confirmation.permission == Permission::AlwaysAllow {
|
||||
let install_result = self.enable_extension(extension_name, request.id.clone()).await;
|
||||
install_results.push(install_result);
|
||||
} else {
|
||||
// User declined - add declined response
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
Ok(vec![denied_content_text.clone()]),
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip the confirmation for approved tools
|
||||
for request in &permission_check_result.approved {
|
||||
@@ -497,28 +456,31 @@ impl Agent {
|
||||
);
|
||||
}
|
||||
|
||||
// Process read-only tools
|
||||
// Process tools requiring approval
|
||||
for request in &permission_check_result.needs_approval {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let confirmation = Message::user().with_tool_confirmation_request(
|
||||
request.id.clone(),
|
||||
tool_call.name.clone(),
|
||||
tool_call.arguments.clone(),
|
||||
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
|
||||
);
|
||||
let (principal_type, confirmation) = get_confirmation_message(&request.id.clone(), tool_call.clone());
|
||||
yield confirmation;
|
||||
|
||||
// Wait for confirmation response through the channel
|
||||
let mut rx = self.confirmation_rx.lock().await;
|
||||
while let Some((req_id, tool_confirmation)) = rx.recv().await {
|
||||
while let Some((req_id, confirmation)) = rx.recv().await {
|
||||
if req_id == request.id {
|
||||
let confirmed = tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow;
|
||||
if confirmed {
|
||||
// Add this tool call to the futures collection
|
||||
let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone());
|
||||
tool_futures.push(tool_future);
|
||||
if tool_confirmation.permission == Permission::AlwaysAllow {
|
||||
permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow);
|
||||
if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
|
||||
if principal_type == PrincipalType::Extension {
|
||||
let extension_name = tool_call.arguments.get("extension_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let install_result = self.enable_extension(extension_name, request.id.clone()).await;
|
||||
install_results.push(install_result);
|
||||
} else {
|
||||
// Add this tool call to the futures collection
|
||||
let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone());
|
||||
tool_futures.push(tool_future);
|
||||
if confirmation.permission == Permission::AlwaysAllow {
|
||||
permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// User declined - add declined response
|
||||
@@ -535,27 +497,19 @@ impl Agent {
|
||||
|
||||
// Wait for all tool calls to complete
|
||||
let results = futures::future::join_all(tool_futures).await;
|
||||
for (request_id, output) in results {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request_id,
|
||||
output,
|
||||
);
|
||||
}
|
||||
|
||||
// Check if any install results had errors before processing them
|
||||
let all_install_successful = !install_results.iter().any(|(_, result)| result.is_err());
|
||||
for (request_id, output) in install_results {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request_id,
|
||||
output
|
||||
);
|
||||
|
||||
for (request_id, output) in results.into_iter().chain(install_results.into_iter()) {
|
||||
message_tool_response = message_tool_response.with_tool_response(request_id, output);
|
||||
}
|
||||
|
||||
// Update system prompt and tools if installations were successful
|
||||
if all_install_successful {
|
||||
(tools, toolshim_tools, system_prompt) = self.prepare_tools_and_prompt().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
yield message_tool_response.clone();
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::agents::platform_tools::PLATFORM_ENABLE_EXTENSION_TOOL_NAME;
|
||||
use crate::config::permission::PermissionLevel;
|
||||
use crate::config::PermissionManager;
|
||||
use crate::message::{Message, MessageContent, ToolRequest};
|
||||
@@ -5,11 +6,14 @@ use crate::providers::base::Provider;
|
||||
use chrono::Utc;
|
||||
use indoc::indoc;
|
||||
use mcp_core::tool::ToolAnnotations;
|
||||
use mcp_core::ToolCall;
|
||||
use mcp_core::{tool::Tool, TextContent};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::permission_confirmation::PrincipalType;
|
||||
|
||||
/// Creates the tool definition for checking read-only permissions.
|
||||
fn create_read_only_tool() -> Tool {
|
||||
Tool::new(
|
||||
@@ -150,6 +154,35 @@ pub async fn detect_read_only_tools(
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the boolean value whether the message is enable extension related and
|
||||
/// the cconfirmation message based on the tool call
|
||||
pub fn get_confirmation_message(request_id: &str, tool_call: ToolCall) -> (PrincipalType, Message) {
|
||||
if tool_call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME {
|
||||
(
|
||||
PrincipalType::Extension,
|
||||
Message::user().with_enable_extension_request(
|
||||
request_id,
|
||||
tool_call
|
||||
.arguments
|
||||
.get("extension_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
PrincipalType::Tool,
|
||||
Message::user().with_tool_confirmation_request(
|
||||
request_id,
|
||||
tool_call.name.clone(),
|
||||
tool_call.arguments.clone(),
|
||||
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Define return structure
|
||||
pub struct PermissionCheckResult {
|
||||
pub approved: Vec<ToolRequest>,
|
||||
@@ -172,6 +205,13 @@ pub async fn check_tool_permissions(
|
||||
|
||||
for request in candidate_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
// Always ask approval for enable extension tool.
|
||||
if tool_call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME {
|
||||
// Insert at the front of the list so that enable extension can be run before other tools.
|
||||
needs_approval.insert(0, request.clone());
|
||||
continue;
|
||||
}
|
||||
|
||||
if mode == "chat" {
|
||||
continue;
|
||||
} else if mode == "auto" {
|
||||
@@ -418,8 +458,16 @@ mod tests {
|
||||
}),
|
||||
};
|
||||
|
||||
// Create a Vec of references to ToolRequests
|
||||
let candidate_requests: Vec<&ToolRequest> = vec![&tool_request_1, &tool_request_2];
|
||||
let enable_extension = ToolRequest {
|
||||
id: "tool_3".to_string(),
|
||||
tool_call: ToolResult::Ok(ToolCall {
|
||||
name: PLATFORM_ENABLE_EXTENSION_TOOL_NAME.to_string(),
|
||||
arguments: serde_json::json!({"url": "http://example.com"}),
|
||||
}),
|
||||
};
|
||||
|
||||
let candidate_requests: Vec<&ToolRequest> =
|
||||
vec![&tool_request_1, &tool_request_2, &enable_extension];
|
||||
|
||||
// Call the function under test
|
||||
let result = check_tool_permissions(
|
||||
@@ -434,12 +482,23 @@ mod tests {
|
||||
|
||||
// Validate the result
|
||||
assert_eq!(result.approved.len(), 1); // file_reader should be approved
|
||||
assert_eq!(result.needs_approval.len(), 1); // data_fetcher should need approval
|
||||
assert_eq!(result.needs_approval.len(), 2); // data_fetcher should need approval
|
||||
assert_eq!(result.denied.len(), 0); // No tool should be denied in this test
|
||||
|
||||
// Ensure the right tools are in the approved and needs_approval lists
|
||||
assert!(result.approved.iter().any(|req| req.id == "tool_1"));
|
||||
assert!(result.needs_approval.iter().any(|req| req.id == "tool_2"));
|
||||
|
||||
let tool_0 = result.needs_approval.get(0);
|
||||
assert!(
|
||||
tool_0.is_some(),
|
||||
"Expected at least one tool in needs_approval"
|
||||
);
|
||||
assert_eq!(
|
||||
tool_0.unwrap().id,
|
||||
"tool_3",
|
||||
"PLATFORM_ENABLE_EXTENSION_TOOL_NAME should be the first in needs_approval"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -475,7 +534,6 @@ mod tests {
|
||||
}),
|
||||
};
|
||||
|
||||
// Create a Vec of references to ToolRequests
|
||||
let candidate_requests: Vec<&ToolRequest> = vec![&tool_request_1, &tool_request_2];
|
||||
|
||||
// Call the function under test
|
||||
|
||||
Reference in New Issue
Block a user