refactor: remove agent flavours, move provider to Agent (#2091)

This commit is contained in:
Salman Mohammed
2025-04-09 15:02:47 -04:00
committed by GitHub
parent a8cbd81c61
commit 513d5c8f5a
30 changed files with 1297 additions and 2272 deletions

View File

@@ -3,7 +3,6 @@ use clap::{Args, Parser, Subcommand};
use goose::config::Config;
use crate::commands::agent_version::AgentCommand;
use crate::commands::bench::agent_generator;
use crate::commands::configure::handle_configure;
use crate::commands::info::handle_info;
@@ -279,9 +278,6 @@ enum Command {
builtin: Vec<String>,
},
/// List available agent versions
Agents(AgentCommand),
/// Update the Goose CLI version
#[command(about = "Update the goose CLI version")]
Update {
@@ -417,10 +413,6 @@ pub async fn cli() -> Result<()> {
return Ok(());
}
Some(Command::Agents(cmd)) => {
cmd.run()?;
return Ok(());
}
Some(Command::Update {
canary,
reconfigure,

View File

@@ -1,33 +0,0 @@
use anyhow::Result;
use clap::Args;
use goose::agents::AgentFactory;
use std::fmt::Write;
#[derive(Args)]
pub struct AgentCommand {}
impl AgentCommand {
pub fn run(&self) -> Result<()> {
let mut output = String::new();
writeln!(output, "Available agent versions:")?;
let versions = AgentFactory::available_versions();
let default_version = AgentFactory::default_version();
let configured_version = AgentFactory::configured_version();
for version in versions {
if version == default_version && version == configured_version {
writeln!(output, "* {} (default)", version)?;
} else if version == default_version {
writeln!(output, " {} (default)", version)?;
} else if version == configured_version {
writeln!(output, "* {}", version)?;
} else {
writeln!(output, " {}", version)?;
}
}
print!("{}", output);
Ok(())
}
}

View File

@@ -3,7 +3,8 @@ use console::style;
use goose::agents::{extension::Envs, ExtensionConfig};
use goose::config::extensions::name_to_key;
use goose::config::{
Config, ConfigError, ExperimentManager, ExtensionEntry, ExtensionManager, PermissionManager,
Config, ConfigError, ExperimentManager, ExtensionConfigManager, ExtensionEntry,
PermissionManager,
};
use goose::message::Message;
use goose::providers::{create, providers};
@@ -63,7 +64,7 @@ pub async fn handle_configure() -> Result<(), Box<dyn Error>> {
);
// Since we are setting up for the first time, we'll also enable the developer system
// This operation is best-effort and errors are ignored
ExtensionManager::set(ExtensionEntry {
ExtensionConfigManager::set(ExtensionEntry {
enabled: true,
config: ExtensionConfig::Builtin {
name: "developer".to_string(),
@@ -392,7 +393,7 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
/// Configure extensions that can be used with goose
/// Dialog for toggling which extensions are enabled/disabled
pub fn toggle_extensions_dialog() -> Result<(), Box<dyn Error>> {
let extensions = ExtensionManager::get_all()?;
let extensions = ExtensionConfigManager::get_all()?;
if extensions.is_empty() {
cliclack::outro(
@@ -430,7 +431,7 @@ pub fn toggle_extensions_dialog() -> Result<(), Box<dyn Error>> {
// Update enabled status for each extension
for name in extension_status.iter().map(|(name, _)| name) {
ExtensionManager::set_enabled(
ExtensionConfigManager::set_enabled(
&name_to_key(name),
selected.iter().any(|s| s.as_str() == name),
)?;
@@ -502,7 +503,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
let display_name = get_display_name(&extension);
ExtensionManager::set(ExtensionEntry {
ExtensionConfigManager::set(ExtensionEntry {
enabled: true,
config: ExtensionConfig::Builtin {
name: extension.clone(),
@@ -514,7 +515,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
cliclack::outro(format!("Enabled {} extension", style(extension).green()))?;
}
"stdio" => {
let extensions = ExtensionManager::get_all_names()?;
let extensions = ExtensionConfigManager::get_all_names()?;
let name: String = cliclack::input("What would you like to call this extension?")
.placeholder("my-extension")
.validate(move |input: &String| {
@@ -590,7 +591,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
}
}
ExtensionManager::set(ExtensionEntry {
ExtensionConfigManager::set(ExtensionEntry {
enabled: true,
config: ExtensionConfig::Stdio {
name: name.clone(),
@@ -605,7 +606,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
cliclack::outro(format!("Added {} extension", style(name).green()))?;
}
"sse" => {
let extensions = ExtensionManager::get_all_names()?;
let extensions = ExtensionConfigManager::get_all_names()?;
let name: String = cliclack::input("What would you like to call this extension?")
.placeholder("my-remote-extension")
.validate(move |input: &String| {
@@ -677,7 +678,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
}
}
ExtensionManager::set(ExtensionEntry {
ExtensionConfigManager::set(ExtensionEntry {
enabled: true,
config: ExtensionConfig::Sse {
name: name.clone(),
@@ -697,7 +698,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
}
pub fn remove_extension_dialog() -> Result<(), Box<dyn Error>> {
let extensions = ExtensionManager::get_all()?;
let extensions = ExtensionConfigManager::get_all()?;
// Create a list of extension names and their enabled status
let extension_status: Vec<(String, bool)> = extensions
@@ -739,7 +740,7 @@ pub fn remove_extension_dialog() -> Result<(), Box<dyn Error>> {
.interact()?;
for name in selected {
ExtensionManager::remove(&name_to_key(name))?;
ExtensionConfigManager::remove(&name_to_key(name))?;
let mut permission_manager = PermissionManager::default();
permission_manager.remove_extension(&name_to_key(name));
cliclack::outro(format!("Removed {} extension", style(name).green()))?;

View File

@@ -1,4 +1,3 @@
pub mod agent_version;
pub mod bench;
pub mod configure;
pub mod info;

View File

@@ -1,7 +1,7 @@
use console::style;
use goose::agents::extension::ExtensionError;
use goose::agents::AgentFactory;
use goose::config::{Config, ExtensionManager};
use goose::agents::Agent;
use goose::config::{Config, ExtensionConfigManager};
use goose::session;
use goose::session::Identifier;
use mcp_client::transport::Error as McpClientError;
@@ -33,8 +33,7 @@ pub async fn build_session(
goose::providers::create(&provider_name, model_config).expect("Failed to create provider");
// Create the agent
let mut agent = AgentFactory::create(&AgentFactory::configured_version(), provider)
.expect("Failed to create agent");
let mut agent = Agent::new(provider);
// Handle session file resolution and resuming
let session_file = if resume {
@@ -93,7 +92,7 @@ pub async fn build_session(
// Setup extensions for the agent
// Extensions need to be added after the session is created because we change directory when resuming a session
for extension in ExtensionManager::get_all().expect("should load extensions") {
for extension in ExtensionConfigManager::get_all().expect("should load extensions") {
if extension.enabled {
let config = extension.config.clone();
agent

View File

@@ -38,7 +38,7 @@ pub enum RunMode {
}
pub struct Session {
agent: Box<dyn Agent>,
agent: Agent,
messages: Vec<Message>,
session_file: PathBuf,
// Cache for completion data - using std::sync for thread safety without async
@@ -76,7 +76,7 @@ pub enum PlannerResponseType {
/// question.
pub async fn classify_planner_response(
message_text: String,
provider: Arc<Box<dyn Provider>>,
provider: Arc<dyn Provider>,
) -> Result<PlannerResponseType> {
let prompt = format!("The text below is the output from an AI model which can either provide a plan or list of clarifying questions. Based on the text below, decide if the output is a \"plan\" or \"clarifying questions\".\n---\n{message_text}");
@@ -101,7 +101,7 @@ pub async fn classify_planner_response(
}
impl Session {
pub fn new(agent: Box<dyn Agent>, session_file: PathBuf, debug: bool) -> Self {
pub fn new(agent: Agent, session_file: PathBuf, debug: bool) -> Self {
let messages = match session::read_messages(&session_file) {
Ok(msgs) => msgs,
Err(e) => {
@@ -278,7 +278,7 @@ impl Session {
async fn process_message(&mut self, message: String) -> Result<()> {
self.messages.push(Message::user().with_text(&message));
// Get the provider from the agent for description generation
let provider = self.agent.provider().await;
let provider = self.agent.provider();
// Persist messages with provider for automatic description generation
session::persist_messages(&self.session_file, &self.messages, Some(provider)).await?;
@@ -350,7 +350,7 @@ impl Session {
self.messages.push(Message::user().with_text(&content));
// Get the provider from the agent for description generation
let provider = self.agent.provider().await;
let provider = self.agent.provider();
// Persist messages with provider for automatic description generation
session::persist_messages(
@@ -535,7 +535,7 @@ impl Session {
async fn plan_with_reasoner_model(
&mut self,
plan_messages: Vec<Message>,
reasoner: Box<dyn Provider + Send + Sync>,
reasoner: Arc<dyn Provider>,
) -> Result<(), anyhow::Error> {
let plan_prompt = self.agent.get_plan_prompt().await?;
output::show_thinking();
@@ -543,7 +543,7 @@ impl Session {
output::render_message(&plan_response, self.debug);
output::hide_thinking();
let planner_response_type =
classify_planner_response(plan_response.as_concat_text(), self.agent.provider().await)
classify_planner_response(plan_response.as_concat_text(), self.agent.provider())
.await?;
match planner_response_type {
@@ -857,7 +857,7 @@ impl Session {
}
}
fn get_reasoner() -> Result<Box<dyn Provider + Send + Sync>, anyhow::Error> {
fn get_reasoner() -> Result<Arc<dyn Provider>, anyhow::Error> {
use goose::model::ModelConfig;
use goose::providers::create;

View File

@@ -5,11 +5,10 @@ use axum::{
routing::{get, post},
Json, Router,
};
use goose::{agents::AgentFactory, config::PermissionManager, model::ModelConfig, providers};
use goose::{
agents::{capabilities::get_parameter_names, extension::ToolInfo},
config::Config,
};
use goose::agents::{extension::ToolInfo, extension_manager::get_parameter_names};
use goose::config::Config;
use goose::config::PermissionManager;
use goose::{agents::Agent, model::ModelConfig, providers};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
@@ -32,7 +31,6 @@ struct ExtendPromptResponse {
#[derive(Deserialize)]
struct CreateAgentRequest {
version: Option<String>,
provider: String,
model: Option<String>,
}
@@ -70,8 +68,8 @@ pub struct GetToolsQuery {
}
async fn get_versions() -> Json<VersionsResponse> {
let versions = AgentFactory::available_versions();
let default_version = AgentFactory::default_version().to_string();
let versions = ["goose".to_string()];
let default_version = "goose".to_string();
Json(VersionsResponse {
available_versions: versions.iter().map(|v| v.to_string()).collect(),
@@ -136,11 +134,8 @@ async fn create_agent(
let provider =
providers::create(&payload.provider, model_config).expect("Failed to create provider");
let version = payload
.version
.unwrap_or_else(|| AgentFactory::default_version().to_string());
let new_agent = AgentFactory::create(&version, provider).expect("Failed to create agent");
let version = String::from("goose");
let new_agent = Agent::new(provider);
let mut agent = state.agent.write().await;
*agent = Some(new_agent);

View File

@@ -8,7 +8,7 @@ use axum::{
use goose::agents::ExtensionConfig;
use goose::config::extensions::name_to_key;
use goose::config::Config;
use goose::config::{ExtensionEntry, ExtensionManager};
use goose::config::{ExtensionConfigManager, ExtensionEntry};
use goose::providers::base::ProviderMetadata;
use goose::providers::providers as get_providers;
use http::{HeaderMap, StatusCode};
@@ -184,7 +184,7 @@ pub async fn get_extensions(
) -> Result<Json<ExtensionResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
match ExtensionManager::get_all() {
match ExtensionConfigManager::get_all() {
Ok(extensions) => Ok(Json(ExtensionResponse { extensions })),
Err(err) => {
// Return UNPROCESSABLE_ENTITY only for DeserializeError, INTERNAL_SERVER_ERROR for everything else
@@ -219,13 +219,13 @@ pub async fn add_extension(
verify_secret_key(&headers, &state)?;
// Get existing extensions to check if this is an update
let extensions = ExtensionManager::get_all().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let extensions =
ExtensionConfigManager::get_all().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let key = name_to_key(&extension_query.name);
let is_update = extensions.iter().any(|e| e.config.key() == key);
// Use ExtensionManager to set the extension
match ExtensionManager::set(ExtensionEntry {
match ExtensionConfigManager::set(ExtensionEntry {
enabled: extension_query.enabled,
config: extension_query.config,
}) {
@@ -257,8 +257,7 @@ pub async fn remove_extension(
verify_secret_key(&headers, &state)?;
let key = name_to_key(&name);
// Use ExtensionManager to remove the extension
match ExtensionManager::remove(&key) {
match ExtensionConfigManager::remove(&key) {
Ok(_) => Ok(Json(format!("Removed extension {}", name))),
Err(_) => Err(StatusCode::NOT_FOUND),
}

View File

@@ -153,7 +153,7 @@ async fn handler(
};
// Get the provider first, before starting the reply stream
let provider = agent.provider().await;
let provider = agent.provider();
let mut stream = match agent
.reply(
@@ -294,7 +294,7 @@ async fn ask_handler(
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
// Get the provider first, before starting the reply stream
let provider = agent.provider().await;
let provider = agent.provider();
// Create a single message for the prompt
let messages = vec![Message::user().with_text(request.prompt)];
@@ -467,7 +467,7 @@ pub fn routes(state: AppState) -> Router {
mod tests {
use super::*;
use goose::{
agents::AgentFactory,
agents::Agent,
model::ModelConfig,
providers::{
base::{Provider, ProviderUsage, Usage},
@@ -518,10 +518,10 @@ mod tests {
async fn test_ask_endpoint() {
// Create a mock app state with mock provider
let mock_model_config = ModelConfig::new("test-model".to_string());
let mock_provider = Box::new(MockProvider {
let mock_provider = Arc::new(MockProvider {
model_config: mock_model_config,
});
let agent = AgentFactory::create("reference", mock_provider).unwrap();
let agent = Agent::new(mock_provider);
let state = AppState {
config: Arc::new(Mutex::new(HashMap::new())),
agent: Arc::new(RwLock::new(Some(agent))),

View File

@@ -9,7 +9,7 @@ use tokio::sync::{Mutex, RwLock};
#[allow(dead_code)]
#[derive(Clone)]
pub struct AppState {
pub agent: Arc<RwLock<Option<Box<dyn Agent>>>>,
pub agent: Arc<RwLock<Option<Agent>>>,
pub secret_key: String,
pub config: Arc<Mutex<HashMap<String, Value>>>,
}

View File

@@ -1,6 +1,8 @@
use std::sync::Arc;
use dotenv::dotenv;
use futures::StreamExt;
use goose::agents::{AgentFactory, ExtensionConfig};
use goose::agents::{Agent, ExtensionConfig};
use goose::config::{DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT};
use goose::message::Message;
use goose::providers::databricks::DatabricksProvider;
@@ -10,10 +12,10 @@ async fn main() {
// Setup a model provider from env vars
let _ = dotenv();
let provider = Box::new(DatabricksProvider::default());
let provider = Arc::new(DatabricksProvider::default());
// Setup an agent with the developer extension
let mut agent = AgentFactory::create("reference", provider).expect("default should exist");
let mut agent = Agent::new(provider);
let config = ExtensionConfig::stdio(
"developer",

View File

@@ -1,76 +1,804 @@
use std::collections::HashMap;
use std::path::PathBuf;
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use super::extension::{ExtensionConfig, ExtensionResult};
use crate::providers::base::Provider;
use crate::session;
use crate::{message::Message, permission::PermissionConfirmation};
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, Tool, ToolResult};
use anyhow::{anyhow, Result};
use futures::stream::BoxStream;
/// Session configuration for an agent
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionConfig {
/// Unique identifier for the session
pub id: session::Identifier,
/// Working directory for the session
pub working_dir: PathBuf,
use serde_json::Value;
use tokio::sync::{mpsc, Mutex};
use tracing::{debug, error, instrument, warn};
use crate::agents::extension::{ExtensionConfig, ExtensionResult, ToolInfo};
use crate::agents::extension_manager::{get_parameter_names, ExtensionManager};
use crate::agents::types::ToolResultReceiver;
use crate::config::{Config, ExtensionConfigManager};
use crate::message::{Message, MessageContent, ToolRequest};
use crate::permission::{
detect_read_only_tools, Permission, PermissionConfirmation, ToolPermissionStore,
};
use crate::providers::base::Provider;
use crate::providers::errors::ProviderError;
use crate::providers::toolshim::{
augment_message_with_tool_calls, modify_system_prompt_for_tool_json, OllamaInterpreter,
};
use crate::session;
use crate::token_counter::TokenCounter;
use crate::truncate::{truncate_messages, OldestFirstTruncation};
use mcp_core::{
prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult,
};
use crate::agents::platform_tools::{
self, PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME,
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
};
use crate::agents::prompt_manager::PromptManager;
use crate::agents::types::SessionConfig;
use super::platform_tools::PLATFORM_ENABLE_EXTENSION_TOOL_NAME;
use super::types::FrontendTool;
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
/// The main goose Agent
pub struct Agent {
provider: Arc<dyn Provider>,
extension_manager: Mutex<ExtensionManager>,
frontend_tools: HashMap<String, FrontendTool>,
frontend_instructions: Option<String>,
prompt_manager: PromptManager,
token_counter: TokenCounter,
confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
tool_result_rx: ToolResultReceiver,
}
/// Core trait defining the behavior of an Agent
#[async_trait]
pub trait Agent: Send + Sync {
/// Create a stream that yields each message as it's generated by the agent
async fn reply(
impl Agent {
pub fn new(provider: Arc<dyn Provider>) -> Self {
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
// Create channels with buffer size 32 (adjust if needed)
let (confirm_tx, confirm_rx) = mpsc::channel(32);
let (tool_tx, tool_rx) = mpsc::channel(32);
Self {
provider,
extension_manager: Mutex::new(ExtensionManager::new()),
frontend_tools: HashMap::new(),
frontend_instructions: None,
prompt_manager: PromptManager::new(),
token_counter,
confirmation_tx: confirm_tx,
confirmation_rx: Mutex::new(confirm_rx),
tool_result_tx: tool_tx,
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
}
}
/// Get a reference count clone to the provider
pub fn provider(&self) -> Arc<dyn Provider> {
Arc::clone(&self.provider)
}
/// Check if a tool is a frontend tool
pub fn is_frontend_tool(&self, name: &str) -> bool {
self.frontend_tools.contains_key(name)
}
/// Get a reference to a frontend tool
pub fn get_frontend_tool(&self, name: &str) -> Option<&FrontendTool> {
self.frontend_tools.get(name)
}
/// Get all tools from all clients with proper prefixing
pub async fn get_prefixed_tools(&mut self) -> ExtensionResult<Vec<Tool>> {
let mut tools = self
.extension_manager
.lock()
.await
.get_prefixed_tools()
.await?;
// Add frontend tools directly - they don't need prefixing since they're already uniquely named
for frontend_tool in self.frontend_tools.values() {
tools.push(frontend_tool.tool.clone());
}
Ok(tools)
}
/// Dispatch a single tool call to the appropriate client
#[instrument(skip(tool_call, extension_manager, request_id), fields(input, output))]
async fn create_tool_future(
extension_manager: &ExtensionManager,
tool_call: mcp_core::tool::ToolCall,
is_frontend_tool: bool,
request_id: String,
) -> (String, Result<Vec<Content>, ToolError>) {
let result = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
// Check if the tool is read_resource and handle it separately
extension_manager
.read_resource(tool_call.arguments.clone())
.await
} else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME {
extension_manager
.list_resources(tool_call.arguments.clone())
.await
} else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME {
extension_manager.search_available_extensions().await
} else if is_frontend_tool {
// For frontend tools, return an error indicating we need frontend execution
Err(ToolError::ExecutionError(
"Frontend tool execution required".to_string(),
))
} else {
extension_manager
.dispatch_tool_call(tool_call.clone())
.await
};
debug!(
"input" = serde_json::to_string(&tool_call).unwrap(),
"output" = serde_json::to_string(&result).unwrap(),
);
(request_id, result)
}
/// Truncates the messages to fit within the model's context window
/// Ensures the last message is a user message and removes tool call-response pairs
async fn truncate_messages(
&self,
messages: &mut Vec<Message>,
estimate_factor: f32,
system_prompt: &str,
tools: &mut Vec<Tool>,
) -> anyhow::Result<()> {
// Model's actual context limit
let context_limit = self.provider.get_model_config().context_limit();
// Our conservative estimate of the **target** context limit
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
let context_limit = (context_limit as f32 * estimate_factor) as usize;
// Take into account the system prompt, and our tools input and subtract that from the
// remaining context limit
let system_prompt_token_count = self.token_counter.count_tokens(system_prompt);
let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice());
// Check if system prompt + tools exceed our context limit
let remaining_tokens = context_limit
.checked_sub(system_prompt_token_count)
.and_then(|remaining| remaining.checked_sub(tools_token_count))
.ok_or_else(|| {
anyhow::anyhow!("System prompt and tools exceed estimated context limit")
})?;
let context_limit = remaining_tokens;
// Calculate current token count of each message, use count_chat_tokens to ensure we
// capture the full content of the message, include ToolRequests and ToolResponses
let mut token_counts: Vec<usize> = messages
.iter()
.map(|msg| {
self.token_counter
.count_chat_tokens("", std::slice::from_ref(msg), &[])
})
.collect();
truncate_messages(
messages,
&mut token_counts,
context_limit,
&OldestFirstTruncation,
)
}
async fn enable_extension(
extension_manager: &mut ExtensionManager,
extension_name: String,
request_id: String,
) -> (String, Result<Vec<Content>, ToolError>) {
let config = match ExtensionConfigManager::get_config_by_name(&extension_name) {
Ok(Some(config)) => config,
Ok(None) => {
return (
request_id,
Err(ToolError::ExecutionError(format!(
"Extension '{}' not found. Please check the extension name and try again.",
extension_name
))),
)
}
Err(e) => {
return (
request_id,
Err(ToolError::ExecutionError(format!(
"Failed to get extension config: {}",
e
))),
)
}
};
let result = extension_manager
.add_extension(config)
.await
.map(|_| {
vec![Content::text(format!(
"The extension '{}' has been installed successfully",
extension_name
))]
})
.map_err(|e| ToolError::ExecutionError(e.to_string()));
(request_id, result)
}
pub async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> {
match &extension {
ExtensionConfig::Frontend {
name: _,
tools,
instructions,
} => {
// For frontend tools, just store them in the frontend_tools map
for tool in tools {
let frontend_tool = FrontendTool {
name: tool.name.clone(),
tool: tool.clone(),
};
self.frontend_tools.insert(tool.name.clone(), frontend_tool);
}
// Store instructions if provided, using "frontend" as the key
if let Some(instructions) = instructions {
self.frontend_instructions = Some(instructions.clone());
} else {
// Default frontend instructions if none provided
self.frontend_instructions = Some(
"The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string(),
);
}
}
_ => {
let mut extension_manager = self.extension_manager.lock().await;
let _ = extension_manager.add_extension(extension).await;
}
};
Ok(())
}
pub async fn list_tools(&self) -> Vec<Tool> {
let mut extension_manager = self.extension_manager.lock().await;
extension_manager
.get_prefixed_tools()
.await
.unwrap_or_default()
}
pub async fn remove_extension(&mut self, name: &str) {
let mut extension_manager = self.extension_manager.lock().await;
extension_manager
.remove_extension(name)
.await
.expect("Failed to remove extension");
}
pub async fn list_extensions(&self) -> Vec<String> {
let extension_manager = self.extension_manager.lock().await;
extension_manager
.list_extensions()
.await
.expect("Failed to list extensions")
}
/// Handle a confirmation response for a tool request
pub async fn handle_confirmation(
&self,
request_id: String,
confirmation: PermissionConfirmation,
) {
if let Err(e) = self.confirmation_tx.send((request_id, confirmation)).await {
error!("Failed to send confirmation: {}", e);
}
}
#[instrument(skip(self, messages, session), fields(user_message))]
pub async fn reply(
&self,
messages: &[Message],
session: Option<SessionConfig>,
) -> Result<BoxStream<'_, Result<Message>>>;
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
let mut messages = messages.to_vec();
let reply_span = tracing::Span::current();
let mut extension_manager = self.extension_manager.lock().await;
let mut tools = extension_manager.get_prefixed_tools().await?;
let mut truncation_attempt: usize = 0;
/// Add a new MCP client to the agent
async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()>;
// Load settings from config
let config = Config::global();
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
/// Remove an extension by name
async fn remove_extension(&mut self, name: &str);
// we add in the 2 resource tools if any extensions support resources
// TODO: make sure there is no collision with another extension's tool name
if extension_manager.supports_resources() {
tools.push(platform_tools::read_resource_tool());
tools.push(platform_tools::list_resources_tool());
}
tools.push(platform_tools::search_available_extensions_tool());
tools.push(platform_tools::enable_extension_tool());
/// List all extensions
// TODO this needs to also include status so we can tell if extensions are dropped
async fn list_extensions(&self) -> Vec<String>;
let (tools_with_readonly_annotation, tools_without_annotation): (Vec<String>, Vec<String>) =
tools.iter().fold((vec![], vec![]), |mut acc, tool| {
match &tool.annotations {
Some(annotations) => {
if annotations.read_only_hint {
acc.0.push(tool.name.clone());
}
}
None => {
acc.1.push(tool.name.clone());
}
}
acc
});
/// List the tools this agent has access to
async fn list_tools(&self) -> Vec<Tool>;
let config = self.provider.get_model_config();
let extensions_info = extension_manager.get_extensions_info().await;
let mut system_prompt = self
.prompt_manager
.build_system_prompt(extensions_info, self.frontend_instructions.clone());
let mut toolshim_tools = vec![];
if config.toolshim {
// If tool interpretation is enabled, modify the system prompt to instruct to return JSON tool requests
system_prompt = modify_system_prompt_for_tool_json(&system_prompt, &tools);
// make a copy of tools before empty
toolshim_tools = tools.clone();
// pass empty tools vector to provider completion since toolshim will handle tool calls instead
tools = vec![];
}
/// Pass through a JSON-RPC request to a specific extension
async fn passthrough(&self, extension: &str, request: Value) -> ExtensionResult<Value>;
// Set the user_message field in the span instead of creating a new event
if let Some(content) = messages
.last()
.and_then(|msg| msg.content.first())
.and_then(|c| c.as_text())
{
debug!("user_message" = &content);
}
/// Add custom text to be included in the system prompt
async fn extend_system_prompt(&mut self, extension: String);
Ok(Box::pin(async_stream::try_stream! {
let _reply_guard = reply_span.enter();
loop {
match self.provider().complete(
&system_prompt,
&messages,
&tools,
).await {
Ok((mut response, usage)) => {
// Post-process / structure the response only if tool interpretation is enabled
if config.toolshim {
let interpreter = OllamaInterpreter::new()
.map_err(|e| anyhow::anyhow!("Failed to create OllamaInterpreter: {}", e))?;
/// Handle a confirmation response for a tool request
async fn handle_confirmation(&self, request_id: String, confirmation: PermissionConfirmation);
response = augment_message_with_tool_calls(&interpreter, response, &toolshim_tools).await?;
}
/// Override the system prompt with custom text
async fn override_system_prompt(&mut self, template: String);
// record usage for the session in the session file
if let Some(session) = session.clone() {
// TODO: track session_id in langfuse tracing
let session_file = session::get_path(session.id);
let mut metadata = session::read_metadata(&session_file)?;
metadata.working_dir = session.working_dir;
metadata.total_tokens = usage.usage.total_tokens;
metadata.input_tokens = usage.usage.input_tokens;
metadata.output_tokens = usage.usage.output_tokens;
// The message count is the number of messages in the session + 1 for the response
// The message count does not include the tool response till next iteration
metadata.message_count = messages.len() + 1;
session::update_metadata(&session_file, &metadata).await?;
}
/// Lists all prompts from all extensions
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>>;
// Reset truncation attempt
truncation_attempt = 0;
/// Get a prompt result with the given name and arguments
/// Returns the prompt text that would be used as user input
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult>;
// Yield the assistant's response, but filter out frontend tool requests that we'll process separately
let filtered_response = Message {
role: response.role.clone(),
created: response.created,
content: response.content.iter().filter(|c| {
if let MessageContent::ToolRequest(req) = c {
// Only filter out frontend tool requests
if let Ok(tool_call) = &req.tool_call {
return !self.is_frontend_tool(&tool_call.name);
}
}
true
}).cloned().collect(),
};
yield filtered_response.clone();
/// Get the plan prompt, which will be used with the planner (reasoner) model
async fn get_plan_prompt(&self) -> anyhow::Result<String>;
tokio::task::yield_now().await;
/// Get a reference to the provider used by this agent
async fn provider(&self) -> Arc<Box<dyn Provider>>;
// First collect any tool requests
let tool_requests: Vec<&ToolRequest> = response.content
.iter()
.filter_map(|content| content.as_tool_request())
.collect();
/// Handle a tool result from the frontend
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>);
if tool_requests.is_empty() {
break;
}
// Process tool requests depending on goose_mode
let mut message_tool_response = Message::user();
// First handle any frontend tool requests
let mut remaining_requests = Vec::new();
for request in &tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
if self.is_frontend_tool(&tool_call.name) {
// Send frontend tool request and wait for response
yield Message::assistant().with_frontend_tool_request(
request.id.clone(),
Ok(tool_call.clone())
);
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
message_tool_response = message_tool_response.with_tool_response(id, result);
}
} else {
remaining_requests.push(request);
}
} else {
remaining_requests.push(request);
}
}
// Split tool requests into enable_extension and others
let (enable_extension_requests, non_enable_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone()
.into_iter()
.partition(|req| {
req.tool_call.as_ref()
.map(|call| call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME)
.unwrap_or(false)
});
let (search_extension_requests, _non_search_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone()
.into_iter()
.partition(|req| {
req.tool_call.as_ref()
.map(|call| call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME)
.unwrap_or(false)
});
// Clone goose_mode once before the match to avoid move issues
let mode = goose_mode.clone();
// If there are install extension requests, always require confirmation
// or if goose_mode is approve or smart_approve, check permissions for all tools
if !enable_extension_requests.is_empty() || mode.as_str() == "approve" || mode.as_str() == "smart_approve" {
let mut needs_confirmation = Vec::<&ToolRequest>::new();
let mut approved_tools = Vec::new();
let mut llm_detect_candidates = Vec::<&ToolRequest>::new();
let mut detected_read_only_tools = Vec::new();
// If approve mode or smart approve mode, check permissions for all tools
if mode.as_str() == "approve" || mode.as_str() == "smart_approve" {
let store = ToolPermissionStore::load()?;
for request in &non_enable_extension_requests {
if let Ok(tool_call) = request.tool_call.clone() {
// Regular permission checking for other tools
if tools_with_readonly_annotation.contains(&tool_call.name) {
approved_tools.push((request.id.clone(), tool_call));
} else if let Some(allowed) = store.check_permission(request) {
if allowed {
// Instead of executing immediately, collect approved tools
approved_tools.push((request.id.clone(), tool_call));
} else {
// If the tool doesn't have any annotation, we can use llm-as-a-judge to check permission.
if tools_without_annotation.contains(&tool_call.name) {
llm_detect_candidates.push(request);
}
needs_confirmation.push(request);
}
} else {
if tools_without_annotation.contains(&tool_call.name) {
llm_detect_candidates.push(request);
}
needs_confirmation.push(request);
}
}
}
}
// Only check read-only status for tools needing confirmation
if !llm_detect_candidates.is_empty() && mode == "smart_approve" {
detected_read_only_tools = detect_read_only_tools(self.provider(), llm_detect_candidates.clone()).await;
// Remove install extensions from read-only tools
if !enable_extension_requests.is_empty() {
detected_read_only_tools.retain(|tool_name| {
!enable_extension_requests.iter().any(|req| {
req.tool_call.as_ref()
.map(|call| call.name == *tool_name)
.unwrap_or(false)
})
});
}
}
// Handle pre-approved and read-only tools in parallel
let mut tool_futures = Vec::new();
let mut install_results = Vec::new();
// Handle install extension requests
for request in &enable_extension_requests {
if let Ok(tool_call) = request.tool_call.clone() {
let confirmation = Message::user().with_enable_extension_request(
request.id.clone(),
tool_call.arguments.get("extension_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string()
);
yield confirmation;
let mut rx = self.confirmation_rx.lock().await;
while let Some((req_id, extension_confirmation)) = rx.recv().await {
if req_id == request.id {
if extension_confirmation.permission == Permission::AllowOnce || extension_confirmation.permission == Permission::AlwaysAllow {
let extension_name = tool_call.arguments.get("extension_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let install_result = Self::enable_extension(&mut extension_manager, extension_name, request.id.clone()).await;
install_results.push(install_result);
}
break;
}
}
}
}
// Process read-only tools
for request in &needs_confirmation {
if let Ok(tool_call) = request.tool_call.clone() {
let is_frontend_tool = self.is_frontend_tool(&tool_call.name);
// Skip confirmation if the tool_call.name is in the read_only_tools list
if detected_read_only_tools.contains(&tool_call.name) {
let tool_future = Self::create_tool_future(&extension_manager, tool_call, is_frontend_tool, request.id.clone());
tool_futures.push(tool_future);
} else {
let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
);
yield confirmation;
// Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await;
while let Some((req_id, tool_confirmation)) = rx.recv().await {
if req_id == request.id {
let confirmed = tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow;
if confirmed {
// Add this tool call to the futures collection
let tool_future = Self::create_tool_future(&extension_manager, tool_call, is_frontend_tool, request.id.clone());
tool_futures.push(tool_future);
} else {
// User declined - add declined response
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text(
"The user has declined to run this tool. \
DO NOT attempt to call this tool again. \
If there are no alternative methods to proceed, clearly explain the situation and STOP.")]),
);
}
break; // Exit the loop once the matching `req_id` is found
}
}
}
}
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
// Check if any install results had errors before processing them
let all_successful = !install_results.iter().any(|(_, result)| result.is_err());
for (request_id, output) in install_results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output
);
}
// Update system prompt and tools if all installations were successful
if all_successful {
let extensions_info = extension_manager.get_extensions_info().await;
system_prompt = self.prompt_manager.build_system_prompt(extensions_info, self.frontend_instructions.clone());
tools = extension_manager.get_prefixed_tools().await?;
}
}
if mode.as_str() == "auto" || !search_extension_requests.is_empty() {
let mut tool_futures = Vec::new();
// Process non_enable_extension_requests and search_extension_requests without duplicates
let mut processed_ids = HashSet::new();
for request in non_enable_extension_requests.iter().chain(search_extension_requests.iter()) {
if processed_ids.insert(request.id.clone()) {
if let Ok(tool_call) = request.tool_call.clone() {
let is_frontend_tool = self.is_frontend_tool(&tool_call.name);
let tool_future = Self::create_tool_future(&extension_manager, tool_call, is_frontend_tool, request.id.clone());
tool_futures.push(tool_future);
}
}
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
}
if mode.as_str() == "chat" {
// Skip all tool calls in chat mode
// Skip search extension requests since they were already processed
let non_search_non_enable_extension_requests = non_enable_extension_requests.iter()
.filter(|req| {
if let Ok(tool_call) = &req.tool_call {
tool_call.name != PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME
} else {
true
}
});
for request in non_search_non_enable_extension_requests {
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text(
"Let the user know the tool call was skipped in Goose chat mode. \
DO NOT apologize for skipping the tool call. DO NOT say sorry. \
Provide an explanation of what the tool call would do, structured as a \
plan for the user. Again, DO NOT apologize. \
**Example Plan:**\n \
1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \
2. **Outline Steps** - Break down the steps.\n \
If needed, adjust the explanation based on user preferences or questions."
)]),
);
}
}
yield message_tool_response.clone();
messages.push(response);
messages.push(message_tool_response);
},
Err(ProviderError::ContextLengthExceeded(_)) => {
if truncation_attempt >= MAX_TRUNCATION_ATTEMPTS {
// Create an error message & terminate the stream
// the previous message would have been a user message (e.g. before any tool calls, this is just after the input message.
// at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user)
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again.");
break;
}
truncation_attempt += 1;
warn!("Context length exceeded. Truncation Attempt: {}/{}.", truncation_attempt, MAX_TRUNCATION_ATTEMPTS);
// Decay the estimate factor as we make more truncation attempts
// Estimate factor decays like this over time: 0.9, 0.81, 0.729, ...
let estimate_factor: f32 = ESTIMATE_FACTOR_DECAY.powi(truncation_attempt as i32);
// release the lock before truncation to prevent deadlock
drop(extension_manager);
if let Err(err) = self.truncate_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await {
yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err));
break;
}
// Re-acquire the lock
extension_manager = self.extension_manager.lock().await;
// Retry the loop after truncation
continue;
},
Err(e) => {
// Create an error message & terminate the stream
error!("Error: {}", e);
yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."));
break;
}
}
// Yield control back to the scheduler to prevent blocking
tokio::task::yield_now().await;
}
}))
}
/// Extend the system prompt with one line of additional instruction
pub async fn extend_system_prompt(&mut self, instruction: String) {
self.prompt_manager.add_system_prompt_extra(instruction);
}
/// Override the system prompt with a custom template
pub async fn override_system_prompt(&mut self, template: String) {
self.prompt_manager.set_system_prompt_override(template);
}
pub async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
let extension_manager = self.extension_manager.lock().await;
extension_manager
.list_prompts()
.await
.expect("Failed to list prompts")
}
pub async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
let extension_manager = self.extension_manager.lock().await;
// First find which extension has this prompt
let prompts = extension_manager
.list_prompts()
.await
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
if let Some(extension) = prompts
.iter()
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
.map(|(extension, _)| extension)
{
return extension_manager
.get_prompt(extension, name, arguments)
.await
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
}
Err(anyhow!("Prompt '{}' not found", name))
}
pub async fn get_plan_prompt(&self) -> anyhow::Result<String> {
let mut extension_manager = self.extension_manager.lock().await;
let tools = extension_manager.get_prefixed_tools().await?;
let tools_info = tools
.into_iter()
.map(|tool| {
ToolInfo::new(
&tool.name,
&tool.description,
get_parameter_names(&tool),
None,
)
})
.collect();
let plan_prompt = extension_manager.get_planning_prompt(tools_info).await;
Ok(plan_prompt)
}
pub async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
if let Err(e) = self.tool_result_tx.send((id, result)).await {
tracing::error!("Failed to send tool result: {}", e);
}
}
}

View File

@@ -262,9 +262,9 @@ impl std::fmt::Display for ExtensionConfig {
/// Information about the extension used for building prompts
#[derive(Clone, Debug, Serialize)]
pub struct ExtensionInfo {
name: String,
instructions: String,
has_resources: bool,
pub name: String,
pub instructions: String,
pub has_resources: bool,
}
impl ExtensionInfo {

View File

@@ -8,12 +8,11 @@ use std::sync::Arc;
use std::sync::LazyLock;
use std::time::Duration;
use tokio::sync::Mutex;
use tracing::{debug, instrument};
use tracing::debug;
use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo};
use crate::config::{Config, ExtensionManager};
use crate::config::ExtensionConfigManager;
use crate::prompt_template;
use crate::providers::base::Provider;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult};
@@ -26,22 +25,11 @@ static DEFAULT_TIMESTAMP: LazyLock<DateTime<Utc>> =
type McpClientBox = Arc<Mutex<Box<dyn McpClientTrait>>>;
/// A frontend tool that will be executed by the frontend rather than an extension
#[derive(Clone)]
pub struct FrontendTool {
pub name: String,
pub tool: Tool,
}
/// Manages MCP clients and their interactions
pub struct Capabilities {
/// Manages Goose extensions / MCP clients and their interactions
pub struct ExtensionManager {
clients: HashMap<String, McpClientBox>,
frontend_tools: HashMap<String, FrontendTool>,
instructions: HashMap<String, String>,
resource_capable_extensions: HashSet<String>,
provider: Arc<Box<dyn Provider>>,
system_prompt_override: Option<String>,
system_prompt_extensions: Vec<String>,
}
/// A flattened representation of a resource used by the agent to prepare inference
@@ -99,17 +87,19 @@ pub fn get_parameter_names(tool: &Tool) -> Vec<String> {
.unwrap_or_default()
}
impl Capabilities {
/// Create a new Capabilities with the specified provider
pub fn new(provider: Box<dyn Provider>) -> Self {
impl Default for ExtensionManager {
fn default() -> Self {
Self::new()
}
}
impl ExtensionManager {
/// Create a new ExtensionManager instance
pub fn new() -> Self {
Self {
clients: HashMap::new(),
frontend_tools: HashMap::new(),
instructions: HashMap::new(),
resource_capable_extensions: HashSet::new(),
provider: Arc::new(provider),
system_prompt_override: None,
system_prompt_extensions: Vec::new(),
}
}
@@ -121,134 +111,106 @@ impl Capabilities {
// TODO IMPORTANT need to ensure this times out if the extension command is broken!
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
let sanitized_name = normalize(config.key().to_string());
match &config {
ExtensionConfig::Frontend {
name: _,
tools,
instructions,
let mut client: Box<dyn McpClientTrait> = match &config {
ExtensionConfig::Sse {
uri, envs, timeout, ..
} => {
// For frontend tools, just store them in the frontend_tools map
for tool in tools {
let frontend_tool = FrontendTool {
name: tool.name.clone(),
tool: tool.clone(),
};
self.frontend_tools.insert(tool.name.clone(), frontend_tool);
}
// Store instructions if provided, using "frontend" as the key
if let Some(instructions) = instructions {
self.instructions
.insert("frontend".to_string(), instructions.clone());
}
Ok(())
let transport = SseTransport::new(uri, envs.get_env());
let handle = transport.start().await?;
let service = McpService::with_timeout(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
);
Box::new(McpClient::new(service))
}
_ => {
let mut client: Box<dyn McpClientTrait> = match &config {
ExtensionConfig::Sse {
uri, envs, timeout, ..
} => {
let transport = SseTransport::new(uri, envs.get_env());
let handle = transport.start().await?;
let service = McpService::with_timeout(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
);
Box::new(McpClient::new(service))
}
ExtensionConfig::Stdio {
cmd,
args,
envs,
timeout,
..
} => {
let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env());
let handle = transport.start().await?;
let service = McpService::with_timeout(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
);
Box::new(McpClient::new(service))
}
ExtensionConfig::Builtin {
name,
display_name: _,
timeout,
} => {
// For builtin extensions, we run the current executable with mcp and extension name
let cmd = std::env::current_exe()
.expect("should find the current executable")
.to_str()
.expect("should resolve executable to string path")
.to_string();
let transport = StdioTransport::new(
&cmd,
vec!["mcp".to_string(), name.clone()],
HashMap::new(),
);
let handle = transport.start().await?;
let service = McpService::with_timeout(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
);
Box::new(McpClient::new(service))
}
_ => unreachable!(),
};
// Initialize the client with default capabilities
let info = ClientInfo {
name: "goose".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
};
let capabilities = ClientCapabilities::default();
let init_result = client
.initialize(info, capabilities)
.await
.map_err(|e| ExtensionError::Initialization(config.clone(), e))?;
// Store instructions if provided
if let Some(instructions) = init_result.instructions {
self.instructions
.insert(sanitized_name.clone(), instructions);
}
// if the server is capable if resources we track it
if init_result.capabilities.resources.is_some() {
self.resource_capable_extensions
.insert(sanitized_name.clone());
}
// Store the client using the provided name
self.clients
.insert(sanitized_name.clone(), Arc::new(Mutex::new(client)));
Ok(())
ExtensionConfig::Stdio {
cmd,
args,
envs,
timeout,
..
} => {
let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env());
let handle = transport.start().await?;
let service = McpService::with_timeout(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
);
Box::new(McpClient::new(service))
}
ExtensionConfig::Builtin {
name,
display_name: _,
timeout,
} => {
// For builtin extensions, we run the current executable with mcp and extension name
let cmd = std::env::current_exe()
.expect("should find the current executable")
.to_str()
.expect("should resolve executable to string path")
.to_string();
let transport = StdioTransport::new(
&cmd,
vec!["mcp".to_string(), name.clone()],
HashMap::new(),
);
let handle = transport.start().await?;
let service = McpService::with_timeout(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
);
Box::new(McpClient::new(service))
}
_ => unreachable!(),
};
// Initialize the client with default capabilities
let info = ClientInfo {
name: "goose".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
};
let capabilities = ClientCapabilities::default();
let init_result = client
.initialize(info, capabilities)
.await
.map_err(|e| ExtensionError::Initialization(config.clone(), e))?;
// Store instructions if provided
if let Some(instructions) = init_result.instructions {
self.instructions
.insert(sanitized_name.clone(), instructions);
}
// if the server is capable if resources we track it
if init_result.capabilities.resources.is_some() {
self.resource_capable_extensions
.insert(sanitized_name.clone());
}
// Store the client using the provided name
self.clients
.insert(sanitized_name.clone(), Arc::new(Mutex::new(client)));
Ok(())
}
/// Add a system prompt extension
pub fn add_system_prompt_extension(&mut self, extension: String) {
self.system_prompt_extensions.push(extension);
}
/// Override the system prompt with custom text
pub fn set_system_prompt_override(&mut self, template: String) {
self.system_prompt_override = Some(template);
}
/// Get a reference to the provider
pub fn provider(&self) -> Arc<Box<dyn Provider>> {
Arc::clone(&self.provider)
/// Get extensions info
pub async fn get_extensions_info(&self) -> Vec<ExtensionInfo> {
self.clients
.keys()
.map(|name| {
let instructions = self.instructions.get(name).cloned().unwrap_or_default();
let has_resources = self.resource_capable_extensions.contains(name);
ExtensionInfo::new(name, &instructions, has_resources)
})
.collect()
}
/// Get aggregated usage statistics
@@ -269,11 +231,6 @@ impl Capabilities {
pub async fn get_prefixed_tools(&mut self) -> ExtensionResult<Vec<Tool>> {
let mut tools = Vec::new();
// Add frontend tools directly - they don't need prefixing since they're already uniquely named
for frontend_tool in self.frontend_tools.values() {
tools.push(frontend_tool.tool.clone());
}
// Add tools from MCP extensions with prefixing
for (name, client) in &self.clients {
let client_guard = client.lock().await;
@@ -297,6 +254,7 @@ impl Capabilities {
client_tools = client_guard.list_tools(client_tools.next_cursor).await?;
}
}
Ok(tools)
}
@@ -353,76 +311,6 @@ impl Capabilities {
prompt_template::render_global_file("plan.md", &context).expect("Prompt should render")
}
/// Get the extension prompt including client instructions
pub async fn get_system_prompt(&self) -> String {
let mut context: HashMap<&str, Value> = HashMap::new();
let mut extensions_info: Vec<ExtensionInfo> = self
.clients
.keys()
.map(|name| {
let instructions = self.instructions.get(name).cloned().unwrap_or_default();
let has_resources = self.resource_capable_extensions.contains(name);
ExtensionInfo::new(name, &instructions, has_resources)
})
.collect();
// Add frontend tools as a special extension if any exist
if !self.frontend_tools.is_empty() {
let name = "frontend";
let instructions = self.instructions.get(name).cloned().unwrap_or_else(||
"The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string()
);
extensions_info.push(ExtensionInfo::new(name, &instructions, false));
}
context.insert("extensions", serde_json::to_value(extensions_info).unwrap());
let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
context.insert("current_date_time", Value::String(current_date_time));
// Conditionally load the override prompt or the global system prompt
let base_prompt = if let Some(override_prompt) = &self.system_prompt_override {
prompt_template::render_inline_once(override_prompt, &context)
.expect("Prompt should render")
} else {
prompt_template::render_global_file("system.md", &context)
.expect("Prompt should render")
};
let mut system_prompt_extensions = self.system_prompt_extensions.clone();
let config = Config::global();
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
if goose_mode == "chat" {
system_prompt_extensions.push(
"Right now you are in the chat only mode, no access to any tool use and system."
.to_string(),
);
} else {
system_prompt_extensions
.push("Right now you are *NOT* in the chat only mode and have access to tool use and system.".to_string());
}
if system_prompt_extensions.is_empty() {
base_prompt
} else {
format!(
"{}\n\n# Additional Instructions:\n\n{}",
base_prompt,
system_prompt_extensions.join("\n\n")
)
}
}
/// Check if a tool is a frontend tool
pub fn is_frontend_tool(&self, name: &str) -> bool {
self.frontend_tools.contains_key(name)
}
/// Get a reference to a frontend tool
pub fn get_frontend_tool(&self, name: &str) -> Option<&FrontendTool> {
self.frontend_tools.get(name)
}
/// Find and return a reference to the appropriate client for a tool call
fn get_client_for_tool(&self, prefixed_name: &str) -> Option<(&str, McpClientBox)> {
self.clients
@@ -432,7 +320,7 @@ impl Capabilities {
}
// Function that gets executed for read_resource tool
async fn read_resource(&self, params: Value) -> Result<Vec<Content>, ToolError> {
pub async fn read_resource(&self, params: Value) -> Result<Vec<Content>, ToolError> {
let uri = params
.get("uri")
.and_then(|v| v.as_str())
@@ -544,7 +432,7 @@ impl Capabilities {
})
}
async fn list_resources(&self, params: Value) -> Result<Vec<Content>, ToolError> {
pub async fn list_resources(&self, params: Value) -> Result<Vec<Content>, ToolError> {
let extension = params.get("extension").and_then(|v| v.as_str());
match extension {
@@ -594,42 +482,26 @@ impl Capabilities {
}
}
/// Dispatch a single tool call to the appropriate client
#[instrument(skip(self, tool_call), fields(input, output))]
pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> ToolResult<Vec<Content>> {
let result = if tool_call.name == "platform__read_resource" {
// Check if the tool is read_resource and handle it separately
self.read_resource(tool_call.arguments.clone()).await
} else if tool_call.name == "platform__list_resources" {
self.list_resources(tool_call.arguments.clone()).await
} else if tool_call.name == "platform__search_available_extensions" {
self.search_available_extensions().await
} else if self.is_frontend_tool(&tool_call.name) {
// For frontend tools, return an error indicating we need frontend execution
Err(ToolError::ExecutionError(
"Frontend tool execution required".to_string(),
))
} else {
// Else, dispatch tool call based on the prefix naming convention
let (client_name, client) = self
.get_client_for_tool(&tool_call.name)
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?;
// Dispatch tool call based on the prefix naming convention
let (client_name, client) = self
.get_client_for_tool(&tool_call.name)
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?;
// rsplit returns the iterator in reverse, tool_name is then at 0
let tool_name = tool_call
.name
.strip_prefix(client_name)
.and_then(|s| s.strip_prefix("__"))
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?;
// rsplit returns the iterator in reverse, tool_name is then at 0
let tool_name = tool_call
.name
.strip_prefix(client_name)
.and_then(|s| s.strip_prefix("__"))
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?;
let client_guard = client.lock().await;
let client_guard = client.lock().await;
client_guard
.call_tool(tool_name, tool_call.clone().arguments)
.await
.map(|result| result.content)
.map_err(|e| ToolError::ExecutionError(e.to_string()))
};
let result = client_guard
.call_tool(tool_name, tool_call.clone().arguments)
.await
.map(|result| result.content)
.map_err(|e| ToolError::ExecutionError(e.to_string()));
debug!(
"input" = serde_json::to_string(&tool_call).unwrap(),
@@ -725,7 +597,7 @@ impl Capabilities {
// First get disabled extensions from current config
let mut disabled_extensions: Vec<String> = vec![];
for extension in ExtensionManager::get_all().expect("should load extensions") {
for extension in ExtensionConfigManager::get_all().expect("should load extensions") {
if !extension.enabled {
let config = extension.config.clone();
let description = match &config {
@@ -775,10 +647,6 @@ impl Capabilities {
#[cfg(test)]
mod tests {
use super::*;
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage};
use crate::providers::errors::ProviderError;
use mcp_client::client::Error;
use mcp_client::client::McpClientTrait;
use mcp_core::protocol::{
@@ -787,35 +655,6 @@ mod tests {
};
use serde_json::json;
// Mock Provider implementation for testing
#[derive(Clone)]
struct MockProvider {
model_config: ModelConfig,
}
#[async_trait::async_trait]
impl Provider for MockProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::empty()
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message::assistant().with_text("Mock response"),
ProviderUsage::new("mock".to_string(), Usage::default()),
))
}
}
struct MockClient {}
#[async_trait::async_trait]
@@ -871,74 +710,68 @@ mod tests {
#[test]
fn test_get_client_for_tool() {
let mock_model_config =
ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into());
let mut capabilities = Capabilities::new(Box::new(MockProvider {
model_config: mock_model_config,
}));
let mut extension_manager = ExtensionManager::new();
// Add some mock clients
capabilities.clients.insert(
extension_manager.clients.insert(
normalize("test_client".to_string()),
Arc::new(Mutex::new(Box::new(MockClient {}))),
);
capabilities.clients.insert(
extension_manager.clients.insert(
normalize("__client".to_string()),
Arc::new(Mutex::new(Box::new(MockClient {}))),
);
capabilities.clients.insert(
extension_manager.clients.insert(
normalize("__cli__ent__".to_string()),
Arc::new(Mutex::new(Box::new(MockClient {}))),
);
capabilities.clients.insert(
extension_manager.clients.insert(
normalize("client 🚀".to_string()),
Arc::new(Mutex::new(Box::new(MockClient {}))),
);
// Test basic case
assert!(capabilities
assert!(extension_manager
.get_client_for_tool("test_client__tool")
.is_some());
// Test leading underscores
assert!(capabilities.get_client_for_tool("__client__tool").is_some());
assert!(extension_manager
.get_client_for_tool("__client__tool")
.is_some());
// Test multiple underscores in client name, and ending with __
assert!(capabilities
assert!(extension_manager
.get_client_for_tool("__cli__ent____tool")
.is_some());
// Test unicode in tool name, "client 🚀" should become "client_"
assert!(capabilities.get_client_for_tool("client___tool").is_some());
assert!(extension_manager
.get_client_for_tool("client___tool")
.is_some());
}
#[tokio::test]
async fn test_dispatch_tool_call() {
// test that dispatch_tool_call parses out the sanitized name correctly, and extracts
// tool_names
let mock_model_config =
ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into());
let mut capabilities = Capabilities::new(Box::new(MockProvider {
model_config: mock_model_config,
}));
let mut extension_manager = ExtensionManager::new();
// Add some mock clients
capabilities.clients.insert(
extension_manager.clients.insert(
normalize("test_client".to_string()),
Arc::new(Mutex::new(Box::new(MockClient {}))),
);
capabilities.clients.insert(
extension_manager.clients.insert(
normalize("__cli__ent__".to_string()),
Arc::new(Mutex::new(Box::new(MockClient {}))),
);
capabilities.clients.insert(
extension_manager.clients.insert(
normalize("client 🚀".to_string()),
Arc::new(Mutex::new(Box::new(MockClient {}))),
);
@@ -949,7 +782,7 @@ mod tests {
arguments: json!({}),
};
let result = capabilities.dispatch_tool_call(tool_call).await;
let result = extension_manager.dispatch_tool_call(tool_call).await;
assert!(result.is_ok());
let tool_call = ToolCall {
@@ -957,7 +790,7 @@ mod tests {
arguments: json!({}),
};
let result = capabilities.dispatch_tool_call(tool_call).await;
let result = extension_manager.dispatch_tool_call(tool_call).await;
assert!(result.is_ok());
// verify a multiple underscores dispatch
@@ -966,7 +799,7 @@ mod tests {
arguments: json!({}),
};
let result = capabilities.dispatch_tool_call(tool_call).await;
let result = extension_manager.dispatch_tool_call(tool_call).await;
assert!(result.is_ok());
// Test unicode in tool name, "client 🚀" should become "client_"
@@ -975,7 +808,7 @@ mod tests {
arguments: json!({}),
};
let result = capabilities.dispatch_tool_call(tool_call).await;
let result = extension_manager.dispatch_tool_call(tool_call).await;
assert!(result.is_ok());
let tool_call = ToolCall {
@@ -983,7 +816,7 @@ mod tests {
arguments: json!({}),
};
let result = capabilities.dispatch_tool_call(tool_call).await;
let result = extension_manager.dispatch_tool_call(tool_call).await;
assert!(result.is_ok());
// this should error out, specifically for an ToolError::ExecutionError
@@ -992,7 +825,9 @@ mod tests {
arguments: json!({}),
};
let result = capabilities.dispatch_tool_call(invalid_tool_call).await;
let result = extension_manager
.dispatch_tool_call(invalid_tool_call)
.await;
assert!(matches!(
result.err().unwrap(),
ToolError::ExecutionError(_)
@@ -1005,7 +840,9 @@ mod tests {
arguments: json!({}),
};
let result = capabilities.dispatch_tool_call(invalid_tool_call).await;
let result = extension_manager
.dispatch_tool_call(invalid_tool_call)
.await;
assert!(matches!(result.err().unwrap(), ToolError::NotFound(_)));
}
}

View File

@@ -1,77 +0,0 @@
use std::collections::HashMap;
use std::sync::{OnceLock, RwLock};
pub use super::Agent;
use crate::config::Config;
use crate::providers::base::Provider;
type AgentConstructor = Box<dyn Fn(Box<dyn Provider>) -> Box<dyn Agent> + Send + Sync>;
// Use std::sync::RwLock for interior mutability
static AGENT_REGISTRY: OnceLock<RwLock<HashMap<&'static str, AgentConstructor>>> = OnceLock::new();
/// Initialize the registry if it hasn't been initialized
fn registry() -> &'static RwLock<HashMap<&'static str, AgentConstructor>> {
AGENT_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
}
/// Register a new agent version
pub fn register_agent(
version: &'static str,
constructor: impl Fn(Box<dyn Provider>) -> Box<dyn Agent> + Send + Sync + 'static,
) {
let registry = registry();
if let Ok(mut map) = registry.write() {
map.insert(version, Box::new(constructor));
}
}
pub struct AgentFactory;
impl AgentFactory {
/// Create a new agent instance of the specified version
pub fn create(version: &str, provider: Box<dyn Provider>) -> Option<Box<dyn Agent>> {
let registry = registry();
let map = registry
.read()
.expect("should be able to read the registry");
let constructor = map.get(version)?;
Some(constructor(provider))
}
/// Get a list of all available agent versions
pub fn available_versions() -> Vec<&'static str> {
registry()
.read()
.map(|map| map.keys().copied().collect())
.unwrap_or_default()
}
pub fn configured_version() -> String {
let config = Config::global();
config
.get_param::<String>("GOOSE_AGENT")
.unwrap_or_else(|_| Self::default_version().to_string())
}
/// Get the default version name
pub fn default_version() -> &'static str {
"truncate"
}
}
/// Macro to help with agent registration
#[macro_export]
macro_rules! register_agent {
($version:expr, $agent_type:ty) => {
paste::paste! {
#[ctor::ctor]
#[allow(non_snake_case)]
fn [<__register_agent_ $version>]() {
$crate::agents::factory::register_agent($version, |provider| {
Box::new(<$agent_type>::new(provider))
});
}
}
};
}

View File

@@ -1,13 +1,12 @@
mod agent;
pub mod capabilities;
pub mod extension;
mod factory;
mod reference;
mod summarize;
mod truncate;
pub mod extension_manager;
pub mod platform_tools;
pub mod prompt_manager;
mod types;
pub use agent::{Agent, SessionConfig};
pub use capabilities::Capabilities;
pub use extension::{ExtensionConfig, ExtensionResult};
pub use factory::{register_agent, AgentFactory};
pub use agent::Agent;
pub use extension::ExtensionConfig;
pub use extension_manager::ExtensionManager;
pub use prompt_manager::PromptManager;
pub use types::{FrontendTool, SessionConfig};

View File

@@ -0,0 +1,112 @@
use indoc::indoc;
use mcp_core::tool::{Tool, ToolAnnotations};
use serde_json::json;
pub const PLATFORM_READ_RESOURCE_TOOL_NAME: &str = "platform__read_resource";
pub const PLATFORM_LIST_RESOURCES_TOOL_NAME: &str = "platform__list_resources";
pub const PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME: &str =
"platform__search_available_extensions";
pub const PLATFORM_ENABLE_EXTENSION_TOOL_NAME: &str = "platform__enable_extension";
pub fn read_resource_tool() -> Tool {
Tool::new(
PLATFORM_READ_RESOURCE_TOOL_NAME.to_string(),
indoc! {r#"
Read a resource from an extension.
Resources allow extensions to share data that provide context to LLMs, such as
files, database schemas, or application-specific information. This tool searches for the
resource URI in the provided extension, and reads in the resource content. If no extension
is provided, the tool will search all extensions for the resource.
"#}.to_string(),
json!({
"type": "object",
"required": ["uri"],
"properties": {
"uri": {"type": "string", "description": "Resource URI"},
"extension_name": {"type": "string", "description": "Optional extension name"}
}
}),
Some(ToolAnnotations {
title: Some("Read a resource".to_string()),
read_only_hint: true,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
)
}
pub fn list_resources_tool() -> Tool {
Tool::new(
PLATFORM_LIST_RESOURCES_TOOL_NAME.to_string(),
indoc! {r#"
List resources from an extension(s).
Resources allow extensions to share data that provide context to LLMs, such as
files, database schemas, or application-specific information. This tool lists resources
in the provided extension, and returns a list for the user to browse. If no extension
is provided, the tool will search all extensions for the resource.
"#}
.to_string(),
json!({
"type": "object",
"properties": {
"extension_name": {"type": "string", "description": "Optional extension name"}
}
}),
Some(ToolAnnotations {
title: Some("List resources".to_string()),
read_only_hint: true,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
)
}
pub fn search_available_extensions_tool() -> Tool {
Tool::new(
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME.to_string(),
"Searches for additional extensions available to help complete tasks.
Use this tool when you're unable to find a specific feature or functionality you need to complete your task, or when standard approaches aren't working.
These extensions might provide the exact tools needed to solve your problem.
If you find a relevant one, consider using your tools to enable it.".to_string(),
json!({
"type": "object",
"required": [],
"properties": {}
}),
Some(ToolAnnotations {
title: Some("Discover extensions".to_string()),
read_only_hint: true,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
)
}
pub fn enable_extension_tool() -> Tool {
Tool::new(
PLATFORM_ENABLE_EXTENSION_TOOL_NAME.to_string(),
"Enable extensions to help complete tasks.
Enable an extension by providing the extension name.
"
.to_string(),
json!({
"type": "object",
"required": ["extension_name"],
"properties": {
"extension_name": {"type": "string", "description": "The name of the extension to enable"}
}
}),
Some(ToolAnnotations {
title: Some("Enable extensions".to_string()),
read_only_hint: false,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
)
}

View File

@@ -0,0 +1,95 @@
use chrono::Utc;
use serde_json::Value;
use std::collections::HashMap;
use crate::agents::extension::ExtensionInfo;
use crate::{config::Config, prompt_template};
pub struct PromptManager {
system_prompt_override: Option<String>,
system_prompt_extras: Vec<String>,
}
impl Default for PromptManager {
fn default() -> Self {
PromptManager::new()
}
}
impl PromptManager {
pub fn new() -> Self {
PromptManager {
system_prompt_override: None,
system_prompt_extras: Vec::new(),
}
}
/// Add an additional instruction to the system prompt
pub fn add_system_prompt_extra(&mut self, instruction: String) {
self.system_prompt_extras.push(instruction);
}
/// Override the system prompt with custom text
pub fn set_system_prompt_override(&mut self, template: String) {
self.system_prompt_override = Some(template);
}
/// Build the final system prompt
///
/// * `extensions_info` extension information for each extension/MCP
/// * `frontend_instructions` instructions for the "frontend" tool
pub fn build_system_prompt(
&self,
extensions_info: Vec<ExtensionInfo>,
frontend_instructions: Option<String>,
) -> String {
let mut context: HashMap<&str, Value> = HashMap::new();
let mut extensions_info = extensions_info.clone();
// Add frontend instructions to extensions_info to simplify json rendering
if let Some(frontend_instructions) = frontend_instructions {
extensions_info.push(ExtensionInfo::new(
"frontend",
&frontend_instructions,
false,
));
}
context.insert("extensions", serde_json::to_value(extensions_info).unwrap());
let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
context.insert("current_date_time", Value::String(current_date_time));
// Conditionally load the override prompt or the global system prompt
let base_prompt = if let Some(override_prompt) = &self.system_prompt_override {
prompt_template::render_inline_once(override_prompt, &context)
.expect("Prompt should render")
} else {
prompt_template::render_global_file("system.md", &context)
.expect("Prompt should render")
};
let mut system_prompt_extras = self.system_prompt_extras.clone();
let config = Config::global();
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
if goose_mode == "chat" {
system_prompt_extras.push(
"Right now you are in the chat only mode, no access to any tool use and system."
.to_string(),
);
} else {
system_prompt_extras
.push("Right now you are *NOT* in the chat only mode and have access to tool use and system.".to_string());
}
if system_prompt_extras.is_empty() {
base_prompt
} else {
format!(
"{}\n\n# Additional Instructions:\n\n{}",
base_prompt,
system_prompt_extras.join("\n\n")
)
}
}
}

View File

@@ -1,319 +0,0 @@
/// A simplified agent implementation used as a reference
/// It makes no attempt to handle context limits, and cannot read resources
use async_trait::async_trait;
use futures::stream::BoxStream;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tracing::{debug, instrument};
use super::agent::SessionConfig;
use super::capabilities::get_parameter_names;
use super::extension::ToolInfo;
use super::types::ToolResultReceiver;
use super::Agent;
use crate::agents::capabilities::Capabilities;
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::message::{Message, ToolRequest};
use crate::permission::PermissionConfirmation;
use crate::providers::base::Provider;
use crate::token_counter::TokenCounter;
use crate::{register_agent, session};
use anyhow::{anyhow, Result};
use indoc::indoc;
use mcp_core::tool::{Tool, ToolAnnotations};
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolResult};
use serde_json::{json, Value};
/// Reference implementation of an Agent
pub struct ReferenceAgent {
capabilities: Mutex<Capabilities>,
_token_counter: TokenCounter,
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
tool_result_rx: ToolResultReceiver,
}
impl ReferenceAgent {
pub fn new(provider: Box<dyn Provider>) -> Self {
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
let (tx, rx) = mpsc::channel(32);
Self {
capabilities: Mutex::new(Capabilities::new(provider)),
_token_counter: token_counter,
tool_result_tx: tx,
tool_result_rx: Arc::new(Mutex::new(rx)),
}
}
}
#[async_trait]
impl Agent for ReferenceAgent {
async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_extension(extension).await
}
async fn list_tools(&self) -> Vec<Tool> {
let mut capabilities = self.capabilities.lock().await;
capabilities.get_prefixed_tools().await.unwrap_or_default()
}
async fn remove_extension(&mut self, name: &str) {
let mut capabilities = self.capabilities.lock().await;
capabilities
.remove_extension(name)
.await
.expect("Failed to remove extension");
}
async fn list_extensions(&self) -> Vec<String> {
let capabilities = self.capabilities.lock().await;
capabilities
.list_extensions()
.await
.expect("Failed to list extensions")
}
async fn passthrough(&self, _extension: &str, _request: Value) -> ExtensionResult<Value> {
// TODO implement
Ok(Value::Null)
}
async fn handle_confirmation(
&self,
_request_id: String,
_confirmation: PermissionConfirmation,
) {
// TODO implement
}
#[instrument(skip(self, messages, session), fields(user_message))]
async fn reply(
&self,
messages: &[Message],
session: Option<SessionConfig>,
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
let mut messages = messages.to_vec();
let reply_span = tracing::Span::current();
let mut capabilities = self.capabilities.lock().await;
let mut tools = capabilities.get_prefixed_tools().await?;
// we add in the read_resource tool by default
// TODO: make sure there is no collision with another extension's tool name
let read_resource_tool = Tool::new(
"platform__read_resource".to_string(),
indoc! {r#"
Read a resource from an extension.
Resources allow extensions to share data that provide context to LLMs, such as
files, database schemas, or application-specific information. This tool searches for the
resource URI in the provided extension, and reads in the resource content. If no extension
is provided, the tool will search all extensions for the resource.
"#}.to_string(),
json!({
"type": "object",
"required": ["uri"],
"properties": {
"uri": {"type": "string", "description": "Resource URI"},
"extension_name": {"type": "string", "description": "Optional extension name"}
}
}),
Some(ToolAnnotations {
title: Some("Read a resource".to_string()),
read_only_hint: true,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
);
let list_resources_tool = Tool::new(
"platform__list_resources".to_string(),
indoc! {r#"
List resources from an extension(s).
Resources allow extensions to share data that provide context to LLMs, such as
files, database schemas, or application-specific information. This tool lists resources
in the provided extension, and returns a list for the user to browse. If no extension
is provided, the tool will search all extensions for the resource.
"#}.to_string(),
json!({
"type": "object",
"properties": {
"extension_name": {"type": "string", "description": "Optional extension name"}
}
}),
Some(ToolAnnotations {
title: Some("List resources".to_string()),
read_only_hint: true,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
);
if capabilities.supports_resources() {
tools.push(read_resource_tool);
tools.push(list_resources_tool);
}
let system_prompt = capabilities.get_system_prompt().await;
// Set the user_message field in the span instead of creating a new event
if let Some(content) = messages
.last()
.and_then(|msg| msg.content.first())
.and_then(|c| c.as_text())
{
debug!("user_message" = &content);
}
Ok(Box::pin(async_stream::try_stream! {
let _reply_guard = reply_span.enter();
loop {
// Get completion from provider
let (response, usage) = capabilities.provider().complete(
&system_prompt,
&messages,
&tools,
).await?;
// record usage for the session in the session file
if let Some(session) = session.clone() {
// TODO: track session_id in langfuse tracing
let session_file = session::get_path(session.id);
let mut metadata = session::read_metadata(&session_file)?;
metadata.working_dir = session.working_dir;
metadata.total_tokens = usage.usage.total_tokens;
metadata.input_tokens = usage.usage.input_tokens;
metadata.output_tokens = usage.usage.output_tokens;
// The message count is the number of messages in the session + 1 for the response
// The message count does not include the tool response till next iteration
metadata.message_count = messages.len() + 1;
session::update_metadata(&session_file, &metadata).await?;
}
// Yield the assistant's response
yield response.clone();
tokio::task::yield_now().await;
// First collect any tool requests
let tool_requests: Vec<&ToolRequest> = response.content
.iter()
.filter_map(|content| content.as_tool_request())
.collect();
if tool_requests.is_empty() {
break;
}
// Then dispatch each in parallel
let mut message_tool_response = Message::user();
for request in tool_requests {
if let Ok(tool_call) = &request.tool_call {
// Check if it's a frontend tool
if capabilities.is_frontend_tool(&tool_call.name) {
// Send frontend tool request and wait for response
yield Message::assistant().with_frontend_tool_request(
request.id.clone(),
request.tool_call.clone()
);
// Wait for the result using our channel
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
message_tool_response = message_tool_response.with_tool_response(id, result);
}
continue;
}
// Handle regular tool calls
let result = capabilities.dispatch_tool_call(tool_call.clone()).await;
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
result,
);
}
}
yield message_tool_response.clone();
messages.push(response);
messages.push(message_tool_response);
}
}))
}
async fn extend_system_prompt(&mut self, extension: String) {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_system_prompt_extension(extension);
}
async fn override_system_prompt(&mut self, template: String) {
let mut capabilities = self.capabilities.lock().await;
capabilities.set_system_prompt_override(template);
}
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
let capabilities = self.capabilities.lock().await;
capabilities
.list_prompts()
.await
.expect("Failed to list prompts")
}
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
let capabilities = self.capabilities.lock().await;
// First find which extension has this prompt
let prompts = capabilities
.list_prompts()
.await
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
if let Some(extension) = prompts
.iter()
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
.map(|(extension, _)| extension)
{
return capabilities
.get_prompt(extension, name, arguments)
.await
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
}
Err(anyhow!("Prompt '{}' not found", name))
}
async fn get_plan_prompt(&self) -> anyhow::Result<String> {
let mut capabilities = self.capabilities.lock().await;
let tools = capabilities.get_prefixed_tools().await?;
let tools_info = tools
.into_iter()
.map(|tool| {
ToolInfo::new(
&tool.name,
&tool.description,
get_parameter_names(&tool),
None,
)
})
.collect();
let plan_prompt = capabilities.get_planning_prompt(tools_info).await;
Ok(plan_prompt)
}
async fn provider(&self) -> Arc<Box<dyn Provider>> {
let capabilities = self.capabilities.lock().await;
capabilities.provider()
}
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
if let Err(e) = self.tool_result_tx.send((id, result)).await {
tracing::error!("Failed to send tool result: {}", e);
}
}
}
register_agent!("reference", ReferenceAgent);

View File

@@ -1,519 +0,0 @@
/// A summarize agent that lets the model summarize the conversation when the history exceeds the
/// model's context limit. If the model fails to summarize, then it falls back to the legacy
/// truncation method. Still cannot read resources.
use async_trait::async_trait;
use futures::stream::BoxStream;
use mcp_core::tool::ToolAnnotations;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tracing::{debug, error, instrument, warn};
use super::agent::SessionConfig;
use super::capabilities::get_parameter_names;
use super::extension::ToolInfo;
use super::Agent;
use crate::agents::capabilities::Capabilities;
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::config::Config;
use crate::memory_condense::condense_messages;
use crate::message::{Message, ToolRequest};
use crate::permission::detect_read_only_tools;
use crate::permission::Permission;
use crate::permission::PermissionConfirmation;
use crate::providers::base::Provider;
use crate::providers::errors::ProviderError;
use crate::register_agent;
use crate::session;
use crate::token_counter::TokenCounter;
use crate::truncate::{truncate_messages, OldestFirstTruncation};
use anyhow::{anyhow, Result};
use indoc::indoc;
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolResult};
use serde_json::{json, Value};
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
/// Summarize implementation of an Agent
pub struct SummarizeAgent {
capabilities: Mutex<Capabilities>,
token_counter: TokenCounter,
confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
}
impl SummarizeAgent {
pub fn new(provider: Box<dyn Provider>) -> Self {
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
// Create channels with buffer size 32 (adjust if needed)
let (confirm_tx, confirm_rx) = mpsc::channel(32);
let (tool_tx, _tool_rx) = mpsc::channel(32);
Self {
capabilities: Mutex::new(Capabilities::new(provider)),
token_counter,
confirmation_tx: confirm_tx,
confirmation_rx: Mutex::new(confirm_rx),
tool_result_tx: tool_tx,
}
}
/// Truncates the messages to fit within the model's context window
/// Ensures the last message is a user message and removes tool call-response pairs
async fn summarize_messages(
&self,
messages: &mut Vec<Message>,
estimate_factor: f32,
system_prompt: &str,
tools: &mut Vec<Tool>,
) -> anyhow::Result<()> {
// Model's actual context limit
let context_limit = self
.capabilities
.lock()
.await
.provider()
.get_model_config()
.context_limit();
// Our conservative estimate of the **target** context limit
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
let context_limit = (context_limit as f32 * estimate_factor) as usize;
// Take into account the system prompt, and our tools input and subtract that from the
// remaining context limit
let system_prompt_token_count = self.token_counter.count_tokens(system_prompt);
let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice());
// Check if system prompt + tools exceed our context limit
let remaining_tokens = context_limit
.checked_sub(system_prompt_token_count)
.and_then(|remaining| remaining.checked_sub(tools_token_count))
.ok_or_else(|| {
anyhow::anyhow!("System prompt and tools exceed estimated context limit")
})?;
let context_limit = remaining_tokens;
// Calculate current token count of each message, use count_chat_tokens to ensure we
// capture the full content of the message, include ToolRequests and ToolResponses
let mut token_counts: Vec<usize> = messages
.iter()
.map(|msg| {
self.token_counter
.count_chat_tokens("", std::slice::from_ref(msg), &[])
})
.collect();
let capabilities_guard = self.capabilities.lock().await;
if condense_messages(
&capabilities_guard,
&self.token_counter,
messages,
&mut token_counts,
context_limit,
)
.await
.is_err()
{
// Fallback to the legacy truncator if the model fails to condense the messages.
truncate_messages(
messages,
&mut token_counts,
context_limit,
&OldestFirstTruncation,
)
} else {
Ok(())
}
}
}
#[async_trait]
impl Agent for SummarizeAgent {
async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_extension(extension).await
}
async fn list_tools(&self) -> Vec<Tool> {
let mut capabilities = self.capabilities.lock().await;
capabilities.get_prefixed_tools().await.unwrap_or_default()
}
async fn remove_extension(&mut self, name: &str) {
let mut capabilities = self.capabilities.lock().await;
capabilities
.remove_extension(name)
.await
.expect("Failed to remove extension");
}
async fn list_extensions(&self) -> Vec<String> {
let capabilities = self.capabilities.lock().await;
capabilities
.list_extensions()
.await
.expect("Failed to list extensions")
}
async fn passthrough(&self, _extension: &str, _request: Value) -> ExtensionResult<Value> {
// TODO implement
Ok(Value::Null)
}
/// Handle a confirmation response for a tool request
async fn handle_confirmation(&self, request_id: String, confirmation: PermissionConfirmation) {
if let Err(e) = self.confirmation_tx.send((request_id, confirmation)).await {
error!("Failed to send confirmation: {}", e);
}
}
#[instrument(skip(self, messages, session), fields(user_message))]
async fn reply(
&self,
messages: &[Message],
session: Option<SessionConfig>,
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
let mut messages = messages.to_vec();
let reply_span = tracing::Span::current();
let mut capabilities = self.capabilities.lock().await;
let mut tools = capabilities.get_prefixed_tools().await?;
let mut truncation_attempt: usize = 0;
// Load settings from config
let config = Config::global();
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
// we add in the 2 resource tools if any extensions support resources
// TODO: make sure there is no collision with another extension's tool name
let read_resource_tool = Tool::new(
"platform__read_resource".to_string(),
indoc! {r#"
Read a resource from an extension.
Resources allow extensions to share data that provide context to LLMs, such as
files, database schemas, or application-specific information. This tool searches for the
resource URI in the provided extension, and reads in the resource content. If no extension
is provided, the tool will search all extensions for the resource.
"#}.to_string(),
json!({
"type": "object",
"required": ["uri"],
"properties": {
"uri": {"type": "string", "description": "Resource URI"},
"extension_name": {"type": "string", "description": "Optional extension name"}
}
}),
Some(ToolAnnotations {
title: Some("Read a resource".to_string()),
read_only_hint: true,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
);
let list_resources_tool = Tool::new(
"platform__list_resources".to_string(),
indoc! {r#"
List resources from an extension(s).
Resources allow extensions to share data that provide context to LLMs, such as
files, database schemas, or application-specific information. This tool lists resources
in the provided extension, and returns a list for the user to browse. If no extension
is provided, the tool will search all extensions for the resource.
"#}.to_string(),
json!({
"type": "object",
"properties": {
"extension_name": {"type": "string", "description": "Optional extension name"}
}
}),
Some(ToolAnnotations {
title: Some("List resources".to_string()),
read_only_hint: true,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
);
if capabilities.supports_resources() {
tools.push(read_resource_tool);
tools.push(list_resources_tool);
}
let system_prompt = capabilities.get_system_prompt().await;
// Set the user_message field in the span instead of creating a new event
if let Some(content) = messages
.last()
.and_then(|msg| msg.content.first())
.and_then(|c| c.as_text())
{
debug!("user_message" = &content);
}
Ok(Box::pin(async_stream::try_stream! {
let _reply_guard = reply_span.enter();
loop {
match capabilities.provider().complete(
&system_prompt,
&messages,
&tools,
).await {
Ok((response, usage)) => {
// record usage for the session in the session file
if let Some(session) = session.clone() {
// TODO: track session_id in langfuse tracing
let session_file = session::get_path(session.id);
let mut metadata = session::read_metadata(&session_file)?;
metadata.working_dir = session.working_dir;
metadata.total_tokens = usage.usage.total_tokens;
metadata.input_tokens = usage.usage.input_tokens;
metadata.output_tokens = usage.usage.output_tokens;
// The message count is the number of messages in the session + 1 for the response
// The message count does not include the tool response till next iteration
metadata.message_count = messages.len() + 1;
session::update_metadata(&session_file, &metadata).await?;
}
// Reset truncation attempt
truncation_attempt = 0;
// Yield the assistant's response
yield response.clone();
tokio::task::yield_now().await;
// First collect any tool requests
let tool_requests: Vec<&ToolRequest> = response.content
.iter()
.filter_map(|content| content.as_tool_request())
.collect();
if tool_requests.is_empty() {
break;
}
// Process tool requests depending on goose_mode
let mut message_tool_response = Message::user();
// Clone goose_mode once before the match to avoid move issues
let mode = goose_mode.clone();
match mode.as_str() {
"approve" => {
let read_only_tools = detect_read_only_tools(&capabilities, tool_requests.clone()).await;
for request in &tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
// Skip confirmation if the tool_call.name is in the read_only_tools list
if read_only_tools.contains(&tool_call.name) {
let output = capabilities.dispatch_tool_call(tool_call).await;
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
} else {
let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
);
yield confirmation;
// Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await;
// Loop the recv until we have a matched req_id due to potential duplicate messages.
while let Some((req_id, tool_confirmation)) = rx.recv().await {
if req_id == request.id {
if tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow {
// User approved - dispatch the tool call
let output = capabilities.dispatch_tool_call(tool_call).await;
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
} else {
// User declined - add declined response
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text("User declined to run this tool.")]),
);
}
break; // Exit the loop once the matching `req_id` is found
}
}
}
}
}
},
"chat" => {
// Skip all tool calls in chat mode
for request in &tool_requests {
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text(
"The following tool call was skipped in Goose chat mode. \
In chat mode, you cannot run tool calls, instead, you can \
only provide a detailed plan to the user. Provide an \
explanation of the proposed tool call as if it were a plan. \
Only if the user asks, provide a short explanation to the \
user that they could consider running the tool above on \
their own or with a different goose mode."
)]),
);
}
},
_ => {
if mode != "auto" {
warn!("Unknown GOOSE_MODE: {mode:?}. Defaulting to 'auto' mode.");
}
// Process tool requests in parallel
let mut tool_futures = Vec::new();
for request in &tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
tool_futures.push(async {
let output = capabilities.dispatch_tool_call(tool_call).await;
(request.id.clone(), output)
});
}
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
}
}
yield message_tool_response.clone();
messages.push(response);
messages.push(message_tool_response);
},
Err(ProviderError::ContextLengthExceeded(_)) => {
if truncation_attempt >= MAX_TRUNCATION_ATTEMPTS {
// Create an error message & terminate the stream
// the previous message would have been a user message (e.g. before any tool calls, this is just after the input message.
// at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user)
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again.");
break;
}
truncation_attempt += 1;
warn!("Context length exceeded. Truncation Attempt: {}/{}.", truncation_attempt, MAX_TRUNCATION_ATTEMPTS);
// Decay the estimate factor as we make more truncation attempts
// Estimate factor decays like this over time: 0.9, 0.81, 0.729, ...
let estimate_factor: f32 = ESTIMATE_FACTOR_DECAY.powi(truncation_attempt as i32);
// release the lock before truncation to prevent deadlock
drop(capabilities);
if let Err(err) = self.summarize_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await {
yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err));
break;
}
// Re-acquire the lock
capabilities = self.capabilities.lock().await;
// Retry the loop after truncation
continue;
},
Err(e) => {
// Create an error message & terminate the stream
error!("Error: {}", e);
yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."));
break;
}
}
// Yield control back to the scheduler to prevent blocking
tokio::task::yield_now().await;
}
}))
}
async fn extend_system_prompt(&mut self, extension: String) {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_system_prompt_extension(extension);
}
async fn override_system_prompt(&mut self, template: String) {
let mut capabilities = self.capabilities.lock().await;
capabilities.set_system_prompt_override(template);
}
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
let capabilities = self.capabilities.lock().await;
capabilities
.list_prompts()
.await
.expect("Failed to list prompts")
}
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
let capabilities = self.capabilities.lock().await;
// First find which extension has this prompt
let prompts = capabilities
.list_prompts()
.await
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
if let Some(extension) = prompts
.iter()
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
.map(|(extension, _)| extension)
{
return capabilities
.get_prompt(extension, name, arguments)
.await
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
}
Err(anyhow!("Prompt '{}' not found", name))
}
async fn get_plan_prompt(&self) -> anyhow::Result<String> {
let mut capabilities = self.capabilities.lock().await;
let tools = capabilities.get_prefixed_tools().await?;
let tools_info = tools
.into_iter()
.map(|tool| {
ToolInfo::new(
&tool.name,
&tool.description,
get_parameter_names(&tool),
None,
)
})
.collect();
let plan_prompt = capabilities.get_planning_prompt(tools_info).await;
Ok(plan_prompt)
}
async fn provider(&self) -> Arc<Box<dyn Provider>> {
let capabilities = self.capabilities.lock().await;
capabilities.provider()
}
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
if let Err(e) = self.tool_result_tx.send((id, result)).await {
tracing::error!("Failed to send tool result: {}", e);
}
}
}
register_agent!("summarize", SummarizeAgent);

View File

@@ -1,804 +0,0 @@
/// A truncate agent that truncates the conversation history when it exceeds the model's context limit
/// It makes no attempt to handle context limits, and cannot read resources
use async_trait::async_trait;
use futures::stream::BoxStream;
use mcp_core::tool::{Tool, ToolAnnotations};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tracing::{debug, error, instrument, warn};
use super::agent::SessionConfig;
use super::extension::ToolInfo;
use super::types::ToolResultReceiver;
use super::Agent;
use crate::agents::capabilities::{get_parameter_names, Capabilities};
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::config::{Config, ExtensionManager};
use crate::message::{Message, MessageContent, ToolRequest};
use crate::permission::detect_read_only_tools;
use crate::permission::Permission;
use crate::permission::PermissionConfirmation;
use crate::permission::ToolPermissionStore;
use crate::providers::base::Provider;
use crate::providers::errors::ProviderError;
use crate::providers::toolshim::{
augment_message_with_tool_calls, modify_system_prompt_for_tool_json, OllamaInterpreter,
};
use crate::register_agent;
use crate::session;
use crate::token_counter::TokenCounter;
use crate::truncate::{truncate_messages, OldestFirstTruncation};
use anyhow::{anyhow, Result};
use indoc::indoc;
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolError, ToolResult};
use serde_json::{json, Value};
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
/// Truncate implementation of an Agent
pub struct TruncateAgent {
capabilities: Mutex<Capabilities>,
token_counter: TokenCounter,
confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
tool_result_rx: ToolResultReceiver,
}
impl TruncateAgent {
pub fn new(provider: Box<dyn Provider>) -> Self {
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
// Create channels with buffer size 32 (adjust if needed)
let (confirm_tx, confirm_rx) = mpsc::channel(32);
let (tool_tx, tool_rx) = mpsc::channel(32);
Self {
capabilities: Mutex::new(Capabilities::new(provider)),
token_counter,
confirmation_tx: confirm_tx,
confirmation_rx: Mutex::new(confirm_rx),
tool_result_tx: tool_tx,
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
}
}
/// Truncates the messages to fit within the model's context window
/// Ensures the last message is a user message and removes tool call-response pairs
async fn truncate_messages(
&self,
messages: &mut Vec<Message>,
estimate_factor: f32,
system_prompt: &str,
tools: &mut Vec<Tool>,
) -> anyhow::Result<()> {
// Model's actual context limit
let context_limit = self
.capabilities
.lock()
.await
.provider()
.get_model_config()
.context_limit();
// Our conservative estimate of the **target** context limit
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
let context_limit = (context_limit as f32 * estimate_factor) as usize;
// Take into account the system prompt, and our tools input and subtract that from the
// remaining context limit
let system_prompt_token_count = self.token_counter.count_tokens(system_prompt);
let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice());
// Check if system prompt + tools exceed our context limit
let remaining_tokens = context_limit
.checked_sub(system_prompt_token_count)
.and_then(|remaining| remaining.checked_sub(tools_token_count))
.ok_or_else(|| {
anyhow::anyhow!("System prompt and tools exceed estimated context limit")
})?;
let context_limit = remaining_tokens;
// Calculate current token count of each message, use count_chat_tokens to ensure we
// capture the full content of the message, include ToolRequests and ToolResponses
let mut token_counts: Vec<usize> = messages
.iter()
.map(|msg| {
self.token_counter
.count_chat_tokens("", std::slice::from_ref(msg), &[])
})
.collect();
truncate_messages(
messages,
&mut token_counts,
context_limit,
&OldestFirstTruncation,
)
}
async fn create_tool_future(
capabilities: &Capabilities,
tool_call: mcp_core::tool::ToolCall,
request_id: String,
) -> (String, Result<Vec<Content>, ToolError>) {
let output = capabilities.dispatch_tool_call(tool_call).await;
(request_id, output)
}
async fn enable_extension(
capabilities: &mut Capabilities,
extension_name: String,
request_id: String,
) -> (String, Result<Vec<Content>, ToolError>) {
let config = match ExtensionManager::get_config_by_name(&extension_name) {
Ok(Some(config)) => config,
Ok(None) => {
return (
request_id,
Err(ToolError::ExecutionError(format!(
"Extension '{}' not found. Please check the extension name and try again.",
extension_name
))),
)
}
Err(e) => {
return (
request_id,
Err(ToolError::ExecutionError(format!(
"Failed to get extension config: {}",
e
))),
)
}
};
let result = capabilities
.add_extension(config)
.await
.map(|_| {
vec![Content::text(format!(
"The extension '{}' has been installed successfully",
extension_name
))]
})
.map_err(|e| ToolError::ExecutionError(e.to_string()));
(request_id, result)
}
}
#[async_trait]
impl Agent for TruncateAgent {
async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_extension(extension).await
}
async fn list_tools(&self) -> Vec<Tool> {
let mut capabilities = self.capabilities.lock().await;
capabilities.get_prefixed_tools().await.unwrap_or_default()
}
async fn remove_extension(&mut self, name: &str) {
let mut capabilities = self.capabilities.lock().await;
capabilities
.remove_extension(name)
.await
.expect("Failed to remove extension");
}
async fn list_extensions(&self) -> Vec<String> {
let capabilities = self.capabilities.lock().await;
capabilities
.list_extensions()
.await
.expect("Failed to list extensions")
}
async fn passthrough(&self, _extension: &str, _request: Value) -> ExtensionResult<Value> {
// TODO implement
Ok(Value::Null)
}
/// Handle a confirmation response for a tool request
async fn handle_confirmation(&self, request_id: String, confirmation: PermissionConfirmation) {
if let Err(e) = self.confirmation_tx.send((request_id, confirmation)).await {
error!("Failed to send confirmation: {}", e);
}
}
#[instrument(skip(self, messages, session), fields(user_message))]
async fn reply(
&self,
messages: &[Message],
session: Option<SessionConfig>,
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
let mut messages = messages.to_vec();
let reply_span = tracing::Span::current();
let mut capabilities = self.capabilities.lock().await;
let mut tools = capabilities.get_prefixed_tools().await?;
let mut truncation_attempt: usize = 0;
// Load settings from config
let config = Config::global();
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
// we add in the 2 resource tools if any extensions support resources
// TODO: make sure there is no collision with another extension's tool name
let read_resource_tool = Tool::new(
"platform__read_resource".to_string(),
indoc! {r#"
Read a resource from an extension.
Resources allow extensions to share data that provide context to LLMs, such as
files, database schemas, or application-specific information. This tool searches for the
resource URI in the provided extension, and reads in the resource content. If no extension
is provided, the tool will search all extensions for the resource.
"#}.to_string(),
json!({
"type": "object",
"required": ["uri"],
"properties": {
"uri": {"type": "string", "description": "Resource URI"},
"extension_name": {"type": "string", "description": "Optional extension name"}
}
}),
Some(ToolAnnotations {
title: Some("Read a resource".to_string()),
read_only_hint: true,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
);
let list_resources_tool = Tool::new(
"platform__list_resources".to_string(),
indoc! {r#"
List resources from an extension(s).
Resources allow extensions to share data that provide context to LLMs, such as
files, database schemas, or application-specific information. This tool lists resources
in the provided extension, and returns a list for the user to browse. If no extension
is provided, the tool will search all extensions for the resource.
"#}.to_string(),
json!({
"type": "object",
"properties": {
"extension_name": {"type": "string", "description": "Optional extension name"}
}
}),
Some(ToolAnnotations {
title: Some("List resources".to_string()),
read_only_hint: true,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
);
let search_available_extensions = Tool::new(
"platform__search_available_extensions".to_string(),
"Searches for additional extensions available to help complete tasks.
Use this tool when you're unable to find a specific feature or functionality you need to complete your task, or when standard approaches aren't working.
These extensions might provide the exact tools needed to solve your problem.
If you find a relevant one, consider using your tools to enable it.".to_string(),
json!({
"type": "object",
"required": [],
"properties": {}
}),
Some(ToolAnnotations {
title: Some("Discover extensions".to_string()),
read_only_hint: true,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
);
let enable_extension_tool = Tool::new(
"platform__enable_extension".to_string(),
"Enable extensions to help complete tasks.
Enable an extension by providing the extension name.
"
.to_string(),
json!({
"type": "object",
"required": ["extension_name"],
"properties": {
"extension_name": {"type": "string", "description": "The name of the extension to enable"}
}
}),
Some(ToolAnnotations {
title: Some("Enable extensions".to_string()),
read_only_hint: false,
destructive_hint: false,
idempotent_hint: false,
open_world_hint: false,
}),
);
if capabilities.supports_resources() {
tools.push(read_resource_tool);
tools.push(list_resources_tool);
}
tools.push(search_available_extensions);
tools.push(enable_extension_tool);
let (tools_with_readonly_annotation, tools_without_annotation): (Vec<String>, Vec<String>) =
tools.iter().fold((vec![], vec![]), |mut acc, tool| {
match &tool.annotations {
Some(annotations) => {
if annotations.read_only_hint {
acc.0.push(tool.name.clone());
}
}
None => {
acc.1.push(tool.name.clone());
}
}
acc
});
let config = capabilities.provider().get_model_config();
let mut system_prompt = capabilities.get_system_prompt().await;
let mut toolshim_tools = vec![];
if config.toolshim {
// If tool interpretation is enabled, modify the system prompt to instruct to return JSON tool requests
system_prompt = modify_system_prompt_for_tool_json(&system_prompt, &tools);
// make a copy of tools before empty
toolshim_tools = tools.clone();
// pass empty tools vector to provider completion since toolshim will handle tool calls instead
tools = vec![];
}
// Set the user_message field in the span instead of creating a new event
if let Some(content) = messages
.last()
.and_then(|msg| msg.content.first())
.and_then(|c| c.as_text())
{
debug!("user_message" = &content);
}
Ok(Box::pin(async_stream::try_stream! {
let _reply_guard = reply_span.enter();
loop {
match capabilities.provider().complete(
&system_prompt,
&messages,
&tools,
).await {
Ok((mut response, usage)) => {
// Post-process / structure the response only if tool interpretation is enabled
if config.toolshim {
let interpreter = OllamaInterpreter::new()
.map_err(|e| anyhow::anyhow!("Failed to create OllamaInterpreter: {}", e))?;
response = augment_message_with_tool_calls(&interpreter, response, &toolshim_tools).await?;
}
// record usage for the session in the session file
if let Some(session) = session.clone() {
// TODO: track session_id in langfuse tracing
let session_file = session::get_path(session.id);
let mut metadata = session::read_metadata(&session_file)?;
metadata.working_dir = session.working_dir;
metadata.total_tokens = usage.usage.total_tokens;
metadata.input_tokens = usage.usage.input_tokens;
metadata.output_tokens = usage.usage.output_tokens;
// The message count is the number of messages in the session + 1 for the response
// The message count does not include the tool response till next iteration
metadata.message_count = messages.len() + 1;
session::update_metadata(&session_file, &metadata).await?;
}
// Reset truncation attempt
truncation_attempt = 0;
// Yield the assistant's response, but filter out frontend tool requests that we'll process separately
let filtered_response = Message {
role: response.role.clone(),
created: response.created,
content: response.content.iter().filter(|c| {
if let MessageContent::ToolRequest(req) = c {
// Only filter out frontend tool requests
if let Ok(tool_call) = &req.tool_call {
return !capabilities.is_frontend_tool(&tool_call.name);
}
}
true
}).cloned().collect(),
};
yield filtered_response.clone();
tokio::task::yield_now().await;
// First collect any tool requests
let tool_requests: Vec<&ToolRequest> = response.content
.iter()
.filter_map(|content| content.as_tool_request())
.collect();
if tool_requests.is_empty() {
break;
}
// Process tool requests depending on goose_mode
let mut message_tool_response = Message::user();
// First handle any frontend tool requests
let mut remaining_requests = Vec::new();
for request in &tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
if capabilities.is_frontend_tool(&tool_call.name) {
// Send frontend tool request and wait for response
yield Message::assistant().with_frontend_tool_request(
request.id.clone(),
Ok(tool_call.clone())
);
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
message_tool_response = message_tool_response.with_tool_response(id, result);
}
} else {
remaining_requests.push(request);
}
} else {
remaining_requests.push(request);
}
}
// Split tool requests into enable_extension and others
let (enable_extension_requests, non_enable_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone()
.into_iter()
.partition(|req| {
req.tool_call.as_ref()
.map(|call| call.name == "platform__enable_extension")
.unwrap_or(false)
});
let (search_extension_requests, _non_search_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone()
.into_iter()
.partition(|req| {
req.tool_call.as_ref()
.map(|call| call.name == "platform__search_available_extensions")
.unwrap_or(false)
});
// Clone goose_mode once before the match to avoid move issues
let mode = goose_mode.clone();
// If there are install extension requests, always require confirmation
// or if goose_mode is approve or smart_approve, check permissions for all tools
if !enable_extension_requests.is_empty() || mode.as_str() == "approve" || mode.as_str() == "smart_approve" {
let mut needs_confirmation = Vec::<&ToolRequest>::new();
let mut approved_tools = Vec::new();
let mut llm_detect_candidates = Vec::<&ToolRequest>::new();
let mut detected_read_only_tools = Vec::new();
// If approve mode or smart approve mode, check permissions for all tools
if mode.as_str() == "approve" || mode.as_str() == "smart_approve" {
let store = ToolPermissionStore::load()?;
for request in &non_enable_extension_requests {
if let Ok(tool_call) = request.tool_call.clone() {
// Regular permission checking for other tools
if tools_with_readonly_annotation.contains(&tool_call.name) {
approved_tools.push((request.id.clone(), tool_call));
} else if let Some(allowed) = store.check_permission(request) {
if allowed {
// Instead of executing immediately, collect approved tools
approved_tools.push((request.id.clone(), tool_call));
} else {
// If the tool doesn't have any annotation, we can use llm-as-a-judge to check permission.
if tools_without_annotation.contains(&tool_call.name) {
llm_detect_candidates.push(request);
}
needs_confirmation.push(request);
}
} else {
if tools_without_annotation.contains(&tool_call.name) {
llm_detect_candidates.push(request);
}
needs_confirmation.push(request);
}
}
}
}
// Only check read-only status for tools needing confirmation
if !llm_detect_candidates.is_empty() && mode == "smart_approve" {
detected_read_only_tools = detect_read_only_tools(&capabilities, llm_detect_candidates.clone()).await;
// Remove install extensions from read-only tools
if !enable_extension_requests.is_empty() {
detected_read_only_tools.retain(|tool_name| {
!enable_extension_requests.iter().any(|req| {
req.tool_call.as_ref()
.map(|call| call.name == *tool_name)
.unwrap_or(false)
})
});
}
}
// Handle pre-approved and read-only tools in parallel
let mut tool_futures = Vec::new();
let mut install_results = Vec::new();
// Handle install extension requests
for request in &enable_extension_requests {
if let Ok(tool_call) = request.tool_call.clone() {
let confirmation = Message::user().with_enable_extension_request(
request.id.clone(),
tool_call.arguments.get("extension_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string()
);
yield confirmation;
let mut rx = self.confirmation_rx.lock().await;
while let Some((req_id, extension_confirmation)) = rx.recv().await {
if req_id == request.id {
if extension_confirmation.permission == Permission::AllowOnce || extension_confirmation.permission == Permission::AlwaysAllow {
let extension_name = tool_call.arguments.get("extension_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let install_result = Self::enable_extension(&mut capabilities, extension_name, request.id.clone()).await;
install_results.push(install_result);
}
break;
}
}
}
}
// Process read-only tools
for request in &needs_confirmation {
if let Ok(tool_call) = request.tool_call.clone() {
// Skip confirmation if the tool_call.name is in the read_only_tools list
if detected_read_only_tools.contains(&tool_call.name) {
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
tool_futures.push(tool_future);
} else {
let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
);
yield confirmation;
// Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await;
while let Some((req_id, tool_confirmation)) = rx.recv().await {
if req_id == request.id {
let confirmed = tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow;
if confirmed {
// Add this tool call to the futures collection
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
tool_futures.push(tool_future);
} else {
// User declined - add declined response
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text(
"The user has declined to run this tool. \
DO NOT attempt to call this tool again. \
If there are no alternative methods to proceed, clearly explain the situation and STOP.")]),
);
}
break; // Exit the loop once the matching `req_id` is found
}
}
}
}
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
// Check if any install results had errors before processing them
let all_successful = !install_results.iter().any(|(_, result)| result.is_err());
for (request_id, output) in install_results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output
);
}
// Update system prompt and tools if all installations were successful
if all_successful {
system_prompt = capabilities.get_system_prompt().await;
tools = capabilities.get_prefixed_tools().await?;
}
}
if mode.as_str() == "auto" || !search_extension_requests.is_empty() {
let mut tool_futures = Vec::new();
// Process non_enable_extension_requests and search_extension_requests without duplicates
let mut processed_ids = HashSet::new();
for request in non_enable_extension_requests.iter().chain(search_extension_requests.iter()) {
if processed_ids.insert(request.id.clone()) {
if let Ok(tool_call) = request.tool_call.clone() {
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
tool_futures.push(tool_future);
}
}
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
}
if mode.as_str() == "chat" {
// Skip all tool calls in chat mode
// Skip search extension requests since they were already processed
let non_search_non_enable_extension_requests = non_enable_extension_requests.iter()
.filter(|req| {
if let Ok(tool_call) = &req.tool_call {
tool_call.name != "platform__search_available_extensions"
} else {
true
}
});
for request in non_search_non_enable_extension_requests {
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text(
"Let the user know the tool call was skipped in Goose chat mode. \
DO NOT apologize for skipping the tool call. DO NOT say sorry. \
Provide an explanation of what the tool call would do, structured as a \
plan for the user. Again, DO NOT apologize. \
**Example Plan:**\n \
1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \
2. **Outline Steps** - Break down the steps.\n \
If needed, adjust the explanation based on user preferences or questions."
)]),
);
}
}
yield message_tool_response.clone();
messages.push(response);
messages.push(message_tool_response);
},
Err(ProviderError::ContextLengthExceeded(_)) => {
if truncation_attempt >= MAX_TRUNCATION_ATTEMPTS {
// Create an error message & terminate the stream
// the previous message would have been a user message (e.g. before any tool calls, this is just after the input message.
// at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user)
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again.");
break;
}
truncation_attempt += 1;
warn!("Context length exceeded. Truncation Attempt: {}/{}.", truncation_attempt, MAX_TRUNCATION_ATTEMPTS);
// Decay the estimate factor as we make more truncation attempts
// Estimate factor decays like this over time: 0.9, 0.81, 0.729, ...
let estimate_factor: f32 = ESTIMATE_FACTOR_DECAY.powi(truncation_attempt as i32);
// release the lock before truncation to prevent deadlock
drop(capabilities);
if let Err(err) = self.truncate_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await {
yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err));
break;
}
// Re-acquire the lock
capabilities = self.capabilities.lock().await;
// Retry the loop after truncation
continue;
},
Err(e) => {
// Create an error message & terminate the stream
error!("Error: {}", e);
yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."));
break;
}
}
// Yield control back to the scheduler to prevent blocking
tokio::task::yield_now().await;
}
}))
}
async fn extend_system_prompt(&mut self, extension: String) {
let mut capabilities = self.capabilities.lock().await;
capabilities.add_system_prompt_extension(extension);
}
async fn override_system_prompt(&mut self, template: String) {
let mut capabilities = self.capabilities.lock().await;
capabilities.set_system_prompt_override(template);
}
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
let capabilities = self.capabilities.lock().await;
capabilities
.list_prompts()
.await
.expect("Failed to list prompts")
}
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
let capabilities = self.capabilities.lock().await;
// First find which extension has this prompt
let prompts = capabilities
.list_prompts()
.await
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
if let Some(extension) = prompts
.iter()
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
.map(|(extension, _)| extension)
{
return capabilities
.get_prompt(extension, name, arguments)
.await
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
}
Err(anyhow!("Prompt '{}' not found", name))
}
async fn get_plan_prompt(&self) -> anyhow::Result<String> {
let mut capabilities = self.capabilities.lock().await;
let tools = capabilities.get_prefixed_tools().await?;
let tools_info = tools
.into_iter()
.map(|tool| {
ToolInfo::new(
&tool.name,
&tool.description,
get_parameter_names(&tool),
None,
)
})
.collect();
let plan_prompt = capabilities.get_planning_prompt(tools_info).await;
Ok(plan_prompt)
}
async fn provider(&self) -> Arc<Box<dyn Provider>> {
let capabilities = self.capabilities.lock().await;
capabilities.provider()
}
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
if let Err(e) = self.tool_result_tx.send((id, result)).await {
tracing::error!("Failed to send tool result: {}", e);
}
}
}
register_agent!("truncate", TruncateAgent);

View File

@@ -1,6 +1,25 @@
use mcp_core::{Content, ToolResult};
use crate::session;
use mcp_core::{Content, Tool, ToolResult};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
/// Type alias for the tool result channel receiver
pub type ToolResultReceiver = Arc<Mutex<mpsc::Receiver<(String, ToolResult<Vec<Content>>)>>>;
/// A frontend tool that will be executed by the frontend rather than an extension
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FrontendTool {
pub name: String,
pub tool: Tool,
}
/// Session configuration for an agent
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionConfig {
/// Unique identifier for the session
pub id: session::Identifier,
/// Working directory for the session
pub working_dir: PathBuf,
}

View File

@@ -25,9 +25,9 @@ pub fn name_to_key(name: &str) -> String {
}
/// Extension configuration management
pub struct ExtensionManager;
pub struct ExtensionConfigManager;
impl ExtensionManager {
impl ExtensionConfigManager {
/// Get the extension configuration if enabled -- uses key
pub fn get_config(key: &str) -> Result<Option<ExtensionConfig>> {
let config = Config::global();

View File

@@ -6,7 +6,7 @@ pub mod permission;
pub use crate::agents::ExtensionConfig;
pub use base::{Config, ConfigError, APP_STRATEGY};
pub use experiments::ExperimentManager;
pub use extensions::{ExtensionEntry, ExtensionManager};
pub use extensions::{ExtensionConfigManager, ExtensionEntry};
pub use permission::PermissionManager;
pub use extensions::DEFAULT_DISPLAY_NAME;

View File

@@ -1,7 +1,8 @@
use crate::agents::Capabilities;
use crate::message::Message;
use crate::providers::base::Provider;
use crate::token_counter::TokenCounter;
use anyhow::{anyhow, Result};
use std::sync::Arc;
use tracing::debug;
const SYSTEM_PROMPT: &str = "You are good at summarizing.";
@@ -12,17 +13,13 @@ fn create_summarize_request(messages: &[Message]) -> Vec<Message> {
]
}
async fn single_request(
capabilities: &Capabilities,
provider: &Arc<dyn Provider>,
messages: &[Message],
) -> Result<Message, anyhow::Error> {
Ok(capabilities
.provider()
.complete(SYSTEM_PROMPT, messages, &[])
.await?
.0)
Ok(provider.complete(SYSTEM_PROMPT, messages, &[]).await?.0)
}
async fn memory_condense(
capabilities: &Capabilities,
provider: Arc<dyn Provider>,
token_counter: &TokenCounter,
messages: &mut Vec<Message>,
token_counts: &mut Vec<usize>,
@@ -67,9 +64,7 @@ async fn memory_condense(
diff = -(current_tokens as isize);
let request = create_summarize_request(&batch);
let response_text = single_request(capabilities, &request)
.await?
.as_concat_text();
let response_text = single_request(&provider, &request).await?.as_concat_text();
// Ensure the conversation starts with a User message
let curr_messages = vec![
@@ -95,8 +90,9 @@ async fn memory_condense(
}
}
/// TODO: currently not used. we will add this is a feature flag under context mgmt
pub async fn condense_messages(
capabilities: &Capabilities,
provider: Arc<dyn Provider>,
token_counter: &TokenCounter,
messages: &mut Vec<Message>,
token_counts: &mut Vec<usize>,
@@ -108,7 +104,7 @@ pub async fn condense_messages(
// The compressor should determine whether we need to compress the messages or not. This
// function just checks if the limit is satisfied.
memory_condense(
capabilities,
provider,
token_counter,
messages,
token_counts,

View File

@@ -1,13 +1,14 @@
use crate::agents::capabilities::Capabilities;
use crate::config::permission::PermissionLevel;
use crate::config::PermissionManager;
use crate::message::{Message, MessageContent, ToolRequest};
use crate::providers::base::Provider;
use chrono::Utc;
use indoc::indoc;
use mcp_core::tool::ToolAnnotations;
use mcp_core::{tool::Tool, TextContent};
use serde_json::{json, Value};
use std::collections::HashSet;
use std::sync::Arc;
/// Creates the tool definition for checking read-only permissions.
fn create_read_only_tool() -> Tool {
@@ -113,7 +114,7 @@ fn extract_read_only_tools(response: &Message) -> Option<Vec<String>> {
/// Executes the read-only tools detection and returns the list of tools with read-only operations.
pub async fn detect_read_only_tools(
capabilities: &Capabilities,
provider: Arc<dyn Provider>,
tool_requests: Vec<&ToolRequest>,
) -> Vec<String> {
if tool_requests.is_empty() {
@@ -122,8 +123,7 @@ pub async fn detect_read_only_tools(
let tool = create_read_only_tool();
let check_messages = create_check_messages(tool_requests);
let res = capabilities
.provider()
let res = provider
.complete(
"You are a good analyst and can detect operations whether they have read-only operations.",
&check_messages,
@@ -152,7 +152,7 @@ pub async fn check_tool_permissions(
tools_with_readonly_annotation: HashSet<String>,
tools_without_annotation: HashSet<String>,
permission_manager: &mut PermissionManager,
capabilities: &Capabilities,
provider: Arc<dyn Provider>,
) -> PermissionCheckResult {
let mut approved = vec![];
let mut needs_approval = vec![];
@@ -206,7 +206,7 @@ pub async fn check_tool_permissions(
// 3. LLM detect
if !llm_detect_candidates.is_empty() && mode == "smart_approve" {
let detected_readonly_tools =
detect_read_only_tools(capabilities, llm_detect_candidates.iter().collect()).await;
detect_read_only_tools(provider, llm_detect_candidates.iter().collect()).await;
for request in llm_detect_candidates {
if let Ok(tool_call) = request.tool_call.clone() {
if detected_readonly_tools.contains(&tool_call.name) {
@@ -236,7 +236,6 @@ pub async fn check_tool_permissions(
#[cfg(test)]
mod tests {
use super::*;
use crate::agents::capabilities::Capabilities;
use crate::message::{Message, MessageContent, ToolRequest};
use crate::model::ModelConfig;
use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage};
@@ -287,12 +286,12 @@ mod tests {
}
}
fn create_mock_capabilities() -> Capabilities {
fn create_mock_provider() -> Arc<dyn Provider> {
let mock_model_config =
ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into());
Capabilities::new(Box::new(MockProvider {
Arc::new(MockProvider {
model_config: mock_model_config,
}))
})
}
#[tokio::test]
@@ -349,7 +348,7 @@ mod tests {
#[tokio::test]
async fn test_detect_read_only_tools() {
let capabilities = create_mock_capabilities();
let provider = create_mock_provider();
let tool_request = ToolRequest {
id: "tool_1".to_string(),
tool_call: ToolResult::Ok(ToolCall {
@@ -358,14 +357,14 @@ mod tests {
}),
};
let result = detect_read_only_tools(&capabilities, vec![&tool_request]).await;
let result = detect_read_only_tools(provider, vec![&tool_request]).await;
assert_eq!(result, vec!["file_reader", "data_fetcher"]);
}
#[tokio::test]
async fn test_detect_read_only_tools_empty_requests() {
let capabilities = create_mock_capabilities();
let result = detect_read_only_tools(&capabilities, vec![]).await;
let provider = create_mock_provider();
let result = detect_read_only_tools(provider, vec![]).await;
assert!(result.is_empty());
}
@@ -375,7 +374,7 @@ mod tests {
let temp_file = NamedTempFile::new().unwrap();
let temp_path = temp_file.path();
let mut permission_manager = PermissionManager::new(temp_path);
let capabilities = create_mock_capabilities();
let provider = create_mock_provider();
let tools_with_readonly_annotation: HashSet<String> =
vec!["file_reader".to_string()].into_iter().collect();
@@ -411,7 +410,7 @@ mod tests {
tools_with_readonly_annotation,
tools_without_annotation,
&mut permission_manager,
&capabilities,
provider,
)
.await;

View File

@@ -1,3 +1,5 @@
use std::sync::Arc;
use super::{
anthropic::AnthropicProvider,
azure::AzureProvider,
@@ -29,18 +31,19 @@ pub fn providers() -> Vec<ProviderMetadata> {
]
}
pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send + Sync>> {
pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
// We use Arc instead of Box to be able to clone for multiple async tasks
match name {
"openai" => Ok(Box::new(OpenAiProvider::from_env(model)?)),
"anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)),
"azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)),
"aws_bedrock" => Ok(Box::new(BedrockProvider::from_env(model)?)),
"databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)),
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
"openrouter" => Ok(Box::new(OpenRouterProvider::from_env(model)?)),
"gcp_vertex_ai" => Ok(Box::new(GcpVertexAIProvider::from_env(model)?)),
"google" => Ok(Box::new(GoogleProvider::from_env(model)?)),
"openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)),
"anthropic" => Ok(Arc::new(AnthropicProvider::from_env(model)?)),
"azure_openai" => Ok(Arc::new(AzureProvider::from_env(model)?)),
"aws_bedrock" => Ok(Arc::new(BedrockProvider::from_env(model)?)),
"databricks" => Ok(Arc::new(DatabricksProvider::from_env(model)?)),
"groq" => Ok(Arc::new(GroqProvider::from_env(model)?)),
"ollama" => Ok(Arc::new(OllamaProvider::from_env(model)?)),
"openrouter" => Ok(Arc::new(OpenRouterProvider::from_env(model)?)),
"gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)),
"google" => Ok(Arc::new(GoogleProvider::from_env(model)?)),
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
}
}

View File

@@ -243,7 +243,7 @@ pub fn read_metadata(session_file: &Path) -> Result<SessionMetadata> {
pub async fn persist_messages(
session_file: &Path,
messages: &[Message],
provider: Option<Arc<Box<dyn Provider>>>,
provider: Option<Arc<dyn Provider>>,
) -> Result<()> {
// Count user messages
let user_message_count = messages
@@ -255,7 +255,7 @@ pub async fn persist_messages(
match provider {
Some(provider) if user_message_count < 4 => {
//generate_description is responsible for writing the messages
generate_description(session_file, messages, provider.as_ref().as_ref()).await
generate_description(session_file, messages, provider).await
}
_ => {
// Read existing metadata
@@ -298,7 +298,7 @@ pub fn save_messages_with_metadata(
pub async fn generate_description(
session_file: &Path,
messages: &[Message],
provider: &dyn Provider,
provider: Arc<dyn Provider>,
) -> Result<()> {
// Create a special message asking for a 3-word description
let mut description_prompt = "Based on the conversation so far, provide a concise description of this session in 4 words or less. This will be used for finding the session later in a UI with limited space - reply *ONLY* with the description".to_string();

View File

@@ -1,8 +1,10 @@
// src/lib.rs or tests/truncate_agent_tests.rs
use std::sync::Arc;
use anyhow::Result;
use futures::StreamExt;
use goose::agents::AgentFactory;
use goose::agents::Agent;
use goose::message::Message;
use goose::model::ModelConfig;
use goose::providers::base::Provider;
@@ -65,18 +67,18 @@ impl ProviderType {
}
}
fn create_provider(&self, model_config: ModelConfig) -> Result<Box<dyn Provider>> {
fn create_provider(&self, model_config: ModelConfig) -> Result<Arc<dyn Provider>> {
Ok(match self {
ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?),
ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?),
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?),
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
ProviderType::GcpVertexAI => Box::new(GcpVertexAIProvider::from_env(model_config)?),
ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?),
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
ProviderType::Ollama => Box::new(OllamaProvider::from_env(model_config)?),
ProviderType::OpenRouter => Box::new(OpenRouterProvider::from_env(model_config)?),
ProviderType::Azure => Arc::new(AzureProvider::from_env(model_config)?),
ProviderType::OpenAi => Arc::new(OpenAiProvider::from_env(model_config)?),
ProviderType::Anthropic => Arc::new(AnthropicProvider::from_env(model_config)?),
ProviderType::Bedrock => Arc::new(BedrockProvider::from_env(model_config)?),
ProviderType::Databricks => Arc::new(DatabricksProvider::from_env(model_config)?),
ProviderType::GcpVertexAI => Arc::new(GcpVertexAIProvider::from_env(model_config)?),
ProviderType::Google => Arc::new(GoogleProvider::from_env(model_config)?),
ProviderType::Groq => Arc::new(GroqProvider::from_env(model_config)?),
ProviderType::Ollama => Arc::new(OllamaProvider::from_env(model_config)?),
ProviderType::OpenRouter => Arc::new(OpenRouterProvider::from_env(model_config)?),
})
}
}
@@ -108,7 +110,7 @@ async fn run_truncate_test(
.with_temperature(Some(0.0));
let provider = provider_type.create_provider(model_config)?;
let agent = AgentFactory::create("truncate", provider).unwrap();
let agent = Agent::new(provider);
let repeat_count = context_window + 10_000;
let large_message_content = "hello ".repeat(repeat_count);
let messages = vec![
@@ -185,7 +187,7 @@ mod tests {
}
#[tokio::test]
async fn test_truncate_agent_with_openai() -> Result<()> {
async fn test_agent_with_openai() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::OpenAi,
model: "o3-mini-low",
@@ -195,7 +197,7 @@ mod tests {
}
#[tokio::test]
async fn test_truncate_agent_with_azure() -> Result<()> {
async fn test_agent_with_azure() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::Azure,
model: "gpt-4o-mini",
@@ -205,7 +207,7 @@ mod tests {
}
#[tokio::test]
async fn test_truncate_agent_with_anthropic() -> Result<()> {
async fn test_agent_with_anthropic() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::Anthropic,
model: "claude-3-5-haiku-latest",
@@ -215,7 +217,7 @@ mod tests {
}
#[tokio::test]
async fn test_truncate_agent_with_bedrock() -> Result<()> {
async fn test_agent_with_bedrock() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::Bedrock,
model: "anthropic.claude-3-5-sonnet-20241022-v2:0",
@@ -225,7 +227,7 @@ mod tests {
}
#[tokio::test]
async fn test_truncate_agent_with_databricks() -> Result<()> {
async fn test_agent_with_databricks() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::Databricks,
model: "databricks-meta-llama-3-3-70b-instruct",
@@ -235,7 +237,7 @@ mod tests {
}
#[tokio::test]
async fn test_truncate_agent_with_databricks_bedrock() -> Result<()> {
async fn test_agent_with_databricks_bedrock() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::Databricks,
model: "claude-3-5-sonnet-2",
@@ -245,7 +247,7 @@ mod tests {
}
#[tokio::test]
async fn test_truncate_agent_with_databricks_openai() -> Result<()> {
async fn test_agent_with_databricks_openai() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::Databricks,
model: "gpt-4o-mini",
@@ -255,7 +257,7 @@ mod tests {
}
#[tokio::test]
async fn test_truncate_agent_with_google() -> Result<()> {
async fn test_agent_with_google() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::Google,
model: "gemini-2.0-flash-exp",
@@ -265,7 +267,7 @@ mod tests {
}
#[tokio::test]
async fn test_truncate_agent_with_groq() -> Result<()> {
async fn test_agent_with_groq() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::Groq,
model: "gemma2-9b-it",
@@ -275,7 +277,7 @@ mod tests {
}
#[tokio::test]
async fn test_truncate_agent_with_openrouter() -> Result<()> {
async fn test_agent_with_openrouter() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::OpenRouter,
model: "deepseek/deepseek-r1",
@@ -285,7 +287,7 @@ mod tests {
}
#[tokio::test]
async fn test_truncate_agent_with_ollama() -> Result<()> {
async fn test_agent_with_ollama() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::Ollama,
model: "llama3.2",
@@ -295,7 +297,7 @@ mod tests {
}
#[tokio::test]
async fn test_truncate_agent_with_gcpvertexai() -> Result<()> {
async fn test_agent_with_gcpvertexai() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::GcpVertexAI,
model: "claude-3-5-sonnet-v2@20241022",

View File

@@ -77,14 +77,14 @@ lazy_static::lazy_static! {
/// Generic test harness for any Provider implementation
struct ProviderTester {
provider: Box<dyn Provider + Send + Sync>,
provider: Arc<dyn Provider>,
name: String,
}
impl ProviderTester {
fn new<T: Provider + Send + Sync + 'static>(provider: T, name: String) -> Self {
Self {
provider: Box::new(provider),
provider: Arc::new(provider),
name,
}
}