feat: start use tool permission confirmation struct (#2044)

This commit is contained in:
Yingjie He
2025-04-04 16:39:04 -07:00
committed by GitHub
parent bfed33c5a4
commit 890f266e7b
12 changed files with 88 additions and 29 deletions

View File

@@ -6,6 +6,9 @@ mod prompt;
mod thinking; mod thinking;
pub use builder::build_session; pub use builder::build_session;
use goose::permission::permission_confirmation::PrincipalType;
use goose::permission::Permission;
use goose::permission::PermissionConfirmation;
use goose::providers::base::Provider; use goose::providers::base::Provider;
pub use goose::session::Identifier; pub use goose::session::Identifier;
@@ -598,7 +601,16 @@ impl Session {
// Get confirmation from user // Get confirmation from user
let confirmed = cliclack::confirm(prompt).initial_value(true).interact()?; let confirmed = cliclack::confirm(prompt).initial_value(true).interact()?;
self.agent.handle_confirmation(confirmation.id.clone(), confirmed).await; let permission = if confirmed {
Permission::AllowOnce
} else {
Permission::DenyOnce
};
self.agent.handle_confirmation(confirmation.id.clone(), PermissionConfirmation {
principal_name: "tool_name_placeholder".to_string(),
principal_type: PrincipalType::Tool,
permission,
},).await;
} }
// otherwise we have a model/tool to render // otherwise we have a model/tool to render
else { else {

View File

@@ -8,10 +8,14 @@ use axum::{
}; };
use bytes::Bytes; use bytes::Bytes;
use futures::{stream::StreamExt, Stream}; use futures::{stream::StreamExt, Stream};
use goose::session;
use goose::{ use goose::{
agents::SessionConfig, agents::SessionConfig,
message::{Message, MessageContent}, message::{Message, MessageContent},
permission::permission_confirmation::PrincipalType,
};
use goose::{
permission::{Permission, PermissionConfirmation},
session,
}; };
use mcp_core::{role::Role, Content, ToolResult}; use mcp_core::{role::Role, Content, ToolResult};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -385,8 +389,20 @@ async fn confirm_handler(
let agent = state.agent.clone(); let agent = state.agent.clone();
let agent = agent.read().await; let agent = agent.read().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?; let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
let permission = if request.confirmed {
Permission::AllowOnce
} else {
Permission::DenyOnce
};
agent agent
.handle_confirmation(request.id.clone(), request.confirmed) .handle_confirmation(
request.id.clone(),
PermissionConfirmation {
principal_name: "tool_name_placeholder".to_string(),
principal_type: PrincipalType::Tool,
permission,
},
)
.await; .await;
Ok(Json(Value::Object(serde_json::Map::new()))) Ok(Json(Value::Object(serde_json::Map::new())))
} }

View File

@@ -9,9 +9,9 @@ use serde_json::Value;
use std::sync::Arc; use std::sync::Arc;
use super::extension::{ExtensionConfig, ExtensionResult}; use super::extension::{ExtensionConfig, ExtensionResult};
use crate::message::Message;
use crate::providers::base::Provider; use crate::providers::base::Provider;
use crate::session; use crate::session;
use crate::{message::Message, permission::PermissionConfirmation};
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolResult}; use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolResult};
/// Session configuration for an agent /// Session configuration for an agent
@@ -50,7 +50,7 @@ pub trait Agent: Send + Sync {
async fn extend_system_prompt(&mut self, extension: String); async fn extend_system_prompt(&mut self, extension: String);
/// Handle a confirmation response for a tool request /// Handle a confirmation response for a tool request
async fn handle_confirmation(&self, request_id: String, confirmed: bool); async fn handle_confirmation(&self, request_id: String, confirmation: PermissionConfirmation);
/// Override the system prompt with custom text /// Override the system prompt with custom text
async fn override_system_prompt(&mut self, template: String); async fn override_system_prompt(&mut self, template: String);

View File

@@ -1,9 +1,7 @@
mod agent; mod agent;
mod capabilities; pub mod capabilities;
pub mod extension; pub mod extension;
mod factory; mod factory;
mod permission_judge;
mod permission_store;
mod reference; mod reference;
mod summarize; mod summarize;
mod truncate; mod truncate;
@@ -13,5 +11,3 @@ pub use agent::{Agent, SessionConfig};
pub use capabilities::Capabilities; pub use capabilities::Capabilities;
pub use extension::ExtensionConfig; pub use extension::ExtensionConfig;
pub use factory::{register_agent, AgentFactory}; pub use factory::{register_agent, AgentFactory};
pub use permission_judge::detect_read_only_tools;
pub use permission_store::ToolPermissionStore;

View File

@@ -15,6 +15,7 @@ use super::Agent;
use crate::agents::capabilities::Capabilities; use crate::agents::capabilities::Capabilities;
use crate::agents::extension::{ExtensionConfig, ExtensionResult}; use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::message::{Message, ToolRequest}; use crate::message::{Message, ToolRequest};
use crate::permission::PermissionConfirmation;
use crate::providers::base::Provider; use crate::providers::base::Provider;
use crate::token_counter::TokenCounter; use crate::token_counter::TokenCounter;
use crate::{register_agent, session}; use crate::{register_agent, session};
@@ -73,7 +74,11 @@ impl Agent for ReferenceAgent {
Ok(Value::Null) Ok(Value::Null)
} }
async fn handle_confirmation(&self, _request_id: String, _confirmed: bool) { async fn handle_confirmation(
&self,
_request_id: String,
_confirmation: PermissionConfirmation,
) {
// TODO implement // TODO implement
} }

View File

@@ -12,7 +12,6 @@ use tracing::{debug, error, instrument, warn};
use super::agent::SessionConfig; use super::agent::SessionConfig;
use super::capabilities::get_parameter_names; use super::capabilities::get_parameter_names;
use super::detect_read_only_tools;
use super::extension::ToolInfo; use super::extension::ToolInfo;
use super::Agent; use super::Agent;
use crate::agents::capabilities::Capabilities; use crate::agents::capabilities::Capabilities;
@@ -20,6 +19,9 @@ use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::config::Config; use crate::config::Config;
use crate::memory_condense::condense_messages; use crate::memory_condense::condense_messages;
use crate::message::{Message, ToolRequest}; use crate::message::{Message, ToolRequest};
use crate::permission::detect_read_only_tools;
use crate::permission::Permission;
use crate::permission::PermissionConfirmation;
use crate::providers::base::Provider; use crate::providers::base::Provider;
use crate::providers::errors::ProviderError; use crate::providers::errors::ProviderError;
use crate::register_agent; use crate::register_agent;
@@ -38,8 +40,8 @@ const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
pub struct SummarizeAgent { pub struct SummarizeAgent {
capabilities: Mutex<Capabilities>, capabilities: Mutex<Capabilities>,
token_counter: TokenCounter, token_counter: TokenCounter,
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed) confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>, confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>, tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
} }
@@ -159,8 +161,8 @@ impl Agent for SummarizeAgent {
} }
/// Handle a confirmation response for a tool request /// Handle a confirmation response for a tool request
async fn handle_confirmation(&self, request_id: String, confirmed: bool) { async fn handle_confirmation(&self, request_id: String, confirmation: PermissionConfirmation) {
if let Err(e) = self.confirmation_tx.send((request_id, confirmed)).await { if let Err(e) = self.confirmation_tx.send((request_id, confirmation)).await {
error!("Failed to send confirmation: {}", e); error!("Failed to send confirmation: {}", e);
} }
} }
@@ -321,9 +323,9 @@ impl Agent for SummarizeAgent {
// Wait for confirmation response through the channel // Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await; let mut rx = self.confirmation_rx.lock().await;
// Loop the recv until we have a matched req_id due to potential duplicate messages. // Loop the recv until we have a matched req_id due to potential duplicate messages.
while let Some((req_id, confirmed)) = rx.recv().await { while let Some((req_id, tool_confirmation)) = rx.recv().await {
if req_id == request.id { if req_id == request.id {
if confirmed { if tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow {
// User approved - dispatch the tool call // User approved - dispatch the tool call
let output = capabilities.dispatch_tool_call(tool_call).await; let output = capabilities.dispatch_tool_call(tool_call).await;
message_tool_response = message_tool_response.with_tool_response( message_tool_response = message_tool_response.with_tool_response(

View File

@@ -10,15 +10,17 @@ use tokio::sync::Mutex;
use tracing::{debug, error, instrument, warn}; use tracing::{debug, error, instrument, warn};
use super::agent::SessionConfig; use super::agent::SessionConfig;
use super::detect_read_only_tools;
use super::extension::ToolInfo; use super::extension::ToolInfo;
use super::types::ToolResultReceiver; use super::types::ToolResultReceiver;
use super::Agent; use super::Agent;
use crate::agents::capabilities::{get_parameter_names, Capabilities}; use crate::agents::capabilities::{get_parameter_names, Capabilities};
use crate::agents::extension::{ExtensionConfig, ExtensionResult}; use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::agents::ToolPermissionStore;
use crate::config::Config; use crate::config::Config;
use crate::message::{Message, MessageContent, ToolRequest}; use crate::message::{Message, MessageContent, ToolRequest};
use crate::permission::detect_read_only_tools;
use crate::permission::Permission;
use crate::permission::PermissionConfirmation;
use crate::permission::ToolPermissionStore;
use crate::providers::base::Provider; use crate::providers::base::Provider;
use crate::providers::errors::ProviderError; use crate::providers::errors::ProviderError;
use crate::providers::toolshim::{ use crate::providers::toolshim::{
@@ -34,7 +36,6 @@ use mcp_core::{
prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult, prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult,
}; };
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::time::Duration;
const MAX_TRUNCATION_ATTEMPTS: usize = 3; const MAX_TRUNCATION_ATTEMPTS: usize = 3;
const ESTIMATE_FACTOR_DECAY: f32 = 0.9; const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
@@ -43,8 +44,8 @@ const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
pub struct TruncateAgent { pub struct TruncateAgent {
capabilities: Mutex<Capabilities>, capabilities: Mutex<Capabilities>,
token_counter: TokenCounter, token_counter: TokenCounter,
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed) confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>, confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>, tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
tool_result_rx: ToolResultReceiver, tool_result_rx: ToolResultReceiver,
} }
@@ -160,8 +161,8 @@ impl Agent for TruncateAgent {
} }
/// Handle a confirmation response for a tool request /// Handle a confirmation response for a tool request
async fn handle_confirmation(&self, request_id: String, confirmed: bool) { async fn handle_confirmation(&self, request_id: String, confirmation: PermissionConfirmation) {
if let Err(e) = self.confirmation_tx.send((request_id, confirmed)).await { if let Err(e) = self.confirmation_tx.send((request_id, confirmation)).await {
error!("Failed to send confirmation: {}", e); error!("Failed to send confirmation: {}", e);
} }
} }
@@ -432,12 +433,10 @@ impl Agent for TruncateAgent {
// Wait for confirmation response through the channel // Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await; let mut rx = self.confirmation_rx.lock().await;
while let Some((req_id, confirmed)) = rx.recv().await { while let Some((req_id, tool_confirmation)) = rx.recv().await {
if req_id == request.id { if req_id == request.id {
// Store the user's response with 30-day expiration // Store the user's response with 30-day expiration
let mut store = ToolPermissionStore::load()?; let confirmed = tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow;
store.record_permission(request, confirmed, Some(Duration::from_secs(30 * 24 * 60 * 60)))?;
if confirmed { if confirmed {
// Add this tool call to the futures collection // Add this tool call to the futures collection
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone()); let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());

View File

@@ -3,6 +3,7 @@ pub mod config;
pub mod memory_condense; pub mod memory_condense;
pub mod message; pub mod message;
pub mod model; pub mod model;
pub mod permission;
pub mod prompt_template; pub mod prompt_template;
pub mod providers; pub mod providers;
pub mod session; pub mod session;

View File

@@ -0,0 +1,7 @@
pub mod permission_confirmation;
pub mod permission_judge;
pub mod permission_store;
pub use permission_confirmation::{Permission, PermissionConfirmation};
pub use permission_judge::detect_read_only_tools;
pub use permission_store::ToolPermissionStore;

View File

@@ -0,0 +1,21 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub enum Permission {
AlwaysAllow,
AllowOnce,
DenyOnce,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub enum PrincipalType {
Extention,
Tool,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PermissionConfirmation {
pub principal_name: String,
pub principal_type: PrincipalType,
pub permission: Permission,
}