mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 22:24:21 +01:00
381 lines
11 KiB
Rust
381 lines
11 KiB
Rust
use anyhow::Result;
|
|
use dotenv::dotenv;
|
|
use goose_llm::message::{Message, MessageContent};
|
|
use goose_llm::providers::base::Provider;
|
|
use goose_llm::providers::errors::ProviderError;
|
|
use goose_llm::providers::{databricks, openai};
|
|
use goose_llm::types::core::{Content, Tool};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::sync::Mutex;
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
enum TestStatus {
|
|
Passed,
|
|
Skipped,
|
|
Failed,
|
|
}
|
|
|
|
impl std::fmt::Display for TestStatus {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
match self {
|
|
TestStatus::Passed => write!(f, "✅"),
|
|
TestStatus::Skipped => write!(f, "⏭️"),
|
|
TestStatus::Failed => write!(f, "❌"),
|
|
}
|
|
}
|
|
}
|
|
|
|
struct TestReport {
|
|
results: Mutex<HashMap<String, TestStatus>>,
|
|
}
|
|
|
|
impl TestReport {
|
|
fn new() -> Arc<Self> {
|
|
Arc::new(Self {
|
|
results: Mutex::new(HashMap::new()),
|
|
})
|
|
}
|
|
|
|
fn record_status(&self, provider: &str, status: TestStatus) {
|
|
let mut results = self.results.lock().unwrap();
|
|
results.insert(provider.to_string(), status);
|
|
}
|
|
|
|
fn record_pass(&self, provider: &str) {
|
|
self.record_status(provider, TestStatus::Passed);
|
|
}
|
|
|
|
fn record_skip(&self, provider: &str) {
|
|
self.record_status(provider, TestStatus::Skipped);
|
|
}
|
|
|
|
fn record_fail(&self, provider: &str) {
|
|
self.record_status(provider, TestStatus::Failed);
|
|
}
|
|
|
|
fn print_summary(&self) {
|
|
println!("\n============== Providers ==============");
|
|
let results = self.results.lock().unwrap();
|
|
let mut providers: Vec<_> = results.iter().collect();
|
|
providers.sort_by(|a, b| a.0.cmp(b.0));
|
|
|
|
for (provider, status) in providers {
|
|
println!("{} {}", status, provider);
|
|
}
|
|
println!("=======================================\n");
|
|
}
|
|
}
|
|
|
|
lazy_static::lazy_static! {
|
|
static ref TEST_REPORT: Arc<TestReport> = TestReport::new();
|
|
static ref ENV_LOCK: Mutex<()> = Mutex::new(());
|
|
}
|
|
|
|
/// Generic test harness for any Provider implementation
|
|
struct ProviderTester {
|
|
provider: Arc<dyn Provider>,
|
|
name: String,
|
|
}
|
|
|
|
impl ProviderTester {
|
|
fn new<T: Provider + Send + Sync + 'static>(provider: T, name: String) -> Self {
|
|
Self {
|
|
provider: Arc::new(provider),
|
|
name,
|
|
}
|
|
}
|
|
|
|
async fn test_basic_response(&self) -> Result<()> {
|
|
let message = Message::user().with_text("Just say hello!");
|
|
|
|
let response = self
|
|
.provider
|
|
.complete("You are a helpful assistant.", &[message], &[])
|
|
.await?;
|
|
|
|
// For a basic response, we expect a single text response
|
|
assert_eq!(
|
|
response.message.content.len(),
|
|
1,
|
|
"Expected single content item in response"
|
|
);
|
|
|
|
// Verify we got a text response
|
|
assert!(
|
|
matches!(response.message.content[0], MessageContent::Text(_)),
|
|
"Expected text response"
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn test_tool_usage(&self) -> Result<()> {
|
|
let weather_tool = Tool::new(
|
|
"get_weather",
|
|
"Get the weather for a location",
|
|
serde_json::json!({
|
|
"type": "object",
|
|
"required": ["location"],
|
|
"properties": {
|
|
"location": {
|
|
"type": "string",
|
|
"description": "The city and state, e.g. San Francisco, CA"
|
|
}
|
|
}
|
|
}),
|
|
);
|
|
|
|
let message = Message::user().with_text("What's the weather like in San Francisco?");
|
|
|
|
let response1 = self
|
|
.provider
|
|
.complete(
|
|
"You are a helpful weather assistant.",
|
|
&[message.clone()],
|
|
&[weather_tool.clone()],
|
|
)
|
|
.await?;
|
|
|
|
println!("=== {}::reponse1 ===", self.name);
|
|
dbg!(&response1);
|
|
println!("===================");
|
|
|
|
// Verify we got a tool request
|
|
assert!(
|
|
response1
|
|
.message
|
|
.content
|
|
.iter()
|
|
.any(|content| matches!(content, MessageContent::ToolReq(_))),
|
|
"Expected tool request in response"
|
|
);
|
|
|
|
let id = &response1
|
|
.message
|
|
.content
|
|
.iter()
|
|
.filter_map(|message| message.as_tool_request())
|
|
.last()
|
|
.expect("got tool request")
|
|
.id;
|
|
|
|
let weather = Message::user().with_tool_response(
|
|
id,
|
|
Ok(vec![Content::text(
|
|
"
|
|
50°F°C
|
|
Precipitation: 0%
|
|
Humidity: 84%
|
|
Wind: 2 mph
|
|
Weather
|
|
Saturday 9:00 PM
|
|
Clear",
|
|
)])
|
|
.into(),
|
|
);
|
|
|
|
// Verify we construct a valid payload including the request/response pair for the next inference
|
|
let response2 = self
|
|
.provider
|
|
.complete(
|
|
"You are a helpful weather assistant.",
|
|
&[message, response1.message, weather],
|
|
&[weather_tool],
|
|
)
|
|
.await?;
|
|
|
|
println!("=== {}::reponse2 ===", self.name);
|
|
dbg!(&response2);
|
|
println!("===================");
|
|
|
|
assert!(
|
|
response2
|
|
.message
|
|
.content
|
|
.iter()
|
|
.any(|content| matches!(content, MessageContent::Text(_))),
|
|
"Expected text for final response"
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn test_context_length_exceeded_error(&self) -> Result<()> {
|
|
// Google Gemini has a really long context window
|
|
let large_message_content = if self.name.to_lowercase() == "google" {
|
|
"hello ".repeat(1_300_000)
|
|
} else {
|
|
"hello ".repeat(300_000)
|
|
};
|
|
|
|
let messages = vec![
|
|
Message::user().with_text("hi there. what is 2 + 2?"),
|
|
Message::assistant().with_text("hey! I think it's 4."),
|
|
Message::user().with_text(&large_message_content),
|
|
Message::assistant().with_text("heyy!!"),
|
|
// Messages before this mark should be truncated
|
|
Message::user().with_text("what's the meaning of life?"),
|
|
Message::assistant().with_text("the meaning of life is 42"),
|
|
Message::user().with_text(
|
|
"did I ask you what's 2+2 in this message history? just respond with 'yes' or 'no'",
|
|
),
|
|
];
|
|
|
|
// Test that we get ProviderError::ContextLengthExceeded when the context window is exceeded
|
|
let result = self
|
|
.provider
|
|
.complete("You are a helpful assistant.", &messages, &[])
|
|
.await;
|
|
|
|
// Print some debug info
|
|
println!("=== {}::context_length_exceeded_error ===", self.name);
|
|
dbg!(&result);
|
|
println!("===================");
|
|
|
|
// Ollama truncates by default even when the context window is exceeded
|
|
if self.name.to_lowercase() == "ollama" {
|
|
assert!(
|
|
result.is_ok(),
|
|
"Expected to succeed because of default truncation"
|
|
);
|
|
return Ok(());
|
|
}
|
|
|
|
assert!(
|
|
result.is_err(),
|
|
"Expected error when context window is exceeded"
|
|
);
|
|
assert!(
|
|
matches!(result.unwrap_err(), ProviderError::ContextLengthExceeded(_)),
|
|
"Expected error to be ContextLengthExceeded"
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Run all provider tests
|
|
async fn run_test_suite(&self) -> Result<()> {
|
|
self.test_basic_response().await?;
|
|
self.test_tool_usage().await?;
|
|
self.test_context_length_exceeded_error().await?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
fn load_env() {
|
|
if let Ok(path) = dotenv() {
|
|
println!("Loaded environment from {:?}", path);
|
|
}
|
|
}
|
|
|
|
/// Helper function to run a provider test with proper error handling and reporting
|
|
async fn test_provider<F, T>(
|
|
name: &str,
|
|
required_vars: &[&str],
|
|
env_modifications: Option<HashMap<&str, Option<String>>>,
|
|
provider_fn: F,
|
|
) -> Result<()>
|
|
where
|
|
F: FnOnce() -> T,
|
|
T: Provider + Send + Sync + 'static,
|
|
{
|
|
// We start off as failed, so that if the process panics it is seen as a failure
|
|
TEST_REPORT.record_fail(name);
|
|
|
|
// Take exclusive access to environment modifications
|
|
let lock = ENV_LOCK.lock().unwrap();
|
|
|
|
load_env();
|
|
|
|
// Save current environment state for required vars and modified vars
|
|
let mut original_env = HashMap::new();
|
|
for &var in required_vars {
|
|
if let Ok(val) = std::env::var(var) {
|
|
original_env.insert(var, val);
|
|
}
|
|
}
|
|
if let Some(mods) = &env_modifications {
|
|
for &var in mods.keys() {
|
|
if let Ok(val) = std::env::var(var) {
|
|
original_env.insert(var, val);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Apply any environment modifications
|
|
if let Some(mods) = &env_modifications {
|
|
for (&var, value) in mods.iter() {
|
|
match value {
|
|
Some(val) => std::env::set_var(var, val),
|
|
None => std::env::remove_var(var),
|
|
}
|
|
}
|
|
}
|
|
|
|
// Setup the provider
|
|
let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err());
|
|
if missing_vars {
|
|
println!("Skipping {} tests - credentials not configured", name);
|
|
TEST_REPORT.record_skip(name);
|
|
return Ok(());
|
|
}
|
|
|
|
let provider = provider_fn();
|
|
|
|
// Restore original environment
|
|
for (&var, value) in original_env.iter() {
|
|
std::env::set_var(var, value);
|
|
}
|
|
if let Some(mods) = env_modifications {
|
|
for &var in mods.keys() {
|
|
if !original_env.contains_key(var) {
|
|
std::env::remove_var(var);
|
|
}
|
|
}
|
|
}
|
|
|
|
std::mem::drop(lock);
|
|
|
|
let tester = ProviderTester::new(provider, name.to_string());
|
|
match tester.run_test_suite().await {
|
|
Ok(_) => {
|
|
TEST_REPORT.record_pass(name);
|
|
Ok(())
|
|
}
|
|
Err(e) => {
|
|
println!("{} test failed: {}", name, e);
|
|
TEST_REPORT.record_fail(name);
|
|
Err(e)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn openai_complete() -> Result<()> {
|
|
test_provider(
|
|
"OpenAI",
|
|
&["OPENAI_API_KEY"],
|
|
None,
|
|
openai::OpenAiProvider::default,
|
|
)
|
|
.await
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn databricks_complete() -> Result<()> {
|
|
test_provider(
|
|
"Databricks",
|
|
&["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
|
|
None,
|
|
databricks::DatabricksProvider::default,
|
|
)
|
|
.await
|
|
}
|
|
|
|
// Print the final test report
|
|
#[ctor::dtor]
|
|
fn print_test_report() {
|
|
TEST_REPORT.print_summary();
|
|
}
|