feat: lead/worker model (#2719)

This commit is contained in:
Michael Neale
2025-06-05 13:55:32 +10:00
committed by GitHub
parent 6076c9b5dc
commit 2f8f8e5767
9 changed files with 1088 additions and 11 deletions

View File

@@ -23,6 +23,31 @@ Whether you're prototyping an idea, refining existing code, or managing intricat
Designed for maximum flexibility, goose works with any LLM, seamlessly integrates with MCP servers, and is available as both a desktop app as well as CLI - making it the ultimate AI assistant for developers who want to move faster and focus on innovation. Designed for maximum flexibility, goose works with any LLM, seamlessly integrates with MCP servers, and is available as both a desktop app as well as CLI - making it the ultimate AI assistant for developers who want to move faster and focus on innovation.
## Multiple Model Configuration
goose supports using different models for different purposes to optimize performance and cost, which can work across model providers as well as models.
### Lead/Worker Model Pattern
Use a powerful model for initial planning and complex reasoning, then switch to a faster/cheaper model for execution, this happens automatically by goose:
```bash
# Required: Enable lead model mode
export GOOSE_LEAD_MODEL=modelY
# Optional: configure a provider for the lead model if not the default provider
export GOOSE_LEAD_PROVIDER=providerX # Defaults to main provider
```
### Planning Model Configuration
Use a specialized model for the `/plan` command in CLI mode, this is explicitly invoked when you want to plan (vs execute)
```bash
# Optional: Use different model for planning
export GOOSE_PLANNER_PROVIDER=openai
export GOOSE_PLANNER_MODEL=gpt-4
```
Both patterns help you balance model capabilities with cost and speed for optimal results, and switch between models and vendors as required.
# Quick Links # Quick Links
- [Quickstart](https://block.github.io/goose/docs/quickstart) - [Quickstart](https://block.github.io/goose/docs/quickstart)

View File

@@ -7,6 +7,7 @@ use goose::session;
use goose::session::Identifier; use goose::session::Identifier;
use mcp_client::transport::Error as McpClientError; use mcp_client::transport::Error as McpClientError;
use std::process; use std::process;
use std::sync::Arc;
use super::output; use super::output;
use super::Session; use super::Session;
@@ -55,6 +56,22 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
// Create the agent // Create the agent
let agent: Agent = Agent::new(); let agent: Agent = Agent::new();
let new_provider = create(&provider_name, model_config).unwrap(); let new_provider = create(&provider_name, model_config).unwrap();
// Keep a reference to the provider for display_session_info
let provider_for_display = Arc::clone(&new_provider);
// Log model information at startup
if let Some(lead_worker) = new_provider.as_lead_worker() {
let (lead_model, worker_model) = lead_worker.get_model_info();
tracing::info!(
"🤖 Lead/Worker Mode Enabled: Lead model (first 3 turns): {}, Worker model (turn 4+): {}, Auto-fallback on failures: Enabled",
lead_model,
worker_model
);
} else {
tracing::info!("🤖 Using model: {}", model);
}
agent agent
.update_provider(new_provider) .update_provider(new_provider)
.await .await
@@ -217,6 +234,12 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
session.agent.override_system_prompt(override_prompt).await; session.agent.override_system_prompt(override_prompt).await;
} }
output::display_session_info(session_config.resume, &provider_name, &model, &session_file); output::display_session_info(
session_config.resume,
&provider_name,
&model,
&session_file,
Some(&provider_for_display),
);
session session
} }

View File

@@ -10,6 +10,7 @@ use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::Error; use std::io::Error;
use std::path::Path; use std::path::Path;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
// Re-export theme for use in main // Re-export theme for use in main
@@ -536,7 +537,13 @@ fn shorten_path(path: &str, debug: bool) -> String {
} }
// Session display functions // Session display functions
pub fn display_session_info(resume: bool, provider: &str, model: &str, session_file: &Path) { pub fn display_session_info(
resume: bool,
provider: &str,
model: &str,
session_file: &Path,
provider_instance: Option<&Arc<dyn goose::providers::base::Provider>>,
) {
let start_session_msg = if resume { let start_session_msg = if resume {
"resuming session |" "resuming session |"
} else if session_file.to_str() == Some("/dev/null") || session_file.to_str() == Some("NUL") { } else if session_file.to_str() == Some("/dev/null") || session_file.to_str() == Some("NUL") {
@@ -544,6 +551,22 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f
} else { } else {
"starting session |" "starting session |"
}; };
// Check if we have lead/worker mode
if let Some(provider_inst) = provider_instance {
if let Some(lead_worker) = provider_inst.as_lead_worker() {
let (lead_model, worker_model) = lead_worker.get_model_info();
println!(
"{} {} {} {} {} {} {}",
style(start_session_msg).dim(),
style("provider:").dim(),
style(provider).cyan().dim(),
style("lead model:").dim(),
style(&lead_model).cyan().dim(),
style("worker model:").dim(),
style(&worker_model).cyan().dim(),
);
} else {
println!( println!(
"{} {} {} {} {}", "{} {} {} {} {}",
style(start_session_msg).dim(), style(start_session_msg).dim(),
@@ -552,6 +575,18 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f
style("model:").dim(), style("model:").dim(),
style(model).cyan().dim(), style(model).cyan().dim(),
); );
}
} else {
// Fallback to original behavior if no provider instance
println!(
"{} {} {} {} {}",
style(start_session_msg).dim(),
style("provider:").dim(),
style(provider).cyan().dim(),
style("model:").dim(),
style(model).cyan().dim(),
);
}
if session_file.to_str() != Some("/dev/null") && session_file.to_str() != Some("NUL") { if session_file.to_str() != Some("/dev/null") && session_file.to_str() != Some("NUL") {
println!( println!(

View File

@@ -148,6 +148,12 @@ impl Usage {
use async_trait::async_trait; use async_trait::async_trait;
/// Trait for LeadWorkerProvider-specific functionality
pub trait LeadWorkerProviderTrait {
/// Get information about the lead and worker models for logging
fn get_model_info(&self) -> (String, String);
}
/// Base trait for AI providers (OpenAI, Anthropic, etc) /// Base trait for AI providers (OpenAI, Anthropic, etc)
#[async_trait] #[async_trait]
pub trait Provider: Send + Sync { pub trait Provider: Send + Sync {
@@ -195,6 +201,12 @@ pub trait Provider: Send + Sync {
"This provider does not support embeddings".to_string(), "This provider does not support embeddings".to_string(),
)) ))
} }
/// Check if this provider is a LeadWorkerProvider
/// This is used for logging model information at startup
fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
None
}
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -10,6 +10,7 @@ use super::{
githubcopilot::GithubCopilotProvider, githubcopilot::GithubCopilotProvider,
google::GoogleProvider, google::GoogleProvider,
groq::GroqProvider, groq::GroqProvider,
lead_worker::LeadWorkerProvider,
ollama::OllamaProvider, ollama::OllamaProvider,
openai::OpenAiProvider, openai::OpenAiProvider,
openrouter::OpenRouterProvider, openrouter::OpenRouterProvider,
@@ -19,6 +20,21 @@ use super::{
use crate::model::ModelConfig; use crate::model::ModelConfig;
use anyhow::Result; use anyhow::Result;
#[cfg(test)]
use super::errors::ProviderError;
#[cfg(test)]
use mcp_core::tool::Tool;
fn default_lead_turns() -> usize {
3
}
fn default_failure_threshold() -> usize {
2
}
fn default_fallback_turns() -> usize {
2
}
pub fn providers() -> Vec<ProviderMetadata> { pub fn providers() -> Vec<ProviderMetadata> {
vec![ vec![
AnthropicProvider::metadata(), AnthropicProvider::metadata(),
@@ -38,6 +54,62 @@ pub fn providers() -> Vec<ProviderMetadata> {
} }
pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> { pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
let config = crate::config::Config::global();
// Check for lead model environment variables
if let Ok(lead_model_name) = config.get_param::<String>("GOOSE_LEAD_MODEL") {
tracing::info!("Creating lead/worker provider from environment variables");
return create_lead_worker_from_env(name, &model, &lead_model_name);
}
// Default: create regular provider
create_provider(name, model)
}
/// Create a lead/worker provider from environment variables
fn create_lead_worker_from_env(
default_provider_name: &str,
default_model: &ModelConfig,
lead_model_name: &str,
) -> Result<Arc<dyn Provider>> {
let config = crate::config::Config::global();
// Get lead provider (optional, defaults to main provider)
let lead_provider_name = config
.get_param::<String>("GOOSE_LEAD_PROVIDER")
.unwrap_or_else(|_| default_provider_name.to_string());
// Get configuration parameters with defaults
let lead_turns = config
.get_param::<usize>("GOOSE_LEAD_TURNS")
.unwrap_or(default_lead_turns());
let failure_threshold = config
.get_param::<usize>("GOOSE_LEAD_FAILURE_THRESHOLD")
.unwrap_or(default_failure_threshold());
let fallback_turns = config
.get_param::<usize>("GOOSE_LEAD_FALLBACK_TURNS")
.unwrap_or(default_fallback_turns());
// Create model configs
let lead_model_config = ModelConfig::new(lead_model_name.to_string());
let worker_model_config = default_model.clone();
// Create the providers
let lead_provider = create_provider(&lead_provider_name, lead_model_config)?;
let worker_provider = create_provider(default_provider_name, worker_model_config)?;
// Create the lead/worker provider with configured settings
Ok(Arc::new(LeadWorkerProvider::new_with_settings(
lead_provider,
worker_provider,
lead_turns,
failure_threshold,
fallback_turns,
)))
}
fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
// We use Arc instead of Box to be able to clone for multiple async tasks // We use Arc instead of Box to be able to clone for multiple async tasks
match name { match name {
"openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)), "openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)),
@@ -56,3 +128,215 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)), _ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::message::{Message, MessageContent};
use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage};
use chrono::Utc;
use mcp_core::{content::TextContent, Role};
use std::env;
#[derive(Clone)]
struct MockTestProvider {
name: String,
model_config: ModelConfig,
}
#[async_trait::async_trait]
impl Provider for MockTestProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"mock_test",
"Mock Test Provider",
"A mock provider for testing",
"mock-model",
vec!["mock-model"],
"",
vec![],
)
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: format!(
"Response from {} with model {}",
self.name, self.model_config.model_name
),
annotations: None,
})],
},
ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()),
))
}
}
#[test]
fn test_create_lead_worker_provider() {
// Save current env vars
let saved_lead = env::var("GOOSE_LEAD_MODEL").ok();
let saved_provider = env::var("GOOSE_LEAD_PROVIDER").ok();
let saved_turns = env::var("GOOSE_LEAD_TURNS").ok();
// Test with basic lead model configuration
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
// This will try to create a lead/worker provider
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// The creation might succeed or fail depending on API keys, but we can verify the logic path
match result {
Ok(_) => {
// If it succeeds, it means we created a lead/worker provider successfully
// This would happen if API keys are available in the test environment
}
Err(error) => {
// If it fails, it should be due to missing API keys, confirming we tried to create providers
let error_msg = error.to_string();
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
}
}
// Test with different lead provider
env::set_var("GOOSE_LEAD_PROVIDER", "anthropic");
env::set_var("GOOSE_LEAD_TURNS", "5");
let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// Similar validation as above - will fail due to missing API keys but confirms the logic
// Restore env vars
match saved_lead {
Some(val) => env::set_var("GOOSE_LEAD_MODEL", val),
None => env::remove_var("GOOSE_LEAD_MODEL"),
}
match saved_provider {
Some(val) => env::set_var("GOOSE_LEAD_PROVIDER", val),
None => env::remove_var("GOOSE_LEAD_PROVIDER"),
}
match saved_turns {
Some(val) => env::set_var("GOOSE_LEAD_TURNS", val),
None => env::remove_var("GOOSE_LEAD_TURNS"),
}
}
#[test]
fn test_lead_model_env_vars_with_defaults() {
// Save current env vars
let saved_vars = [
("GOOSE_LEAD_MODEL", env::var("GOOSE_LEAD_MODEL").ok()),
("GOOSE_LEAD_PROVIDER", env::var("GOOSE_LEAD_PROVIDER").ok()),
("GOOSE_LEAD_TURNS", env::var("GOOSE_LEAD_TURNS").ok()),
(
"GOOSE_LEAD_FAILURE_THRESHOLD",
env::var("GOOSE_LEAD_FAILURE_THRESHOLD").ok(),
),
(
"GOOSE_LEAD_FALLBACK_TURNS",
env::var("GOOSE_LEAD_FALLBACK_TURNS").ok(),
),
];
// Clear all lead env vars
for (key, _) in &saved_vars {
env::remove_var(key);
}
// Set only the required lead model
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
// This should use defaults for all other values
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// Should attempt to create lead/worker provider (will fail due to missing API keys but confirms logic)
match result {
Ok(_) => {
// Success means we have API keys and created the provider
}
Err(error) => {
// Should fail due to missing API keys, confirming we tried to create providers
let error_msg = error.to_string();
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
}
}
// Test with custom values
env::set_var("GOOSE_LEAD_TURNS", "7");
env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", "4");
env::set_var("GOOSE_LEAD_FALLBACK_TURNS", "3");
let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// Should still attempt to create lead/worker provider with custom settings
// Restore all env vars
for (key, value) in saved_vars {
match value {
Some(val) => env::set_var(key, val),
None => env::remove_var(key),
}
}
}
#[test]
fn test_create_regular_provider_without_lead_config() {
// Save current env vars
let saved_lead = env::var("GOOSE_LEAD_MODEL").ok();
let saved_provider = env::var("GOOSE_LEAD_PROVIDER").ok();
let saved_turns = env::var("GOOSE_LEAD_TURNS").ok();
let saved_threshold = env::var("GOOSE_LEAD_FAILURE_THRESHOLD").ok();
let saved_fallback = env::var("GOOSE_LEAD_FALLBACK_TURNS").ok();
// Ensure all GOOSE_LEAD_* variables are not set
env::remove_var("GOOSE_LEAD_MODEL");
env::remove_var("GOOSE_LEAD_PROVIDER");
env::remove_var("GOOSE_LEAD_TURNS");
env::remove_var("GOOSE_LEAD_FAILURE_THRESHOLD");
env::remove_var("GOOSE_LEAD_FALLBACK_TURNS");
// This should try to create a regular provider
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// The creation might succeed or fail depending on API keys
match result {
Ok(_) => {
// If it succeeds, it means we created a regular provider successfully
// This would happen if API keys are available in the test environment
}
Err(error) => {
// If it fails, it should be due to missing API keys
let error_msg = error.to_string();
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
}
}
// Restore env vars
if let Some(val) = saved_lead {
env::set_var("GOOSE_LEAD_MODEL", val);
}
if let Some(val) = saved_provider {
env::set_var("GOOSE_LEAD_PROVIDER", val);
}
if let Some(val) = saved_turns {
env::set_var("GOOSE_LEAD_TURNS", val);
}
if let Some(val) = saved_threshold {
env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", val);
}
if let Some(val) = saved_fallback {
env::set_var("GOOSE_LEAD_FALLBACK_TURNS", val);
}
}
}

View File

@@ -0,0 +1,637 @@
use anyhow::Result;
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::Mutex;
use super::base::{LeadWorkerProviderTrait, Provider, ProviderMetadata, ProviderUsage};
use super::errors::ProviderError;
use crate::message::{Message, MessageContent};
use crate::model::ModelConfig;
use mcp_core::{tool::Tool, Content};
/// A provider that switches between a lead model and a worker model based on turn count
/// and can fallback to lead model on consecutive failures
pub struct LeadWorkerProvider {
lead_provider: Arc<dyn Provider>,
worker_provider: Arc<dyn Provider>,
lead_turns: usize,
turn_count: Arc<Mutex<usize>>,
failure_count: Arc<Mutex<usize>>,
max_failures_before_fallback: usize,
fallback_turns: usize,
in_fallback_mode: Arc<Mutex<bool>>,
fallback_remaining: Arc<Mutex<usize>>,
}
impl LeadWorkerProvider {
/// Create a new LeadWorkerProvider
///
/// # Arguments
/// * `lead_provider` - The provider to use for the initial turns
/// * `worker_provider` - The provider to use after lead_turns
/// * `lead_turns` - Number of turns to use the lead provider (default: 3)
pub fn new(
lead_provider: Arc<dyn Provider>,
worker_provider: Arc<dyn Provider>,
lead_turns: Option<usize>,
) -> Self {
Self {
lead_provider,
worker_provider,
lead_turns: lead_turns.unwrap_or(3),
turn_count: Arc::new(Mutex::new(0)),
failure_count: Arc::new(Mutex::new(0)),
max_failures_before_fallback: 2, // Fallback after 2 consecutive failures
fallback_turns: 2, // Use lead model for 2 turns when in fallback mode
in_fallback_mode: Arc::new(Mutex::new(false)),
fallback_remaining: Arc::new(Mutex::new(0)),
}
}
/// Create a new LeadWorkerProvider with custom settings
///
/// # Arguments
/// * `lead_provider` - The provider to use for the initial turns
/// * `worker_provider` - The provider to use after lead_turns
/// * `lead_turns` - Number of turns to use the lead provider
/// * `failure_threshold` - Number of consecutive failures before fallback
/// * `fallback_turns` - Number of turns to use lead model in fallback mode
pub fn new_with_settings(
lead_provider: Arc<dyn Provider>,
worker_provider: Arc<dyn Provider>,
lead_turns: usize,
failure_threshold: usize,
fallback_turns: usize,
) -> Self {
Self {
lead_provider,
worker_provider,
lead_turns,
turn_count: Arc::new(Mutex::new(0)),
failure_count: Arc::new(Mutex::new(0)),
max_failures_before_fallback: failure_threshold,
fallback_turns,
in_fallback_mode: Arc::new(Mutex::new(false)),
fallback_remaining: Arc::new(Mutex::new(0)),
}
}
/// Reset the turn counter and failure tracking (useful for new conversations)
pub async fn reset_turn_count(&self) {
let mut count = self.turn_count.lock().await;
*count = 0;
let mut failures = self.failure_count.lock().await;
*failures = 0;
let mut fallback = self.in_fallback_mode.lock().await;
*fallback = false;
let mut remaining = self.fallback_remaining.lock().await;
*remaining = 0;
}
/// Get the current turn count
pub async fn get_turn_count(&self) -> usize {
*self.turn_count.lock().await
}
/// Get the current failure count
pub async fn get_failure_count(&self) -> usize {
*self.failure_count.lock().await
}
/// Check if currently in fallback mode
pub async fn is_in_fallback_mode(&self) -> bool {
*self.in_fallback_mode.lock().await
}
/// Get the currently active provider based on turn count and fallback state
async fn get_active_provider(&self) -> Arc<dyn Provider> {
let count = *self.turn_count.lock().await;
let in_fallback = *self.in_fallback_mode.lock().await;
// Use lead provider if we're in initial turns OR in fallback mode
if count < self.lead_turns || in_fallback {
Arc::clone(&self.lead_provider)
} else {
Arc::clone(&self.worker_provider)
}
}
/// Handle the result of a completion attempt and update failure tracking
async fn handle_completion_result(
&self,
result: &Result<(Message, ProviderUsage), ProviderError>,
) {
match result {
Ok((message, _usage)) => {
// Check for task-level failures in the response
let has_task_failure = self.detect_task_failures(message).await;
if has_task_failure {
// Task failure detected - increment failure count
let mut failures = self.failure_count.lock().await;
*failures += 1;
let failure_count = *failures;
let turn_count = *self.turn_count.lock().await;
tracing::warn!(
"Task failure detected in response (failure count: {})",
failure_count
);
// Check if we should trigger fallback
if turn_count >= self.lead_turns
&& !*self.in_fallback_mode.lock().await
&& failure_count >= self.max_failures_before_fallback
{
let mut in_fallback = self.in_fallback_mode.lock().await;
let mut fallback_remaining = self.fallback_remaining.lock().await;
*in_fallback = true;
*fallback_remaining = self.fallback_turns;
*failures = 0; // Reset failure count when entering fallback
tracing::warn!(
"🔄 SWITCHING TO LEAD MODEL: Entering fallback mode after {} consecutive task failures - using lead model for {} turns",
self.max_failures_before_fallback,
self.fallback_turns
);
}
} else {
// Success - reset failure count and handle fallback mode
let mut failures = self.failure_count.lock().await;
*failures = 0;
let mut in_fallback = self.in_fallback_mode.lock().await;
let mut fallback_remaining = self.fallback_remaining.lock().await;
if *in_fallback {
*fallback_remaining -= 1;
if *fallback_remaining == 0 {
*in_fallback = false;
tracing::info!("✅ SWITCHING BACK TO WORKER MODEL: Exiting fallback mode - worker model resumed");
}
}
}
// Increment turn count on any completion (success or task failure)
let mut count = self.turn_count.lock().await;
*count += 1;
}
Err(_) => {
// Technical failure - just log and let it bubble up
// For technical failures (API/LLM issues), we don't want to second-guess
// the model choice - just let the default model handle it
tracing::warn!(
"Technical failure detected - API/LLM issue, will use default model"
);
// Don't increment turn count or failure tracking for technical failures
// as these are temporary infrastructure issues, not model capability issues
}
}
}
/// Detect task-level failures in the model's response
async fn detect_task_failures(&self, message: &Message) -> bool {
let mut failure_indicators = 0;
for content in &message.content {
match content {
MessageContent::ToolRequest(tool_request) => {
// Check if tool request itself failed (malformed, etc.)
if tool_request.tool_call.is_err() {
failure_indicators += 1;
tracing::debug!(
"Failed tool request detected: {:?}",
tool_request.tool_call
);
}
}
MessageContent::ToolResponse(tool_response) => {
// Check if tool execution failed
if let Err(tool_error) = &tool_response.tool_result {
failure_indicators += 1;
tracing::debug!("Tool execution failure detected: {:?}", tool_error);
} else if let Ok(contents) = &tool_response.tool_result {
// Check tool output for error indicators
if self.contains_error_indicators(contents) {
failure_indicators += 1;
tracing::debug!("Tool output contains error indicators");
}
}
}
MessageContent::Text(text_content) => {
// Check for user correction patterns or error acknowledgments
if self.contains_user_correction_patterns(&text_content.text) {
failure_indicators += 1;
tracing::debug!("User correction pattern detected in text");
}
}
_ => {}
}
}
// Consider it a failure if we have multiple failure indicators
failure_indicators >= 1
}
/// Check if tool output contains error indicators
fn contains_error_indicators(&self, contents: &[Content]) -> bool {
for content in contents {
if let Content::Text(text_content) = content {
let text_lower = text_content.text.to_lowercase();
// Common error patterns in tool outputs
if text_lower.contains("error:")
|| text_lower.contains("failed:")
|| text_lower.contains("exception:")
|| text_lower.contains("traceback")
|| text_lower.contains("syntax error")
|| text_lower.contains("permission denied")
|| text_lower.contains("file not found")
|| text_lower.contains("command not found")
|| text_lower.contains("compilation failed")
|| text_lower.contains("test failed")
|| text_lower.contains("assertion failed")
{
return true;
}
}
}
false
}
/// Check for user correction patterns in text
fn contains_user_correction_patterns(&self, text: &str) -> bool {
let text_lower = text.to_lowercase();
// Patterns indicating user is correcting or expressing dissatisfaction
text_lower.contains("that's wrong")
|| text_lower.contains("that's not right")
|| text_lower.contains("that doesn't work")
|| text_lower.contains("try again")
|| text_lower.contains("let me correct")
|| text_lower.contains("actually, ")
|| text_lower.contains("no, that's")
|| text_lower.contains("that's incorrect")
|| text_lower.contains("fix this")
|| text_lower.contains("this is broken")
|| text_lower.contains("this doesn't")
|| text_lower.starts_with("no,")
|| text_lower.starts_with("wrong")
|| text_lower.starts_with("incorrect")
}
}
impl LeadWorkerProviderTrait for LeadWorkerProvider {
/// Get information about the lead and worker models for logging
fn get_model_info(&self) -> (String, String) {
let lead_model = self.lead_provider.get_model_config().model_name;
let worker_model = self.worker_provider.get_model_config().model_name;
(lead_model, worker_model)
}
}
#[async_trait]
impl Provider for LeadWorkerProvider {
fn metadata() -> ProviderMetadata {
// This is a wrapper provider, so we return minimal metadata
ProviderMetadata::new(
"lead_worker",
"Lead/Worker Provider",
"A provider that switches between lead and worker models based on turn count",
"", // No default model as this is determined by the wrapped providers
vec![], // No known models as this depends on wrapped providers
"", // No doc link
vec![], // No config keys as configuration is done through wrapped providers
)
}
fn get_model_config(&self) -> ModelConfig {
// Return the lead provider's model config as the default
// In practice, this might need to be more sophisticated
self.lead_provider.get_model_config()
}
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
// Get the active provider
let provider = self.get_active_provider().await;
// Log which provider is being used
let turn_count = *self.turn_count.lock().await;
let in_fallback = *self.in_fallback_mode.lock().await;
let fallback_remaining = *self.fallback_remaining.lock().await;
let provider_type = if turn_count < self.lead_turns {
"lead (initial)"
} else if in_fallback {
"lead (fallback)"
} else {
"worker"
};
if in_fallback {
tracing::info!(
"🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining)",
provider_type,
turn_count + 1,
fallback_remaining
);
} else {
tracing::info!(
"Using {} provider for turn {} (lead_turns: {})",
provider_type,
turn_count + 1,
self.lead_turns
);
}
// Make the completion request
let result = provider.complete(system, messages, tools).await;
// For technical failures, try with default model (lead provider) instead
let final_result = match &result {
Err(_) => {
tracing::warn!("Technical failure with {} provider, retrying with default model (lead provider)", provider_type);
// Try with lead provider as the default/fallback for technical failures
let default_result = self.lead_provider.complete(system, messages, tools).await;
match &default_result {
Ok(_) => {
tracing::info!(
"✅ Default model (lead provider) succeeded after technical failure"
);
default_result
}
Err(_) => {
tracing::error!("❌ Default model (lead provider) also failed - returning original error");
result // Return the original error
}
}
}
Ok(_) => result, // Success with original provider
};
// Handle the result and update tracking (only for successful completions)
self.handle_completion_result(&final_result).await;
final_result
}
async fn fetch_supported_models_async(&self) -> Result<Option<Vec<String>>, ProviderError> {
// Combine models from both providers
let lead_models = self.lead_provider.fetch_supported_models_async().await?;
let worker_models = self.worker_provider.fetch_supported_models_async().await?;
match (lead_models, worker_models) {
(Some(lead), Some(worker)) => {
let mut all_models = lead;
all_models.extend(worker);
all_models.sort();
all_models.dedup();
Ok(Some(all_models))
}
(Some(models), None) | (None, Some(models)) => Ok(Some(models)),
(None, None) => Ok(None),
}
}
fn supports_embeddings(&self) -> bool {
// Support embeddings if either provider supports them
self.lead_provider.supports_embeddings() || self.worker_provider.supports_embeddings()
}
async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
// Use the lead provider for embeddings if it supports them, otherwise use worker
if self.lead_provider.supports_embeddings() {
self.lead_provider.create_embeddings(texts).await
} else if self.worker_provider.supports_embeddings() {
self.worker_provider.create_embeddings(texts).await
} else {
Err(ProviderError::ExecutionError(
"Neither lead nor worker provider supports embeddings".to_string(),
))
}
}
/// Check if this provider is a LeadWorkerProvider
fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
Some(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::MessageContent;
use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage};
use chrono::Utc;
use mcp_core::{content::TextContent, Role};
#[derive(Clone)]
struct MockProvider {
name: String,
model_config: ModelConfig,
}
#[async_trait]
impl Provider for MockProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::empty()
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: format!("Response from {}", self.name),
annotations: None,
})],
},
ProviderUsage::new(self.name.clone(), Usage::default()),
))
}
}
#[tokio::test]
async fn test_lead_worker_switching() {
let lead_provider = Arc::new(MockProvider {
name: "lead".to_string(),
model_config: ModelConfig::new("lead-model".to_string()),
});
let worker_provider = Arc::new(MockProvider {
name: "worker".to_string(),
model_config: ModelConfig::new("worker-model".to_string()),
});
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(3));
// First three turns should use lead provider
for i in 0..3 {
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
assert_eq!(usage.model, "lead");
assert_eq!(provider.get_turn_count().await, i + 1);
assert!(!provider.is_in_fallback_mode().await);
}
// Subsequent turns should use worker provider
for i in 3..6 {
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
assert_eq!(usage.model, "worker");
assert_eq!(provider.get_turn_count().await, i + 1);
assert!(!provider.is_in_fallback_mode().await);
}
// Reset and verify it goes back to lead
provider.reset_turn_count().await;
assert_eq!(provider.get_turn_count().await, 0);
assert_eq!(provider.get_failure_count().await, 0);
assert!(!provider.is_in_fallback_mode().await);
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
assert_eq!(usage.model, "lead");
}
#[tokio::test]
async fn test_technical_failure_retry() {
let lead_provider = Arc::new(MockFailureProvider {
name: "lead".to_string(),
model_config: ModelConfig::new("lead-model".to_string()),
should_fail: false, // Lead provider works
});
let worker_provider = Arc::new(MockFailureProvider {
name: "worker".to_string(),
model_config: ModelConfig::new("worker-model".to_string()),
should_fail: true, // Worker will fail
});
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
// First two turns use lead (should succeed)
for _i in 0..2 {
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().1.model, "lead");
assert!(!provider.is_in_fallback_mode().await);
}
// Next turn uses worker (will fail, but should retry with lead and succeed)
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok()); // Should succeed because lead provider is used as fallback
assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider
assert_eq!(provider.get_failure_count().await, 0); // No failure tracking for technical failures
assert!(!provider.is_in_fallback_mode().await); // Not in fallback mode
// Another turn - should still try worker first, then retry with lead
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok()); // Should succeed because lead provider is used as fallback
assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider
assert_eq!(provider.get_failure_count().await, 0); // Still no failure tracking
assert!(!provider.is_in_fallback_mode().await); // Still not in fallback mode
}
#[tokio::test]
async fn test_fallback_on_task_failures() {
// Test that task failures (not technical failures) still trigger fallback mode
// This would need a different mock that simulates task failures in successful responses
// For now, we'll test the fallback mode functionality directly
let lead_provider = Arc::new(MockFailureProvider {
name: "lead".to_string(),
model_config: ModelConfig::new("lead-model".to_string()),
should_fail: false,
});
let worker_provider = Arc::new(MockFailureProvider {
name: "worker".to_string(),
model_config: ModelConfig::new("worker-model".to_string()),
should_fail: false,
});
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
// Simulate being in fallback mode
{
let mut in_fallback = provider.in_fallback_mode.lock().await;
*in_fallback = true;
let mut fallback_remaining = provider.fallback_remaining.lock().await;
*fallback_remaining = 2;
let mut turn_count = provider.turn_count.lock().await;
*turn_count = 4; // Past initial lead turns
}
// Should use lead provider in fallback mode
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().1.model, "lead");
assert!(provider.is_in_fallback_mode().await);
// One more fallback turn
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().1.model, "lead");
assert!(!provider.is_in_fallback_mode().await); // Should exit fallback mode
}
#[derive(Clone)]
struct MockFailureProvider {
name: String,
model_config: ModelConfig,
should_fail: bool,
}
#[async_trait]
impl Provider for MockFailureProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::empty()
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
if self.should_fail {
Err(ProviderError::ExecutionError(
"Simulated failure".to_string(),
))
} else {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: format!("Response from {}", self.name),
annotations: None,
})],
},
ProviderUsage::new(self.name.clone(), Usage::default()),
))
}
}
}
}

View File

@@ -13,6 +13,7 @@ pub mod gcpvertexai;
pub mod githubcopilot; pub mod githubcopilot;
pub mod google; pub mod google;
pub mod groq; pub mod groq;
pub mod lead_worker;
pub mod oauth; pub mod oauth;
pub mod ollama; pub mod ollama;
pub mod openai; pub mod openai;

View File

@@ -9,6 +9,7 @@ Goose supports various environment variables that allow you to customize its beh
## Model Configuration ## Model Configuration
These variables control the [language models](/docs/getting-started/providers) and their behavior. These variables control the [language models](/docs/getting-started/providers) and their behavior.
### Basic Provider Configuration ### Basic Provider Configuration
These are the minimum required variables to get started with Goose. These are the minimum required variables to get started with Goose.
@@ -27,6 +28,7 @@ export GOOSE_PROVIDER="anthropic"
export GOOSE_MODEL="claude-3.5-sonnet" export GOOSE_MODEL="claude-3.5-sonnet"
export GOOSE_TEMPERATURE=0.7 export GOOSE_TEMPERATURE=0.7
``` ```
### Advanced Provider Configuration ### Advanced Provider Configuration
These variables are needed when using custom endpoints, enterprise deployments, or specific provider implementations. These variables are needed when using custom endpoints, enterprise deployments, or specific provider implementations.
@@ -45,7 +47,34 @@ export GOOSE_PROVIDER__TYPE="anthropic"
export GOOSE_PROVIDER__HOST="https://api.anthropic.com" export GOOSE_PROVIDER__HOST="https://api.anthropic.com"
export GOOSE_PROVIDER__API_KEY="your-api-key-here" export GOOSE_PROVIDER__API_KEY="your-api-key-here"
``` ```
## Planning Mode Configuration
### Lead/Worker Model Configuration
Configure a lead/worker model pattern where a powerful model handles initial planning and complex reasoning, then switches to a faster/cheaper model for execution.
| Variable | Purpose | Values | Default |
|----------|---------|---------|---------|
| `GOOSE_LEAD_MODEL` | **Required to enable lead mode.** Specifies the lead model name | Model name (e.g., "gpt-4o", "claude-3.5-sonnet") | None |
| `GOOSE_LEAD_PROVIDER` | Provider for the lead model | [See available providers](/docs/getting-started/providers#available-providers) | Falls back to GOOSE_PROVIDER |
| `GOOSE_LEAD_TURNS` | Number of initial turns using the lead model | Integer | 3 |
| `GOOSE_LEAD_FAILURE_THRESHOLD` | Consecutive failures before fallback to lead model | Integer | 2 |
| `GOOSE_LEAD_FALLBACK_TURNS` | Number of turns to use lead model in fallback mode | Integer | 2 |
**Examples**
```bash
# Basic lead/worker setup
export GOOSE_LEAD_MODEL="o4"
# Advanced lead/worker configuration
export GOOSE_LEAD_MODEL="claude4-opus"
export GOOSE_LEAD_PROVIDER="anthropic"
export GOOSE_LEAD_TURNS=5
export GOOSE_LEAD_FAILURE_THRESHOLD=3
export GOOSE_LEAD_FALLBACK_TURNS=2
```
### Planning Mode Configuration
These variables control Goose's [planning functionality](/docs/guides/creating-plans). These variables control Goose's [planning functionality](/docs/guides/creating-plans).

31
test_lead_worker.sh Executable file
View File

@@ -0,0 +1,31 @@
#!/bin/bash
# Test script for lead/worker provider functionality
# Set up test environment variables
export GOOSE_PROVIDER="openai"
export GOOSE_MODEL="gpt-4o-mini"
export OPENAI_API_KEY="test-key"
# Test 1: Default behavior (no lead/worker)
echo "Test 1: Default behavior (no lead/worker)"
unset GOOSE_LEAD_MODEL
unset GOOSE_WORKER_MODEL
unset GOOSE_LEAD_TURNS
# Test 2: Lead/worker with same provider
echo -e "\nTest 2: Lead/worker with same provider"
export GOOSE_LEAD_MODEL="gpt-4o"
export GOOSE_WORKER_MODEL="gpt-4o-mini"
export GOOSE_LEAD_TURNS="3"
# Test 3: Lead/worker with default worker (uses main model)
echo -e "\nTest 3: Lead/worker with default worker"
export GOOSE_LEAD_MODEL="gpt-4o"
unset GOOSE_WORKER_MODEL
export GOOSE_LEAD_TURNS="5"
echo -e "\nConfiguration examples:"
echo "- Default: Uses GOOSE_MODEL for all turns"
echo "- Lead/Worker: Set GOOSE_LEAD_MODEL to use a different model for initial turns"
echo "- GOOSE_LEAD_TURNS: Number of turns to use lead model (default: 5)"
echo "- GOOSE_WORKER_MODEL: Model to use after lead turns (default: GOOSE_MODEL)"