mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-23 07:24:24 +01:00
feat: add permission manager for tool permission control (#2060)
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
pub mod base;
|
||||
mod experiments;
|
||||
pub mod extensions;
|
||||
pub mod permission;
|
||||
|
||||
pub use crate::agents::ExtensionConfig;
|
||||
pub use base::{Config, ConfigError, APP_STRATEGY};
|
||||
pub use experiments::ExperimentManager;
|
||||
pub use extensions::{ExtensionEntry, ExtensionManager};
|
||||
pub use permission::PermissionManager;
|
||||
|
||||
pub use extensions::DEFAULT_DISPLAY_NAME;
|
||||
pub use extensions::DEFAULT_EXTENSION;
|
||||
|
||||
275
crates/goose/src/config/permission.rs
Normal file
275
crates/goose/src/config/permission.rs
Normal file
@@ -0,0 +1,275 @@
|
||||
use super::APP_STRATEGY;
|
||||
use etcetera::{choose_app_strategy, AppStrategy};
|
||||
use once_cell::sync::OnceCell;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Enum representing the possible permission levels for a tool.
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PermissionLevel {
|
||||
AlwaysAllow, // Tool can always be used without prompt
|
||||
AskBefore, // Tool requires permission to be granted before use
|
||||
NeverAllow, // Tool is never allowed to be used
|
||||
}
|
||||
|
||||
/// Struct representing the configuration of permissions, categorized by level.
|
||||
#[derive(Debug, Deserialize, Serialize, Default, Clone)]
|
||||
pub struct PermissionConfig {
|
||||
pub always_allow: Vec<String>, // List of tools that are always allowed
|
||||
pub ask_before: Vec<String>, // List of tools that require user consent
|
||||
pub never_allow: Vec<String>, // List of tools that are never allowed
|
||||
}
|
||||
|
||||
/// PermissionManager manages permission configurations for various tools.
|
||||
#[derive(Debug)]
|
||||
pub struct PermissionManager {
|
||||
config_path: PathBuf, // Path to the permission configuration file
|
||||
permission_map: HashMap<String, PermissionConfig>, // Mapping of permission names to configurations
|
||||
}
|
||||
|
||||
// Global singleton for the PermissionManager
|
||||
static GLOBAL_PERMISSION_MANAGER: OnceCell<PermissionManager> = OnceCell::new();
|
||||
|
||||
// Constants representing specific permission categories
|
||||
const USER_PERMISSION: &str = "user";
|
||||
const SMART_APPROVE_PERMISSION: &str = "smart_approve";
|
||||
|
||||
/// Implements the default constructor for `PermissionManager`.
|
||||
impl Default for PermissionManager {
|
||||
fn default() -> Self {
|
||||
// Choose the app strategy and determine the config directory
|
||||
let config_dir = choose_app_strategy(APP_STRATEGY.clone())
|
||||
.expect("goose requires a home dir")
|
||||
.config_dir();
|
||||
|
||||
// Ensure the configuration directory exists
|
||||
std::fs::create_dir_all(&config_dir).expect("Failed to create config directory");
|
||||
let config_path = config_dir.join("permission.yaml");
|
||||
|
||||
// Load the existing configuration file or create an empty map if the file doesn't exist
|
||||
let permission_map = if config_path.exists() {
|
||||
// Load the configuration file
|
||||
let file_contents =
|
||||
fs::read_to_string(&config_path).expect("Failed to read permission.yaml");
|
||||
serde_yaml::from_str(&file_contents).unwrap_or_else(|_| HashMap::new())
|
||||
} else {
|
||||
HashMap::new() // No config file, create an empty map
|
||||
};
|
||||
|
||||
PermissionManager {
|
||||
config_path,
|
||||
permission_map,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PermissionManager {
|
||||
/// Returns the global instance of the PermissionManager, initializing it if necessary.
|
||||
pub fn global() -> &'static PermissionManager {
|
||||
GLOBAL_PERMISSION_MANAGER.get_or_init(PermissionManager::default)
|
||||
}
|
||||
|
||||
/// Creates a new `PermissionManager` with a specified config path.
|
||||
pub fn new<P: AsRef<Path>>(config_path: P) -> Self {
|
||||
let config_path = config_path.as_ref().to_path_buf();
|
||||
|
||||
// Load the existing configuration file or create an empty map if the file doesn't exist
|
||||
let permission_map = if config_path.exists() {
|
||||
// Load the configuration file
|
||||
let file_contents =
|
||||
fs::read_to_string(&config_path).expect("Failed to read permission.yaml");
|
||||
serde_yaml::from_str(&file_contents).unwrap_or_else(|_| HashMap::new())
|
||||
} else {
|
||||
HashMap::new() // No config file, create an empty map
|
||||
};
|
||||
|
||||
PermissionManager {
|
||||
config_path,
|
||||
permission_map,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a list of all the names (keys) in the permission map.
|
||||
pub fn get_permission_names(&self) -> Vec<String> {
|
||||
self.permission_map.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Retrieves the user permission level for a specific tool.
|
||||
pub fn get_user_permission(&self, principal_name: &str) -> Option<PermissionLevel> {
|
||||
self.get_permission(USER_PERMISSION, principal_name)
|
||||
}
|
||||
|
||||
/// Retrieves the smart approve permission level for a specific tool.
|
||||
pub fn get_smart_approve_permission(&self, principal_name: &str) -> Option<PermissionLevel> {
|
||||
self.get_permission(SMART_APPROVE_PERMISSION, principal_name)
|
||||
}
|
||||
|
||||
/// Helper function to retrieve the permission level for a specific permission category and tool.
|
||||
fn get_permission(&self, name: &str, principal_name: &str) -> Option<PermissionLevel> {
|
||||
// Check if the permission category exists in the map
|
||||
if let Some(permission_config) = self.permission_map.get(name) {
|
||||
// Check the permission levels for the given tool
|
||||
if permission_config
|
||||
.always_allow
|
||||
.contains(&principal_name.to_string())
|
||||
{
|
||||
return Some(PermissionLevel::AlwaysAllow);
|
||||
} else if permission_config
|
||||
.ask_before
|
||||
.contains(&principal_name.to_string())
|
||||
{
|
||||
return Some(PermissionLevel::AskBefore);
|
||||
} else if permission_config
|
||||
.never_allow
|
||||
.contains(&principal_name.to_string())
|
||||
{
|
||||
return Some(PermissionLevel::NeverAllow);
|
||||
}
|
||||
}
|
||||
None // Return None if no matching permission level is found
|
||||
}
|
||||
|
||||
/// Updates the user permission level for a specific tool.
|
||||
pub fn update_user_permission(&mut self, principal_name: &str, level: PermissionLevel) {
|
||||
self.update_permission(USER_PERMISSION, principal_name, level)
|
||||
}
|
||||
|
||||
/// Updates the smart approve permission level for a specific tool.
|
||||
pub fn update_smart_approve_permission(
|
||||
&mut self,
|
||||
principal_name: &str,
|
||||
level: PermissionLevel,
|
||||
) {
|
||||
self.update_permission(SMART_APPROVE_PERMISSION, principal_name, level)
|
||||
}
|
||||
|
||||
/// Helper function to update a permission level for a specific tool in a given permission category.
|
||||
fn update_permission(&mut self, name: &str, principal_name: &str, level: PermissionLevel) {
|
||||
// Get or create a new PermissionConfig for the specified category
|
||||
let permission_config = self.permission_map.entry(name.to_string()).or_default();
|
||||
|
||||
// Remove the principal from all existing lists to avoid duplicates
|
||||
permission_config
|
||||
.always_allow
|
||||
.retain(|p| p != principal_name);
|
||||
permission_config.ask_before.retain(|p| p != principal_name);
|
||||
permission_config
|
||||
.never_allow
|
||||
.retain(|p| p != principal_name);
|
||||
|
||||
// Add the principal to the appropriate list
|
||||
match level {
|
||||
PermissionLevel::AlwaysAllow => permission_config
|
||||
.always_allow
|
||||
.push(principal_name.to_string()),
|
||||
PermissionLevel::AskBefore => permission_config
|
||||
.ask_before
|
||||
.push(principal_name.to_string()),
|
||||
PermissionLevel::NeverAllow => permission_config
|
||||
.never_allow
|
||||
.push(principal_name.to_string()),
|
||||
}
|
||||
|
||||
// Serialize the updated permission map and write it back to the config file
|
||||
let yaml_content = serde_yaml::to_string(&self.permission_map)
|
||||
.expect("Failed to serialize permission config");
|
||||
fs::write(&self.config_path, yaml_content).expect("Failed to write to permission.yaml");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
// Helper function to create a test instance of PermissionManager with a temp dir
|
||||
fn create_test_permission_manager() -> PermissionManager {
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let temp_path = temp_file.path();
|
||||
PermissionManager::new(temp_path)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_permission_names_empty() {
|
||||
let manager = create_test_permission_manager();
|
||||
|
||||
assert!(manager.get_permission_names().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_user_permission() {
|
||||
let mut manager = create_test_permission_manager();
|
||||
manager.update_user_permission("tool1", PermissionLevel::AlwaysAllow);
|
||||
|
||||
let permission = manager.get_user_permission("tool1");
|
||||
assert_eq!(permission, Some(PermissionLevel::AlwaysAllow));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_update_smart_approve_permission() {
|
||||
let mut manager = create_test_permission_manager();
|
||||
manager.update_smart_approve_permission("tool2", PermissionLevel::AskBefore);
|
||||
|
||||
let permission = manager.get_smart_approve_permission("tool2");
|
||||
assert_eq!(permission, Some(PermissionLevel::AskBefore));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_permission_not_found() {
|
||||
let manager = create_test_permission_manager();
|
||||
|
||||
let permission = manager.get_user_permission("non_existent_tool");
|
||||
assert_eq!(permission, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permission_levels() {
|
||||
let mut manager = create_test_permission_manager();
|
||||
|
||||
manager.update_user_permission("tool4", PermissionLevel::AlwaysAllow);
|
||||
manager.update_user_permission("tool5", PermissionLevel::AskBefore);
|
||||
manager.update_user_permission("tool6", PermissionLevel::NeverAllow);
|
||||
|
||||
// Check the permission levels
|
||||
assert_eq!(
|
||||
manager.get_user_permission("tool4"),
|
||||
Some(PermissionLevel::AlwaysAllow)
|
||||
);
|
||||
assert_eq!(
|
||||
manager.get_user_permission("tool5"),
|
||||
Some(PermissionLevel::AskBefore)
|
||||
);
|
||||
assert_eq!(
|
||||
manager.get_user_permission("tool6"),
|
||||
Some(PermissionLevel::NeverAllow)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permission_update_replaces_existing_level() {
|
||||
let mut manager = create_test_permission_manager();
|
||||
|
||||
// Initially AlwaysAllow
|
||||
manager.update_user_permission("tool7", PermissionLevel::AlwaysAllow);
|
||||
assert_eq!(
|
||||
manager.get_user_permission("tool7"),
|
||||
Some(PermissionLevel::AlwaysAllow)
|
||||
);
|
||||
|
||||
// Now change to NeverAllow
|
||||
manager.update_user_permission("tool7", PermissionLevel::NeverAllow);
|
||||
assert_eq!(
|
||||
manager.get_user_permission("tool7"),
|
||||
Some(PermissionLevel::NeverAllow)
|
||||
);
|
||||
|
||||
// Ensure it's removed from other levels
|
||||
let config = manager.permission_map.get(USER_PERMISSION).unwrap();
|
||||
assert!(!config.always_allow.contains(&"tool7".to_string()));
|
||||
assert!(!config.ask_before.contains(&"tool7".to_string()));
|
||||
assert!(config.never_allow.contains(&"tool7".to_string()));
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
use crate::agents::capabilities::Capabilities;
|
||||
use crate::config::permission::PermissionLevel;
|
||||
use crate::config::PermissionManager;
|
||||
use crate::message::{Message, MessageContent, ToolRequest};
|
||||
use chrono::Utc;
|
||||
use indoc::indoc;
|
||||
use mcp_core::tool::ToolAnnotations;
|
||||
use mcp_core::{tool::Tool, TextContent};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Creates the tool definition for checking read-only permissions.
|
||||
fn create_read_only_tool() -> Tool {
|
||||
@@ -136,6 +139,100 @@ pub async fn detect_read_only_tools(
|
||||
}
|
||||
}
|
||||
|
||||
// Define return structure
|
||||
pub struct PermissionCheckResult {
|
||||
pub approved: Vec<ToolRequest>,
|
||||
pub needs_approval: Vec<ToolRequest>,
|
||||
pub denied: Vec<ToolRequest>,
|
||||
}
|
||||
|
||||
pub async fn check_tool_permissions(
|
||||
remaining_requests: Vec<ToolRequest>,
|
||||
mode: &str,
|
||||
tools_with_readonly_annotation: HashSet<String>,
|
||||
tools_without_annotation: HashSet<String>,
|
||||
permission_manager: &mut PermissionManager,
|
||||
capabilities: &Capabilities,
|
||||
) -> PermissionCheckResult {
|
||||
let mut approved = vec![];
|
||||
let mut needs_approval = vec![];
|
||||
let mut denied = vec![];
|
||||
let mut llm_detect_candidates = vec![];
|
||||
|
||||
for request in remaining_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
// 1. Check user-defined permission
|
||||
if let Some(level) = permission_manager.get_user_permission(&tool_call.name) {
|
||||
match level {
|
||||
PermissionLevel::AlwaysAllow => approved.push(request),
|
||||
PermissionLevel::AskBefore => needs_approval.push(request),
|
||||
PermissionLevel::NeverAllow => denied.push(request),
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// 2. Fallback based on mode
|
||||
match mode {
|
||||
"manual_approve" => {
|
||||
needs_approval.push(request);
|
||||
}
|
||||
"smart_approve" => {
|
||||
if let Some(level) =
|
||||
permission_manager.get_smart_approve_permission(&tool_call.name)
|
||||
{
|
||||
match level {
|
||||
PermissionLevel::AlwaysAllow => approved.push(request),
|
||||
PermissionLevel::AskBefore => needs_approval.push(request),
|
||||
PermissionLevel::NeverAllow => denied.push(request),
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if tools_with_readonly_annotation.contains(&tool_call.name) {
|
||||
approved.push(request);
|
||||
} else if tools_without_annotation.contains(&tool_call.name) {
|
||||
llm_detect_candidates.push(request);
|
||||
} else {
|
||||
needs_approval.push(request);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
needs_approval.push(request);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. LLM detect
|
||||
if !llm_detect_candidates.is_empty() && mode == "smart_approve" {
|
||||
let detected_readonly_tools =
|
||||
detect_read_only_tools(capabilities, llm_detect_candidates.iter().collect()).await;
|
||||
for request in llm_detect_candidates {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
if detected_readonly_tools.contains(&tool_call.name) {
|
||||
approved.push(request);
|
||||
permission_manager.update_smart_approve_permission(
|
||||
&tool_call.name,
|
||||
PermissionLevel::AlwaysAllow,
|
||||
);
|
||||
} else {
|
||||
needs_approval.push(request);
|
||||
permission_manager.update_smart_approve_permission(
|
||||
&tool_call.name,
|
||||
PermissionLevel::AskBefore,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PermissionCheckResult {
|
||||
approved,
|
||||
needs_approval,
|
||||
denied,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -148,6 +245,7 @@ mod tests {
|
||||
use mcp_core::ToolCall;
|
||||
use mcp_core::{tool::Tool, Role, ToolResult};
|
||||
use serde_json::json;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MockProvider {
|
||||
@@ -270,4 +368,60 @@ mod tests {
|
||||
let result = detect_read_only_tools(&capabilities, vec![]).await;
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_check_tool_permissions() {
|
||||
// Setup mocks
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let temp_path = temp_file.path();
|
||||
let mut permission_manager = PermissionManager::new(temp_path);
|
||||
let capabilities = create_mock_capabilities();
|
||||
|
||||
let tools_with_readonly_annotation: HashSet<String> =
|
||||
vec!["file_reader".to_string()].into_iter().collect();
|
||||
let tools_without_annotation: HashSet<String> =
|
||||
vec!["data_fetcher".to_string()].into_iter().collect();
|
||||
|
||||
permission_manager.update_user_permission("file_reader", PermissionLevel::AlwaysAllow);
|
||||
permission_manager
|
||||
.update_smart_approve_permission("data_fetcher", PermissionLevel::AskBefore);
|
||||
|
||||
let tool_request_1 = ToolRequest {
|
||||
id: "tool_1".to_string(),
|
||||
tool_call: ToolResult::Ok(ToolCall {
|
||||
name: "file_reader".to_string(),
|
||||
arguments: serde_json::json!({"path": "/path/to/file"}),
|
||||
}),
|
||||
};
|
||||
|
||||
let tool_request_2 = ToolRequest {
|
||||
id: "tool_2".to_string(),
|
||||
tool_call: ToolResult::Ok(ToolCall {
|
||||
name: "data_fetcher".to_string(),
|
||||
arguments: serde_json::json!({"url": "http://example.com"}),
|
||||
}),
|
||||
};
|
||||
|
||||
let remaining_requests = vec![tool_request_1, tool_request_2];
|
||||
|
||||
// Call the function under test
|
||||
let result = check_tool_permissions(
|
||||
remaining_requests,
|
||||
"smart_approve",
|
||||
tools_with_readonly_annotation,
|
||||
tools_without_annotation,
|
||||
&mut permission_manager,
|
||||
&capabilities,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Validate the result
|
||||
assert_eq!(result.approved.len(), 1); // file_reader should be approved
|
||||
assert_eq!(result.needs_approval.len(), 1); // data_fetcher should need approval
|
||||
assert_eq!(result.denied.len(), 0); // No tool should be denied in this test
|
||||
|
||||
// Ensure the right tools are in the approved and needs_approval lists
|
||||
assert!(result.approved.iter().any(|req| req.id == "tool_1"));
|
||||
assert!(result.needs_approval.iter().any(|req| req.id == "tool_2"));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user