mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-23 07:24:24 +01:00
feat: implement a tool permission store (#1516)
This commit is contained in:
26
Cargo.lock
generated
26
Cargo.lock
generated
@@ -185,6 +185,12 @@ version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236"
|
||||
|
||||
[[package]]
|
||||
name = "arrayref"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb"
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.6"
|
||||
@@ -918,6 +924,19 @@ version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2"
|
||||
|
||||
[[package]]
|
||||
name = "blake3"
|
||||
version = "1.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "675f87afced0413c9bb02843499dbbd3882a237645883f71a2b59644a6d2f753"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"arrayvec",
|
||||
"cc",
|
||||
"cfg-if",
|
||||
"constant_time_eq",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.4"
|
||||
@@ -1313,6 +1332,12 @@ dependencies = [
|
||||
"tiny-keccak",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "constant_time_eq"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6"
|
||||
|
||||
[[package]]
|
||||
name = "content_inspector"
|
||||
version = "0.2.4"
|
||||
@@ -2179,6 +2204,7 @@ dependencies = [
|
||||
"aws-smithy-types",
|
||||
"axum 0.7.9",
|
||||
"base64 0.21.7",
|
||||
"blake3",
|
||||
"chrono",
|
||||
"criterion",
|
||||
"ctor",
|
||||
|
||||
@@ -69,6 +69,9 @@ aws-sdk-bedrockruntime = "1.72.0"
|
||||
# For GCP Vertex AI provider auth
|
||||
jsonwebtoken = "9.3.1"
|
||||
|
||||
# Added blake3 hashing library as a dependency
|
||||
blake3 = "1.5"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
winapi = { version = "0.3", features = ["wincred"] }
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ mod capabilities;
|
||||
pub mod extension;
|
||||
mod factory;
|
||||
mod permission_judge;
|
||||
mod permission_store;
|
||||
mod reference;
|
||||
mod summarize;
|
||||
mod truncate;
|
||||
@@ -12,3 +13,4 @@ pub use capabilities::Capabilities;
|
||||
pub use extension::ExtensionConfig;
|
||||
pub use factory::{register_agent, AgentFactory};
|
||||
pub use permission_judge::detect_read_only_tools;
|
||||
pub use permission_store::ToolPermissionStore;
|
||||
|
||||
149
crates/goose/src/agents/permission_store.rs
Normal file
149
crates/goose/src/agents/permission_store.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
use crate::message::ToolRequest;
|
||||
use anyhow::Result;
|
||||
use blake3::Hasher;
|
||||
use chrono::Utc;
|
||||
use etcetera::{choose_app_strategy, AppStrategy};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use std::{fs::File, path::PathBuf};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ToolPermissionRecord {
|
||||
tool_name: String,
|
||||
allowed: bool,
|
||||
context_hash: String, // Hash of the tool's arguments/context to differentiate similar calls
|
||||
#[serde(skip_serializing_if = "Option::is_none")] // Don't serialize if None
|
||||
readable_context: Option<String>, // Add this field
|
||||
timestamp: i64,
|
||||
expiry: Option<i64>, // Optional expiry timestamp
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ToolPermissionStore {
|
||||
permissions: HashMap<String, Vec<ToolPermissionRecord>>,
|
||||
version: u32, // For future schema migrations
|
||||
#[serde(skip)] // Don't serialize this field
|
||||
permissions_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl Default for ToolPermissionStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolPermissionStore {
|
||||
pub fn new() -> Self {
|
||||
let permissions_dir = choose_app_strategy(crate::config::APP_STRATEGY.clone())
|
||||
.map(|strategy| strategy.config_dir())
|
||||
.unwrap_or_else(|_| PathBuf::from(".config/goose"));
|
||||
|
||||
Self {
|
||||
permissions: HashMap::new(),
|
||||
version: 1,
|
||||
permissions_dir,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load() -> Result<Self> {
|
||||
let store = Self::new();
|
||||
let file_path = store.permissions_dir.join("tool_permissions.json");
|
||||
|
||||
if !file_path.exists() {
|
||||
return Ok(store);
|
||||
}
|
||||
|
||||
let file = File::open(file_path)?;
|
||||
let mut permissions: ToolPermissionStore = serde_json::from_reader(file)?;
|
||||
permissions.permissions_dir = store.permissions_dir;
|
||||
|
||||
// Clean up expired entries on load
|
||||
permissions.cleanup_expired()?;
|
||||
|
||||
Ok(permissions)
|
||||
}
|
||||
|
||||
pub fn save(&self) -> anyhow::Result<()> {
|
||||
std::fs::create_dir_all(&self.permissions_dir)?;
|
||||
|
||||
let path = self.permissions_dir.join("tool_permissions.json");
|
||||
let temp_path = path.with_extension("tmp");
|
||||
|
||||
// Write complete content to temporary file
|
||||
let content = serde_json::to_string_pretty(self)?;
|
||||
std::fs::write(&temp_path, &content)?;
|
||||
|
||||
// Atomically rename temp file to target file
|
||||
std::fs::rename(temp_path, path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn check_permission(&self, tool_request: &ToolRequest) -> Option<bool> {
|
||||
let context_hash = self.hash_tool_context(tool_request);
|
||||
let tool_call = tool_request.tool_call.as_ref().unwrap();
|
||||
let key = format!("{}:{}", tool_call.name, context_hash);
|
||||
|
||||
self.permissions.get(&key).and_then(|records| {
|
||||
records
|
||||
.iter()
|
||||
.filter(|record| record.expiry.is_none_or(|exp| exp > Utc::now().timestamp()))
|
||||
.last()
|
||||
.map(|record| record.allowed)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn record_permission(
|
||||
&mut self,
|
||||
tool_request: &ToolRequest,
|
||||
allowed: bool,
|
||||
expiry_duration: Option<Duration>,
|
||||
) -> anyhow::Result<()> {
|
||||
let context_hash = self.hash_tool_context(tool_request);
|
||||
let tool_call = tool_request.tool_call.as_ref().unwrap();
|
||||
let key = format!("{}:{}", tool_call.name, context_hash);
|
||||
|
||||
let record = ToolPermissionRecord {
|
||||
tool_name: tool_call.name.clone(),
|
||||
allowed,
|
||||
context_hash,
|
||||
readable_context: Some(tool_request.to_readable_string()),
|
||||
timestamp: Utc::now().timestamp(),
|
||||
expiry: expiry_duration.map(|d| Utc::now().timestamp() + d.as_secs() as i64),
|
||||
};
|
||||
|
||||
self.permissions.entry(key).or_default().push(record);
|
||||
|
||||
self.save()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn hash_tool_context(&self, tool_request: &ToolRequest) -> String {
|
||||
// Create a hash of the tool's arguments to differentiate similar calls
|
||||
// This helps identify when the same tool is being used in a different context
|
||||
let mut hasher = Hasher::new();
|
||||
hasher.update(
|
||||
serde_json::to_string(&tool_request.tool_call.as_ref().unwrap().arguments)
|
||||
.unwrap_or_default()
|
||||
.as_bytes(),
|
||||
);
|
||||
hasher.finalize().to_hex().to_string()
|
||||
}
|
||||
|
||||
pub fn cleanup_expired(&mut self) -> anyhow::Result<()> {
|
||||
let now = Utc::now().timestamp();
|
||||
let mut changed = false;
|
||||
|
||||
self.permissions.retain(|_, records| {
|
||||
records.retain(|record| record.expiry.is_none_or(|exp| exp > now));
|
||||
changed = changed || records.is_empty();
|
||||
!records.is_empty()
|
||||
});
|
||||
|
||||
if changed {
|
||||
self.save()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ use super::detect_read_only_tools;
|
||||
use super::Agent;
|
||||
use crate::agents::capabilities::Capabilities;
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
|
||||
use crate::agents::ToolPermissionStore;
|
||||
use crate::config::Config;
|
||||
use crate::config::ExperimentManager;
|
||||
use crate::message::{Message, ToolRequest};
|
||||
@@ -29,6 +30,7 @@ use mcp_core::prompt::Prompt;
|
||||
use mcp_core::protocol::GetPromptResult;
|
||||
use mcp_core::{tool::Tool, Content};
|
||||
use serde_json::{json, Value};
|
||||
use std::time::Duration;
|
||||
|
||||
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
|
||||
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
|
||||
@@ -267,19 +269,43 @@ impl Agent for TruncateAgent {
|
||||
match mode.as_str() {
|
||||
"approve" => {
|
||||
let mut read_only_tools = Vec::new();
|
||||
// Process each tool request sequentially with confirmation
|
||||
if ExperimentManager::is_enabled("GOOSE_SMART_APPROVE")? {
|
||||
read_only_tools = detect_read_only_tools(&capabilities, tool_requests.clone()).await;
|
||||
let mut needs_confirmation = Vec::<&ToolRequest>::new();
|
||||
|
||||
// First check permissions for all tools
|
||||
let store = ToolPermissionStore::load()?;
|
||||
for request in tool_requests.iter() {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
if let Some(allowed) = store.check_permission(request) {
|
||||
if allowed {
|
||||
let output = capabilities.dispatch_tool_call(tool_call).await;
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
output,
|
||||
);
|
||||
} else {
|
||||
needs_confirmation.push(request);
|
||||
}
|
||||
} else {
|
||||
needs_confirmation.push(request);
|
||||
}
|
||||
}
|
||||
}
|
||||
for request in &tool_requests {
|
||||
|
||||
// Only check read-only status for tools needing confirmation
|
||||
if !needs_confirmation.is_empty() && ExperimentManager::is_enabled("GOOSE_SMART_APPROVE")? {
|
||||
read_only_tools = detect_read_only_tools(&capabilities, needs_confirmation.clone()).await;
|
||||
}
|
||||
|
||||
// Process remaining tools that need confirmation
|
||||
for request in &needs_confirmation {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
// Skip confirmation if the tool_call.name is in the read_only_tools list
|
||||
if read_only_tools.contains(&tool_call.name) {
|
||||
let output = capabilities.dispatch_tool_call(tool_call).await;
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
output,
|
||||
);
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
output,
|
||||
);
|
||||
} else {
|
||||
let confirmation = Message::user().with_tool_confirmation_request(
|
||||
request.id.clone(),
|
||||
@@ -291,9 +317,12 @@ impl Agent for TruncateAgent {
|
||||
|
||||
// Wait for confirmation response through the channel
|
||||
let mut rx = self.confirmation_rx.lock().await;
|
||||
// Loop the recv until we have a matched req_id due to potential duplicate messages.
|
||||
while let Some((req_id, confirmed)) = rx.recv().await {
|
||||
if req_id == request.id {
|
||||
// Store the user's response with 30-day expiration
|
||||
let mut store = ToolPermissionStore::load()?;
|
||||
store.record_permission(request, confirmed, Some(Duration::from_secs(30 * 24 * 60 * 60)))?;
|
||||
|
||||
if confirmed {
|
||||
// User approved - dispatch the tool call
|
||||
let output = capabilities.dispatch_tool_call(tool_call).await;
|
||||
|
||||
@@ -26,6 +26,22 @@ pub struct ToolRequest {
|
||||
pub tool_call: ToolResult<ToolCall>,
|
||||
}
|
||||
|
||||
impl ToolRequest {
|
||||
pub fn to_readable_string(&self) -> String {
|
||||
match &self.tool_call {
|
||||
Ok(tool_call) => {
|
||||
format!(
|
||||
"Tool: {}, Args: {}",
|
||||
tool_call.name,
|
||||
serde_json::to_string_pretty(&tool_call.arguments)
|
||||
.unwrap_or_else(|_| "<<invalid json>>".to_string())
|
||||
)
|
||||
}
|
||||
Err(e) => format!("Invalid tool call: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolResponse {
|
||||
|
||||
Reference in New Issue
Block a user