mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
feat: lead/worker model (#2719)
This commit is contained in:
25
README.md
25
README.md
@@ -23,6 +23,31 @@ Whether you're prototyping an idea, refining existing code, or managing intricat
|
|||||||
|
|
||||||
Designed for maximum flexibility, goose works with any LLM, seamlessly integrates with MCP servers, and is available as both a desktop app as well as CLI - making it the ultimate AI assistant for developers who want to move faster and focus on innovation.
|
Designed for maximum flexibility, goose works with any LLM, seamlessly integrates with MCP servers, and is available as both a desktop app as well as CLI - making it the ultimate AI assistant for developers who want to move faster and focus on innovation.
|
||||||
|
|
||||||
|
## Multiple Model Configuration
|
||||||
|
|
||||||
|
goose supports using different models for different purposes to optimize performance and cost, which can work across model providers as well as models.
|
||||||
|
|
||||||
|
### Lead/Worker Model Pattern
|
||||||
|
Use a powerful model for initial planning and complex reasoning, then switch to a faster/cheaper model for execution, this happens automatically by goose:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Required: Enable lead model mode
|
||||||
|
export GOOSE_LEAD_MODEL=modelY
|
||||||
|
# Optional: configure a provider for the lead model if not the default provider
|
||||||
|
export GOOSE_LEAD_PROVIDER=providerX # Defaults to main provider
|
||||||
|
```
|
||||||
|
|
||||||
|
### Planning Model Configuration
|
||||||
|
Use a specialized model for the `/plan` command in CLI mode, this is explicitly invoked when you want to plan (vs execute)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Optional: Use different model for planning
|
||||||
|
export GOOSE_PLANNER_PROVIDER=openai
|
||||||
|
export GOOSE_PLANNER_MODEL=gpt-4
|
||||||
|
```
|
||||||
|
|
||||||
|
Both patterns help you balance model capabilities with cost and speed for optimal results, and switch between models and vendors as required.
|
||||||
|
|
||||||
|
|
||||||
# Quick Links
|
# Quick Links
|
||||||
- [Quickstart](https://block.github.io/goose/docs/quickstart)
|
- [Quickstart](https://block.github.io/goose/docs/quickstart)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ use goose::session;
|
|||||||
use goose::session::Identifier;
|
use goose::session::Identifier;
|
||||||
use mcp_client::transport::Error as McpClientError;
|
use mcp_client::transport::Error as McpClientError;
|
||||||
use std::process;
|
use std::process;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use super::output;
|
use super::output;
|
||||||
use super::Session;
|
use super::Session;
|
||||||
@@ -55,6 +56,22 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
|||||||
// Create the agent
|
// Create the agent
|
||||||
let agent: Agent = Agent::new();
|
let agent: Agent = Agent::new();
|
||||||
let new_provider = create(&provider_name, model_config).unwrap();
|
let new_provider = create(&provider_name, model_config).unwrap();
|
||||||
|
|
||||||
|
// Keep a reference to the provider for display_session_info
|
||||||
|
let provider_for_display = Arc::clone(&new_provider);
|
||||||
|
|
||||||
|
// Log model information at startup
|
||||||
|
if let Some(lead_worker) = new_provider.as_lead_worker() {
|
||||||
|
let (lead_model, worker_model) = lead_worker.get_model_info();
|
||||||
|
tracing::info!(
|
||||||
|
"🤖 Lead/Worker Mode Enabled: Lead model (first 3 turns): {}, Worker model (turn 4+): {}, Auto-fallback on failures: Enabled",
|
||||||
|
lead_model,
|
||||||
|
worker_model
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
tracing::info!("🤖 Using model: {}", model);
|
||||||
|
}
|
||||||
|
|
||||||
agent
|
agent
|
||||||
.update_provider(new_provider)
|
.update_provider(new_provider)
|
||||||
.await
|
.await
|
||||||
@@ -217,6 +234,12 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
|||||||
session.agent.override_system_prompt(override_prompt).await;
|
session.agent.override_system_prompt(override_prompt).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
output::display_session_info(session_config.resume, &provider_name, &model, &session_file);
|
output::display_session_info(
|
||||||
|
session_config.resume,
|
||||||
|
&provider_name,
|
||||||
|
&model,
|
||||||
|
&session_file,
|
||||||
|
Some(&provider_for_display),
|
||||||
|
);
|
||||||
session
|
session
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use std::cell::RefCell;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::io::Error;
|
use std::io::Error;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
// Re-export theme for use in main
|
// Re-export theme for use in main
|
||||||
@@ -536,7 +537,13 @@ fn shorten_path(path: &str, debug: bool) -> String {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Session display functions
|
// Session display functions
|
||||||
pub fn display_session_info(resume: bool, provider: &str, model: &str, session_file: &Path) {
|
pub fn display_session_info(
|
||||||
|
resume: bool,
|
||||||
|
provider: &str,
|
||||||
|
model: &str,
|
||||||
|
session_file: &Path,
|
||||||
|
provider_instance: Option<&Arc<dyn goose::providers::base::Provider>>,
|
||||||
|
) {
|
||||||
let start_session_msg = if resume {
|
let start_session_msg = if resume {
|
||||||
"resuming session |"
|
"resuming session |"
|
||||||
} else if session_file.to_str() == Some("/dev/null") || session_file.to_str() == Some("NUL") {
|
} else if session_file.to_str() == Some("/dev/null") || session_file.to_str() == Some("NUL") {
|
||||||
@@ -544,6 +551,22 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f
|
|||||||
} else {
|
} else {
|
||||||
"starting session |"
|
"starting session |"
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Check if we have lead/worker mode
|
||||||
|
if let Some(provider_inst) = provider_instance {
|
||||||
|
if let Some(lead_worker) = provider_inst.as_lead_worker() {
|
||||||
|
let (lead_model, worker_model) = lead_worker.get_model_info();
|
||||||
|
println!(
|
||||||
|
"{} {} {} {} {} {} {}",
|
||||||
|
style(start_session_msg).dim(),
|
||||||
|
style("provider:").dim(),
|
||||||
|
style(provider).cyan().dim(),
|
||||||
|
style("lead model:").dim(),
|
||||||
|
style(&lead_model).cyan().dim(),
|
||||||
|
style("worker model:").dim(),
|
||||||
|
style(&worker_model).cyan().dim(),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
println!(
|
println!(
|
||||||
"{} {} {} {} {}",
|
"{} {} {} {} {}",
|
||||||
style(start_session_msg).dim(),
|
style(start_session_msg).dim(),
|
||||||
@@ -552,6 +575,18 @@ pub fn display_session_info(resume: bool, provider: &str, model: &str, session_f
|
|||||||
style("model:").dim(),
|
style("model:").dim(),
|
||||||
style(model).cyan().dim(),
|
style(model).cyan().dim(),
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Fallback to original behavior if no provider instance
|
||||||
|
println!(
|
||||||
|
"{} {} {} {} {}",
|
||||||
|
style(start_session_msg).dim(),
|
||||||
|
style("provider:").dim(),
|
||||||
|
style(provider).cyan().dim(),
|
||||||
|
style("model:").dim(),
|
||||||
|
style(model).cyan().dim(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if session_file.to_str() != Some("/dev/null") && session_file.to_str() != Some("NUL") {
|
if session_file.to_str() != Some("/dev/null") && session_file.to_str() != Some("NUL") {
|
||||||
println!(
|
println!(
|
||||||
|
|||||||
@@ -148,6 +148,12 @@ impl Usage {
|
|||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
/// Trait for LeadWorkerProvider-specific functionality
|
||||||
|
pub trait LeadWorkerProviderTrait {
|
||||||
|
/// Get information about the lead and worker models for logging
|
||||||
|
fn get_model_info(&self) -> (String, String);
|
||||||
|
}
|
||||||
|
|
||||||
/// Base trait for AI providers (OpenAI, Anthropic, etc)
|
/// Base trait for AI providers (OpenAI, Anthropic, etc)
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Provider: Send + Sync {
|
pub trait Provider: Send + Sync {
|
||||||
@@ -195,6 +201,12 @@ pub trait Provider: Send + Sync {
|
|||||||
"This provider does not support embeddings".to_string(),
|
"This provider does not support embeddings".to_string(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if this provider is a LeadWorkerProvider
|
||||||
|
/// This is used for logging model information at startup
|
||||||
|
fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use super::{
|
|||||||
githubcopilot::GithubCopilotProvider,
|
githubcopilot::GithubCopilotProvider,
|
||||||
google::GoogleProvider,
|
google::GoogleProvider,
|
||||||
groq::GroqProvider,
|
groq::GroqProvider,
|
||||||
|
lead_worker::LeadWorkerProvider,
|
||||||
ollama::OllamaProvider,
|
ollama::OllamaProvider,
|
||||||
openai::OpenAiProvider,
|
openai::OpenAiProvider,
|
||||||
openrouter::OpenRouterProvider,
|
openrouter::OpenRouterProvider,
|
||||||
@@ -19,6 +20,21 @@ use super::{
|
|||||||
use crate::model::ModelConfig;
|
use crate::model::ModelConfig;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
use super::errors::ProviderError;
|
||||||
|
#[cfg(test)]
|
||||||
|
use mcp_core::tool::Tool;
|
||||||
|
|
||||||
|
fn default_lead_turns() -> usize {
|
||||||
|
3
|
||||||
|
}
|
||||||
|
fn default_failure_threshold() -> usize {
|
||||||
|
2
|
||||||
|
}
|
||||||
|
fn default_fallback_turns() -> usize {
|
||||||
|
2
|
||||||
|
}
|
||||||
|
|
||||||
pub fn providers() -> Vec<ProviderMetadata> {
|
pub fn providers() -> Vec<ProviderMetadata> {
|
||||||
vec![
|
vec![
|
||||||
AnthropicProvider::metadata(),
|
AnthropicProvider::metadata(),
|
||||||
@@ -38,6 +54,62 @@ pub fn providers() -> Vec<ProviderMetadata> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
|
pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
|
||||||
|
let config = crate::config::Config::global();
|
||||||
|
|
||||||
|
// Check for lead model environment variables
|
||||||
|
if let Ok(lead_model_name) = config.get_param::<String>("GOOSE_LEAD_MODEL") {
|
||||||
|
tracing::info!("Creating lead/worker provider from environment variables");
|
||||||
|
|
||||||
|
return create_lead_worker_from_env(name, &model, &lead_model_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default: create regular provider
|
||||||
|
create_provider(name, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a lead/worker provider from environment variables
|
||||||
|
fn create_lead_worker_from_env(
|
||||||
|
default_provider_name: &str,
|
||||||
|
default_model: &ModelConfig,
|
||||||
|
lead_model_name: &str,
|
||||||
|
) -> Result<Arc<dyn Provider>> {
|
||||||
|
let config = crate::config::Config::global();
|
||||||
|
|
||||||
|
// Get lead provider (optional, defaults to main provider)
|
||||||
|
let lead_provider_name = config
|
||||||
|
.get_param::<String>("GOOSE_LEAD_PROVIDER")
|
||||||
|
.unwrap_or_else(|_| default_provider_name.to_string());
|
||||||
|
|
||||||
|
// Get configuration parameters with defaults
|
||||||
|
let lead_turns = config
|
||||||
|
.get_param::<usize>("GOOSE_LEAD_TURNS")
|
||||||
|
.unwrap_or(default_lead_turns());
|
||||||
|
let failure_threshold = config
|
||||||
|
.get_param::<usize>("GOOSE_LEAD_FAILURE_THRESHOLD")
|
||||||
|
.unwrap_or(default_failure_threshold());
|
||||||
|
let fallback_turns = config
|
||||||
|
.get_param::<usize>("GOOSE_LEAD_FALLBACK_TURNS")
|
||||||
|
.unwrap_or(default_fallback_turns());
|
||||||
|
|
||||||
|
// Create model configs
|
||||||
|
let lead_model_config = ModelConfig::new(lead_model_name.to_string());
|
||||||
|
let worker_model_config = default_model.clone();
|
||||||
|
|
||||||
|
// Create the providers
|
||||||
|
let lead_provider = create_provider(&lead_provider_name, lead_model_config)?;
|
||||||
|
let worker_provider = create_provider(default_provider_name, worker_model_config)?;
|
||||||
|
|
||||||
|
// Create the lead/worker provider with configured settings
|
||||||
|
Ok(Arc::new(LeadWorkerProvider::new_with_settings(
|
||||||
|
lead_provider,
|
||||||
|
worker_provider,
|
||||||
|
lead_turns,
|
||||||
|
failure_threshold,
|
||||||
|
fallback_turns,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_provider(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
|
||||||
// We use Arc instead of Box to be able to clone for multiple async tasks
|
// We use Arc instead of Box to be able to clone for multiple async tasks
|
||||||
match name {
|
match name {
|
||||||
"openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)),
|
"openai" => Ok(Arc::new(OpenAiProvider::from_env(model)?)),
|
||||||
@@ -56,3 +128,215 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
|
|||||||
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::message::{Message, MessageContent};
|
||||||
|
use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage};
|
||||||
|
use chrono::Utc;
|
||||||
|
use mcp_core::{content::TextContent, Role};
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct MockTestProvider {
|
||||||
|
name: String,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Provider for MockTestProvider {
|
||||||
|
fn metadata() -> ProviderMetadata {
|
||||||
|
ProviderMetadata::new(
|
||||||
|
"mock_test",
|
||||||
|
"Mock Test Provider",
|
||||||
|
"A mock provider for testing",
|
||||||
|
"mock-model",
|
||||||
|
vec!["mock-model"],
|
||||||
|
"",
|
||||||
|
vec![],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model_config(&self) -> ModelConfig {
|
||||||
|
self.model_config.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn complete(
|
||||||
|
&self,
|
||||||
|
_system: &str,
|
||||||
|
_messages: &[Message],
|
||||||
|
_tools: &[Tool],
|
||||||
|
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||||
|
Ok((
|
||||||
|
Message {
|
||||||
|
role: Role::Assistant,
|
||||||
|
created: Utc::now().timestamp(),
|
||||||
|
content: vec![MessageContent::Text(TextContent {
|
||||||
|
text: format!(
|
||||||
|
"Response from {} with model {}",
|
||||||
|
self.name, self.model_config.model_name
|
||||||
|
),
|
||||||
|
annotations: None,
|
||||||
|
})],
|
||||||
|
},
|
||||||
|
ProviderUsage::new(self.model_config.model_name.clone(), Usage::default()),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_lead_worker_provider() {
|
||||||
|
// Save current env vars
|
||||||
|
let saved_lead = env::var("GOOSE_LEAD_MODEL").ok();
|
||||||
|
let saved_provider = env::var("GOOSE_LEAD_PROVIDER").ok();
|
||||||
|
let saved_turns = env::var("GOOSE_LEAD_TURNS").ok();
|
||||||
|
|
||||||
|
// Test with basic lead model configuration
|
||||||
|
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
|
||||||
|
|
||||||
|
// This will try to create a lead/worker provider
|
||||||
|
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
||||||
|
|
||||||
|
// The creation might succeed or fail depending on API keys, but we can verify the logic path
|
||||||
|
match result {
|
||||||
|
Ok(_) => {
|
||||||
|
// If it succeeds, it means we created a lead/worker provider successfully
|
||||||
|
// This would happen if API keys are available in the test environment
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
// If it fails, it should be due to missing API keys, confirming we tried to create providers
|
||||||
|
let error_msg = error.to_string();
|
||||||
|
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with different lead provider
|
||||||
|
env::set_var("GOOSE_LEAD_PROVIDER", "anthropic");
|
||||||
|
env::set_var("GOOSE_LEAD_TURNS", "5");
|
||||||
|
|
||||||
|
let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
||||||
|
// Similar validation as above - will fail due to missing API keys but confirms the logic
|
||||||
|
|
||||||
|
// Restore env vars
|
||||||
|
match saved_lead {
|
||||||
|
Some(val) => env::set_var("GOOSE_LEAD_MODEL", val),
|
||||||
|
None => env::remove_var("GOOSE_LEAD_MODEL"),
|
||||||
|
}
|
||||||
|
match saved_provider {
|
||||||
|
Some(val) => env::set_var("GOOSE_LEAD_PROVIDER", val),
|
||||||
|
None => env::remove_var("GOOSE_LEAD_PROVIDER"),
|
||||||
|
}
|
||||||
|
match saved_turns {
|
||||||
|
Some(val) => env::set_var("GOOSE_LEAD_TURNS", val),
|
||||||
|
None => env::remove_var("GOOSE_LEAD_TURNS"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_lead_model_env_vars_with_defaults() {
|
||||||
|
// Save current env vars
|
||||||
|
let saved_vars = [
|
||||||
|
("GOOSE_LEAD_MODEL", env::var("GOOSE_LEAD_MODEL").ok()),
|
||||||
|
("GOOSE_LEAD_PROVIDER", env::var("GOOSE_LEAD_PROVIDER").ok()),
|
||||||
|
("GOOSE_LEAD_TURNS", env::var("GOOSE_LEAD_TURNS").ok()),
|
||||||
|
(
|
||||||
|
"GOOSE_LEAD_FAILURE_THRESHOLD",
|
||||||
|
env::var("GOOSE_LEAD_FAILURE_THRESHOLD").ok(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"GOOSE_LEAD_FALLBACK_TURNS",
|
||||||
|
env::var("GOOSE_LEAD_FALLBACK_TURNS").ok(),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Clear all lead env vars
|
||||||
|
for (key, _) in &saved_vars {
|
||||||
|
env::remove_var(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set only the required lead model
|
||||||
|
env::set_var("GOOSE_LEAD_MODEL", "gpt-4o");
|
||||||
|
|
||||||
|
// This should use defaults for all other values
|
||||||
|
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
||||||
|
|
||||||
|
// Should attempt to create lead/worker provider (will fail due to missing API keys but confirms logic)
|
||||||
|
match result {
|
||||||
|
Ok(_) => {
|
||||||
|
// Success means we have API keys and created the provider
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
// Should fail due to missing API keys, confirming we tried to create providers
|
||||||
|
let error_msg = error.to_string();
|
||||||
|
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with custom values
|
||||||
|
env::set_var("GOOSE_LEAD_TURNS", "7");
|
||||||
|
env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", "4");
|
||||||
|
env::set_var("GOOSE_LEAD_FALLBACK_TURNS", "3");
|
||||||
|
|
||||||
|
let _result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
||||||
|
// Should still attempt to create lead/worker provider with custom settings
|
||||||
|
|
||||||
|
// Restore all env vars
|
||||||
|
for (key, value) in saved_vars {
|
||||||
|
match value {
|
||||||
|
Some(val) => env::set_var(key, val),
|
||||||
|
None => env::remove_var(key),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_create_regular_provider_without_lead_config() {
|
||||||
|
// Save current env vars
|
||||||
|
let saved_lead = env::var("GOOSE_LEAD_MODEL").ok();
|
||||||
|
let saved_provider = env::var("GOOSE_LEAD_PROVIDER").ok();
|
||||||
|
let saved_turns = env::var("GOOSE_LEAD_TURNS").ok();
|
||||||
|
let saved_threshold = env::var("GOOSE_LEAD_FAILURE_THRESHOLD").ok();
|
||||||
|
let saved_fallback = env::var("GOOSE_LEAD_FALLBACK_TURNS").ok();
|
||||||
|
|
||||||
|
// Ensure all GOOSE_LEAD_* variables are not set
|
||||||
|
env::remove_var("GOOSE_LEAD_MODEL");
|
||||||
|
env::remove_var("GOOSE_LEAD_PROVIDER");
|
||||||
|
env::remove_var("GOOSE_LEAD_TURNS");
|
||||||
|
env::remove_var("GOOSE_LEAD_FAILURE_THRESHOLD");
|
||||||
|
env::remove_var("GOOSE_LEAD_FALLBACK_TURNS");
|
||||||
|
|
||||||
|
// This should try to create a regular provider
|
||||||
|
let result = create("openai", ModelConfig::new("gpt-4o-mini".to_string()));
|
||||||
|
|
||||||
|
// The creation might succeed or fail depending on API keys
|
||||||
|
match result {
|
||||||
|
Ok(_) => {
|
||||||
|
// If it succeeds, it means we created a regular provider successfully
|
||||||
|
// This would happen if API keys are available in the test environment
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
// If it fails, it should be due to missing API keys
|
||||||
|
let error_msg = error.to_string();
|
||||||
|
assert!(error_msg.contains("OPENAI_API_KEY") || error_msg.contains("secret"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore env vars
|
||||||
|
if let Some(val) = saved_lead {
|
||||||
|
env::set_var("GOOSE_LEAD_MODEL", val);
|
||||||
|
}
|
||||||
|
if let Some(val) = saved_provider {
|
||||||
|
env::set_var("GOOSE_LEAD_PROVIDER", val);
|
||||||
|
}
|
||||||
|
if let Some(val) = saved_turns {
|
||||||
|
env::set_var("GOOSE_LEAD_TURNS", val);
|
||||||
|
}
|
||||||
|
if let Some(val) = saved_threshold {
|
||||||
|
env::set_var("GOOSE_LEAD_FAILURE_THRESHOLD", val);
|
||||||
|
}
|
||||||
|
if let Some(val) = saved_fallback {
|
||||||
|
env::set_var("GOOSE_LEAD_FALLBACK_TURNS", val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
637
crates/goose/src/providers/lead_worker.rs
Normal file
637
crates/goose/src/providers/lead_worker.rs
Normal file
@@ -0,0 +1,637 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
use super::base::{LeadWorkerProviderTrait, Provider, ProviderMetadata, ProviderUsage};
|
||||||
|
use super::errors::ProviderError;
|
||||||
|
use crate::message::{Message, MessageContent};
|
||||||
|
use crate::model::ModelConfig;
|
||||||
|
use mcp_core::{tool::Tool, Content};
|
||||||
|
|
||||||
|
/// A provider that switches between a lead model and a worker model based on turn count
|
||||||
|
/// and can fallback to lead model on consecutive failures
|
||||||
|
pub struct LeadWorkerProvider {
|
||||||
|
lead_provider: Arc<dyn Provider>,
|
||||||
|
worker_provider: Arc<dyn Provider>,
|
||||||
|
lead_turns: usize,
|
||||||
|
turn_count: Arc<Mutex<usize>>,
|
||||||
|
failure_count: Arc<Mutex<usize>>,
|
||||||
|
max_failures_before_fallback: usize,
|
||||||
|
fallback_turns: usize,
|
||||||
|
in_fallback_mode: Arc<Mutex<bool>>,
|
||||||
|
fallback_remaining: Arc<Mutex<usize>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LeadWorkerProvider {
|
||||||
|
/// Create a new LeadWorkerProvider
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `lead_provider` - The provider to use for the initial turns
|
||||||
|
/// * `worker_provider` - The provider to use after lead_turns
|
||||||
|
/// * `lead_turns` - Number of turns to use the lead provider (default: 3)
|
||||||
|
pub fn new(
|
||||||
|
lead_provider: Arc<dyn Provider>,
|
||||||
|
worker_provider: Arc<dyn Provider>,
|
||||||
|
lead_turns: Option<usize>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
lead_provider,
|
||||||
|
worker_provider,
|
||||||
|
lead_turns: lead_turns.unwrap_or(3),
|
||||||
|
turn_count: Arc::new(Mutex::new(0)),
|
||||||
|
failure_count: Arc::new(Mutex::new(0)),
|
||||||
|
max_failures_before_fallback: 2, // Fallback after 2 consecutive failures
|
||||||
|
fallback_turns: 2, // Use lead model for 2 turns when in fallback mode
|
||||||
|
in_fallback_mode: Arc::new(Mutex::new(false)),
|
||||||
|
fallback_remaining: Arc::new(Mutex::new(0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new LeadWorkerProvider with custom settings
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `lead_provider` - The provider to use for the initial turns
|
||||||
|
/// * `worker_provider` - The provider to use after lead_turns
|
||||||
|
/// * `lead_turns` - Number of turns to use the lead provider
|
||||||
|
/// * `failure_threshold` - Number of consecutive failures before fallback
|
||||||
|
/// * `fallback_turns` - Number of turns to use lead model in fallback mode
|
||||||
|
pub fn new_with_settings(
|
||||||
|
lead_provider: Arc<dyn Provider>,
|
||||||
|
worker_provider: Arc<dyn Provider>,
|
||||||
|
lead_turns: usize,
|
||||||
|
failure_threshold: usize,
|
||||||
|
fallback_turns: usize,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
lead_provider,
|
||||||
|
worker_provider,
|
||||||
|
lead_turns,
|
||||||
|
turn_count: Arc::new(Mutex::new(0)),
|
||||||
|
failure_count: Arc::new(Mutex::new(0)),
|
||||||
|
max_failures_before_fallback: failure_threshold,
|
||||||
|
fallback_turns,
|
||||||
|
in_fallback_mode: Arc::new(Mutex::new(false)),
|
||||||
|
fallback_remaining: Arc::new(Mutex::new(0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset the turn counter and failure tracking (useful for new conversations)
|
||||||
|
pub async fn reset_turn_count(&self) {
|
||||||
|
let mut count = self.turn_count.lock().await;
|
||||||
|
*count = 0;
|
||||||
|
let mut failures = self.failure_count.lock().await;
|
||||||
|
*failures = 0;
|
||||||
|
let mut fallback = self.in_fallback_mode.lock().await;
|
||||||
|
*fallback = false;
|
||||||
|
let mut remaining = self.fallback_remaining.lock().await;
|
||||||
|
*remaining = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current turn count
|
||||||
|
pub async fn get_turn_count(&self) -> usize {
|
||||||
|
*self.turn_count.lock().await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current failure count
|
||||||
|
pub async fn get_failure_count(&self) -> usize {
|
||||||
|
*self.failure_count.lock().await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if currently in fallback mode
|
||||||
|
pub async fn is_in_fallback_mode(&self) -> bool {
|
||||||
|
*self.in_fallback_mode.lock().await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the currently active provider based on turn count and fallback state
|
||||||
|
async fn get_active_provider(&self) -> Arc<dyn Provider> {
|
||||||
|
let count = *self.turn_count.lock().await;
|
||||||
|
let in_fallback = *self.in_fallback_mode.lock().await;
|
||||||
|
|
||||||
|
// Use lead provider if we're in initial turns OR in fallback mode
|
||||||
|
if count < self.lead_turns || in_fallback {
|
||||||
|
Arc::clone(&self.lead_provider)
|
||||||
|
} else {
|
||||||
|
Arc::clone(&self.worker_provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle the result of a completion attempt and update failure tracking
|
||||||
|
async fn handle_completion_result(
|
||||||
|
&self,
|
||||||
|
result: &Result<(Message, ProviderUsage), ProviderError>,
|
||||||
|
) {
|
||||||
|
match result {
|
||||||
|
Ok((message, _usage)) => {
|
||||||
|
// Check for task-level failures in the response
|
||||||
|
let has_task_failure = self.detect_task_failures(message).await;
|
||||||
|
|
||||||
|
if has_task_failure {
|
||||||
|
// Task failure detected - increment failure count
|
||||||
|
let mut failures = self.failure_count.lock().await;
|
||||||
|
*failures += 1;
|
||||||
|
|
||||||
|
let failure_count = *failures;
|
||||||
|
let turn_count = *self.turn_count.lock().await;
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
"Task failure detected in response (failure count: {})",
|
||||||
|
failure_count
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check if we should trigger fallback
|
||||||
|
if turn_count >= self.lead_turns
|
||||||
|
&& !*self.in_fallback_mode.lock().await
|
||||||
|
&& failure_count >= self.max_failures_before_fallback
|
||||||
|
{
|
||||||
|
let mut in_fallback = self.in_fallback_mode.lock().await;
|
||||||
|
let mut fallback_remaining = self.fallback_remaining.lock().await;
|
||||||
|
|
||||||
|
*in_fallback = true;
|
||||||
|
*fallback_remaining = self.fallback_turns;
|
||||||
|
*failures = 0; // Reset failure count when entering fallback
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
"🔄 SWITCHING TO LEAD MODEL: Entering fallback mode after {} consecutive task failures - using lead model for {} turns",
|
||||||
|
self.max_failures_before_fallback,
|
||||||
|
self.fallback_turns
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Success - reset failure count and handle fallback mode
|
||||||
|
let mut failures = self.failure_count.lock().await;
|
||||||
|
*failures = 0;
|
||||||
|
|
||||||
|
let mut in_fallback = self.in_fallback_mode.lock().await;
|
||||||
|
let mut fallback_remaining = self.fallback_remaining.lock().await;
|
||||||
|
|
||||||
|
if *in_fallback {
|
||||||
|
*fallback_remaining -= 1;
|
||||||
|
if *fallback_remaining == 0 {
|
||||||
|
*in_fallback = false;
|
||||||
|
tracing::info!("✅ SWITCHING BACK TO WORKER MODEL: Exiting fallback mode - worker model resumed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment turn count on any completion (success or task failure)
|
||||||
|
let mut count = self.turn_count.lock().await;
|
||||||
|
*count += 1;
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// Technical failure - just log and let it bubble up
|
||||||
|
// For technical failures (API/LLM issues), we don't want to second-guess
|
||||||
|
// the model choice - just let the default model handle it
|
||||||
|
tracing::warn!(
|
||||||
|
"Technical failure detected - API/LLM issue, will use default model"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Don't increment turn count or failure tracking for technical failures
|
||||||
|
// as these are temporary infrastructure issues, not model capability issues
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Detect task-level failures in the model's response
|
||||||
|
async fn detect_task_failures(&self, message: &Message) -> bool {
|
||||||
|
let mut failure_indicators = 0;
|
||||||
|
|
||||||
|
for content in &message.content {
|
||||||
|
match content {
|
||||||
|
MessageContent::ToolRequest(tool_request) => {
|
||||||
|
// Check if tool request itself failed (malformed, etc.)
|
||||||
|
if tool_request.tool_call.is_err() {
|
||||||
|
failure_indicators += 1;
|
||||||
|
tracing::debug!(
|
||||||
|
"Failed tool request detected: {:?}",
|
||||||
|
tool_request.tool_call
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MessageContent::ToolResponse(tool_response) => {
|
||||||
|
// Check if tool execution failed
|
||||||
|
if let Err(tool_error) = &tool_response.tool_result {
|
||||||
|
failure_indicators += 1;
|
||||||
|
tracing::debug!("Tool execution failure detected: {:?}", tool_error);
|
||||||
|
} else if let Ok(contents) = &tool_response.tool_result {
|
||||||
|
// Check tool output for error indicators
|
||||||
|
if self.contains_error_indicators(contents) {
|
||||||
|
failure_indicators += 1;
|
||||||
|
tracing::debug!("Tool output contains error indicators");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MessageContent::Text(text_content) => {
|
||||||
|
// Check for user correction patterns or error acknowledgments
|
||||||
|
if self.contains_user_correction_patterns(&text_content.text) {
|
||||||
|
failure_indicators += 1;
|
||||||
|
tracing::debug!("User correction pattern detected in text");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consider it a failure if we have multiple failure indicators
|
||||||
|
failure_indicators >= 1
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if tool output contains error indicators
|
||||||
|
fn contains_error_indicators(&self, contents: &[Content]) -> bool {
|
||||||
|
for content in contents {
|
||||||
|
if let Content::Text(text_content) = content {
|
||||||
|
let text_lower = text_content.text.to_lowercase();
|
||||||
|
|
||||||
|
// Common error patterns in tool outputs
|
||||||
|
if text_lower.contains("error:")
|
||||||
|
|| text_lower.contains("failed:")
|
||||||
|
|| text_lower.contains("exception:")
|
||||||
|
|| text_lower.contains("traceback")
|
||||||
|
|| text_lower.contains("syntax error")
|
||||||
|
|| text_lower.contains("permission denied")
|
||||||
|
|| text_lower.contains("file not found")
|
||||||
|
|| text_lower.contains("command not found")
|
||||||
|
|| text_lower.contains("compilation failed")
|
||||||
|
|| text_lower.contains("test failed")
|
||||||
|
|| text_lower.contains("assertion failed")
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check for user correction patterns in text
|
||||||
|
fn contains_user_correction_patterns(&self, text: &str) -> bool {
|
||||||
|
let text_lower = text.to_lowercase();
|
||||||
|
|
||||||
|
// Patterns indicating user is correcting or expressing dissatisfaction
|
||||||
|
text_lower.contains("that's wrong")
|
||||||
|
|| text_lower.contains("that's not right")
|
||||||
|
|| text_lower.contains("that doesn't work")
|
||||||
|
|| text_lower.contains("try again")
|
||||||
|
|| text_lower.contains("let me correct")
|
||||||
|
|| text_lower.contains("actually, ")
|
||||||
|
|| text_lower.contains("no, that's")
|
||||||
|
|| text_lower.contains("that's incorrect")
|
||||||
|
|| text_lower.contains("fix this")
|
||||||
|
|| text_lower.contains("this is broken")
|
||||||
|
|| text_lower.contains("this doesn't")
|
||||||
|
|| text_lower.starts_with("no,")
|
||||||
|
|| text_lower.starts_with("wrong")
|
||||||
|
|| text_lower.starts_with("incorrect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LeadWorkerProviderTrait for LeadWorkerProvider {
|
||||||
|
/// Get information about the lead and worker models for logging
|
||||||
|
fn get_model_info(&self) -> (String, String) {
|
||||||
|
let lead_model = self.lead_provider.get_model_config().model_name;
|
||||||
|
let worker_model = self.worker_provider.get_model_config().model_name;
|
||||||
|
(lead_model, worker_model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for LeadWorkerProvider {
|
||||||
|
fn metadata() -> ProviderMetadata {
|
||||||
|
// This is a wrapper provider, so we return minimal metadata
|
||||||
|
ProviderMetadata::new(
|
||||||
|
"lead_worker",
|
||||||
|
"Lead/Worker Provider",
|
||||||
|
"A provider that switches between lead and worker models based on turn count",
|
||||||
|
"", // No default model as this is determined by the wrapped providers
|
||||||
|
vec![], // No known models as this depends on wrapped providers
|
||||||
|
"", // No doc link
|
||||||
|
vec![], // No config keys as configuration is done through wrapped providers
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model_config(&self) -> ModelConfig {
|
||||||
|
// Return the lead provider's model config as the default
|
||||||
|
// In practice, this might need to be more sophisticated
|
||||||
|
self.lead_provider.get_model_config()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn complete(
|
||||||
|
&self,
|
||||||
|
system: &str,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: &[Tool],
|
||||||
|
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||||
|
// Get the active provider
|
||||||
|
let provider = self.get_active_provider().await;
|
||||||
|
|
||||||
|
// Log which provider is being used
|
||||||
|
let turn_count = *self.turn_count.lock().await;
|
||||||
|
let in_fallback = *self.in_fallback_mode.lock().await;
|
||||||
|
let fallback_remaining = *self.fallback_remaining.lock().await;
|
||||||
|
|
||||||
|
let provider_type = if turn_count < self.lead_turns {
|
||||||
|
"lead (initial)"
|
||||||
|
} else if in_fallback {
|
||||||
|
"lead (fallback)"
|
||||||
|
} else {
|
||||||
|
"worker"
|
||||||
|
};
|
||||||
|
|
||||||
|
if in_fallback {
|
||||||
|
tracing::info!(
|
||||||
|
"🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining)",
|
||||||
|
provider_type,
|
||||||
|
turn_count + 1,
|
||||||
|
fallback_remaining
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
tracing::info!(
|
||||||
|
"Using {} provider for turn {} (lead_turns: {})",
|
||||||
|
provider_type,
|
||||||
|
turn_count + 1,
|
||||||
|
self.lead_turns
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the completion request
|
||||||
|
let result = provider.complete(system, messages, tools).await;
|
||||||
|
|
||||||
|
// For technical failures, try with default model (lead provider) instead
|
||||||
|
let final_result = match &result {
|
||||||
|
Err(_) => {
|
||||||
|
tracing::warn!("Technical failure with {} provider, retrying with default model (lead provider)", provider_type);
|
||||||
|
|
||||||
|
// Try with lead provider as the default/fallback for technical failures
|
||||||
|
let default_result = self.lead_provider.complete(system, messages, tools).await;
|
||||||
|
|
||||||
|
match &default_result {
|
||||||
|
Ok(_) => {
|
||||||
|
tracing::info!(
|
||||||
|
"✅ Default model (lead provider) succeeded after technical failure"
|
||||||
|
);
|
||||||
|
default_result
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
tracing::error!("❌ Default model (lead provider) also failed - returning original error");
|
||||||
|
result // Return the original error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(_) => result, // Success with original provider
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle the result and update tracking (only for successful completions)
|
||||||
|
self.handle_completion_result(&final_result).await;
|
||||||
|
|
||||||
|
final_result
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_supported_models_async(&self) -> Result<Option<Vec<String>>, ProviderError> {
|
||||||
|
// Combine models from both providers
|
||||||
|
let lead_models = self.lead_provider.fetch_supported_models_async().await?;
|
||||||
|
let worker_models = self.worker_provider.fetch_supported_models_async().await?;
|
||||||
|
|
||||||
|
match (lead_models, worker_models) {
|
||||||
|
(Some(lead), Some(worker)) => {
|
||||||
|
let mut all_models = lead;
|
||||||
|
all_models.extend(worker);
|
||||||
|
all_models.sort();
|
||||||
|
all_models.dedup();
|
||||||
|
Ok(Some(all_models))
|
||||||
|
}
|
||||||
|
(Some(models), None) | (None, Some(models)) => Ok(Some(models)),
|
||||||
|
(None, None) => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_embeddings(&self) -> bool {
|
||||||
|
// Support embeddings if either provider supports them
|
||||||
|
self.lead_provider.supports_embeddings() || self.worker_provider.supports_embeddings()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
|
||||||
|
// Use the lead provider for embeddings if it supports them, otherwise use worker
|
||||||
|
if self.lead_provider.supports_embeddings() {
|
||||||
|
self.lead_provider.create_embeddings(texts).await
|
||||||
|
} else if self.worker_provider.supports_embeddings() {
|
||||||
|
self.worker_provider.create_embeddings(texts).await
|
||||||
|
} else {
|
||||||
|
Err(ProviderError::ExecutionError(
|
||||||
|
"Neither lead nor worker provider supports embeddings".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if this provider is a LeadWorkerProvider
|
||||||
|
fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
|
||||||
|
Some(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::message::MessageContent;
|
||||||
|
use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage};
|
||||||
|
use chrono::Utc;
|
||||||
|
use mcp_core::{content::TextContent, Role};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct MockProvider {
|
||||||
|
name: String,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for MockProvider {
|
||||||
|
fn metadata() -> ProviderMetadata {
|
||||||
|
ProviderMetadata::empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model_config(&self) -> ModelConfig {
|
||||||
|
self.model_config.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn complete(
|
||||||
|
&self,
|
||||||
|
_system: &str,
|
||||||
|
_messages: &[Message],
|
||||||
|
_tools: &[Tool],
|
||||||
|
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||||
|
Ok((
|
||||||
|
Message {
|
||||||
|
role: Role::Assistant,
|
||||||
|
created: Utc::now().timestamp(),
|
||||||
|
content: vec![MessageContent::Text(TextContent {
|
||||||
|
text: format!("Response from {}", self.name),
|
||||||
|
annotations: None,
|
||||||
|
})],
|
||||||
|
},
|
||||||
|
ProviderUsage::new(self.name.clone(), Usage::default()),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_lead_worker_switching() {
|
||||||
|
let lead_provider = Arc::new(MockProvider {
|
||||||
|
name: "lead".to_string(),
|
||||||
|
model_config: ModelConfig::new("lead-model".to_string()),
|
||||||
|
});
|
||||||
|
|
||||||
|
let worker_provider = Arc::new(MockProvider {
|
||||||
|
name: "worker".to_string(),
|
||||||
|
model_config: ModelConfig::new("worker-model".to_string()),
|
||||||
|
});
|
||||||
|
|
||||||
|
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(3));
|
||||||
|
|
||||||
|
// First three turns should use lead provider
|
||||||
|
for i in 0..3 {
|
||||||
|
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
|
||||||
|
assert_eq!(usage.model, "lead");
|
||||||
|
assert_eq!(provider.get_turn_count().await, i + 1);
|
||||||
|
assert!(!provider.is_in_fallback_mode().await);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subsequent turns should use worker provider
|
||||||
|
for i in 3..6 {
|
||||||
|
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
|
||||||
|
assert_eq!(usage.model, "worker");
|
||||||
|
assert_eq!(provider.get_turn_count().await, i + 1);
|
||||||
|
assert!(!provider.is_in_fallback_mode().await);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset and verify it goes back to lead
|
||||||
|
provider.reset_turn_count().await;
|
||||||
|
assert_eq!(provider.get_turn_count().await, 0);
|
||||||
|
assert_eq!(provider.get_failure_count().await, 0);
|
||||||
|
assert!(!provider.is_in_fallback_mode().await);
|
||||||
|
|
||||||
|
let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
|
||||||
|
assert_eq!(usage.model, "lead");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_technical_failure_retry() {
|
||||||
|
let lead_provider = Arc::new(MockFailureProvider {
|
||||||
|
name: "lead".to_string(),
|
||||||
|
model_config: ModelConfig::new("lead-model".to_string()),
|
||||||
|
should_fail: false, // Lead provider works
|
||||||
|
});
|
||||||
|
|
||||||
|
let worker_provider = Arc::new(MockFailureProvider {
|
||||||
|
name: "worker".to_string(),
|
||||||
|
model_config: ModelConfig::new("worker-model".to_string()),
|
||||||
|
should_fail: true, // Worker will fail
|
||||||
|
});
|
||||||
|
|
||||||
|
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
|
||||||
|
|
||||||
|
// First two turns use lead (should succeed)
|
||||||
|
for _i in 0..2 {
|
||||||
|
let result = provider.complete("system", &[], &[]).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(result.unwrap().1.model, "lead");
|
||||||
|
assert!(!provider.is_in_fallback_mode().await);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next turn uses worker (will fail, but should retry with lead and succeed)
|
||||||
|
let result = provider.complete("system", &[], &[]).await;
|
||||||
|
assert!(result.is_ok()); // Should succeed because lead provider is used as fallback
|
||||||
|
assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider
|
||||||
|
assert_eq!(provider.get_failure_count().await, 0); // No failure tracking for technical failures
|
||||||
|
assert!(!provider.is_in_fallback_mode().await); // Not in fallback mode
|
||||||
|
|
||||||
|
// Another turn - should still try worker first, then retry with lead
|
||||||
|
let result = provider.complete("system", &[], &[]).await;
|
||||||
|
assert!(result.is_ok()); // Should succeed because lead provider is used as fallback
|
||||||
|
assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider
|
||||||
|
assert_eq!(provider.get_failure_count().await, 0); // Still no failure tracking
|
||||||
|
assert!(!provider.is_in_fallback_mode().await); // Still not in fallback mode
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_fallback_on_task_failures() {
|
||||||
|
// Test that task failures (not technical failures) still trigger fallback mode
|
||||||
|
// This would need a different mock that simulates task failures in successful responses
|
||||||
|
// For now, we'll test the fallback mode functionality directly
|
||||||
|
let lead_provider = Arc::new(MockFailureProvider {
|
||||||
|
name: "lead".to_string(),
|
||||||
|
model_config: ModelConfig::new("lead-model".to_string()),
|
||||||
|
should_fail: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
let worker_provider = Arc::new(MockFailureProvider {
|
||||||
|
name: "worker".to_string(),
|
||||||
|
model_config: ModelConfig::new("worker-model".to_string()),
|
||||||
|
should_fail: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
|
||||||
|
|
||||||
|
// Simulate being in fallback mode
|
||||||
|
{
|
||||||
|
let mut in_fallback = provider.in_fallback_mode.lock().await;
|
||||||
|
*in_fallback = true;
|
||||||
|
let mut fallback_remaining = provider.fallback_remaining.lock().await;
|
||||||
|
*fallback_remaining = 2;
|
||||||
|
let mut turn_count = provider.turn_count.lock().await;
|
||||||
|
*turn_count = 4; // Past initial lead turns
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should use lead provider in fallback mode
|
||||||
|
let result = provider.complete("system", &[], &[]).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(result.unwrap().1.model, "lead");
|
||||||
|
assert!(provider.is_in_fallback_mode().await);
|
||||||
|
|
||||||
|
// One more fallback turn
|
||||||
|
let result = provider.complete("system", &[], &[]).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(result.unwrap().1.model, "lead");
|
||||||
|
assert!(!provider.is_in_fallback_mode().await); // Should exit fallback mode
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct MockFailureProvider {
|
||||||
|
name: String,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
should_fail: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for MockFailureProvider {
|
||||||
|
fn metadata() -> ProviderMetadata {
|
||||||
|
ProviderMetadata::empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model_config(&self) -> ModelConfig {
|
||||||
|
self.model_config.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn complete(
|
||||||
|
&self,
|
||||||
|
_system: &str,
|
||||||
|
_messages: &[Message],
|
||||||
|
_tools: &[Tool],
|
||||||
|
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||||
|
if self.should_fail {
|
||||||
|
Err(ProviderError::ExecutionError(
|
||||||
|
"Simulated failure".to_string(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Ok((
|
||||||
|
Message {
|
||||||
|
role: Role::Assistant,
|
||||||
|
created: Utc::now().timestamp(),
|
||||||
|
content: vec![MessageContent::Text(TextContent {
|
||||||
|
text: format!("Response from {}", self.name),
|
||||||
|
annotations: None,
|
||||||
|
})],
|
||||||
|
},
|
||||||
|
ProviderUsage::new(self.name.clone(), Usage::default()),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ pub mod gcpvertexai;
|
|||||||
pub mod githubcopilot;
|
pub mod githubcopilot;
|
||||||
pub mod google;
|
pub mod google;
|
||||||
pub mod groq;
|
pub mod groq;
|
||||||
|
pub mod lead_worker;
|
||||||
pub mod oauth;
|
pub mod oauth;
|
||||||
pub mod ollama;
|
pub mod ollama;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ Goose supports various environment variables that allow you to customize its beh
|
|||||||
## Model Configuration
|
## Model Configuration
|
||||||
|
|
||||||
These variables control the [language models](/docs/getting-started/providers) and their behavior.
|
These variables control the [language models](/docs/getting-started/providers) and their behavior.
|
||||||
|
|
||||||
### Basic Provider Configuration
|
### Basic Provider Configuration
|
||||||
|
|
||||||
These are the minimum required variables to get started with Goose.
|
These are the minimum required variables to get started with Goose.
|
||||||
@@ -27,6 +28,7 @@ export GOOSE_PROVIDER="anthropic"
|
|||||||
export GOOSE_MODEL="claude-3.5-sonnet"
|
export GOOSE_MODEL="claude-3.5-sonnet"
|
||||||
export GOOSE_TEMPERATURE=0.7
|
export GOOSE_TEMPERATURE=0.7
|
||||||
```
|
```
|
||||||
|
|
||||||
### Advanced Provider Configuration
|
### Advanced Provider Configuration
|
||||||
|
|
||||||
These variables are needed when using custom endpoints, enterprise deployments, or specific provider implementations.
|
These variables are needed when using custom endpoints, enterprise deployments, or specific provider implementations.
|
||||||
@@ -45,7 +47,34 @@ export GOOSE_PROVIDER__TYPE="anthropic"
|
|||||||
export GOOSE_PROVIDER__HOST="https://api.anthropic.com"
|
export GOOSE_PROVIDER__HOST="https://api.anthropic.com"
|
||||||
export GOOSE_PROVIDER__API_KEY="your-api-key-here"
|
export GOOSE_PROVIDER__API_KEY="your-api-key-here"
|
||||||
```
|
```
|
||||||
## Planning Mode Configuration
|
|
||||||
|
### Lead/Worker Model Configuration
|
||||||
|
|
||||||
|
Configure a lead/worker model pattern where a powerful model handles initial planning and complex reasoning, then switches to a faster/cheaper model for execution.
|
||||||
|
|
||||||
|
| Variable | Purpose | Values | Default |
|
||||||
|
|----------|---------|---------|---------|
|
||||||
|
| `GOOSE_LEAD_MODEL` | **Required to enable lead mode.** Specifies the lead model name | Model name (e.g., "gpt-4o", "claude-3.5-sonnet") | None |
|
||||||
|
| `GOOSE_LEAD_PROVIDER` | Provider for the lead model | [See available providers](/docs/getting-started/providers#available-providers) | Falls back to GOOSE_PROVIDER |
|
||||||
|
| `GOOSE_LEAD_TURNS` | Number of initial turns using the lead model | Integer | 3 |
|
||||||
|
| `GOOSE_LEAD_FAILURE_THRESHOLD` | Consecutive failures before fallback to lead model | Integer | 2 |
|
||||||
|
| `GOOSE_LEAD_FALLBACK_TURNS` | Number of turns to use lead model in fallback mode | Integer | 2 |
|
||||||
|
|
||||||
|
**Examples**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Basic lead/worker setup
|
||||||
|
export GOOSE_LEAD_MODEL="o4"
|
||||||
|
|
||||||
|
# Advanced lead/worker configuration
|
||||||
|
export GOOSE_LEAD_MODEL="claude4-opus"
|
||||||
|
export GOOSE_LEAD_PROVIDER="anthropic"
|
||||||
|
export GOOSE_LEAD_TURNS=5
|
||||||
|
export GOOSE_LEAD_FAILURE_THRESHOLD=3
|
||||||
|
export GOOSE_LEAD_FALLBACK_TURNS=2
|
||||||
|
```
|
||||||
|
|
||||||
|
### Planning Mode Configuration
|
||||||
|
|
||||||
These variables control Goose's [planning functionality](/docs/guides/creating-plans).
|
These variables control Goose's [planning functionality](/docs/guides/creating-plans).
|
||||||
|
|
||||||
|
|||||||
31
test_lead_worker.sh
Executable file
31
test_lead_worker.sh
Executable file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Test script for lead/worker provider functionality
|
||||||
|
|
||||||
|
# Set up test environment variables
|
||||||
|
export GOOSE_PROVIDER="openai"
|
||||||
|
export GOOSE_MODEL="gpt-4o-mini"
|
||||||
|
export OPENAI_API_KEY="test-key"
|
||||||
|
|
||||||
|
# Test 1: Default behavior (no lead/worker)
|
||||||
|
echo "Test 1: Default behavior (no lead/worker)"
|
||||||
|
unset GOOSE_LEAD_MODEL
|
||||||
|
unset GOOSE_WORKER_MODEL
|
||||||
|
unset GOOSE_LEAD_TURNS
|
||||||
|
|
||||||
|
# Test 2: Lead/worker with same provider
|
||||||
|
echo -e "\nTest 2: Lead/worker with same provider"
|
||||||
|
export GOOSE_LEAD_MODEL="gpt-4o"
|
||||||
|
export GOOSE_WORKER_MODEL="gpt-4o-mini"
|
||||||
|
export GOOSE_LEAD_TURNS="3"
|
||||||
|
|
||||||
|
# Test 3: Lead/worker with default worker (uses main model)
|
||||||
|
echo -e "\nTest 3: Lead/worker with default worker"
|
||||||
|
export GOOSE_LEAD_MODEL="gpt-4o"
|
||||||
|
unset GOOSE_WORKER_MODEL
|
||||||
|
export GOOSE_LEAD_TURNS="5"
|
||||||
|
|
||||||
|
echo -e "\nConfiguration examples:"
|
||||||
|
echo "- Default: Uses GOOSE_MODEL for all turns"
|
||||||
|
echo "- Lead/Worker: Set GOOSE_LEAD_MODEL to use a different model for initial turns"
|
||||||
|
echo "- GOOSE_LEAD_TURNS: Number of turns to use lead model (default: 5)"
|
||||||
|
echo "- GOOSE_WORKER_MODEL: Model to use after lead turns (default: GOOSE_MODEL)"
|
||||||
Reference in New Issue
Block a user