feat: lead/worker model (#2719)

This commit is contained in:
Michael Neale
2025-06-05 13:55:32 +10:00
committed by GitHub
parent 6076c9b5dc
commit 2f8f8e5767
9 changed files with 1088 additions and 11 deletions

View File

@@ -23,6 +23,31 @@ Whether you're prototyping an idea, refining existing code, or managing intricat
Designed for maximum flexibility, goose works with any LLM, seamlessly integrates with MCP servers, and is available as both a desktop app as well as CLI - making it the ultimate AI assistant for developers who want to move faster and focus on innovation.
## Multiple Model Configuration
goose supports using different models for different purposes to optimize performance and cost, which can work across model providers as well as models.
### Lead/Worker Model Pattern
Use a powerful model for initial planning and complex reasoning, then switch to a faster/cheaper model for execution, this happens automatically by goose:
```bash
# Required: Enable lead model mode
export GOOSE_LEAD_MODEL=modelY
# Optional: configure a provider for the lead model if not the default provider
export GOOSE_LEAD_PROVIDER=providerX # Defaults to main provider
```
### Planning Model Configuration
Use a specialized model for the `/plan` command in CLI mode, this is explicitly invoked when you want to plan (vs execute)
```bash
# Optional: Use different model for planning
export GOOSE_PLANNER_PROVIDER=openai
export GOOSE_PLANNER_MODEL=gpt-4
```
Both patterns help you balance model capabilities with cost and speed for optimal results, and switch between models and vendors as required.
# Quick Links
- [Quickstart](https://block.github.io/goose/docs/quickstart)

View File

@@ -7,6 +7,7 @@ use goose::session;
use goose::session::Identifier;
use mcp_client::transport::Error as McpClientError;
use std::process;
use std::sync::Arc;
use super::output;
use super::Session;
@@ -55,6 +56,22 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
// Create the agent
let agent: Agent = Agent::new();
let new_provider = create(&provider_name, model_config).unwrap();
// Keep a reference to the provider for display_session_info
let provider_for_display = Arc::clone(&new_provider);
// Log model information at startup
if let Some(lead_worker) = new_provider.as_lead_worker() {
let (lead_model, worker_model) = lead_worker.get_model_info();
tracing::info!(
"🤖 Lead/Worker Mode Enabled: Lead model (first 3 turns): {}, Worker model (turn 4+): {}, Auto-fallback on failures: Enabled",
lead_model,
worker_model
);
} else {
tracing::info!("🤖 Using model: {}", model);
}
agent
.update_provider(new_provider)
.await
@@ -217,6 +234,12 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
session.agent.override_system_prompt(override_prompt).await;
}
output::display_session_info(session_config.resume, &provider_name, &model, &session_file);
output::display_session_info(
session_config.resume,
&provider_name,
&model,
&session_file,
Some(&provider_for_display),
);
session
}

View File

@@ -10,6 +10,7 @@ use std::cell::RefCell;
use std::collections::HashMap;
use std::io::Error;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
// Re-export theme for use in main
@@ -536,7 +537,13 @@ fn shorten_path(path: &str, debug: bool) -> String {
}
// Session display functions
pub fn display_session_info(resume: bool, provider: &str, model: &str, session_file: &Path) {
pub fn display_session_info(
resume: bool,
provider: &str,
model: &str,
session_file: &Path,
provider_instance: Option<&Arc<dyn goose::providers::base::Provider>>,
) {
let start_session_msg = if resume {
"resuming session |"
} else if session_file.to_str() == Some("/dev/null") || session_file.to_str() == Some("NUL") {
@@ -544,6 +551,22 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f
} else {
"starting session |"
};
// Check if we have lead/worker mode
if let Some(provider_inst) = provider_instance {
if let Some(lead_worker) = provider_inst.as_lead_worker() {
let (lead_model, worker_model) = lead_worker.get_model_info();
println!(
"{} {} {} {} {} {} {}",
style(start_session_msg).dim(),
style("provider:").dim(),
style(provider).cyan().dim(),
style("lead model:").dim(),
style(&lead_model).cyan().dim(),
style("worker model:").dim(),
style(&worker_model).cyan().dim(),
);
} else {
println!(
"{} {} {} {} {}",
style(start_session_msg).dim(),
@@ -552,6 +575,18 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f
style("model:").dim(),
style(model).cyan().dim(),
);
}
} else {
// Fallback to original behavior if no provider instance
println!(
"{} {} {} {} {}",
style(start_session_msg).dim(),
style("provider:").dim(),
style(provider).cyan().dim(),
style("model:").dim(),
style(model).cyan().dim(),
);
}
if session_file.to_str() != Some("/dev/null") && session_file.to_str() != Some("NUL") {
println!(

View File

@@ -148,6 +148,12 @@ impl Usage {
use async_trait::async_trait;
/// Trait for LeadWorkerProvider-specific functionality
pub trait LeadWorkerProviderTrait {
/// Get information about the lead and worker models for logging
fn get_model_info(&self) -> (String, String);
}
/// Base trait for AI providers (OpenAI, Anthropic, etc)
#[async_trait]
pub trait Provider: Send + Sync {
@@ -195,6 +201,12 @@ pub trait Provider: Send + Sync {
"This provider does not support embeddings".to_string(),
))
}
/// Check if this provider is a LeadWorkerProvider
/// This is used for logging model information at startup
fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
None
}
}
#[cfg(test)]

View File

@@ -10,6 +10,7 @@ use super::{
githubcopilot::GithubCopilotProvider,
google::GoogleProvider,
groq::GroqProvider,
lead_worker::LeadWorkerProvider,
ollama::OllamaProvider,
openai::OpenAiProvider,
openrouter::OpenRouterProvider,
@@ -19,6 +20,21 @@ use super::{
use crate::model::ModelConfig;
use anyhow::Result;
#[cfg(test)]
use super::errors::ProviderError;
#[cfg(test)]
use mcp_core::tool::Tool;
fn default_lead_turns() -> usize {
3
}
fn default_failure_threshold() -> usize {
2
}
fn default_fallback_turns() -> usize {
2
}
pub fn providers() -> Vec<ProviderMetadata> {
vec![
AnthropicProvider::metadata(),
@@ -38,6 +54,62 @@ pub fn providers() -> Vec<ProviderMetadata> {
}
pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
let config = crate::config::Config::global();
// Check for lead model environment variables
if let Ok(lead_model_name) = config.get_param::<String>("GOOSE_LEAD_MODEL") {
tracing::info!("Creating lead/worker provider from environment variables");
return create_lead_worker_from_env(name, &model, &lead_model_name);
}
// Default: create regular provider
create_provider(name, model)
}
/// Create a lead/worker provider from environment variables
fn create_lead_worker_from_env(
default_provider_name: &str,
default_model: &ModelConfig,
lead_model_name: &str,
) -> Result<Arc<dyn Provider>> {
let config = crate::config::Config::global();
// Get lead provider (optional, defaults to main provider)
let lead_provider_name = config
.get_param::<String>("GOOSE_LEAD_PROVIDER")
.unwrap_or_else(|_| default_provider_name.to_string());
// Get configuration parameters with defaults
let lead_turns = config
.get_param::<usize>("GOOSE_LEAD_TURNS")
.unwrap_or(default_lead_turns());
let failure_threshold = config
.get_param::<usize>("GOOSE_LEAD_FAILURE_THRESHOLD")
.unwrap_or(default_failure_threshold());
let fallback_turns = config
.get_param::<usize>("GOOSE_LEAD_FALLBACK_TURNS")
.unwrap_or(default_fallback_turns());
// Create model configs
let lead_model_config = ModelConfig::new(lead_model_name.to_string());
let worker_model_config = default_model.clone();
// Create the providers
let lead_provider = create_provider(&lead_provider_name, lead_model_config)?;
let worker_provider = create_provider(default_provider_name, worker_model_config)?;
// Create the lead/worker provider with configured settings
Ok(Arc::new(LeadWorkerProvider::new_with_settings(
lead_provider,
worker_provider,
lead_turns,
failure_threshold,
fallback_turns,
)))
}
fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
// We use Arc instead of Box to be able to clone for multiple async tasks
match name {
"openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)),
@@ -56,3 +128,215 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::{Message, MessageContent};
use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage};
use chrono::Utc;
use mcp_core::{content::TextContent, Role};
use std::env;
#[derive(Clone)]
struct MockTestProvider {
name: String,
model_config: ModelConfig,
}
#[async_trait::async_trait]
impl Provider for MockTestProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"mock_test",
"Mock Test Provider",
"A mock provider for testing",
"mock-model",
vec!["mock-model"],
"",
vec![],
)
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: format!(
"Response from {} with model {}",
self.name, self.model_config.model_name
),
annotations: None,
})],
},
ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()),
))
}
}
#[test]
fn test_create_lead_worker_provider() {
// Save current env vars
let saved_lead = env::var("GOOSE_LEAD_MODEL").ok();
let saved_provider = env::var("GOOSE_LEAD_PROVIDER").ok();
let saved_turns = env::var("GOOSE_LEAD_TURNS").ok();
// Test with basic lead model configuration
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
// This will try to create a lead/worker provider
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// The creation might succeed or fail depending on API keys, but we can verify the logic path
match result {
Ok(_) => {
// If it succeeds, it means we created a lead/worker provider successfully
// This would happen if API keys are available in the test environment
}
Err(error) => {
// If it fails, it should be due to missing API keys, confirming we tried to create providers
let error_msg = error.to_string();
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
}
}
// Test with different lead provider
env::set_var("GOOSE_LEAD_PROVIDER", "anthropic");
env::set_var("GOOSE_LEAD_TURNS", "5");
let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// Similar validation as above - will fail due to missing API keys but confirms the logic
// Restore env vars
match saved_lead {
Some(val) => env::set_var("GOOSE_LEAD_MODEL", val),
None => env::remove_var("GOOSE_LEAD_MODEL"),
}
match saved_provider {
Some(val) => env::set_var("GOOSE_LEAD_PROVIDER", val),
None => env::remove_var("GOOSE_LEAD_PROVIDER"),
}
match saved_turns {
Some(val) => env::set_var("GOOSE_LEAD_TURNS", val),
None => env::remove_var("GOOSE_LEAD_TURNS"),
}
}
#[test]
fn test_lead_model_env_vars_with_defaults() {
// Save current env vars
let saved_vars = [
("GOOSE_LEAD_MODEL", env::var("GOOSE_LEAD_MODEL").ok()),
("GOOSE_LEAD_PROVIDER", env::var("GOOSE_LEAD_PROVIDER").ok()),
("GOOSE_LEAD_TURNS", env::var("GOOSE_LEAD_TURNS").ok()),
(
"GOOSE_LEAD_FAILURE_THRESHOLD",
env::var("GOOSE_LEAD_FAILURE_THRESHOLD").ok(),
),
(
"GOOSE_LEAD_FALLBACK_TURNS",
env::var("GOOSE_LEAD_FALLBACK_TURNS").ok(),
),
];
// Clear all lead env vars
for (key, _) in &saved_vars {
env::remove_var(key);
}
// Set only the required lead model
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
// This should use defaults for all other values
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// Should attempt to create lead/worker provider (will fail due to missing API keys but confirms logic)
match result {
Ok(_) => {
// Success means we have API keys and created the provider
}
Err(error) => {
// Should fail due to missing API keys, confirming we tried to create providers
let error_msg = error.to_string();
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
}
}
// Test with custom values
env::set_var("GOOSE_LEAD_TURNS", "7");
env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", "4");
env::set_var("GOOSE_LEAD_FALLBACK_TURNS", "3");
let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// Should still attempt to create lead/worker provider with custom settings
// Restore all env vars
for (key, value) in saved_vars {
match value {
Some(val) => env::set_var(key, val),
None => env::remove_var(key),
}
}
}
#[test]
fn test_create_regular_provider_without_lead_config() {
// Save current env vars
let saved_lead = env::var("GOOSE_LEAD_MODEL").ok();
let saved_provider = env::var("GOOSE_LEAD_PROVIDER").ok();
let saved_turns = env::var("GOOSE_LEAD_TURNS").ok();
let saved_threshold = env::var("GOOSE_LEAD_FAILURE_THRESHOLD").ok();
let saved_fallback = env::var("GOOSE_LEAD_FALLBACK_TURNS").ok();
// Ensure all GOOSE_LEAD_* variables are not set
env::remove_var("GOOSE_LEAD_MODEL");
env::remove_var("GOOSE_LEAD_PROVIDER");
env::remove_var("GOOSE_LEAD_TURNS");
env::remove_var("GOOSE_LEAD_FAILURE_THRESHOLD");
env::remove_var("GOOSE_LEAD_FALLBACK_TURNS");
// This should try to create a regular provider
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
// The creation might succeed or fail depending on API keys
match result {
Ok(_) => {
// If it succeeds, it means we created a regular provider successfully
// This would happen if API keys are available in the test environment
}
Err(error) => {
// If it fails, it should be due to missing API keys
let error_msg = error.to_string();
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
}
}
// Restore env vars
if let Some(val) = saved_lead {
env::set_var("GOOSE_LEAD_MODEL", val);
}
if let Some(val) = saved_provider {
env::set_var("GOOSE_LEAD_PROVIDER", val);
}
if let Some(val) = saved_turns {
env::set_var("GOOSE_LEAD_TURNS", val);
}
if let Some(val) = saved_threshold {
env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", val);
}
if let Some(val) = saved_fallback {
env::set_var("GOOSE_LEAD_FALLBACK_TURNS", val);
}
}
}

View File

@@ -0,0 +1,637 @@
use anyhow::Result;
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::Mutex;
use super::base::{LeadWorkerProviderTrait, Provider, ProviderMetadata, ProviderUsage};
use super::errors::ProviderError;
use crate::message::{Message, MessageContent};
use crate::model::ModelConfig;
use mcp_core::{tool::Tool, Content};
/// A provider that switches between a lead model and a worker model based on turn count
/// and can fallback to lead model on consecutive failures
pub struct LeadWorkerProvider {
lead_provider: Arc<dyn Provider>,
worker_provider: Arc<dyn Provider>,
lead_turns: usize,
turn_count: Arc<Mutex<usize>>,
failure_count: Arc<Mutex<usize>>,
max_failures_before_fallback: usize,
fallback_turns: usize,
in_fallback_mode: Arc<Mutex<bool>>,
fallback_remaining: Arc<Mutex<usize>>,
}
impl LeadWorkerProvider {
/// Create a new LeadWorkerProvider
///
/// # Arguments
/// * `lead_provider` - The provider to use for the initial turns
/// * `worker_provider` - The provider to use after lead_turns
/// * `lead_turns` - Number of turns to use the lead provider (default: 3)
pub fn new(
lead_provider: Arc<dyn Provider>,
worker_provider: Arc<dyn Provider>,
lead_turns: Option<usize>,
) -> Self {
Self {
lead_provider,
worker_provider,
lead_turns: lead_turns.unwrap_or(3),
turn_count: Arc::new(Mutex::new(0)),
failure_count: Arc::new(Mutex::new(0)),
max_failures_before_fallback: 2, // Fallback after 2 consecutive failures
fallback_turns: 2, // Use lead model for 2 turns when in fallback mode
in_fallback_mode: Arc::new(Mutex::new(false)),
fallback_remaining: Arc::new(Mutex::new(0)),
}
}
/// Create a new LeadWorkerProvider with custom settings
///
/// # Arguments
/// * `lead_provider` - The provider to use for the initial turns
/// * `worker_provider` - The provider to use after lead_turns
/// * `lead_turns` - Number of turns to use the lead provider
/// * `failure_threshold` - Number of consecutive failures before fallback
/// * `fallback_turns` - Number of turns to use lead model in fallback mode
pub fn new_with_settings(
lead_provider: Arc<dyn Provider>,
worker_provider: Arc<dyn Provider>,
lead_turns: usize,
failure_threshold: usize,
fallback_turns: usize,
) -> Self {
Self {
lead_provider,
worker_provider,
lead_turns,
turn_count: Arc::new(Mutex::new(0)),
failure_count: Arc::new(Mutex::new(0)),
max_failures_before_fallback: failure_threshold,
fallback_turns,
in_fallback_mode: Arc::new(Mutex::new(false)),
fallback_remaining: Arc::new(Mutex::new(0)),
}
}
/// Reset the turn counter and failure tracking (useful for new conversations)
pub async fn reset_turn_count(&self) {
let mut count = self.turn_count.lock().await;
*count = 0;
let mut failures = self.failure_count.lock().await;
*failures = 0;
let mut fallback = self.in_fallback_mode.lock().await;
*fallback = false;
let mut remaining = self.fallback_remaining.lock().await;
*remaining = 0;
}
/// Get the current turn count
pub async fn get_turn_count(&self) -> usize {
*self.turn_count.lock().await
}
/// Get the current failure count
pub async fn get_failure_count(&self) -> usize {
*self.failure_count.lock().await
}
/// Check if currently in fallback mode
pub async fn is_in_fallback_mode(&self) -> bool {
*self.in_fallback_mode.lock().await
}
/// Get the currently active provider based on turn count and fallback state
async fn get_active_provider(&self) -> Arc<dyn Provider> {
let count = *self.turn_count.lock().await;
let in_fallback = *self.in_fallback_mode.lock().await;
// Use lead provider if we're in initial turns OR in fallback mode
if count < self.lead_turns || in_fallback {
Arc::clone(&self.lead_provider)
} else {
Arc::clone(&self.worker_provider)
}
}
/// Handle the result of a completion attempt and update failure tracking
async fn handle_completion_result(
&self,
result: &Result<(Message, ProviderUsage), ProviderError>,
) {
match result {
Ok((message, _usage)) => {
// Check for task-level failures in the response
let has_task_failure = self.detect_task_failures(message).await;
if has_task_failure {
// Task failure detected - increment failure count
let mut failures = self.failure_count.lock().await;
*failures += 1;
let failure_count = *failures;
let turn_count = *self.turn_count.lock().await;
tracing::warn!(
"Task failure detected in response (failure count: {})",
failure_count
);
// Check if we should trigger fallback
if turn_count >= self.lead_turns
&& !*self.in_fallback_mode.lock().await
&& failure_count >= self.max_failures_before_fallback
{
let mut in_fallback = self.in_fallback_mode.lock().await;
let mut fallback_remaining = self.fallback_remaining.lock().await;
*in_fallback = true;
*fallback_remaining = self.fallback_turns;
*failures = 0; // Reset failure count when entering fallback
tracing::warn!(
"🔄 SWITCHING TO LEAD MODEL: Entering fallback mode after {} consecutive task failures - using lead model for {} turns",
self.max_failures_before_fallback,
self.fallback_turns
);
}
} else {
// Success - reset failure count and handle fallback mode
let mut failures = self.failure_count.lock().await;
*failures = 0;
let mut in_fallback = self.in_fallback_mode.lock().await;
let mut fallback_remaining = self.fallback_remaining.lock().await;
if *in_fallback {
*fallback_remaining -= 1;
if *fallback_remaining == 0 {
*in_fallback = false;
tracing::info!("✅ SWITCHING BACK TO WORKER MODEL: Exiting fallback mode - worker model resumed");
}
}
}
// Increment turn count on any completion (success or task failure)
let mut count = self.turn_count.lock().await;
*count += 1;
}
Err(_) => {
// Technical failure - just log and let it bubble up
// For technical failures (API/LLM issues), we don't want to second-guess
// the model choice - just let the default model handle it
tracing::warn!(
"Technical failure detected - API/LLM issue, will use default model"
);
// Don't increment turn count or failure tracking for technical failures
// as these are temporary infrastructure issues, not model capability issues
}
}
}
/// Detect task-level failures in the model's response
async fn detect_task_failures(&self, message: &Message) -> bool {
let mut failure_indicators = 0;
for content in &message.content {
match content {
MessageContent::ToolRequest(tool_request) => {
// Check if tool request itself failed (malformed, etc.)
if tool_request.tool_call.is_err() {
failure_indicators += 1;
tracing::debug!(
"Failed tool request detected: {:?}",
tool_request.tool_call
);
}
}
MessageContent::ToolResponse(tool_response) => {
// Check if tool execution failed
if let Err(tool_error) = &tool_response.tool_result {
failure_indicators += 1;
tracing::debug!("Tool execution failure detected: {:?}", tool_error);
} else if let Ok(contents) = &tool_response.tool_result {
// Check tool output for error indicators
if self.contains_error_indicators(contents) {
failure_indicators += 1;
tracing::debug!("Tool output contains error indicators");
}
}
}
MessageContent::Text(text_content) => {
// Check for user correction patterns or error acknowledgments
if self.contains_user_correction_patterns(&text_content.text) {
failure_indicators += 1;
tracing::debug!("User correction pattern detected in text");
}
}
_ => {}
}
}
// Consider it a failure if we have multiple failure indicators
failure_indicators >= 1
}
/// Check if tool output contains error indicators
fn contains_error_indicators(&self, contents: &[Content]) -> bool {
for content in contents {
if let Content::Text(text_content) = content {
let text_lower = text_content.text.to_lowercase();
// Common error patterns in tool outputs
if text_lower.contains("error:")
|| text_lower.contains("failed:")
|| text_lower.contains("exception:")
|| text_lower.contains("traceback")
|| text_lower.contains("syntax error")
|| text_lower.contains("permission denied")
|| text_lower.contains("file not found")
|| text_lower.contains("command not found")
|| text_lower.contains("compilation failed")
|| text_lower.contains("test failed")
|| text_lower.contains("assertion failed")
{
return true;
}
}
}
false
}
/// Check for user correction patterns in text
fn contains_user_correction_patterns(&self, text: &str) -> bool {
let text_lower = text.to_lowercase();
// Patterns indicating user is correcting or expressing dissatisfaction
text_lower.contains("that's wrong")
|| text_lower.contains("that's not right")
|| text_lower.contains("that doesn't work")
|| text_lower.contains("try again")
|| text_lower.contains("let me correct")
|| text_lower.contains("actually, ")
|| text_lower.contains("no, that's")
|| text_lower.contains("that's incorrect")
|| text_lower.contains("fix this")
|| text_lower.contains("this is broken")
|| text_lower.contains("this doesn't")
|| text_lower.starts_with("no,")
|| text_lower.starts_with("wrong")
|| text_lower.starts_with("incorrect")
}
}
impl LeadWorkerProviderTrait for LeadWorkerProvider {
/// Get information about the lead and worker models for logging
fn get_model_info(&self) -> (String, String) {
let lead_model = self.lead_provider.get_model_config().model_name;
let worker_model = self.worker_provider.get_model_config().model_name;
(lead_model, worker_model)
}
}
#[async_trait]
impl Provider for LeadWorkerProvider {
fn metadata() -> ProviderMetadata {
// This is a wrapper provider, so we return minimal metadata
ProviderMetadata::new(
"lead_worker",
"Lead/Worker Provider",
"A provider that switches between lead and worker models based on turn count",
"", // No default model as this is determined by the wrapped providers
vec![], // No known models as this depends on wrapped providers
"", // No doc link
vec![], // No config keys as configuration is done through wrapped providers
)
}
fn get_model_config(&self) -> ModelConfig {
// Return the lead provider's model config as the default
// In practice, this might need to be more sophisticated
self.lead_provider.get_model_config()
}
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
// Get the active provider
let provider = self.get_active_provider().await;
// Log which provider is being used
let turn_count = *self.turn_count.lock().await;
let in_fallback = *self.in_fallback_mode.lock().await;
let fallback_remaining = *self.fallback_remaining.lock().await;
let provider_type = if turn_count < self.lead_turns {
"lead (initial)"
} else if in_fallback {
"lead (fallback)"
} else {
"worker"
};
if in_fallback {
tracing::info!(
"🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining)",
provider_type,
turn_count + 1,
fallback_remaining
);
} else {
tracing::info!(
"Using {} provider for turn {} (lead_turns: {})",
provider_type,
turn_count + 1,
self.lead_turns
);
}
// Make the completion request
let result = provider.complete(system, messages, tools).await;
// For technical failures, try with default model (lead provider) instead
let final_result = match &result {
Err(_) => {
tracing::warn!("Technical failure with {} provider, retrying with default model (lead provider)", provider_type);
// Try with lead provider as the default/fallback for technical failures
let default_result = self.lead_provider.complete(system, messages, tools).await;
match &default_result {
Ok(_) => {
tracing::info!(
"✅ Default model (lead provider) succeeded after technical failure"
);
default_result
}
Err(_) => {
tracing::error!("❌ Default model (lead provider) also failed - returning original error");
result // Return the original error
}
}
}
Ok(_) => result, // Success with original provider
};
// Handle the result and update tracking (only for successful completions)
self.handle_completion_result(&final_result).await;
final_result
}
async fn fetch_supported_models_async(&self) -> Result<Option<Vec<String>>, ProviderError> {
// Combine models from both providers
let lead_models = self.lead_provider.fetch_supported_models_async().await?;
let worker_models = self.worker_provider.fetch_supported_models_async().await?;
match (lead_models, worker_models) {
(Some(lead), Some(worker)) => {
let mut all_models = lead;
all_models.extend(worker);
all_models.sort();
all_models.dedup();
Ok(Some(all_models))
}
(Some(models), None) | (None, Some(models)) => Ok(Some(models)),
(None, None) => Ok(None),
}
}
fn supports_embeddings(&self) -> bool {
// Support embeddings if either provider supports them
self.lead_provider.supports_embeddings() || self.worker_provider.supports_embeddings()
}
async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
// Use the lead provider for embeddings if it supports them, otherwise use worker
if self.lead_provider.supports_embeddings() {
self.lead_provider.create_embeddings(texts).await
} else if self.worker_provider.supports_embeddings() {
self.worker_provider.create_embeddings(texts).await
} else {
Err(ProviderError::ExecutionError(
"Neither lead nor worker provider supports embeddings".to_string(),
))
}
}
/// Check if this provider is a LeadWorkerProvider
fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
Some(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::MessageContent;
use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage};
use chrono::Utc;
use mcp_core::{content::TextContent, Role};
#[derive(Clone)]
struct MockProvider {
name: String,
model_config: ModelConfig,
}
#[async_trait]
impl Provider for MockProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::empty()
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: format!("Response from {}", self.name),
annotations: None,
})],
},
ProviderUsage::new(self.name.clone(), Usage::default()),
))
}
}
#[tokio::test]
async fn test_lead_worker_switching() {
let lead_provider = Arc::new(MockProvider {
name: "lead".to_string(),
model_config: ModelConfig::new("lead-model".to_string()),
});
let worker_provider = Arc::new(MockProvider {
name: "worker".to_string(),
model_config: ModelConfig::new("worker-model".to_string()),
});
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(3));
// First three turns should use lead provider
for i in 0..3 {
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
assert_eq!(usage.model, "lead");
assert_eq!(provider.get_turn_count().await, i + 1);
assert!(!provider.is_in_fallback_mode().await);
}
// Subsequent turns should use worker provider
for i in 3..6 {
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
assert_eq!(usage.model, "worker");
assert_eq!(provider.get_turn_count().await, i + 1);
assert!(!provider.is_in_fallback_mode().await);
}
// Reset and verify it goes back to lead
provider.reset_turn_count().await;
assert_eq!(provider.get_turn_count().await, 0);
assert_eq!(provider.get_failure_count().await, 0);
assert!(!provider.is_in_fallback_mode().await);
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
assert_eq!(usage.model, "lead");
}
#[tokio::test]
async fn test_technical_failure_retry() {
let lead_provider = Arc::new(MockFailureProvider {
name: "lead".to_string(),
model_config: ModelConfig::new("lead-model".to_string()),
should_fail: false, // Lead provider works
});
let worker_provider = Arc::new(MockFailureProvider {
name: "worker".to_string(),
model_config: ModelConfig::new("worker-model".to_string()),
should_fail: true, // Worker will fail
});
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
// First two turns use lead (should succeed)
for _i in 0..2 {
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().1.model, "lead");
assert!(!provider.is_in_fallback_mode().await);
}
// Next turn uses worker (will fail, but should retry with lead and succeed)
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok()); // Should succeed because lead provider is used as fallback
assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider
assert_eq!(provider.get_failure_count().await, 0); // No failure tracking for technical failures
assert!(!provider.is_in_fallback_mode().await); // Not in fallback mode
// Another turn - should still try worker first, then retry with lead
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok()); // Should succeed because lead provider is used as fallback
assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider
assert_eq!(provider.get_failure_count().await, 0); // Still no failure tracking
assert!(!provider.is_in_fallback_mode().await); // Still not in fallback mode
}
#[tokio::test]
async fn test_fallback_on_task_failures() {
// Test that task failures (not technical failures) still trigger fallback mode
// This would need a different mock that simulates task failures in successful responses
// For now, we'll test the fallback mode functionality directly
let lead_provider = Arc::new(MockFailureProvider {
name: "lead".to_string(),
model_config: ModelConfig::new("lead-model".to_string()),
should_fail: false,
});
let worker_provider = Arc::new(MockFailureProvider {
name: "worker".to_string(),
model_config: ModelConfig::new("worker-model".to_string()),
should_fail: false,
});
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
// Simulate being in fallback mode
{
let mut in_fallback = provider.in_fallback_mode.lock().await;
*in_fallback = true;
let mut fallback_remaining = provider.fallback_remaining.lock().await;
*fallback_remaining = 2;
let mut turn_count = provider.turn_count.lock().await;
*turn_count = 4; // Past initial lead turns
}
// Should use lead provider in fallback mode
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().1.model, "lead");
assert!(provider.is_in_fallback_mode().await);
// One more fallback turn
let result = provider.complete("system", &[], &[]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().1.model, "lead");
assert!(!provider.is_in_fallback_mode().await); // Should exit fallback mode
}
#[derive(Clone)]
struct MockFailureProvider {
name: String,
model_config: ModelConfig,
should_fail: bool,
}
#[async_trait]
impl Provider for MockFailureProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::empty()
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
if self.should_fail {
Err(ProviderError::ExecutionError(
"Simulated failure".to_string(),
))
} else {
Ok((
Message {
role: Role::Assistant,
created: Utc::now().timestamp(),
content: vec![MessageContent::Text(TextContent {
text: format!("Response from {}", self.name),
annotations: None,
})],
},
ProviderUsage::new(self.name.clone(), Usage::default()),
))
}
}
}
}

View File

@@ -13,6 +13,7 @@ pub mod gcpvertexai;
pub mod githubcopilot;
pub mod google;
pub mod groq;
pub mod lead_worker;
pub mod oauth;
pub mod ollama;
pub mod openai;

View File

@@ -9,6 +9,7 @@ Goose supports various environment variables that allow you to customize its beh
## Model Configuration
These variables control the [language models](/docs/getting-started/providers) and their behavior.
### Basic Provider Configuration
These are the minimum required variables to get started with Goose.
@@ -27,6 +28,7 @@ export GOOSE_PROVIDER="anthropic"
export GOOSE_MODEL="claude-3.5-sonnet"
export GOOSE_TEMPERATURE=0.7
```
### Advanced Provider Configuration
These variables are needed when using custom endpoints, enterprise deployments, or specific provider implementations.
@@ -45,7 +47,34 @@ export GOOSE_PROVIDER__TYPE="anthropic"
export GOOSE_PROVIDER__HOST="https://api.anthropic.com"
export GOOSE_PROVIDER__API_KEY="your-api-key-here"
```
## Planning Mode Configuration
### Lead/Worker Model Configuration
Configure a lead/worker model pattern where a powerful model handles initial planning and complex reasoning, then switches to a faster/cheaper model for execution.
| Variable | Purpose | Values | Default |
|----------|---------|---------|---------|
| `GOOSE_LEAD_MODEL` | **Required to enable lead mode.** Specifies the lead model name | Model name (e.g., "gpt-4o", "claude-3.5-sonnet") | None |
| `GOOSE_LEAD_PROVIDER` | Provider for the lead model | [See available providers](/docs/getting-started/providers#available-providers) | Falls back to GOOSE_PROVIDER |
| `GOOSE_LEAD_TURNS` | Number of initial turns using the lead model | Integer | 3 |
| `GOOSE_LEAD_FAILURE_THRESHOLD` | Consecutive failures before fallback to lead model | Integer | 2 |
| `GOOSE_LEAD_FALLBACK_TURNS` | Number of turns to use lead model in fallback mode | Integer | 2 |
**Examples**
```bash
# Basic lead/worker setup
export GOOSE_LEAD_MODEL="o4"
# Advanced lead/worker configuration
export GOOSE_LEAD_MODEL="claude4-opus"
export GOOSE_LEAD_PROVIDER="anthropic"
export GOOSE_LEAD_TURNS=5
export GOOSE_LEAD_FAILURE_THRESHOLD=3
export GOOSE_LEAD_FALLBACK_TURNS=2
```
### Planning Mode Configuration
These variables control Goose's [planning functionality](/docs/guides/creating-plans).

31
test_lead_worker.sh Executable file
View File

@@ -0,0 +1,31 @@
#!/bin/bash
# Test script for lead/worker provider functionality
# Set up test environment variables
export GOOSE_PROVIDER="openai"
export GOOSE_MODEL="gpt-4o-mini"
export OPENAI_API_KEY="test-key"
# Test 1: Default behavior (no lead/worker)
echo "Test 1: Default behavior (no lead/worker)"
unset GOOSE_LEAD_MODEL
unset GOOSE_WORKER_MODEL
unset GOOSE_LEAD_TURNS
# Test 2: Lead/worker with same provider
echo -e "\nTest 2: Lead/worker with same provider"
export GOOSE_LEAD_MODEL="gpt-4o"
export GOOSE_WORKER_MODEL="gpt-4o-mini"
export GOOSE_LEAD_TURNS="3"
# Test 3: Lead/worker with default worker (uses main model)
echo -e "\nTest 3: Lead/worker with default worker"
export GOOSE_LEAD_MODEL="gpt-4o"
unset GOOSE_WORKER_MODEL
export GOOSE_LEAD_TURNS="5"
echo -e "\nConfiguration examples:"
echo "- Default: Uses GOOSE_MODEL for all turns"
echo "- Lead/Worker: Set GOOSE_LEAD_MODEL to use a different model for initial turns"
echo "- GOOSE_LEAD_TURNS: Number of turns to use lead model (default: 5)"
echo "- GOOSE_WORKER_MODEL: Model to use after lead turns (default: GOOSE_MODEL)"