diff --git a/Cargo.lock b/Cargo.lock index d27ec8c1..fe08495d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2406,7 +2406,7 @@ dependencies = [ "fs2", "futures", "include_dir", - "indoc", + "indoc 2.0.6", "jsonwebtoken", "keyring", "lazy_static", @@ -2527,7 +2527,11 @@ dependencies = [ "base64 0.21.7", "chrono", "criterion", + "ctor", + "dotenv", "include_dir", + "indoc 1.0.9", + "lazy_static", "minijinja", "once_cell", "regex", @@ -2559,7 +2563,7 @@ dependencies = [ "ignore", "image 0.24.9", "include_dir", - "indoc", + "indoc 2.0.6", "keyring", "kill_tree", "lazy_static", @@ -3263,6 +3267,12 @@ dependencies = [ "web-time", ] +[[package]] +name = "indoc" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" + [[package]] name = "indoc" version = "2.0.6" diff --git a/crates/goose-llm/Cargo.toml b/crates/goose-llm/Cargo.toml index cdf2d845..6bb99c42 100644 --- a/crates/goose-llm/Cargo.toml +++ b/crates/goose-llm/Cargo.toml @@ -35,10 +35,14 @@ base64 = "0.21" regex = "1.11.1" tracing = "0.1" smallvec = { version = "1.13", features = ["serde"] } +indoc = "1.0" [dev-dependencies] criterion = "0.5" tempfile = "3.15.0" +dotenv = "0.15" +lazy_static = "1.5" +ctor = "0.2.7" [[example]] diff --git a/crates/goose-llm/src/extractors/mod.rs b/crates/goose-llm/src/extractors/mod.rs new file mode 100644 index 00000000..6b5e3be5 --- /dev/null +++ b/crates/goose-llm/src/extractors/mod.rs @@ -0,0 +1,5 @@ +mod session_name; +mod tooltip; + +pub use session_name::generate_session_name; +pub use tooltip::generate_tooltip; diff --git a/crates/goose-llm/src/extractors/session_name.rs b/crates/goose-llm/src/extractors/session_name.rs new file mode 100644 index 00000000..4358a7ec --- /dev/null +++ b/crates/goose-llm/src/extractors/session_name.rs @@ -0,0 +1,107 @@ +use crate::message::Message; +use crate::model::ModelConfig; +use crate::providers::base::Provider; +use crate::providers::databricks::DatabricksProvider; +use crate::providers::errors::ProviderError; +use crate::types::core::Role; +use anyhow::Result; +use indoc::indoc; +use serde_json::{json, Value}; + +const SESSION_NAME_EXAMPLES: &[&str] = &[ + "Research Synthesis", + "Sentiment Analysis", + "Performance Report", + "Feedback Collector", + "Accessibility Check", + "Design Reminder", + "Project Reminder", + "Launch Checklist", + "Metrics Monitor", + "Incident Response", + "Deploy Cabinet App", + "Design Reminder Alert", + "Generate Monthly Expense Report", + "Automate Incident Response Workflow", + "Analyze Brand Sentiment Trends", + "Monitor Device Health Issues", + "Collect UI Feedback Summary", + "Schedule Project Deadline Reminders", +]; + +fn build_system_prompt() -> String { + let examples = SESSION_NAME_EXAMPLES + .iter() + .map(|e| format!("- {}", e)) + .collect::>() + .join("\n"); + + indoc! {r#" + You are an assistant that crafts a concise session title. + Given the first couple user messages in the conversation so far, + reply with only a short name (up to 4 words) that best describes + this session’s goal. + + Examples: + "#} + .to_string() + + &examples +} + +/// Generates a short (≤4 words) session name +pub async fn generate_session_name(messages: &[Message]) -> Result { + // Collect up to the first 3 user messages (truncated to 300 chars each) + let context: Vec = messages + .iter() + .filter(|m| m.role == Role::User) + .take(3) + .map(|m| { + let text = m.content.concat_text_str(); + if text.len() > 300 { + text.chars().take(300).collect() + } else { + text + } + }) + .collect(); + + if context.is_empty() { + return Err(ProviderError::ExecutionError( + "No user messages found to generate a session name.".to_string(), + )); + } + + let system_prompt = build_system_prompt(); + let user_msg_text = format!("Here are the user messages:\n{}", context.join("\n")); + + // Instantiate DatabricksProvider with goose-gpt-4-1 + let model_cfg = ModelConfig::new("goose-gpt-4-1".to_string()).with_temperature(Some(0.0)); + let provider = DatabricksProvider::from_env(model_cfg)?; + + // Use `extract` with a simple string schema + let schema = json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + }, + "required": ["name"], + "additionalProperties": false + }); + let user_msg = Message::user().with_text(&user_msg_text); + let resp = provider + .extract(&system_prompt, &[user_msg], &schema) + .await?; + + let obj = resp + .data + .as_object() + .ok_or_else(|| ProviderError::ResponseParseError("Expected object".into()))?; + + let name = obj + .get("name") + .and_then(Value::as_str) + .ok_or_else(|| ProviderError::ResponseParseError("Missing or non-string name".into()))? + .to_string(); + + Ok(name) +} diff --git a/crates/goose-llm/src/extractors/tooltip.rs b/crates/goose-llm/src/extractors/tooltip.rs new file mode 100644 index 00000000..164823c6 --- /dev/null +++ b/crates/goose-llm/src/extractors/tooltip.rs @@ -0,0 +1,165 @@ +use crate::message::{Message, MessageContent}; +use crate::model::ModelConfig; +use crate::providers::base::Provider; +use crate::providers::databricks::DatabricksProvider; +use crate::providers::errors::ProviderError; +use crate::types::core::{Content, Role}; +use anyhow::Result; +use indoc::indoc; +use serde_json::{json, Value}; + +const TOOLTIP_EXAMPLES: &[&str] = &[ + "analyzing KPIs", + "detecting anomalies", + "building artifacts in Buildkite", + "categorizing issues", + "checking dependencies", + "collecting feedback", + "deploying changes in AWS", + "drafting report in Google Docs", + "extracting action items", + "generating insights", + "logging issues", + "monitoring tickets in Zendesk", + "notifying design team", + "running integration tests", + "scanning threads in Figma", + "sending reminders in Gmail", + "sending surveys", + "sharing with stakeholders", + "summarizing findings", + "transcribing meeting", + "tracking resolution", + "updating status in Linear", +]; + +fn build_system_prompt() -> String { + let examples = TOOLTIP_EXAMPLES + .iter() + .map(|e| format!("- {}", e)) + .collect::>() + .join("\n"); + + indoc! {r#" + You are an assistant that summarizes the recent conversation into a tooltip. + Given the last two messages, reply with only a short tooltip (up to 4 words) + describing what is happening now. + + Examples: + "#} + .to_string() + + &examples +} + +/// Generates a tooltip summarizing the last two messages in the session, +/// including any tool calls or results. +pub async fn generate_tooltip(messages: &[Message]) -> Result { + // Need at least two messages to summarize + if messages.len() < 2 { + return Err(ProviderError::ExecutionError( + "Need at least two messages to generate a tooltip".to_string(), + )); + } + + // Helper to render a single message's content + fn render_message(m: &Message) -> String { + let mut parts = Vec::new(); + for content in m.content.iter() { + match content { + MessageContent::Text(text_block) => { + let txt = text_block.text.trim(); + if !txt.is_empty() { + parts.push(txt.to_string()); + } + } + MessageContent::ToolRequest(req) => { + if let Ok(tool_call) = &req.tool_call { + parts.push(format!( + "called tool '{}' with args {}", + tool_call.name, tool_call.arguments + )); + } else if let Err(e) = &req.tool_call { + parts.push(format!("tool request error: {}", e)); + } + } + MessageContent::ToolResponse(resp) => match &resp.tool_result { + Ok(contents) => { + let results: Vec = contents + .iter() + .map(|c| match c { + Content::Text(t) => t.text.clone(), + Content::Image(_) => "[image]".to_string(), + }) + .collect(); + parts.push(format!("tool responded with: {}", results.join(" "))); + } + Err(e) => { + parts.push(format!("tool error: {}", e)); + } + }, + _ => {} // ignore other variants + } + } + + let role = match m.role { + Role::User => "User", + Role::Assistant => "Assistant", + }; + + format!("{}: {}", role, parts.join("; ")) + } + + // Take the last two messages (in correct chronological order) + let rendered: Vec = messages + .iter() + .rev() + .take(2) + .map(render_message) + .collect::>() + .into_iter() + .rev() + .collect(); + + let system_prompt = build_system_prompt(); + + let user_msg_text = format!( + "Here are the last two messages:\n{}\n\nTooltip:", + rendered.join("\n") + ); + + // Instantiate the provider + let model_cfg = ModelConfig::new("goose-gpt-4-1".to_string()).with_temperature(Some(0.0)); + let provider = DatabricksProvider::from_env(model_cfg)?; + + // Schema wrapping our tooltip string + let schema = json!({ + "type": "object", + "properties": { + "tooltip": { "type": "string" } + }, + "required": ["tooltip"], + "additionalProperties": false + }); + + // Call extract + let user_msg = Message::user().with_text(&user_msg_text); + let resp = provider + .extract(&system_prompt, &[user_msg], &schema) + .await?; + + // Pull out the tooltip field + let obj = resp + .data + .as_object() + .ok_or_else(|| ProviderError::ResponseParseError("Expected JSON object".into()))?; + + let tooltip = obj + .get("tooltip") + .and_then(Value::as_str) + .ok_or_else(|| { + ProviderError::ResponseParseError("Missing or non-string `tooltip` field".into()) + })? + .to_string(); + + Ok(tooltip) +} diff --git a/crates/goose-llm/src/lib.rs b/crates/goose-llm/src/lib.rs index 3798cc9d..5dc8dac6 100644 --- a/crates/goose-llm/src/lib.rs +++ b/crates/goose-llm/src/lib.rs @@ -1,8 +1,9 @@ mod completion; -mod message; +pub mod extractors; +pub mod message; mod model; mod prompt_template; -mod providers; +pub mod providers; pub mod types; pub use completion::completion; diff --git a/crates/goose-llm/src/providers/base.rs b/crates/goose-llm/src/providers/base.rs index eb580490..c03f5f04 100644 --- a/crates/goose-llm/src/providers/base.rs +++ b/crates/goose-llm/src/providers/base.rs @@ -43,6 +43,23 @@ impl ProviderCompleteResponse { } } +/// Response from a structured‐extraction call +#[derive(Debug, Clone)] +pub struct ProviderExtractResponse { + /// The extracted JSON object + pub data: serde_json::Value, + /// Which model produced it + pub model: String, + /// Token usage stats + pub usage: Usage, +} + +impl ProviderExtractResponse { + pub fn new(data: serde_json::Value, model: String, usage: Usage) -> Self { + Self { data, model, usage } + } +} + /// Base trait for AI providers (OpenAI, Anthropic, etc) #[async_trait] pub trait Provider: Send + Sync { @@ -65,6 +82,27 @@ pub trait Provider: Send + Sync { messages: &[Message], tools: &[Tool], ) -> Result; + + /// Structured extraction: always JSON‐Schema + /// + /// # Arguments + /// * `system` – system prompt guiding the extraction task + /// * `messages` – conversation history + /// * `schema` – a JSON‐Schema for the expected output. + /// Will set strict=true for OpenAI & Databricks. + /// + /// # Returns + /// A `ProviderExtractResponse` whose `data` is a JSON object matching `schema`. + /// + /// # Errors + /// * `ProviderError::ContextLengthExceeded` if the prompt is too large + /// * other `ProviderError` variants for API/network failures + async fn extract( + &self, + system: &str, + messages: &[Message], + schema: &serde_json::Value, + ) -> Result; } #[cfg(test)] diff --git a/crates/goose-llm/src/providers/databricks.rs b/crates/goose-llm/src/providers/databricks.rs index 013b8a89..2a20c7fa 100644 --- a/crates/goose-llm/src/providers/databricks.rs +++ b/crates/goose-llm/src/providers/databricks.rs @@ -4,7 +4,7 @@ use anyhow::Result; use async_trait::async_trait; use reqwest::{Client, StatusCode}; use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{json, Value}; use url::Url; use super::{ @@ -15,17 +15,15 @@ use super::{ use crate::{ message::Message, model::ModelConfig, - providers::{Provider, ProviderCompleteResponse, Usage}, + providers::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage}, types::core::Tool, }; -pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-meta-llama-3-3-70b-instruct"; +pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet"; // Databricks can passthrough to a wide range of models, we only provide the default pub const _DATABRICKS_KNOWN_MODELS: &[&str] = &[ "databricks-meta-llama-3-3-70b-instruct", - "databricks-meta-llama-3-1-405b-instruct", - "databricks-dbrx-instruct", - "databricks-mixtral-8x7b-instruct", + "databricks-claude-3-7-sonnet", ]; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -215,4 +213,63 @@ impl Provider for DatabricksProvider { Ok(ProviderCompleteResponse::new(message, model, usage)) } + + async fn extract( + &self, + system: &str, + messages: &[Message], + schema: &Value, + ) -> Result { + // 1. Build base payload (no tools) + let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?; + + // 2. Inject strict JSON‐Schema wrapper + payload + .as_object_mut() + .expect("payload must be an object") + .insert( + "response_format".to_string(), + json!({ + "type": "json_schema", + "json_schema": { + "name": "extraction", + "schema": schema, + "strict": true + } + }), + ); + + // 3. Call OpenAI + let response = self.post(payload.clone()).await?; + + // 4. Extract the assistant’s `content` and parse it into JSON + let msg = &response["choices"][0]["message"]; + let raw = msg.get("content").cloned().ok_or_else(|| { + ProviderError::ResponseParseError("Missing content in extract response".into()) + })?; + let data = match raw { + Value::String(s) => serde_json::from_str(&s) + .map_err(|e| ProviderError::ResponseParseError(format!("Invalid JSON: {}", e)))?, + Value::Object(_) | Value::Array(_) => raw, + other => { + return Err(ProviderError::ResponseParseError(format!( + "Unexpected content type: {:?}", + other + ))) + } + }; + + // 5. Gather usage & model info + let usage = match get_usage(&response) { + Ok(u) => u, + Err(ProviderError::UsageError(e)) => { + tracing::debug!("Failed to get usage in extract: {}", e); + Usage::default() + } + Err(e) => return Err(e), + }; + let model = get_model(&response); + + Ok(ProviderExtractResponse::new(data, model, usage)) + } } diff --git a/crates/goose-llm/src/providers/errors.rs b/crates/goose-llm/src/providers/errors.rs index 4b31d286..c44ff140 100644 --- a/crates/goose-llm/src/providers/errors.rs +++ b/crates/goose-llm/src/providers/errors.rs @@ -22,6 +22,9 @@ pub enum ProviderError { #[error("Usage data error: {0}")] UsageError(String), + + #[error("Invalid response: {0}")] + ResponseParseError(String), } impl From for ProviderError { diff --git a/crates/goose-llm/src/providers/formats/databricks.rs b/crates/goose-llm/src/providers/formats/databricks.rs index 8aece4d2..209bc175 100644 --- a/crates/goose-llm/src/providers/formats/databricks.rs +++ b/crates/goose-llm/src/providers/formats/databricks.rs @@ -200,6 +200,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< } /// Convert internal Tool format to OpenAI's API tool specification +/// https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/api-reference#functionobject pub fn format_tools(tools: &[Tool]) -> anyhow::Result> { let mut tool_names = std::collections::HashSet::new(); let mut result = Vec::new(); diff --git a/crates/goose-llm/src/providers/mod.rs b/crates/goose-llm/src/providers/mod.rs index c952d2d3..c8089380 100644 --- a/crates/goose-llm/src/providers/mod.rs +++ b/crates/goose-llm/src/providers/mod.rs @@ -6,5 +6,5 @@ pub mod formats; pub mod openai; pub mod utils; -pub use base::{Provider, ProviderCompleteResponse, Usage}; +pub use base::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage}; pub use factory::create; diff --git a/crates/goose-llm/src/providers/openai.rs b/crates/goose-llm/src/providers/openai.rs index 450dfc46..b08c094d 100644 --- a/crates/goose-llm/src/providers/openai.rs +++ b/crates/goose-llm/src/providers/openai.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, time::Duration}; use anyhow::Result; use async_trait::async_trait; use reqwest::Client; -use serde_json::Value; +use serde_json::{json, Value}; use super::{ errors::ProviderError, @@ -13,20 +13,12 @@ use super::{ use crate::{ message::Message, model::ModelConfig, - providers::{Provider, ProviderCompleteResponse, Usage}, + providers::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage}, types::core::Tool, }; pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o"; -pub const _OPEN_AI_KNOWN_MODELS: &[&str] = &[ - "gpt-4o", - "gpt-4o-mini", - "gpt-4-turbo", - "gpt-3.5-turbo", - "o1", - "o3", - "o4-mini", -]; +pub const _OPEN_AI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4.1", "o1", "o3", "o4-mini"]; #[derive(Debug)] pub struct OpenAiProvider { @@ -146,6 +138,65 @@ impl Provider for OpenAiProvider { emit_debug_trace(&self.model, &payload, &response, &usage); Ok(ProviderCompleteResponse::new(message, model, usage)) } + + async fn extract( + &self, + system: &str, + messages: &[Message], + schema: &Value, + ) -> Result { + // 1. Build base payload (no tools) + let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?; + + // 2. Inject strict JSON‐Schema wrapper + payload + .as_object_mut() + .expect("payload must be an object") + .insert( + "response_format".to_string(), + json!({ + "type": "json_schema", + "json_schema": { + "name": "extraction", + "schema": schema, + "strict": true + } + }), + ); + + // 3. Call OpenAI + let response = self.post(payload.clone()).await?; + + // 4. Extract the assistant’s `content` and parse it into JSON + let msg = &response["choices"][0]["message"]; + let raw = msg.get("content").cloned().ok_or_else(|| { + ProviderError::ResponseParseError("Missing content in extract response".into()) + })?; + let data = match raw { + Value::String(s) => serde_json::from_str(&s) + .map_err(|e| ProviderError::ResponseParseError(format!("Invalid JSON: {}", e)))?, + Value::Object(_) | Value::Array(_) => raw, + other => { + return Err(ProviderError::ResponseParseError(format!( + "Unexpected content type: {:?}", + other + ))) + } + }; + + // 5. Gather usage & model info + let usage = match get_usage(&response) { + Ok(u) => u, + Err(ProviderError::UsageError(e)) => { + tracing::debug!("Failed to get usage in extract: {}", e); + Usage::default() + } + Err(e) => return Err(e), + }; + let model = get_model(&response); + + Ok(ProviderExtractResponse::new(data, model, usage)) + } } fn parse_custom_headers(s: String) -> HashMap { diff --git a/crates/goose-llm/tests/extract_session_name.rs b/crates/goose-llm/tests/extract_session_name.rs new file mode 100644 index 00000000..d8cd6e9b --- /dev/null +++ b/crates/goose-llm/tests/extract_session_name.rs @@ -0,0 +1,56 @@ +use anyhow::Result; +use dotenv::dotenv; +use goose_llm::extractors::generate_session_name; +use goose_llm::message::Message; +use goose_llm::providers::errors::ProviderError; + +#[tokio::test] +async fn test_generate_session_name_success() -> Result<(), ProviderError> { + // Load .env for Databricks credentials + dotenv().ok(); + let has_creds = + std::env::var("DATABRICKS_HOST").is_ok() && std::env::var("DATABRICKS_TOKEN").is_ok(); + if !has_creds { + println!("Skipping generate_session_name test – Databricks creds not set"); + return Ok(()); + } + + // Build a few messages with at least two user messages + let messages = vec![ + Message::user().with_text("Hello, how are you?"), + Message::assistant().with_text("I’m fine, thanks!"), + Message::user().with_text("What’s the weather in New York tomorrow?"), + ]; + + let name = generate_session_name(&messages).await?; + println!("Generated session name: {:?}", name); + + // Should be non-empty and at most 4 words + let name = name.trim(); + assert!(!name.is_empty(), "Name must not be empty"); + let word_count = name.split_whitespace().count(); + assert!( + word_count <= 4, + "Name must be 4 words or less, got {}: {}", + word_count, + name + ); + + Ok(()) +} + +#[tokio::test] +async fn test_generate_session_name_no_user() { + // No user messages → expect ExecutionError + let messages = vec![ + Message::assistant().with_text("System starting…"), + Message::assistant().with_text("All systems go."), + ]; + + let err = generate_session_name(&messages).await; + assert!( + matches!(err, Err(ProviderError::ExecutionError(_))), + "Expected ExecutionError when there are no user messages, got: {:?}", + err + ); +} diff --git a/crates/goose-llm/tests/extract_tooltip.rs b/crates/goose-llm/tests/extract_tooltip.rs new file mode 100644 index 00000000..8ad1c3c5 --- /dev/null +++ b/crates/goose-llm/tests/extract_tooltip.rs @@ -0,0 +1,69 @@ +use anyhow::Result; +use dotenv::dotenv; +use goose_llm::extractors::generate_tooltip; +use goose_llm::message::{Message, MessageContent, ToolRequest}; +use goose_llm::providers::errors::ProviderError; +use goose_llm::types::core::{Content, ToolCall}; +use serde_json::json; + +#[tokio::test] +async fn test_generate_tooltip_simple() -> Result<(), ProviderError> { + // Skip if no Databricks creds + dotenv().ok(); + if std::env::var("DATABRICKS_HOST").is_err() || std::env::var("DATABRICKS_TOKEN").is_err() { + println!("Skipping simple tooltip test – Databricks creds not set"); + return Ok(()); + } + + // Two plain-text messages + let messages = vec![ + Message::user().with_text("Hello, how are you?"), + Message::assistant().with_text("I'm fine, thanks! How can I help?"), + ]; + + let tooltip = generate_tooltip(&messages).await?; + println!("Generated tooltip: {:?}", tooltip); + + assert!(!tooltip.trim().is_empty(), "Tooltip must not be empty"); + assert!( + tooltip.len() < 100, + "Tooltip should be reasonably short (<100 chars)" + ); + Ok(()) +} + +#[tokio::test] +async fn test_generate_tooltip_with_tools() -> Result<(), ProviderError> { + // Skip if no Databricks creds + dotenv().ok(); + if std::env::var("DATABRICKS_HOST").is_err() || std::env::var("DATABRICKS_TOKEN").is_err() { + println!("Skipping tool‐based tooltip test – Databricks creds not set"); + return Ok(()); + } + + // 1) Assistant message with a tool request + let mut tool_req_msg = Message::assistant(); + let req = ToolRequest { + id: "1".to_string(), + tool_call: Ok(ToolCall::new("get_time", json!({"timezone": "UTC"}))), + }; + tool_req_msg.content.push(MessageContent::ToolRequest(req)); + + // 2) User message with the tool response + let tool_resp_msg = Message::user().with_tool_response( + "1", + Ok(vec![Content::text("The current time is 12:00 UTC")]), + ); + + let messages = vec![tool_req_msg, tool_resp_msg]; + + let tooltip = generate_tooltip(&messages).await?; + println!("Generated tooltip (tools): {:?}", tooltip); + + assert!(!tooltip.trim().is_empty(), "Tooltip must not be empty"); + assert!( + tooltip.len() < 100, + "Tooltip should be reasonably short (<100 chars)" + ); + Ok(()) +} diff --git a/crates/goose-llm/tests/providers_complete.rs b/crates/goose-llm/tests/providers_complete.rs new file mode 100644 index 00000000..43212bfd --- /dev/null +++ b/crates/goose-llm/tests/providers_complete.rs @@ -0,0 +1,379 @@ +use anyhow::Result; +use dotenv::dotenv; +use goose_llm::message::{Message, MessageContent}; +use goose_llm::providers::base::Provider; +use goose_llm::providers::errors::ProviderError; +use goose_llm::providers::{databricks, openai}; +use goose_llm::types::core::{Content, Tool}; +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex; + +#[derive(Debug, Clone, Copy)] +enum TestStatus { + Passed, + Skipped, + Failed, +} + +impl std::fmt::Display for TestStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TestStatus::Passed => write!(f, "✅"), + TestStatus::Skipped => write!(f, "⏭️"), + TestStatus::Failed => write!(f, "❌"), + } + } +} + +struct TestReport { + results: Mutex>, +} + +impl TestReport { + fn new() -> Arc { + Arc::new(Self { + results: Mutex::new(HashMap::new()), + }) + } + + fn record_status(&self, provider: &str, status: TestStatus) { + let mut results = self.results.lock().unwrap(); + results.insert(provider.to_string(), status); + } + + fn record_pass(&self, provider: &str) { + self.record_status(provider, TestStatus::Passed); + } + + fn record_skip(&self, provider: &str) { + self.record_status(provider, TestStatus::Skipped); + } + + fn record_fail(&self, provider: &str) { + self.record_status(provider, TestStatus::Failed); + } + + fn print_summary(&self) { + println!("\n============== Providers =============="); + let results = self.results.lock().unwrap(); + let mut providers: Vec<_> = results.iter().collect(); + providers.sort_by(|a, b| a.0.cmp(b.0)); + + for (provider, status) in providers { + println!("{} {}", status, provider); + } + println!("=======================================\n"); + } +} + +lazy_static::lazy_static! { + static ref TEST_REPORT: Arc = TestReport::new(); + static ref ENV_LOCK: Mutex<()> = Mutex::new(()); +} + +/// Generic test harness for any Provider implementation +struct ProviderTester { + provider: Arc, + name: String, +} + +impl ProviderTester { + fn new(provider: T, name: String) -> Self { + Self { + provider: Arc::new(provider), + name, + } + } + + async fn test_basic_response(&self) -> Result<()> { + let message = Message::user().with_text("Just say hello!"); + + let response = self + .provider + .complete("You are a helpful assistant.", &[message], &[]) + .await?; + + // For a basic response, we expect a single text response + assert_eq!( + response.message.content.len(), + 1, + "Expected single content item in response" + ); + + // Verify we got a text response + assert!( + matches!(response.message.content[0], MessageContent::Text(_)), + "Expected text response" + ); + + Ok(()) + } + + async fn test_tool_usage(&self) -> Result<()> { + let weather_tool = Tool::new( + "get_weather", + "Get the weather for a location", + serde_json::json!({ + "type": "object", + "required": ["location"], + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + } + } + }), + ); + + let message = Message::user().with_text("What's the weather like in San Francisco?"); + + let response1 = self + .provider + .complete( + "You are a helpful weather assistant.", + &[message.clone()], + &[weather_tool.clone()], + ) + .await?; + + println!("=== {}::reponse1 ===", self.name); + dbg!(&response1); + println!("==================="); + + // Verify we got a tool request + assert!( + response1 + .message + .content + .iter() + .any(|content| matches!(content, MessageContent::ToolRequest(_))), + "Expected tool request in response" + ); + + let id = &response1 + .message + .content + .iter() + .filter_map(|message| message.as_tool_request()) + .last() + .expect("got tool request") + .id; + + let weather = Message::user().with_tool_response( + id, + Ok(vec![Content::text( + " + 50°F°C + Precipitation: 0% + Humidity: 84% + Wind: 2 mph + Weather + Saturday 9:00 PM + Clear", + )]), + ); + + // Verify we construct a valid payload including the request/response pair for the next inference + let response2 = self + .provider + .complete( + "You are a helpful weather assistant.", + &[message, response1.message, weather], + &[weather_tool], + ) + .await?; + + println!("=== {}::reponse2 ===", self.name); + dbg!(&response2); + println!("==================="); + + assert!( + response2 + .message + .content + .iter() + .any(|content| matches!(content, MessageContent::Text(_))), + "Expected text for final response" + ); + + Ok(()) + } + + async fn test_context_length_exceeded_error(&self) -> Result<()> { + // Google Gemini has a really long context window + let large_message_content = if self.name.to_lowercase() == "google" { + "hello ".repeat(1_300_000) + } else { + "hello ".repeat(300_000) + }; + + let messages = vec![ + Message::user().with_text("hi there. what is 2 + 2?"), + Message::assistant().with_text("hey! I think it's 4."), + Message::user().with_text(&large_message_content), + Message::assistant().with_text("heyy!!"), + // Messages before this mark should be truncated + Message::user().with_text("what's the meaning of life?"), + Message::assistant().with_text("the meaning of life is 42"), + Message::user().with_text( + "did I ask you what's 2+2 in this message history? just respond with 'yes' or 'no'", + ), + ]; + + // Test that we get ProviderError::ContextLengthExceeded when the context window is exceeded + let result = self + .provider + .complete("You are a helpful assistant.", &messages, &[]) + .await; + + // Print some debug info + println!("=== {}::context_length_exceeded_error ===", self.name); + dbg!(&result); + println!("==================="); + + // Ollama truncates by default even when the context window is exceeded + if self.name.to_lowercase() == "ollama" { + assert!( + result.is_ok(), + "Expected to succeed because of default truncation" + ); + return Ok(()); + } + + assert!( + result.is_err(), + "Expected error when context window is exceeded" + ); + assert!( + matches!(result.unwrap_err(), ProviderError::ContextLengthExceeded(_)), + "Expected error to be ContextLengthExceeded" + ); + + Ok(()) + } + + /// Run all provider tests + async fn run_test_suite(&self) -> Result<()> { + self.test_basic_response().await?; + self.test_tool_usage().await?; + self.test_context_length_exceeded_error().await?; + Ok(()) + } +} + +fn load_env() { + if let Ok(path) = dotenv() { + println!("Loaded environment from {:?}", path); + } +} + +/// Helper function to run a provider test with proper error handling and reporting +async fn test_provider( + name: &str, + required_vars: &[&str], + env_modifications: Option>>, + provider_fn: F, +) -> Result<()> +where + F: FnOnce() -> T, + T: Provider + Send + Sync + 'static, +{ + // We start off as failed, so that if the process panics it is seen as a failure + TEST_REPORT.record_fail(name); + + // Take exclusive access to environment modifications + let lock = ENV_LOCK.lock().unwrap(); + + load_env(); + + // Save current environment state for required vars and modified vars + let mut original_env = HashMap::new(); + for &var in required_vars { + if let Ok(val) = std::env::var(var) { + original_env.insert(var, val); + } + } + if let Some(mods) = &env_modifications { + for &var in mods.keys() { + if let Ok(val) = std::env::var(var) { + original_env.insert(var, val); + } + } + } + + // Apply any environment modifications + if let Some(mods) = &env_modifications { + for (&var, value) in mods.iter() { + match value { + Some(val) => std::env::set_var(var, val), + None => std::env::remove_var(var), + } + } + } + + // Setup the provider + let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err()); + if missing_vars { + println!("Skipping {} tests - credentials not configured", name); + TEST_REPORT.record_skip(name); + return Ok(()); + } + + let provider = provider_fn(); + + // Restore original environment + for (&var, value) in original_env.iter() { + std::env::set_var(var, value); + } + if let Some(mods) = env_modifications { + for &var in mods.keys() { + if !original_env.contains_key(var) { + std::env::remove_var(var); + } + } + } + + std::mem::drop(lock); + + let tester = ProviderTester::new(provider, name.to_string()); + match tester.run_test_suite().await { + Ok(_) => { + TEST_REPORT.record_pass(name); + Ok(()) + } + Err(e) => { + println!("{} test failed: {}", name, e); + TEST_REPORT.record_fail(name); + Err(e) + } + } +} + +#[tokio::test] +async fn openai_complete() -> Result<()> { + test_provider( + "OpenAI", + &["OPENAI_API_KEY"], + None, + openai::OpenAiProvider::default, + ) + .await +} + +#[tokio::test] +async fn databricks_complete() -> Result<()> { + test_provider( + "Databricks", + &["DATABRICKS_HOST", "DATABRICKS_TOKEN"], + None, + databricks::DatabricksProvider::default, + ) + .await +} + +// Print the final test report +#[ctor::dtor] +fn print_test_report() { + TEST_REPORT.print_summary(); +} diff --git a/crates/goose-llm/tests/providers_extract.rs b/crates/goose-llm/tests/providers_extract.rs new file mode 100644 index 00000000..544040f7 --- /dev/null +++ b/crates/goose-llm/tests/providers_extract.rs @@ -0,0 +1,195 @@ +// tests/providers_extract.rs + +use anyhow::Result; +use dotenv::dotenv; +use goose_llm::message::Message; +use goose_llm::providers::base::Provider; +use goose_llm::providers::{databricks::DatabricksProvider, openai::OpenAiProvider}; +use goose_llm::ModelConfig; +use serde_json::{json, Value}; +use std::sync::Arc; + +#[derive(Debug, PartialEq, Copy, Clone)] +enum ProviderType { + OpenAi, + Databricks, +} + +impl ProviderType { + fn required_env(&self) -> &'static [&'static str] { + match self { + ProviderType::OpenAi => &["OPENAI_API_KEY"], + ProviderType::Databricks => &["DATABRICKS_HOST", "DATABRICKS_TOKEN"], + } + } + + fn create_provider(&self, cfg: ModelConfig) -> Result> { + Ok(match self { + ProviderType::OpenAi => Arc::new(OpenAiProvider::from_env(cfg)?), + ProviderType::Databricks => Arc::new(DatabricksProvider::from_env(cfg)?), + }) + } +} + +fn check_required_env_vars(required: &[&str]) -> bool { + let missing: Vec<_> = required + .iter() + .filter(|&&v| std::env::var(v).is_err()) + .cloned() + .collect(); + if !missing.is_empty() { + println!("Skipping test; missing env vars: {:?}", missing); + false + } else { + true + } +} + +// --- Shared inputs for "paper" task --- +const PAPER_SYSTEM: &str = + "You are an expert at structured data extraction. Extract the metadata of a research paper into JSON."; +const PAPER_TEXT: &str = + "Application of Quantum Algorithms in Interstellar Navigation: A New Frontier \ + by Dr. Stella Voyager, Dr. Nova Star, Dr. Lyra Hunter. Abstract: This paper \ + investigates the utilization of quantum algorithms to improve interstellar \ + navigation systems. Keywords: Quantum algorithms, interstellar navigation, \ + space-time anomalies, quantum superposition, quantum entanglement, space travel."; + +fn paper_schema() -> Value { + json!({ + "type": "object", + "properties": { + "title": { "type": "string" }, + "authors": { "type": "array", "items": { "type": "string" } }, + "abstract": { "type": "string" }, + "keywords": { "type": "array", "items": { "type": "string" } } + }, + "required": ["title","authors","abstract","keywords"], + "additionalProperties": false + }) +} + +// --- Shared inputs for "UI" task --- +const UI_SYSTEM: &str = "You are a UI generator AI. Convert the user input into a JSON-driven UI."; +const UI_TEXT: &str = "Make a User Profile Form"; + +fn ui_schema() -> Value { + json!({ + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["div","button","header","section","field","form"] + }, + "label": { "type": "string" }, + "children": { + "type": "array", + "items": { "$ref": "#" } + }, + "attributes": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "value": { "type": "string" } + }, + "required": ["name","value"], + "additionalProperties": false + } + } + }, + "required": ["type","label","children","attributes"], + "additionalProperties": false + }) +} + +/// Generic runner for any extract task +async fn run_extract_test( + provider_type: ProviderType, + model: &str, + system: &'static str, + user_text: &'static str, + schema: Value, + validate: F, +) -> Result<()> +where + F: Fn(&Value) -> bool, +{ + dotenv().ok(); + if !check_required_env_vars(provider_type.required_env()) { + return Ok(()); + } + + let cfg = ModelConfig::new(model.to_string()).with_temperature(Some(0.0)); + let provider = provider_type.create_provider(cfg)?; + + let msg = Message::user().with_text(user_text); + let resp = provider.extract(system, &[msg], &schema).await?; + + println!("[{:?}] extract => {}", provider_type, resp.data); + + assert!( + validate(&resp.data), + "{:?} failed validation on {}", + provider_type, + resp.data + ); + Ok(()) +} + +/// Helper for the "paper" task +async fn run_extract_paper_test(provider: ProviderType, model: &str) -> Result<()> { + run_extract_test( + provider, + model, + PAPER_SYSTEM, + PAPER_TEXT, + paper_schema(), + |v| { + v.as_object() + .map(|o| { + ["title", "authors", "abstract", "keywords"] + .iter() + .all(|k| o.contains_key(*k)) + }) + .unwrap_or(false) + }, + ) + .await +} + +/// Helper for the "UI" task +async fn run_extract_ui_test(provider: ProviderType, model: &str) -> Result<()> { + run_extract_test(provider, model, UI_SYSTEM, UI_TEXT, ui_schema(), |v| { + v.as_object() + .and_then(|o| o.get("type").and_then(Value::as_str)) + == Some("form") + }) + .await +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn openai_extract_paper() -> Result<()> { + run_extract_paper_test(ProviderType::OpenAi, "gpt-4o").await + } + + #[tokio::test] + async fn openai_extract_ui() -> Result<()> { + run_extract_ui_test(ProviderType::OpenAi, "gpt-4o").await + } + + #[tokio::test] + async fn databricks_extract_paper() -> Result<()> { + run_extract_paper_test(ProviderType::Databricks, "goose-gpt-4-1").await + } + + #[tokio::test] + async fn databricks_extract_ui() -> Result<()> { + run_extract_ui_test(ProviderType::Databricks, "goose-gpt-4-1").await + } +}