mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
feat: support tool level permission control in ui (#2133)
This commit is contained in:
@@ -21,6 +21,7 @@ use utoipa::OpenApi;
|
||||
super::routes::config_management::get_extensions,
|
||||
super::routes::config_management::read_all_config,
|
||||
super::routes::config_management::providers,
|
||||
super::routes::config_management::upsert_permissions,
|
||||
super::routes::agent::get_tools,
|
||||
),
|
||||
components(schemas(
|
||||
@@ -31,6 +32,8 @@ use utoipa::OpenApi;
|
||||
super::routes::config_management::ProviderDetails,
|
||||
super::routes::config_management::ExtensionResponse,
|
||||
super::routes::config_management::ExtensionQuery,
|
||||
super::routes::config_management::ToolPermission,
|
||||
super::routes::config_management::UpsertPermissionsQuery,
|
||||
ProviderMetadata,
|
||||
ExtensionEntry,
|
||||
ExtensionConfig,
|
||||
|
||||
@@ -5,10 +5,13 @@ use axum::{
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use goose::agents::{extension::ToolInfo, extension_manager::get_parameter_names};
|
||||
use goose::config::Config;
|
||||
use goose::config::PermissionManager;
|
||||
use goose::{agents::Agent, model::ModelConfig, providers};
|
||||
use goose::{
|
||||
agents::{extension::ToolInfo, extension_manager::get_parameter_names},
|
||||
config::permission::PermissionLevel,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
@@ -173,7 +176,7 @@ async fn list_providers() -> Json<Vec<ProviderList>> {
|
||||
("extension_name" = Option<String>, Query, description = "Optional extension name to filter tools")
|
||||
),
|
||||
responses(
|
||||
(status = 200, description = "Tools retrieved successfully", body = Vec<Tool>),
|
||||
(status = 200, description = "Tools retrieved successfully", body = Vec<ToolInfo>),
|
||||
(status = 401, description = "Unauthorized - invalid secret key"),
|
||||
(status = 424, description = "Agent not initialized"),
|
||||
(status = 500, description = "Internal server error")
|
||||
@@ -193,11 +196,13 @@ async fn get_tools(
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let config = Config::global();
|
||||
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
|
||||
let mut agent = state.agent.write().await;
|
||||
let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?;
|
||||
let permission_manager = PermissionManager::default();
|
||||
|
||||
let tools = agent
|
||||
let mut tools: Vec<ToolInfo> = agent
|
||||
.list_tools()
|
||||
.await
|
||||
.into_iter()
|
||||
@@ -210,14 +215,27 @@ async fn get_tools(
|
||||
}
|
||||
})
|
||||
.map(|tool| {
|
||||
let permission = permission_manager
|
||||
.get_user_permission(&tool.name)
|
||||
.or_else(|| {
|
||||
if goose_mode == "smart_approve" {
|
||||
permission_manager.get_smart_approve_permission(&tool.name)
|
||||
} else if goose_mode == "approve" {
|
||||
Some(PermissionLevel::AskBefore)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
ToolInfo::new(
|
||||
&tool.name,
|
||||
&tool.description,
|
||||
get_parameter_names(&tool),
|
||||
permission_manager.get_user_permission(&tool.name),
|
||||
permission,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
.collect::<Vec<ToolInfo>>();
|
||||
tools.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
Ok(Json(tools))
|
||||
}
|
||||
|
||||
@@ -5,12 +5,12 @@ use axum::{
|
||||
routing::{delete, get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use goose::agents::ExtensionConfig;
|
||||
use goose::config::extensions::name_to_key;
|
||||
use goose::config::Config;
|
||||
use goose::config::{extensions::name_to_key, PermissionManager};
|
||||
use goose::config::{ExtensionConfigManager, ExtensionEntry};
|
||||
use goose::providers::base::ProviderMetadata;
|
||||
use goose::providers::providers as get_providers;
|
||||
use goose::{agents::ExtensionConfig, config::permission::PermissionLevel};
|
||||
use http::{HeaderMap, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
@@ -78,6 +78,18 @@ pub struct ProvidersResponse {
|
||||
pub providers: Vec<ProviderDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct ToolPermission {
|
||||
/// Unique identifier and name of the tool, format <extension_name>__<tool_name>
|
||||
pub tool_name: String,
|
||||
pub permission: PermissionLevel,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, ToSchema)]
|
||||
pub struct UpsertPermissionsQuery {
|
||||
pub tool_permissions: Vec<ToolPermission>,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/config/upsert",
|
||||
@@ -389,6 +401,34 @@ pub async fn init_config(
|
||||
}
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/config/permissions",
|
||||
request_body = UpsertPermissionsQuery,
|
||||
responses(
|
||||
(status = 200, description = "Permission update completed", body = String),
|
||||
(status = 400, description = "Invalid request"),
|
||||
)
|
||||
)]
|
||||
pub async fn upsert_permissions(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(query): Json<UpsertPermissionsQuery>,
|
||||
) -> Result<Json<String>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
let mut permission_manager = PermissionManager::default();
|
||||
// Iterate over each tool permission and update permissions
|
||||
for tool_permission in &query.tool_permissions {
|
||||
permission_manager.update_user_permission(
|
||||
&tool_permission.tool_name,
|
||||
tool_permission.permission.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Json("Permissions updated successfully".to_string()))
|
||||
}
|
||||
|
||||
pub fn routes(state: AppState) -> Router {
|
||||
Router::new()
|
||||
.route("/config", get(read_all_config))
|
||||
@@ -400,5 +440,6 @@ pub fn routes(state: AppState) -> Router {
|
||||
.route("/config/extensions/:name", delete(remove_extension))
|
||||
.route("/config/providers", get(providers))
|
||||
.route("/config/init", post(init_config))
|
||||
.route("/config/permissions", post(upsert_permissions))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
@@ -368,7 +368,7 @@ async fn ask_handler(
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ToolConfirmationRequest {
|
||||
id: String,
|
||||
confirmed: bool,
|
||||
action: String,
|
||||
}
|
||||
|
||||
async fn confirm_handler(
|
||||
@@ -389,11 +389,14 @@ async fn confirm_handler(
|
||||
let agent = state.agent.clone();
|
||||
let agent = agent.read().await;
|
||||
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
|
||||
let permission = if request.confirmed {
|
||||
Permission::AllowOnce
|
||||
} else {
|
||||
Permission::DenyOnce
|
||||
|
||||
let permission = match request.action.as_str() {
|
||||
"always_allow" => Permission::AlwaysAllow,
|
||||
"allow_once" => Permission::AllowOnce,
|
||||
"deny" => Permission::DenyOnce,
|
||||
_ => Permission::DenyOnce,
|
||||
};
|
||||
|
||||
agent
|
||||
.handle_confirmation(
|
||||
request.id.clone(),
|
||||
|
||||
@@ -24,6 +24,8 @@ pub enum ExtensionError {
|
||||
Transport(#[from] mcp_client::transport::Error),
|
||||
#[error("Environment variable `{0}` is not allowed to be overridden.")]
|
||||
InvalidEnvVar(String),
|
||||
#[error("Join error occurred during task execution: {0}")]
|
||||
TaskJoinError(#[from] tokio::task::JoinError),
|
||||
}
|
||||
|
||||
pub type ExtensionResult<T> = Result<T, ExtensionError>;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, TimeZone, Utc};
|
||||
use futures::future;
|
||||
use futures::stream::{FuturesUnordered, StreamExt};
|
||||
use mcp_client::McpService;
|
||||
use mcp_core::protocol::GetPromptResult;
|
||||
@@ -8,6 +9,7 @@ use std::sync::Arc;
|
||||
use std::sync::LazyLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task;
|
||||
use tracing::debug;
|
||||
|
||||
use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo};
|
||||
@@ -230,29 +232,47 @@ impl ExtensionManager {
|
||||
|
||||
/// Get all tools from all clients with proper prefixing
|
||||
pub async fn get_prefixed_tools(&self) -> ExtensionResult<Vec<Tool>> {
|
||||
let client_futures = self.clients.iter().map(|(name, client)| {
|
||||
let name = name.clone();
|
||||
let client = client.clone();
|
||||
|
||||
task::spawn(async move {
|
||||
let mut tools = Vec::new();
|
||||
let client_guard = client.lock().await;
|
||||
let mut client_tools = client_guard.list_tools(None).await?;
|
||||
|
||||
loop {
|
||||
for tool in client_tools.tools {
|
||||
tools.push(Tool::new(
|
||||
format!("{}__{}", name, tool.name),
|
||||
&tool.description,
|
||||
tool.input_schema,
|
||||
tool.annotations,
|
||||
));
|
||||
}
|
||||
|
||||
// Exit loop when there are no more pages
|
||||
if client_tools.next_cursor.is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
client_tools = client_guard.list_tools(client_tools.next_cursor).await?;
|
||||
}
|
||||
|
||||
Ok::<Vec<Tool>, ExtensionError>(tools)
|
||||
})
|
||||
});
|
||||
|
||||
// Collect all results concurrently
|
||||
let results = future::join_all(client_futures).await;
|
||||
|
||||
// Aggregate tools and handle errors
|
||||
let mut tools = Vec::new();
|
||||
|
||||
// Add tools from MCP extensions with prefixing
|
||||
for (name, client) in &self.clients {
|
||||
let client_guard = client.lock().await;
|
||||
let mut client_tools = client_guard.list_tools(None).await?;
|
||||
|
||||
loop {
|
||||
for tool in client_tools.tools {
|
||||
tools.push(Tool::new(
|
||||
format!("{}__{}", name, tool.name),
|
||||
&tool.description,
|
||||
tool.input_schema,
|
||||
tool.annotations,
|
||||
));
|
||||
}
|
||||
|
||||
// exit loop when there are no more pages
|
||||
if client_tools.next_cursor.is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
client_tools = client_guard.list_tools(client_tools.next_cursor).await?;
|
||||
for result in results {
|
||||
match result {
|
||||
Ok(Ok(client_tools)) => tools.extend(client_tools),
|
||||
Ok(Err(err)) => return Err(err),
|
||||
Err(join_err) => return Err(ExtensionError::from(join_err)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user