mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 14:14:26 +01:00
feat: Enable frontend tools (#1778)
This commit is contained in:
@@ -8,6 +8,7 @@ use goose::{
|
||||
};
|
||||
use http::{HeaderMap, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing;
|
||||
|
||||
/// Enum representing the different types of extension configuration requests.
|
||||
#[derive(Deserialize)]
|
||||
@@ -48,6 +49,16 @@ enum ExtensionConfigRequest {
|
||||
display_name: Option<String>,
|
||||
timeout: Option<u64>,
|
||||
},
|
||||
/// Frontend extension that provides tools to be executed by the frontend.
|
||||
#[serde(rename = "frontend")]
|
||||
Frontend {
|
||||
/// The name to identify this extension
|
||||
name: String,
|
||||
/// The tools provided by this extension
|
||||
tools: Vec<mcp_core::tool::Tool>,
|
||||
/// Optional instructions for using the tools
|
||||
instructions: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Response structure for adding an extension.
|
||||
@@ -64,8 +75,26 @@ struct ExtensionResponse {
|
||||
async fn add_extension(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<ExtensionConfigRequest>,
|
||||
raw: axum::extract::Json<serde_json::Value>,
|
||||
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
||||
// Log the raw request for debugging
|
||||
tracing::info!(
|
||||
"Received extension request: {}",
|
||||
serde_json::to_string_pretty(&raw.0).unwrap()
|
||||
);
|
||||
|
||||
// Try to parse into our enum
|
||||
let request: ExtensionConfigRequest = match serde_json::from_value(raw.0.clone()) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to parse extension request: {}", e);
|
||||
tracing::error!(
|
||||
"Raw request was: {}",
|
||||
serde_json::to_string_pretty(&raw.0).unwrap()
|
||||
);
|
||||
return Err(StatusCode::UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
};
|
||||
// Verify the presence and validity of the secret key.
|
||||
let secret_key = headers
|
||||
.get("X-Secret-Key")
|
||||
@@ -167,6 +196,15 @@ async fn add_extension(
|
||||
display_name,
|
||||
timeout,
|
||||
},
|
||||
ExtensionConfigRequest::Frontend {
|
||||
name,
|
||||
tools,
|
||||
instructions,
|
||||
} => ExtensionConfig::Frontend {
|
||||
name,
|
||||
tools,
|
||||
instructions,
|
||||
},
|
||||
};
|
||||
|
||||
// Acquire a lock on the agent and attempt to add the extension.
|
||||
|
||||
@@ -13,9 +13,9 @@ use goose::{
|
||||
agents::SessionConfig,
|
||||
message::{Message, MessageContent},
|
||||
};
|
||||
|
||||
use mcp_core::role::Role;
|
||||
use mcp_core::{role::Role, Content, ToolResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
@@ -391,12 +391,59 @@ async fn confirm_handler(
|
||||
Ok(Json(Value::Object(serde_json::Map::new())))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ToolResultRequest {
|
||||
id: String,
|
||||
result: ToolResult<Vec<Content>>,
|
||||
}
|
||||
|
||||
async fn submit_tool_result(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
raw: axum::extract::Json<serde_json::Value>,
|
||||
) -> Result<Json<Value>, StatusCode> {
|
||||
// Log the raw request for debugging
|
||||
tracing::info!(
|
||||
"Received tool result request: {}",
|
||||
serde_json::to_string_pretty(&raw.0).unwrap()
|
||||
);
|
||||
|
||||
// Try to parse into our struct
|
||||
let payload: ToolResultRequest = match serde_json::from_value(raw.0.clone()) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to parse tool result request: {}", e);
|
||||
tracing::error!(
|
||||
"Raw request was: {}",
|
||||
serde_json::to_string_pretty(&raw.0).unwrap()
|
||||
);
|
||||
return Err(StatusCode::UNPROCESSABLE_ENTITY);
|
||||
}
|
||||
};
|
||||
|
||||
// Verify secret key
|
||||
let secret_key = headers
|
||||
.get("X-Secret-Key")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if secret_key != state.secret_key {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let agent = state.agent.read().await;
|
||||
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
|
||||
agent.handle_tool_result(payload.id, payload.result).await;
|
||||
Ok(Json(json!({"status": "ok"})))
|
||||
}
|
||||
|
||||
// Configure routes for this module
|
||||
pub fn routes(state: AppState) -> Router {
|
||||
Router::new()
|
||||
.route("/reply", post(handler))
|
||||
.route("/ask", post(ask_handler))
|
||||
.route("/confirm", post(confirm_handler))
|
||||
.route("/tool_result", post(submit_tool_result))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
|
||||
@@ -12,8 +12,7 @@ use super::extension::{ExtensionConfig, ExtensionResult};
|
||||
use crate::message::Message;
|
||||
use crate::providers::base::Provider;
|
||||
use crate::session;
|
||||
use mcp_core::prompt::Prompt;
|
||||
use mcp_core::protocol::GetPromptResult;
|
||||
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolResult};
|
||||
|
||||
/// Session configuration for an agent
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -68,4 +67,7 @@ pub trait Agent: Send + Sync {
|
||||
|
||||
/// Get a reference to the provider used by this agent
|
||||
async fn provider(&self) -> Arc<Box<dyn Provider>>;
|
||||
|
||||
/// Handle a tool result from the frontend
|
||||
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>);
|
||||
}
|
||||
|
||||
@@ -26,9 +26,17 @@ static DEFAULT_TIMESTAMP: LazyLock<DateTime<Utc>> =
|
||||
|
||||
type McpClientBox = Arc<Mutex<Box<dyn McpClientTrait>>>;
|
||||
|
||||
/// A frontend tool that will be executed by the frontend rather than an extension
|
||||
#[derive(Clone)]
|
||||
pub struct FrontendTool {
|
||||
pub name: String,
|
||||
pub tool: Tool,
|
||||
}
|
||||
|
||||
/// Manages MCP clients and their interactions
|
||||
pub struct Capabilities {
|
||||
clients: HashMap<String, McpClientBox>,
|
||||
frontend_tools: HashMap<String, FrontendTool>,
|
||||
instructions: HashMap<String, String>,
|
||||
resource_capable_extensions: HashSet<String>,
|
||||
provider: Arc<Box<dyn Provider>>,
|
||||
@@ -96,6 +104,7 @@ impl Capabilities {
|
||||
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||
Self {
|
||||
clients: HashMap::new(),
|
||||
frontend_tools: HashMap::new(),
|
||||
instructions: HashMap::new(),
|
||||
resource_capable_extensions: HashSet::new(),
|
||||
provider: Arc::new(provider),
|
||||
@@ -111,6 +120,30 @@ impl Capabilities {
|
||||
/// Add a new MCP extension based on the provided client type
|
||||
// TODO IMPORTANT need to ensure this times out if the extension command is broken!
|
||||
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
|
||||
let sanitized_name = normalize(config.key().to_string());
|
||||
|
||||
match &config {
|
||||
ExtensionConfig::Frontend {
|
||||
name: _,
|
||||
tools,
|
||||
instructions,
|
||||
} => {
|
||||
// For frontend tools, just store them in the frontend_tools map
|
||||
for tool in tools {
|
||||
let frontend_tool = FrontendTool {
|
||||
name: tool.name.clone(),
|
||||
tool: tool.clone(),
|
||||
};
|
||||
self.frontend_tools.insert(tool.name.clone(), frontend_tool);
|
||||
}
|
||||
// Store instructions if provided, using "frontend" as the key
|
||||
if let Some(instructions) = instructions {
|
||||
self.instructions
|
||||
.insert("frontend".to_string(), instructions.clone());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
let mut client: Box<dyn McpClientTrait> = match &config {
|
||||
ExtensionConfig::Sse {
|
||||
uri, envs, timeout, ..
|
||||
@@ -142,10 +175,9 @@ impl Capabilities {
|
||||
);
|
||||
Box::new(McpClient::new(service))
|
||||
}
|
||||
#[allow(unused_variables)]
|
||||
ExtensionConfig::Builtin {
|
||||
name,
|
||||
display_name,
|
||||
display_name: _,
|
||||
timeout,
|
||||
} => {
|
||||
// For builtin extensions, we run the current executable with mcp and extension name
|
||||
@@ -168,6 +200,7 @@ impl Capabilities {
|
||||
);
|
||||
Box::new(McpClient::new(service))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
// Initialize the client with default capabilities
|
||||
@@ -182,8 +215,6 @@ impl Capabilities {
|
||||
.await
|
||||
.map_err(|e| ExtensionError::Initialization(config.clone(), e))?;
|
||||
|
||||
let sanitized_name = normalize(config.key().to_string());
|
||||
|
||||
// Store instructions if provided
|
||||
if let Some(instructions) = init_result.instructions {
|
||||
self.instructions
|
||||
@@ -202,6 +233,8 @@ impl Capabilities {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a system prompt extension
|
||||
pub fn add_system_prompt_extension(&mut self, extension: String) {
|
||||
@@ -235,6 +268,13 @@ impl Capabilities {
|
||||
/// Get all tools from all clients with proper prefixing
|
||||
pub async fn get_prefixed_tools(&mut self) -> ExtensionResult<Vec<Tool>> {
|
||||
let mut tools = Vec::new();
|
||||
|
||||
// Add frontend tools directly - they don't need prefixing since they're already uniquely named
|
||||
for frontend_tool in self.frontend_tools.values() {
|
||||
tools.push(frontend_tool.tool.clone());
|
||||
}
|
||||
|
||||
// Add tools from MCP extensions with prefixing
|
||||
for (name, client) in &self.clients {
|
||||
let client_guard = client.lock().await;
|
||||
let mut client_tools = client_guard.list_tools(None).await?;
|
||||
@@ -317,7 +357,7 @@ impl Capabilities {
|
||||
pub async fn get_system_prompt(&self) -> String {
|
||||
let mut context: HashMap<&str, Value> = HashMap::new();
|
||||
|
||||
let extensions_info: Vec<ExtensionInfo> = self
|
||||
let mut extensions_info: Vec<ExtensionInfo> = self
|
||||
.clients
|
||||
.keys()
|
||||
.map(|name| {
|
||||
@@ -326,6 +366,15 @@ impl Capabilities {
|
||||
ExtensionInfo::new(name, &instructions, has_resources)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add frontend tools as a special extension if any exist
|
||||
if !self.frontend_tools.is_empty() {
|
||||
let name = "frontend";
|
||||
let instructions = self.instructions.get(name).cloned().unwrap_or_else(||
|
||||
"The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string()
|
||||
);
|
||||
extensions_info.push(ExtensionInfo::new(name, &instructions, false));
|
||||
}
|
||||
context.insert("extensions", serde_json::to_value(extensions_info).unwrap());
|
||||
|
||||
let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
|
||||
@@ -364,6 +413,16 @@ impl Capabilities {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a tool is a frontend tool
|
||||
pub fn is_frontend_tool(&self, name: &str) -> bool {
|
||||
self.frontend_tools.contains_key(name)
|
||||
}
|
||||
|
||||
/// Get a reference to a frontend tool
|
||||
pub fn get_frontend_tool(&self, name: &str) -> Option<&FrontendTool> {
|
||||
self.frontend_tools.get(name)
|
||||
}
|
||||
|
||||
/// Find and return a reference to the appropriate client for a tool call
|
||||
fn get_client_for_tool(&self, prefixed_name: &str) -> Option<(&str, McpClientBox)> {
|
||||
self.clients
|
||||
@@ -543,6 +602,11 @@ impl Capabilities {
|
||||
self.read_resource(tool_call.arguments.clone()).await
|
||||
} else if tool_call.name == "platform__list_resources" {
|
||||
self.list_resources(tool_call.arguments.clone()).await
|
||||
} else if self.is_frontend_tool(&tool_call.name) {
|
||||
// For frontend tools, return an error indicating we need frontend execution
|
||||
Err(ToolError::ExecutionError(
|
||||
"Frontend tool execution required".to_string(),
|
||||
))
|
||||
} else {
|
||||
// Else, dispatch tool call based on the prefix naming convention
|
||||
let (client_name, client) = self
|
||||
|
||||
@@ -149,6 +149,16 @@ pub enum ExtensionConfig {
|
||||
display_name: Option<String>, // needed for the UI
|
||||
timeout: Option<u64>,
|
||||
},
|
||||
/// Frontend-provided tools that will be called through the frontend
|
||||
#[serde(rename = "frontend")]
|
||||
Frontend {
|
||||
/// The name used to identify this extension
|
||||
name: String,
|
||||
/// The tools provided by the frontend
|
||||
tools: Vec<mcp_core::tool::Tool>,
|
||||
/// Instructions for how to use these tools
|
||||
instructions: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for ExtensionConfig {
|
||||
@@ -224,6 +234,7 @@ impl ExtensionConfig {
|
||||
Self::Sse { name, .. } => name,
|
||||
Self::Stdio { name, .. } => name,
|
||||
Self::Builtin { name, .. } => name,
|
||||
Self::Frontend { name, .. } => name,
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
@@ -239,6 +250,9 @@ impl std::fmt::Display for ExtensionConfig {
|
||||
write!(f, "Stdio({}: {} {})", name, cmd, args.join(" "))
|
||||
}
|
||||
ExtensionConfig::Builtin { name, .. } => write!(f, "Builtin({})", name),
|
||||
ExtensionConfig::Frontend { name, tools, .. } => {
|
||||
write!(f, "Frontend({}: {} tools)", name, tools.len())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ mod permission_store;
|
||||
mod reference;
|
||||
mod summarize;
|
||||
mod truncate;
|
||||
mod types;
|
||||
|
||||
pub use agent::{Agent, SessionConfig};
|
||||
pub use capabilities::Capabilities;
|
||||
|
||||
@@ -4,12 +4,13 @@ use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
use super::agent::SessionConfig;
|
||||
use super::capabilities::get_parameter_names;
|
||||
use super::extension::ToolInfo;
|
||||
use super::types::ToolResultReceiver;
|
||||
use super::Agent;
|
||||
use crate::agents::capabilities::Capabilities;
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
|
||||
@@ -19,23 +20,27 @@ use crate::token_counter::TokenCounter;
|
||||
use crate::{register_agent, session};
|
||||
use anyhow::{anyhow, Result};
|
||||
use indoc::indoc;
|
||||
use mcp_core::prompt::Prompt;
|
||||
use mcp_core::protocol::GetPromptResult;
|
||||
use mcp_core::tool::{Tool, ToolAnnotations};
|
||||
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolResult};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
/// Reference implementation of an Agent
|
||||
pub struct ReferenceAgent {
|
||||
capabilities: Mutex<Capabilities>,
|
||||
_token_counter: TokenCounter,
|
||||
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||
tool_result_rx: ToolResultReceiver,
|
||||
}
|
||||
|
||||
impl ReferenceAgent {
|
||||
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||
let (tx, rx) = mpsc::channel(32);
|
||||
Self {
|
||||
capabilities: Mutex::new(Capabilities::new(provider)),
|
||||
_token_counter: token_counter,
|
||||
tool_result_tx: tx,
|
||||
tool_result_rx: Arc::new(Mutex::new(rx)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -193,24 +198,32 @@ impl Agent for ReferenceAgent {
|
||||
}
|
||||
|
||||
// Then dispatch each in parallel
|
||||
let futures: Vec<_> = tool_requests
|
||||
.iter()
|
||||
.filter_map(|request| request.tool_call.clone().ok())
|
||||
.map(|tool_call| capabilities.dispatch_tool_call(tool_call))
|
||||
.collect();
|
||||
|
||||
// Process all the futures in parallel but wait until all are finished
|
||||
let outputs = futures::future::join_all(futures).await;
|
||||
|
||||
// Create a message with the responses
|
||||
let mut message_tool_response = Message::user();
|
||||
// Now combine these into MessageContent::ToolResponse using the original ID
|
||||
for (request, output) in tool_requests.iter().zip(outputs.into_iter()) {
|
||||
for request in tool_requests {
|
||||
if let Ok(tool_call) = &request.tool_call {
|
||||
// Check if it's a frontend tool
|
||||
if capabilities.is_frontend_tool(&tool_call.name) {
|
||||
// Send frontend tool request and wait for response
|
||||
yield Message::assistant().with_frontend_tool_request(
|
||||
request.id.clone(),
|
||||
request.tool_call.clone()
|
||||
);
|
||||
|
||||
// Wait for the result using our channel
|
||||
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
|
||||
message_tool_response = message_tool_response.with_tool_response(id, result);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle regular tool calls
|
||||
let result = capabilities.dispatch_tool_call(tool_call.clone()).await;
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
output,
|
||||
result,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
yield message_tool_response.clone();
|
||||
|
||||
@@ -278,6 +291,12 @@ impl Agent for ReferenceAgent {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities.provider()
|
||||
}
|
||||
|
||||
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
|
||||
if let Err(e) = self.tool_result_tx.send((id, result)).await {
|
||||
tracing::error!("Failed to send tool result: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
register_agent!("reference", ReferenceAgent);
|
||||
|
||||
@@ -28,9 +28,7 @@ use crate::token_counter::TokenCounter;
|
||||
use crate::truncate::{truncate_messages, OldestFirstTruncation};
|
||||
use anyhow::{anyhow, Result};
|
||||
use indoc::indoc;
|
||||
use mcp_core::prompt::Prompt;
|
||||
use mcp_core::protocol::GetPromptResult;
|
||||
use mcp_core::{tool::Tool, Content};
|
||||
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolResult};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
|
||||
@@ -42,19 +40,22 @@ pub struct SummarizeAgent {
|
||||
token_counter: TokenCounter,
|
||||
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed)
|
||||
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>,
|
||||
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||
}
|
||||
|
||||
impl SummarizeAgent {
|
||||
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||
// Create channel with buffer size 32 (adjust if needed)
|
||||
let (tx, rx) = mpsc::channel(32);
|
||||
// Create channels with buffer size 32 (adjust if needed)
|
||||
let (confirm_tx, confirm_rx) = mpsc::channel(32);
|
||||
let (tool_tx, _tool_rx) = mpsc::channel(32);
|
||||
|
||||
Self {
|
||||
capabilities: Mutex::new(Capabilities::new(provider)),
|
||||
token_counter,
|
||||
confirmation_tx: tx,
|
||||
confirmation_rx: Mutex::new(rx),
|
||||
confirmation_tx: confirm_tx,
|
||||
confirmation_rx: Mutex::new(confirm_rx),
|
||||
tool_result_tx: tool_tx,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -493,6 +494,12 @@ impl Agent for SummarizeAgent {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities.provider()
|
||||
}
|
||||
|
||||
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
|
||||
if let Err(e) = self.tool_result_tx.send((id, result)).await {
|
||||
tracing::error!("Failed to send tool result: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
register_agent!("summarize", SummarizeAgent);
|
||||
|
||||
@@ -12,12 +12,13 @@ use tracing::{debug, error, instrument, warn};
|
||||
use super::agent::SessionConfig;
|
||||
use super::detect_read_only_tools;
|
||||
use super::extension::ToolInfo;
|
||||
use super::types::ToolResultReceiver;
|
||||
use super::Agent;
|
||||
use crate::agents::capabilities::{get_parameter_names, Capabilities};
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
|
||||
use crate::agents::ToolPermissionStore;
|
||||
use crate::config::Config;
|
||||
use crate::message::{Message, ToolRequest};
|
||||
use crate::message::{Message, MessageContent, ToolRequest};
|
||||
use crate::providers::base::Provider;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::providers::toolshim::{
|
||||
@@ -29,9 +30,9 @@ use crate::token_counter::TokenCounter;
|
||||
use crate::truncate::{truncate_messages, OldestFirstTruncation};
|
||||
use anyhow::{anyhow, Result};
|
||||
use indoc::indoc;
|
||||
use mcp_core::prompt::Prompt;
|
||||
use mcp_core::protocol::GetPromptResult;
|
||||
use mcp_core::{tool::Tool, Content, ToolError};
|
||||
use mcp_core::{
|
||||
prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult,
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -44,19 +45,24 @@ pub struct TruncateAgent {
|
||||
token_counter: TokenCounter,
|
||||
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed)
|
||||
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>,
|
||||
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||
tool_result_rx: ToolResultReceiver,
|
||||
}
|
||||
|
||||
impl TruncateAgent {
|
||||
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||
// Create channel with buffer size 32 (adjust if needed)
|
||||
let (tx, rx) = mpsc::channel(32);
|
||||
// Create channels with buffer size 32 (adjust if needed)
|
||||
let (confirm_tx, confirm_rx) = mpsc::channel(32);
|
||||
let (tool_tx, tool_rx) = mpsc::channel(32);
|
||||
|
||||
Self {
|
||||
capabilities: Mutex::new(Capabilities::new(provider)),
|
||||
token_counter,
|
||||
confirmation_tx: tx,
|
||||
confirmation_rx: Mutex::new(rx),
|
||||
confirmation_tx: confirm_tx,
|
||||
confirmation_rx: Mutex::new(confirm_rx),
|
||||
tool_result_tx: tool_tx,
|
||||
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -306,8 +312,21 @@ impl Agent for TruncateAgent {
|
||||
// Reset truncation attempt
|
||||
truncation_attempt = 0;
|
||||
|
||||
// Yield the assistant's response
|
||||
yield response.clone();
|
||||
// Yield the assistant's response, but filter out frontend tool requests that we'll process separately
|
||||
let filtered_response = Message {
|
||||
role: response.role.clone(),
|
||||
created: response.created,
|
||||
content: response.content.iter().filter(|c| {
|
||||
if let MessageContent::ToolRequest(req) = c {
|
||||
// Only filter out frontend tool requests
|
||||
if let Ok(tool_call) = &req.tool_call {
|
||||
return !capabilities.is_frontend_tool(&tool_call.name);
|
||||
}
|
||||
}
|
||||
true
|
||||
}).cloned().collect(),
|
||||
};
|
||||
yield filtered_response.clone();
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
@@ -323,6 +342,29 @@ impl Agent for TruncateAgent {
|
||||
|
||||
// Process tool requests depending on goose_mode
|
||||
let mut message_tool_response = Message::user();
|
||||
|
||||
// First handle any frontend tool requests
|
||||
let mut remaining_requests = Vec::new();
|
||||
for request in &tool_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
if capabilities.is_frontend_tool(&tool_call.name) {
|
||||
// Send frontend tool request and wait for response
|
||||
yield Message::assistant().with_frontend_tool_request(
|
||||
request.id.clone(),
|
||||
Ok(tool_call.clone())
|
||||
);
|
||||
|
||||
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
|
||||
message_tool_response = message_tool_response.with_tool_response(id, result);
|
||||
}
|
||||
} else {
|
||||
remaining_requests.push(request);
|
||||
}
|
||||
} else {
|
||||
remaining_requests.push(request);
|
||||
}
|
||||
}
|
||||
|
||||
// Clone goose_mode once before the match to avoid move issues
|
||||
let mode = goose_mode.clone();
|
||||
match mode.as_str() {
|
||||
@@ -334,7 +376,7 @@ impl Agent for TruncateAgent {
|
||||
|
||||
// First check permissions for all tools
|
||||
let store = ToolPermissionStore::load()?;
|
||||
for request in tool_requests.iter() {
|
||||
for request in remaining_requests.iter() {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
if tools_with_readonly_annotation.contains(&tool_call.name) {
|
||||
approved_tools.push((request.id.clone(), tool_call));
|
||||
@@ -427,7 +469,7 @@ impl Agent for TruncateAgent {
|
||||
},
|
||||
"chat" => {
|
||||
// Skip all tool calls in chat mode
|
||||
for request in &tool_requests {
|
||||
for request in &remaining_requests {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
Ok(vec![Content::text(
|
||||
@@ -449,7 +491,7 @@ impl Agent for TruncateAgent {
|
||||
}
|
||||
// Process tool requests in parallel
|
||||
let mut tool_futures = Vec::new();
|
||||
for request in &tool_requests {
|
||||
for request in &remaining_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
|
||||
tool_futures.push(tool_future);
|
||||
@@ -574,6 +616,12 @@ impl Agent for TruncateAgent {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities.provider()
|
||||
}
|
||||
|
||||
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
|
||||
if let Err(e) = self.tool_result_tx.send((id, result)).await {
|
||||
tracing::error!("Failed to send tool result: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
register_agent!("truncate", TruncateAgent);
|
||||
|
||||
6
crates/goose/src/agents/types.rs
Normal file
6
crates/goose/src/agents/types.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use mcp_core::{Content, ToolResult};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
/// Type alias for the tool result channel receiver
|
||||
pub type ToolResultReceiver = Arc<Mutex<mpsc::Receiver<(String, ToolResult<Vec<Content>>)>>>;
|
||||
@@ -70,6 +70,14 @@ pub struct RedactedThinkingContent {
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FrontendToolRequest {
|
||||
pub id: String,
|
||||
#[serde(with = "tool_result_serde")]
|
||||
pub tool_call: ToolResult<ToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||
/// Content passed inside a message, which can be both simple content and tool content
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
@@ -79,6 +87,7 @@ pub enum MessageContent {
|
||||
ToolRequest(ToolRequest),
|
||||
ToolResponse(ToolResponse),
|
||||
ToolConfirmationRequest(ToolConfirmationRequest),
|
||||
FrontendToolRequest(FrontendToolRequest),
|
||||
Thinking(ThinkingContent),
|
||||
RedactedThinking(RedactedThinkingContent),
|
||||
}
|
||||
@@ -137,6 +146,13 @@ impl MessageContent {
|
||||
pub fn redacted_thinking<S: Into<String>>(data: S) -> Self {
|
||||
MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() })
|
||||
}
|
||||
|
||||
pub fn frontend_tool_request<S: Into<String>>(id: S, tool_call: ToolResult<ToolCall>) -> Self {
|
||||
MessageContent::FrontendToolRequest(FrontendToolRequest {
|
||||
id: id.into(),
|
||||
tool_call,
|
||||
})
|
||||
}
|
||||
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
|
||||
if let MessageContent::ToolRequest(ref tool_request) = self {
|
||||
Some(tool_request)
|
||||
@@ -320,6 +336,14 @@ impl Message {
|
||||
))
|
||||
}
|
||||
|
||||
pub fn with_frontend_tool_request<S: Into<String>>(
|
||||
self,
|
||||
id: S,
|
||||
tool_call: ToolResult<ToolCall>,
|
||||
) -> Self {
|
||||
self.with_content(MessageContent::frontend_tool_request(id, tool_call))
|
||||
}
|
||||
|
||||
/// Add thinking content to the message
|
||||
pub fn with_thinking<S1: Into<String>, S2: Into<String>>(
|
||||
self,
|
||||
|
||||
@@ -74,6 +74,16 @@ pub fn format_messages(messages: &[Message]) -> Vec<Value> {
|
||||
}));
|
||||
}
|
||||
MessageContent::Image(_) => continue, // Anthropic doesn't support image content yet
|
||||
MessageContent::FrontendToolRequest(tool_request) => {
|
||||
if let Ok(tool_call) = &tool_request.tool_call {
|
||||
content.push(json!({
|
||||
"type": "tool_use",
|
||||
"id": tool_request.id,
|
||||
"name": tool_call.name,
|
||||
"input": tool_call.arguments
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -57,6 +57,21 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::C
|
||||
}?;
|
||||
bedrock::ContentBlock::ToolUse(tool_use)
|
||||
}
|
||||
MessageContent::FrontendToolRequest(tool_req) => {
|
||||
let tool_use_id = tool_req.id.to_string();
|
||||
let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() {
|
||||
bedrock::ToolUseBlock::builder()
|
||||
.tool_use_id(tool_use_id)
|
||||
.name(call.name.to_string())
|
||||
.input(to_bedrock_json(&call.arguments))
|
||||
.build()
|
||||
} else {
|
||||
bedrock::ToolUseBlock::builder()
|
||||
.tool_use_id(tool_use_id)
|
||||
.build()
|
||||
}?;
|
||||
bedrock::ContentBlock::ToolUse(tool_use)
|
||||
}
|
||||
MessageContent::ToolResponse(tool_res) => {
|
||||
let content = match &tool_res.tool_result {
|
||||
Ok(content) => Some(
|
||||
|
||||
@@ -188,6 +188,27 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
}
|
||||
}));
|
||||
}
|
||||
MessageContent::FrontendToolRequest(req) => {
|
||||
// Frontend tool requests are converted to text messages
|
||||
if let Ok(tool_call) = &req.tool_call {
|
||||
content_array.push(json!({
|
||||
"type": "text",
|
||||
"text": format!(
|
||||
"Frontend tool request: {} ({})",
|
||||
tool_call.name,
|
||||
serde_json::to_string_pretty(&tool_call.arguments).unwrap()
|
||||
)
|
||||
}));
|
||||
} else {
|
||||
content_array.push(json!({
|
||||
"type": "text",
|
||||
"text": format!(
|
||||
"Frontend tool request error: {}",
|
||||
req.tool_call.as_ref().unwrap_err()
|
||||
)
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -151,6 +151,32 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
// Handle direct image content
|
||||
converted["content"] = json!([convert_image(image, image_format)]);
|
||||
}
|
||||
MessageContent::FrontendToolRequest(request) => match &request.tool_call {
|
||||
Ok(tool_call) => {
|
||||
let sanitized_name = sanitize_function_name(&tool_call.name);
|
||||
let tool_calls = converted
|
||||
.as_object_mut()
|
||||
.unwrap()
|
||||
.entry("tool_calls")
|
||||
.or_insert(json!([]));
|
||||
|
||||
tool_calls.as_array_mut().unwrap().push(json!({
|
||||
"id": request.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": sanitized_name,
|
||||
"arguments": tool_call.arguments.to_string(),
|
||||
}
|
||||
}));
|
||||
}
|
||||
Err(e) => {
|
||||
output.push(json!({
|
||||
"role": "tool",
|
||||
"content": format!("Error: {}", e),
|
||||
"tool_call_id": request.id
|
||||
}));
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
210
examples/frontend_tools.py
Normal file
210
examples/frontend_tools.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import List, Dict, Any
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
|
||||
# Configuration
|
||||
GOOSE_HOST = "127.0.0.1"
|
||||
GOOSE_PORT = "3001"
|
||||
GOOSE_URL = f"http://{GOOSE_HOST}:{GOOSE_PORT}"
|
||||
SECRET_KEY = "test" # Default development secret key
|
||||
|
||||
# A simple calculator tool definition
|
||||
CALCULATOR_TOOL = {
|
||||
"name": "calculator",
|
||||
"description": "Perform basic arithmetic calculations",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"required": ["operation", "numbers"],
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["add", "subtract", "multiply", "divide"],
|
||||
"description": "The arithmetic operation to perform",
|
||||
},
|
||||
"numbers": {
|
||||
"type": "array",
|
||||
"items": {"type": "number"},
|
||||
"description": "List of numbers to operate on in order",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Frontend extension configuration
|
||||
FRONTEND_CONFIG = {
|
||||
"name": "pythonclient",
|
||||
"type": "frontend",
|
||||
"tools": [CALCULATOR_TOOL],
|
||||
"instructions": "A calculator extension that can perform basic arithmetic operations.",
|
||||
}
|
||||
|
||||
|
||||
async def setup_agent() -> None:
|
||||
"""Initialize the agent with our frontend tool."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# First create the agent
|
||||
response = await client.post(
|
||||
f"{GOOSE_URL}/agent",
|
||||
json={"provider": "databricks", "model": "goose"},
|
||||
headers={"X-Secret-Key": SECRET_KEY},
|
||||
)
|
||||
response.raise_for_status()
|
||||
print("Successfully created agent")
|
||||
|
||||
# Then add our frontend extension
|
||||
response = await client.post(
|
||||
f"{GOOSE_URL}/extensions/add",
|
||||
json=FRONTEND_CONFIG,
|
||||
headers={"X-Secret-Key": SECRET_KEY},
|
||||
)
|
||||
response.raise_for_status()
|
||||
print("Successfully added calculator extension")
|
||||
|
||||
|
||||
def execute_calculator(args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Execute the calculator tool with the given arguments."""
|
||||
operation = args["operation"]
|
||||
numbers = args["numbers"]
|
||||
|
||||
try:
|
||||
result = None
|
||||
if operation == "add":
|
||||
result = sum(numbers)
|
||||
elif operation == "subtract":
|
||||
result = numbers[0] - sum(numbers[1:])
|
||||
elif operation == "multiply":
|
||||
result = 1
|
||||
for n in numbers:
|
||||
result *= n
|
||||
elif operation == "divide":
|
||||
result = numbers[0]
|
||||
for n in numbers[1:]:
|
||||
result /= n
|
||||
|
||||
# Return properly structured Content::Text variant
|
||||
return [
|
||||
{
|
||||
"type": "text",
|
||||
"text": str(result),
|
||||
"annotations": None, # Required field in Rust struct
|
||||
}
|
||||
]
|
||||
except Exception as e:
|
||||
return [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Error: {str(e)}",
|
||||
"annotations": None, # Required field in Rust struct
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def submit_tool_result(tool_id: str, result: List[Dict[str, Any]]) -> None:
|
||||
"""Submit the tool execution result back to Goose.
|
||||
|
||||
The result should be a list of Content variants (Text, Image, or Resource).
|
||||
Each Content variant has a type tag and appropriate fields.
|
||||
"""
|
||||
payload = {
|
||||
"id": tool_id,
|
||||
"result": {
|
||||
"Ok": result # Result enum variant with single key for success case
|
||||
},
|
||||
}
|
||||
|
||||
with httpx.Client(timeout=2.0) as client:
|
||||
response = client.post(
|
||||
f"{GOOSE_URL}/tool_result",
|
||||
json=payload,
|
||||
headers={"X-Secret-Key": SECRET_KEY},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
async def chat_loop() -> None:
|
||||
"""Main chat loop that handles the conversation with Goose."""
|
||||
session_id = "test-session"
|
||||
|
||||
# Use a client with a longer timeout for streaming
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
# Get user input
|
||||
user_message = input("\nYou: ")
|
||||
if user_message.lower() in ["exit", "quit"]:
|
||||
return
|
||||
|
||||
# Create the message object
|
||||
message = {
|
||||
"role": "user",
|
||||
"created": int(datetime.now().timestamp()),
|
||||
"content": [{"type": "text", "text": user_message}],
|
||||
}
|
||||
|
||||
# Send to /reply endpoint
|
||||
payload = {
|
||||
"messages": [message],
|
||||
"session_id": session_id,
|
||||
"session_working_dir": os.getcwd(),
|
||||
}
|
||||
|
||||
# Process the stream of responses
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{GOOSE_URL}/reply",
|
||||
json=payload,
|
||||
headers={
|
||||
"X-Secret-Key": SECRET_KEY,
|
||||
"Accept": "text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
) as stream:
|
||||
async for line in stream.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Handle SSE format
|
||||
if line.startswith("data: "):
|
||||
line = line[6:] # Remove "data: " prefix
|
||||
|
||||
try:
|
||||
payload = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Failed to parse line: {line}")
|
||||
continue
|
||||
|
||||
if payload["type"] == "Finish":
|
||||
break
|
||||
|
||||
message = payload["message"]
|
||||
# Handle different message types
|
||||
for content in message.get("content", []):
|
||||
if content["type"] == "text":
|
||||
print(f"\nGoose: {content['text']}")
|
||||
elif content["type"] == "frontendToolRequest":
|
||||
# Execute the tool and submit results
|
||||
tool_call = content["toolCall"]["value"]
|
||||
print(f"Calculator: {tool_call}")
|
||||
# Execute the tool
|
||||
result = execute_calculator(tool_call["arguments"])
|
||||
|
||||
# Submit the result
|
||||
submit_tool_result(content["id"], result)
|
||||
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# Initialize the agent with our tool
|
||||
await setup_agent()
|
||||
|
||||
# Start the chat loop
|
||||
await chat_loop()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
raise # Re-raise to see full traceback
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user