From 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721 Mon Sep 17 00:00:00 2001 From: Michael Neale Date: Thu, 5 Jun 2025 13:55:32 +1000 Subject: [PATCH] feat: lead/worker model (#2719) --- README.md | 25 + crates/goose-cli/src/session/builder.rs | 25 +- crates/goose-cli/src/session/output.rs | 53 +- crates/goose/src/providers/base.rs | 12 + crates/goose/src/providers/factory.rs | 284 ++++++++ crates/goose/src/providers/lead_worker.rs | 637 ++++++++++++++++++ crates/goose/src/providers/mod.rs | 1 + .../docs/guides/environment-variables.md | 31 +- test_lead_worker.sh | 31 + 9 files changed, 1088 insertions(+), 11 deletions(-) create mode 100644 crates/goose/src/providers/lead_worker.rs create mode 100755 test_lead_worker.sh diff --git a/README.md b/README.md index f2baddfe..ab0c9123 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,31 @@ Whether you're prototyping an idea, refining existing code, or managing intricat Designed for maximum flexibility, goose works with any LLM, seamlessly integrates with MCP servers, and is available as both a desktop app as well as CLI - making it the ultimate AI assistant for developers who want to move faster and focus on innovation. +## Multiple Model Configuration + +goose supports using different models for different purposes to optimize performance and cost, which can work across model providers as well as models. + +### Lead/Worker Model Pattern +Use a powerful model for initial planning and complex reasoning, then switch to a faster/cheaper model for execution, this happens automatically by goose: + +```bash +# Required: Enable lead model mode +export GOOSE_LEAD_MODEL=modelY +# Optional: configure a provider for the lead model if not the default provider +export GOOSE_LEAD_PROVIDER=providerX # Defaults to main provider +``` + +### Planning Model Configuration +Use a specialized model for the `/plan` command in CLI mode, this is explicitly invoked when you want to plan (vs execute) + +```bash +# Optional: Use different model for planning +export GOOSE_PLANNER_PROVIDER=openai +export GOOSE_PLANNER_MODEL=gpt-4 +``` + +Both patterns help you balance model capabilities with cost and speed for optimal results, and switch between models and vendors as required. + # Quick Links - [Quickstart](https://block.github.io/goose/docs/quickstart) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index f7cfeba7..1190220b 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -7,6 +7,7 @@ use goose::session; use goose::session::Identifier; use mcp_client::transport::Error as McpClientError; use std::process; +use std::sync::Arc; use super::output; use super::Session; @@ -55,6 +56,22 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { // Create the agent let agent: Agent = Agent::new(); let new_provider = create(&provider_name, model_config).unwrap(); + + // Keep a reference to the provider for display_session_info + let provider_for_display = Arc::clone(&new_provider); + + // Log model information at startup + if let Some(lead_worker) = new_provider.as_lead_worker() { + let (lead_model, worker_model) = lead_worker.get_model_info(); + tracing::info!( + "🤖 Lead/Worker Mode Enabled: Lead model (first 3 turns): {}, Worker model (turn 4+): {}, Auto-fallback on failures: Enabled", + lead_model, + worker_model + ); + } else { + tracing::info!("🤖 Using model: {}", model); + } + agent .update_provider(new_provider) .await @@ -217,6 +234,12 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { session.agent.override_system_prompt(override_prompt).await; } - output::display_session_info(session_config.resume, &provider_name, &model, &session_file); + output::display_session_info( + session_config.resume, + &provider_name, + &model, + &session_file, + Some(&provider_for_display), + ); session } diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index ea822c55..525faa74 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -10,6 +10,7 @@ use std::cell::RefCell; use std::collections::HashMap; use std::io::Error; use std::path::Path; +use std::sync::Arc; use std::time::Duration; // Re-export theme for use in main @@ -536,7 +537,13 @@ fn shorten_path(path: &str, debug: bool) -> String { } // Session display functions -pub fn display_session_info(resume: bool, provider: &str, model: &str, session_file: &Path) { +pub fn display_session_info( + resume: bool, + provider: &str, + model: &str, + session_file: &Path, + provider_instance: Option<&Arc>, +) { let start_session_msg = if resume { "resuming session |" } else if session_file.to_str() == Some("/dev/null") || session_file.to_str() == Some("NUL") { @@ -544,14 +551,42 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f } else { "starting session |" }; - println!( - "{} {} {} {} {}", - style(start_session_msg).dim(), - style("provider:").dim(), - style(provider).cyan().dim(), - style("model:").dim(), - style(model).cyan().dim(), - ); + + // Check if we have lead/worker mode + if let Some(provider_inst) = provider_instance { + if let Some(lead_worker) = provider_inst.as_lead_worker() { + let (lead_model, worker_model) = lead_worker.get_model_info(); + println!( + "{} {} {} {} {} {} {}", + style(start_session_msg).dim(), + style("provider:").dim(), + style(provider).cyan().dim(), + style("lead model:").dim(), + style(&lead_model).cyan().dim(), + style("worker model:").dim(), + style(&worker_model).cyan().dim(), + ); + } else { + println!( + "{} {} {} {} {}", + style(start_session_msg).dim(), + style("provider:").dim(), + style(provider).cyan().dim(), + style("model:").dim(), + style(model).cyan().dim(), + ); + } + } else { + // Fallback to original behavior if no provider instance + println!( + "{} {} {} {} {}", + style(start_session_msg).dim(), + style("provider:").dim(), + style(provider).cyan().dim(), + style("model:").dim(), + style(model).cyan().dim(), + ); + } if session_file.to_str() != Some("/dev/null") && session_file.to_str() != Some("NUL") { println!( diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index c7062642..2059ab00 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -148,6 +148,12 @@ impl Usage { use async_trait::async_trait; +/// Trait for LeadWorkerProvider-specific functionality +pub trait LeadWorkerProviderTrait { + /// Get information about the lead and worker models for logging + fn get_model_info(&self) -> (String, String); +} + /// Base trait for AI providers (OpenAI, Anthropic, etc) #[async_trait] pub trait Provider: Send + Sync { @@ -195,6 +201,12 @@ pub trait Provider: Send + Sync { "This provider does not support embeddings".to_string(), )) } + + /// Check if this provider is a LeadWorkerProvider + /// This is used for logging model information at startup + fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> { + None + } } #[cfg(test)] diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index 42d7e69b..22bdaa95 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -10,6 +10,7 @@ use super::{ githubcopilot::GithubCopilotProvider, google::GoogleProvider, groq::GroqProvider, + lead_worker::LeadWorkerProvider, ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, @@ -19,6 +20,21 @@ use super::{ use crate::model::ModelConfig; use anyhow::Result; +#[cfg(test)] +use super::errors::ProviderError; +#[cfg(test)] +use mcp_core::tool::Tool; + +fn default_lead_turns() -> usize { + 3 +} +fn default_failure_threshold() -> usize { + 2 +} +fn default_fallback_turns() -> usize { + 2 +} + pub fn providers() -> Vec { vec![ AnthropicProvider::metadata(), @@ -38,6 +54,62 @@ pub fn providers() -> Vec { } pub fn create(name: &str, model: ModelConfig) -> Result> { + let config = crate::config::Config::global(); + + // Check for lead model environment variables + if let Ok(lead_model_name) = config.get_param::("GOOSE_LEAD_MODEL") { + tracing::info!("Creating lead/worker provider from environment variables"); + + return create_lead_worker_from_env(name, &model, &lead_model_name); + } + + // Default: create regular provider + create_provider(name, model) +} + +/// Create a lead/worker provider from environment variables +fn create_lead_worker_from_env( + default_provider_name: &str, + default_model: &ModelConfig, + lead_model_name: &str, +) -> Result> { + let config = crate::config::Config::global(); + + // Get lead provider (optional, defaults to main provider) + let lead_provider_name = config + .get_param::("GOOSE_LEAD_PROVIDER") + .unwrap_or_else(|_| default_provider_name.to_string()); + + // Get configuration parameters with defaults + let lead_turns = config + .get_param::("GOOSE_LEAD_TURNS") + .unwrap_or(default_lead_turns()); + let failure_threshold = config + .get_param::("GOOSE_LEAD_FAILURE_THRESHOLD") + .unwrap_or(default_failure_threshold()); + let fallback_turns = config + .get_param::("GOOSE_LEAD_FALLBACK_TURNS") + .unwrap_or(default_fallback_turns()); + + // Create model configs + let lead_model_config = ModelConfig::new(lead_model_name.to_string()); + let worker_model_config = default_model.clone(); + + // Create the providers + let lead_provider = create_provider(&lead_provider_name, lead_model_config)?; + let worker_provider = create_provider(default_provider_name, worker_model_config)?; + + // Create the lead/worker provider with configured settings + Ok(Arc::new(LeadWorkerProvider::new_with_settings( + lead_provider, + worker_provider, + lead_turns, + failure_threshold, + fallback_turns, + ))) +} + +fn create_provider(name: &str, model: ModelConfig) -> Result> { // We use Arc instead of Box to be able to clone for multiple async tasks match name { "openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)), @@ -56,3 +128,215 @@ pub fn create(name: &str, model: ModelConfig) -> Result> { _ => Err(anyhow::anyhow!("Unknown provider: {}", name)), } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::{Message, MessageContent}; + use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage}; + use chrono::Utc; + use mcp_core::{content::TextContent, Role}; + use std::env; + + #[derive(Clone)] + struct MockTestProvider { + name: String, + model_config: ModelConfig, + } + + #[async_trait::async_trait] + impl Provider for MockTestProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::new( + "mock_test", + "Mock Test Provider", + "A mock provider for testing", + "mock-model", + vec!["mock-model"], + "", + vec![], + ) + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + + async fn complete( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + Ok(( + Message { + role: Role::Assistant, + created: Utc::now().timestamp(), + content: vec![MessageContent::Text(TextContent { + text: format!( + "Response from {} with model {}", + self.name, self.model_config.model_name + ), + annotations: None, + })], + }, + ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()), + )) + } + } + + #[test] + fn test_create_lead_worker_provider() { + // Save current env vars + let saved_lead = env::var("GOOSE_LEAD_MODEL").ok(); + let saved_provider = env::var("GOOSE_LEAD_PROVIDER").ok(); + let saved_turns = env::var("GOOSE_LEAD_TURNS").ok(); + + // Test with basic lead model configuration + env::set_var("GOOSE_LEAD_MODEL", "gpt-4o"); + + // This will try to create a lead/worker provider + let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string())); + + // The creation might succeed or fail depending on API keys, but we can verify the logic path + match result { + Ok(_) => { + // If it succeeds, it means we created a lead/worker provider successfully + // This would happen if API keys are available in the test environment + } + Err(error) => { + // If it fails, it should be due to missing API keys, confirming we tried to create providers + let error_msg = error.to_string(); + assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret")); + } + } + + // Test with different lead provider + env::set_var("GOOSE_LEAD_PROVIDER", "anthropic"); + env::set_var("GOOSE_LEAD_TURNS", "5"); + + let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string())); + // Similar validation as above - will fail due to missing API keys but confirms the logic + + // Restore env vars + match saved_lead { + Some(val) => env::set_var("GOOSE_LEAD_MODEL", val), + None => env::remove_var("GOOSE_LEAD_MODEL"), + } + match saved_provider { + Some(val) => env::set_var("GOOSE_LEAD_PROVIDER", val), + None => env::remove_var("GOOSE_LEAD_PROVIDER"), + } + match saved_turns { + Some(val) => env::set_var("GOOSE_LEAD_TURNS", val), + None => env::remove_var("GOOSE_LEAD_TURNS"), + } + } + + #[test] + fn test_lead_model_env_vars_with_defaults() { + // Save current env vars + let saved_vars = [ + ("GOOSE_LEAD_MODEL", env::var("GOOSE_LEAD_MODEL").ok()), + ("GOOSE_LEAD_PROVIDER", env::var("GOOSE_LEAD_PROVIDER").ok()), + ("GOOSE_LEAD_TURNS", env::var("GOOSE_LEAD_TURNS").ok()), + ( + "GOOSE_LEAD_FAILURE_THRESHOLD", + env::var("GOOSE_LEAD_FAILURE_THRESHOLD").ok(), + ), + ( + "GOOSE_LEAD_FALLBACK_TURNS", + env::var("GOOSE_LEAD_FALLBACK_TURNS").ok(), + ), + ]; + + // Clear all lead env vars + for (key, _) in &saved_vars { + env::remove_var(key); + } + + // Set only the required lead model + env::set_var("GOOSE_LEAD_MODEL", "gpt-4o"); + + // This should use defaults for all other values + let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string())); + + // Should attempt to create lead/worker provider (will fail due to missing API keys but confirms logic) + match result { + Ok(_) => { + // Success means we have API keys and created the provider + } + Err(error) => { + // Should fail due to missing API keys, confirming we tried to create providers + let error_msg = error.to_string(); + assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret")); + } + } + + // Test with custom values + env::set_var("GOOSE_LEAD_TURNS", "7"); + env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", "4"); + env::set_var("GOOSE_LEAD_FALLBACK_TURNS", "3"); + + let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string())); + // Should still attempt to create lead/worker provider with custom settings + + // Restore all env vars + for (key, value) in saved_vars { + match value { + Some(val) => env::set_var(key, val), + None => env::remove_var(key), + } + } + } + + #[test] + fn test_create_regular_provider_without_lead_config() { + // Save current env vars + let saved_lead = env::var("GOOSE_LEAD_MODEL").ok(); + let saved_provider = env::var("GOOSE_LEAD_PROVIDER").ok(); + let saved_turns = env::var("GOOSE_LEAD_TURNS").ok(); + let saved_threshold = env::var("GOOSE_LEAD_FAILURE_THRESHOLD").ok(); + let saved_fallback = env::var("GOOSE_LEAD_FALLBACK_TURNS").ok(); + + // Ensure all GOOSE_LEAD_* variables are not set + env::remove_var("GOOSE_LEAD_MODEL"); + env::remove_var("GOOSE_LEAD_PROVIDER"); + env::remove_var("GOOSE_LEAD_TURNS"); + env::remove_var("GOOSE_LEAD_FAILURE_THRESHOLD"); + env::remove_var("GOOSE_LEAD_FALLBACK_TURNS"); + + // This should try to create a regular provider + let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string())); + + // The creation might succeed or fail depending on API keys + match result { + Ok(_) => { + // If it succeeds, it means we created a regular provider successfully + // This would happen if API keys are available in the test environment + } + Err(error) => { + // If it fails, it should be due to missing API keys + let error_msg = error.to_string(); + assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret")); + } + } + + // Restore env vars + if let Some(val) = saved_lead { + env::set_var("GOOSE_LEAD_MODEL", val); + } + if let Some(val) = saved_provider { + env::set_var("GOOSE_LEAD_PROVIDER", val); + } + if let Some(val) = saved_turns { + env::set_var("GOOSE_LEAD_TURNS", val); + } + if let Some(val) = saved_threshold { + env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", val); + } + if let Some(val) = saved_fallback { + env::set_var("GOOSE_LEAD_FALLBACK_TURNS", val); + } + } +} diff --git a/crates/goose/src/providers/lead_worker.rs b/crates/goose/src/providers/lead_worker.rs new file mode 100644 index 00000000..a242dcb9 --- /dev/null +++ b/crates/goose/src/providers/lead_worker.rs @@ -0,0 +1,637 @@ +use anyhow::Result; +use async_trait::async_trait; +use std::sync::Arc; +use tokio::sync::Mutex; + +use super::base::{LeadWorkerProviderTrait, Provider, ProviderMetadata, ProviderUsage}; +use super::errors::ProviderError; +use crate::message::{Message, MessageContent}; +use crate::model::ModelConfig; +use mcp_core::{tool::Tool, Content}; + +/// A provider that switches between a lead model and a worker model based on turn count +/// and can fallback to lead model on consecutive failures +pub struct LeadWorkerProvider { + lead_provider: Arc, + worker_provider: Arc, + lead_turns: usize, + turn_count: Arc>, + failure_count: Arc>, + max_failures_before_fallback: usize, + fallback_turns: usize, + in_fallback_mode: Arc>, + fallback_remaining: Arc>, +} + +impl LeadWorkerProvider { + /// Create a new LeadWorkerProvider + /// + /// # Arguments + /// * `lead_provider` - The provider to use for the initial turns + /// * `worker_provider` - The provider to use after lead_turns + /// * `lead_turns` - Number of turns to use the lead provider (default: 3) + pub fn new( + lead_provider: Arc, + worker_provider: Arc, + lead_turns: Option, + ) -> Self { + Self { + lead_provider, + worker_provider, + lead_turns: lead_turns.unwrap_or(3), + turn_count: Arc::new(Mutex::new(0)), + failure_count: Arc::new(Mutex::new(0)), + max_failures_before_fallback: 2, // Fallback after 2 consecutive failures + fallback_turns: 2, // Use lead model for 2 turns when in fallback mode + in_fallback_mode: Arc::new(Mutex::new(false)), + fallback_remaining: Arc::new(Mutex::new(0)), + } + } + + /// Create a new LeadWorkerProvider with custom settings + /// + /// # Arguments + /// * `lead_provider` - The provider to use for the initial turns + /// * `worker_provider` - The provider to use after lead_turns + /// * `lead_turns` - Number of turns to use the lead provider + /// * `failure_threshold` - Number of consecutive failures before fallback + /// * `fallback_turns` - Number of turns to use lead model in fallback mode + pub fn new_with_settings( + lead_provider: Arc, + worker_provider: Arc, + lead_turns: usize, + failure_threshold: usize, + fallback_turns: usize, + ) -> Self { + Self { + lead_provider, + worker_provider, + lead_turns, + turn_count: Arc::new(Mutex::new(0)), + failure_count: Arc::new(Mutex::new(0)), + max_failures_before_fallback: failure_threshold, + fallback_turns, + in_fallback_mode: Arc::new(Mutex::new(false)), + fallback_remaining: Arc::new(Mutex::new(0)), + } + } + + /// Reset the turn counter and failure tracking (useful for new conversations) + pub async fn reset_turn_count(&self) { + let mut count = self.turn_count.lock().await; + *count = 0; + let mut failures = self.failure_count.lock().await; + *failures = 0; + let mut fallback = self.in_fallback_mode.lock().await; + *fallback = false; + let mut remaining = self.fallback_remaining.lock().await; + *remaining = 0; + } + + /// Get the current turn count + pub async fn get_turn_count(&self) -> usize { + *self.turn_count.lock().await + } + + /// Get the current failure count + pub async fn get_failure_count(&self) -> usize { + *self.failure_count.lock().await + } + + /// Check if currently in fallback mode + pub async fn is_in_fallback_mode(&self) -> bool { + *self.in_fallback_mode.lock().await + } + + /// Get the currently active provider based on turn count and fallback state + async fn get_active_provider(&self) -> Arc { + let count = *self.turn_count.lock().await; + let in_fallback = *self.in_fallback_mode.lock().await; + + // Use lead provider if we're in initial turns OR in fallback mode + if count < self.lead_turns || in_fallback { + Arc::clone(&self.lead_provider) + } else { + Arc::clone(&self.worker_provider) + } + } + + /// Handle the result of a completion attempt and update failure tracking + async fn handle_completion_result( + &self, + result: &Result<(Message, ProviderUsage), ProviderError>, + ) { + match result { + Ok((message, _usage)) => { + // Check for task-level failures in the response + let has_task_failure = self.detect_task_failures(message).await; + + if has_task_failure { + // Task failure detected - increment failure count + let mut failures = self.failure_count.lock().await; + *failures += 1; + + let failure_count = *failures; + let turn_count = *self.turn_count.lock().await; + + tracing::warn!( + "Task failure detected in response (failure count: {})", + failure_count + ); + + // Check if we should trigger fallback + if turn_count >= self.lead_turns + && !*self.in_fallback_mode.lock().await + && failure_count >= self.max_failures_before_fallback + { + let mut in_fallback = self.in_fallback_mode.lock().await; + let mut fallback_remaining = self.fallback_remaining.lock().await; + + *in_fallback = true; + *fallback_remaining = self.fallback_turns; + *failures = 0; // Reset failure count when entering fallback + + tracing::warn!( + "🔄 SWITCHING TO LEAD MODEL: Entering fallback mode after {} consecutive task failures - using lead model for {} turns", + self.max_failures_before_fallback, + self.fallback_turns + ); + } + } else { + // Success - reset failure count and handle fallback mode + let mut failures = self.failure_count.lock().await; + *failures = 0; + + let mut in_fallback = self.in_fallback_mode.lock().await; + let mut fallback_remaining = self.fallback_remaining.lock().await; + + if *in_fallback { + *fallback_remaining -= 1; + if *fallback_remaining == 0 { + *in_fallback = false; + tracing::info!("✅ SWITCHING BACK TO WORKER MODEL: Exiting fallback mode - worker model resumed"); + } + } + } + + // Increment turn count on any completion (success or task failure) + let mut count = self.turn_count.lock().await; + *count += 1; + } + Err(_) => { + // Technical failure - just log and let it bubble up + // For technical failures (API/LLM issues), we don't want to second-guess + // the model choice - just let the default model handle it + tracing::warn!( + "Technical failure detected - API/LLM issue, will use default model" + ); + + // Don't increment turn count or failure tracking for technical failures + // as these are temporary infrastructure issues, not model capability issues + } + } + } + + /// Detect task-level failures in the model's response + async fn detect_task_failures(&self, message: &Message) -> bool { + let mut failure_indicators = 0; + + for content in &message.content { + match content { + MessageContent::ToolRequest(tool_request) => { + // Check if tool request itself failed (malformed, etc.) + if tool_request.tool_call.is_err() { + failure_indicators += 1; + tracing::debug!( + "Failed tool request detected: {:?}", + tool_request.tool_call + ); + } + } + MessageContent::ToolResponse(tool_response) => { + // Check if tool execution failed + if let Err(tool_error) = &tool_response.tool_result { + failure_indicators += 1; + tracing::debug!("Tool execution failure detected: {:?}", tool_error); + } else if let Ok(contents) = &tool_response.tool_result { + // Check tool output for error indicators + if self.contains_error_indicators(contents) { + failure_indicators += 1; + tracing::debug!("Tool output contains error indicators"); + } + } + } + MessageContent::Text(text_content) => { + // Check for user correction patterns or error acknowledgments + if self.contains_user_correction_patterns(&text_content.text) { + failure_indicators += 1; + tracing::debug!("User correction pattern detected in text"); + } + } + _ => {} + } + } + + // Consider it a failure if we have multiple failure indicators + failure_indicators >= 1 + } + + /// Check if tool output contains error indicators + fn contains_error_indicators(&self, contents: &[Content]) -> bool { + for content in contents { + if let Content::Text(text_content) = content { + let text_lower = text_content.text.to_lowercase(); + + // Common error patterns in tool outputs + if text_lower.contains("error:") + || text_lower.contains("failed:") + || text_lower.contains("exception:") + || text_lower.contains("traceback") + || text_lower.contains("syntax error") + || text_lower.contains("permission denied") + || text_lower.contains("file not found") + || text_lower.contains("command not found") + || text_lower.contains("compilation failed") + || text_lower.contains("test failed") + || text_lower.contains("assertion failed") + { + return true; + } + } + } + false + } + + /// Check for user correction patterns in text + fn contains_user_correction_patterns(&self, text: &str) -> bool { + let text_lower = text.to_lowercase(); + + // Patterns indicating user is correcting or expressing dissatisfaction + text_lower.contains("that's wrong") + || text_lower.contains("that's not right") + || text_lower.contains("that doesn't work") + || text_lower.contains("try again") + || text_lower.contains("let me correct") + || text_lower.contains("actually, ") + || text_lower.contains("no, that's") + || text_lower.contains("that's incorrect") + || text_lower.contains("fix this") + || text_lower.contains("this is broken") + || text_lower.contains("this doesn't") + || text_lower.starts_with("no,") + || text_lower.starts_with("wrong") + || text_lower.starts_with("incorrect") + } +} + +impl LeadWorkerProviderTrait for LeadWorkerProvider { + /// Get information about the lead and worker models for logging + fn get_model_info(&self) -> (String, String) { + let lead_model = self.lead_provider.get_model_config().model_name; + let worker_model = self.worker_provider.get_model_config().model_name; + (lead_model, worker_model) + } +} + +#[async_trait] +impl Provider for LeadWorkerProvider { + fn metadata() -> ProviderMetadata { + // This is a wrapper provider, so we return minimal metadata + ProviderMetadata::new( + "lead_worker", + "Lead/Worker Provider", + "A provider that switches between lead and worker models based on turn count", + "", // No default model as this is determined by the wrapped providers + vec![], // No known models as this depends on wrapped providers + "", // No doc link + vec![], // No config keys as configuration is done through wrapped providers + ) + } + + fn get_model_config(&self) -> ModelConfig { + // Return the lead provider's model config as the default + // In practice, this might need to be more sophisticated + self.lead_provider.get_model_config() + } + + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + // Get the active provider + let provider = self.get_active_provider().await; + + // Log which provider is being used + let turn_count = *self.turn_count.lock().await; + let in_fallback = *self.in_fallback_mode.lock().await; + let fallback_remaining = *self.fallback_remaining.lock().await; + + let provider_type = if turn_count < self.lead_turns { + "lead (initial)" + } else if in_fallback { + "lead (fallback)" + } else { + "worker" + }; + + if in_fallback { + tracing::info!( + "🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining)", + provider_type, + turn_count + 1, + fallback_remaining + ); + } else { + tracing::info!( + "Using {} provider for turn {} (lead_turns: {})", + provider_type, + turn_count + 1, + self.lead_turns + ); + } + + // Make the completion request + let result = provider.complete(system, messages, tools).await; + + // For technical failures, try with default model (lead provider) instead + let final_result = match &result { + Err(_) => { + tracing::warn!("Technical failure with {} provider, retrying with default model (lead provider)", provider_type); + + // Try with lead provider as the default/fallback for technical failures + let default_result = self.lead_provider.complete(system, messages, tools).await; + + match &default_result { + Ok(_) => { + tracing::info!( + "✅ Default model (lead provider) succeeded after technical failure" + ); + default_result + } + Err(_) => { + tracing::error!("❌ Default model (lead provider) also failed - returning original error"); + result // Return the original error + } + } + } + Ok(_) => result, // Success with original provider + }; + + // Handle the result and update tracking (only for successful completions) + self.handle_completion_result(&final_result).await; + + final_result + } + + async fn fetch_supported_models_async(&self) -> Result>, ProviderError> { + // Combine models from both providers + let lead_models = self.lead_provider.fetch_supported_models_async().await?; + let worker_models = self.worker_provider.fetch_supported_models_async().await?; + + match (lead_models, worker_models) { + (Some(lead), Some(worker)) => { + let mut all_models = lead; + all_models.extend(worker); + all_models.sort(); + all_models.dedup(); + Ok(Some(all_models)) + } + (Some(models), None) | (None, Some(models)) => Ok(Some(models)), + (None, None) => Ok(None), + } + } + + fn supports_embeddings(&self) -> bool { + // Support embeddings if either provider supports them + self.lead_provider.supports_embeddings() || self.worker_provider.supports_embeddings() + } + + async fn create_embeddings(&self, texts: Vec) -> Result>, ProviderError> { + // Use the lead provider for embeddings if it supports them, otherwise use worker + if self.lead_provider.supports_embeddings() { + self.lead_provider.create_embeddings(texts).await + } else if self.worker_provider.supports_embeddings() { + self.worker_provider.create_embeddings(texts).await + } else { + Err(ProviderError::ExecutionError( + "Neither lead nor worker provider supports embeddings".to_string(), + )) + } + } + + /// Check if this provider is a LeadWorkerProvider + fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> { + Some(self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::MessageContent; + use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage}; + use chrono::Utc; + use mcp_core::{content::TextContent, Role}; + + #[derive(Clone)] + struct MockProvider { + name: String, + model_config: ModelConfig, + } + + #[async_trait] + impl Provider for MockProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::empty() + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + + async fn complete( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + Ok(( + Message { + role: Role::Assistant, + created: Utc::now().timestamp(), + content: vec![MessageContent::Text(TextContent { + text: format!("Response from {}", self.name), + annotations: None, + })], + }, + ProviderUsage::new(self.name.clone(), Usage::default()), + )) + } + } + + #[tokio::test] + async fn test_lead_worker_switching() { + let lead_provider = Arc::new(MockProvider { + name: "lead".to_string(), + model_config: ModelConfig::new("lead-model".to_string()), + }); + + let worker_provider = Arc::new(MockProvider { + name: "worker".to_string(), + model_config: ModelConfig::new("worker-model".to_string()), + }); + + let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(3)); + + // First three turns should use lead provider + for i in 0..3 { + let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap(); + assert_eq!(usage.model, "lead"); + assert_eq!(provider.get_turn_count().await, i + 1); + assert!(!provider.is_in_fallback_mode().await); + } + + // Subsequent turns should use worker provider + for i in 3..6 { + let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap(); + assert_eq!(usage.model, "worker"); + assert_eq!(provider.get_turn_count().await, i + 1); + assert!(!provider.is_in_fallback_mode().await); + } + + // Reset and verify it goes back to lead + provider.reset_turn_count().await; + assert_eq!(provider.get_turn_count().await, 0); + assert_eq!(provider.get_failure_count().await, 0); + assert!(!provider.is_in_fallback_mode().await); + + let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap(); + assert_eq!(usage.model, "lead"); + } + + #[tokio::test] + async fn test_technical_failure_retry() { + let lead_provider = Arc::new(MockFailureProvider { + name: "lead".to_string(), + model_config: ModelConfig::new("lead-model".to_string()), + should_fail: false, // Lead provider works + }); + + let worker_provider = Arc::new(MockFailureProvider { + name: "worker".to_string(), + model_config: ModelConfig::new("worker-model".to_string()), + should_fail: true, // Worker will fail + }); + + let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2)); + + // First two turns use lead (should succeed) + for _i in 0..2 { + let result = provider.complete("system", &[], &[]).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().1.model, "lead"); + assert!(!provider.is_in_fallback_mode().await); + } + + // Next turn uses worker (will fail, but should retry with lead and succeed) + let result = provider.complete("system", &[], &[]).await; + assert!(result.is_ok()); // Should succeed because lead provider is used as fallback + assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider + assert_eq!(provider.get_failure_count().await, 0); // No failure tracking for technical failures + assert!(!provider.is_in_fallback_mode().await); // Not in fallback mode + + // Another turn - should still try worker first, then retry with lead + let result = provider.complete("system", &[], &[]).await; + assert!(result.is_ok()); // Should succeed because lead provider is used as fallback + assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider + assert_eq!(provider.get_failure_count().await, 0); // Still no failure tracking + assert!(!provider.is_in_fallback_mode().await); // Still not in fallback mode + } + + #[tokio::test] + async fn test_fallback_on_task_failures() { + // Test that task failures (not technical failures) still trigger fallback mode + // This would need a different mock that simulates task failures in successful responses + // For now, we'll test the fallback mode functionality directly + let lead_provider = Arc::new(MockFailureProvider { + name: "lead".to_string(), + model_config: ModelConfig::new("lead-model".to_string()), + should_fail: false, + }); + + let worker_provider = Arc::new(MockFailureProvider { + name: "worker".to_string(), + model_config: ModelConfig::new("worker-model".to_string()), + should_fail: false, + }); + + let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2)); + + // Simulate being in fallback mode + { + let mut in_fallback = provider.in_fallback_mode.lock().await; + *in_fallback = true; + let mut fallback_remaining = provider.fallback_remaining.lock().await; + *fallback_remaining = 2; + let mut turn_count = provider.turn_count.lock().await; + *turn_count = 4; // Past initial lead turns + } + + // Should use lead provider in fallback mode + let result = provider.complete("system", &[], &[]).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().1.model, "lead"); + assert!(provider.is_in_fallback_mode().await); + + // One more fallback turn + let result = provider.complete("system", &[], &[]).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().1.model, "lead"); + assert!(!provider.is_in_fallback_mode().await); // Should exit fallback mode + } + + #[derive(Clone)] + struct MockFailureProvider { + name: String, + model_config: ModelConfig, + should_fail: bool, + } + + #[async_trait] + impl Provider for MockFailureProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::empty() + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + + async fn complete( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + if self.should_fail { + Err(ProviderError::ExecutionError( + "Simulated failure".to_string(), + )) + } else { + Ok(( + Message { + role: Role::Assistant, + created: Utc::now().timestamp(), + content: vec![MessageContent::Text(TextContent { + text: format!("Response from {}", self.name), + annotations: None, + })], + }, + ProviderUsage::new(self.name.clone(), Usage::default()), + )) + } + } + } +} diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 2c1a4571..2f6f1f87 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -13,6 +13,7 @@ pub mod gcpvertexai; pub mod githubcopilot; pub mod google; pub mod groq; +pub mod lead_worker; pub mod oauth; pub mod ollama; pub mod openai; diff --git a/documentation/docs/guides/environment-variables.md b/documentation/docs/guides/environment-variables.md index 4f29ff0e..487661a2 100644 --- a/documentation/docs/guides/environment-variables.md +++ b/documentation/docs/guides/environment-variables.md @@ -9,6 +9,7 @@ Goose supports various environment variables that allow you to customize its beh ## Model Configuration These variables control the [language models](/docs/getting-started/providers) and their behavior. + ### Basic Provider Configuration These are the minimum required variables to get started with Goose. @@ -27,6 +28,7 @@ export GOOSE_PROVIDER="anthropic" export GOOSE_MODEL="claude-3.5-sonnet" export GOOSE_TEMPERATURE=0.7 ``` + ### Advanced Provider Configuration These variables are needed when using custom endpoints, enterprise deployments, or specific provider implementations. @@ -45,7 +47,34 @@ export GOOSE_PROVIDER__TYPE="anthropic" export GOOSE_PROVIDER__HOST="https://api.anthropic.com" export GOOSE_PROVIDER__API_KEY="your-api-key-here" ``` -## Planning Mode Configuration + +### Lead/Worker Model Configuration + +Configure a lead/worker model pattern where a powerful model handles initial planning and complex reasoning, then switches to a faster/cheaper model for execution. + +| Variable | Purpose | Values | Default | +|----------|---------|---------|---------| +| `GOOSE_LEAD_MODEL` | **Required to enable lead mode.** Specifies the lead model name | Model name (e.g., "gpt-4o", "claude-3.5-sonnet") | None | +| `GOOSE_LEAD_PROVIDER` | Provider for the lead model | [See available providers](/docs/getting-started/providers#available-providers) | Falls back to GOOSE_PROVIDER | +| `GOOSE_LEAD_TURNS` | Number of initial turns using the lead model | Integer | 3 | +| `GOOSE_LEAD_FAILURE_THRESHOLD` | Consecutive failures before fallback to lead model | Integer | 2 | +| `GOOSE_LEAD_FALLBACK_TURNS` | Number of turns to use lead model in fallback mode | Integer | 2 | + +**Examples** + +```bash +# Basic lead/worker setup +export GOOSE_LEAD_MODEL="o4" + +# Advanced lead/worker configuration +export GOOSE_LEAD_MODEL="claude4-opus" +export GOOSE_LEAD_PROVIDER="anthropic" +export GOOSE_LEAD_TURNS=5 +export GOOSE_LEAD_FAILURE_THRESHOLD=3 +export GOOSE_LEAD_FALLBACK_TURNS=2 +``` + +### Planning Mode Configuration These variables control Goose's [planning functionality](/docs/guides/creating-plans). diff --git a/test_lead_worker.sh b/test_lead_worker.sh new file mode 100755 index 00000000..3d403b82 --- /dev/null +++ b/test_lead_worker.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Test script for lead/worker provider functionality + +# Set up test environment variables +export GOOSE_PROVIDER="openai" +export GOOSE_MODEL="gpt-4o-mini" +export OPENAI_API_KEY="test-key" + +# Test 1: Default behavior (no lead/worker) +echo "Test 1: Default behavior (no lead/worker)" +unset GOOSE_LEAD_MODEL +unset GOOSE_WORKER_MODEL +unset GOOSE_LEAD_TURNS + +# Test 2: Lead/worker with same provider +echo -e "\nTest 2: Lead/worker with same provider" +export GOOSE_LEAD_MODEL="gpt-4o" +export GOOSE_WORKER_MODEL="gpt-4o-mini" +export GOOSE_LEAD_TURNS="3" + +# Test 3: Lead/worker with default worker (uses main model) +echo -e "\nTest 3: Lead/worker with default worker" +export GOOSE_LEAD_MODEL="gpt-4o" +unset GOOSE_WORKER_MODEL +export GOOSE_LEAD_TURNS="5" + +echo -e "\nConfiguration examples:" +echo "- Default: Uses GOOSE_MODEL for all turns" +echo "- Lead/Worker: Set GOOSE_LEAD_MODEL to use a different model for initial turns" +echo "- GOOSE_LEAD_TURNS: Number of turns to use lead model (default: 5)" +echo "- GOOSE_WORKER_MODEL: Model to use after lead turns (default: GOOSE_MODEL)" \ No newline at end of file