feat: support tool level permission control in ui (#2133)

This commit is contained in:
Yingjie He
2025-04-11 12:19:08 -07:00
committed by GitHub
parent cb32160a49
commit df7f2b8ab9
14 changed files with 1387 additions and 113 deletions

View File

@@ -21,6 +21,7 @@ use utoipa::OpenApi;
super::routes::config_management::get_extensions,
super::routes::config_management::read_all_config,
super::routes::config_management::providers,
super::routes::config_management::upsert_permissions,
super::routes::agent::get_tools,
),
components(schemas(
@@ -31,6 +32,8 @@ use utoipa::OpenApi;
super::routes::config_management::ProviderDetails,
super::routes::config_management::ExtensionResponse,
super::routes::config_management::ExtensionQuery,
super::routes::config_management::ToolPermission,
super::routes::config_management::UpsertPermissionsQuery,
ProviderMetadata,
ExtensionEntry,
ExtensionConfig,

View File

@@ -5,10 +5,13 @@ use axum::{
routing::{get, post},
Json, Router,
};
use goose::agents::{extension::ToolInfo, extension_manager::get_parameter_names};
use goose::config::Config;
use goose::config::PermissionManager;
use goose::{agents::Agent, model::ModelConfig, providers};
use goose::{
agents::{extension::ToolInfo, extension_manager::get_parameter_names},
config::permission::PermissionLevel,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
@@ -173,7 +176,7 @@ async fn list_providers() -> Json<Vec<ProviderList>> {
("extension_name" = Option<String>, Query, description = "Optional extension name to filter tools")
),
responses(
(status = 200, description = "Tools retrieved successfully", body = Vec<Tool>),
(status = 200, description = "Tools retrieved successfully", body = Vec<ToolInfo>),
(status = 401, description = "Unauthorized - invalid secret key"),
(status = 424, description = "Agent not initialized"),
(status = 500, description = "Internal server error")
@@ -193,11 +196,13 @@ async fn get_tools(
return Err(StatusCode::UNAUTHORIZED);
}
let config = Config::global();
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
let mut agent = state.agent.write().await;
let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
let permission_manager = PermissionManager::default();
let tools = agent
let mut tools: Vec<ToolInfo> = agent
.list_tools()
.await
.into_iter()
@@ -210,14 +215,27 @@ async fn get_tools(
}
})
.map(|tool| {
let permission = permission_manager
.get_user_permission(&tool.name)
.or_else(|| {
if goose_mode == "smart_approve" {
permission_manager.get_smart_approve_permission(&tool.name)
} else if goose_mode == "approve" {
Some(PermissionLevel::AskBefore)
} else {
None
}
});
ToolInfo::new(
&tool.name,
&tool.description,
get_parameter_names(&tool),
permission_manager.get_user_permission(&tool.name),
permission,
)
})
.collect();
.collect::<Vec<ToolInfo>>();
tools.sort_by(|a, b| a.name.cmp(&b.name));
Ok(Json(tools))
}

View File

@@ -5,12 +5,12 @@ use axum::{
routing::{delete, get, post},
Json, Router,
};
use goose::agents::ExtensionConfig;
use goose::config::extensions::name_to_key;
use goose::config::Config;
use goose::config::{extensions::name_to_key, PermissionManager};
use goose::config::{ExtensionConfigManager, ExtensionEntry};
use goose::providers::base::ProviderMetadata;
use goose::providers::providers as get_providers;
use goose::{agents::ExtensionConfig, config::permission::PermissionLevel};
use http::{HeaderMap, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -78,6 +78,18 @@ pub struct ProvidersResponse {
pub providers: Vec<ProviderDetails>,
}
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct ToolPermission {
/// Unique identifier and name of the tool, format <extension_name>__<tool_name>
pub tool_name: String,
pub permission: PermissionLevel,
}
#[derive(Deserialize, ToSchema)]
pub struct UpsertPermissionsQuery {
pub tool_permissions: Vec<ToolPermission>,
}
#[utoipa::path(
post,
path = "/config/upsert",
@@ -389,6 +401,34 @@ pub async fn init_config(
}
}
#[utoipa::path(
post,
path = "/config/permissions",
request_body = UpsertPermissionsQuery,
responses(
(status = 200, description = "Permission update completed", body = String),
(status = 400, description = "Invalid request"),
)
)]
pub async fn upsert_permissions(
State(state): State<AppState>,
headers: HeaderMap,
Json(query): Json<UpsertPermissionsQuery>,
) -> Result<Json<String>, StatusCode> {
verify_secret_key(&headers, &state)?;
let mut permission_manager = PermissionManager::default();
// Iterate over each tool permission and update permissions
for tool_permission in &query.tool_permissions {
permission_manager.update_user_permission(
&tool_permission.tool_name,
tool_permission.permission.clone(),
);
}
Ok(Json("Permissions updated successfully".to_string()))
}
pub fn routes(state: AppState) -> Router {
Router::new()
.route("/config", get(read_all_config))
@@ -400,5 +440,6 @@ pub fn routes(state: AppState) -> Router {
.route("/config/extensions/:name", delete(remove_extension))
.route("/config/providers", get(providers))
.route("/config/init", post(init_config))
.route("/config/permissions", post(upsert_permissions))
.with_state(state)
}

View File

@@ -368,7 +368,7 @@ async fn ask_handler(
#[derive(Debug, Deserialize)]
struct ToolConfirmationRequest {
id: String,
confirmed: bool,
action: String,
}
async fn confirm_handler(
@@ -389,11 +389,14 @@ async fn confirm_handler(
let agent = state.agent.clone();
let agent = agent.read().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
let permission = if request.confirmed {
Permission::AllowOnce
} else {
Permission::DenyOnce
let permission = match request.action.as_str() {
"always_allow" => Permission::AlwaysAllow,
"allow_once" => Permission::AllowOnce,
"deny" => Permission::DenyOnce,
_ => Permission::DenyOnce,
};
agent
.handle_confirmation(
request.id.clone(),

View File

@@ -24,6 +24,8 @@ pub enum ExtensionError {
Transport(#[from] mcp_client::transport::Error),
#[error("Environment variable `{0}` is not allowed to be overridden.")]
InvalidEnvVar(String),
#[error("Join error occurred during task execution: {0}")]
TaskJoinError(#[from] tokio::task::JoinError),
}
pub type ExtensionResult<T> = Result<T, ExtensionError>;

View File

@@ -1,5 +1,6 @@
use anyhow::Result;
use chrono::{DateTime, TimeZone, Utc};
use futures::future;
use futures::stream::{FuturesUnordered, StreamExt};
use mcp_client::McpService;
use mcp_core::protocol::GetPromptResult;
@@ -8,6 +9,7 @@ use std::sync::Arc;
use std::sync::LazyLock;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::task;
use tracing::debug;
use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo};
@@ -230,29 +232,47 @@ impl ExtensionManager {
/// Get all tools from all clients with proper prefixing
pub async fn get_prefixed_tools(&self) -> ExtensionResult<Vec<Tool>> {
let client_futures = self.clients.iter().map(|(name, client)| {
let name = name.clone();
let client = client.clone();
task::spawn(async move {
let mut tools = Vec::new();
let client_guard = client.lock().await;
let mut client_tools = client_guard.list_tools(None).await?;
loop {
for tool in client_tools.tools {
tools.push(Tool::new(
format!("{}__{}", name, tool.name),
&tool.description,
tool.input_schema,
tool.annotations,
));
}
// Exit loop when there are no more pages
if client_tools.next_cursor.is_none() {
break;
}
client_tools = client_guard.list_tools(client_tools.next_cursor).await?;
}
Ok::<Vec<Tool>, ExtensionError>(tools)
})
});
// Collect all results concurrently
let results = future::join_all(client_futures).await;
// Aggregate tools and handle errors
let mut tools = Vec::new();
// Add tools from MCP extensions with prefixing
for (name, client) in &self.clients {
let client_guard = client.lock().await;
let mut client_tools = client_guard.list_tools(None).await?;
loop {
for tool in client_tools.tools {
tools.push(Tool::new(
format!("{}__{}", name, tool.name),
&tool.description,
tool.input_schema,
tool.annotations,
));
}
// exit loop when there are no more pages
if client_tools.next_cursor.is_none() {
break;
}
client_tools = client_guard.list_tools(client_tools.next_cursor).await?;
for result in results {
match result {
Ok(Ok(client_tools)) => tools.extend(client_tools),
Ok(Err(err)) => return Err(err),
Err(join_err) => return Err(ExtensionError::from(join_err)),
}
}