mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 22:24:21 +01:00
feat: Enable frontend tools (#1778)
This commit is contained in:
@@ -8,6 +8,7 @@ use goose::{
|
|||||||
};
|
};
|
||||||
use http::{HeaderMap, StatusCode};
|
use http::{HeaderMap, StatusCode};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing;
|
||||||
|
|
||||||
/// Enum representing the different types of extension configuration requests.
|
/// Enum representing the different types of extension configuration requests.
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@@ -48,6 +49,16 @@ enum ExtensionConfigRequest {
|
|||||||
display_name: Option<String>,
|
display_name: Option<String>,
|
||||||
timeout: Option<u64>,
|
timeout: Option<u64>,
|
||||||
},
|
},
|
||||||
|
/// Frontend extension that provides tools to be executed by the frontend.
|
||||||
|
#[serde(rename = "frontend")]
|
||||||
|
Frontend {
|
||||||
|
/// The name to identify this extension
|
||||||
|
name: String,
|
||||||
|
/// The tools provided by this extension
|
||||||
|
tools: Vec<mcp_core::tool::Tool>,
|
||||||
|
/// Optional instructions for using the tools
|
||||||
|
instructions: Option<String>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Response structure for adding an extension.
|
/// Response structure for adding an extension.
|
||||||
@@ -64,8 +75,26 @@ struct ExtensionResponse {
|
|||||||
async fn add_extension(
|
async fn add_extension(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(request): Json<ExtensionConfigRequest>,
|
raw: axum::extract::Json<serde_json::Value>,
|
||||||
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
||||||
|
// Log the raw request for debugging
|
||||||
|
tracing::info!(
|
||||||
|
"Received extension request: {}",
|
||||||
|
serde_json::to_string_pretty(&raw.0).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Try to parse into our enum
|
||||||
|
let request: ExtensionConfigRequest = match serde_json::from_value(raw.0.clone()) {
|
||||||
|
Ok(req) => req,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to parse extension request: {}", e);
|
||||||
|
tracing::error!(
|
||||||
|
"Raw request was: {}",
|
||||||
|
serde_json::to_string_pretty(&raw.0).unwrap()
|
||||||
|
);
|
||||||
|
return Err(StatusCode::UNPROCESSABLE_ENTITY);
|
||||||
|
}
|
||||||
|
};
|
||||||
// Verify the presence and validity of the secret key.
|
// Verify the presence and validity of the secret key.
|
||||||
let secret_key = headers
|
let secret_key = headers
|
||||||
.get("X-Secret-Key")
|
.get("X-Secret-Key")
|
||||||
@@ -167,6 +196,15 @@ async fn add_extension(
|
|||||||
display_name,
|
display_name,
|
||||||
timeout,
|
timeout,
|
||||||
},
|
},
|
||||||
|
ExtensionConfigRequest::Frontend {
|
||||||
|
name,
|
||||||
|
tools,
|
||||||
|
instructions,
|
||||||
|
} => ExtensionConfig::Frontend {
|
||||||
|
name,
|
||||||
|
tools,
|
||||||
|
instructions,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// Acquire a lock on the agent and attempt to add the extension.
|
// Acquire a lock on the agent and attempt to add the extension.
|
||||||
|
|||||||
@@ -13,9 +13,9 @@ use goose::{
|
|||||||
agents::SessionConfig,
|
agents::SessionConfig,
|
||||||
message::{Message, MessageContent},
|
message::{Message, MessageContent},
|
||||||
};
|
};
|
||||||
|
use mcp_core::{role::Role, Content, ToolResult};
|
||||||
use mcp_core::role::Role;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::json;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::{
|
use std::{
|
||||||
convert::Infallible,
|
convert::Infallible,
|
||||||
@@ -391,12 +391,59 @@ async fn confirm_handler(
|
|||||||
Ok(Json(Value::Object(serde_json::Map::new())))
|
Ok(Json(Value::Object(serde_json::Map::new())))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ToolResultRequest {
|
||||||
|
id: String,
|
||||||
|
result: ToolResult<Vec<Content>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn submit_tool_result(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
headers: HeaderMap,
|
||||||
|
raw: axum::extract::Json<serde_json::Value>,
|
||||||
|
) -> Result<Json<Value>, StatusCode> {
|
||||||
|
// Log the raw request for debugging
|
||||||
|
tracing::info!(
|
||||||
|
"Received tool result request: {}",
|
||||||
|
serde_json::to_string_pretty(&raw.0).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Try to parse into our struct
|
||||||
|
let payload: ToolResultRequest = match serde_json::from_value(raw.0.clone()) {
|
||||||
|
Ok(req) => req,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Failed to parse tool result request: {}", e);
|
||||||
|
tracing::error!(
|
||||||
|
"Raw request was: {}",
|
||||||
|
serde_json::to_string_pretty(&raw.0).unwrap()
|
||||||
|
);
|
||||||
|
return Err(StatusCode::UNPROCESSABLE_ENTITY);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Verify secret key
|
||||||
|
let secret_key = headers
|
||||||
|
.get("X-Secret-Key")
|
||||||
|
.and_then(|value| value.to_str().ok())
|
||||||
|
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||||
|
|
||||||
|
if secret_key != state.secret_key {
|
||||||
|
return Err(StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
let agent = state.agent.read().await;
|
||||||
|
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
|
||||||
|
agent.handle_tool_result(payload.id, payload.result).await;
|
||||||
|
Ok(Json(json!({"status": "ok"})))
|
||||||
|
}
|
||||||
|
|
||||||
// Configure routes for this module
|
// Configure routes for this module
|
||||||
pub fn routes(state: AppState) -> Router {
|
pub fn routes(state: AppState) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/reply", post(handler))
|
.route("/reply", post(handler))
|
||||||
.route("/ask", post(ask_handler))
|
.route("/ask", post(ask_handler))
|
||||||
.route("/confirm", post(confirm_handler))
|
.route("/confirm", post(confirm_handler))
|
||||||
|
.route("/tool_result", post(submit_tool_result))
|
||||||
.with_state(state)
|
.with_state(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,7 @@ use super::extension::{ExtensionConfig, ExtensionResult};
|
|||||||
use crate::message::Message;
|
use crate::message::Message;
|
||||||
use crate::providers::base::Provider;
|
use crate::providers::base::Provider;
|
||||||
use crate::session;
|
use crate::session;
|
||||||
use mcp_core::prompt::Prompt;
|
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolResult};
|
||||||
use mcp_core::protocol::GetPromptResult;
|
|
||||||
|
|
||||||
/// Session configuration for an agent
|
/// Session configuration for an agent
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -68,4 +67,7 @@ pub trait Agent: Send + Sync {
|
|||||||
|
|
||||||
/// Get a reference to the provider used by this agent
|
/// Get a reference to the provider used by this agent
|
||||||
async fn provider(&self) -> Arc<Box<dyn Provider>>;
|
async fn provider(&self) -> Arc<Box<dyn Provider>>;
|
||||||
|
|
||||||
|
/// Handle a tool result from the frontend
|
||||||
|
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,9 +26,17 @@ static DEFAULT_TIMESTAMP: LazyLock<DateTime<Utc>> =
|
|||||||
|
|
||||||
type McpClientBox = Arc<Mutex<Box<dyn McpClientTrait>>>;
|
type McpClientBox = Arc<Mutex<Box<dyn McpClientTrait>>>;
|
||||||
|
|
||||||
|
/// A frontend tool that will be executed by the frontend rather than an extension
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FrontendTool {
|
||||||
|
pub name: String,
|
||||||
|
pub tool: Tool,
|
||||||
|
}
|
||||||
|
|
||||||
/// Manages MCP clients and their interactions
|
/// Manages MCP clients and their interactions
|
||||||
pub struct Capabilities {
|
pub struct Capabilities {
|
||||||
clients: HashMap<String, McpClientBox>,
|
clients: HashMap<String, McpClientBox>,
|
||||||
|
frontend_tools: HashMap<String, FrontendTool>,
|
||||||
instructions: HashMap<String, String>,
|
instructions: HashMap<String, String>,
|
||||||
resource_capable_extensions: HashSet<String>,
|
resource_capable_extensions: HashSet<String>,
|
||||||
provider: Arc<Box<dyn Provider>>,
|
provider: Arc<Box<dyn Provider>>,
|
||||||
@@ -96,6 +104,7 @@ impl Capabilities {
|
|||||||
pub fn new(provider: Box<dyn Provider>) -> Self {
|
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
clients: HashMap::new(),
|
clients: HashMap::new(),
|
||||||
|
frontend_tools: HashMap::new(),
|
||||||
instructions: HashMap::new(),
|
instructions: HashMap::new(),
|
||||||
resource_capable_extensions: HashSet::new(),
|
resource_capable_extensions: HashSet::new(),
|
||||||
provider: Arc::new(provider),
|
provider: Arc::new(provider),
|
||||||
@@ -111,6 +120,30 @@ impl Capabilities {
|
|||||||
/// Add a new MCP extension based on the provided client type
|
/// Add a new MCP extension based on the provided client type
|
||||||
// TODO IMPORTANT need to ensure this times out if the extension command is broken!
|
// TODO IMPORTANT need to ensure this times out if the extension command is broken!
|
||||||
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
|
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
|
||||||
|
let sanitized_name = normalize(config.key().to_string());
|
||||||
|
|
||||||
|
match &config {
|
||||||
|
ExtensionConfig::Frontend {
|
||||||
|
name: _,
|
||||||
|
tools,
|
||||||
|
instructions,
|
||||||
|
} => {
|
||||||
|
// For frontend tools, just store them in the frontend_tools map
|
||||||
|
for tool in tools {
|
||||||
|
let frontend_tool = FrontendTool {
|
||||||
|
name: tool.name.clone(),
|
||||||
|
tool: tool.clone(),
|
||||||
|
};
|
||||||
|
self.frontend_tools.insert(tool.name.clone(), frontend_tool);
|
||||||
|
}
|
||||||
|
// Store instructions if provided, using "frontend" as the key
|
||||||
|
if let Some(instructions) = instructions {
|
||||||
|
self.instructions
|
||||||
|
.insert("frontend".to_string(), instructions.clone());
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
let mut client: Box<dyn McpClientTrait> = match &config {
|
let mut client: Box<dyn McpClientTrait> = match &config {
|
||||||
ExtensionConfig::Sse {
|
ExtensionConfig::Sse {
|
||||||
uri, envs, timeout, ..
|
uri, envs, timeout, ..
|
||||||
@@ -142,10 +175,9 @@ impl Capabilities {
|
|||||||
);
|
);
|
||||||
Box::new(McpClient::new(service))
|
Box::new(McpClient::new(service))
|
||||||
}
|
}
|
||||||
#[allow(unused_variables)]
|
|
||||||
ExtensionConfig::Builtin {
|
ExtensionConfig::Builtin {
|
||||||
name,
|
name,
|
||||||
display_name,
|
display_name: _,
|
||||||
timeout,
|
timeout,
|
||||||
} => {
|
} => {
|
||||||
// For builtin extensions, we run the current executable with mcp and extension name
|
// For builtin extensions, we run the current executable with mcp and extension name
|
||||||
@@ -168,6 +200,7 @@ impl Capabilities {
|
|||||||
);
|
);
|
||||||
Box::new(McpClient::new(service))
|
Box::new(McpClient::new(service))
|
||||||
}
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize the client with default capabilities
|
// Initialize the client with default capabilities
|
||||||
@@ -182,8 +215,6 @@ impl Capabilities {
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| ExtensionError::Initialization(config.clone(), e))?;
|
.map_err(|e| ExtensionError::Initialization(config.clone(), e))?;
|
||||||
|
|
||||||
let sanitized_name = normalize(config.key().to_string());
|
|
||||||
|
|
||||||
// Store instructions if provided
|
// Store instructions if provided
|
||||||
if let Some(instructions) = init_result.instructions {
|
if let Some(instructions) = init_result.instructions {
|
||||||
self.instructions
|
self.instructions
|
||||||
@@ -202,6 +233,8 @@ impl Capabilities {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Add a system prompt extension
|
/// Add a system prompt extension
|
||||||
pub fn add_system_prompt_extension(&mut self, extension: String) {
|
pub fn add_system_prompt_extension(&mut self, extension: String) {
|
||||||
@@ -235,6 +268,13 @@ impl Capabilities {
|
|||||||
/// Get all tools from all clients with proper prefixing
|
/// Get all tools from all clients with proper prefixing
|
||||||
pub async fn get_prefixed_tools(&mut self) -> ExtensionResult<Vec<Tool>> {
|
pub async fn get_prefixed_tools(&mut self) -> ExtensionResult<Vec<Tool>> {
|
||||||
let mut tools = Vec::new();
|
let mut tools = Vec::new();
|
||||||
|
|
||||||
|
// Add frontend tools directly - they don't need prefixing since they're already uniquely named
|
||||||
|
for frontend_tool in self.frontend_tools.values() {
|
||||||
|
tools.push(frontend_tool.tool.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tools from MCP extensions with prefixing
|
||||||
for (name, client) in &self.clients {
|
for (name, client) in &self.clients {
|
||||||
let client_guard = client.lock().await;
|
let client_guard = client.lock().await;
|
||||||
let mut client_tools = client_guard.list_tools(None).await?;
|
let mut client_tools = client_guard.list_tools(None).await?;
|
||||||
@@ -317,7 +357,7 @@ impl Capabilities {
|
|||||||
pub async fn get_system_prompt(&self) -> String {
|
pub async fn get_system_prompt(&self) -> String {
|
||||||
let mut context: HashMap<&str, Value> = HashMap::new();
|
let mut context: HashMap<&str, Value> = HashMap::new();
|
||||||
|
|
||||||
let extensions_info: Vec<ExtensionInfo> = self
|
let mut extensions_info: Vec<ExtensionInfo> = self
|
||||||
.clients
|
.clients
|
||||||
.keys()
|
.keys()
|
||||||
.map(|name| {
|
.map(|name| {
|
||||||
@@ -326,6 +366,15 @@ impl Capabilities {
|
|||||||
ExtensionInfo::new(name, &instructions, has_resources)
|
ExtensionInfo::new(name, &instructions, has_resources)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
// Add frontend tools as a special extension if any exist
|
||||||
|
if !self.frontend_tools.is_empty() {
|
||||||
|
let name = "frontend";
|
||||||
|
let instructions = self.instructions.get(name).cloned().unwrap_or_else(||
|
||||||
|
"The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string()
|
||||||
|
);
|
||||||
|
extensions_info.push(ExtensionInfo::new(name, &instructions, false));
|
||||||
|
}
|
||||||
context.insert("extensions", serde_json::to_value(extensions_info).unwrap());
|
context.insert("extensions", serde_json::to_value(extensions_info).unwrap());
|
||||||
|
|
||||||
let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
|
let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
|
||||||
@@ -364,6 +413,16 @@ impl Capabilities {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if a tool is a frontend tool
|
||||||
|
pub fn is_frontend_tool(&self, name: &str) -> bool {
|
||||||
|
self.frontend_tools.contains_key(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a reference to a frontend tool
|
||||||
|
pub fn get_frontend_tool(&self, name: &str) -> Option<&FrontendTool> {
|
||||||
|
self.frontend_tools.get(name)
|
||||||
|
}
|
||||||
|
|
||||||
/// Find and return a reference to the appropriate client for a tool call
|
/// Find and return a reference to the appropriate client for a tool call
|
||||||
fn get_client_for_tool(&self, prefixed_name: &str) -> Option<(&str, McpClientBox)> {
|
fn get_client_for_tool(&self, prefixed_name: &str) -> Option<(&str, McpClientBox)> {
|
||||||
self.clients
|
self.clients
|
||||||
@@ -543,6 +602,11 @@ impl Capabilities {
|
|||||||
self.read_resource(tool_call.arguments.clone()).await
|
self.read_resource(tool_call.arguments.clone()).await
|
||||||
} else if tool_call.name == "platform__list_resources" {
|
} else if tool_call.name == "platform__list_resources" {
|
||||||
self.list_resources(tool_call.arguments.clone()).await
|
self.list_resources(tool_call.arguments.clone()).await
|
||||||
|
} else if self.is_frontend_tool(&tool_call.name) {
|
||||||
|
// For frontend tools, return an error indicating we need frontend execution
|
||||||
|
Err(ToolError::ExecutionError(
|
||||||
|
"Frontend tool execution required".to_string(),
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
// Else, dispatch tool call based on the prefix naming convention
|
// Else, dispatch tool call based on the prefix naming convention
|
||||||
let (client_name, client) = self
|
let (client_name, client) = self
|
||||||
|
|||||||
@@ -149,6 +149,16 @@ pub enum ExtensionConfig {
|
|||||||
display_name: Option<String>, // needed for the UI
|
display_name: Option<String>, // needed for the UI
|
||||||
timeout: Option<u64>,
|
timeout: Option<u64>,
|
||||||
},
|
},
|
||||||
|
/// Frontend-provided tools that will be called through the frontend
|
||||||
|
#[serde(rename = "frontend")]
|
||||||
|
Frontend {
|
||||||
|
/// The name used to identify this extension
|
||||||
|
name: String,
|
||||||
|
/// The tools provided by the frontend
|
||||||
|
tools: Vec<mcp_core::tool::Tool>,
|
||||||
|
/// Instructions for how to use these tools
|
||||||
|
instructions: Option<String>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ExtensionConfig {
|
impl Default for ExtensionConfig {
|
||||||
@@ -224,6 +234,7 @@ impl ExtensionConfig {
|
|||||||
Self::Sse { name, .. } => name,
|
Self::Sse { name, .. } => name,
|
||||||
Self::Stdio { name, .. } => name,
|
Self::Stdio { name, .. } => name,
|
||||||
Self::Builtin { name, .. } => name,
|
Self::Builtin { name, .. } => name,
|
||||||
|
Self::Frontend { name, .. } => name,
|
||||||
}
|
}
|
||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
@@ -239,6 +250,9 @@ impl std::fmt::Display for ExtensionConfig {
|
|||||||
write!(f, "Stdio({}: {} {})", name, cmd, args.join(" "))
|
write!(f, "Stdio({}: {} {})", name, cmd, args.join(" "))
|
||||||
}
|
}
|
||||||
ExtensionConfig::Builtin { name, .. } => write!(f, "Builtin({})", name),
|
ExtensionConfig::Builtin { name, .. } => write!(f, "Builtin({})", name),
|
||||||
|
ExtensionConfig::Frontend { name, tools, .. } => {
|
||||||
|
write!(f, "Frontend({}: {} tools)", name, tools.len())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ mod permission_store;
|
|||||||
mod reference;
|
mod reference;
|
||||||
mod summarize;
|
mod summarize;
|
||||||
mod truncate;
|
mod truncate;
|
||||||
|
mod types;
|
||||||
|
|
||||||
pub use agent::{Agent, SessionConfig};
|
pub use agent::{Agent, SessionConfig};
|
||||||
pub use capabilities::Capabilities;
|
pub use capabilities::Capabilities;
|
||||||
|
|||||||
@@ -4,12 +4,13 @@ use async_trait::async_trait;
|
|||||||
use futures::stream::BoxStream;
|
use futures::stream::BoxStream;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::{mpsc, Mutex};
|
||||||
use tracing::{debug, instrument};
|
use tracing::{debug, instrument};
|
||||||
|
|
||||||
use super::agent::SessionConfig;
|
use super::agent::SessionConfig;
|
||||||
use super::capabilities::get_parameter_names;
|
use super::capabilities::get_parameter_names;
|
||||||
use super::extension::ToolInfo;
|
use super::extension::ToolInfo;
|
||||||
|
use super::types::ToolResultReceiver;
|
||||||
use super::Agent;
|
use super::Agent;
|
||||||
use crate::agents::capabilities::Capabilities;
|
use crate::agents::capabilities::Capabilities;
|
||||||
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
|
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
|
||||||
@@ -19,23 +20,27 @@ use crate::token_counter::TokenCounter;
|
|||||||
use crate::{register_agent, session};
|
use crate::{register_agent, session};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use mcp_core::prompt::Prompt;
|
|
||||||
use mcp_core::protocol::GetPromptResult;
|
|
||||||
use mcp_core::tool::{Tool, ToolAnnotations};
|
use mcp_core::tool::{Tool, ToolAnnotations};
|
||||||
|
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolResult};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
/// Reference implementation of an Agent
|
/// Reference implementation of an Agent
|
||||||
pub struct ReferenceAgent {
|
pub struct ReferenceAgent {
|
||||||
capabilities: Mutex<Capabilities>,
|
capabilities: Mutex<Capabilities>,
|
||||||
_token_counter: TokenCounter,
|
_token_counter: TokenCounter,
|
||||||
|
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||||
|
tool_result_rx: ToolResultReceiver,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ReferenceAgent {
|
impl ReferenceAgent {
|
||||||
pub fn new(provider: Box<dyn Provider>) -> Self {
|
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||||
|
let (tx, rx) = mpsc::channel(32);
|
||||||
Self {
|
Self {
|
||||||
capabilities: Mutex::new(Capabilities::new(provider)),
|
capabilities: Mutex::new(Capabilities::new(provider)),
|
||||||
_token_counter: token_counter,
|
_token_counter: token_counter,
|
||||||
|
tool_result_tx: tx,
|
||||||
|
tool_result_rx: Arc::new(Mutex::new(rx)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -193,24 +198,32 @@ impl Agent for ReferenceAgent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Then dispatch each in parallel
|
// Then dispatch each in parallel
|
||||||
let futures: Vec<_> = tool_requests
|
|
||||||
.iter()
|
|
||||||
.filter_map(|request| request.tool_call.clone().ok())
|
|
||||||
.map(|tool_call| capabilities.dispatch_tool_call(tool_call))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Process all the futures in parallel but wait until all are finished
|
|
||||||
let outputs = futures::future::join_all(futures).await;
|
|
||||||
|
|
||||||
// Create a message with the responses
|
|
||||||
let mut message_tool_response = Message::user();
|
let mut message_tool_response = Message::user();
|
||||||
// Now combine these into MessageContent::ToolResponse using the original ID
|
for request in tool_requests {
|
||||||
for (request, output) in tool_requests.iter().zip(outputs.into_iter()) {
|
if let Ok(tool_call) = &request.tool_call {
|
||||||
|
// Check if it's a frontend tool
|
||||||
|
if capabilities.is_frontend_tool(&tool_call.name) {
|
||||||
|
// Send frontend tool request and wait for response
|
||||||
|
yield Message::assistant().with_frontend_tool_request(
|
||||||
|
request.id.clone(),
|
||||||
|
request.tool_call.clone()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Wait for the result using our channel
|
||||||
|
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
|
||||||
|
message_tool_response = message_tool_response.with_tool_response(id, result);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle regular tool calls
|
||||||
|
let result = capabilities.dispatch_tool_call(tool_call.clone()).await;
|
||||||
message_tool_response = message_tool_response.with_tool_response(
|
message_tool_response = message_tool_response.with_tool_response(
|
||||||
request.id.clone(),
|
request.id.clone(),
|
||||||
output,
|
result,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
yield message_tool_response.clone();
|
yield message_tool_response.clone();
|
||||||
|
|
||||||
@@ -278,6 +291,12 @@ impl Agent for ReferenceAgent {
|
|||||||
let capabilities = self.capabilities.lock().await;
|
let capabilities = self.capabilities.lock().await;
|
||||||
capabilities.provider()
|
capabilities.provider()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
|
||||||
|
if let Err(e) = self.tool_result_tx.send((id, result)).await {
|
||||||
|
tracing::error!("Failed to send tool result: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
register_agent!("reference", ReferenceAgent);
|
register_agent!("reference", ReferenceAgent);
|
||||||
|
|||||||
@@ -28,9 +28,7 @@ use crate::token_counter::TokenCounter;
|
|||||||
use crate::truncate::{truncate_messages, OldestFirstTruncation};
|
use crate::truncate::{truncate_messages, OldestFirstTruncation};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use mcp_core::prompt::Prompt;
|
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolResult};
|
||||||
use mcp_core::protocol::GetPromptResult;
|
|
||||||
use mcp_core::{tool::Tool, Content};
|
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
|
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
|
||||||
@@ -42,19 +40,22 @@ pub struct SummarizeAgent {
|
|||||||
token_counter: TokenCounter,
|
token_counter: TokenCounter,
|
||||||
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed)
|
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed)
|
||||||
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>,
|
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>,
|
||||||
|
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SummarizeAgent {
|
impl SummarizeAgent {
|
||||||
pub fn new(provider: Box<dyn Provider>) -> Self {
|
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||||
// Create channel with buffer size 32 (adjust if needed)
|
// Create channels with buffer size 32 (adjust if needed)
|
||||||
let (tx, rx) = mpsc::channel(32);
|
let (confirm_tx, confirm_rx) = mpsc::channel(32);
|
||||||
|
let (tool_tx, _tool_rx) = mpsc::channel(32);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
capabilities: Mutex::new(Capabilities::new(provider)),
|
capabilities: Mutex::new(Capabilities::new(provider)),
|
||||||
token_counter,
|
token_counter,
|
||||||
confirmation_tx: tx,
|
confirmation_tx: confirm_tx,
|
||||||
confirmation_rx: Mutex::new(rx),
|
confirmation_rx: Mutex::new(confirm_rx),
|
||||||
|
tool_result_tx: tool_tx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -493,6 +494,12 @@ impl Agent for SummarizeAgent {
|
|||||||
let capabilities = self.capabilities.lock().await;
|
let capabilities = self.capabilities.lock().await;
|
||||||
capabilities.provider()
|
capabilities.provider()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
|
||||||
|
if let Err(e) = self.tool_result_tx.send((id, result)).await {
|
||||||
|
tracing::error!("Failed to send tool result: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
register_agent!("summarize", SummarizeAgent);
|
register_agent!("summarize", SummarizeAgent);
|
||||||
|
|||||||
@@ -12,12 +12,13 @@ use tracing::{debug, error, instrument, warn};
|
|||||||
use super::agent::SessionConfig;
|
use super::agent::SessionConfig;
|
||||||
use super::detect_read_only_tools;
|
use super::detect_read_only_tools;
|
||||||
use super::extension::ToolInfo;
|
use super::extension::ToolInfo;
|
||||||
|
use super::types::ToolResultReceiver;
|
||||||
use super::Agent;
|
use super::Agent;
|
||||||
use crate::agents::capabilities::{get_parameter_names, Capabilities};
|
use crate::agents::capabilities::{get_parameter_names, Capabilities};
|
||||||
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
|
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
|
||||||
use crate::agents::ToolPermissionStore;
|
use crate::agents::ToolPermissionStore;
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::message::{Message, ToolRequest};
|
use crate::message::{Message, MessageContent, ToolRequest};
|
||||||
use crate::providers::base::Provider;
|
use crate::providers::base::Provider;
|
||||||
use crate::providers::errors::ProviderError;
|
use crate::providers::errors::ProviderError;
|
||||||
use crate::providers::toolshim::{
|
use crate::providers::toolshim::{
|
||||||
@@ -29,9 +30,9 @@ use crate::token_counter::TokenCounter;
|
|||||||
use crate::truncate::{truncate_messages, OldestFirstTruncation};
|
use crate::truncate::{truncate_messages, OldestFirstTruncation};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use mcp_core::prompt::Prompt;
|
use mcp_core::{
|
||||||
use mcp_core::protocol::GetPromptResult;
|
prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult,
|
||||||
use mcp_core::{tool::Tool, Content, ToolError};
|
};
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
@@ -44,19 +45,24 @@ pub struct TruncateAgent {
|
|||||||
token_counter: TokenCounter,
|
token_counter: TokenCounter,
|
||||||
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed)
|
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed)
|
||||||
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>,
|
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>,
|
||||||
|
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||||
|
tool_result_rx: ToolResultReceiver,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TruncateAgent {
|
impl TruncateAgent {
|
||||||
pub fn new(provider: Box<dyn Provider>) -> Self {
|
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||||
// Create channel with buffer size 32 (adjust if needed)
|
// Create channels with buffer size 32 (adjust if needed)
|
||||||
let (tx, rx) = mpsc::channel(32);
|
let (confirm_tx, confirm_rx) = mpsc::channel(32);
|
||||||
|
let (tool_tx, tool_rx) = mpsc::channel(32);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
capabilities: Mutex::new(Capabilities::new(provider)),
|
capabilities: Mutex::new(Capabilities::new(provider)),
|
||||||
token_counter,
|
token_counter,
|
||||||
confirmation_tx: tx,
|
confirmation_tx: confirm_tx,
|
||||||
confirmation_rx: Mutex::new(rx),
|
confirmation_rx: Mutex::new(confirm_rx),
|
||||||
|
tool_result_tx: tool_tx,
|
||||||
|
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,8 +312,21 @@ impl Agent for TruncateAgent {
|
|||||||
// Reset truncation attempt
|
// Reset truncation attempt
|
||||||
truncation_attempt = 0;
|
truncation_attempt = 0;
|
||||||
|
|
||||||
// Yield the assistant's response
|
// Yield the assistant's response, but filter out frontend tool requests that we'll process separately
|
||||||
yield response.clone();
|
let filtered_response = Message {
|
||||||
|
role: response.role.clone(),
|
||||||
|
created: response.created,
|
||||||
|
content: response.content.iter().filter(|c| {
|
||||||
|
if let MessageContent::ToolRequest(req) = c {
|
||||||
|
// Only filter out frontend tool requests
|
||||||
|
if let Ok(tool_call) = &req.tool_call {
|
||||||
|
return !capabilities.is_frontend_tool(&tool_call.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}).cloned().collect(),
|
||||||
|
};
|
||||||
|
yield filtered_response.clone();
|
||||||
|
|
||||||
tokio::task::yield_now().await;
|
tokio::task::yield_now().await;
|
||||||
|
|
||||||
@@ -323,6 +342,29 @@ impl Agent for TruncateAgent {
|
|||||||
|
|
||||||
// Process tool requests depending on goose_mode
|
// Process tool requests depending on goose_mode
|
||||||
let mut message_tool_response = Message::user();
|
let mut message_tool_response = Message::user();
|
||||||
|
|
||||||
|
// First handle any frontend tool requests
|
||||||
|
let mut remaining_requests = Vec::new();
|
||||||
|
for request in &tool_requests {
|
||||||
|
if let Ok(tool_call) = request.tool_call.clone() {
|
||||||
|
if capabilities.is_frontend_tool(&tool_call.name) {
|
||||||
|
// Send frontend tool request and wait for response
|
||||||
|
yield Message::assistant().with_frontend_tool_request(
|
||||||
|
request.id.clone(),
|
||||||
|
Ok(tool_call.clone())
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
|
||||||
|
message_tool_response = message_tool_response.with_tool_response(id, result);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
remaining_requests.push(request);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
remaining_requests.push(request);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Clone goose_mode once before the match to avoid move issues
|
// Clone goose_mode once before the match to avoid move issues
|
||||||
let mode = goose_mode.clone();
|
let mode = goose_mode.clone();
|
||||||
match mode.as_str() {
|
match mode.as_str() {
|
||||||
@@ -334,7 +376,7 @@ impl Agent for TruncateAgent {
|
|||||||
|
|
||||||
// First check permissions for all tools
|
// First check permissions for all tools
|
||||||
let store = ToolPermissionStore::load()?;
|
let store = ToolPermissionStore::load()?;
|
||||||
for request in tool_requests.iter() {
|
for request in remaining_requests.iter() {
|
||||||
if let Ok(tool_call) = request.tool_call.clone() {
|
if let Ok(tool_call) = request.tool_call.clone() {
|
||||||
if tools_with_readonly_annotation.contains(&tool_call.name) {
|
if tools_with_readonly_annotation.contains(&tool_call.name) {
|
||||||
approved_tools.push((request.id.clone(), tool_call));
|
approved_tools.push((request.id.clone(), tool_call));
|
||||||
@@ -427,7 +469,7 @@ impl Agent for TruncateAgent {
|
|||||||
},
|
},
|
||||||
"chat" => {
|
"chat" => {
|
||||||
// Skip all tool calls in chat mode
|
// Skip all tool calls in chat mode
|
||||||
for request in &tool_requests {
|
for request in &remaining_requests {
|
||||||
message_tool_response = message_tool_response.with_tool_response(
|
message_tool_response = message_tool_response.with_tool_response(
|
||||||
request.id.clone(),
|
request.id.clone(),
|
||||||
Ok(vec![Content::text(
|
Ok(vec![Content::text(
|
||||||
@@ -449,7 +491,7 @@ impl Agent for TruncateAgent {
|
|||||||
}
|
}
|
||||||
// Process tool requests in parallel
|
// Process tool requests in parallel
|
||||||
let mut tool_futures = Vec::new();
|
let mut tool_futures = Vec::new();
|
||||||
for request in &tool_requests {
|
for request in &remaining_requests {
|
||||||
if let Ok(tool_call) = request.tool_call.clone() {
|
if let Ok(tool_call) = request.tool_call.clone() {
|
||||||
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
|
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
|
||||||
tool_futures.push(tool_future);
|
tool_futures.push(tool_future);
|
||||||
@@ -574,6 +616,12 @@ impl Agent for TruncateAgent {
|
|||||||
let capabilities = self.capabilities.lock().await;
|
let capabilities = self.capabilities.lock().await;
|
||||||
capabilities.provider()
|
capabilities.provider()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
|
||||||
|
if let Err(e) = self.tool_result_tx.send((id, result)).await {
|
||||||
|
tracing::error!("Failed to send tool result: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
register_agent!("truncate", TruncateAgent);
|
register_agent!("truncate", TruncateAgent);
|
||||||
|
|||||||
6
crates/goose/src/agents/types.rs
Normal file
6
crates/goose/src/agents/types.rs
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
use mcp_core::{Content, ToolResult};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::{mpsc, Mutex};
|
||||||
|
|
||||||
|
/// Type alias for the tool result channel receiver
|
||||||
|
pub type ToolResultReceiver = Arc<Mutex<mpsc::Receiver<(String, ToolResult<Vec<Content>>)>>>;
|
||||||
@@ -70,6 +70,14 @@ pub struct RedactedThinkingContent {
|
|||||||
pub data: String,
|
pub data: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct FrontendToolRequest {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(with = "tool_result_serde")]
|
||||||
|
pub tool_call: ToolResult<ToolCall>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
|
||||||
/// Content passed inside a message, which can be both simple content and tool content
|
/// Content passed inside a message, which can be both simple content and tool content
|
||||||
#[serde(tag = "type", rename_all = "camelCase")]
|
#[serde(tag = "type", rename_all = "camelCase")]
|
||||||
@@ -79,6 +87,7 @@ pub enum MessageContent {
|
|||||||
ToolRequest(ToolRequest),
|
ToolRequest(ToolRequest),
|
||||||
ToolResponse(ToolResponse),
|
ToolResponse(ToolResponse),
|
||||||
ToolConfirmationRequest(ToolConfirmationRequest),
|
ToolConfirmationRequest(ToolConfirmationRequest),
|
||||||
|
FrontendToolRequest(FrontendToolRequest),
|
||||||
Thinking(ThinkingContent),
|
Thinking(ThinkingContent),
|
||||||
RedactedThinking(RedactedThinkingContent),
|
RedactedThinking(RedactedThinkingContent),
|
||||||
}
|
}
|
||||||
@@ -137,6 +146,13 @@ impl MessageContent {
|
|||||||
pub fn redacted_thinking<S: Into<String>>(data: S) -> Self {
|
pub fn redacted_thinking<S: Into<String>>(data: S) -> Self {
|
||||||
MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() })
|
MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn frontend_tool_request<S: Into<String>>(id: S, tool_call: ToolResult<ToolCall>) -> Self {
|
||||||
|
MessageContent::FrontendToolRequest(FrontendToolRequest {
|
||||||
|
id: id.into(),
|
||||||
|
tool_call,
|
||||||
|
})
|
||||||
|
}
|
||||||
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
|
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
|
||||||
if let MessageContent::ToolRequest(ref tool_request) = self {
|
if let MessageContent::ToolRequest(ref tool_request) = self {
|
||||||
Some(tool_request)
|
Some(tool_request)
|
||||||
@@ -320,6 +336,14 @@ impl Message {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_frontend_tool_request<S: Into<String>>(
|
||||||
|
self,
|
||||||
|
id: S,
|
||||||
|
tool_call: ToolResult<ToolCall>,
|
||||||
|
) -> Self {
|
||||||
|
self.with_content(MessageContent::frontend_tool_request(id, tool_call))
|
||||||
|
}
|
||||||
|
|
||||||
/// Add thinking content to the message
|
/// Add thinking content to the message
|
||||||
pub fn with_thinking<S1: Into<String>, S2: Into<String>>(
|
pub fn with_thinking<S1: Into<String>, S2: Into<String>>(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -74,6 +74,16 @@ pub fn format_messages(messages: &[Message]) -> Vec<Value> {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
MessageContent::Image(_) => continue, // Anthropic doesn't support image content yet
|
MessageContent::Image(_) => continue, // Anthropic doesn't support image content yet
|
||||||
|
MessageContent::FrontendToolRequest(tool_request) => {
|
||||||
|
if let Ok(tool_call) = &tool_request.tool_call {
|
||||||
|
content.push(json!({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tool_request.id,
|
||||||
|
"name": tool_call.name,
|
||||||
|
"input": tool_call.arguments
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -57,6 +57,21 @@ pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::C
|
|||||||
}?;
|
}?;
|
||||||
bedrock::ContentBlock::ToolUse(tool_use)
|
bedrock::ContentBlock::ToolUse(tool_use)
|
||||||
}
|
}
|
||||||
|
MessageContent::FrontendToolRequest(tool_req) => {
|
||||||
|
let tool_use_id = tool_req.id.to_string();
|
||||||
|
let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() {
|
||||||
|
bedrock::ToolUseBlock::builder()
|
||||||
|
.tool_use_id(tool_use_id)
|
||||||
|
.name(call.name.to_string())
|
||||||
|
.input(to_bedrock_json(&call.arguments))
|
||||||
|
.build()
|
||||||
|
} else {
|
||||||
|
bedrock::ToolUseBlock::builder()
|
||||||
|
.tool_use_id(tool_use_id)
|
||||||
|
.build()
|
||||||
|
}?;
|
||||||
|
bedrock::ContentBlock::ToolUse(tool_use)
|
||||||
|
}
|
||||||
MessageContent::ToolResponse(tool_res) => {
|
MessageContent::ToolResponse(tool_res) => {
|
||||||
let content = match &tool_res.tool_result {
|
let content = match &tool_res.tool_result {
|
||||||
Ok(content) => Some(
|
Ok(content) => Some(
|
||||||
|
|||||||
@@ -188,6 +188,27 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
|||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
MessageContent::FrontendToolRequest(req) => {
|
||||||
|
// Frontend tool requests are converted to text messages
|
||||||
|
if let Ok(tool_call) = &req.tool_call {
|
||||||
|
content_array.push(json!({
|
||||||
|
"type": "text",
|
||||||
|
"text": format!(
|
||||||
|
"Frontend tool request: {} ({})",
|
||||||
|
tool_call.name,
|
||||||
|
serde_json::to_string_pretty(&tool_call.arguments).unwrap()
|
||||||
|
)
|
||||||
|
}));
|
||||||
|
} else {
|
||||||
|
content_array.push(json!({
|
||||||
|
"type": "text",
|
||||||
|
"text": format!(
|
||||||
|
"Frontend tool request error: {}",
|
||||||
|
req.tool_call.as_ref().unwrap_err()
|
||||||
|
)
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -151,6 +151,32 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
|||||||
// Handle direct image content
|
// Handle direct image content
|
||||||
converted["content"] = json!([convert_image(image, image_format)]);
|
converted["content"] = json!([convert_image(image, image_format)]);
|
||||||
}
|
}
|
||||||
|
MessageContent::FrontendToolRequest(request) => match &request.tool_call {
|
||||||
|
Ok(tool_call) => {
|
||||||
|
let sanitized_name = sanitize_function_name(&tool_call.name);
|
||||||
|
let tool_calls = converted
|
||||||
|
.as_object_mut()
|
||||||
|
.unwrap()
|
||||||
|
.entry("tool_calls")
|
||||||
|
.or_insert(json!([]));
|
||||||
|
|
||||||
|
tool_calls.as_array_mut().unwrap().push(json!({
|
||||||
|
"id": request.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": sanitized_name,
|
||||||
|
"arguments": tool_call.arguments.to_string(),
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
output.push(json!({
|
||||||
|
"role": "tool",
|
||||||
|
"content": format!("Error: {}", e),
|
||||||
|
"tool_call_id": request.id
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
210
examples/frontend_tools.py
Normal file
210
examples/frontend_tools.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import httpx
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
GOOSE_HOST = "127.0.0.1"
|
||||||
|
GOOSE_PORT = "3001"
|
||||||
|
GOOSE_URL = f"http://{GOOSE_HOST}:{GOOSE_PORT}"
|
||||||
|
SECRET_KEY = "test" # Default development secret key
|
||||||
|
|
||||||
|
# A simple calculator tool definition
|
||||||
|
CALCULATOR_TOOL = {
|
||||||
|
"name": "calculator",
|
||||||
|
"description": "Perform basic arithmetic calculations",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["operation", "numbers"],
|
||||||
|
"properties": {
|
||||||
|
"operation": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["add", "subtract", "multiply", "divide"],
|
||||||
|
"description": "The arithmetic operation to perform",
|
||||||
|
},
|
||||||
|
"numbers": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "number"},
|
||||||
|
"description": "List of numbers to operate on in order",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Frontend extension configuration
|
||||||
|
FRONTEND_CONFIG = {
|
||||||
|
"name": "pythonclient",
|
||||||
|
"type": "frontend",
|
||||||
|
"tools": [CALCULATOR_TOOL],
|
||||||
|
"instructions": "A calculator extension that can perform basic arithmetic operations.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_agent() -> None:
|
||||||
|
"""Initialize the agent with our frontend tool."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# First create the agent
|
||||||
|
response = await client.post(
|
||||||
|
f"{GOOSE_URL}/agent",
|
||||||
|
json={"provider": "databricks", "model": "goose"},
|
||||||
|
headers={"X-Secret-Key": SECRET_KEY},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
print("Successfully created agent")
|
||||||
|
|
||||||
|
# Then add our frontend extension
|
||||||
|
response = await client.post(
|
||||||
|
f"{GOOSE_URL}/extensions/add",
|
||||||
|
json=FRONTEND_CONFIG,
|
||||||
|
headers={"X-Secret-Key": SECRET_KEY},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
print("Successfully added calculator extension")
|
||||||
|
|
||||||
|
|
||||||
|
def execute_calculator(args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
|
"""Execute the calculator tool with the given arguments."""
|
||||||
|
operation = args["operation"]
|
||||||
|
numbers = args["numbers"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = None
|
||||||
|
if operation == "add":
|
||||||
|
result = sum(numbers)
|
||||||
|
elif operation == "subtract":
|
||||||
|
result = numbers[0] - sum(numbers[1:])
|
||||||
|
elif operation == "multiply":
|
||||||
|
result = 1
|
||||||
|
for n in numbers:
|
||||||
|
result *= n
|
||||||
|
elif operation == "divide":
|
||||||
|
result = numbers[0]
|
||||||
|
for n in numbers[1:]:
|
||||||
|
result /= n
|
||||||
|
|
||||||
|
# Return properly structured Content::Text variant
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": str(result),
|
||||||
|
"annotations": None, # Required field in Rust struct
|
||||||
|
}
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"Error: {str(e)}",
|
||||||
|
"annotations": None, # Required field in Rust struct
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def submit_tool_result(tool_id: str, result: List[Dict[str, Any]]) -> None:
|
||||||
|
"""Submit the tool execution result back to Goose.
|
||||||
|
|
||||||
|
The result should be a list of Content variants (Text, Image, or Resource).
|
||||||
|
Each Content variant has a type tag and appropriate fields.
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"id": tool_id,
|
||||||
|
"result": {
|
||||||
|
"Ok": result # Result enum variant with single key for success case
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with httpx.Client(timeout=2.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{GOOSE_URL}/tool_result",
|
||||||
|
json=payload,
|
||||||
|
headers={"X-Secret-Key": SECRET_KEY},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_loop() -> None:
|
||||||
|
"""Main chat loop that handles the conversation with Goose."""
|
||||||
|
session_id = "test-session"
|
||||||
|
|
||||||
|
# Use a client with a longer timeout for streaming
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
# Get user input
|
||||||
|
user_message = input("\nYou: ")
|
||||||
|
if user_message.lower() in ["exit", "quit"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create the message object
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"created": int(datetime.now().timestamp()),
|
||||||
|
"content": [{"type": "text", "text": user_message}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send to /reply endpoint
|
||||||
|
payload = {
|
||||||
|
"messages": [message],
|
||||||
|
"session_id": session_id,
|
||||||
|
"session_working_dir": os.getcwd(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process the stream of responses
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{GOOSE_URL}/reply",
|
||||||
|
json=payload,
|
||||||
|
headers={
|
||||||
|
"X-Secret-Key": SECRET_KEY,
|
||||||
|
"Accept": "text/event-stream",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
) as stream:
|
||||||
|
async for line in stream.aiter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle SSE format
|
||||||
|
if line.startswith("data: "):
|
||||||
|
line = line[6:] # Remove "data: " prefix
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"Failed to parse line: {line}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if payload["type"] == "Finish":
|
||||||
|
break
|
||||||
|
|
||||||
|
message = payload["message"]
|
||||||
|
# Handle different message types
|
||||||
|
for content in message.get("content", []):
|
||||||
|
if content["type"] == "text":
|
||||||
|
print(f"\nGoose: {content['text']}")
|
||||||
|
elif content["type"] == "frontendToolRequest":
|
||||||
|
# Execute the tool and submit results
|
||||||
|
tool_call = content["toolCall"]["value"]
|
||||||
|
print(f"Calculator: {tool_call}")
|
||||||
|
# Execute the tool
|
||||||
|
result = execute_calculator(tool_call["arguments"])
|
||||||
|
|
||||||
|
# Submit the result
|
||||||
|
submit_tool_result(content["id"], result)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
try:
|
||||||
|
# Initialize the agent with our tool
|
||||||
|
await setup_agent()
|
||||||
|
|
||||||
|
# Start the chat loop
|
||||||
|
await chat_loop()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
raise # Re-raise to see full traceback
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Reference in New Issue
Block a user