[goose-llm] add generate tooltip & session name via extract method (#2467)

* extract method in provider to use for structured outputs

* generate session name from msgs

* generate tooltip from msgs

* add provider tests
This commit is contained in:
Salman Mohammed
2025-05-07 15:42:03 -04:00
committed by GitHub
parent a4f0ec365f
commit 300dd06ec8
16 changed files with 1163 additions and 22 deletions

14
Cargo.lock generated
View File

@@ -2406,7 +2406,7 @@ dependencies = [
"fs2", "fs2",
"futures", "futures",
"include_dir", "include_dir",
"indoc", "indoc 2.0.6",
"jsonwebtoken", "jsonwebtoken",
"keyring", "keyring",
"lazy_static", "lazy_static",
@@ -2527,7 +2527,11 @@ dependencies = [
"base64 0.21.7", "base64 0.21.7",
"chrono", "chrono",
"criterion", "criterion",
"ctor",
"dotenv",
"include_dir", "include_dir",
"indoc 1.0.9",
"lazy_static",
"minijinja", "minijinja",
"once_cell", "once_cell",
"regex", "regex",
@@ -2559,7 +2563,7 @@ dependencies = [
"ignore", "ignore",
"image 0.24.9", "image 0.24.9",
"include_dir", "include_dir",
"indoc", "indoc 2.0.6",
"keyring", "keyring",
"kill_tree", "kill_tree",
"lazy_static", "lazy_static",
@@ -3263,6 +3267,12 @@ dependencies = [
"web-time", "web-time",
] ]
[[package]]
name = "indoc"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306"
[[package]] [[package]]
name = "indoc" name = "indoc"
version = "2.0.6" version = "2.0.6"

View File

@@ -35,10 +35,14 @@ base64 = "0.21"
regex = "1.11.1" regex = "1.11.1"
tracing = "0.1" tracing = "0.1"
smallvec = { version = "1.13", features = ["serde"] } smallvec = { version = "1.13", features = ["serde"] }
indoc = "1.0"
[dev-dependencies] [dev-dependencies]
criterion = "0.5" criterion = "0.5"
tempfile = "3.15.0" tempfile = "3.15.0"
dotenv = "0.15"
lazy_static = "1.5"
ctor = "0.2.7"
[[example]] [[example]]

View File

@@ -0,0 +1,5 @@
mod session_name;
mod tooltip;
pub use session_name::generate_session_name;
pub use tooltip::generate_tooltip;

View File

@@ -0,0 +1,107 @@
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::base::Provider;
use crate::providers::databricks::DatabricksProvider;
use crate::providers::errors::ProviderError;
use crate::types::core::Role;
use anyhow::Result;
use indoc::indoc;
use serde_json::{json, Value};
const SESSION_NAME_EXAMPLES: &[&str] = &[
"Research Synthesis",
"Sentiment Analysis",
"Performance Report",
"Feedback Collector",
"Accessibility Check",
"Design Reminder",
"Project Reminder",
"Launch Checklist",
"Metrics Monitor",
"Incident Response",
"Deploy Cabinet App",
"Design Reminder Alert",
"Generate Monthly Expense Report",
"Automate Incident Response Workflow",
"Analyze Brand Sentiment Trends",
"Monitor Device Health Issues",
"Collect UI Feedback Summary",
"Schedule Project Deadline Reminders",
];
fn build_system_prompt() -> String {
let examples = SESSION_NAME_EXAMPLES
.iter()
.map(|e| format!("- {}", e))
.collect::<Vec<_>>()
.join("\n");
indoc! {r#"
You are an assistant that crafts a concise session title.
Given the first couple user messages in the conversation so far,
reply with only a short name (up to 4 words) that best describes
this sessions goal.
Examples:
"#}
.to_string()
+ &examples
}
/// Generates a short (≤4 words) session name
pub async fn generate_session_name(messages: &[Message]) -> Result<String, ProviderError> {
// Collect up to the first 3 user messages (truncated to 300 chars each)
let context: Vec<String> = messages
.iter()
.filter(|m| m.role == Role::User)
.take(3)
.map(|m| {
let text = m.content.concat_text_str();
if text.len() > 300 {
text.chars().take(300).collect()
} else {
text
}
})
.collect();
if context.is_empty() {
return Err(ProviderError::ExecutionError(
"No user messages found to generate a session name.".to_string(),
));
}
let system_prompt = build_system_prompt();
let user_msg_text = format!("Here are the user messages:\n{}", context.join("\n"));
// Instantiate DatabricksProvider with goose-gpt-4-1
let model_cfg = ModelConfig::new("goose-gpt-4-1".to_string()).with_temperature(Some(0.0));
let provider = DatabricksProvider::from_env(model_cfg)?;
// Use `extract` with a simple string schema
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"],
"additionalProperties": false
});
let user_msg = Message::user().with_text(&user_msg_text);
let resp = provider
.extract(&system_prompt, &[user_msg], &schema)
.await?;
let obj = resp
.data
.as_object()
.ok_or_else(|| ProviderError::ResponseParseError("Expected object".into()))?;
let name = obj
.get("name")
.and_then(Value::as_str)
.ok_or_else(|| ProviderError::ResponseParseError("Missing or non-string name".into()))?
.to_string();
Ok(name)
}

View File

@@ -0,0 +1,165 @@
use crate::message::{Message, MessageContent};
use crate::model::ModelConfig;
use crate::providers::base::Provider;
use crate::providers::databricks::DatabricksProvider;
use crate::providers::errors::ProviderError;
use crate::types::core::{Content, Role};
use anyhow::Result;
use indoc::indoc;
use serde_json::{json, Value};
const TOOLTIP_EXAMPLES: &[&str] = &[
"analyzing KPIs",
"detecting anomalies",
"building artifacts in Buildkite",
"categorizing issues",
"checking dependencies",
"collecting feedback",
"deploying changes in AWS",
"drafting report in Google Docs",
"extracting action items",
"generating insights",
"logging issues",
"monitoring tickets in Zendesk",
"notifying design team",
"running integration tests",
"scanning threads in Figma",
"sending reminders in Gmail",
"sending surveys",
"sharing with stakeholders",
"summarizing findings",
"transcribing meeting",
"tracking resolution",
"updating status in Linear",
];
fn build_system_prompt() -> String {
let examples = TOOLTIP_EXAMPLES
.iter()
.map(|e| format!("- {}", e))
.collect::<Vec<_>>()
.join("\n");
indoc! {r#"
You are an assistant that summarizes the recent conversation into a tooltip.
Given the last two messages, reply with only a short tooltip (up to 4 words)
describing what is happening now.
Examples:
"#}
.to_string()
+ &examples
}
/// Generates a tooltip summarizing the last two messages in the session,
/// including any tool calls or results.
pub async fn generate_tooltip(messages: &[Message]) -> Result<String, ProviderError> {
// Need at least two messages to summarize
if messages.len() < 2 {
return Err(ProviderError::ExecutionError(
"Need at least two messages to generate a tooltip".to_string(),
));
}
// Helper to render a single message's content
fn render_message(m: &Message) -> String {
let mut parts = Vec::new();
for content in m.content.iter() {
match content {
MessageContent::Text(text_block) => {
let txt = text_block.text.trim();
if !txt.is_empty() {
parts.push(txt.to_string());
}
}
MessageContent::ToolRequest(req) => {
if let Ok(tool_call) = &req.tool_call {
parts.push(format!(
"called tool '{}' with args {}",
tool_call.name, tool_call.arguments
));
} else if let Err(e) = &req.tool_call {
parts.push(format!("tool request error: {}", e));
}
}
MessageContent::ToolResponse(resp) => match &resp.tool_result {
Ok(contents) => {
let results: Vec<String> = contents
.iter()
.map(|c| match c {
Content::Text(t) => t.text.clone(),
Content::Image(_) => "[image]".to_string(),
})
.collect();
parts.push(format!("tool responded with: {}", results.join(" ")));
}
Err(e) => {
parts.push(format!("tool error: {}", e));
}
},
_ => {} // ignore other variants
}
}
let role = match m.role {
Role::User => "User",
Role::Assistant => "Assistant",
};
format!("{}: {}", role, parts.join("; "))
}
// Take the last two messages (in correct chronological order)
let rendered: Vec<String> = messages
.iter()
.rev()
.take(2)
.map(render_message)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
let system_prompt = build_system_prompt();
let user_msg_text = format!(
"Here are the last two messages:\n{}\n\nTooltip:",
rendered.join("\n")
);
// Instantiate the provider
let model_cfg = ModelConfig::new("goose-gpt-4-1".to_string()).with_temperature(Some(0.0));
let provider = DatabricksProvider::from_env(model_cfg)?;
// Schema wrapping our tooltip string
let schema = json!({
"type": "object",
"properties": {
"tooltip": { "type": "string" }
},
"required": ["tooltip"],
"additionalProperties": false
});
// Call extract
let user_msg = Message::user().with_text(&user_msg_text);
let resp = provider
.extract(&system_prompt, &[user_msg], &schema)
.await?;
// Pull out the tooltip field
let obj = resp
.data
.as_object()
.ok_or_else(|| ProviderError::ResponseParseError("Expected JSON object".into()))?;
let tooltip = obj
.get("tooltip")
.and_then(Value::as_str)
.ok_or_else(|| {
ProviderError::ResponseParseError("Missing or non-string `tooltip` field".into())
})?
.to_string();
Ok(tooltip)
}

View File

@@ -1,8 +1,9 @@
mod completion; mod completion;
mod message; pub mod extractors;
pub mod message;
mod model; mod model;
mod prompt_template; mod prompt_template;
mod providers; pub mod providers;
pub mod types; pub mod types;
pub use completion::completion; pub use completion::completion;

View File

@@ -43,6 +43,23 @@ impl ProviderCompleteResponse {
} }
} }
/// Response from a structuredextraction call
#[derive(Debug, Clone)]
pub struct ProviderExtractResponse {
/// The extracted JSON object
pub data: serde_json::Value,
/// Which model produced it
pub model: String,
/// Token usage stats
pub usage: Usage,
}
impl ProviderExtractResponse {
pub fn new(data: serde_json::Value, model: String, usage: Usage) -> Self {
Self { data, model, usage }
}
}
/// Base trait for AI providers (OpenAI, Anthropic, etc) /// Base trait for AI providers (OpenAI, Anthropic, etc)
#[async_trait] #[async_trait]
pub trait Provider: Send + Sync { pub trait Provider: Send + Sync {
@@ -65,6 +82,27 @@ pub trait Provider: Send + Sync {
messages: &[Message], messages: &[Message],
tools: &[Tool], tools: &[Tool],
) -> Result<ProviderCompleteResponse, ProviderError>; ) -> Result<ProviderCompleteResponse, ProviderError>;
/// Structured extraction: always JSONSchema
///
/// # Arguments
/// * `system` system prompt guiding the extraction task
/// * `messages` conversation history
/// * `schema` a JSONSchema for the expected output.
/// Will set strict=true for OpenAI & Databricks.
///
/// # Returns
/// A `ProviderExtractResponse` whose `data` is a JSON object matching `schema`.
///
/// # Errors
/// * `ProviderError::ContextLengthExceeded` if the prompt is too large
/// * other `ProviderError` variants for API/network failures
async fn extract(
&self,
system: &str,
messages: &[Message],
schema: &serde_json::Value,
) -> Result<ProviderExtractResponse, ProviderError>;
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -4,7 +4,7 @@ use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::{Client, StatusCode}; use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::{json, Value};
use url::Url; use url::Url;
use super::{ use super::{
@@ -15,17 +15,15 @@ use super::{
use crate::{ use crate::{
message::Message, message::Message,
model::ModelConfig, model::ModelConfig,
providers::{Provider, ProviderCompleteResponse, Usage}, providers::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage},
types::core::Tool, types::core::Tool,
}; };
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-meta-llama-3-3-70b-instruct"; pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet";
// Databricks can passthrough to a wide range of models, we only provide the default // Databricks can passthrough to a wide range of models, we only provide the default
pub const _DATABRICKS_KNOWN_MODELS: &[&str] = &[ pub const _DATABRICKS_KNOWN_MODELS: &[&str] = &[
"databricks-meta-llama-3-3-70b-instruct", "databricks-meta-llama-3-3-70b-instruct",
"databricks-meta-llama-3-1-405b-instruct", "databricks-claude-3-7-sonnet",
"databricks-dbrx-instruct",
"databricks-mixtral-8x7b-instruct",
]; ];
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -215,4 +213,63 @@ impl Provider for DatabricksProvider {
Ok(ProviderCompleteResponse::new(message, model, usage)) Ok(ProviderCompleteResponse::new(message, model, usage))
} }
async fn extract(
&self,
system: &str,
messages: &[Message],
schema: &Value,
) -> Result<ProviderExtractResponse, ProviderError> {
// 1. Build base payload (no tools)
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
// 2. Inject strict JSONSchema wrapper
payload
.as_object_mut()
.expect("payload must be an object")
.insert(
"response_format".to_string(),
json!({
"type": "json_schema",
"json_schema": {
"name": "extraction",
"schema": schema,
"strict": true
}
}),
);
// 3. Call OpenAI
let response = self.post(payload.clone()).await?;
// 4. Extract the assistants `content` and parse it into JSON
let msg = &response["choices"][0]["message"];
let raw = msg.get("content").cloned().ok_or_else(|| {
ProviderError::ResponseParseError("Missing content in extract response".into())
})?;
let data = match raw {
Value::String(s) => serde_json::from_str(&s)
.map_err(|e| ProviderError::ResponseParseError(format!("Invalid JSON: {}", e)))?,
Value::Object(_) | Value::Array(_) => raw,
other => {
return Err(ProviderError::ResponseParseError(format!(
"Unexpected content type: {:?}",
other
)))
}
};
// 5. Gather usage & model info
let usage = match get_usage(&response) {
Ok(u) => u,
Err(ProviderError::UsageError(e)) => {
tracing::debug!("Failed to get usage in extract: {}", e);
Usage::default()
}
Err(e) => return Err(e),
};
let model = get_model(&response);
Ok(ProviderExtractResponse::new(data, model, usage))
}
} }

View File

@@ -22,6 +22,9 @@ pub enum ProviderError {
#[error("Usage data error: {0}")] #[error("Usage data error: {0}")]
UsageError(String), UsageError(String),
#[error("Invalid response: {0}")]
ResponseParseError(String),
} }
impl From<anyhow::Error> for ProviderError { impl From<anyhow::Error> for ProviderError {

View File

@@ -200,6 +200,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
} }
/// Convert internal Tool format to OpenAI's API tool specification /// Convert internal Tool format to OpenAI's API tool specification
/// https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/api-reference#functionobject
pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> { pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> {
let mut tool_names = std::collections::HashSet::new(); let mut tool_names = std::collections::HashSet::new();
let mut result = Vec::new(); let mut result = Vec::new();

View File

@@ -6,5 +6,5 @@ pub mod formats;
pub mod openai; pub mod openai;
pub mod utils; pub mod utils;
pub use base::{Provider, ProviderCompleteResponse, Usage}; pub use base::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage};
pub use factory::create; pub use factory::create;

View File

@@ -3,7 +3,7 @@ use std::{collections::HashMap, time::Duration};
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::Client; use reqwest::Client;
use serde_json::Value; use serde_json::{json, Value};
use super::{ use super::{
errors::ProviderError, errors::ProviderError,
@@ -13,20 +13,12 @@ use super::{
use crate::{ use crate::{
message::Message, message::Message,
model::ModelConfig, model::ModelConfig,
providers::{Provider, ProviderCompleteResponse, Usage}, providers::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage},
types::core::Tool, types::core::Tool,
}; };
pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o"; pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o";
pub const _OPEN_AI_KNOWN_MODELS: &[&str] = &[ pub const _OPEN_AI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4.1", "o1", "o3", "o4-mini"];
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-3.5-turbo",
"o1",
"o3",
"o4-mini",
];
#[derive(Debug)] #[derive(Debug)]
pub struct OpenAiProvider { pub struct OpenAiProvider {
@@ -146,6 +138,65 @@ impl Provider for OpenAiProvider {
emit_debug_trace(&self.model, &payload, &response, &usage); emit_debug_trace(&self.model, &payload, &response, &usage);
Ok(ProviderCompleteResponse::new(message, model, usage)) Ok(ProviderCompleteResponse::new(message, model, usage))
} }
async fn extract(
&self,
system: &str,
messages: &[Message],
schema: &Value,
) -> Result<ProviderExtractResponse, ProviderError> {
// 1. Build base payload (no tools)
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
// 2. Inject strict JSONSchema wrapper
payload
.as_object_mut()
.expect("payload must be an object")
.insert(
"response_format".to_string(),
json!({
"type": "json_schema",
"json_schema": {
"name": "extraction",
"schema": schema,
"strict": true
}
}),
);
// 3. Call OpenAI
let response = self.post(payload.clone()).await?;
// 4. Extract the assistants `content` and parse it into JSON
let msg = &response["choices"][0]["message"];
let raw = msg.get("content").cloned().ok_or_else(|| {
ProviderError::ResponseParseError("Missing content in extract response".into())
})?;
let data = match raw {
Value::String(s) => serde_json::from_str(&s)
.map_err(|e| ProviderError::ResponseParseError(format!("Invalid JSON: {}", e)))?,
Value::Object(_) | Value::Array(_) => raw,
other => {
return Err(ProviderError::ResponseParseError(format!(
"Unexpected content type: {:?}",
other
)))
}
};
// 5. Gather usage & model info
let usage = match get_usage(&response) {
Ok(u) => u,
Err(ProviderError::UsageError(e)) => {
tracing::debug!("Failed to get usage in extract: {}", e);
Usage::default()
}
Err(e) => return Err(e),
};
let model = get_model(&response);
Ok(ProviderExtractResponse::new(data, model, usage))
}
} }
fn parse_custom_headers(s: String) -> HashMap<String, String> { fn parse_custom_headers(s: String) -> HashMap<String, String> {

View File

@@ -0,0 +1,56 @@
use anyhow::Result;
use dotenv::dotenv;
use goose_llm::extractors::generate_session_name;
use goose_llm::message::Message;
use goose_llm::providers::errors::ProviderError;
#[tokio::test]
async fn test_generate_session_name_success() -> Result<(), ProviderError> {
// Load .env for Databricks credentials
dotenv().ok();
let has_creds =
std::env::var("DATABRICKS_HOST").is_ok() && std::env::var("DATABRICKS_TOKEN").is_ok();
if !has_creds {
println!("Skipping generate_session_name test Databricks creds not set");
return Ok(());
}
// Build a few messages with at least two user messages
let messages = vec![
Message::user().with_text("Hello, how are you?"),
Message::assistant().with_text("Im fine, thanks!"),
Message::user().with_text("Whats the weather in New York tomorrow?"),
];
let name = generate_session_name(&messages).await?;
println!("Generated session name: {:?}", name);
// Should be non-empty and at most 4 words
let name = name.trim();
assert!(!name.is_empty(), "Name must not be empty");
let word_count = name.split_whitespace().count();
assert!(
word_count <= 4,
"Name must be 4 words or less, got {}: {}",
word_count,
name
);
Ok(())
}
#[tokio::test]
async fn test_generate_session_name_no_user() {
// No user messages → expect ExecutionError
let messages = vec![
Message::assistant().with_text("System starting…"),
Message::assistant().with_text("All systems go."),
];
let err = generate_session_name(&messages).await;
assert!(
matches!(err, Err(ProviderError::ExecutionError(_))),
"Expected ExecutionError when there are no user messages, got: {:?}",
err
);
}

View File

@@ -0,0 +1,69 @@
use anyhow::Result;
use dotenv::dotenv;
use goose_llm::extractors::generate_tooltip;
use goose_llm::message::{Message, MessageContent, ToolRequest};
use goose_llm::providers::errors::ProviderError;
use goose_llm::types::core::{Content, ToolCall};
use serde_json::json;
#[tokio::test]
async fn test_generate_tooltip_simple() -> Result<(), ProviderError> {
// Skip if no Databricks creds
dotenv().ok();
if std::env::var("DATABRICKS_HOST").is_err() || std::env::var("DATABRICKS_TOKEN").is_err() {
println!("Skipping simple tooltip test Databricks creds not set");
return Ok(());
}
// Two plain-text messages
let messages = vec![
Message::user().with_text("Hello, how are you?"),
Message::assistant().with_text("I'm fine, thanks! How can I help?"),
];
let tooltip = generate_tooltip(&messages).await?;
println!("Generated tooltip: {:?}", tooltip);
assert!(!tooltip.trim().is_empty(), "Tooltip must not be empty");
assert!(
tooltip.len() < 100,
"Tooltip should be reasonably short (<100 chars)"
);
Ok(())
}
#[tokio::test]
async fn test_generate_tooltip_with_tools() -> Result<(), ProviderError> {
// Skip if no Databricks creds
dotenv().ok();
if std::env::var("DATABRICKS_HOST").is_err() || std::env::var("DATABRICKS_TOKEN").is_err() {
println!("Skipping toolbased tooltip test Databricks creds not set");
return Ok(());
}
// 1) Assistant message with a tool request
let mut tool_req_msg = Message::assistant();
let req = ToolRequest {
id: "1".to_string(),
tool_call: Ok(ToolCall::new("get_time", json!({"timezone": "UTC"}))),
};
tool_req_msg.content.push(MessageContent::ToolRequest(req));
// 2) User message with the tool response
let tool_resp_msg = Message::user().with_tool_response(
"1",
Ok(vec![Content::text("The current time is 12:00 UTC")]),
);
let messages = vec![tool_req_msg, tool_resp_msg];
let tooltip = generate_tooltip(&messages).await?;
println!("Generated tooltip (tools): {:?}", tooltip);
assert!(!tooltip.trim().is_empty(), "Tooltip must not be empty");
assert!(
tooltip.len() < 100,
"Tooltip should be reasonably short (<100 chars)"
);
Ok(())
}

View File

@@ -0,0 +1,379 @@
use anyhow::Result;
use dotenv::dotenv;
use goose_llm::message::{Message, MessageContent};
use goose_llm::providers::base::Provider;
use goose_llm::providers::errors::ProviderError;
use goose_llm::providers::{databricks, openai};
use goose_llm::types::core::{Content, Tool};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
#[derive(Debug, Clone, Copy)]
enum TestStatus {
Passed,
Skipped,
Failed,
}
impl std::fmt::Display for TestStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TestStatus::Passed => write!(f, ""),
TestStatus::Skipped => write!(f, "⏭️"),
TestStatus::Failed => write!(f, ""),
}
}
}
struct TestReport {
results: Mutex<HashMap<String, TestStatus>>,
}
impl TestReport {
fn new() -> Arc<Self> {
Arc::new(Self {
results: Mutex::new(HashMap::new()),
})
}
fn record_status(&self, provider: &str, status: TestStatus) {
let mut results = self.results.lock().unwrap();
results.insert(provider.to_string(), status);
}
fn record_pass(&self, provider: &str) {
self.record_status(provider, TestStatus::Passed);
}
fn record_skip(&self, provider: &str) {
self.record_status(provider, TestStatus::Skipped);
}
fn record_fail(&self, provider: &str) {
self.record_status(provider, TestStatus::Failed);
}
fn print_summary(&self) {
println!("\n============== Providers ==============");
let results = self.results.lock().unwrap();
let mut providers: Vec<_> = results.iter().collect();
providers.sort_by(|a, b| a.0.cmp(b.0));
for (provider, status) in providers {
println!("{} {}", status, provider);
}
println!("=======================================\n");
}
}
lazy_static::lazy_static! {
static ref TEST_REPORT: Arc<TestReport> = TestReport::new();
static ref ENV_LOCK: Mutex<()> = Mutex::new(());
}
/// Generic test harness for any Provider implementation
struct ProviderTester {
provider: Arc<dyn Provider>,
name: String,
}
impl ProviderTester {
fn new<T: Provider + Send + Sync + 'static>(provider: T, name: String) -> Self {
Self {
provider: Arc::new(provider),
name,
}
}
async fn test_basic_response(&self) -> Result<()> {
let message = Message::user().with_text("Just say hello!");
let response = self
.provider
.complete("You are a helpful assistant.", &[message], &[])
.await?;
// For a basic response, we expect a single text response
assert_eq!(
response.message.content.len(),
1,
"Expected single content item in response"
);
// Verify we got a text response
assert!(
matches!(response.message.content[0], MessageContent::Text(_)),
"Expected text response"
);
Ok(())
}
async fn test_tool_usage(&self) -> Result<()> {
let weather_tool = Tool::new(
"get_weather",
"Get the weather for a location",
serde_json::json!({
"type": "object",
"required": ["location"],
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
}
}
}),
);
let message = Message::user().with_text("What's the weather like in San Francisco?");
let response1 = self
.provider
.complete(
"You are a helpful weather assistant.",
&[message.clone()],
&[weather_tool.clone()],
)
.await?;
println!("=== {}::reponse1 ===", self.name);
dbg!(&response1);
println!("===================");
// Verify we got a tool request
assert!(
response1
.message
.content
.iter()
.any(|content| matches!(content, MessageContent::ToolRequest(_))),
"Expected tool request in response"
);
let id = &response1
.message
.content
.iter()
.filter_map(|message| message.as_tool_request())
.last()
.expect("got tool request")
.id;
let weather = Message::user().with_tool_response(
id,
Ok(vec![Content::text(
"
50°F°C
Precipitation: 0%
Humidity: 84%
Wind: 2 mph
Weather
Saturday 9:00 PM
Clear",
)]),
);
// Verify we construct a valid payload including the request/response pair for the next inference
let response2 = self
.provider
.complete(
"You are a helpful weather assistant.",
&[message, response1.message, weather],
&[weather_tool],
)
.await?;
println!("=== {}::reponse2 ===", self.name);
dbg!(&response2);
println!("===================");
assert!(
response2
.message
.content
.iter()
.any(|content| matches!(content, MessageContent::Text(_))),
"Expected text for final response"
);
Ok(())
}
async fn test_context_length_exceeded_error(&self) -> Result<()> {
// Google Gemini has a really long context window
let large_message_content = if self.name.to_lowercase() == "google" {
"hello ".repeat(1_300_000)
} else {
"hello ".repeat(300_000)
};
let messages = vec![
Message::user().with_text("hi there. what is 2 + 2?"),
Message::assistant().with_text("hey! I think it's 4."),
Message::user().with_text(&large_message_content),
Message::assistant().with_text("heyy!!"),
// Messages before this mark should be truncated
Message::user().with_text("what's the meaning of life?"),
Message::assistant().with_text("the meaning of life is 42"),
Message::user().with_text(
"did I ask you what's 2+2 in this message history? just respond with 'yes' or 'no'",
),
];
// Test that we get ProviderError::ContextLengthExceeded when the context window is exceeded
let result = self
.provider
.complete("You are a helpful assistant.", &messages, &[])
.await;
// Print some debug info
println!("=== {}::context_length_exceeded_error ===", self.name);
dbg!(&result);
println!("===================");
// Ollama truncates by default even when the context window is exceeded
if self.name.to_lowercase() == "ollama" {
assert!(
result.is_ok(),
"Expected to succeed because of default truncation"
);
return Ok(());
}
assert!(
result.is_err(),
"Expected error when context window is exceeded"
);
assert!(
matches!(result.unwrap_err(), ProviderError::ContextLengthExceeded(_)),
"Expected error to be ContextLengthExceeded"
);
Ok(())
}
/// Run all provider tests
async fn run_test_suite(&self) -> Result<()> {
self.test_basic_response().await?;
self.test_tool_usage().await?;
self.test_context_length_exceeded_error().await?;
Ok(())
}
}
fn load_env() {
if let Ok(path) = dotenv() {
println!("Loaded environment from {:?}", path);
}
}
/// Helper function to run a provider test with proper error handling and reporting
async fn test_provider<F, T>(
name: &str,
required_vars: &[&str],
env_modifications: Option<HashMap<&str, Option<String>>>,
provider_fn: F,
) -> Result<()>
where
F: FnOnce() -> T,
T: Provider + Send + Sync + 'static,
{
// We start off as failed, so that if the process panics it is seen as a failure
TEST_REPORT.record_fail(name);
// Take exclusive access to environment modifications
let lock = ENV_LOCK.lock().unwrap();
load_env();
// Save current environment state for required vars and modified vars
let mut original_env = HashMap::new();
for &var in required_vars {
if let Ok(val) = std::env::var(var) {
original_env.insert(var, val);
}
}
if let Some(mods) = &env_modifications {
for &var in mods.keys() {
if let Ok(val) = std::env::var(var) {
original_env.insert(var, val);
}
}
}
// Apply any environment modifications
if let Some(mods) = &env_modifications {
for (&var, value) in mods.iter() {
match value {
Some(val) => std::env::set_var(var, val),
None => std::env::remove_var(var),
}
}
}
// Setup the provider
let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err());
if missing_vars {
println!("Skipping {} tests - credentials not configured", name);
TEST_REPORT.record_skip(name);
return Ok(());
}
let provider = provider_fn();
// Restore original environment
for (&var, value) in original_env.iter() {
std::env::set_var(var, value);
}
if let Some(mods) = env_modifications {
for &var in mods.keys() {
if !original_env.contains_key(var) {
std::env::remove_var(var);
}
}
}
std::mem::drop(lock);
let tester = ProviderTester::new(provider, name.to_string());
match tester.run_test_suite().await {
Ok(_) => {
TEST_REPORT.record_pass(name);
Ok(())
}
Err(e) => {
println!("{} test failed: {}", name, e);
TEST_REPORT.record_fail(name);
Err(e)
}
}
}
#[tokio::test]
async fn openai_complete() -> Result<()> {
test_provider(
"OpenAI",
&["OPENAI_API_KEY"],
None,
openai::OpenAiProvider::default,
)
.await
}
#[tokio::test]
async fn databricks_complete() -> Result<()> {
test_provider(
"Databricks",
&["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
None,
databricks::DatabricksProvider::default,
)
.await
}
// Print the final test report
#[ctor::dtor]
fn print_test_report() {
TEST_REPORT.print_summary();
}

View File

@@ -0,0 +1,195 @@
// tests/providers_extract.rs
use anyhow::Result;
use dotenv::dotenv;
use goose_llm::message::Message;
use goose_llm::providers::base::Provider;
use goose_llm::providers::{databricks::DatabricksProvider, openai::OpenAiProvider};
use goose_llm::ModelConfig;
use serde_json::{json, Value};
use std::sync::Arc;
#[derive(Debug, PartialEq, Copy, Clone)]
enum ProviderType {
OpenAi,
Databricks,
}
impl ProviderType {
fn required_env(&self) -> &'static [&'static str] {
match self {
ProviderType::OpenAi => &["OPENAI_API_KEY"],
ProviderType::Databricks => &["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
}
}
fn create_provider(&self, cfg: ModelConfig) -> Result<Arc<dyn Provider>> {
Ok(match self {
ProviderType::OpenAi => Arc::new(OpenAiProvider::from_env(cfg)?),
ProviderType::Databricks => Arc::new(DatabricksProvider::from_env(cfg)?),
})
}
}
fn check_required_env_vars(required: &[&str]) -> bool {
let missing: Vec<_> = required
.iter()
.filter(|&&v| std::env::var(v).is_err())
.cloned()
.collect();
if !missing.is_empty() {
println!("Skipping test; missing env vars: {:?}", missing);
false
} else {
true
}
}
// --- Shared inputs for "paper" task ---
const PAPER_SYSTEM: &str =
"You are an expert at structured data extraction. Extract the metadata of a research paper into JSON.";
const PAPER_TEXT: &str =
"Application of Quantum Algorithms in Interstellar Navigation: A New Frontier \
by Dr. Stella Voyager, Dr. Nova Star, Dr. Lyra Hunter. Abstract: This paper \
investigates the utilization of quantum algorithms to improve interstellar \
navigation systems. Keywords: Quantum algorithms, interstellar navigation, \
space-time anomalies, quantum superposition, quantum entanglement, space travel.";
fn paper_schema() -> Value {
json!({
"type": "object",
"properties": {
"title": { "type": "string" },
"authors": { "type": "array", "items": { "type": "string" } },
"abstract": { "type": "string" },
"keywords": { "type": "array", "items": { "type": "string" } }
},
"required": ["title","authors","abstract","keywords"],
"additionalProperties": false
})
}
// --- Shared inputs for "UI" task ---
const UI_SYSTEM: &str = "You are a UI generator AI. Convert the user input into a JSON-driven UI.";
const UI_TEXT: &str = "Make a User Profile Form";
fn ui_schema() -> Value {
json!({
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["div","button","header","section","field","form"]
},
"label": { "type": "string" },
"children": {
"type": "array",
"items": { "$ref": "#" }
},
"attributes": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"value": { "type": "string" }
},
"required": ["name","value"],
"additionalProperties": false
}
}
},
"required": ["type","label","children","attributes"],
"additionalProperties": false
})
}
/// Generic runner for any extract task
async fn run_extract_test<F>(
provider_type: ProviderType,
model: &str,
system: &'static str,
user_text: &'static str,
schema: Value,
validate: F,
) -> Result<()>
where
F: Fn(&Value) -> bool,
{
dotenv().ok();
if !check_required_env_vars(provider_type.required_env()) {
return Ok(());
}
let cfg = ModelConfig::new(model.to_string()).with_temperature(Some(0.0));
let provider = provider_type.create_provider(cfg)?;
let msg = Message::user().with_text(user_text);
let resp = provider.extract(system, &[msg], &schema).await?;
println!("[{:?}] extract => {}", provider_type, resp.data);
assert!(
validate(&resp.data),
"{:?} failed validation on {}",
provider_type,
resp.data
);
Ok(())
}
/// Helper for the "paper" task
async fn run_extract_paper_test(provider: ProviderType, model: &str) -> Result<()> {
run_extract_test(
provider,
model,
PAPER_SYSTEM,
PAPER_TEXT,
paper_schema(),
|v| {
v.as_object()
.map(|o| {
["title", "authors", "abstract", "keywords"]
.iter()
.all(|k| o.contains_key(*k))
})
.unwrap_or(false)
},
)
.await
}
/// Helper for the "UI" task
async fn run_extract_ui_test(provider: ProviderType, model: &str) -> Result<()> {
run_extract_test(provider, model, UI_SYSTEM, UI_TEXT, ui_schema(), |v| {
v.as_object()
.and_then(|o| o.get("type").and_then(Value::as_str))
== Some("form")
})
.await
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn openai_extract_paper() -> Result<()> {
run_extract_paper_test(ProviderType::OpenAi, "gpt-4o").await
}
#[tokio::test]
async fn openai_extract_ui() -> Result<()> {
run_extract_ui_test(ProviderType::OpenAi, "gpt-4o").await
}
#[tokio::test]
async fn databricks_extract_paper() -> Result<()> {
run_extract_paper_test(ProviderType::Databricks, "goose-gpt-4-1").await
}
#[tokio::test]
async fn databricks_extract_ui() -> Result<()> {
run_extract_ui_test(ProviderType::Databricks, "goose-gpt-4-1").await
}
}