Files
goose/crates/goose-llm/tests/providers_complete.rs

381 lines
11 KiB
Rust

use anyhow::Result;
use dotenv::dotenv;
use goose_llm::message::{Message, MessageContent};
use goose_llm::providers::base::Provider;
use goose_llm::providers::errors::ProviderError;
use goose_llm::providers::{databricks, openai};
use goose_llm::types::core::{Content, Tool};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
#[derive(Debug, Clone, Copy)]
enum TestStatus {
Passed,
Skipped,
Failed,
}
impl std::fmt::Display for TestStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TestStatus::Passed => write!(f, ""),
TestStatus::Skipped => write!(f, "⏭️"),
TestStatus::Failed => write!(f, ""),
}
}
}
struct TestReport {
results: Mutex<HashMap<String, TestStatus>>,
}
impl TestReport {
fn new() -> Arc<Self> {
Arc::new(Self {
results: Mutex::new(HashMap::new()),
})
}
fn record_status(&self, provider: &str, status: TestStatus) {
let mut results = self.results.lock().unwrap();
results.insert(provider.to_string(), status);
}
fn record_pass(&self, provider: &str) {
self.record_status(provider, TestStatus::Passed);
}
fn record_skip(&self, provider: &str) {
self.record_status(provider, TestStatus::Skipped);
}
fn record_fail(&self, provider: &str) {
self.record_status(provider, TestStatus::Failed);
}
fn print_summary(&self) {
println!("\n============== Providers ==============");
let results = self.results.lock().unwrap();
let mut providers: Vec<_> = results.iter().collect();
providers.sort_by(|a, b| a.0.cmp(b.0));
for (provider, status) in providers {
println!("{} {}", status, provider);
}
println!("=======================================\n");
}
}
lazy_static::lazy_static! {
static ref TEST_REPORT: Arc<TestReport> = TestReport::new();
static ref ENV_LOCK: Mutex<()> = Mutex::new(());
}
/// Generic test harness for any Provider implementation
struct ProviderTester {
provider: Arc<dyn Provider>,
name: String,
}
impl ProviderTester {
fn new<T: Provider + Send + Sync + 'static>(provider: T, name: String) -> Self {
Self {
provider: Arc::new(provider),
name,
}
}
async fn test_basic_response(&self) -> Result<()> {
let message = Message::user().with_text("Just say hello!");
let response = self
.provider
.complete("You are a helpful assistant.", &[message], &[])
.await?;
// For a basic response, we expect a single text response
assert_eq!(
response.message.content.len(),
1,
"Expected single content item in response"
);
// Verify we got a text response
assert!(
matches!(response.message.content[0], MessageContent::Text(_)),
"Expected text response"
);
Ok(())
}
async fn test_tool_usage(&self) -> Result<()> {
let weather_tool = Tool::new(
"get_weather",
"Get the weather for a location",
serde_json::json!({
"type": "object",
"required": ["location"],
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
}
}
}),
);
let message = Message::user().with_text("What's the weather like in San Francisco?");
let response1 = self
.provider
.complete(
"You are a helpful weather assistant.",
&[message.clone()],
&[weather_tool.clone()],
)
.await?;
println!("=== {}::reponse1 ===", self.name);
dbg!(&response1);
println!("===================");
// Verify we got a tool request
assert!(
response1
.message
.content
.iter()
.any(|content| matches!(content, MessageContent::ToolReq(_))),
"Expected tool request in response"
);
let id = &response1
.message
.content
.iter()
.filter_map(|message| message.as_tool_request())
.last()
.expect("got tool request")
.id;
let weather = Message::user().with_tool_response(
id,
Ok(vec![Content::text(
"
50°F°C
Precipitation: 0%
Humidity: 84%
Wind: 2 mph
Weather
Saturday 9:00 PM
Clear",
)])
.into(),
);
// Verify we construct a valid payload including the request/response pair for the next inference
let response2 = self
.provider
.complete(
"You are a helpful weather assistant.",
&[message, response1.message, weather],
&[weather_tool],
)
.await?;
println!("=== {}::reponse2 ===", self.name);
dbg!(&response2);
println!("===================");
assert!(
response2
.message
.content
.iter()
.any(|content| matches!(content, MessageContent::Text(_))),
"Expected text for final response"
);
Ok(())
}
async fn test_context_length_exceeded_error(&self) -> Result<()> {
// Google Gemini has a really long context window
let large_message_content = if self.name.to_lowercase() == "google" {
"hello ".repeat(1_300_000)
} else {
"hello ".repeat(300_000)
};
let messages = vec![
Message::user().with_text("hi there. what is 2 + 2?"),
Message::assistant().with_text("hey! I think it's 4."),
Message::user().with_text(&large_message_content),
Message::assistant().with_text("heyy!!"),
// Messages before this mark should be truncated
Message::user().with_text("what's the meaning of life?"),
Message::assistant().with_text("the meaning of life is 42"),
Message::user().with_text(
"did I ask you what's 2+2 in this message history? just respond with 'yes' or 'no'",
),
];
// Test that we get ProviderError::ContextLengthExceeded when the context window is exceeded
let result = self
.provider
.complete("You are a helpful assistant.", &messages, &[])
.await;
// Print some debug info
println!("=== {}::context_length_exceeded_error ===", self.name);
dbg!(&result);
println!("===================");
// Ollama truncates by default even when the context window is exceeded
if self.name.to_lowercase() == "ollama" {
assert!(
result.is_ok(),
"Expected to succeed because of default truncation"
);
return Ok(());
}
assert!(
result.is_err(),
"Expected error when context window is exceeded"
);
assert!(
matches!(result.unwrap_err(), ProviderError::ContextLengthExceeded(_)),
"Expected error to be ContextLengthExceeded"
);
Ok(())
}
/// Run all provider tests
async fn run_test_suite(&self) -> Result<()> {
self.test_basic_response().await?;
self.test_tool_usage().await?;
self.test_context_length_exceeded_error().await?;
Ok(())
}
}
fn load_env() {
if let Ok(path) = dotenv() {
println!("Loaded environment from {:?}", path);
}
}
/// Helper function to run a provider test with proper error handling and reporting
async fn test_provider<F, T>(
name: &str,
required_vars: &[&str],
env_modifications: Option<HashMap<&str, Option<String>>>,
provider_fn: F,
) -> Result<()>
where
F: FnOnce() -> T,
T: Provider + Send + Sync + 'static,
{
// We start off as failed, so that if the process panics it is seen as a failure
TEST_REPORT.record_fail(name);
// Take exclusive access to environment modifications
let lock = ENV_LOCK.lock().unwrap();
load_env();
// Save current environment state for required vars and modified vars
let mut original_env = HashMap::new();
for &var in required_vars {
if let Ok(val) = std::env::var(var) {
original_env.insert(var, val);
}
}
if let Some(mods) = &env_modifications {
for &var in mods.keys() {
if let Ok(val) = std::env::var(var) {
original_env.insert(var, val);
}
}
}
// Apply any environment modifications
if let Some(mods) = &env_modifications {
for (&var, value) in mods.iter() {
match value {
Some(val) => std::env::set_var(var, val),
None => std::env::remove_var(var),
}
}
}
// Setup the provider
let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err());
if missing_vars {
println!("Skipping {} tests - credentials not configured", name);
TEST_REPORT.record_skip(name);
return Ok(());
}
let provider = provider_fn();
// Restore original environment
for (&var, value) in original_env.iter() {
std::env::set_var(var, value);
}
if let Some(mods) = env_modifications {
for &var in mods.keys() {
if !original_env.contains_key(var) {
std::env::remove_var(var);
}
}
}
std::mem::drop(lock);
let tester = ProviderTester::new(provider, name.to_string());
match tester.run_test_suite().await {
Ok(_) => {
TEST_REPORT.record_pass(name);
Ok(())
}
Err(e) => {
println!("{} test failed: {}", name, e);
TEST_REPORT.record_fail(name);
Err(e)
}
}
}
#[tokio::test]
async fn openai_complete() -> Result<()> {
test_provider(
"OpenAI",
&["OPENAI_API_KEY"],
None,
openai::OpenAiProvider::default,
)
.await
}
#[tokio::test]
async fn databricks_complete() -> Result<()> {
test_provider(
"Databricks",
&["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
None,
databricks::DatabricksProvider::default,
)
.await
}
// Print the final test report
#[ctor::dtor]
fn print_test_report() {
TEST_REPORT.print_summary();
}