feat: implement a tool permission store (#1516)

This commit is contained in:
Wendy Tang
2025-03-07 10:54:01 -08:00
committed by GitHub
parent 2383de5921
commit c4f571aeec
6 changed files with 234 additions and 9 deletions

26
Cargo.lock generated
View File

@@ -185,6 +185,12 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236"
[[package]]
name = "arrayref"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
[[package]]
name = "arrayvec"
version = "0.7.6"
@@ -918,6 +924,19 @@ version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2"
[[package]]
name = "blake3"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "675f87afced0413c9bb02843499dbbd3882a237645883f71a2b59644a6d2f753"
dependencies = [
"arrayref",
"arrayvec",
"cc",
"cfg-if",
"constant_time_eq",
]
[[package]]
name = "block-buffer"
version = "0.10.4"
@@ -1313,6 +1332,12 @@ dependencies = [
"tiny-keccak",
]
[[package]]
name = "constant_time_eq"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6"
[[package]]
name = "content_inspector"
version = "0.2.4"
@@ -2179,6 +2204,7 @@ dependencies = [
"aws-smithy-types",
"axum 0.7.9",
"base64 0.21.7",
"blake3",
"chrono",
"criterion",
"ctor",

View File

@@ -69,6 +69,9 @@ aws-sdk-bedrockruntime = "1.72.0"
# For GCP Vertex AI provider auth
jsonwebtoken = "9.3.1"
# Added blake3 hashing library as a dependency
blake3 = "1.5"
[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }

View File

@@ -3,6 +3,7 @@ mod capabilities;
pub mod extension;
mod factory;
mod permission_judge;
mod permission_store;
mod reference;
mod summarize;
mod truncate;
@@ -12,3 +13,4 @@ pub use capabilities::Capabilities;
pub use extension::ExtensionConfig;
pub use factory::{register_agent, AgentFactory};
pub use permission_judge::detect_read_only_tools;
pub use permission_store::ToolPermissionStore;

View File

@@ -0,0 +1,149 @@
use crate::message::ToolRequest;
use anyhow::Result;
use blake3::Hasher;
use chrono::Utc;
use etcetera::{choose_app_strategy, AppStrategy};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use std::{fs::File, path::PathBuf};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolPermissionRecord {
tool_name: String,
allowed: bool,
context_hash: String, // Hash of the tool's arguments/context to differentiate similar calls
#[serde(skip_serializing_if = "Option::is_none")] // Don't serialize if None
readable_context: Option<String>, // Add this field
timestamp: i64,
expiry: Option<i64>, // Optional expiry timestamp
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ToolPermissionStore {
permissions: HashMap<String, Vec<ToolPermissionRecord>>,
version: u32, // For future schema migrations
#[serde(skip)] // Don't serialize this field
permissions_dir: PathBuf,
}
impl Default for ToolPermissionStore {
fn default() -> Self {
Self::new()
}
}
impl ToolPermissionStore {
pub fn new() -> Self {
let permissions_dir = choose_app_strategy(crate::config::APP_STRATEGY.clone())
.map(|strategy| strategy.config_dir())
.unwrap_or_else(|_| PathBuf::from(".config/goose"));
Self {
permissions: HashMap::new(),
version: 1,
permissions_dir,
}
}
pub fn load() -> Result<Self> {
let store = Self::new();
let file_path = store.permissions_dir.join("tool_permissions.json");
if !file_path.exists() {
return Ok(store);
}
let file = File::open(file_path)?;
let mut permissions: ToolPermissionStore = serde_json::from_reader(file)?;
permissions.permissions_dir = store.permissions_dir;
// Clean up expired entries on load
permissions.cleanup_expired()?;
Ok(permissions)
}
pub fn save(&self) -> anyhow::Result<()> {
std::fs::create_dir_all(&self.permissions_dir)?;
let path = self.permissions_dir.join("tool_permissions.json");
let temp_path = path.with_extension("tmp");
// Write complete content to temporary file
let content = serde_json::to_string_pretty(self)?;
std::fs::write(&temp_path, &content)?;
// Atomically rename temp file to target file
std::fs::rename(temp_path, path)?;
Ok(())
}
pub fn check_permission(&self, tool_request: &ToolRequest) -> Option<bool> {
let context_hash = self.hash_tool_context(tool_request);
let tool_call = tool_request.tool_call.as_ref().unwrap();
let key = format!("{}:{}", tool_call.name, context_hash);
self.permissions.get(&key).and_then(|records| {
records
.iter()
.filter(|record| record.expiry.is_none_or(|exp| exp > Utc::now().timestamp()))
.last()
.map(|record| record.allowed)
})
}
pub fn record_permission(
&mut self,
tool_request: &ToolRequest,
allowed: bool,
expiry_duration: Option<Duration>,
) -> anyhow::Result<()> {
let context_hash = self.hash_tool_context(tool_request);
let tool_call = tool_request.tool_call.as_ref().unwrap();
let key = format!("{}:{}", tool_call.name, context_hash);
let record = ToolPermissionRecord {
tool_name: tool_call.name.clone(),
allowed,
context_hash,
readable_context: Some(tool_request.to_readable_string()),
timestamp: Utc::now().timestamp(),
expiry: expiry_duration.map(|d| Utc::now().timestamp() + d.as_secs() as i64),
};
self.permissions.entry(key).or_default().push(record);
self.save()?;
Ok(())
}
fn hash_tool_context(&self, tool_request: &ToolRequest) -> String {
// Create a hash of the tool's arguments to differentiate similar calls
// This helps identify when the same tool is being used in a different context
let mut hasher = Hasher::new();
hasher.update(
serde_json::to_string(&tool_request.tool_call.as_ref().unwrap().arguments)
.unwrap_or_default()
.as_bytes(),
);
hasher.finalize().to_hex().to_string()
}
pub fn cleanup_expired(&mut self) -> anyhow::Result<()> {
let now = Utc::now().timestamp();
let mut changed = false;
self.permissions.retain(|_, records| {
records.retain(|record| record.expiry.is_none_or(|exp| exp > now));
changed = changed || records.is_empty();
!records.is_empty()
});
if changed {
self.save()?;
}
Ok(())
}
}

View File

@@ -13,6 +13,7 @@ use super::detect_read_only_tools;
use super::Agent;
use crate::agents::capabilities::Capabilities;
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::agents::ToolPermissionStore;
use crate::config::Config;
use crate::config::ExperimentManager;
use crate::message::{Message, ToolRequest};
@@ -29,6 +30,7 @@ use mcp_core::prompt::Prompt;
use mcp_core::protocol::GetPromptResult;
use mcp_core::{tool::Tool, Content};
use serde_json::{json, Value};
use std::time::Duration;
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
@@ -267,19 +269,43 @@ impl Agent for TruncateAgent {
match mode.as_str() {
"approve" => {
let mut read_only_tools = Vec::new();
// Process each tool request sequentially with confirmation
if ExperimentManager::is_enabled("GOOSE_SMART_APPROVE")? {
read_only_tools = detect_read_only_tools(&capabilities, tool_requests.clone()).await;
let mut needs_confirmation = Vec::<&ToolRequest>::new();
// First check permissions for all tools
let store = ToolPermissionStore::load()?;
for request in tool_requests.iter() {
if let Ok(tool_call) = request.tool_call.clone() {
if let Some(allowed) = store.check_permission(request) {
if allowed {
let output = capabilities.dispatch_tool_call(tool_call).await;
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
} else {
needs_confirmation.push(request);
}
} else {
needs_confirmation.push(request);
}
}
}
for request in &tool_requests {
// Only check read-only status for tools needing confirmation
if !needs_confirmation.is_empty() && ExperimentManager::is_enabled("GOOSE_SMART_APPROVE")? {
read_only_tools = detect_read_only_tools(&capabilities, needs_confirmation.clone()).await;
}
// Process remaining tools that need confirmation
for request in &needs_confirmation {
if let Ok(tool_call) = request.tool_call.clone() {
// Skip confirmation if the tool_call.name is in the read_only_tools list
if read_only_tools.contains(&tool_call.name) {
let output = capabilities.dispatch_tool_call(tool_call).await;
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
} else {
let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
@@ -291,9 +317,12 @@ impl Agent for TruncateAgent {
// Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await;
// Loop the recv until we have a matched req_id due to potential duplicate messages.
while let Some((req_id, confirmed)) = rx.recv().await {
if req_id == request.id {
// Store the user's response with 30-day expiration
let mut store = ToolPermissionStore::load()?;
store.record_permission(request, confirmed, Some(Duration::from_secs(30 * 24 * 60 * 60)))?;
if confirmed {
// User approved - dispatch the tool call
let output = capabilities.dispatch_tool_call(tool_call).await;

View File

@@ -26,6 +26,22 @@ pub struct ToolRequest {
pub tool_call: ToolResult<ToolCall>,
}
impl ToolRequest {
pub fn to_readable_string(&self) -> String {
match &self.tool_call {
Ok(tool_call) => {
format!(
"Tool: {}, Args: {}",
tool_call.name,
serde_json::to_string_pretty(&tool_call.arguments)
.unwrap_or_else(|_| "<<invalid json>>".to_string())
)
}
Err(e) => format!("Invalid tool call: {}", e),
}
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolResponse {