feat: add permission manager for tool permission control (#2060)

This commit is contained in:
Yingjie He
2025-04-07 20:11:01 -07:00
committed by GitHub
parent 490944c3f8
commit 268fc5e057
3 changed files with 431 additions and 0 deletions

View File

@@ -1,11 +1,13 @@
pub mod base;
mod experiments;
pub mod extensions;
pub mod permission;
pub use crate::agents::ExtensionConfig;
pub use base::{Config, ConfigError, APP_STRATEGY};
pub use experiments::ExperimentManager;
pub use extensions::{ExtensionEntry, ExtensionManager};
pub use permission::PermissionManager;
pub use extensions::DEFAULT_DISPLAY_NAME;
pub use extensions::DEFAULT_EXTENSION;

View File

@@ -0,0 +1,275 @@
use super::APP_STRATEGY;
use etcetera::{choose_app_strategy, AppStrategy};
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
/// Enum representing the possible permission levels for a tool.
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum PermissionLevel {
AlwaysAllow, // Tool can always be used without prompt
AskBefore, // Tool requires permission to be granted before use
NeverAllow, // Tool is never allowed to be used
}
/// Struct representing the configuration of permissions, categorized by level.
#[derive(Debug, Deserialize, Serialize, Default, Clone)]
pub struct PermissionConfig {
pub always_allow: Vec<String>, // List of tools that are always allowed
pub ask_before: Vec<String>, // List of tools that require user consent
pub never_allow: Vec<String>, // List of tools that are never allowed
}
/// PermissionManager manages permission configurations for various tools.
#[derive(Debug)]
pub struct PermissionManager {
config_path: PathBuf, // Path to the permission configuration file
permission_map: HashMap<String, PermissionConfig>, // Mapping of permission names to configurations
}
// Global singleton for the PermissionManager
static GLOBAL_PERMISSION_MANAGER: OnceCell<PermissionManager> = OnceCell::new();
// Constants representing specific permission categories
const USER_PERMISSION: &str = "user";
const SMART_APPROVE_PERMISSION: &str = "smart_approve";
/// Implements the default constructor for `PermissionManager`.
impl Default for PermissionManager {
fn default() -> Self {
// Choose the app strategy and determine the config directory
let config_dir = choose_app_strategy(APP_STRATEGY.clone())
.expect("goose requires a home dir")
.config_dir();
// Ensure the configuration directory exists
std::fs::create_dir_all(&config_dir).expect("Failed to create config directory");
let config_path = config_dir.join("permission.yaml");
// Load the existing configuration file or create an empty map if the file doesn't exist
let permission_map = if config_path.exists() {
// Load the configuration file
let file_contents =
fs::read_to_string(&config_path).expect("Failed to read permission.yaml");
serde_yaml::from_str(&file_contents).unwrap_or_else(|_| HashMap::new())
} else {
HashMap::new() // No config file, create an empty map
};
PermissionManager {
config_path,
permission_map,
}
}
}
impl PermissionManager {
/// Returns the global instance of the PermissionManager, initializing it if necessary.
pub fn global() -> &'static PermissionManager {
GLOBAL_PERMISSION_MANAGER.get_or_init(PermissionManager::default)
}
/// Creates a new `PermissionManager` with a specified config path.
pub fn new<P: AsRef<Path>>(config_path: P) -> Self {
let config_path = config_path.as_ref().to_path_buf();
// Load the existing configuration file or create an empty map if the file doesn't exist
let permission_map = if config_path.exists() {
// Load the configuration file
let file_contents =
fs::read_to_string(&config_path).expect("Failed to read permission.yaml");
serde_yaml::from_str(&file_contents).unwrap_or_else(|_| HashMap::new())
} else {
HashMap::new() // No config file, create an empty map
};
PermissionManager {
config_path,
permission_map,
}
}
/// Returns a list of all the names (keys) in the permission map.
pub fn get_permission_names(&self) -> Vec<String> {
self.permission_map.keys().cloned().collect()
}
/// Retrieves the user permission level for a specific tool.
pub fn get_user_permission(&self, principal_name: &str) -> Option<PermissionLevel> {
self.get_permission(USER_PERMISSION, principal_name)
}
/// Retrieves the smart approve permission level for a specific tool.
pub fn get_smart_approve_permission(&self, principal_name: &str) -> Option<PermissionLevel> {
self.get_permission(SMART_APPROVE_PERMISSION, principal_name)
}
/// Helper function to retrieve the permission level for a specific permission category and tool.
fn get_permission(&self, name: &str, principal_name: &str) -> Option<PermissionLevel> {
// Check if the permission category exists in the map
if let Some(permission_config) = self.permission_map.get(name) {
// Check the permission levels for the given tool
if permission_config
.always_allow
.contains(&principal_name.to_string())
{
return Some(PermissionLevel::AlwaysAllow);
} else if permission_config
.ask_before
.contains(&principal_name.to_string())
{
return Some(PermissionLevel::AskBefore);
} else if permission_config
.never_allow
.contains(&principal_name.to_string())
{
return Some(PermissionLevel::NeverAllow);
}
}
None // Return None if no matching permission level is found
}
/// Updates the user permission level for a specific tool.
pub fn update_user_permission(&mut self, principal_name: &str, level: PermissionLevel) {
self.update_permission(USER_PERMISSION, principal_name, level)
}
/// Updates the smart approve permission level for a specific tool.
pub fn update_smart_approve_permission(
&mut self,
principal_name: &str,
level: PermissionLevel,
) {
self.update_permission(SMART_APPROVE_PERMISSION, principal_name, level)
}
/// Helper function to update a permission level for a specific tool in a given permission category.
fn update_permission(&mut self, name: &str, principal_name: &str, level: PermissionLevel) {
// Get or create a new PermissionConfig for the specified category
let permission_config = self.permission_map.entry(name.to_string()).or_default();
// Remove the principal from all existing lists to avoid duplicates
permission_config
.always_allow
.retain(|p| p != principal_name);
permission_config.ask_before.retain(|p| p != principal_name);
permission_config
.never_allow
.retain(|p| p != principal_name);
// Add the principal to the appropriate list
match level {
PermissionLevel::AlwaysAllow => permission_config
.always_allow
.push(principal_name.to_string()),
PermissionLevel::AskBefore => permission_config
.ask_before
.push(principal_name.to_string()),
PermissionLevel::NeverAllow => permission_config
.never_allow
.push(principal_name.to_string()),
}
// Serialize the updated permission map and write it back to the config file
let yaml_content = serde_yaml::to_string(&self.permission_map)
.expect("Failed to serialize permission config");
fs::write(&self.config_path, yaml_content).expect("Failed to write to permission.yaml");
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
// Helper function to create a test instance of PermissionManager with a temp dir
fn create_test_permission_manager() -> PermissionManager {
let temp_file = NamedTempFile::new().unwrap();
let temp_path = temp_file.path();
PermissionManager::new(temp_path)
}
#[test]
fn test_get_permission_names_empty() {
let manager = create_test_permission_manager();
assert!(manager.get_permission_names().is_empty());
}
#[test]
fn test_update_user_permission() {
let mut manager = create_test_permission_manager();
manager.update_user_permission("tool1", PermissionLevel::AlwaysAllow);
let permission = manager.get_user_permission("tool1");
assert_eq!(permission, Some(PermissionLevel::AlwaysAllow));
}
#[test]
fn test_update_smart_approve_permission() {
let mut manager = create_test_permission_manager();
manager.update_smart_approve_permission("tool2", PermissionLevel::AskBefore);
let permission = manager.get_smart_approve_permission("tool2");
assert_eq!(permission, Some(PermissionLevel::AskBefore));
}
#[test]
fn test_get_permission_not_found() {
let manager = create_test_permission_manager();
let permission = manager.get_user_permission("non_existent_tool");
assert_eq!(permission, None);
}
#[test]
fn test_permission_levels() {
let mut manager = create_test_permission_manager();
manager.update_user_permission("tool4", PermissionLevel::AlwaysAllow);
manager.update_user_permission("tool5", PermissionLevel::AskBefore);
manager.update_user_permission("tool6", PermissionLevel::NeverAllow);
// Check the permission levels
assert_eq!(
manager.get_user_permission("tool4"),
Some(PermissionLevel::AlwaysAllow)
);
assert_eq!(
manager.get_user_permission("tool5"),
Some(PermissionLevel::AskBefore)
);
assert_eq!(
manager.get_user_permission("tool6"),
Some(PermissionLevel::NeverAllow)
);
}
#[test]
fn test_permission_update_replaces_existing_level() {
let mut manager = create_test_permission_manager();
// Initially AlwaysAllow
manager.update_user_permission("tool7", PermissionLevel::AlwaysAllow);
assert_eq!(
manager.get_user_permission("tool7"),
Some(PermissionLevel::AlwaysAllow)
);
// Now change to NeverAllow
manager.update_user_permission("tool7", PermissionLevel::NeverAllow);
assert_eq!(
manager.get_user_permission("tool7"),
Some(PermissionLevel::NeverAllow)
);
// Ensure it's removed from other levels
let config = manager.permission_map.get(USER_PERMISSION).unwrap();
assert!(!config.always_allow.contains(&"tool7".to_string()));
assert!(!config.ask_before.contains(&"tool7".to_string()));
assert!(config.never_allow.contains(&"tool7".to_string()));
}
}

View File

@@ -1,10 +1,13 @@
use crate::agents::capabilities::Capabilities;
use crate::config::permission::PermissionLevel;
use crate::config::PermissionManager;
use crate::message::{Message, MessageContent, ToolRequest};
use chrono::Utc;
use indoc::indoc;
use mcp_core::tool::ToolAnnotations;
use mcp_core::{tool::Tool, TextContent};
use serde_json::{json, Value};
use std::collections::HashSet;
/// Creates the tool definition for checking read-only permissions.
fn create_read_only_tool() -> Tool {
@@ -136,6 +139,100 @@ pub async fn detect_read_only_tools(
}
}
// Define return structure
pub struct PermissionCheckResult {
pub approved: Vec<ToolRequest>,
pub needs_approval: Vec<ToolRequest>,
pub denied: Vec<ToolRequest>,
}
pub async fn check_tool_permissions(
remaining_requests: Vec<ToolRequest>,
mode: &str,
tools_with_readonly_annotation: HashSet<String>,
tools_without_annotation: HashSet<String>,
permission_manager: &mut PermissionManager,
capabilities: &Capabilities,
) -> PermissionCheckResult {
let mut approved = vec![];
let mut needs_approval = vec![];
let mut denied = vec![];
let mut llm_detect_candidates = vec![];
for request in remaining_requests {
if let Ok(tool_call) = request.tool_call.clone() {
// 1. Check user-defined permission
if let Some(level) = permission_manager.get_user_permission(&tool_call.name) {
match level {
PermissionLevel::AlwaysAllow => approved.push(request),
PermissionLevel::AskBefore => needs_approval.push(request),
PermissionLevel::NeverAllow => denied.push(request),
}
continue;
}
// 2. Fallback based on mode
match mode {
"manual_approve" => {
needs_approval.push(request);
}
"smart_approve" => {
if let Some(level) =
permission_manager.get_smart_approve_permission(&tool_call.name)
{
match level {
PermissionLevel::AlwaysAllow => approved.push(request),
PermissionLevel::AskBefore => needs_approval.push(request),
PermissionLevel::NeverAllow => denied.push(request),
}
continue;
}
if tools_with_readonly_annotation.contains(&tool_call.name) {
approved.push(request);
} else if tools_without_annotation.contains(&tool_call.name) {
llm_detect_candidates.push(request);
} else {
needs_approval.push(request);
}
}
_ => {
needs_approval.push(request);
}
}
}
}
// 3. LLM detect
if !llm_detect_candidates.is_empty() && mode == "smart_approve" {
let detected_readonly_tools =
detect_read_only_tools(capabilities, llm_detect_candidates.iter().collect()).await;
for request in llm_detect_candidates {
if let Ok(tool_call) = request.tool_call.clone() {
if detected_readonly_tools.contains(&tool_call.name) {
approved.push(request);
permission_manager.update_smart_approve_permission(
&tool_call.name,
PermissionLevel::AlwaysAllow,
);
} else {
needs_approval.push(request);
permission_manager.update_smart_approve_permission(
&tool_call.name,
PermissionLevel::AskBefore,
);
}
}
}
}
PermissionCheckResult {
approved,
needs_approval,
denied,
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -148,6 +245,7 @@ mod tests {
use mcp_core::ToolCall;
use mcp_core::{tool::Tool, Role, ToolResult};
use serde_json::json;
use tempfile::NamedTempFile;
#[derive(Clone)]
struct MockProvider {
@@ -270,4 +368,60 @@ mod tests {
let result = detect_read_only_tools(&capabilities, vec![]).await;
assert!(result.is_empty());
}
#[tokio::test]
async fn test_check_tool_permissions() {
// Setup mocks
let temp_file = NamedTempFile::new().unwrap();
let temp_path = temp_file.path();
let mut permission_manager = PermissionManager::new(temp_path);
let capabilities = create_mock_capabilities();
let tools_with_readonly_annotation: HashSet<String> =
vec!["file_reader".to_string()].into_iter().collect();
let tools_without_annotation: HashSet<String> =
vec!["data_fetcher".to_string()].into_iter().collect();
permission_manager.update_user_permission("file_reader", PermissionLevel::AlwaysAllow);
permission_manager
.update_smart_approve_permission("data_fetcher", PermissionLevel::AskBefore);
let tool_request_1 = ToolRequest {
id: "tool_1".to_string(),
tool_call: ToolResult::Ok(ToolCall {
name: "file_reader".to_string(),
arguments: serde_json::json!({"path": "/path/to/file"}),
}),
};
let tool_request_2 = ToolRequest {
id: "tool_2".to_string(),
tool_call: ToolResult::Ok(ToolCall {
name: "data_fetcher".to_string(),
arguments: serde_json::json!({"url": "http://example.com"}),
}),
};
let remaining_requests = vec![tool_request_1, tool_request_2];
// Call the function under test
let result = check_tool_permissions(
remaining_requests,
"smart_approve",
tools_with_readonly_annotation,
tools_without_annotation,
&mut permission_manager,
&capabilities,
)
.await;
// Validate the result
assert_eq!(result.approved.len(), 1); // file_reader should be approved
assert_eq!(result.needs_approval.len(), 1); // data_fetcher should need approval
assert_eq!(result.denied.len(), 0); // No tool should be denied in this test
// Ensure the right tools are in the approved and needs_approval lists
assert!(result.approved.iter().any(|req| req.id == "tool_1"));
assert!(result.needs_approval.iter().any(|req| req.id == "tool_2"));
}
}