diff --git a/crates/goose/src/config/mod.rs b/crates/goose/src/config/mod.rs index ae7570ab..79178332 100644 --- a/crates/goose/src/config/mod.rs +++ b/crates/goose/src/config/mod.rs @@ -1,11 +1,13 @@ pub mod base; mod experiments; pub mod extensions; +pub mod permission; pub use crate::agents::ExtensionConfig; pub use base::{Config, ConfigError, APP_STRATEGY}; pub use experiments::ExperimentManager; pub use extensions::{ExtensionEntry, ExtensionManager}; +pub use permission::PermissionManager; pub use extensions::DEFAULT_DISPLAY_NAME; pub use extensions::DEFAULT_EXTENSION; diff --git a/crates/goose/src/config/permission.rs b/crates/goose/src/config/permission.rs new file mode 100644 index 00000000..62c7a142 --- /dev/null +++ b/crates/goose/src/config/permission.rs @@ -0,0 +1,275 @@ +use super::APP_STRATEGY; +use etcetera::{choose_app_strategy, AppStrategy}; +use once_cell::sync::OnceCell; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fs; +use std::path::{Path, PathBuf}; + +/// Enum representing the possible permission levels for a tool. +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum PermissionLevel { + AlwaysAllow, // Tool can always be used without prompt + AskBefore, // Tool requires permission to be granted before use + NeverAllow, // Tool is never allowed to be used +} + +/// Struct representing the configuration of permissions, categorized by level. +#[derive(Debug, Deserialize, Serialize, Default, Clone)] +pub struct PermissionConfig { + pub always_allow: Vec, // List of tools that are always allowed + pub ask_before: Vec, // List of tools that require user consent + pub never_allow: Vec, // List of tools that are never allowed +} + +/// PermissionManager manages permission configurations for various tools. +#[derive(Debug)] +pub struct PermissionManager { + config_path: PathBuf, // Path to the permission configuration file + permission_map: HashMap, // Mapping of permission names to configurations +} + +// Global singleton for the PermissionManager +static GLOBAL_PERMISSION_MANAGER: OnceCell = OnceCell::new(); + +// Constants representing specific permission categories +const USER_PERMISSION: &str = "user"; +const SMART_APPROVE_PERMISSION: &str = "smart_approve"; + +/// Implements the default constructor for `PermissionManager`. +impl Default for PermissionManager { + fn default() -> Self { + // Choose the app strategy and determine the config directory + let config_dir = choose_app_strategy(APP_STRATEGY.clone()) + .expect("goose requires a home dir") + .config_dir(); + + // Ensure the configuration directory exists + std::fs::create_dir_all(&config_dir).expect("Failed to create config directory"); + let config_path = config_dir.join("permission.yaml"); + + // Load the existing configuration file or create an empty map if the file doesn't exist + let permission_map = if config_path.exists() { + // Load the configuration file + let file_contents = + fs::read_to_string(&config_path).expect("Failed to read permission.yaml"); + serde_yaml::from_str(&file_contents).unwrap_or_else(|_| HashMap::new()) + } else { + HashMap::new() // No config file, create an empty map + }; + + PermissionManager { + config_path, + permission_map, + } + } +} + +impl PermissionManager { + /// Returns the global instance of the PermissionManager, initializing it if necessary. + pub fn global() -> &'static PermissionManager { + GLOBAL_PERMISSION_MANAGER.get_or_init(PermissionManager::default) + } + + /// Creates a new `PermissionManager` with a specified config path. + pub fn new>(config_path: P) -> Self { + let config_path = config_path.as_ref().to_path_buf(); + + // Load the existing configuration file or create an empty map if the file doesn't exist + let permission_map = if config_path.exists() { + // Load the configuration file + let file_contents = + fs::read_to_string(&config_path).expect("Failed to read permission.yaml"); + serde_yaml::from_str(&file_contents).unwrap_or_else(|_| HashMap::new()) + } else { + HashMap::new() // No config file, create an empty map + }; + + PermissionManager { + config_path, + permission_map, + } + } + + /// Returns a list of all the names (keys) in the permission map. + pub fn get_permission_names(&self) -> Vec { + self.permission_map.keys().cloned().collect() + } + + /// Retrieves the user permission level for a specific tool. + pub fn get_user_permission(&self, principal_name: &str) -> Option { + self.get_permission(USER_PERMISSION, principal_name) + } + + /// Retrieves the smart approve permission level for a specific tool. + pub fn get_smart_approve_permission(&self, principal_name: &str) -> Option { + self.get_permission(SMART_APPROVE_PERMISSION, principal_name) + } + + /// Helper function to retrieve the permission level for a specific permission category and tool. + fn get_permission(&self, name: &str, principal_name: &str) -> Option { + // Check if the permission category exists in the map + if let Some(permission_config) = self.permission_map.get(name) { + // Check the permission levels for the given tool + if permission_config + .always_allow + .contains(&principal_name.to_string()) + { + return Some(PermissionLevel::AlwaysAllow); + } else if permission_config + .ask_before + .contains(&principal_name.to_string()) + { + return Some(PermissionLevel::AskBefore); + } else if permission_config + .never_allow + .contains(&principal_name.to_string()) + { + return Some(PermissionLevel::NeverAllow); + } + } + None // Return None if no matching permission level is found + } + + /// Updates the user permission level for a specific tool. + pub fn update_user_permission(&mut self, principal_name: &str, level: PermissionLevel) { + self.update_permission(USER_PERMISSION, principal_name, level) + } + + /// Updates the smart approve permission level for a specific tool. + pub fn update_smart_approve_permission( + &mut self, + principal_name: &str, + level: PermissionLevel, + ) { + self.update_permission(SMART_APPROVE_PERMISSION, principal_name, level) + } + + /// Helper function to update a permission level for a specific tool in a given permission category. + fn update_permission(&mut self, name: &str, principal_name: &str, level: PermissionLevel) { + // Get or create a new PermissionConfig for the specified category + let permission_config = self.permission_map.entry(name.to_string()).or_default(); + + // Remove the principal from all existing lists to avoid duplicates + permission_config + .always_allow + .retain(|p| p != principal_name); + permission_config.ask_before.retain(|p| p != principal_name); + permission_config + .never_allow + .retain(|p| p != principal_name); + + // Add the principal to the appropriate list + match level { + PermissionLevel::AlwaysAllow => permission_config + .always_allow + .push(principal_name.to_string()), + PermissionLevel::AskBefore => permission_config + .ask_before + .push(principal_name.to_string()), + PermissionLevel::NeverAllow => permission_config + .never_allow + .push(principal_name.to_string()), + } + + // Serialize the updated permission map and write it back to the config file + let yaml_content = serde_yaml::to_string(&self.permission_map) + .expect("Failed to serialize permission config"); + fs::write(&self.config_path, yaml_content).expect("Failed to write to permission.yaml"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + + // Helper function to create a test instance of PermissionManager with a temp dir + fn create_test_permission_manager() -> PermissionManager { + let temp_file = NamedTempFile::new().unwrap(); + let temp_path = temp_file.path(); + PermissionManager::new(temp_path) + } + + #[test] + fn test_get_permission_names_empty() { + let manager = create_test_permission_manager(); + + assert!(manager.get_permission_names().is_empty()); + } + + #[test] + fn test_update_user_permission() { + let mut manager = create_test_permission_manager(); + manager.update_user_permission("tool1", PermissionLevel::AlwaysAllow); + + let permission = manager.get_user_permission("tool1"); + assert_eq!(permission, Some(PermissionLevel::AlwaysAllow)); + } + + #[test] + fn test_update_smart_approve_permission() { + let mut manager = create_test_permission_manager(); + manager.update_smart_approve_permission("tool2", PermissionLevel::AskBefore); + + let permission = manager.get_smart_approve_permission("tool2"); + assert_eq!(permission, Some(PermissionLevel::AskBefore)); + } + + #[test] + fn test_get_permission_not_found() { + let manager = create_test_permission_manager(); + + let permission = manager.get_user_permission("non_existent_tool"); + assert_eq!(permission, None); + } + + #[test] + fn test_permission_levels() { + let mut manager = create_test_permission_manager(); + + manager.update_user_permission("tool4", PermissionLevel::AlwaysAllow); + manager.update_user_permission("tool5", PermissionLevel::AskBefore); + manager.update_user_permission("tool6", PermissionLevel::NeverAllow); + + // Check the permission levels + assert_eq!( + manager.get_user_permission("tool4"), + Some(PermissionLevel::AlwaysAllow) + ); + assert_eq!( + manager.get_user_permission("tool5"), + Some(PermissionLevel::AskBefore) + ); + assert_eq!( + manager.get_user_permission("tool6"), + Some(PermissionLevel::NeverAllow) + ); + } + + #[test] + fn test_permission_update_replaces_existing_level() { + let mut manager = create_test_permission_manager(); + + // Initially AlwaysAllow + manager.update_user_permission("tool7", PermissionLevel::AlwaysAllow); + assert_eq!( + manager.get_user_permission("tool7"), + Some(PermissionLevel::AlwaysAllow) + ); + + // Now change to NeverAllow + manager.update_user_permission("tool7", PermissionLevel::NeverAllow); + assert_eq!( + manager.get_user_permission("tool7"), + Some(PermissionLevel::NeverAllow) + ); + + // Ensure it's removed from other levels + let config = manager.permission_map.get(USER_PERMISSION).unwrap(); + assert!(!config.always_allow.contains(&"tool7".to_string())); + assert!(!config.ask_before.contains(&"tool7".to_string())); + assert!(config.never_allow.contains(&"tool7".to_string())); + } +} diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index 9be3660a..f2924ec9 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -1,10 +1,13 @@ use crate::agents::capabilities::Capabilities; +use crate::config::permission::PermissionLevel; +use crate::config::PermissionManager; use crate::message::{Message, MessageContent, ToolRequest}; use chrono::Utc; use indoc::indoc; use mcp_core::tool::ToolAnnotations; use mcp_core::{tool::Tool, TextContent}; use serde_json::{json, Value}; +use std::collections::HashSet; /// Creates the tool definition for checking read-only permissions. fn create_read_only_tool() -> Tool { @@ -136,6 +139,100 @@ pub async fn detect_read_only_tools( } } +// Define return structure +pub struct PermissionCheckResult { + pub approved: Vec, + pub needs_approval: Vec, + pub denied: Vec, +} + +pub async fn check_tool_permissions( + remaining_requests: Vec, + mode: &str, + tools_with_readonly_annotation: HashSet, + tools_without_annotation: HashSet, + permission_manager: &mut PermissionManager, + capabilities: &Capabilities, +) -> PermissionCheckResult { + let mut approved = vec![]; + let mut needs_approval = vec![]; + let mut denied = vec![]; + let mut llm_detect_candidates = vec![]; + + for request in remaining_requests { + if let Ok(tool_call) = request.tool_call.clone() { + // 1. Check user-defined permission + if let Some(level) = permission_manager.get_user_permission(&tool_call.name) { + match level { + PermissionLevel::AlwaysAllow => approved.push(request), + PermissionLevel::AskBefore => needs_approval.push(request), + PermissionLevel::NeverAllow => denied.push(request), + } + continue; + } + + // 2. Fallback based on mode + match mode { + "manual_approve" => { + needs_approval.push(request); + } + "smart_approve" => { + if let Some(level) = + permission_manager.get_smart_approve_permission(&tool_call.name) + { + match level { + PermissionLevel::AlwaysAllow => approved.push(request), + PermissionLevel::AskBefore => needs_approval.push(request), + PermissionLevel::NeverAllow => denied.push(request), + } + continue; + } + + if tools_with_readonly_annotation.contains(&tool_call.name) { + approved.push(request); + } else if tools_without_annotation.contains(&tool_call.name) { + llm_detect_candidates.push(request); + } else { + needs_approval.push(request); + } + } + _ => { + needs_approval.push(request); + } + } + } + } + + // 3. LLM detect + if !llm_detect_candidates.is_empty() && mode == "smart_approve" { + let detected_readonly_tools = + detect_read_only_tools(capabilities, llm_detect_candidates.iter().collect()).await; + for request in llm_detect_candidates { + if let Ok(tool_call) = request.tool_call.clone() { + if detected_readonly_tools.contains(&tool_call.name) { + approved.push(request); + permission_manager.update_smart_approve_permission( + &tool_call.name, + PermissionLevel::AlwaysAllow, + ); + } else { + needs_approval.push(request); + permission_manager.update_smart_approve_permission( + &tool_call.name, + PermissionLevel::AskBefore, + ); + } + } + } + } + + PermissionCheckResult { + approved, + needs_approval, + denied, + } +} + #[cfg(test)] mod tests { use super::*; @@ -148,6 +245,7 @@ mod tests { use mcp_core::ToolCall; use mcp_core::{tool::Tool, Role, ToolResult}; use serde_json::json; + use tempfile::NamedTempFile; #[derive(Clone)] struct MockProvider { @@ -270,4 +368,60 @@ mod tests { let result = detect_read_only_tools(&capabilities, vec![]).await; assert!(result.is_empty()); } + + #[tokio::test] + async fn test_check_tool_permissions() { + // Setup mocks + let temp_file = NamedTempFile::new().unwrap(); + let temp_path = temp_file.path(); + let mut permission_manager = PermissionManager::new(temp_path); + let capabilities = create_mock_capabilities(); + + let tools_with_readonly_annotation: HashSet = + vec!["file_reader".to_string()].into_iter().collect(); + let tools_without_annotation: HashSet = + vec!["data_fetcher".to_string()].into_iter().collect(); + + permission_manager.update_user_permission("file_reader", PermissionLevel::AlwaysAllow); + permission_manager + .update_smart_approve_permission("data_fetcher", PermissionLevel::AskBefore); + + let tool_request_1 = ToolRequest { + id: "tool_1".to_string(), + tool_call: ToolResult::Ok(ToolCall { + name: "file_reader".to_string(), + arguments: serde_json::json!({"path": "/path/to/file"}), + }), + }; + + let tool_request_2 = ToolRequest { + id: "tool_2".to_string(), + tool_call: ToolResult::Ok(ToolCall { + name: "data_fetcher".to_string(), + arguments: serde_json::json!({"url": "http://example.com"}), + }), + }; + + let remaining_requests = vec![tool_request_1, tool_request_2]; + + // Call the function under test + let result = check_tool_permissions( + remaining_requests, + "smart_approve", + tools_with_readonly_annotation, + tools_without_annotation, + &mut permission_manager, + &capabilities, + ) + .await; + + // Validate the result + assert_eq!(result.approved.len(), 1); // file_reader should be approved + assert_eq!(result.needs_approval.len(), 1); // data_fetcher should need approval + assert_eq!(result.denied.len(), 0); // No tool should be denied in this test + + // Ensure the right tools are in the approved and needs_approval lists + assert!(result.approved.iter().any(|req| req.id == "tool_1")); + assert!(result.needs_approval.iter().any(|req| req.id == "tool_2")); + } }