feat: mcp router extension discovery and install tool (#1995)

Co-authored-by: Alice Hau <ahau@squareup.com>
This commit is contained in:
Wendy Tang
2025-04-07 14:39:35 -07:00
committed by GitHub
parent 050a8f2f42
commit 7efe06096d
11 changed files with 401 additions and 104 deletions

View File

@@ -642,6 +642,24 @@ impl Session {
principal_type: PrincipalType::Tool,
permission,
},).await;
} else if let Some(MessageContent::EnableExtensionRequest(enable_extension_request)) = message.content.first() {
output::hide_thinking();
let prompt = "Goose would like to install the following extension, do you approve?".to_string();
let confirmed = cliclack::select(prompt)
.item(true, "Yes, for this session", "Enable the extension for this session")
.item(false, "No", "Do not enable the extension")
.interact()?;
let permission = if confirmed {
Permission::AllowOnce
} else {
Permission::DenyOnce
};
self.agent.handle_confirmation(enable_extension_request.id.clone(), PermissionConfirmation {
principal_name: "extension_name_placeholder".to_string(),
principal_type: PrincipalType::Extension,
permission,
},).await;
}
// otherwise we have a model/tool to render
else {

View File

@@ -11,7 +11,7 @@ use tokio::sync::Mutex;
use tracing::{debug, instrument};
use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo};
use crate::config::Config;
use crate::config::{Config, ExtensionManager};
use crate::prompt_template;
use crate::providers::base::Provider;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
@@ -602,6 +602,8 @@ impl Capabilities {
self.read_resource(tool_call.arguments.clone()).await
} else if tool_call.name == "platform__list_resources" {
self.list_resources(tool_call.arguments.clone()).await
} else if tool_call.name == "platform__search_available_extensions" {
self.search_available_extensions().await
} else if self.is_frontend_tool(&tool_call.name) {
// For frontend tools, return an error indicating we need frontend execution
Err(ToolError::ExecutionError(
@@ -717,6 +719,57 @@ impl Capabilities {
.await
.map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e))
}
pub async fn search_available_extensions(&self) -> Result<Vec<Content>, ToolError> {
let mut output_parts = vec![];
// First get disabled extensions from current config
let mut disabled_extensions: Vec<String> = vec![];
for extension in ExtensionManager::get_all().expect("should load extensions") {
if !extension.enabled {
let config = extension.config.clone();
let description = match &config {
ExtensionConfig::Builtin {
name, display_name, ..
} => {
// For builtin extensions, use display name if available
display_name
.as_ref()
.map(|s| s.to_string())
.unwrap_or_else(|| name.clone())
}
ExtensionConfig::Sse {
description, name, ..
}
| ExtensionConfig::Stdio {
description, name, ..
} => {
// For SSE/Stdio, use description if available
description
.as_ref()
.map(|s| s.to_string())
.unwrap_or_else(|| format!("Extension '{}'", name))
}
ExtensionConfig::Frontend { name, .. } => {
format!("Frontend extension '{}'", name)
}
};
disabled_extensions.push(format!("- {} - {}", config.name(), description));
}
}
if !disabled_extensions.is_empty() {
output_parts.push(format!(
"Currently available extensions user can enable:\n{}\n",
disabled_extensions.join("\n")
));
} else {
output_parts
.push("No available extensions found in current configuration.\n".to_string());
}
Ok(vec![Content::text(output_parts.join("\n"))])
}
}
#[cfg(test)]

View File

@@ -2,8 +2,8 @@
/// It makes no attempt to handle context limits, and cannot read resources
use async_trait::async_trait;
use futures::stream::BoxStream;
use mcp_core::tool::ToolAnnotations;
use std::collections::HashMap;
use mcp_core::tool::{Tool, ToolAnnotations};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
@@ -15,7 +15,7 @@ use super::types::ToolResultReceiver;
use super::Agent;
use crate::agents::capabilities::{get_parameter_names, Capabilities};
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::config::Config;
use crate::config::{Config, ExtensionManager};
use crate::message::{Message, MessageContent, ToolRequest};
use crate::permission::detect_read_only_tools;
use crate::permission::Permission;
@@ -32,9 +32,7 @@ use crate::token_counter::TokenCounter;
use crate::truncate::{truncate_messages, OldestFirstTruncation};
use anyhow::{anyhow, Result};
use indoc::indoc;
use mcp_core::{
prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult,
};
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolError, ToolResult};
use serde_json::{json, Value};
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
@@ -130,6 +128,47 @@ impl TruncateAgent {
let output = capabilities.dispatch_tool_call(tool_call).await;
(request_id, output)
}
async fn enable_extension(
capabilities: &mut Capabilities,
extension_name: String,
request_id: String,
) -> (String, Result<Vec<Content>, ToolError>) {
let config = match ExtensionManager::get_config_by_name(&extension_name) {
Ok(Some(config)) => config,
Ok(None) => {
return (
request_id,
Err(ToolError::ExecutionError(format!(
"Extension '{}' not found. Please check the extension name and try again.",
extension_name
))),
)
}
Err(e) => {
return (
request_id,
Err(ToolError::ExecutionError(format!(
"Failed to get extension config: {}",
e
))),
)
}
};
let result = capabilities
.add_extension(config)
.await
.map(|_| {
vec![Content::text(format!(
"The extension '{}' has been installed successfully",
extension_name
))]
})
.map_err(|e| ToolError::ExecutionError(e.to_string()));
(request_id, result)
}
}
#[async_trait]
@@ -237,10 +276,54 @@ impl Agent for TruncateAgent {
}),
);
let search_available_extensions = Tool::new(
"platform__search_available_extensions".to_string(),
"Searches for additional extensions available to help complete tasks.
Use this tool when you're unable to find a specific feature or functionality you need to complete your task, or when standard approaches aren't working.
These extensions might provide the exact tools needed to solve your problem.
If you find a relevant one, consider using your tools to enable it.".to_string(),
json!({
"type": "object",
"required": [],
"properties": {}
}),
Some(ToolAnnotations {
title: Some("Discover extensions".to_string()),
read_only_hint: true,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
);
let enable_extension_tool = Tool::new(
"platform__enable_extension".to_string(),
"Enable extensions to help complete tasks.
Enable an extension by providing the extension name.
"
.to_string(),
json!({
"type": "object",
"required": ["extension_name"],
"properties": {
"extension_name": {"type": "string", "description": "The name of the extension to enable"}
}
}),
Some(ToolAnnotations {
title: Some("Enable extensions".to_string()),
read_only_hint: false,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
);
if capabilities.supports_resources() {
tools.push(read_resource_tool);
tools.push(list_resources_tool);
}
tools.push(search_available_extensions);
tools.push(enable_extension_tool);
let (tools_with_readonly_annotation, tools_without_annotation): (Vec<String>, Vec<String>) =
tools.iter().fold((vec![], vec![]), |mut acc, tool| {
@@ -366,19 +449,40 @@ impl Agent for TruncateAgent {
}
}
// Split tool requests into enable_extension and others
let (enable_extension_requests, non_enable_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone()
.into_iter()
.partition(|req| {
req.tool_call.as_ref()
.map(|call| call.name == "platform__enable_extension")
.unwrap_or(false)
});
let (search_extension_requests, _non_search_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone()
.into_iter()
.partition(|req| {
req.tool_call.as_ref()
.map(|call| call.name == "platform__search_available_extensions")
.unwrap_or(false)
});
// Clone goose_mode once before the match to avoid move issues
let mode = goose_mode.clone();
match mode.as_str() {
"approve" | "smart_approve" => {
let mut needs_confirmation = Vec::<&ToolRequest>::new();
let mut approved_tools = Vec::new();
let mut llm_detect_candidates = Vec::<&ToolRequest>::new();
let mut detected_read_only_tools = Vec::new();
// First check permissions for all tools
// If there are install extension requests, always require confirmation
// or if goose_mode is approve or smart_approve, check permissions for all tools
if !enable_extension_requests.is_empty() || mode.as_str() == "approve" || mode.as_str() == "smart_approve" {
let mut needs_confirmation = Vec::<&ToolRequest>::new();
let mut approved_tools = Vec::new();
let mut llm_detect_candidates = Vec::<&ToolRequest>::new();
let mut detected_read_only_tools = Vec::new();
// If approve mode or smart approve mode, check permissions for all tools
if mode.as_str() == "approve" || mode.as_str() == "smart_approve" {
let store = ToolPermissionStore::load()?;
for request in remaining_requests.iter() {
for request in &non_enable_extension_requests {
if let Ok(tool_call) = request.tool_call.clone() {
// Regular permission checking for other tools
if tools_with_readonly_annotation.contains(&tool_call.name) {
approved_tools.push((request.id.clone(), tool_call));
} else if let Some(allowed) = store.check_permission(request) {
@@ -400,110 +504,172 @@ impl Agent for TruncateAgent {
}
}
}
// Only check read-only status for tools without annotation
if !llm_detect_candidates.is_empty() && mode == "smart_approve" {
detected_read_only_tools = detect_read_only_tools(&capabilities, llm_detect_candidates.clone()).await;
}
// Only check read-only status for tools needing confirmation
if !llm_detect_candidates.is_empty() && mode == "smart_approve" {
detected_read_only_tools = detect_read_only_tools(&capabilities, llm_detect_candidates.clone()).await;
// Remove install extensions from read-only tools
if !enable_extension_requests.is_empty() {
detected_read_only_tools.retain(|tool_name| {
!enable_extension_requests.iter().any(|req| {
req.tool_call.as_ref()
.map(|call| call.name == *tool_name)
.unwrap_or(false)
})
});
}
}
// Handle pre-approved and read-only tools in parallel
let mut tool_futures = Vec::new();
// Handle pre-approved and read-only tools in parallel
let mut tool_futures = Vec::new();
let mut install_results = Vec::new();
// Add pre-approved tools
for (request_id, tool_call) in approved_tools {
let tool_future = Self::create_tool_future(&capabilities, tool_call, request_id.clone());
tool_futures.push(tool_future);
// Handle install extension requests
for request in &enable_extension_requests {
if let Ok(tool_call) = request.tool_call.clone() {
let confirmation = Message::user().with_enable_extension_request(
request.id.clone(),
tool_call.arguments.get("extension_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string()
);
yield confirmation;
let mut rx = self.confirmation_rx.lock().await;
while let Some((req_id, extension_confirmation)) = rx.recv().await {
if req_id == request.id {
if extension_confirmation.permission == Permission::AllowOnce || extension_confirmation.permission == Permission::AlwaysAllow {
let extension_name = tool_call.arguments.get("extension_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let install_result = Self::enable_extension(&mut capabilities, extension_name, request.id.clone()).await;
install_results.push(install_result);
}
break;
}
}
}
}
// Process read-only tools
for request in &needs_confirmation {
if let Ok(tool_call) = request.tool_call.clone() {
// Skip confirmation if the tool_call.name is in the read_only_tools list
if detected_read_only_tools.contains(&tool_call.name) {
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
tool_futures.push(tool_future);
} else {
let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
);
yield confirmation;
// Process read-only tools
for request in &needs_confirmation {
if let Ok(tool_call) = request.tool_call.clone() {
// Skip confirmation if the tool_call.name is in the read_only_tools list
if detected_read_only_tools.contains(&tool_call.name) {
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
tool_futures.push(tool_future);
} else {
let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
);
yield confirmation;
// Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await;
while let Some((req_id, tool_confirmation)) = rx.recv().await {
if req_id == request.id {
// Store the user's response with 30-day expiration
let confirmed = tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow;
if confirmed {
// Add this tool call to the futures collection
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
tool_futures.push(tool_future);
} else {
// User declined - add declined response
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text(
"The user has declined to run this tool. \
DO NOT attempt to call this tool again. \
If there are no alternative methods to proceed, clearly explain the situation and STOP.")]),
);
}
break; // Exit the loop once the matching `req_id` is found
// Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await;
while let Some((req_id, tool_confirmation)) = rx.recv().await {
if req_id == request.id {
let confirmed = tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow;
if confirmed {
// Add this tool call to the futures collection
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
tool_futures.push(tool_future);
} else {
// User declined - add declined response
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text(
"The user has declined to run this tool. \
DO NOT attempt to call this tool again. \
If there are no alternative methods to proceed, clearly explain the situation and STOP.")]),
);
}
break; // Exit the loop once the matching `req_id` is found
}
}
}
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
},
"chat" => {
// Skip all tool calls in chat mode
for request in &remaining_requests {
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text(
"Let the user know the tool call was skipped in Goose chat mode. \
DO NOT apologize for skipping the tool call. DO NOT say sorry. \
Provide an explanation of what the tool call would do, structured as a \
plan for the user. Again, DO NOT apologize. \
**Example Plan:**\n \
1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \
2. **Outline Steps** - Break down the steps.\n \
If needed, adjust the explanation based on user preferences or questions."
)]),
);
}
},
_ => {
if mode != "auto" {
warn!("Unknown GOOSE_MODE: {mode:?}. Defaulting to 'auto' mode.");
}
// Process tool requests in parallel
let mut tool_futures = Vec::new();
for request in &remaining_requests {
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
// Check if any install results had errors before processing them
let all_successful = !install_results.iter().any(|(_, result)| result.is_err());
for (request_id, output) in install_results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output
);
}
// Update system prompt and tools if all installations were successful
if all_successful {
system_prompt = capabilities.get_system_prompt().await;
tools = capabilities.get_prefixed_tools().await?;
}
}
if mode.as_str() == "auto" || !search_extension_requests.is_empty() {
let mut tool_futures = Vec::new();
// Process non_enable_extension_requests and search_extension_requests without duplicates
let mut processed_ids = HashSet::new();
for request in non_enable_extension_requests.iter().chain(search_extension_requests.iter()) {
if processed_ids.insert(request.id.clone()) {
if let Ok(tool_call) = request.tool_call.clone() {
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
tool_futures.push(tool_future);
}
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
}
if mode.as_str() == "chat" {
// Skip all tool calls in chat mode
// Skip search extension requests since they were already processed
let non_search_non_enable_extension_requests = non_enable_extension_requests.iter()
.filter(|req| {
if let Ok(tool_call) = &req.tool_call {
tool_call.name != "platform__search_available_extensions"
} else {
true
}
});
for request in non_search_non_enable_extension_requests {
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text(
"Let the user know the tool call was skipped in Goose chat mode. \
DO NOT apologize for skipping the tool call. DO NOT say sorry. \
Provide an explanation of what the tool call would do, structured as a \
plan for the user. Again, DO NOT apologize. \
**Example Plan:**\n \
1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \
2. **Outline Steps** - Break down the steps.\n \
If needed, adjust the explanation based on user preferences or questions."
)]),
);
}
}

View File

@@ -63,6 +63,22 @@ impl ExtensionManager {
}))
}
pub fn get_config_by_name(name: &str) -> Result<Option<ExtensionConfig>> {
let config = Config::global();
// Try to get the extension entry
let extensions: HashMap<String, ExtensionEntry> = match config.get_param("extensions") {
Ok(exts) => exts,
Err(super::ConfigError::NotFound(_)) => HashMap::new(),
Err(_) => HashMap::new(),
};
Ok(extensions
.values()
.find(|entry| entry.config.name() == name)
.map(|entry| entry.config.clone()))
}
/// Set or update an extension configuration
pub fn set(entry: ExtensionEntry) -> Result<()> {
let config = Config::global();

View File

@@ -59,6 +59,13 @@ pub struct ToolConfirmationRequest {
pub prompt: Option<String>,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EnableExtensionRequest {
pub id: String,
pub extension_name: String,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct ThinkingContent {
pub thinking: String,
@@ -87,6 +94,7 @@ pub enum MessageContent {
ToolRequest(ToolRequest),
ToolResponse(ToolResponse),
ToolConfirmationRequest(ToolConfirmationRequest),
EnableExtensionRequest(EnableExtensionRequest),
FrontendToolRequest(FrontendToolRequest),
Thinking(ThinkingContent),
RedactedThinking(RedactedThinkingContent),
@@ -136,6 +144,13 @@ impl MessageContent {
})
}
pub fn enable_extension_request<S: Into<String>>(id: S, extension_name: String) -> Self {
MessageContent::EnableExtensionRequest(EnableExtensionRequest {
id: id.into(),
extension_name,
})
}
pub fn thinking<S1: Into<String>, S2: Into<String>>(thinking: S1, signature: S2) -> Self {
MessageContent::Thinking(ThinkingContent {
thinking: thinking.into(),
@@ -177,6 +192,14 @@ impl MessageContent {
}
}
pub fn as_enable_extension_request(&self) -> Option<&EnableExtensionRequest> {
if let MessageContent::EnableExtensionRequest(ref enable_extension_request) = self {
Some(enable_extension_request)
} else {
None
}
}
pub fn as_tool_response_text(&self) -> Option<String> {
if let Some(tool_response) = self.as_tool_response() {
if let Ok(contents) = &tool_response.tool_result {
@@ -336,6 +359,14 @@ impl Message {
))
}
pub fn with_enable_extension_request<S: Into<String>>(
self,
id: S,
extension_name: String,
) -> Self {
self.with_content(MessageContent::enable_extension_request(id, extension_name))
}
pub fn with_frontend_tool_request<S: Into<String>>(
self,
id: S,

View File

@@ -9,7 +9,7 @@ pub enum Permission {
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub enum PrincipalType {
Extention,
Extension,
Tool,
}

View File

@@ -9,6 +9,7 @@ These models have varying knowledge cut-off dates depending on when they were tr
Extensions allow other applications to provide context to Goose. Extensions connect Goose to different data sources and tools.
You are capable of dynamically plugging into new extensions and learning how to use them. You solve higher level problems using the tools in these extensions, and can interact with multiple at once.
Use the search_available_extensions tool to find additional extensions to enable to help with your task. To enable extensions, use the enable_extensions tool. You should only enable extensions found from the search_available_extensions tool.
{% if (extensions is defined) and extensions %}
Because you dynamically load extensions, your conversation history may refer

View File

@@ -60,6 +60,9 @@ pub fn format_messages(messages: &[Message]) -> Vec<Value> {
MessageContent::ToolConfirmationRequest(_tool_confirmation_request) => {
// Skip tool confirmation requests
}
MessageContent::EnableExtensionRequest(_enable_extension_request) => {
// Skip enable extension requests
}
MessageContent::Thinking(thinking) => {
content.push(json!({
"type": "thinking",

View File

@@ -31,6 +31,9 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::C
MessageContent::ToolConfirmationRequest(_tool_confirmation_request) => {
bedrock::ContentBlock::Text("".to_string())
}
MessageContent::EnableExtensionRequest(_enable_extension_request) => {
bedrock::ContentBlock::Text("".to_string())
}
MessageContent::Image(_) => {
bail!("Image content is not supported by Bedrock provider yet")
}

View File

@@ -179,6 +179,9 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
MessageContent::ToolConfirmationRequest(_) => {
// Skip tool confirmation requests
}
MessageContent::EnableExtensionRequest(_) => {
// Skip enable extension requests
}
MessageContent::Image(image) => {
// Handle direct image content
content_array.push(json!({

View File

@@ -147,6 +147,9 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
MessageContent::ToolConfirmationRequest(_) => {
// Skip tool confirmation requests
}
MessageContent::EnableExtensionRequest(_) => {
// Skip enable extension requests
}
MessageContent::Image(image) => {
// Handle direct image content
converted["content"] = json!([convert_image(image, image_format)]);