mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-23 15:34:27 +01:00
refactor: remove agent flavours, move provider to Agent (#2091)
This commit is contained in:
@@ -3,7 +3,6 @@ use clap::{Args, Parser, Subcommand};
|
||||
|
||||
use goose::config::Config;
|
||||
|
||||
use crate::commands::agent_version::AgentCommand;
|
||||
use crate::commands::bench::agent_generator;
|
||||
use crate::commands::configure::handle_configure;
|
||||
use crate::commands::info::handle_info;
|
||||
@@ -279,9 +278,6 @@ enum Command {
|
||||
builtin: Vec<String>,
|
||||
},
|
||||
|
||||
/// List available agent versions
|
||||
Agents(AgentCommand),
|
||||
|
||||
/// Update the Goose CLI version
|
||||
#[command(about = "Update the goose CLI version")]
|
||||
Update {
|
||||
@@ -417,10 +413,6 @@ pub async fn cli() -> Result<()> {
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
Some(Command::Agents(cmd)) => {
|
||||
cmd.run()?;
|
||||
return Ok(());
|
||||
}
|
||||
Some(Command::Update {
|
||||
canary,
|
||||
reconfigure,
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use clap::Args;
|
||||
use goose::agents::AgentFactory;
|
||||
use std::fmt::Write;
|
||||
|
||||
#[derive(Args)]
|
||||
pub struct AgentCommand {}
|
||||
|
||||
impl AgentCommand {
|
||||
pub fn run(&self) -> Result<()> {
|
||||
let mut output = String::new();
|
||||
writeln!(output, "Available agent versions:")?;
|
||||
|
||||
let versions = AgentFactory::available_versions();
|
||||
let default_version = AgentFactory::default_version();
|
||||
let configured_version = AgentFactory::configured_version();
|
||||
|
||||
for version in versions {
|
||||
if version == default_version && version == configured_version {
|
||||
writeln!(output, "* {} (default)", version)?;
|
||||
} else if version == default_version {
|
||||
writeln!(output, " {} (default)", version)?;
|
||||
} else if version == configured_version {
|
||||
writeln!(output, "* {}", version)?;
|
||||
} else {
|
||||
writeln!(output, " {}", version)?;
|
||||
}
|
||||
}
|
||||
|
||||
print!("{}", output);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,8 @@ use console::style;
|
||||
use goose::agents::{extension::Envs, ExtensionConfig};
|
||||
use goose::config::extensions::name_to_key;
|
||||
use goose::config::{
|
||||
Config, ConfigError, ExperimentManager, ExtensionEntry, ExtensionManager, PermissionManager,
|
||||
Config, ConfigError, ExperimentManager, ExtensionConfigManager, ExtensionEntry,
|
||||
PermissionManager,
|
||||
};
|
||||
use goose::message::Message;
|
||||
use goose::providers::{create, providers};
|
||||
@@ -63,7 +64,7 @@ pub async fn handle_configure() -> Result<(), Box<dyn Error>> {
|
||||
);
|
||||
// Since we are setting up for the first time, we'll also enable the developer system
|
||||
// This operation is best-effort and errors are ignored
|
||||
ExtensionManager::set(ExtensionEntry {
|
||||
ExtensionConfigManager::set(ExtensionEntry {
|
||||
enabled: true,
|
||||
config: ExtensionConfig::Builtin {
|
||||
name: "developer".to_string(),
|
||||
@@ -392,7 +393,7 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
|
||||
/// Configure extensions that can be used with goose
|
||||
/// Dialog for toggling which extensions are enabled/disabled
|
||||
pub fn toggle_extensions_dialog() -> Result<(), Box<dyn Error>> {
|
||||
let extensions = ExtensionManager::get_all()?;
|
||||
let extensions = ExtensionConfigManager::get_all()?;
|
||||
|
||||
if extensions.is_empty() {
|
||||
cliclack::outro(
|
||||
@@ -430,7 +431,7 @@ pub fn toggle_extensions_dialog() -> Result<(), Box<dyn Error>> {
|
||||
|
||||
// Update enabled status for each extension
|
||||
for name in extension_status.iter().map(|(name, _)| name) {
|
||||
ExtensionManager::set_enabled(
|
||||
ExtensionConfigManager::set_enabled(
|
||||
&name_to_key(name),
|
||||
selected.iter().any(|s| s.as_str() == name),
|
||||
)?;
|
||||
@@ -502,7 +503,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
|
||||
|
||||
let display_name = get_display_name(&extension);
|
||||
|
||||
ExtensionManager::set(ExtensionEntry {
|
||||
ExtensionConfigManager::set(ExtensionEntry {
|
||||
enabled: true,
|
||||
config: ExtensionConfig::Builtin {
|
||||
name: extension.clone(),
|
||||
@@ -514,7 +515,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
|
||||
cliclack::outro(format!("Enabled {} extension", style(extension).green()))?;
|
||||
}
|
||||
"stdio" => {
|
||||
let extensions = ExtensionManager::get_all_names()?;
|
||||
let extensions = ExtensionConfigManager::get_all_names()?;
|
||||
let name: String = cliclack::input("What would you like to call this extension?")
|
||||
.placeholder("my-extension")
|
||||
.validate(move |input: &String| {
|
||||
@@ -590,7 +591,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
|
||||
}
|
||||
}
|
||||
|
||||
ExtensionManager::set(ExtensionEntry {
|
||||
ExtensionConfigManager::set(ExtensionEntry {
|
||||
enabled: true,
|
||||
config: ExtensionConfig::Stdio {
|
||||
name: name.clone(),
|
||||
@@ -605,7 +606,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
|
||||
cliclack::outro(format!("Added {} extension", style(name).green()))?;
|
||||
}
|
||||
"sse" => {
|
||||
let extensions = ExtensionManager::get_all_names()?;
|
||||
let extensions = ExtensionConfigManager::get_all_names()?;
|
||||
let name: String = cliclack::input("What would you like to call this extension?")
|
||||
.placeholder("my-remote-extension")
|
||||
.validate(move |input: &String| {
|
||||
@@ -677,7 +678,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
|
||||
}
|
||||
}
|
||||
|
||||
ExtensionManager::set(ExtensionEntry {
|
||||
ExtensionConfigManager::set(ExtensionEntry {
|
||||
enabled: true,
|
||||
config: ExtensionConfig::Sse {
|
||||
name: name.clone(),
|
||||
@@ -697,7 +698,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box<dyn Error>> {
|
||||
}
|
||||
|
||||
pub fn remove_extension_dialog() -> Result<(), Box<dyn Error>> {
|
||||
let extensions = ExtensionManager::get_all()?;
|
||||
let extensions = ExtensionConfigManager::get_all()?;
|
||||
|
||||
// Create a list of extension names and their enabled status
|
||||
let extension_status: Vec<(String, bool)> = extensions
|
||||
@@ -739,7 +740,7 @@ pub fn remove_extension_dialog() -> Result<(), Box<dyn Error>> {
|
||||
.interact()?;
|
||||
|
||||
for name in selected {
|
||||
ExtensionManager::remove(&name_to_key(name))?;
|
||||
ExtensionConfigManager::remove(&name_to_key(name))?;
|
||||
let mut permission_manager = PermissionManager::default();
|
||||
permission_manager.remove_extension(&name_to_key(name));
|
||||
cliclack::outro(format!("Removed {} extension", style(name).green()))?;
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
pub mod agent_version;
|
||||
pub mod bench;
|
||||
pub mod configure;
|
||||
pub mod info;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use console::style;
|
||||
use goose::agents::extension::ExtensionError;
|
||||
use goose::agents::AgentFactory;
|
||||
use goose::config::{Config, ExtensionManager};
|
||||
use goose::agents::Agent;
|
||||
use goose::config::{Config, ExtensionConfigManager};
|
||||
use goose::session;
|
||||
use goose::session::Identifier;
|
||||
use mcp_client::transport::Error as McpClientError;
|
||||
@@ -33,8 +33,7 @@ pub async fn build_session(
|
||||
goose::providers::create(&provider_name, model_config).expect("Failed to create provider");
|
||||
|
||||
// Create the agent
|
||||
let mut agent = AgentFactory::create(&AgentFactory::configured_version(), provider)
|
||||
.expect("Failed to create agent");
|
||||
let mut agent = Agent::new(provider);
|
||||
|
||||
// Handle session file resolution and resuming
|
||||
let session_file = if resume {
|
||||
@@ -93,7 +92,7 @@ pub async fn build_session(
|
||||
|
||||
// Setup extensions for the agent
|
||||
// Extensions need to be added after the session is created because we change directory when resuming a session
|
||||
for extension in ExtensionManager::get_all().expect("should load extensions") {
|
||||
for extension in ExtensionConfigManager::get_all().expect("should load extensions") {
|
||||
if extension.enabled {
|
||||
let config = extension.config.clone();
|
||||
agent
|
||||
|
||||
@@ -38,7 +38,7 @@ pub enum RunMode {
|
||||
}
|
||||
|
||||
pub struct Session {
|
||||
agent: Box<dyn Agent>,
|
||||
agent: Agent,
|
||||
messages: Vec<Message>,
|
||||
session_file: PathBuf,
|
||||
// Cache for completion data - using std::sync for thread safety without async
|
||||
@@ -76,7 +76,7 @@ pub enum PlannerResponseType {
|
||||
/// question.
|
||||
pub async fn classify_planner_response(
|
||||
message_text: String,
|
||||
provider: Arc<Box<dyn Provider>>,
|
||||
provider: Arc<dyn Provider>,
|
||||
) -> Result<PlannerResponseType> {
|
||||
let prompt = format!("The text below is the output from an AI model which can either provide a plan or list of clarifying questions. Based on the text below, decide if the output is a \"plan\" or \"clarifying questions\".\n---\n{message_text}");
|
||||
|
||||
@@ -101,7 +101,7 @@ pub async fn classify_planner_response(
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new(agent: Box<dyn Agent>, session_file: PathBuf, debug: bool) -> Self {
|
||||
pub fn new(agent: Agent, session_file: PathBuf, debug: bool) -> Self {
|
||||
let messages = match session::read_messages(&session_file) {
|
||||
Ok(msgs) => msgs,
|
||||
Err(e) => {
|
||||
@@ -278,7 +278,7 @@ impl Session {
|
||||
async fn process_message(&mut self, message: String) -> Result<()> {
|
||||
self.messages.push(Message::user().with_text(&message));
|
||||
// Get the provider from the agent for description generation
|
||||
let provider = self.agent.provider().await;
|
||||
let provider = self.agent.provider();
|
||||
|
||||
// Persist messages with provider for automatic description generation
|
||||
session::persist_messages(&self.session_file, &self.messages, Some(provider)).await?;
|
||||
@@ -350,7 +350,7 @@ impl Session {
|
||||
self.messages.push(Message::user().with_text(&content));
|
||||
|
||||
// Get the provider from the agent for description generation
|
||||
let provider = self.agent.provider().await;
|
||||
let provider = self.agent.provider();
|
||||
|
||||
// Persist messages with provider for automatic description generation
|
||||
session::persist_messages(
|
||||
@@ -535,7 +535,7 @@ impl Session {
|
||||
async fn plan_with_reasoner_model(
|
||||
&mut self,
|
||||
plan_messages: Vec<Message>,
|
||||
reasoner: Box<dyn Provider + Send + Sync>,
|
||||
reasoner: Arc<dyn Provider>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let plan_prompt = self.agent.get_plan_prompt().await?;
|
||||
output::show_thinking();
|
||||
@@ -543,7 +543,7 @@ impl Session {
|
||||
output::render_message(&plan_response, self.debug);
|
||||
output::hide_thinking();
|
||||
let planner_response_type =
|
||||
classify_planner_response(plan_response.as_concat_text(), self.agent.provider().await)
|
||||
classify_planner_response(plan_response.as_concat_text(), self.agent.provider())
|
||||
.await?;
|
||||
|
||||
match planner_response_type {
|
||||
@@ -857,7 +857,7 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_reasoner() -> Result<Box<dyn Provider + Send + Sync>, anyhow::Error> {
|
||||
fn get_reasoner() -> Result<Arc<dyn Provider>, anyhow::Error> {
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::create;
|
||||
|
||||
|
||||
@@ -5,11 +5,10 @@ use axum::{
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use goose::{agents::AgentFactory, config::PermissionManager, model::ModelConfig, providers};
|
||||
use goose::{
|
||||
agents::{capabilities::get_parameter_names, extension::ToolInfo},
|
||||
config::Config,
|
||||
};
|
||||
use goose::agents::{extension::ToolInfo, extension_manager::get_parameter_names};
|
||||
use goose::config::Config;
|
||||
use goose::config::PermissionManager;
|
||||
use goose::{agents::Agent, model::ModelConfig, providers};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
@@ -32,7 +31,6 @@ struct ExtendPromptResponse {
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct CreateAgentRequest {
|
||||
version: Option<String>,
|
||||
provider: String,
|
||||
model: Option<String>,
|
||||
}
|
||||
@@ -70,8 +68,8 @@ pub struct GetToolsQuery {
|
||||
}
|
||||
|
||||
async fn get_versions() -> Json<VersionsResponse> {
|
||||
let versions = AgentFactory::available_versions();
|
||||
let default_version = AgentFactory::default_version().to_string();
|
||||
let versions = ["goose".to_string()];
|
||||
let default_version = "goose".to_string();
|
||||
|
||||
Json(VersionsResponse {
|
||||
available_versions: versions.iter().map(|v| v.to_string()).collect(),
|
||||
@@ -136,11 +134,8 @@ async fn create_agent(
|
||||
let provider =
|
||||
providers::create(&payload.provider, model_config).expect("Failed to create provider");
|
||||
|
||||
let version = payload
|
||||
.version
|
||||
.unwrap_or_else(|| AgentFactory::default_version().to_string());
|
||||
|
||||
let new_agent = AgentFactory::create(&version, provider).expect("Failed to create agent");
|
||||
let version = String::from("goose");
|
||||
let new_agent = Agent::new(provider);
|
||||
|
||||
let mut agent = state.agent.write().await;
|
||||
*agent = Some(new_agent);
|
||||
|
||||
@@ -8,7 +8,7 @@ use axum::{
|
||||
use goose::agents::ExtensionConfig;
|
||||
use goose::config::extensions::name_to_key;
|
||||
use goose::config::Config;
|
||||
use goose::config::{ExtensionEntry, ExtensionManager};
|
||||
use goose::config::{ExtensionConfigManager, ExtensionEntry};
|
||||
use goose::providers::base::ProviderMetadata;
|
||||
use goose::providers::providers as get_providers;
|
||||
use http::{HeaderMap, StatusCode};
|
||||
@@ -184,7 +184,7 @@ pub async fn get_extensions(
|
||||
) -> Result<Json<ExtensionResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
match ExtensionManager::get_all() {
|
||||
match ExtensionConfigManager::get_all() {
|
||||
Ok(extensions) => Ok(Json(ExtensionResponse { extensions })),
|
||||
Err(err) => {
|
||||
// Return UNPROCESSABLE_ENTITY only for DeserializeError, INTERNAL_SERVER_ERROR for everything else
|
||||
@@ -219,13 +219,13 @@ pub async fn add_extension(
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
// Get existing extensions to check if this is an update
|
||||
let extensions = ExtensionManager::get_all().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let extensions =
|
||||
ExtensionConfigManager::get_all().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let key = name_to_key(&extension_query.name);
|
||||
|
||||
let is_update = extensions.iter().any(|e| e.config.key() == key);
|
||||
|
||||
// Use ExtensionManager to set the extension
|
||||
match ExtensionManager::set(ExtensionEntry {
|
||||
match ExtensionConfigManager::set(ExtensionEntry {
|
||||
enabled: extension_query.enabled,
|
||||
config: extension_query.config,
|
||||
}) {
|
||||
@@ -257,8 +257,7 @@ pub async fn remove_extension(
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
let key = name_to_key(&name);
|
||||
// Use ExtensionManager to remove the extension
|
||||
match ExtensionManager::remove(&key) {
|
||||
match ExtensionConfigManager::remove(&key) {
|
||||
Ok(_) => Ok(Json(format!("Removed extension {}", name))),
|
||||
Err(_) => Err(StatusCode::NOT_FOUND),
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ async fn handler(
|
||||
};
|
||||
|
||||
// Get the provider first, before starting the reply stream
|
||||
let provider = agent.provider().await;
|
||||
let provider = agent.provider();
|
||||
|
||||
let mut stream = match agent
|
||||
.reply(
|
||||
@@ -294,7 +294,7 @@ async fn ask_handler(
|
||||
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Get the provider first, before starting the reply stream
|
||||
let provider = agent.provider().await;
|
||||
let provider = agent.provider();
|
||||
|
||||
// Create a single message for the prompt
|
||||
let messages = vec![Message::user().with_text(request.prompt)];
|
||||
@@ -467,7 +467,7 @@ pub fn routes(state: AppState) -> Router {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use goose::{
|
||||
agents::AgentFactory,
|
||||
agents::Agent,
|
||||
model::ModelConfig,
|
||||
providers::{
|
||||
base::{Provider, ProviderUsage, Usage},
|
||||
@@ -518,10 +518,10 @@ mod tests {
|
||||
async fn test_ask_endpoint() {
|
||||
// Create a mock app state with mock provider
|
||||
let mock_model_config = ModelConfig::new("test-model".to_string());
|
||||
let mock_provider = Box::new(MockProvider {
|
||||
let mock_provider = Arc::new(MockProvider {
|
||||
model_config: mock_model_config,
|
||||
});
|
||||
let agent = AgentFactory::create("reference", mock_provider).unwrap();
|
||||
let agent = Agent::new(mock_provider);
|
||||
let state = AppState {
|
||||
config: Arc::new(Mutex::new(HashMap::new())),
|
||||
agent: Arc::new(RwLock::new(Some(agent))),
|
||||
|
||||
@@ -9,7 +9,7 @@ use tokio::sync::{Mutex, RwLock};
|
||||
#[allow(dead_code)]
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub agent: Arc<RwLock<Option<Box<dyn Agent>>>>,
|
||||
pub agent: Arc<RwLock<Option<Agent>>>,
|
||||
pub secret_key: String,
|
||||
pub config: Arc<Mutex<HashMap<String, Value>>>,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use dotenv::dotenv;
|
||||
use futures::StreamExt;
|
||||
use goose::agents::{AgentFactory, ExtensionConfig};
|
||||
use goose::agents::{Agent, ExtensionConfig};
|
||||
use goose::config::{DEFAULT_EXTENSION_DESCRIPTION, DEFAULT_EXTENSION_TIMEOUT};
|
||||
use goose::message::Message;
|
||||
use goose::providers::databricks::DatabricksProvider;
|
||||
@@ -10,10 +12,10 @@ async fn main() {
|
||||
// Setup a model provider from env vars
|
||||
let _ = dotenv();
|
||||
|
||||
let provider = Box::new(DatabricksProvider::default());
|
||||
let provider = Arc::new(DatabricksProvider::default());
|
||||
|
||||
// Setup an agent with the developer extension
|
||||
let mut agent = AgentFactory::create("reference", provider).expect("default should exist");
|
||||
let mut agent = Agent::new(provider);
|
||||
|
||||
let config = ExtensionConfig::stdio(
|
||||
"developer",
|
||||
|
||||
@@ -1,76 +1,804 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::extension::{ExtensionConfig, ExtensionResult};
|
||||
use crate::providers::base::Provider;
|
||||
use crate::session;
|
||||
use crate::{message::Message, permission::PermissionConfirmation};
|
||||
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, Tool, ToolResult};
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::stream::BoxStream;
|
||||
|
||||
/// Session configuration for an agent
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionConfig {
|
||||
/// Unique identifier for the session
|
||||
pub id: session::Identifier,
|
||||
/// Working directory for the session
|
||||
pub working_dir: PathBuf,
|
||||
use serde_json::Value;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tracing::{debug, error, instrument, warn};
|
||||
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionResult, ToolInfo};
|
||||
use crate::agents::extension_manager::{get_parameter_names, ExtensionManager};
|
||||
use crate::agents::types::ToolResultReceiver;
|
||||
use crate::config::{Config, ExtensionConfigManager};
|
||||
use crate::message::{Message, MessageContent, ToolRequest};
|
||||
use crate::permission::{
|
||||
detect_read_only_tools, Permission, PermissionConfirmation, ToolPermissionStore,
|
||||
};
|
||||
use crate::providers::base::Provider;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::providers::toolshim::{
|
||||
augment_message_with_tool_calls, modify_system_prompt_for_tool_json, OllamaInterpreter,
|
||||
};
|
||||
use crate::session;
|
||||
use crate::token_counter::TokenCounter;
|
||||
use crate::truncate::{truncate_messages, OldestFirstTruncation};
|
||||
|
||||
use mcp_core::{
|
||||
prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult,
|
||||
};
|
||||
|
||||
use crate::agents::platform_tools::{
|
||||
self, PLATFORM_LIST_RESOURCES_TOOL_NAME, PLATFORM_READ_RESOURCE_TOOL_NAME,
|
||||
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
|
||||
};
|
||||
use crate::agents::prompt_manager::PromptManager;
|
||||
use crate::agents::types::SessionConfig;
|
||||
|
||||
use super::platform_tools::PLATFORM_ENABLE_EXTENSION_TOOL_NAME;
|
||||
use super::types::FrontendTool;
|
||||
|
||||
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
|
||||
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
|
||||
|
||||
/// The main goose Agent
|
||||
pub struct Agent {
|
||||
provider: Arc<dyn Provider>,
|
||||
extension_manager: Mutex<ExtensionManager>,
|
||||
frontend_tools: HashMap<String, FrontendTool>,
|
||||
frontend_instructions: Option<String>,
|
||||
prompt_manager: PromptManager,
|
||||
token_counter: TokenCounter,
|
||||
confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
|
||||
confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
|
||||
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||
tool_result_rx: ToolResultReceiver,
|
||||
}
|
||||
|
||||
/// Core trait defining the behavior of an Agent
|
||||
#[async_trait]
|
||||
pub trait Agent: Send + Sync {
|
||||
/// Create a stream that yields each message as it's generated by the agent
|
||||
async fn reply(
|
||||
impl Agent {
|
||||
pub fn new(provider: Arc<dyn Provider>) -> Self {
|
||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||
// Create channels with buffer size 32 (adjust if needed)
|
||||
let (confirm_tx, confirm_rx) = mpsc::channel(32);
|
||||
let (tool_tx, tool_rx) = mpsc::channel(32);
|
||||
|
||||
Self {
|
||||
provider,
|
||||
extension_manager: Mutex::new(ExtensionManager::new()),
|
||||
frontend_tools: HashMap::new(),
|
||||
frontend_instructions: None,
|
||||
prompt_manager: PromptManager::new(),
|
||||
token_counter,
|
||||
confirmation_tx: confirm_tx,
|
||||
confirmation_rx: Mutex::new(confirm_rx),
|
||||
tool_result_tx: tool_tx,
|
||||
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a reference count clone to the provider
|
||||
pub fn provider(&self) -> Arc<dyn Provider> {
|
||||
Arc::clone(&self.provider)
|
||||
}
|
||||
|
||||
/// Check if a tool is a frontend tool
|
||||
pub fn is_frontend_tool(&self, name: &str) -> bool {
|
||||
self.frontend_tools.contains_key(name)
|
||||
}
|
||||
|
||||
/// Get a reference to a frontend tool
|
||||
pub fn get_frontend_tool(&self, name: &str) -> Option<&FrontendTool> {
|
||||
self.frontend_tools.get(name)
|
||||
}
|
||||
|
||||
/// Get all tools from all clients with proper prefixing
|
||||
pub async fn get_prefixed_tools(&mut self) -> ExtensionResult<Vec<Tool>> {
|
||||
let mut tools = self
|
||||
.extension_manager
|
||||
.lock()
|
||||
.await
|
||||
.get_prefixed_tools()
|
||||
.await?;
|
||||
|
||||
// Add frontend tools directly - they don't need prefixing since they're already uniquely named
|
||||
for frontend_tool in self.frontend_tools.values() {
|
||||
tools.push(frontend_tool.tool.clone());
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
/// Dispatch a single tool call to the appropriate client
|
||||
#[instrument(skip(tool_call, extension_manager, request_id), fields(input, output))]
|
||||
async fn create_tool_future(
|
||||
extension_manager: &ExtensionManager,
|
||||
tool_call: mcp_core::tool::ToolCall,
|
||||
is_frontend_tool: bool,
|
||||
request_id: String,
|
||||
) -> (String, Result<Vec<Content>, ToolError>) {
|
||||
let result = if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
|
||||
// Check if the tool is read_resource and handle it separately
|
||||
extension_manager
|
||||
.read_resource(tool_call.arguments.clone())
|
||||
.await
|
||||
} else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME {
|
||||
extension_manager
|
||||
.list_resources(tool_call.arguments.clone())
|
||||
.await
|
||||
} else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME {
|
||||
extension_manager.search_available_extensions().await
|
||||
} else if is_frontend_tool {
|
||||
// For frontend tools, return an error indicating we need frontend execution
|
||||
Err(ToolError::ExecutionError(
|
||||
"Frontend tool execution required".to_string(),
|
||||
))
|
||||
} else {
|
||||
extension_manager
|
||||
.dispatch_tool_call(tool_call.clone())
|
||||
.await
|
||||
};
|
||||
|
||||
debug!(
|
||||
"input" = serde_json::to_string(&tool_call).unwrap(),
|
||||
"output" = serde_json::to_string(&result).unwrap(),
|
||||
);
|
||||
|
||||
(request_id, result)
|
||||
}
|
||||
|
||||
/// Truncates the messages to fit within the model's context window
|
||||
/// Ensures the last message is a user message and removes tool call-response pairs
|
||||
async fn truncate_messages(
|
||||
&self,
|
||||
messages: &mut Vec<Message>,
|
||||
estimate_factor: f32,
|
||||
system_prompt: &str,
|
||||
tools: &mut Vec<Tool>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Model's actual context limit
|
||||
let context_limit = self.provider.get_model_config().context_limit();
|
||||
|
||||
// Our conservative estimate of the **target** context limit
|
||||
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
|
||||
let context_limit = (context_limit as f32 * estimate_factor) as usize;
|
||||
|
||||
// Take into account the system prompt, and our tools input and subtract that from the
|
||||
// remaining context limit
|
||||
let system_prompt_token_count = self.token_counter.count_tokens(system_prompt);
|
||||
let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice());
|
||||
|
||||
// Check if system prompt + tools exceed our context limit
|
||||
let remaining_tokens = context_limit
|
||||
.checked_sub(system_prompt_token_count)
|
||||
.and_then(|remaining| remaining.checked_sub(tools_token_count))
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("System prompt and tools exceed estimated context limit")
|
||||
})?;
|
||||
|
||||
let context_limit = remaining_tokens;
|
||||
|
||||
// Calculate current token count of each message, use count_chat_tokens to ensure we
|
||||
// capture the full content of the message, include ToolRequests and ToolResponses
|
||||
let mut token_counts: Vec<usize> = messages
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
self.token_counter
|
||||
.count_chat_tokens("", std::slice::from_ref(msg), &[])
|
||||
})
|
||||
.collect();
|
||||
|
||||
truncate_messages(
|
||||
messages,
|
||||
&mut token_counts,
|
||||
context_limit,
|
||||
&OldestFirstTruncation,
|
||||
)
|
||||
}
|
||||
|
||||
async fn enable_extension(
|
||||
extension_manager: &mut ExtensionManager,
|
||||
extension_name: String,
|
||||
request_id: String,
|
||||
) -> (String, Result<Vec<Content>, ToolError>) {
|
||||
let config = match ExtensionConfigManager::get_config_by_name(&extension_name) {
|
||||
Ok(Some(config)) => config,
|
||||
Ok(None) => {
|
||||
return (
|
||||
request_id,
|
||||
Err(ToolError::ExecutionError(format!(
|
||||
"Extension '{}' not found. Please check the extension name and try again.",
|
||||
extension_name
|
||||
))),
|
||||
)
|
||||
}
|
||||
Err(e) => {
|
||||
return (
|
||||
request_id,
|
||||
Err(ToolError::ExecutionError(format!(
|
||||
"Failed to get extension config: {}",
|
||||
e
|
||||
))),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let result = extension_manager
|
||||
.add_extension(config)
|
||||
.await
|
||||
.map(|_| {
|
||||
vec![Content::text(format!(
|
||||
"The extension '{}' has been installed successfully",
|
||||
extension_name
|
||||
))]
|
||||
})
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()));
|
||||
|
||||
(request_id, result)
|
||||
}
|
||||
|
||||
pub async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> {
|
||||
match &extension {
|
||||
ExtensionConfig::Frontend {
|
||||
name: _,
|
||||
tools,
|
||||
instructions,
|
||||
} => {
|
||||
// For frontend tools, just store them in the frontend_tools map
|
||||
for tool in tools {
|
||||
let frontend_tool = FrontendTool {
|
||||
name: tool.name.clone(),
|
||||
tool: tool.clone(),
|
||||
};
|
||||
self.frontend_tools.insert(tool.name.clone(), frontend_tool);
|
||||
}
|
||||
// Store instructions if provided, using "frontend" as the key
|
||||
if let Some(instructions) = instructions {
|
||||
self.frontend_instructions = Some(instructions.clone());
|
||||
} else {
|
||||
// Default frontend instructions if none provided
|
||||
self.frontend_instructions = Some(
|
||||
"The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let mut extension_manager = self.extension_manager.lock().await;
|
||||
let _ = extension_manager.add_extension(extension).await;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_tools(&self) -> Vec<Tool> {
|
||||
let mut extension_manager = self.extension_manager.lock().await;
|
||||
extension_manager
|
||||
.get_prefixed_tools()
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub async fn remove_extension(&mut self, name: &str) {
|
||||
let mut extension_manager = self.extension_manager.lock().await;
|
||||
extension_manager
|
||||
.remove_extension(name)
|
||||
.await
|
||||
.expect("Failed to remove extension");
|
||||
}
|
||||
|
||||
pub async fn list_extensions(&self) -> Vec<String> {
|
||||
let extension_manager = self.extension_manager.lock().await;
|
||||
extension_manager
|
||||
.list_extensions()
|
||||
.await
|
||||
.expect("Failed to list extensions")
|
||||
}
|
||||
|
||||
/// Handle a confirmation response for a tool request
|
||||
pub async fn handle_confirmation(
|
||||
&self,
|
||||
request_id: String,
|
||||
confirmation: PermissionConfirmation,
|
||||
) {
|
||||
if let Err(e) = self.confirmation_tx.send((request_id, confirmation)).await {
|
||||
error!("Failed to send confirmation: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self, messages, session), fields(user_message))]
|
||||
pub async fn reply(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
session: Option<SessionConfig>,
|
||||
) -> Result<BoxStream<'_, Result<Message>>>;
|
||||
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
|
||||
let mut messages = messages.to_vec();
|
||||
let reply_span = tracing::Span::current();
|
||||
let mut extension_manager = self.extension_manager.lock().await;
|
||||
let mut tools = extension_manager.get_prefixed_tools().await?;
|
||||
let mut truncation_attempt: usize = 0;
|
||||
|
||||
/// Add a new MCP client to the agent
|
||||
async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()>;
|
||||
// Load settings from config
|
||||
let config = Config::global();
|
||||
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
|
||||
|
||||
/// Remove an extension by name
|
||||
async fn remove_extension(&mut self, name: &str);
|
||||
// we add in the 2 resource tools if any extensions support resources
|
||||
// TODO: make sure there is no collision with another extension's tool name
|
||||
if extension_manager.supports_resources() {
|
||||
tools.push(platform_tools::read_resource_tool());
|
||||
tools.push(platform_tools::list_resources_tool());
|
||||
}
|
||||
tools.push(platform_tools::search_available_extensions_tool());
|
||||
tools.push(platform_tools::enable_extension_tool());
|
||||
|
||||
/// List all extensions
|
||||
// TODO this needs to also include status so we can tell if extensions are dropped
|
||||
async fn list_extensions(&self) -> Vec<String>;
|
||||
let (tools_with_readonly_annotation, tools_without_annotation): (Vec<String>, Vec<String>) =
|
||||
tools.iter().fold((vec![], vec![]), |mut acc, tool| {
|
||||
match &tool.annotations {
|
||||
Some(annotations) => {
|
||||
if annotations.read_only_hint {
|
||||
acc.0.push(tool.name.clone());
|
||||
}
|
||||
}
|
||||
None => {
|
||||
acc.1.push(tool.name.clone());
|
||||
}
|
||||
}
|
||||
acc
|
||||
});
|
||||
|
||||
/// List the tools this agent has access to
|
||||
async fn list_tools(&self) -> Vec<Tool>;
|
||||
let config = self.provider.get_model_config();
|
||||
let extensions_info = extension_manager.get_extensions_info().await;
|
||||
let mut system_prompt = self
|
||||
.prompt_manager
|
||||
.build_system_prompt(extensions_info, self.frontend_instructions.clone());
|
||||
let mut toolshim_tools = vec![];
|
||||
if config.toolshim {
|
||||
// If tool interpretation is enabled, modify the system prompt to instruct to return JSON tool requests
|
||||
system_prompt = modify_system_prompt_for_tool_json(&system_prompt, &tools);
|
||||
// make a copy of tools before empty
|
||||
toolshim_tools = tools.clone();
|
||||
// pass empty tools vector to provider completion since toolshim will handle tool calls instead
|
||||
tools = vec![];
|
||||
}
|
||||
|
||||
/// Pass through a JSON-RPC request to a specific extension
|
||||
async fn passthrough(&self, extension: &str, request: Value) -> ExtensionResult<Value>;
|
||||
// Set the user_message field in the span instead of creating a new event
|
||||
if let Some(content) = messages
|
||||
.last()
|
||||
.and_then(|msg| msg.content.first())
|
||||
.and_then(|c| c.as_text())
|
||||
{
|
||||
debug!("user_message" = &content);
|
||||
}
|
||||
|
||||
/// Add custom text to be included in the system prompt
|
||||
async fn extend_system_prompt(&mut self, extension: String);
|
||||
Ok(Box::pin(async_stream::try_stream! {
|
||||
let _reply_guard = reply_span.enter();
|
||||
loop {
|
||||
match self.provider().complete(
|
||||
&system_prompt,
|
||||
&messages,
|
||||
&tools,
|
||||
).await {
|
||||
Ok((mut response, usage)) => {
|
||||
// Post-process / structure the response only if tool interpretation is enabled
|
||||
if config.toolshim {
|
||||
let interpreter = OllamaInterpreter::new()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create OllamaInterpreter: {}", e))?;
|
||||
|
||||
/// Handle a confirmation response for a tool request
|
||||
async fn handle_confirmation(&self, request_id: String, confirmation: PermissionConfirmation);
|
||||
response = augment_message_with_tool_calls(&interpreter, response, &toolshim_tools).await?;
|
||||
}
|
||||
|
||||
/// Override the system prompt with custom text
|
||||
async fn override_system_prompt(&mut self, template: String);
|
||||
// record usage for the session in the session file
|
||||
if let Some(session) = session.clone() {
|
||||
// TODO: track session_id in langfuse tracing
|
||||
let session_file = session::get_path(session.id);
|
||||
let mut metadata = session::read_metadata(&session_file)?;
|
||||
metadata.working_dir = session.working_dir;
|
||||
metadata.total_tokens = usage.usage.total_tokens;
|
||||
metadata.input_tokens = usage.usage.input_tokens;
|
||||
metadata.output_tokens = usage.usage.output_tokens;
|
||||
// The message count is the number of messages in the session + 1 for the response
|
||||
// The message count does not include the tool response till next iteration
|
||||
metadata.message_count = messages.len() + 1;
|
||||
session::update_metadata(&session_file, &metadata).await?;
|
||||
}
|
||||
|
||||
/// Lists all prompts from all extensions
|
||||
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>>;
|
||||
// Reset truncation attempt
|
||||
truncation_attempt = 0;
|
||||
|
||||
/// Get a prompt result with the given name and arguments
|
||||
/// Returns the prompt text that would be used as user input
|
||||
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult>;
|
||||
// Yield the assistant's response, but filter out frontend tool requests that we'll process separately
|
||||
let filtered_response = Message {
|
||||
role: response.role.clone(),
|
||||
created: response.created,
|
||||
content: response.content.iter().filter(|c| {
|
||||
if let MessageContent::ToolRequest(req) = c {
|
||||
// Only filter out frontend tool requests
|
||||
if let Ok(tool_call) = &req.tool_call {
|
||||
return !self.is_frontend_tool(&tool_call.name);
|
||||
}
|
||||
}
|
||||
true
|
||||
}).cloned().collect(),
|
||||
};
|
||||
yield filtered_response.clone();
|
||||
|
||||
/// Get the plan prompt, which will be used with the planner (reasoner) model
|
||||
async fn get_plan_prompt(&self) -> anyhow::Result<String>;
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
/// Get a reference to the provider used by this agent
|
||||
async fn provider(&self) -> Arc<Box<dyn Provider>>;
|
||||
// First collect any tool requests
|
||||
let tool_requests: Vec<&ToolRequest> = response.content
|
||||
.iter()
|
||||
.filter_map(|content| content.as_tool_request())
|
||||
.collect();
|
||||
|
||||
/// Handle a tool result from the frontend
|
||||
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>);
|
||||
if tool_requests.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Process tool requests depending on goose_mode
|
||||
let mut message_tool_response = Message::user();
|
||||
|
||||
// First handle any frontend tool requests
|
||||
let mut remaining_requests = Vec::new();
|
||||
for request in &tool_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
if self.is_frontend_tool(&tool_call.name) {
|
||||
// Send frontend tool request and wait for response
|
||||
yield Message::assistant().with_frontend_tool_request(
|
||||
request.id.clone(),
|
||||
Ok(tool_call.clone())
|
||||
);
|
||||
|
||||
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
|
||||
message_tool_response = message_tool_response.with_tool_response(id, result);
|
||||
}
|
||||
} else {
|
||||
remaining_requests.push(request);
|
||||
}
|
||||
} else {
|
||||
remaining_requests.push(request);
|
||||
}
|
||||
}
|
||||
|
||||
// Split tool requests into enable_extension and others
|
||||
let (enable_extension_requests, non_enable_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone()
|
||||
.into_iter()
|
||||
.partition(|req| {
|
||||
req.tool_call.as_ref()
|
||||
.map(|call| call.name == PLATFORM_ENABLE_EXTENSION_TOOL_NAME)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
let (search_extension_requests, _non_search_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone()
|
||||
.into_iter()
|
||||
.partition(|req| {
|
||||
req.tool_call.as_ref()
|
||||
.map(|call| call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
// Clone goose_mode once before the match to avoid move issues
|
||||
let mode = goose_mode.clone();
|
||||
|
||||
// If there are install extension requests, always require confirmation
|
||||
// or if goose_mode is approve or smart_approve, check permissions for all tools
|
||||
if !enable_extension_requests.is_empty() || mode.as_str() == "approve" || mode.as_str() == "smart_approve" {
|
||||
let mut needs_confirmation = Vec::<&ToolRequest>::new();
|
||||
let mut approved_tools = Vec::new();
|
||||
let mut llm_detect_candidates = Vec::<&ToolRequest>::new();
|
||||
let mut detected_read_only_tools = Vec::new();
|
||||
|
||||
// If approve mode or smart approve mode, check permissions for all tools
|
||||
if mode.as_str() == "approve" || mode.as_str() == "smart_approve" {
|
||||
let store = ToolPermissionStore::load()?;
|
||||
for request in &non_enable_extension_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
// Regular permission checking for other tools
|
||||
if tools_with_readonly_annotation.contains(&tool_call.name) {
|
||||
approved_tools.push((request.id.clone(), tool_call));
|
||||
} else if let Some(allowed) = store.check_permission(request) {
|
||||
if allowed {
|
||||
// Instead of executing immediately, collect approved tools
|
||||
approved_tools.push((request.id.clone(), tool_call));
|
||||
} else {
|
||||
// If the tool doesn't have any annotation, we can use llm-as-a-judge to check permission.
|
||||
if tools_without_annotation.contains(&tool_call.name) {
|
||||
llm_detect_candidates.push(request);
|
||||
}
|
||||
needs_confirmation.push(request);
|
||||
}
|
||||
} else {
|
||||
if tools_without_annotation.contains(&tool_call.name) {
|
||||
llm_detect_candidates.push(request);
|
||||
}
|
||||
needs_confirmation.push(request);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Only check read-only status for tools needing confirmation
|
||||
if !llm_detect_candidates.is_empty() && mode == "smart_approve" {
|
||||
detected_read_only_tools = detect_read_only_tools(self.provider(), llm_detect_candidates.clone()).await;
|
||||
// Remove install extensions from read-only tools
|
||||
if !enable_extension_requests.is_empty() {
|
||||
detected_read_only_tools.retain(|tool_name| {
|
||||
!enable_extension_requests.iter().any(|req| {
|
||||
req.tool_call.as_ref()
|
||||
.map(|call| call.name == *tool_name)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle pre-approved and read-only tools in parallel
|
||||
let mut tool_futures = Vec::new();
|
||||
let mut install_results = Vec::new();
|
||||
|
||||
// Handle install extension requests
|
||||
for request in &enable_extension_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let confirmation = Message::user().with_enable_extension_request(
|
||||
request.id.clone(),
|
||||
tool_call.arguments.get("extension_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
);
|
||||
yield confirmation;
|
||||
|
||||
let mut rx = self.confirmation_rx.lock().await;
|
||||
while let Some((req_id, extension_confirmation)) = rx.recv().await {
|
||||
if req_id == request.id {
|
||||
if extension_confirmation.permission == Permission::AllowOnce || extension_confirmation.permission == Permission::AlwaysAllow {
|
||||
let extension_name = tool_call.arguments.get("extension_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let install_result = Self::enable_extension(&mut extension_manager, extension_name, request.id.clone()).await;
|
||||
install_results.push(install_result);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process read-only tools
|
||||
for request in &needs_confirmation {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let is_frontend_tool = self.is_frontend_tool(&tool_call.name);
|
||||
// Skip confirmation if the tool_call.name is in the read_only_tools list
|
||||
if detected_read_only_tools.contains(&tool_call.name) {
|
||||
let tool_future = Self::create_tool_future(&extension_manager, tool_call, is_frontend_tool, request.id.clone());
|
||||
tool_futures.push(tool_future);
|
||||
} else {
|
||||
let confirmation = Message::user().with_tool_confirmation_request(
|
||||
request.id.clone(),
|
||||
tool_call.name.clone(),
|
||||
tool_call.arguments.clone(),
|
||||
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
|
||||
);
|
||||
yield confirmation;
|
||||
|
||||
// Wait for confirmation response through the channel
|
||||
let mut rx = self.confirmation_rx.lock().await;
|
||||
while let Some((req_id, tool_confirmation)) = rx.recv().await {
|
||||
if req_id == request.id {
|
||||
let confirmed = tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow;
|
||||
if confirmed {
|
||||
// Add this tool call to the futures collection
|
||||
let tool_future = Self::create_tool_future(&extension_manager, tool_call, is_frontend_tool, request.id.clone());
|
||||
tool_futures.push(tool_future);
|
||||
} else {
|
||||
// User declined - add declined response
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
Ok(vec![Content::text(
|
||||
"The user has declined to run this tool. \
|
||||
DO NOT attempt to call this tool again. \
|
||||
If there are no alternative methods to proceed, clearly explain the situation and STOP.")]),
|
||||
);
|
||||
}
|
||||
break; // Exit the loop once the matching `req_id` is found
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all tool calls to complete
|
||||
let results = futures::future::join_all(tool_futures).await;
|
||||
for (request_id, output) in results {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request_id,
|
||||
output,
|
||||
);
|
||||
}
|
||||
|
||||
// Check if any install results had errors before processing them
|
||||
let all_successful = !install_results.iter().any(|(_, result)| result.is_err());
|
||||
|
||||
for (request_id, output) in install_results {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request_id,
|
||||
output
|
||||
);
|
||||
}
|
||||
|
||||
// Update system prompt and tools if all installations were successful
|
||||
if all_successful {
|
||||
let extensions_info = extension_manager.get_extensions_info().await;
|
||||
system_prompt = self.prompt_manager.build_system_prompt(extensions_info, self.frontend_instructions.clone());
|
||||
tools = extension_manager.get_prefixed_tools().await?;
|
||||
}
|
||||
}
|
||||
|
||||
if mode.as_str() == "auto" || !search_extension_requests.is_empty() {
|
||||
let mut tool_futures = Vec::new();
|
||||
// Process non_enable_extension_requests and search_extension_requests without duplicates
|
||||
let mut processed_ids = HashSet::new();
|
||||
|
||||
for request in non_enable_extension_requests.iter().chain(search_extension_requests.iter()) {
|
||||
if processed_ids.insert(request.id.clone()) {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let is_frontend_tool = self.is_frontend_tool(&tool_call.name);
|
||||
let tool_future = Self::create_tool_future(&extension_manager, tool_call, is_frontend_tool, request.id.clone());
|
||||
tool_futures.push(tool_future);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all tool calls to complete
|
||||
let results = futures::future::join_all(tool_futures).await;
|
||||
for (request_id, output) in results {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request_id,
|
||||
output,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if mode.as_str() == "chat" {
|
||||
// Skip all tool calls in chat mode
|
||||
// Skip search extension requests since they were already processed
|
||||
let non_search_non_enable_extension_requests = non_enable_extension_requests.iter()
|
||||
.filter(|req| {
|
||||
if let Ok(tool_call) = &req.tool_call {
|
||||
tool_call.name != PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
for request in non_search_non_enable_extension_requests {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
Ok(vec![Content::text(
|
||||
"Let the user know the tool call was skipped in Goose chat mode. \
|
||||
DO NOT apologize for skipping the tool call. DO NOT say sorry. \
|
||||
Provide an explanation of what the tool call would do, structured as a \
|
||||
plan for the user. Again, DO NOT apologize. \
|
||||
**Example Plan:**\n \
|
||||
1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \
|
||||
2. **Outline Steps** - Break down the steps.\n \
|
||||
If needed, adjust the explanation based on user preferences or questions."
|
||||
)]),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
yield message_tool_response.clone();
|
||||
|
||||
messages.push(response);
|
||||
messages.push(message_tool_response);
|
||||
},
|
||||
Err(ProviderError::ContextLengthExceeded(_)) => {
|
||||
if truncation_attempt >= MAX_TRUNCATION_ATTEMPTS {
|
||||
// Create an error message & terminate the stream
|
||||
// the previous message would have been a user message (e.g. before any tool calls, this is just after the input message.
|
||||
// at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user)
|
||||
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again.");
|
||||
break;
|
||||
}
|
||||
|
||||
truncation_attempt += 1;
|
||||
warn!("Context length exceeded. Truncation Attempt: {}/{}.", truncation_attempt, MAX_TRUNCATION_ATTEMPTS);
|
||||
|
||||
// Decay the estimate factor as we make more truncation attempts
|
||||
// Estimate factor decays like this over time: 0.9, 0.81, 0.729, ...
|
||||
let estimate_factor: f32 = ESTIMATE_FACTOR_DECAY.powi(truncation_attempt as i32);
|
||||
|
||||
// release the lock before truncation to prevent deadlock
|
||||
drop(extension_manager);
|
||||
|
||||
if let Err(err) = self.truncate_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await {
|
||||
yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err));
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Re-acquire the lock
|
||||
extension_manager = self.extension_manager.lock().await;
|
||||
|
||||
// Retry the loop after truncation
|
||||
continue;
|
||||
},
|
||||
Err(e) => {
|
||||
// Create an error message & terminate the stream
|
||||
error!("Error: {}", e);
|
||||
yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Yield control back to the scheduler to prevent blocking
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Extend the system prompt with one line of additional instruction
|
||||
pub async fn extend_system_prompt(&mut self, instruction: String) {
|
||||
self.prompt_manager.add_system_prompt_extra(instruction);
|
||||
}
|
||||
|
||||
/// Override the system prompt with a custom template
|
||||
pub async fn override_system_prompt(&mut self, template: String) {
|
||||
self.prompt_manager.set_system_prompt_override(template);
|
||||
}
|
||||
|
||||
pub async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
|
||||
let extension_manager = self.extension_manager.lock().await;
|
||||
extension_manager
|
||||
.list_prompts()
|
||||
.await
|
||||
.expect("Failed to list prompts")
|
||||
}
|
||||
|
||||
pub async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
|
||||
let extension_manager = self.extension_manager.lock().await;
|
||||
|
||||
// First find which extension has this prompt
|
||||
let prompts = extension_manager
|
||||
.list_prompts()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
|
||||
|
||||
if let Some(extension) = prompts
|
||||
.iter()
|
||||
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
|
||||
.map(|(extension, _)| extension)
|
||||
{
|
||||
return extension_manager
|
||||
.get_prompt(extension, name, arguments)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
|
||||
}
|
||||
|
||||
Err(anyhow!("Prompt '{}' not found", name))
|
||||
}
|
||||
|
||||
pub async fn get_plan_prompt(&self) -> anyhow::Result<String> {
|
||||
let mut extension_manager = self.extension_manager.lock().await;
|
||||
let tools = extension_manager.get_prefixed_tools().await?;
|
||||
let tools_info = tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
ToolInfo::new(
|
||||
&tool.name,
|
||||
&tool.description,
|
||||
get_parameter_names(&tool),
|
||||
None,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let plan_prompt = extension_manager.get_planning_prompt(tools_info).await;
|
||||
|
||||
Ok(plan_prompt)
|
||||
}
|
||||
|
||||
pub async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
|
||||
if let Err(e) = self.tool_result_tx.send((id, result)).await {
|
||||
tracing::error!("Failed to send tool result: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,9 +262,9 @@ impl std::fmt::Display for ExtensionConfig {
|
||||
/// Information about the extension used for building prompts
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct ExtensionInfo {
|
||||
name: String,
|
||||
instructions: String,
|
||||
has_resources: bool,
|
||||
pub name: String,
|
||||
pub instructions: String,
|
||||
pub has_resources: bool,
|
||||
}
|
||||
|
||||
impl ExtensionInfo {
|
||||
|
||||
@@ -8,12 +8,11 @@ use std::sync::Arc;
|
||||
use std::sync::LazyLock;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, instrument};
|
||||
use tracing::debug;
|
||||
|
||||
use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo};
|
||||
use crate::config::{Config, ExtensionManager};
|
||||
use crate::config::ExtensionConfigManager;
|
||||
use crate::prompt_template;
|
||||
use crate::providers::base::Provider;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
|
||||
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult};
|
||||
@@ -26,22 +25,11 @@ static DEFAULT_TIMESTAMP: LazyLock<DateTime<Utc>> =
|
||||
|
||||
type McpClientBox = Arc<Mutex<Box<dyn McpClientTrait>>>;
|
||||
|
||||
/// A frontend tool that will be executed by the frontend rather than an extension
|
||||
#[derive(Clone)]
|
||||
pub struct FrontendTool {
|
||||
pub name: String,
|
||||
pub tool: Tool,
|
||||
}
|
||||
|
||||
/// Manages MCP clients and their interactions
|
||||
pub struct Capabilities {
|
||||
/// Manages Goose extensions / MCP clients and their interactions
|
||||
pub struct ExtensionManager {
|
||||
clients: HashMap<String, McpClientBox>,
|
||||
frontend_tools: HashMap<String, FrontendTool>,
|
||||
instructions: HashMap<String, String>,
|
||||
resource_capable_extensions: HashSet<String>,
|
||||
provider: Arc<Box<dyn Provider>>,
|
||||
system_prompt_override: Option<String>,
|
||||
system_prompt_extensions: Vec<String>,
|
||||
}
|
||||
|
||||
/// A flattened representation of a resource used by the agent to prepare inference
|
||||
@@ -99,17 +87,19 @@ pub fn get_parameter_names(tool: &Tool) -> Vec<String> {
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
impl Capabilities {
|
||||
/// Create a new Capabilities with the specified provider
|
||||
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||
impl Default for ExtensionManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtensionManager {
|
||||
/// Create a new ExtensionManager instance
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
clients: HashMap::new(),
|
||||
frontend_tools: HashMap::new(),
|
||||
instructions: HashMap::new(),
|
||||
resource_capable_extensions: HashSet::new(),
|
||||
provider: Arc::new(provider),
|
||||
system_prompt_override: None,
|
||||
system_prompt_extensions: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,134 +111,106 @@ impl Capabilities {
|
||||
// TODO IMPORTANT need to ensure this times out if the extension command is broken!
|
||||
pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> {
|
||||
let sanitized_name = normalize(config.key().to_string());
|
||||
|
||||
match &config {
|
||||
ExtensionConfig::Frontend {
|
||||
name: _,
|
||||
tools,
|
||||
instructions,
|
||||
let mut client: Box<dyn McpClientTrait> = match &config {
|
||||
ExtensionConfig::Sse {
|
||||
uri, envs, timeout, ..
|
||||
} => {
|
||||
// For frontend tools, just store them in the frontend_tools map
|
||||
for tool in tools {
|
||||
let frontend_tool = FrontendTool {
|
||||
name: tool.name.clone(),
|
||||
tool: tool.clone(),
|
||||
};
|
||||
self.frontend_tools.insert(tool.name.clone(), frontend_tool);
|
||||
}
|
||||
// Store instructions if provided, using "frontend" as the key
|
||||
if let Some(instructions) = instructions {
|
||||
self.instructions
|
||||
.insert("frontend".to_string(), instructions.clone());
|
||||
}
|
||||
Ok(())
|
||||
let transport = SseTransport::new(uri, envs.get_env());
|
||||
let handle = transport.start().await?;
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
Box::new(McpClient::new(service))
|
||||
}
|
||||
_ => {
|
||||
let mut client: Box<dyn McpClientTrait> = match &config {
|
||||
ExtensionConfig::Sse {
|
||||
uri, envs, timeout, ..
|
||||
} => {
|
||||
let transport = SseTransport::new(uri, envs.get_env());
|
||||
let handle = transport.start().await?;
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
Box::new(McpClient::new(service))
|
||||
}
|
||||
ExtensionConfig::Stdio {
|
||||
cmd,
|
||||
args,
|
||||
envs,
|
||||
timeout,
|
||||
..
|
||||
} => {
|
||||
let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env());
|
||||
let handle = transport.start().await?;
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
Box::new(McpClient::new(service))
|
||||
}
|
||||
ExtensionConfig::Builtin {
|
||||
name,
|
||||
display_name: _,
|
||||
timeout,
|
||||
} => {
|
||||
// For builtin extensions, we run the current executable with mcp and extension name
|
||||
let cmd = std::env::current_exe()
|
||||
.expect("should find the current executable")
|
||||
.to_str()
|
||||
.expect("should resolve executable to string path")
|
||||
.to_string();
|
||||
let transport = StdioTransport::new(
|
||||
&cmd,
|
||||
vec!["mcp".to_string(), name.clone()],
|
||||
HashMap::new(),
|
||||
);
|
||||
let handle = transport.start().await?;
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
Box::new(McpClient::new(service))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
// Initialize the client with default capabilities
|
||||
let info = ClientInfo {
|
||||
name: "goose".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
};
|
||||
let capabilities = ClientCapabilities::default();
|
||||
|
||||
let init_result = client
|
||||
.initialize(info, capabilities)
|
||||
.await
|
||||
.map_err(|e| ExtensionError::Initialization(config.clone(), e))?;
|
||||
|
||||
// Store instructions if provided
|
||||
if let Some(instructions) = init_result.instructions {
|
||||
self.instructions
|
||||
.insert(sanitized_name.clone(), instructions);
|
||||
}
|
||||
|
||||
// if the server is capable if resources we track it
|
||||
if init_result.capabilities.resources.is_some() {
|
||||
self.resource_capable_extensions
|
||||
.insert(sanitized_name.clone());
|
||||
}
|
||||
|
||||
// Store the client using the provided name
|
||||
self.clients
|
||||
.insert(sanitized_name.clone(), Arc::new(Mutex::new(client)));
|
||||
|
||||
Ok(())
|
||||
ExtensionConfig::Stdio {
|
||||
cmd,
|
||||
args,
|
||||
envs,
|
||||
timeout,
|
||||
..
|
||||
} => {
|
||||
let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env());
|
||||
let handle = transport.start().await?;
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
Box::new(McpClient::new(service))
|
||||
}
|
||||
ExtensionConfig::Builtin {
|
||||
name,
|
||||
display_name: _,
|
||||
timeout,
|
||||
} => {
|
||||
// For builtin extensions, we run the current executable with mcp and extension name
|
||||
let cmd = std::env::current_exe()
|
||||
.expect("should find the current executable")
|
||||
.to_str()
|
||||
.expect("should resolve executable to string path")
|
||||
.to_string();
|
||||
let transport = StdioTransport::new(
|
||||
&cmd,
|
||||
vec!["mcp".to_string(), name.clone()],
|
||||
HashMap::new(),
|
||||
);
|
||||
let handle = transport.start().await?;
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
Box::new(McpClient::new(service))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
// Initialize the client with default capabilities
|
||||
let info = ClientInfo {
|
||||
name: "goose".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
};
|
||||
let capabilities = ClientCapabilities::default();
|
||||
|
||||
let init_result = client
|
||||
.initialize(info, capabilities)
|
||||
.await
|
||||
.map_err(|e| ExtensionError::Initialization(config.clone(), e))?;
|
||||
|
||||
// Store instructions if provided
|
||||
if let Some(instructions) = init_result.instructions {
|
||||
self.instructions
|
||||
.insert(sanitized_name.clone(), instructions);
|
||||
}
|
||||
|
||||
// if the server is capable if resources we track it
|
||||
if init_result.capabilities.resources.is_some() {
|
||||
self.resource_capable_extensions
|
||||
.insert(sanitized_name.clone());
|
||||
}
|
||||
|
||||
// Store the client using the provided name
|
||||
self.clients
|
||||
.insert(sanitized_name.clone(), Arc::new(Mutex::new(client)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a system prompt extension
|
||||
pub fn add_system_prompt_extension(&mut self, extension: String) {
|
||||
self.system_prompt_extensions.push(extension);
|
||||
}
|
||||
|
||||
/// Override the system prompt with custom text
|
||||
pub fn set_system_prompt_override(&mut self, template: String) {
|
||||
self.system_prompt_override = Some(template);
|
||||
}
|
||||
|
||||
/// Get a reference to the provider
|
||||
pub fn provider(&self) -> Arc<Box<dyn Provider>> {
|
||||
Arc::clone(&self.provider)
|
||||
/// Get extensions info
|
||||
pub async fn get_extensions_info(&self) -> Vec<ExtensionInfo> {
|
||||
self.clients
|
||||
.keys()
|
||||
.map(|name| {
|
||||
let instructions = self.instructions.get(name).cloned().unwrap_or_default();
|
||||
let has_resources = self.resource_capable_extensions.contains(name);
|
||||
ExtensionInfo::new(name, &instructions, has_resources)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get aggregated usage statistics
|
||||
@@ -269,11 +231,6 @@ impl Capabilities {
|
||||
pub async fn get_prefixed_tools(&mut self) -> ExtensionResult<Vec<Tool>> {
|
||||
let mut tools = Vec::new();
|
||||
|
||||
// Add frontend tools directly - they don't need prefixing since they're already uniquely named
|
||||
for frontend_tool in self.frontend_tools.values() {
|
||||
tools.push(frontend_tool.tool.clone());
|
||||
}
|
||||
|
||||
// Add tools from MCP extensions with prefixing
|
||||
for (name, client) in &self.clients {
|
||||
let client_guard = client.lock().await;
|
||||
@@ -297,6 +254,7 @@ impl Capabilities {
|
||||
client_tools = client_guard.list_tools(client_tools.next_cursor).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
}
|
||||
|
||||
@@ -353,76 +311,6 @@ impl Capabilities {
|
||||
prompt_template::render_global_file("plan.md", &context).expect("Prompt should render")
|
||||
}
|
||||
|
||||
/// Get the extension prompt including client instructions
|
||||
pub async fn get_system_prompt(&self) -> String {
|
||||
let mut context: HashMap<&str, Value> = HashMap::new();
|
||||
|
||||
let mut extensions_info: Vec<ExtensionInfo> = self
|
||||
.clients
|
||||
.keys()
|
||||
.map(|name| {
|
||||
let instructions = self.instructions.get(name).cloned().unwrap_or_default();
|
||||
let has_resources = self.resource_capable_extensions.contains(name);
|
||||
ExtensionInfo::new(name, &instructions, has_resources)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Add frontend tools as a special extension if any exist
|
||||
if !self.frontend_tools.is_empty() {
|
||||
let name = "frontend";
|
||||
let instructions = self.instructions.get(name).cloned().unwrap_or_else(||
|
||||
"The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string()
|
||||
);
|
||||
extensions_info.push(ExtensionInfo::new(name, &instructions, false));
|
||||
}
|
||||
context.insert("extensions", serde_json::to_value(extensions_info).unwrap());
|
||||
|
||||
let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
|
||||
context.insert("current_date_time", Value::String(current_date_time));
|
||||
|
||||
// Conditionally load the override prompt or the global system prompt
|
||||
let base_prompt = if let Some(override_prompt) = &self.system_prompt_override {
|
||||
prompt_template::render_inline_once(override_prompt, &context)
|
||||
.expect("Prompt should render")
|
||||
} else {
|
||||
prompt_template::render_global_file("system.md", &context)
|
||||
.expect("Prompt should render")
|
||||
};
|
||||
|
||||
let mut system_prompt_extensions = self.system_prompt_extensions.clone();
|
||||
let config = Config::global();
|
||||
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
|
||||
if goose_mode == "chat" {
|
||||
system_prompt_extensions.push(
|
||||
"Right now you are in the chat only mode, no access to any tool use and system."
|
||||
.to_string(),
|
||||
);
|
||||
} else {
|
||||
system_prompt_extensions
|
||||
.push("Right now you are *NOT* in the chat only mode and have access to tool use and system.".to_string());
|
||||
}
|
||||
|
||||
if system_prompt_extensions.is_empty() {
|
||||
base_prompt
|
||||
} else {
|
||||
format!(
|
||||
"{}\n\n# Additional Instructions:\n\n{}",
|
||||
base_prompt,
|
||||
system_prompt_extensions.join("\n\n")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a tool is a frontend tool
|
||||
pub fn is_frontend_tool(&self, name: &str) -> bool {
|
||||
self.frontend_tools.contains_key(name)
|
||||
}
|
||||
|
||||
/// Get a reference to a frontend tool
|
||||
pub fn get_frontend_tool(&self, name: &str) -> Option<&FrontendTool> {
|
||||
self.frontend_tools.get(name)
|
||||
}
|
||||
|
||||
/// Find and return a reference to the appropriate client for a tool call
|
||||
fn get_client_for_tool(&self, prefixed_name: &str) -> Option<(&str, McpClientBox)> {
|
||||
self.clients
|
||||
@@ -432,7 +320,7 @@ impl Capabilities {
|
||||
}
|
||||
|
||||
// Function that gets executed for read_resource tool
|
||||
async fn read_resource(&self, params: Value) -> Result<Vec<Content>, ToolError> {
|
||||
pub async fn read_resource(&self, params: Value) -> Result<Vec<Content>, ToolError> {
|
||||
let uri = params
|
||||
.get("uri")
|
||||
.and_then(|v| v.as_str())
|
||||
@@ -544,7 +432,7 @@ impl Capabilities {
|
||||
})
|
||||
}
|
||||
|
||||
async fn list_resources(&self, params: Value) -> Result<Vec<Content>, ToolError> {
|
||||
pub async fn list_resources(&self, params: Value) -> Result<Vec<Content>, ToolError> {
|
||||
let extension = params.get("extension").and_then(|v| v.as_str());
|
||||
|
||||
match extension {
|
||||
@@ -594,42 +482,26 @@ impl Capabilities {
|
||||
}
|
||||
}
|
||||
|
||||
/// Dispatch a single tool call to the appropriate client
|
||||
#[instrument(skip(self, tool_call), fields(input, output))]
|
||||
pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> ToolResult<Vec<Content>> {
|
||||
let result = if tool_call.name == "platform__read_resource" {
|
||||
// Check if the tool is read_resource and handle it separately
|
||||
self.read_resource(tool_call.arguments.clone()).await
|
||||
} else if tool_call.name == "platform__list_resources" {
|
||||
self.list_resources(tool_call.arguments.clone()).await
|
||||
} else if tool_call.name == "platform__search_available_extensions" {
|
||||
self.search_available_extensions().await
|
||||
} else if self.is_frontend_tool(&tool_call.name) {
|
||||
// For frontend tools, return an error indicating we need frontend execution
|
||||
Err(ToolError::ExecutionError(
|
||||
"Frontend tool execution required".to_string(),
|
||||
))
|
||||
} else {
|
||||
// Else, dispatch tool call based on the prefix naming convention
|
||||
let (client_name, client) = self
|
||||
.get_client_for_tool(&tool_call.name)
|
||||
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?;
|
||||
// Dispatch tool call based on the prefix naming convention
|
||||
let (client_name, client) = self
|
||||
.get_client_for_tool(&tool_call.name)
|
||||
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?;
|
||||
|
||||
// rsplit returns the iterator in reverse, tool_name is then at 0
|
||||
let tool_name = tool_call
|
||||
.name
|
||||
.strip_prefix(client_name)
|
||||
.and_then(|s| s.strip_prefix("__"))
|
||||
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?;
|
||||
// rsplit returns the iterator in reverse, tool_name is then at 0
|
||||
let tool_name = tool_call
|
||||
.name
|
||||
.strip_prefix(client_name)
|
||||
.and_then(|s| s.strip_prefix("__"))
|
||||
.ok_or_else(|| ToolError::NotFound(tool_call.name.clone()))?;
|
||||
|
||||
let client_guard = client.lock().await;
|
||||
let client_guard = client.lock().await;
|
||||
|
||||
client_guard
|
||||
.call_tool(tool_name, tool_call.clone().arguments)
|
||||
.await
|
||||
.map(|result| result.content)
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))
|
||||
};
|
||||
let result = client_guard
|
||||
.call_tool(tool_name, tool_call.clone().arguments)
|
||||
.await
|
||||
.map(|result| result.content)
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()));
|
||||
|
||||
debug!(
|
||||
"input" = serde_json::to_string(&tool_call).unwrap(),
|
||||
@@ -725,7 +597,7 @@ impl Capabilities {
|
||||
|
||||
// First get disabled extensions from current config
|
||||
let mut disabled_extensions: Vec<String> = vec![];
|
||||
for extension in ExtensionManager::get_all().expect("should load extensions") {
|
||||
for extension in ExtensionConfigManager::get_all().expect("should load extensions") {
|
||||
if !extension.enabled {
|
||||
let config = extension.config.clone();
|
||||
let description = match &config {
|
||||
@@ -775,10 +647,6 @@ impl Capabilities {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||
use crate::providers::errors::ProviderError;
|
||||
use mcp_client::client::Error;
|
||||
use mcp_client::client::McpClientTrait;
|
||||
use mcp_core::protocol::{
|
||||
@@ -787,35 +655,6 @@ mod tests {
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
// Mock Provider implementation for testing
|
||||
#[derive(Clone)]
|
||||
struct MockProvider {
|
||||
model_config: ModelConfig,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Provider for MockProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::empty()
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model_config.clone()
|
||||
}
|
||||
|
||||
async fn complete(
|
||||
&self,
|
||||
_system: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> anyhow::Result<(Message, ProviderUsage), ProviderError> {
|
||||
Ok((
|
||||
Message::assistant().with_text("Mock response"),
|
||||
ProviderUsage::new("mock".to_string(), Usage::default()),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
struct MockClient {}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -871,74 +710,68 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_get_client_for_tool() {
|
||||
let mock_model_config =
|
||||
ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into());
|
||||
|
||||
let mut capabilities = Capabilities::new(Box::new(MockProvider {
|
||||
model_config: mock_model_config,
|
||||
}));
|
||||
let mut extension_manager = ExtensionManager::new();
|
||||
|
||||
// Add some mock clients
|
||||
capabilities.clients.insert(
|
||||
extension_manager.clients.insert(
|
||||
normalize("test_client".to_string()),
|
||||
Arc::new(Mutex::new(Box::new(MockClient {}))),
|
||||
);
|
||||
|
||||
capabilities.clients.insert(
|
||||
extension_manager.clients.insert(
|
||||
normalize("__client".to_string()),
|
||||
Arc::new(Mutex::new(Box::new(MockClient {}))),
|
||||
);
|
||||
|
||||
capabilities.clients.insert(
|
||||
extension_manager.clients.insert(
|
||||
normalize("__cli__ent__".to_string()),
|
||||
Arc::new(Mutex::new(Box::new(MockClient {}))),
|
||||
);
|
||||
|
||||
capabilities.clients.insert(
|
||||
extension_manager.clients.insert(
|
||||
normalize("client 🚀".to_string()),
|
||||
Arc::new(Mutex::new(Box::new(MockClient {}))),
|
||||
);
|
||||
|
||||
// Test basic case
|
||||
assert!(capabilities
|
||||
assert!(extension_manager
|
||||
.get_client_for_tool("test_client__tool")
|
||||
.is_some());
|
||||
|
||||
// Test leading underscores
|
||||
assert!(capabilities.get_client_for_tool("__client__tool").is_some());
|
||||
assert!(extension_manager
|
||||
.get_client_for_tool("__client__tool")
|
||||
.is_some());
|
||||
|
||||
// Test multiple underscores in client name, and ending with __
|
||||
assert!(capabilities
|
||||
assert!(extension_manager
|
||||
.get_client_for_tool("__cli__ent____tool")
|
||||
.is_some());
|
||||
|
||||
// Test unicode in tool name, "client 🚀" should become "client_"
|
||||
assert!(capabilities.get_client_for_tool("client___tool").is_some());
|
||||
assert!(extension_manager
|
||||
.get_client_for_tool("client___tool")
|
||||
.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dispatch_tool_call() {
|
||||
// test that dispatch_tool_call parses out the sanitized name correctly, and extracts
|
||||
// tool_names
|
||||
let mock_model_config =
|
||||
ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into());
|
||||
|
||||
let mut capabilities = Capabilities::new(Box::new(MockProvider {
|
||||
model_config: mock_model_config,
|
||||
}));
|
||||
let mut extension_manager = ExtensionManager::new();
|
||||
|
||||
// Add some mock clients
|
||||
capabilities.clients.insert(
|
||||
extension_manager.clients.insert(
|
||||
normalize("test_client".to_string()),
|
||||
Arc::new(Mutex::new(Box::new(MockClient {}))),
|
||||
);
|
||||
|
||||
capabilities.clients.insert(
|
||||
extension_manager.clients.insert(
|
||||
normalize("__cli__ent__".to_string()),
|
||||
Arc::new(Mutex::new(Box::new(MockClient {}))),
|
||||
);
|
||||
|
||||
capabilities.clients.insert(
|
||||
extension_manager.clients.insert(
|
||||
normalize("client 🚀".to_string()),
|
||||
Arc::new(Mutex::new(Box::new(MockClient {}))),
|
||||
);
|
||||
@@ -949,7 +782,7 @@ mod tests {
|
||||
arguments: json!({}),
|
||||
};
|
||||
|
||||
let result = capabilities.dispatch_tool_call(tool_call).await;
|
||||
let result = extension_manager.dispatch_tool_call(tool_call).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let tool_call = ToolCall {
|
||||
@@ -957,7 +790,7 @@ mod tests {
|
||||
arguments: json!({}),
|
||||
};
|
||||
|
||||
let result = capabilities.dispatch_tool_call(tool_call).await;
|
||||
let result = extension_manager.dispatch_tool_call(tool_call).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// verify a multiple underscores dispatch
|
||||
@@ -966,7 +799,7 @@ mod tests {
|
||||
arguments: json!({}),
|
||||
};
|
||||
|
||||
let result = capabilities.dispatch_tool_call(tool_call).await;
|
||||
let result = extension_manager.dispatch_tool_call(tool_call).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test unicode in tool name, "client 🚀" should become "client_"
|
||||
@@ -975,7 +808,7 @@ mod tests {
|
||||
arguments: json!({}),
|
||||
};
|
||||
|
||||
let result = capabilities.dispatch_tool_call(tool_call).await;
|
||||
let result = extension_manager.dispatch_tool_call(tool_call).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let tool_call = ToolCall {
|
||||
@@ -983,7 +816,7 @@ mod tests {
|
||||
arguments: json!({}),
|
||||
};
|
||||
|
||||
let result = capabilities.dispatch_tool_call(tool_call).await;
|
||||
let result = extension_manager.dispatch_tool_call(tool_call).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// this should error out, specifically for an ToolError::ExecutionError
|
||||
@@ -992,7 +825,9 @@ mod tests {
|
||||
arguments: json!({}),
|
||||
};
|
||||
|
||||
let result = capabilities.dispatch_tool_call(invalid_tool_call).await;
|
||||
let result = extension_manager
|
||||
.dispatch_tool_call(invalid_tool_call)
|
||||
.await;
|
||||
assert!(matches!(
|
||||
result.err().unwrap(),
|
||||
ToolError::ExecutionError(_)
|
||||
@@ -1005,7 +840,9 @@ mod tests {
|
||||
arguments: json!({}),
|
||||
};
|
||||
|
||||
let result = capabilities.dispatch_tool_call(invalid_tool_call).await;
|
||||
let result = extension_manager
|
||||
.dispatch_tool_call(invalid_tool_call)
|
||||
.await;
|
||||
assert!(matches!(result.err().unwrap(), ToolError::NotFound(_)));
|
||||
}
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{OnceLock, RwLock};
|
||||
|
||||
pub use super::Agent;
|
||||
use crate::config::Config;
|
||||
use crate::providers::base::Provider;
|
||||
|
||||
type AgentConstructor = Box<dyn Fn(Box<dyn Provider>) -> Box<dyn Agent> + Send + Sync>;
|
||||
|
||||
// Use std::sync::RwLock for interior mutability
|
||||
static AGENT_REGISTRY: OnceLock<RwLock<HashMap<&'static str, AgentConstructor>>> = OnceLock::new();
|
||||
|
||||
/// Initialize the registry if it hasn't been initialized
|
||||
fn registry() -> &'static RwLock<HashMap<&'static str, AgentConstructor>> {
|
||||
AGENT_REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
|
||||
}
|
||||
|
||||
/// Register a new agent version
|
||||
pub fn register_agent(
|
||||
version: &'static str,
|
||||
constructor: impl Fn(Box<dyn Provider>) -> Box<dyn Agent> + Send + Sync + 'static,
|
||||
) {
|
||||
let registry = registry();
|
||||
if let Ok(mut map) = registry.write() {
|
||||
map.insert(version, Box::new(constructor));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AgentFactory;
|
||||
|
||||
impl AgentFactory {
|
||||
/// Create a new agent instance of the specified version
|
||||
pub fn create(version: &str, provider: Box<dyn Provider>) -> Option<Box<dyn Agent>> {
|
||||
let registry = registry();
|
||||
let map = registry
|
||||
.read()
|
||||
.expect("should be able to read the registry");
|
||||
let constructor = map.get(version)?;
|
||||
Some(constructor(provider))
|
||||
}
|
||||
|
||||
/// Get a list of all available agent versions
|
||||
pub fn available_versions() -> Vec<&'static str> {
|
||||
registry()
|
||||
.read()
|
||||
.map(|map| map.keys().copied().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn configured_version() -> String {
|
||||
let config = Config::global();
|
||||
config
|
||||
.get_param::<String>("GOOSE_AGENT")
|
||||
.unwrap_or_else(|_| Self::default_version().to_string())
|
||||
}
|
||||
|
||||
/// Get the default version name
|
||||
pub fn default_version() -> &'static str {
|
||||
"truncate"
|
||||
}
|
||||
}
|
||||
|
||||
/// Macro to help with agent registration
|
||||
#[macro_export]
|
||||
macro_rules! register_agent {
|
||||
($version:expr, $agent_type:ty) => {
|
||||
paste::paste! {
|
||||
#[ctor::ctor]
|
||||
#[allow(non_snake_case)]
|
||||
fn [<__register_agent_ $version>]() {
|
||||
$crate::agents::factory::register_agent($version, |provider| {
|
||||
Box::new(<$agent_type>::new(provider))
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -1,13 +1,12 @@
|
||||
mod agent;
|
||||
pub mod capabilities;
|
||||
pub mod extension;
|
||||
mod factory;
|
||||
mod reference;
|
||||
mod summarize;
|
||||
mod truncate;
|
||||
pub mod extension_manager;
|
||||
pub mod platform_tools;
|
||||
pub mod prompt_manager;
|
||||
mod types;
|
||||
|
||||
pub use agent::{Agent, SessionConfig};
|
||||
pub use capabilities::Capabilities;
|
||||
pub use extension::{ExtensionConfig, ExtensionResult};
|
||||
pub use factory::{register_agent, AgentFactory};
|
||||
pub use agent::Agent;
|
||||
pub use extension::ExtensionConfig;
|
||||
pub use extension_manager::ExtensionManager;
|
||||
pub use prompt_manager::PromptManager;
|
||||
pub use types::{FrontendTool, SessionConfig};
|
||||
|
||||
112
crates/goose/src/agents/platform_tools.rs
Normal file
112
crates/goose/src/agents/platform_tools.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use indoc::indoc;
|
||||
use mcp_core::tool::{Tool, ToolAnnotations};
|
||||
use serde_json::json;
|
||||
|
||||
pub const PLATFORM_READ_RESOURCE_TOOL_NAME: &str = "platform__read_resource";
|
||||
pub const PLATFORM_LIST_RESOURCES_TOOL_NAME: &str = "platform__list_resources";
|
||||
pub const PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME: &str =
|
||||
"platform__search_available_extensions";
|
||||
pub const PLATFORM_ENABLE_EXTENSION_TOOL_NAME: &str = "platform__enable_extension";
|
||||
|
||||
pub fn read_resource_tool() -> Tool {
|
||||
Tool::new(
|
||||
PLATFORM_READ_RESOURCE_TOOL_NAME.to_string(),
|
||||
indoc! {r#"
|
||||
Read a resource from an extension.
|
||||
|
||||
Resources allow extensions to share data that provide context to LLMs, such as
|
||||
files, database schemas, or application-specific information. This tool searches for the
|
||||
resource URI in the provided extension, and reads in the resource content. If no extension
|
||||
is provided, the tool will search all extensions for the resource.
|
||||
"#}.to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["uri"],
|
||||
"properties": {
|
||||
"uri": {"type": "string", "description": "Resource URI"},
|
||||
"extension_name": {"type": "string", "description": "Optional extension name"}
|
||||
}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("Read a resource".to_string()),
|
||||
read_only_hint: true,
|
||||
destructive_hint: false,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: false,
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn list_resources_tool() -> Tool {
|
||||
Tool::new(
|
||||
PLATFORM_LIST_RESOURCES_TOOL_NAME.to_string(),
|
||||
indoc! {r#"
|
||||
List resources from an extension(s).
|
||||
|
||||
Resources allow extensions to share data that provide context to LLMs, such as
|
||||
files, database schemas, or application-specific information. This tool lists resources
|
||||
in the provided extension, and returns a list for the user to browse. If no extension
|
||||
is provided, the tool will search all extensions for the resource.
|
||||
"#}
|
||||
.to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"extension_name": {"type": "string", "description": "Optional extension name"}
|
||||
}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("List resources".to_string()),
|
||||
read_only_hint: true,
|
||||
destructive_hint: false,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: false,
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn search_available_extensions_tool() -> Tool {
|
||||
Tool::new(
|
||||
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME.to_string(),
|
||||
"Searches for additional extensions available to help complete tasks.
|
||||
Use this tool when you're unable to find a specific feature or functionality you need to complete your task, or when standard approaches aren't working.
|
||||
These extensions might provide the exact tools needed to solve your problem.
|
||||
If you find a relevant one, consider using your tools to enable it.".to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": [],
|
||||
"properties": {}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("Discover extensions".to_string()),
|
||||
read_only_hint: true,
|
||||
destructive_hint: false,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: false,
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn enable_extension_tool() -> Tool {
|
||||
Tool::new(
|
||||
PLATFORM_ENABLE_EXTENSION_TOOL_NAME.to_string(),
|
||||
"Enable extensions to help complete tasks.
|
||||
Enable an extension by providing the extension name.
|
||||
"
|
||||
.to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["extension_name"],
|
||||
"properties": {
|
||||
"extension_name": {"type": "string", "description": "The name of the extension to enable"}
|
||||
}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("Enable extensions".to_string()),
|
||||
read_only_hint: false,
|
||||
destructive_hint: false,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: false,
|
||||
}),
|
||||
)
|
||||
}
|
||||
95
crates/goose/src/agents/prompt_manager.rs
Normal file
95
crates/goose/src/agents/prompt_manager.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
use chrono::Utc;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::agents::extension::ExtensionInfo;
|
||||
use crate::{config::Config, prompt_template};
|
||||
|
||||
pub struct PromptManager {
|
||||
system_prompt_override: Option<String>,
|
||||
system_prompt_extras: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for PromptManager {
|
||||
fn default() -> Self {
|
||||
PromptManager::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptManager {
|
||||
pub fn new() -> Self {
|
||||
PromptManager {
|
||||
system_prompt_override: None,
|
||||
system_prompt_extras: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an additional instruction to the system prompt
|
||||
pub fn add_system_prompt_extra(&mut self, instruction: String) {
|
||||
self.system_prompt_extras.push(instruction);
|
||||
}
|
||||
|
||||
/// Override the system prompt with custom text
|
||||
pub fn set_system_prompt_override(&mut self, template: String) {
|
||||
self.system_prompt_override = Some(template);
|
||||
}
|
||||
|
||||
/// Build the final system prompt
|
||||
///
|
||||
/// * `extensions_info` – extension information for each extension/MCP
|
||||
/// * `frontend_instructions` – instructions for the "frontend" tool
|
||||
pub fn build_system_prompt(
|
||||
&self,
|
||||
extensions_info: Vec<ExtensionInfo>,
|
||||
frontend_instructions: Option<String>,
|
||||
) -> String {
|
||||
let mut context: HashMap<&str, Value> = HashMap::new();
|
||||
let mut extensions_info = extensions_info.clone();
|
||||
|
||||
// Add frontend instructions to extensions_info to simplify json rendering
|
||||
if let Some(frontend_instructions) = frontend_instructions {
|
||||
extensions_info.push(ExtensionInfo::new(
|
||||
"frontend",
|
||||
&frontend_instructions,
|
||||
false,
|
||||
));
|
||||
}
|
||||
|
||||
context.insert("extensions", serde_json::to_value(extensions_info).unwrap());
|
||||
|
||||
let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
|
||||
context.insert("current_date_time", Value::String(current_date_time));
|
||||
|
||||
// Conditionally load the override prompt or the global system prompt
|
||||
let base_prompt = if let Some(override_prompt) = &self.system_prompt_override {
|
||||
prompt_template::render_inline_once(override_prompt, &context)
|
||||
.expect("Prompt should render")
|
||||
} else {
|
||||
prompt_template::render_global_file("system.md", &context)
|
||||
.expect("Prompt should render")
|
||||
};
|
||||
|
||||
let mut system_prompt_extras = self.system_prompt_extras.clone();
|
||||
let config = Config::global();
|
||||
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
|
||||
if goose_mode == "chat" {
|
||||
system_prompt_extras.push(
|
||||
"Right now you are in the chat only mode, no access to any tool use and system."
|
||||
.to_string(),
|
||||
);
|
||||
} else {
|
||||
system_prompt_extras
|
||||
.push("Right now you are *NOT* in the chat only mode and have access to tool use and system.".to_string());
|
||||
}
|
||||
|
||||
if system_prompt_extras.is_empty() {
|
||||
base_prompt
|
||||
} else {
|
||||
format!(
|
||||
"{}\n\n# Additional Instructions:\n\n{}",
|
||||
base_prompt,
|
||||
system_prompt_extras.join("\n\n")
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,319 +0,0 @@
|
||||
/// A simplified agent implementation used as a reference
|
||||
/// It makes no attempt to handle context limits, and cannot read resources
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
use super::agent::SessionConfig;
|
||||
use super::capabilities::get_parameter_names;
|
||||
use super::extension::ToolInfo;
|
||||
use super::types::ToolResultReceiver;
|
||||
use super::Agent;
|
||||
use crate::agents::capabilities::Capabilities;
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
|
||||
use crate::message::{Message, ToolRequest};
|
||||
use crate::permission::PermissionConfirmation;
|
||||
use crate::providers::base::Provider;
|
||||
use crate::token_counter::TokenCounter;
|
||||
use crate::{register_agent, session};
|
||||
use anyhow::{anyhow, Result};
|
||||
use indoc::indoc;
|
||||
use mcp_core::tool::{Tool, ToolAnnotations};
|
||||
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolResult};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
/// Reference implementation of an Agent
|
||||
pub struct ReferenceAgent {
|
||||
capabilities: Mutex<Capabilities>,
|
||||
_token_counter: TokenCounter,
|
||||
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||
tool_result_rx: ToolResultReceiver,
|
||||
}
|
||||
|
||||
impl ReferenceAgent {
|
||||
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||
let (tx, rx) = mpsc::channel(32);
|
||||
Self {
|
||||
capabilities: Mutex::new(Capabilities::new(provider)),
|
||||
_token_counter: token_counter,
|
||||
tool_result_tx: tx,
|
||||
tool_result_rx: Arc::new(Mutex::new(rx)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Agent for ReferenceAgent {
|
||||
async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.add_extension(extension).await
|
||||
}
|
||||
|
||||
async fn list_tools(&self) -> Vec<Tool> {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.get_prefixed_tools().await.unwrap_or_default()
|
||||
}
|
||||
|
||||
async fn remove_extension(&mut self, name: &str) {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities
|
||||
.remove_extension(name)
|
||||
.await
|
||||
.expect("Failed to remove extension");
|
||||
}
|
||||
|
||||
async fn list_extensions(&self) -> Vec<String> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities
|
||||
.list_extensions()
|
||||
.await
|
||||
.expect("Failed to list extensions")
|
||||
}
|
||||
|
||||
async fn passthrough(&self, _extension: &str, _request: Value) -> ExtensionResult<Value> {
|
||||
// TODO implement
|
||||
Ok(Value::Null)
|
||||
}
|
||||
|
||||
async fn handle_confirmation(
|
||||
&self,
|
||||
_request_id: String,
|
||||
_confirmation: PermissionConfirmation,
|
||||
) {
|
||||
// TODO implement
|
||||
}
|
||||
|
||||
#[instrument(skip(self, messages, session), fields(user_message))]
|
||||
async fn reply(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
session: Option<SessionConfig>,
|
||||
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
|
||||
let mut messages = messages.to_vec();
|
||||
let reply_span = tracing::Span::current();
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
let mut tools = capabilities.get_prefixed_tools().await?;
|
||||
// we add in the read_resource tool by default
|
||||
// TODO: make sure there is no collision with another extension's tool name
|
||||
let read_resource_tool = Tool::new(
|
||||
"platform__read_resource".to_string(),
|
||||
indoc! {r#"
|
||||
Read a resource from an extension.
|
||||
|
||||
Resources allow extensions to share data that provide context to LLMs, such as
|
||||
files, database schemas, or application-specific information. This tool searches for the
|
||||
resource URI in the provided extension, and reads in the resource content. If no extension
|
||||
is provided, the tool will search all extensions for the resource.
|
||||
"#}.to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["uri"],
|
||||
"properties": {
|
||||
"uri": {"type": "string", "description": "Resource URI"},
|
||||
"extension_name": {"type": "string", "description": "Optional extension name"}
|
||||
}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("Read a resource".to_string()),
|
||||
read_only_hint: true,
|
||||
destructive_hint: false,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: false,
|
||||
}),
|
||||
);
|
||||
|
||||
let list_resources_tool = Tool::new(
|
||||
"platform__list_resources".to_string(),
|
||||
indoc! {r#"
|
||||
List resources from an extension(s).
|
||||
|
||||
Resources allow extensions to share data that provide context to LLMs, such as
|
||||
files, database schemas, or application-specific information. This tool lists resources
|
||||
in the provided extension, and returns a list for the user to browse. If no extension
|
||||
is provided, the tool will search all extensions for the resource.
|
||||
"#}.to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"extension_name": {"type": "string", "description": "Optional extension name"}
|
||||
}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("List resources".to_string()),
|
||||
read_only_hint: true,
|
||||
destructive_hint: false,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: false,
|
||||
}),
|
||||
);
|
||||
|
||||
if capabilities.supports_resources() {
|
||||
tools.push(read_resource_tool);
|
||||
tools.push(list_resources_tool);
|
||||
}
|
||||
|
||||
let system_prompt = capabilities.get_system_prompt().await;
|
||||
|
||||
// Set the user_message field in the span instead of creating a new event
|
||||
if let Some(content) = messages
|
||||
.last()
|
||||
.and_then(|msg| msg.content.first())
|
||||
.and_then(|c| c.as_text())
|
||||
{
|
||||
debug!("user_message" = &content);
|
||||
}
|
||||
|
||||
Ok(Box::pin(async_stream::try_stream! {
|
||||
let _reply_guard = reply_span.enter();
|
||||
loop {
|
||||
// Get completion from provider
|
||||
let (response, usage) = capabilities.provider().complete(
|
||||
&system_prompt,
|
||||
&messages,
|
||||
&tools,
|
||||
).await?;
|
||||
|
||||
// record usage for the session in the session file
|
||||
if let Some(session) = session.clone() {
|
||||
// TODO: track session_id in langfuse tracing
|
||||
let session_file = session::get_path(session.id);
|
||||
let mut metadata = session::read_metadata(&session_file)?;
|
||||
metadata.working_dir = session.working_dir;
|
||||
metadata.total_tokens = usage.usage.total_tokens;
|
||||
metadata.input_tokens = usage.usage.input_tokens;
|
||||
metadata.output_tokens = usage.usage.output_tokens;
|
||||
// The message count is the number of messages in the session + 1 for the response
|
||||
// The message count does not include the tool response till next iteration
|
||||
metadata.message_count = messages.len() + 1;
|
||||
session::update_metadata(&session_file, &metadata).await?;
|
||||
}
|
||||
|
||||
// Yield the assistant's response
|
||||
yield response.clone();
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
// First collect any tool requests
|
||||
let tool_requests: Vec<&ToolRequest> = response.content
|
||||
.iter()
|
||||
.filter_map(|content| content.as_tool_request())
|
||||
.collect();
|
||||
|
||||
if tool_requests.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Then dispatch each in parallel
|
||||
let mut message_tool_response = Message::user();
|
||||
for request in tool_requests {
|
||||
if let Ok(tool_call) = &request.tool_call {
|
||||
// Check if it's a frontend tool
|
||||
if capabilities.is_frontend_tool(&tool_call.name) {
|
||||
// Send frontend tool request and wait for response
|
||||
yield Message::assistant().with_frontend_tool_request(
|
||||
request.id.clone(),
|
||||
request.tool_call.clone()
|
||||
);
|
||||
|
||||
// Wait for the result using our channel
|
||||
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
|
||||
message_tool_response = message_tool_response.with_tool_response(id, result);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle regular tool calls
|
||||
let result = capabilities.dispatch_tool_call(tool_call.clone()).await;
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
result,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
yield message_tool_response.clone();
|
||||
|
||||
messages.push(response);
|
||||
messages.push(message_tool_response);
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
async fn extend_system_prompt(&mut self, extension: String) {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.add_system_prompt_extension(extension);
|
||||
}
|
||||
|
||||
async fn override_system_prompt(&mut self, template: String) {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.set_system_prompt_override(template);
|
||||
}
|
||||
|
||||
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities
|
||||
.list_prompts()
|
||||
.await
|
||||
.expect("Failed to list prompts")
|
||||
}
|
||||
|
||||
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
|
||||
// First find which extension has this prompt
|
||||
let prompts = capabilities
|
||||
.list_prompts()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
|
||||
|
||||
if let Some(extension) = prompts
|
||||
.iter()
|
||||
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
|
||||
.map(|(extension, _)| extension)
|
||||
{
|
||||
return capabilities
|
||||
.get_prompt(extension, name, arguments)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
|
||||
}
|
||||
|
||||
Err(anyhow!("Prompt '{}' not found", name))
|
||||
}
|
||||
|
||||
async fn get_plan_prompt(&self) -> anyhow::Result<String> {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
let tools = capabilities.get_prefixed_tools().await?;
|
||||
let tools_info = tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
ToolInfo::new(
|
||||
&tool.name,
|
||||
&tool.description,
|
||||
get_parameter_names(&tool),
|
||||
None,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let plan_prompt = capabilities.get_planning_prompt(tools_info).await;
|
||||
|
||||
Ok(plan_prompt)
|
||||
}
|
||||
|
||||
async fn provider(&self) -> Arc<Box<dyn Provider>> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities.provider()
|
||||
}
|
||||
|
||||
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
|
||||
if let Err(e) = self.tool_result_tx.send((id, result)).await {
|
||||
tracing::error!("Failed to send tool result: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
register_agent!("reference", ReferenceAgent);
|
||||
@@ -1,519 +0,0 @@
|
||||
/// A summarize agent that lets the model summarize the conversation when the history exceeds the
|
||||
/// model's context limit. If the model fails to summarize, then it falls back to the legacy
|
||||
/// truncation method. Still cannot read resources.
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use mcp_core::tool::ToolAnnotations;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, error, instrument, warn};
|
||||
|
||||
use super::agent::SessionConfig;
|
||||
use super::capabilities::get_parameter_names;
|
||||
use super::extension::ToolInfo;
|
||||
use super::Agent;
|
||||
use crate::agents::capabilities::Capabilities;
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
|
||||
use crate::config::Config;
|
||||
use crate::memory_condense::condense_messages;
|
||||
use crate::message::{Message, ToolRequest};
|
||||
use crate::permission::detect_read_only_tools;
|
||||
use crate::permission::Permission;
|
||||
use crate::permission::PermissionConfirmation;
|
||||
use crate::providers::base::Provider;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::register_agent;
|
||||
use crate::session;
|
||||
use crate::token_counter::TokenCounter;
|
||||
use crate::truncate::{truncate_messages, OldestFirstTruncation};
|
||||
use anyhow::{anyhow, Result};
|
||||
use indoc::indoc;
|
||||
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolResult};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
|
||||
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
|
||||
|
||||
/// Summarize implementation of an Agent
|
||||
pub struct SummarizeAgent {
|
||||
capabilities: Mutex<Capabilities>,
|
||||
token_counter: TokenCounter,
|
||||
confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
|
||||
confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
|
||||
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||
}
|
||||
|
||||
impl SummarizeAgent {
|
||||
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||
// Create channels with buffer size 32 (adjust if needed)
|
||||
let (confirm_tx, confirm_rx) = mpsc::channel(32);
|
||||
let (tool_tx, _tool_rx) = mpsc::channel(32);
|
||||
|
||||
Self {
|
||||
capabilities: Mutex::new(Capabilities::new(provider)),
|
||||
token_counter,
|
||||
confirmation_tx: confirm_tx,
|
||||
confirmation_rx: Mutex::new(confirm_rx),
|
||||
tool_result_tx: tool_tx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncates the messages to fit within the model's context window
|
||||
/// Ensures the last message is a user message and removes tool call-response pairs
|
||||
async fn summarize_messages(
|
||||
&self,
|
||||
messages: &mut Vec<Message>,
|
||||
estimate_factor: f32,
|
||||
system_prompt: &str,
|
||||
tools: &mut Vec<Tool>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Model's actual context limit
|
||||
let context_limit = self
|
||||
.capabilities
|
||||
.lock()
|
||||
.await
|
||||
.provider()
|
||||
.get_model_config()
|
||||
.context_limit();
|
||||
|
||||
// Our conservative estimate of the **target** context limit
|
||||
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
|
||||
let context_limit = (context_limit as f32 * estimate_factor) as usize;
|
||||
|
||||
// Take into account the system prompt, and our tools input and subtract that from the
|
||||
// remaining context limit
|
||||
let system_prompt_token_count = self.token_counter.count_tokens(system_prompt);
|
||||
let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice());
|
||||
|
||||
// Check if system prompt + tools exceed our context limit
|
||||
let remaining_tokens = context_limit
|
||||
.checked_sub(system_prompt_token_count)
|
||||
.and_then(|remaining| remaining.checked_sub(tools_token_count))
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("System prompt and tools exceed estimated context limit")
|
||||
})?;
|
||||
|
||||
let context_limit = remaining_tokens;
|
||||
|
||||
// Calculate current token count of each message, use count_chat_tokens to ensure we
|
||||
// capture the full content of the message, include ToolRequests and ToolResponses
|
||||
let mut token_counts: Vec<usize> = messages
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
self.token_counter
|
||||
.count_chat_tokens("", std::slice::from_ref(msg), &[])
|
||||
})
|
||||
.collect();
|
||||
|
||||
let capabilities_guard = self.capabilities.lock().await;
|
||||
if condense_messages(
|
||||
&capabilities_guard,
|
||||
&self.token_counter,
|
||||
messages,
|
||||
&mut token_counts,
|
||||
context_limit,
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
// Fallback to the legacy truncator if the model fails to condense the messages.
|
||||
truncate_messages(
|
||||
messages,
|
||||
&mut token_counts,
|
||||
context_limit,
|
||||
&OldestFirstTruncation,
|
||||
)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Agent for SummarizeAgent {
|
||||
async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.add_extension(extension).await
|
||||
}
|
||||
|
||||
async fn list_tools(&self) -> Vec<Tool> {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.get_prefixed_tools().await.unwrap_or_default()
|
||||
}
|
||||
|
||||
async fn remove_extension(&mut self, name: &str) {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities
|
||||
.remove_extension(name)
|
||||
.await
|
||||
.expect("Failed to remove extension");
|
||||
}
|
||||
|
||||
async fn list_extensions(&self) -> Vec<String> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities
|
||||
.list_extensions()
|
||||
.await
|
||||
.expect("Failed to list extensions")
|
||||
}
|
||||
|
||||
async fn passthrough(&self, _extension: &str, _request: Value) -> ExtensionResult<Value> {
|
||||
// TODO implement
|
||||
Ok(Value::Null)
|
||||
}
|
||||
|
||||
/// Handle a confirmation response for a tool request
|
||||
async fn handle_confirmation(&self, request_id: String, confirmation: PermissionConfirmation) {
|
||||
if let Err(e) = self.confirmation_tx.send((request_id, confirmation)).await {
|
||||
error!("Failed to send confirmation: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self, messages, session), fields(user_message))]
|
||||
async fn reply(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
session: Option<SessionConfig>,
|
||||
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
|
||||
let mut messages = messages.to_vec();
|
||||
let reply_span = tracing::Span::current();
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
let mut tools = capabilities.get_prefixed_tools().await?;
|
||||
let mut truncation_attempt: usize = 0;
|
||||
|
||||
// Load settings from config
|
||||
let config = Config::global();
|
||||
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
|
||||
|
||||
// we add in the 2 resource tools if any extensions support resources
|
||||
// TODO: make sure there is no collision with another extension's tool name
|
||||
let read_resource_tool = Tool::new(
|
||||
"platform__read_resource".to_string(),
|
||||
indoc! {r#"
|
||||
Read a resource from an extension.
|
||||
|
||||
Resources allow extensions to share data that provide context to LLMs, such as
|
||||
files, database schemas, or application-specific information. This tool searches for the
|
||||
resource URI in the provided extension, and reads in the resource content. If no extension
|
||||
is provided, the tool will search all extensions for the resource.
|
||||
"#}.to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["uri"],
|
||||
"properties": {
|
||||
"uri": {"type": "string", "description": "Resource URI"},
|
||||
"extension_name": {"type": "string", "description": "Optional extension name"}
|
||||
}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("Read a resource".to_string()),
|
||||
read_only_hint: true,
|
||||
destructive_hint: false,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: false,
|
||||
}),
|
||||
);
|
||||
|
||||
let list_resources_tool = Tool::new(
|
||||
"platform__list_resources".to_string(),
|
||||
indoc! {r#"
|
||||
List resources from an extension(s).
|
||||
|
||||
Resources allow extensions to share data that provide context to LLMs, such as
|
||||
files, database schemas, or application-specific information. This tool lists resources
|
||||
in the provided extension, and returns a list for the user to browse. If no extension
|
||||
is provided, the tool will search all extensions for the resource.
|
||||
"#}.to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"extension_name": {"type": "string", "description": "Optional extension name"}
|
||||
}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("List resources".to_string()),
|
||||
read_only_hint: true,
|
||||
destructive_hint: false,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: false,
|
||||
}),
|
||||
);
|
||||
|
||||
if capabilities.supports_resources() {
|
||||
tools.push(read_resource_tool);
|
||||
tools.push(list_resources_tool);
|
||||
}
|
||||
|
||||
let system_prompt = capabilities.get_system_prompt().await;
|
||||
|
||||
// Set the user_message field in the span instead of creating a new event
|
||||
if let Some(content) = messages
|
||||
.last()
|
||||
.and_then(|msg| msg.content.first())
|
||||
.and_then(|c| c.as_text())
|
||||
{
|
||||
debug!("user_message" = &content);
|
||||
}
|
||||
|
||||
Ok(Box::pin(async_stream::try_stream! {
|
||||
let _reply_guard = reply_span.enter();
|
||||
loop {
|
||||
match capabilities.provider().complete(
|
||||
&system_prompt,
|
||||
&messages,
|
||||
&tools,
|
||||
).await {
|
||||
Ok((response, usage)) => {
|
||||
// record usage for the session in the session file
|
||||
if let Some(session) = session.clone() {
|
||||
// TODO: track session_id in langfuse tracing
|
||||
let session_file = session::get_path(session.id);
|
||||
let mut metadata = session::read_metadata(&session_file)?;
|
||||
metadata.working_dir = session.working_dir;
|
||||
metadata.total_tokens = usage.usage.total_tokens;
|
||||
metadata.input_tokens = usage.usage.input_tokens;
|
||||
metadata.output_tokens = usage.usage.output_tokens;
|
||||
// The message count is the number of messages in the session + 1 for the response
|
||||
// The message count does not include the tool response till next iteration
|
||||
metadata.message_count = messages.len() + 1;
|
||||
session::update_metadata(&session_file, &metadata).await?;
|
||||
}
|
||||
|
||||
// Reset truncation attempt
|
||||
truncation_attempt = 0;
|
||||
|
||||
// Yield the assistant's response
|
||||
yield response.clone();
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
// First collect any tool requests
|
||||
let tool_requests: Vec<&ToolRequest> = response.content
|
||||
.iter()
|
||||
.filter_map(|content| content.as_tool_request())
|
||||
.collect();
|
||||
|
||||
if tool_requests.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Process tool requests depending on goose_mode
|
||||
let mut message_tool_response = Message::user();
|
||||
// Clone goose_mode once before the match to avoid move issues
|
||||
let mode = goose_mode.clone();
|
||||
match mode.as_str() {
|
||||
"approve" => {
|
||||
let read_only_tools = detect_read_only_tools(&capabilities, tool_requests.clone()).await;
|
||||
for request in &tool_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
// Skip confirmation if the tool_call.name is in the read_only_tools list
|
||||
if read_only_tools.contains(&tool_call.name) {
|
||||
let output = capabilities.dispatch_tool_call(tool_call).await;
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
output,
|
||||
);
|
||||
} else {
|
||||
let confirmation = Message::user().with_tool_confirmation_request(
|
||||
request.id.clone(),
|
||||
tool_call.name.clone(),
|
||||
tool_call.arguments.clone(),
|
||||
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
|
||||
);
|
||||
yield confirmation;
|
||||
|
||||
// Wait for confirmation response through the channel
|
||||
let mut rx = self.confirmation_rx.lock().await;
|
||||
// Loop the recv until we have a matched req_id due to potential duplicate messages.
|
||||
while let Some((req_id, tool_confirmation)) = rx.recv().await {
|
||||
if req_id == request.id {
|
||||
if tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow {
|
||||
// User approved - dispatch the tool call
|
||||
let output = capabilities.dispatch_tool_call(tool_call).await;
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
output,
|
||||
);
|
||||
} else {
|
||||
// User declined - add declined response
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
Ok(vec![Content::text("User declined to run this tool.")]),
|
||||
);
|
||||
}
|
||||
break; // Exit the loop once the matching `req_id` is found
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"chat" => {
|
||||
// Skip all tool calls in chat mode
|
||||
for request in &tool_requests {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
Ok(vec![Content::text(
|
||||
"The following tool call was skipped in Goose chat mode. \
|
||||
In chat mode, you cannot run tool calls, instead, you can \
|
||||
only provide a detailed plan to the user. Provide an \
|
||||
explanation of the proposed tool call as if it were a plan. \
|
||||
Only if the user asks, provide a short explanation to the \
|
||||
user that they could consider running the tool above on \
|
||||
their own or with a different goose mode."
|
||||
)]),
|
||||
);
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
if mode != "auto" {
|
||||
warn!("Unknown GOOSE_MODE: {mode:?}. Defaulting to 'auto' mode.");
|
||||
}
|
||||
// Process tool requests in parallel
|
||||
let mut tool_futures = Vec::new();
|
||||
for request in &tool_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
tool_futures.push(async {
|
||||
let output = capabilities.dispatch_tool_call(tool_call).await;
|
||||
(request.id.clone(), output)
|
||||
});
|
||||
}
|
||||
}
|
||||
// Wait for all tool calls to complete
|
||||
let results = futures::future::join_all(tool_futures).await;
|
||||
for (request_id, output) in results {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request_id,
|
||||
output,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
yield message_tool_response.clone();
|
||||
|
||||
messages.push(response);
|
||||
messages.push(message_tool_response);
|
||||
},
|
||||
Err(ProviderError::ContextLengthExceeded(_)) => {
|
||||
if truncation_attempt >= MAX_TRUNCATION_ATTEMPTS {
|
||||
// Create an error message & terminate the stream
|
||||
// the previous message would have been a user message (e.g. before any tool calls, this is just after the input message.
|
||||
// at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user)
|
||||
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again.");
|
||||
break;
|
||||
}
|
||||
|
||||
truncation_attempt += 1;
|
||||
warn!("Context length exceeded. Truncation Attempt: {}/{}.", truncation_attempt, MAX_TRUNCATION_ATTEMPTS);
|
||||
|
||||
// Decay the estimate factor as we make more truncation attempts
|
||||
// Estimate factor decays like this over time: 0.9, 0.81, 0.729, ...
|
||||
let estimate_factor: f32 = ESTIMATE_FACTOR_DECAY.powi(truncation_attempt as i32);
|
||||
|
||||
// release the lock before truncation to prevent deadlock
|
||||
drop(capabilities);
|
||||
|
||||
if let Err(err) = self.summarize_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await {
|
||||
yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err));
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Re-acquire the lock
|
||||
capabilities = self.capabilities.lock().await;
|
||||
|
||||
// Retry the loop after truncation
|
||||
continue;
|
||||
},
|
||||
Err(e) => {
|
||||
// Create an error message & terminate the stream
|
||||
error!("Error: {}", e);
|
||||
yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Yield control back to the scheduler to prevent blocking
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
async fn extend_system_prompt(&mut self, extension: String) {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.add_system_prompt_extension(extension);
|
||||
}
|
||||
|
||||
async fn override_system_prompt(&mut self, template: String) {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.set_system_prompt_override(template);
|
||||
}
|
||||
|
||||
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities
|
||||
.list_prompts()
|
||||
.await
|
||||
.expect("Failed to list prompts")
|
||||
}
|
||||
|
||||
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
|
||||
// First find which extension has this prompt
|
||||
let prompts = capabilities
|
||||
.list_prompts()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
|
||||
|
||||
if let Some(extension) = prompts
|
||||
.iter()
|
||||
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
|
||||
.map(|(extension, _)| extension)
|
||||
{
|
||||
return capabilities
|
||||
.get_prompt(extension, name, arguments)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
|
||||
}
|
||||
|
||||
Err(anyhow!("Prompt '{}' not found", name))
|
||||
}
|
||||
|
||||
async fn get_plan_prompt(&self) -> anyhow::Result<String> {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
let tools = capabilities.get_prefixed_tools().await?;
|
||||
let tools_info = tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
ToolInfo::new(
|
||||
&tool.name,
|
||||
&tool.description,
|
||||
get_parameter_names(&tool),
|
||||
None,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let plan_prompt = capabilities.get_planning_prompt(tools_info).await;
|
||||
|
||||
Ok(plan_prompt)
|
||||
}
|
||||
|
||||
async fn provider(&self) -> Arc<Box<dyn Provider>> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities.provider()
|
||||
}
|
||||
|
||||
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
|
||||
if let Err(e) = self.tool_result_tx.send((id, result)).await {
|
||||
tracing::error!("Failed to send tool result: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
register_agent!("summarize", SummarizeAgent);
|
||||
@@ -1,804 +0,0 @@
|
||||
/// A truncate agent that truncates the conversation history when it exceeds the model's context limit
|
||||
/// It makes no attempt to handle context limits, and cannot read resources
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use mcp_core::tool::{Tool, ToolAnnotations};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, error, instrument, warn};
|
||||
|
||||
use super::agent::SessionConfig;
|
||||
use super::extension::ToolInfo;
|
||||
use super::types::ToolResultReceiver;
|
||||
use super::Agent;
|
||||
use crate::agents::capabilities::{get_parameter_names, Capabilities};
|
||||
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
|
||||
use crate::config::{Config, ExtensionManager};
|
||||
use crate::message::{Message, MessageContent, ToolRequest};
|
||||
use crate::permission::detect_read_only_tools;
|
||||
use crate::permission::Permission;
|
||||
use crate::permission::PermissionConfirmation;
|
||||
use crate::permission::ToolPermissionStore;
|
||||
use crate::providers::base::Provider;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::providers::toolshim::{
|
||||
augment_message_with_tool_calls, modify_system_prompt_for_tool_json, OllamaInterpreter,
|
||||
};
|
||||
use crate::register_agent;
|
||||
use crate::session;
|
||||
use crate::token_counter::TokenCounter;
|
||||
use crate::truncate::{truncate_messages, OldestFirstTruncation};
|
||||
use anyhow::{anyhow, Result};
|
||||
use indoc::indoc;
|
||||
use mcp_core::{prompt::Prompt, protocol::GetPromptResult, Content, ToolError, ToolResult};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
|
||||
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
|
||||
|
||||
/// Truncate implementation of an Agent
|
||||
pub struct TruncateAgent {
|
||||
capabilities: Mutex<Capabilities>,
|
||||
token_counter: TokenCounter,
|
||||
confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>,
|
||||
confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
|
||||
tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||
tool_result_rx: ToolResultReceiver,
|
||||
}
|
||||
|
||||
impl TruncateAgent {
|
||||
pub fn new(provider: Box<dyn Provider>) -> Self {
|
||||
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
|
||||
// Create channels with buffer size 32 (adjust if needed)
|
||||
let (confirm_tx, confirm_rx) = mpsc::channel(32);
|
||||
let (tool_tx, tool_rx) = mpsc::channel(32);
|
||||
|
||||
Self {
|
||||
capabilities: Mutex::new(Capabilities::new(provider)),
|
||||
token_counter,
|
||||
confirmation_tx: confirm_tx,
|
||||
confirmation_rx: Mutex::new(confirm_rx),
|
||||
tool_result_tx: tool_tx,
|
||||
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncates the messages to fit within the model's context window
|
||||
/// Ensures the last message is a user message and removes tool call-response pairs
|
||||
async fn truncate_messages(
|
||||
&self,
|
||||
messages: &mut Vec<Message>,
|
||||
estimate_factor: f32,
|
||||
system_prompt: &str,
|
||||
tools: &mut Vec<Tool>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Model's actual context limit
|
||||
let context_limit = self
|
||||
.capabilities
|
||||
.lock()
|
||||
.await
|
||||
.provider()
|
||||
.get_model_config()
|
||||
.context_limit();
|
||||
|
||||
// Our conservative estimate of the **target** context limit
|
||||
// Our token count is an estimate since model providers often don't provide the tokenizer (eg. Claude)
|
||||
let context_limit = (context_limit as f32 * estimate_factor) as usize;
|
||||
|
||||
// Take into account the system prompt, and our tools input and subtract that from the
|
||||
// remaining context limit
|
||||
let system_prompt_token_count = self.token_counter.count_tokens(system_prompt);
|
||||
let tools_token_count = self.token_counter.count_tokens_for_tools(tools.as_slice());
|
||||
|
||||
// Check if system prompt + tools exceed our context limit
|
||||
let remaining_tokens = context_limit
|
||||
.checked_sub(system_prompt_token_count)
|
||||
.and_then(|remaining| remaining.checked_sub(tools_token_count))
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("System prompt and tools exceed estimated context limit")
|
||||
})?;
|
||||
|
||||
let context_limit = remaining_tokens;
|
||||
|
||||
// Calculate current token count of each message, use count_chat_tokens to ensure we
|
||||
// capture the full content of the message, include ToolRequests and ToolResponses
|
||||
let mut token_counts: Vec<usize> = messages
|
||||
.iter()
|
||||
.map(|msg| {
|
||||
self.token_counter
|
||||
.count_chat_tokens("", std::slice::from_ref(msg), &[])
|
||||
})
|
||||
.collect();
|
||||
|
||||
truncate_messages(
|
||||
messages,
|
||||
&mut token_counts,
|
||||
context_limit,
|
||||
&OldestFirstTruncation,
|
||||
)
|
||||
}
|
||||
|
||||
async fn create_tool_future(
|
||||
capabilities: &Capabilities,
|
||||
tool_call: mcp_core::tool::ToolCall,
|
||||
request_id: String,
|
||||
) -> (String, Result<Vec<Content>, ToolError>) {
|
||||
let output = capabilities.dispatch_tool_call(tool_call).await;
|
||||
(request_id, output)
|
||||
}
|
||||
|
||||
async fn enable_extension(
|
||||
capabilities: &mut Capabilities,
|
||||
extension_name: String,
|
||||
request_id: String,
|
||||
) -> (String, Result<Vec<Content>, ToolError>) {
|
||||
let config = match ExtensionManager::get_config_by_name(&extension_name) {
|
||||
Ok(Some(config)) => config,
|
||||
Ok(None) => {
|
||||
return (
|
||||
request_id,
|
||||
Err(ToolError::ExecutionError(format!(
|
||||
"Extension '{}' not found. Please check the extension name and try again.",
|
||||
extension_name
|
||||
))),
|
||||
)
|
||||
}
|
||||
Err(e) => {
|
||||
return (
|
||||
request_id,
|
||||
Err(ToolError::ExecutionError(format!(
|
||||
"Failed to get extension config: {}",
|
||||
e
|
||||
))),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let result = capabilities
|
||||
.add_extension(config)
|
||||
.await
|
||||
.map(|_| {
|
||||
vec![Content::text(format!(
|
||||
"The extension '{}' has been installed successfully",
|
||||
extension_name
|
||||
))]
|
||||
})
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()));
|
||||
|
||||
(request_id, result)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Agent for TruncateAgent {
|
||||
async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.add_extension(extension).await
|
||||
}
|
||||
|
||||
async fn list_tools(&self) -> Vec<Tool> {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.get_prefixed_tools().await.unwrap_or_default()
|
||||
}
|
||||
|
||||
async fn remove_extension(&mut self, name: &str) {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities
|
||||
.remove_extension(name)
|
||||
.await
|
||||
.expect("Failed to remove extension");
|
||||
}
|
||||
|
||||
async fn list_extensions(&self) -> Vec<String> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities
|
||||
.list_extensions()
|
||||
.await
|
||||
.expect("Failed to list extensions")
|
||||
}
|
||||
|
||||
async fn passthrough(&self, _extension: &str, _request: Value) -> ExtensionResult<Value> {
|
||||
// TODO implement
|
||||
Ok(Value::Null)
|
||||
}
|
||||
|
||||
/// Handle a confirmation response for a tool request
|
||||
async fn handle_confirmation(&self, request_id: String, confirmation: PermissionConfirmation) {
|
||||
if let Err(e) = self.confirmation_tx.send((request_id, confirmation)).await {
|
||||
error!("Failed to send confirmation: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self, messages, session), fields(user_message))]
|
||||
async fn reply(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
session: Option<SessionConfig>,
|
||||
) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
|
||||
let mut messages = messages.to_vec();
|
||||
let reply_span = tracing::Span::current();
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
let mut tools = capabilities.get_prefixed_tools().await?;
|
||||
let mut truncation_attempt: usize = 0;
|
||||
|
||||
// Load settings from config
|
||||
let config = Config::global();
|
||||
let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
|
||||
|
||||
// we add in the 2 resource tools if any extensions support resources
|
||||
// TODO: make sure there is no collision with another extension's tool name
|
||||
let read_resource_tool = Tool::new(
|
||||
"platform__read_resource".to_string(),
|
||||
indoc! {r#"
|
||||
Read a resource from an extension.
|
||||
|
||||
Resources allow extensions to share data that provide context to LLMs, such as
|
||||
files, database schemas, or application-specific information. This tool searches for the
|
||||
resource URI in the provided extension, and reads in the resource content. If no extension
|
||||
is provided, the tool will search all extensions for the resource.
|
||||
"#}.to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["uri"],
|
||||
"properties": {
|
||||
"uri": {"type": "string", "description": "Resource URI"},
|
||||
"extension_name": {"type": "string", "description": "Optional extension name"}
|
||||
}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("Read a resource".to_string()),
|
||||
read_only_hint: true,
|
||||
destructive_hint: false,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: false,
|
||||
}),
|
||||
);
|
||||
|
||||
let list_resources_tool = Tool::new(
|
||||
"platform__list_resources".to_string(),
|
||||
indoc! {r#"
|
||||
List resources from an extension(s).
|
||||
|
||||
Resources allow extensions to share data that provide context to LLMs, such as
|
||||
files, database schemas, or application-specific information. This tool lists resources
|
||||
in the provided extension, and returns a list for the user to browse. If no extension
|
||||
is provided, the tool will search all extensions for the resource.
|
||||
"#}.to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"extension_name": {"type": "string", "description": "Optional extension name"}
|
||||
}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("List resources".to_string()),
|
||||
read_only_hint: true,
|
||||
destructive_hint: false,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: false,
|
||||
}),
|
||||
);
|
||||
|
||||
let search_available_extensions = Tool::new(
|
||||
"platform__search_available_extensions".to_string(),
|
||||
"Searches for additional extensions available to help complete tasks.
|
||||
Use this tool when you're unable to find a specific feature or functionality you need to complete your task, or when standard approaches aren't working.
|
||||
These extensions might provide the exact tools needed to solve your problem.
|
||||
If you find a relevant one, consider using your tools to enable it.".to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": [],
|
||||
"properties": {}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("Discover extensions".to_string()),
|
||||
read_only_hint: true,
|
||||
destructive_hint: false,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: false,
|
||||
}),
|
||||
);
|
||||
|
||||
let enable_extension_tool = Tool::new(
|
||||
"platform__enable_extension".to_string(),
|
||||
"Enable extensions to help complete tasks.
|
||||
Enable an extension by providing the extension name.
|
||||
"
|
||||
.to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["extension_name"],
|
||||
"properties": {
|
||||
"extension_name": {"type": "string", "description": "The name of the extension to enable"}
|
||||
}
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("Enable extensions".to_string()),
|
||||
read_only_hint: false,
|
||||
destructive_hint: false,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: false,
|
||||
}),
|
||||
);
|
||||
|
||||
if capabilities.supports_resources() {
|
||||
tools.push(read_resource_tool);
|
||||
tools.push(list_resources_tool);
|
||||
}
|
||||
tools.push(search_available_extensions);
|
||||
tools.push(enable_extension_tool);
|
||||
|
||||
let (tools_with_readonly_annotation, tools_without_annotation): (Vec<String>, Vec<String>) =
|
||||
tools.iter().fold((vec![], vec![]), |mut acc, tool| {
|
||||
match &tool.annotations {
|
||||
Some(annotations) => {
|
||||
if annotations.read_only_hint {
|
||||
acc.0.push(tool.name.clone());
|
||||
}
|
||||
}
|
||||
None => {
|
||||
acc.1.push(tool.name.clone());
|
||||
}
|
||||
}
|
||||
acc
|
||||
});
|
||||
|
||||
let config = capabilities.provider().get_model_config();
|
||||
let mut system_prompt = capabilities.get_system_prompt().await;
|
||||
let mut toolshim_tools = vec![];
|
||||
if config.toolshim {
|
||||
// If tool interpretation is enabled, modify the system prompt to instruct to return JSON tool requests
|
||||
system_prompt = modify_system_prompt_for_tool_json(&system_prompt, &tools);
|
||||
// make a copy of tools before empty
|
||||
toolshim_tools = tools.clone();
|
||||
// pass empty tools vector to provider completion since toolshim will handle tool calls instead
|
||||
tools = vec![];
|
||||
}
|
||||
|
||||
// Set the user_message field in the span instead of creating a new event
|
||||
if let Some(content) = messages
|
||||
.last()
|
||||
.and_then(|msg| msg.content.first())
|
||||
.and_then(|c| c.as_text())
|
||||
{
|
||||
debug!("user_message" = &content);
|
||||
}
|
||||
|
||||
Ok(Box::pin(async_stream::try_stream! {
|
||||
let _reply_guard = reply_span.enter();
|
||||
loop {
|
||||
match capabilities.provider().complete(
|
||||
&system_prompt,
|
||||
&messages,
|
||||
&tools,
|
||||
).await {
|
||||
Ok((mut response, usage)) => {
|
||||
// Post-process / structure the response only if tool interpretation is enabled
|
||||
if config.toolshim {
|
||||
let interpreter = OllamaInterpreter::new()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create OllamaInterpreter: {}", e))?;
|
||||
|
||||
response = augment_message_with_tool_calls(&interpreter, response, &toolshim_tools).await?;
|
||||
}
|
||||
|
||||
// record usage for the session in the session file
|
||||
if let Some(session) = session.clone() {
|
||||
// TODO: track session_id in langfuse tracing
|
||||
let session_file = session::get_path(session.id);
|
||||
let mut metadata = session::read_metadata(&session_file)?;
|
||||
metadata.working_dir = session.working_dir;
|
||||
metadata.total_tokens = usage.usage.total_tokens;
|
||||
metadata.input_tokens = usage.usage.input_tokens;
|
||||
metadata.output_tokens = usage.usage.output_tokens;
|
||||
// The message count is the number of messages in the session + 1 for the response
|
||||
// The message count does not include the tool response till next iteration
|
||||
metadata.message_count = messages.len() + 1;
|
||||
session::update_metadata(&session_file, &metadata).await?;
|
||||
}
|
||||
|
||||
// Reset truncation attempt
|
||||
truncation_attempt = 0;
|
||||
|
||||
// Yield the assistant's response, but filter out frontend tool requests that we'll process separately
|
||||
let filtered_response = Message {
|
||||
role: response.role.clone(),
|
||||
created: response.created,
|
||||
content: response.content.iter().filter(|c| {
|
||||
if let MessageContent::ToolRequest(req) = c {
|
||||
// Only filter out frontend tool requests
|
||||
if let Ok(tool_call) = &req.tool_call {
|
||||
return !capabilities.is_frontend_tool(&tool_call.name);
|
||||
}
|
||||
}
|
||||
true
|
||||
}).cloned().collect(),
|
||||
};
|
||||
yield filtered_response.clone();
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
|
||||
// First collect any tool requests
|
||||
let tool_requests: Vec<&ToolRequest> = response.content
|
||||
.iter()
|
||||
.filter_map(|content| content.as_tool_request())
|
||||
.collect();
|
||||
|
||||
if tool_requests.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Process tool requests depending on goose_mode
|
||||
let mut message_tool_response = Message::user();
|
||||
|
||||
// First handle any frontend tool requests
|
||||
let mut remaining_requests = Vec::new();
|
||||
for request in &tool_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
if capabilities.is_frontend_tool(&tool_call.name) {
|
||||
// Send frontend tool request and wait for response
|
||||
yield Message::assistant().with_frontend_tool_request(
|
||||
request.id.clone(),
|
||||
Ok(tool_call.clone())
|
||||
);
|
||||
|
||||
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
|
||||
message_tool_response = message_tool_response.with_tool_response(id, result);
|
||||
}
|
||||
} else {
|
||||
remaining_requests.push(request);
|
||||
}
|
||||
} else {
|
||||
remaining_requests.push(request);
|
||||
}
|
||||
}
|
||||
|
||||
// Split tool requests into enable_extension and others
|
||||
let (enable_extension_requests, non_enable_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone()
|
||||
.into_iter()
|
||||
.partition(|req| {
|
||||
req.tool_call.as_ref()
|
||||
.map(|call| call.name == "platform__enable_extension")
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
let (search_extension_requests, _non_search_extension_requests): (Vec<&ToolRequest>, Vec<&ToolRequest>) = remaining_requests.clone()
|
||||
.into_iter()
|
||||
.partition(|req| {
|
||||
req.tool_call.as_ref()
|
||||
.map(|call| call.name == "platform__search_available_extensions")
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
// Clone goose_mode once before the match to avoid move issues
|
||||
let mode = goose_mode.clone();
|
||||
|
||||
// If there are install extension requests, always require confirmation
|
||||
// or if goose_mode is approve or smart_approve, check permissions for all tools
|
||||
if !enable_extension_requests.is_empty() || mode.as_str() == "approve" || mode.as_str() == "smart_approve" {
|
||||
let mut needs_confirmation = Vec::<&ToolRequest>::new();
|
||||
let mut approved_tools = Vec::new();
|
||||
let mut llm_detect_candidates = Vec::<&ToolRequest>::new();
|
||||
let mut detected_read_only_tools = Vec::new();
|
||||
|
||||
// If approve mode or smart approve mode, check permissions for all tools
|
||||
if mode.as_str() == "approve" || mode.as_str() == "smart_approve" {
|
||||
let store = ToolPermissionStore::load()?;
|
||||
for request in &non_enable_extension_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
// Regular permission checking for other tools
|
||||
if tools_with_readonly_annotation.contains(&tool_call.name) {
|
||||
approved_tools.push((request.id.clone(), tool_call));
|
||||
} else if let Some(allowed) = store.check_permission(request) {
|
||||
if allowed {
|
||||
// Instead of executing immediately, collect approved tools
|
||||
approved_tools.push((request.id.clone(), tool_call));
|
||||
} else {
|
||||
// If the tool doesn't have any annotation, we can use llm-as-a-judge to check permission.
|
||||
if tools_without_annotation.contains(&tool_call.name) {
|
||||
llm_detect_candidates.push(request);
|
||||
}
|
||||
needs_confirmation.push(request);
|
||||
}
|
||||
} else {
|
||||
if tools_without_annotation.contains(&tool_call.name) {
|
||||
llm_detect_candidates.push(request);
|
||||
}
|
||||
needs_confirmation.push(request);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Only check read-only status for tools needing confirmation
|
||||
if !llm_detect_candidates.is_empty() && mode == "smart_approve" {
|
||||
detected_read_only_tools = detect_read_only_tools(&capabilities, llm_detect_candidates.clone()).await;
|
||||
// Remove install extensions from read-only tools
|
||||
if !enable_extension_requests.is_empty() {
|
||||
detected_read_only_tools.retain(|tool_name| {
|
||||
!enable_extension_requests.iter().any(|req| {
|
||||
req.tool_call.as_ref()
|
||||
.map(|call| call.name == *tool_name)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle pre-approved and read-only tools in parallel
|
||||
let mut tool_futures = Vec::new();
|
||||
let mut install_results = Vec::new();
|
||||
|
||||
// Handle install extension requests
|
||||
for request in &enable_extension_requests {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let confirmation = Message::user().with_enable_extension_request(
|
||||
request.id.clone(),
|
||||
tool_call.arguments.get("extension_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
);
|
||||
yield confirmation;
|
||||
|
||||
let mut rx = self.confirmation_rx.lock().await;
|
||||
while let Some((req_id, extension_confirmation)) = rx.recv().await {
|
||||
if req_id == request.id {
|
||||
if extension_confirmation.permission == Permission::AllowOnce || extension_confirmation.permission == Permission::AlwaysAllow {
|
||||
let extension_name = tool_call.arguments.get("extension_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let install_result = Self::enable_extension(&mut capabilities, extension_name, request.id.clone()).await;
|
||||
install_results.push(install_result);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process read-only tools
|
||||
for request in &needs_confirmation {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
// Skip confirmation if the tool_call.name is in the read_only_tools list
|
||||
if detected_read_only_tools.contains(&tool_call.name) {
|
||||
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
|
||||
tool_futures.push(tool_future);
|
||||
} else {
|
||||
let confirmation = Message::user().with_tool_confirmation_request(
|
||||
request.id.clone(),
|
||||
tool_call.name.clone(),
|
||||
tool_call.arguments.clone(),
|
||||
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
|
||||
);
|
||||
yield confirmation;
|
||||
|
||||
// Wait for confirmation response through the channel
|
||||
let mut rx = self.confirmation_rx.lock().await;
|
||||
while let Some((req_id, tool_confirmation)) = rx.recv().await {
|
||||
if req_id == request.id {
|
||||
let confirmed = tool_confirmation.permission == Permission::AllowOnce || tool_confirmation.permission == Permission::AlwaysAllow;
|
||||
if confirmed {
|
||||
// Add this tool call to the futures collection
|
||||
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
|
||||
tool_futures.push(tool_future);
|
||||
} else {
|
||||
// User declined - add declined response
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
Ok(vec![Content::text(
|
||||
"The user has declined to run this tool. \
|
||||
DO NOT attempt to call this tool again. \
|
||||
If there are no alternative methods to proceed, clearly explain the situation and STOP.")]),
|
||||
);
|
||||
}
|
||||
break; // Exit the loop once the matching `req_id` is found
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all tool calls to complete
|
||||
let results = futures::future::join_all(tool_futures).await;
|
||||
for (request_id, output) in results {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request_id,
|
||||
output,
|
||||
);
|
||||
}
|
||||
|
||||
// Check if any install results had errors before processing them
|
||||
let all_successful = !install_results.iter().any(|(_, result)| result.is_err());
|
||||
|
||||
for (request_id, output) in install_results {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request_id,
|
||||
output
|
||||
);
|
||||
}
|
||||
|
||||
// Update system prompt and tools if all installations were successful
|
||||
if all_successful {
|
||||
system_prompt = capabilities.get_system_prompt().await;
|
||||
tools = capabilities.get_prefixed_tools().await?;
|
||||
}
|
||||
}
|
||||
|
||||
if mode.as_str() == "auto" || !search_extension_requests.is_empty() {
|
||||
let mut tool_futures = Vec::new();
|
||||
// Process non_enable_extension_requests and search_extension_requests without duplicates
|
||||
let mut processed_ids = HashSet::new();
|
||||
|
||||
for request in non_enable_extension_requests.iter().chain(search_extension_requests.iter()) {
|
||||
if processed_ids.insert(request.id.clone()) {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
let tool_future = Self::create_tool_future(&capabilities, tool_call, request.id.clone());
|
||||
tool_futures.push(tool_future);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all tool calls to complete
|
||||
let results = futures::future::join_all(tool_futures).await;
|
||||
for (request_id, output) in results {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request_id,
|
||||
output,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if mode.as_str() == "chat" {
|
||||
// Skip all tool calls in chat mode
|
||||
// Skip search extension requests since they were already processed
|
||||
let non_search_non_enable_extension_requests = non_enable_extension_requests.iter()
|
||||
.filter(|req| {
|
||||
if let Ok(tool_call) = &req.tool_call {
|
||||
tool_call.name != "platform__search_available_extensions"
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
for request in non_search_non_enable_extension_requests {
|
||||
message_tool_response = message_tool_response.with_tool_response(
|
||||
request.id.clone(),
|
||||
Ok(vec![Content::text(
|
||||
"Let the user know the tool call was skipped in Goose chat mode. \
|
||||
DO NOT apologize for skipping the tool call. DO NOT say sorry. \
|
||||
Provide an explanation of what the tool call would do, structured as a \
|
||||
plan for the user. Again, DO NOT apologize. \
|
||||
**Example Plan:**\n \
|
||||
1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \
|
||||
2. **Outline Steps** - Break down the steps.\n \
|
||||
If needed, adjust the explanation based on user preferences or questions."
|
||||
)]),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
yield message_tool_response.clone();
|
||||
|
||||
messages.push(response);
|
||||
messages.push(message_tool_response);
|
||||
},
|
||||
Err(ProviderError::ContextLengthExceeded(_)) => {
|
||||
if truncation_attempt >= MAX_TRUNCATION_ATTEMPTS {
|
||||
// Create an error message & terminate the stream
|
||||
// the previous message would have been a user message (e.g. before any tool calls, this is just after the input message.
|
||||
// at the start of a loop after a tool call, it would be after a tool_use assistant followed by a tool_result user)
|
||||
yield Message::assistant().with_text("Error: Context length exceeds limits even after multiple attempts to truncate. Please start a new session with fresh context and try again.");
|
||||
break;
|
||||
}
|
||||
|
||||
truncation_attempt += 1;
|
||||
warn!("Context length exceeded. Truncation Attempt: {}/{}.", truncation_attempt, MAX_TRUNCATION_ATTEMPTS);
|
||||
|
||||
// Decay the estimate factor as we make more truncation attempts
|
||||
// Estimate factor decays like this over time: 0.9, 0.81, 0.729, ...
|
||||
let estimate_factor: f32 = ESTIMATE_FACTOR_DECAY.powi(truncation_attempt as i32);
|
||||
|
||||
// release the lock before truncation to prevent deadlock
|
||||
drop(capabilities);
|
||||
|
||||
if let Err(err) = self.truncate_messages(&mut messages, estimate_factor, &system_prompt, &mut tools).await {
|
||||
yield Message::assistant().with_text(format!("Error: Unable to truncate messages to stay within context limit. \n\nRan into this error: {}.\n\nPlease start a new session with fresh context and try again.", err));
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Re-acquire the lock
|
||||
capabilities = self.capabilities.lock().await;
|
||||
|
||||
// Retry the loop after truncation
|
||||
continue;
|
||||
},
|
||||
Err(e) => {
|
||||
// Create an error message & terminate the stream
|
||||
error!("Error: {}", e);
|
||||
yield Message::assistant().with_text(format!("Ran into this error: {e}.\n\nPlease retry if you think this is a transient or recoverable error."));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Yield control back to the scheduler to prevent blocking
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
async fn extend_system_prompt(&mut self, extension: String) {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.add_system_prompt_extension(extension);
|
||||
}
|
||||
|
||||
async fn override_system_prompt(&mut self, template: String) {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
capabilities.set_system_prompt_override(template);
|
||||
}
|
||||
|
||||
async fn list_extension_prompts(&self) -> HashMap<String, Vec<Prompt>> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities
|
||||
.list_prompts()
|
||||
.await
|
||||
.expect("Failed to list prompts")
|
||||
}
|
||||
|
||||
async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
|
||||
// First find which extension has this prompt
|
||||
let prompts = capabilities
|
||||
.list_prompts()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to list prompts: {}", e))?;
|
||||
|
||||
if let Some(extension) = prompts
|
||||
.iter()
|
||||
.find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name))
|
||||
.map(|(extension, _)| extension)
|
||||
{
|
||||
return capabilities
|
||||
.get_prompt(extension, name, arguments)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to get prompt: {}", e));
|
||||
}
|
||||
|
||||
Err(anyhow!("Prompt '{}' not found", name))
|
||||
}
|
||||
|
||||
async fn get_plan_prompt(&self) -> anyhow::Result<String> {
|
||||
let mut capabilities = self.capabilities.lock().await;
|
||||
let tools = capabilities.get_prefixed_tools().await?;
|
||||
let tools_info = tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
ToolInfo::new(
|
||||
&tool.name,
|
||||
&tool.description,
|
||||
get_parameter_names(&tool),
|
||||
None,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let plan_prompt = capabilities.get_planning_prompt(tools_info).await;
|
||||
|
||||
Ok(plan_prompt)
|
||||
}
|
||||
|
||||
async fn provider(&self) -> Arc<Box<dyn Provider>> {
|
||||
let capabilities = self.capabilities.lock().await;
|
||||
capabilities.provider()
|
||||
}
|
||||
|
||||
async fn handle_tool_result(&self, id: String, result: ToolResult<Vec<Content>>) {
|
||||
if let Err(e) = self.tool_result_tx.send((id, result)).await {
|
||||
tracing::error!("Failed to send tool result: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
register_agent!("truncate", TruncateAgent);
|
||||
@@ -1,6 +1,25 @@
|
||||
use mcp_core::{Content, ToolResult};
|
||||
use crate::session;
|
||||
use mcp_core::{Content, Tool, ToolResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
/// Type alias for the tool result channel receiver
|
||||
pub type ToolResultReceiver = Arc<Mutex<mpsc::Receiver<(String, ToolResult<Vec<Content>>)>>>;
|
||||
|
||||
/// A frontend tool that will be executed by the frontend rather than an extension
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FrontendTool {
|
||||
pub name: String,
|
||||
pub tool: Tool,
|
||||
}
|
||||
|
||||
/// Session configuration for an agent
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionConfig {
|
||||
/// Unique identifier for the session
|
||||
pub id: session::Identifier,
|
||||
/// Working directory for the session
|
||||
pub working_dir: PathBuf,
|
||||
}
|
||||
|
||||
@@ -25,9 +25,9 @@ pub fn name_to_key(name: &str) -> String {
|
||||
}
|
||||
|
||||
/// Extension configuration management
|
||||
pub struct ExtensionManager;
|
||||
pub struct ExtensionConfigManager;
|
||||
|
||||
impl ExtensionManager {
|
||||
impl ExtensionConfigManager {
|
||||
/// Get the extension configuration if enabled -- uses key
|
||||
pub fn get_config(key: &str) -> Result<Option<ExtensionConfig>> {
|
||||
let config = Config::global();
|
||||
|
||||
@@ -6,7 +6,7 @@ pub mod permission;
|
||||
pub use crate::agents::ExtensionConfig;
|
||||
pub use base::{Config, ConfigError, APP_STRATEGY};
|
||||
pub use experiments::ExperimentManager;
|
||||
pub use extensions::{ExtensionEntry, ExtensionManager};
|
||||
pub use extensions::{ExtensionConfigManager, ExtensionEntry};
|
||||
pub use permission::PermissionManager;
|
||||
|
||||
pub use extensions::DEFAULT_DISPLAY_NAME;
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::agents::Capabilities;
|
||||
use crate::message::Message;
|
||||
use crate::providers::base::Provider;
|
||||
use crate::token_counter::TokenCounter;
|
||||
use anyhow::{anyhow, Result};
|
||||
use std::sync::Arc;
|
||||
use tracing::debug;
|
||||
|
||||
const SYSTEM_PROMPT: &str = "You are good at summarizing.";
|
||||
@@ -12,17 +13,13 @@ fn create_summarize_request(messages: &[Message]) -> Vec<Message> {
|
||||
]
|
||||
}
|
||||
async fn single_request(
|
||||
capabilities: &Capabilities,
|
||||
provider: &Arc<dyn Provider>,
|
||||
messages: &[Message],
|
||||
) -> Result<Message, anyhow::Error> {
|
||||
Ok(capabilities
|
||||
.provider()
|
||||
.complete(SYSTEM_PROMPT, messages, &[])
|
||||
.await?
|
||||
.0)
|
||||
Ok(provider.complete(SYSTEM_PROMPT, messages, &[]).await?.0)
|
||||
}
|
||||
async fn memory_condense(
|
||||
capabilities: &Capabilities,
|
||||
provider: Arc<dyn Provider>,
|
||||
token_counter: &TokenCounter,
|
||||
messages: &mut Vec<Message>,
|
||||
token_counts: &mut Vec<usize>,
|
||||
@@ -67,9 +64,7 @@ async fn memory_condense(
|
||||
|
||||
diff = -(current_tokens as isize);
|
||||
let request = create_summarize_request(&batch);
|
||||
let response_text = single_request(capabilities, &request)
|
||||
.await?
|
||||
.as_concat_text();
|
||||
let response_text = single_request(&provider, &request).await?.as_concat_text();
|
||||
|
||||
// Ensure the conversation starts with a User message
|
||||
let curr_messages = vec![
|
||||
@@ -95,8 +90,9 @@ async fn memory_condense(
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO: currently not used. we will add this is a feature flag under context mgmt
|
||||
pub async fn condense_messages(
|
||||
capabilities: &Capabilities,
|
||||
provider: Arc<dyn Provider>,
|
||||
token_counter: &TokenCounter,
|
||||
messages: &mut Vec<Message>,
|
||||
token_counts: &mut Vec<usize>,
|
||||
@@ -108,7 +104,7 @@ pub async fn condense_messages(
|
||||
// The compressor should determine whether we need to compress the messages or not. This
|
||||
// function just checks if the limit is satisfied.
|
||||
memory_condense(
|
||||
capabilities,
|
||||
provider,
|
||||
token_counter,
|
||||
messages,
|
||||
token_counts,
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
use crate::agents::capabilities::Capabilities;
|
||||
use crate::config::permission::PermissionLevel;
|
||||
use crate::config::PermissionManager;
|
||||
use crate::message::{Message, MessageContent, ToolRequest};
|
||||
use crate::providers::base::Provider;
|
||||
use chrono::Utc;
|
||||
use indoc::indoc;
|
||||
use mcp_core::tool::ToolAnnotations;
|
||||
use mcp_core::{tool::Tool, TextContent};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Creates the tool definition for checking read-only permissions.
|
||||
fn create_read_only_tool() -> Tool {
|
||||
@@ -113,7 +114,7 @@ fn extract_read_only_tools(response: &Message) -> Option<Vec<String>> {
|
||||
|
||||
/// Executes the read-only tools detection and returns the list of tools with read-only operations.
|
||||
pub async fn detect_read_only_tools(
|
||||
capabilities: &Capabilities,
|
||||
provider: Arc<dyn Provider>,
|
||||
tool_requests: Vec<&ToolRequest>,
|
||||
) -> Vec<String> {
|
||||
if tool_requests.is_empty() {
|
||||
@@ -122,8 +123,7 @@ pub async fn detect_read_only_tools(
|
||||
let tool = create_read_only_tool();
|
||||
let check_messages = create_check_messages(tool_requests);
|
||||
|
||||
let res = capabilities
|
||||
.provider()
|
||||
let res = provider
|
||||
.complete(
|
||||
"You are a good analyst and can detect operations whether they have read-only operations.",
|
||||
&check_messages,
|
||||
@@ -152,7 +152,7 @@ pub async fn check_tool_permissions(
|
||||
tools_with_readonly_annotation: HashSet<String>,
|
||||
tools_without_annotation: HashSet<String>,
|
||||
permission_manager: &mut PermissionManager,
|
||||
capabilities: &Capabilities,
|
||||
provider: Arc<dyn Provider>,
|
||||
) -> PermissionCheckResult {
|
||||
let mut approved = vec![];
|
||||
let mut needs_approval = vec![];
|
||||
@@ -206,7 +206,7 @@ pub async fn check_tool_permissions(
|
||||
// 3. LLM detect
|
||||
if !llm_detect_candidates.is_empty() && mode == "smart_approve" {
|
||||
let detected_readonly_tools =
|
||||
detect_read_only_tools(capabilities, llm_detect_candidates.iter().collect()).await;
|
||||
detect_read_only_tools(provider, llm_detect_candidates.iter().collect()).await;
|
||||
for request in llm_detect_candidates {
|
||||
if let Ok(tool_call) = request.tool_call.clone() {
|
||||
if detected_readonly_tools.contains(&tool_call.name) {
|
||||
@@ -236,7 +236,6 @@ pub async fn check_tool_permissions(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::agents::capabilities::Capabilities;
|
||||
use crate::message::{Message, MessageContent, ToolRequest};
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::{Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||
@@ -287,12 +286,12 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn create_mock_capabilities() -> Capabilities {
|
||||
fn create_mock_provider() -> Arc<dyn Provider> {
|
||||
let mock_model_config =
|
||||
ModelConfig::new("test-model".to_string()).with_context_limit(200_000.into());
|
||||
Capabilities::new(Box::new(MockProvider {
|
||||
Arc::new(MockProvider {
|
||||
model_config: mock_model_config,
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -349,7 +348,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_detect_read_only_tools() {
|
||||
let capabilities = create_mock_capabilities();
|
||||
let provider = create_mock_provider();
|
||||
let tool_request = ToolRequest {
|
||||
id: "tool_1".to_string(),
|
||||
tool_call: ToolResult::Ok(ToolCall {
|
||||
@@ -358,14 +357,14 @@ mod tests {
|
||||
}),
|
||||
};
|
||||
|
||||
let result = detect_read_only_tools(&capabilities, vec![&tool_request]).await;
|
||||
let result = detect_read_only_tools(provider, vec![&tool_request]).await;
|
||||
assert_eq!(result, vec!["file_reader", "data_fetcher"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_detect_read_only_tools_empty_requests() {
|
||||
let capabilities = create_mock_capabilities();
|
||||
let result = detect_read_only_tools(&capabilities, vec![]).await;
|
||||
let provider = create_mock_provider();
|
||||
let result = detect_read_only_tools(provider, vec![]).await;
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
|
||||
@@ -375,7 +374,7 @@ mod tests {
|
||||
let temp_file = NamedTempFile::new().unwrap();
|
||||
let temp_path = temp_file.path();
|
||||
let mut permission_manager = PermissionManager::new(temp_path);
|
||||
let capabilities = create_mock_capabilities();
|
||||
let provider = create_mock_provider();
|
||||
|
||||
let tools_with_readonly_annotation: HashSet<String> =
|
||||
vec!["file_reader".to_string()].into_iter().collect();
|
||||
@@ -411,7 +410,7 @@ mod tests {
|
||||
tools_with_readonly_annotation,
|
||||
tools_without_annotation,
|
||||
&mut permission_manager,
|
||||
&capabilities,
|
||||
provider,
|
||||
)
|
||||
.await;
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::{
|
||||
anthropic::AnthropicProvider,
|
||||
azure::AzureProvider,
|
||||
@@ -29,18 +31,19 @@ pub fn providers() -> Vec<ProviderMetadata> {
|
||||
]
|
||||
}
|
||||
|
||||
pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send + Sync>> {
|
||||
pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
|
||||
// We use Arc instead of Box to be able to clone for multiple async tasks
|
||||
match name {
|
||||
"openai" => Ok(Box::new(OpenAiProvider::from_env(model)?)),
|
||||
"anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)),
|
||||
"azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)),
|
||||
"aws_bedrock" => Ok(Box::new(BedrockProvider::from_env(model)?)),
|
||||
"databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)),
|
||||
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
|
||||
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
|
||||
"openrouter" => Ok(Box::new(OpenRouterProvider::from_env(model)?)),
|
||||
"gcp_vertex_ai" => Ok(Box::new(GcpVertexAIProvider::from_env(model)?)),
|
||||
"google" => Ok(Box::new(GoogleProvider::from_env(model)?)),
|
||||
"openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)),
|
||||
"anthropic" => Ok(Arc::new(AnthropicProvider::from_env(model)?)),
|
||||
"azure_openai" => Ok(Arc::new(AzureProvider::from_env(model)?)),
|
||||
"aws_bedrock" => Ok(Arc::new(BedrockProvider::from_env(model)?)),
|
||||
"databricks" => Ok(Arc::new(DatabricksProvider::from_env(model)?)),
|
||||
"groq" => Ok(Arc::new(GroqProvider::from_env(model)?)),
|
||||
"ollama" => Ok(Arc::new(OllamaProvider::from_env(model)?)),
|
||||
"openrouter" => Ok(Arc::new(OpenRouterProvider::from_env(model)?)),
|
||||
"gcp_vertex_ai" => Ok(Arc::new(GcpVertexAIProvider::from_env(model)?)),
|
||||
"google" => Ok(Arc::new(GoogleProvider::from_env(model)?)),
|
||||
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,7 +243,7 @@ pub fn read_metadata(session_file: &Path) -> Result<SessionMetadata> {
|
||||
pub async fn persist_messages(
|
||||
session_file: &Path,
|
||||
messages: &[Message],
|
||||
provider: Option<Arc<Box<dyn Provider>>>,
|
||||
provider: Option<Arc<dyn Provider>>,
|
||||
) -> Result<()> {
|
||||
// Count user messages
|
||||
let user_message_count = messages
|
||||
@@ -255,7 +255,7 @@ pub async fn persist_messages(
|
||||
match provider {
|
||||
Some(provider) if user_message_count < 4 => {
|
||||
//generate_description is responsible for writing the messages
|
||||
generate_description(session_file, messages, provider.as_ref().as_ref()).await
|
||||
generate_description(session_file, messages, provider).await
|
||||
}
|
||||
_ => {
|
||||
// Read existing metadata
|
||||
@@ -298,7 +298,7 @@ pub fn save_messages_with_metadata(
|
||||
pub async fn generate_description(
|
||||
session_file: &Path,
|
||||
messages: &[Message],
|
||||
provider: &dyn Provider,
|
||||
provider: Arc<dyn Provider>,
|
||||
) -> Result<()> {
|
||||
// Create a special message asking for a 3-word description
|
||||
let mut description_prompt = "Based on the conversation so far, provide a concise description of this session in 4 words or less. This will be used for finding the session later in a UI with limited space - reply *ONLY* with the description".to_string();
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
// src/lib.rs or tests/truncate_agent_tests.rs
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use futures::StreamExt;
|
||||
use goose::agents::AgentFactory;
|
||||
use goose::agents::Agent;
|
||||
use goose::message::Message;
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::Provider;
|
||||
@@ -65,18 +67,18 @@ impl ProviderType {
|
||||
}
|
||||
}
|
||||
|
||||
fn create_provider(&self, model_config: ModelConfig) -> Result<Box<dyn Provider>> {
|
||||
fn create_provider(&self, model_config: ModelConfig) -> Result<Arc<dyn Provider>> {
|
||||
Ok(match self {
|
||||
ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?),
|
||||
ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?),
|
||||
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
|
||||
ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?),
|
||||
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
|
||||
ProviderType::GcpVertexAI => Box::new(GcpVertexAIProvider::from_env(model_config)?),
|
||||
ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?),
|
||||
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
|
||||
ProviderType::Ollama => Box::new(OllamaProvider::from_env(model_config)?),
|
||||
ProviderType::OpenRouter => Box::new(OpenRouterProvider::from_env(model_config)?),
|
||||
ProviderType::Azure => Arc::new(AzureProvider::from_env(model_config)?),
|
||||
ProviderType::OpenAi => Arc::new(OpenAiProvider::from_env(model_config)?),
|
||||
ProviderType::Anthropic => Arc::new(AnthropicProvider::from_env(model_config)?),
|
||||
ProviderType::Bedrock => Arc::new(BedrockProvider::from_env(model_config)?),
|
||||
ProviderType::Databricks => Arc::new(DatabricksProvider::from_env(model_config)?),
|
||||
ProviderType::GcpVertexAI => Arc::new(GcpVertexAIProvider::from_env(model_config)?),
|
||||
ProviderType::Google => Arc::new(GoogleProvider::from_env(model_config)?),
|
||||
ProviderType::Groq => Arc::new(GroqProvider::from_env(model_config)?),
|
||||
ProviderType::Ollama => Arc::new(OllamaProvider::from_env(model_config)?),
|
||||
ProviderType::OpenRouter => Arc::new(OpenRouterProvider::from_env(model_config)?),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -108,7 +110,7 @@ async fn run_truncate_test(
|
||||
.with_temperature(Some(0.0));
|
||||
let provider = provider_type.create_provider(model_config)?;
|
||||
|
||||
let agent = AgentFactory::create("truncate", provider).unwrap();
|
||||
let agent = Agent::new(provider);
|
||||
let repeat_count = context_window + 10_000;
|
||||
let large_message_content = "hello ".repeat(repeat_count);
|
||||
let messages = vec![
|
||||
@@ -185,7 +187,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_openai() -> Result<()> {
|
||||
async fn test_agent_with_openai() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::OpenAi,
|
||||
model: "o3-mini-low",
|
||||
@@ -195,7 +197,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_azure() -> Result<()> {
|
||||
async fn test_agent_with_azure() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::Azure,
|
||||
model: "gpt-4o-mini",
|
||||
@@ -205,7 +207,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_anthropic() -> Result<()> {
|
||||
async fn test_agent_with_anthropic() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::Anthropic,
|
||||
model: "claude-3-5-haiku-latest",
|
||||
@@ -215,7 +217,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_bedrock() -> Result<()> {
|
||||
async fn test_agent_with_bedrock() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::Bedrock,
|
||||
model: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
@@ -225,7 +227,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_databricks() -> Result<()> {
|
||||
async fn test_agent_with_databricks() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::Databricks,
|
||||
model: "databricks-meta-llama-3-3-70b-instruct",
|
||||
@@ -235,7 +237,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_databricks_bedrock() -> Result<()> {
|
||||
async fn test_agent_with_databricks_bedrock() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::Databricks,
|
||||
model: "claude-3-5-sonnet-2",
|
||||
@@ -245,7 +247,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_databricks_openai() -> Result<()> {
|
||||
async fn test_agent_with_databricks_openai() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::Databricks,
|
||||
model: "gpt-4o-mini",
|
||||
@@ -255,7 +257,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_google() -> Result<()> {
|
||||
async fn test_agent_with_google() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::Google,
|
||||
model: "gemini-2.0-flash-exp",
|
||||
@@ -265,7 +267,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_groq() -> Result<()> {
|
||||
async fn test_agent_with_groq() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::Groq,
|
||||
model: "gemma2-9b-it",
|
||||
@@ -275,7 +277,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_openrouter() -> Result<()> {
|
||||
async fn test_agent_with_openrouter() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::OpenRouter,
|
||||
model: "deepseek/deepseek-r1",
|
||||
@@ -285,7 +287,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_ollama() -> Result<()> {
|
||||
async fn test_agent_with_ollama() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::Ollama,
|
||||
model: "llama3.2",
|
||||
@@ -295,7 +297,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_gcpvertexai() -> Result<()> {
|
||||
async fn test_agent_with_gcpvertexai() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::GcpVertexAI,
|
||||
model: "claude-3-5-sonnet-v2@20241022",
|
||||
@@ -77,14 +77,14 @@ lazy_static::lazy_static! {
|
||||
|
||||
/// Generic test harness for any Provider implementation
|
||||
struct ProviderTester {
|
||||
provider: Box<dyn Provider + Send + Sync>,
|
||||
provider: Arc<dyn Provider>,
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl ProviderTester {
|
||||
fn new<T: Provider + Send + Sync + 'static>(provider: T, name: String) -> Self {
|
||||
Self {
|
||||
provider: Box::new(provider),
|
||||
provider: Arc::new(provider),
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user