mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 15:14:21 +01:00
[goose-llm] add generate tooltip & session name via extract method (#2467)
* extract method in provider to use for structured outputs * generate session name from msgs * generate tooltip from msgs * add provider tests
This commit is contained in:
14
Cargo.lock
generated
14
Cargo.lock
generated
@@ -2406,7 +2406,7 @@ dependencies = [
|
||||
"fs2",
|
||||
"futures",
|
||||
"include_dir",
|
||||
"indoc",
|
||||
"indoc 2.0.6",
|
||||
"jsonwebtoken",
|
||||
"keyring",
|
||||
"lazy_static",
|
||||
@@ -2527,7 +2527,11 @@ dependencies = [
|
||||
"base64 0.21.7",
|
||||
"chrono",
|
||||
"criterion",
|
||||
"ctor",
|
||||
"dotenv",
|
||||
"include_dir",
|
||||
"indoc 1.0.9",
|
||||
"lazy_static",
|
||||
"minijinja",
|
||||
"once_cell",
|
||||
"regex",
|
||||
@@ -2559,7 +2563,7 @@ dependencies = [
|
||||
"ignore",
|
||||
"image 0.24.9",
|
||||
"include_dir",
|
||||
"indoc",
|
||||
"indoc 2.0.6",
|
||||
"keyring",
|
||||
"kill_tree",
|
||||
"lazy_static",
|
||||
@@ -3263,6 +3267,12 @@ dependencies = [
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indoc"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306"
|
||||
|
||||
[[package]]
|
||||
name = "indoc"
|
||||
version = "2.0.6"
|
||||
|
||||
@@ -35,10 +35,14 @@ base64 = "0.21"
|
||||
regex = "1.11.1"
|
||||
tracing = "0.1"
|
||||
smallvec = { version = "1.13", features = ["serde"] }
|
||||
indoc = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5"
|
||||
tempfile = "3.15.0"
|
||||
dotenv = "0.15"
|
||||
lazy_static = "1.5"
|
||||
ctor = "0.2.7"
|
||||
|
||||
|
||||
[[example]]
|
||||
|
||||
5
crates/goose-llm/src/extractors/mod.rs
Normal file
5
crates/goose-llm/src/extractors/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod session_name;
|
||||
mod tooltip;
|
||||
|
||||
pub use session_name::generate_session_name;
|
||||
pub use tooltip::generate_tooltip;
|
||||
107
crates/goose-llm/src/extractors/session_name.rs
Normal file
107
crates/goose-llm/src/extractors/session_name.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::Provider;
|
||||
use crate::providers::databricks::DatabricksProvider;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::types::core::Role;
|
||||
use anyhow::Result;
|
||||
use indoc::indoc;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const SESSION_NAME_EXAMPLES: &[&str] = &[
|
||||
"Research Synthesis",
|
||||
"Sentiment Analysis",
|
||||
"Performance Report",
|
||||
"Feedback Collector",
|
||||
"Accessibility Check",
|
||||
"Design Reminder",
|
||||
"Project Reminder",
|
||||
"Launch Checklist",
|
||||
"Metrics Monitor",
|
||||
"Incident Response",
|
||||
"Deploy Cabinet App",
|
||||
"Design Reminder Alert",
|
||||
"Generate Monthly Expense Report",
|
||||
"Automate Incident Response Workflow",
|
||||
"Analyze Brand Sentiment Trends",
|
||||
"Monitor Device Health Issues",
|
||||
"Collect UI Feedback Summary",
|
||||
"Schedule Project Deadline Reminders",
|
||||
];
|
||||
|
||||
fn build_system_prompt() -> String {
|
||||
let examples = SESSION_NAME_EXAMPLES
|
||||
.iter()
|
||||
.map(|e| format!("- {}", e))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
indoc! {r#"
|
||||
You are an assistant that crafts a concise session title.
|
||||
Given the first couple user messages in the conversation so far,
|
||||
reply with only a short name (up to 4 words) that best describes
|
||||
this session’s goal.
|
||||
|
||||
Examples:
|
||||
"#}
|
||||
.to_string()
|
||||
+ &examples
|
||||
}
|
||||
|
||||
/// Generates a short (≤4 words) session name
|
||||
pub async fn generate_session_name(messages: &[Message]) -> Result<String, ProviderError> {
|
||||
// Collect up to the first 3 user messages (truncated to 300 chars each)
|
||||
let context: Vec<String> = messages
|
||||
.iter()
|
||||
.filter(|m| m.role == Role::User)
|
||||
.take(3)
|
||||
.map(|m| {
|
||||
let text = m.content.concat_text_str();
|
||||
if text.len() > 300 {
|
||||
text.chars().take(300).collect()
|
||||
} else {
|
||||
text
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if context.is_empty() {
|
||||
return Err(ProviderError::ExecutionError(
|
||||
"No user messages found to generate a session name.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let system_prompt = build_system_prompt();
|
||||
let user_msg_text = format!("Here are the user messages:\n{}", context.join("\n"));
|
||||
|
||||
// Instantiate DatabricksProvider with goose-gpt-4-1
|
||||
let model_cfg = ModelConfig::new("goose-gpt-4-1".to_string()).with_temperature(Some(0.0));
|
||||
let provider = DatabricksProvider::from_env(model_cfg)?;
|
||||
|
||||
// Use `extract` with a simple string schema
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
},
|
||||
"required": ["name"],
|
||||
"additionalProperties": false
|
||||
});
|
||||
let user_msg = Message::user().with_text(&user_msg_text);
|
||||
let resp = provider
|
||||
.extract(&system_prompt, &[user_msg], &schema)
|
||||
.await?;
|
||||
|
||||
let obj = resp
|
||||
.data
|
||||
.as_object()
|
||||
.ok_or_else(|| ProviderError::ResponseParseError("Expected object".into()))?;
|
||||
|
||||
let name = obj
|
||||
.get("name")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| ProviderError::ResponseParseError("Missing or non-string name".into()))?
|
||||
.to_string();
|
||||
|
||||
Ok(name)
|
||||
}
|
||||
165
crates/goose-llm/src/extractors/tooltip.rs
Normal file
165
crates/goose-llm/src/extractors/tooltip.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
use crate::message::{Message, MessageContent};
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::Provider;
|
||||
use crate::providers::databricks::DatabricksProvider;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::types::core::{Content, Role};
|
||||
use anyhow::Result;
|
||||
use indoc::indoc;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
const TOOLTIP_EXAMPLES: &[&str] = &[
|
||||
"analyzing KPIs",
|
||||
"detecting anomalies",
|
||||
"building artifacts in Buildkite",
|
||||
"categorizing issues",
|
||||
"checking dependencies",
|
||||
"collecting feedback",
|
||||
"deploying changes in AWS",
|
||||
"drafting report in Google Docs",
|
||||
"extracting action items",
|
||||
"generating insights",
|
||||
"logging issues",
|
||||
"monitoring tickets in Zendesk",
|
||||
"notifying design team",
|
||||
"running integration tests",
|
||||
"scanning threads in Figma",
|
||||
"sending reminders in Gmail",
|
||||
"sending surveys",
|
||||
"sharing with stakeholders",
|
||||
"summarizing findings",
|
||||
"transcribing meeting",
|
||||
"tracking resolution",
|
||||
"updating status in Linear",
|
||||
];
|
||||
|
||||
fn build_system_prompt() -> String {
|
||||
let examples = TOOLTIP_EXAMPLES
|
||||
.iter()
|
||||
.map(|e| format!("- {}", e))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
indoc! {r#"
|
||||
You are an assistant that summarizes the recent conversation into a tooltip.
|
||||
Given the last two messages, reply with only a short tooltip (up to 4 words)
|
||||
describing what is happening now.
|
||||
|
||||
Examples:
|
||||
"#}
|
||||
.to_string()
|
||||
+ &examples
|
||||
}
|
||||
|
||||
/// Generates a tooltip summarizing the last two messages in the session,
|
||||
/// including any tool calls or results.
|
||||
pub async fn generate_tooltip(messages: &[Message]) -> Result<String, ProviderError> {
|
||||
// Need at least two messages to summarize
|
||||
if messages.len() < 2 {
|
||||
return Err(ProviderError::ExecutionError(
|
||||
"Need at least two messages to generate a tooltip".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Helper to render a single message's content
|
||||
fn render_message(m: &Message) -> String {
|
||||
let mut parts = Vec::new();
|
||||
for content in m.content.iter() {
|
||||
match content {
|
||||
MessageContent::Text(text_block) => {
|
||||
let txt = text_block.text.trim();
|
||||
if !txt.is_empty() {
|
||||
parts.push(txt.to_string());
|
||||
}
|
||||
}
|
||||
MessageContent::ToolRequest(req) => {
|
||||
if let Ok(tool_call) = &req.tool_call {
|
||||
parts.push(format!(
|
||||
"called tool '{}' with args {}",
|
||||
tool_call.name, tool_call.arguments
|
||||
));
|
||||
} else if let Err(e) = &req.tool_call {
|
||||
parts.push(format!("tool request error: {}", e));
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResponse(resp) => match &resp.tool_result {
|
||||
Ok(contents) => {
|
||||
let results: Vec<String> = contents
|
||||
.iter()
|
||||
.map(|c| match c {
|
||||
Content::Text(t) => t.text.clone(),
|
||||
Content::Image(_) => "[image]".to_string(),
|
||||
})
|
||||
.collect();
|
||||
parts.push(format!("tool responded with: {}", results.join(" ")));
|
||||
}
|
||||
Err(e) => {
|
||||
parts.push(format!("tool error: {}", e));
|
||||
}
|
||||
},
|
||||
_ => {} // ignore other variants
|
||||
}
|
||||
}
|
||||
|
||||
let role = match m.role {
|
||||
Role::User => "User",
|
||||
Role::Assistant => "Assistant",
|
||||
};
|
||||
|
||||
format!("{}: {}", role, parts.join("; "))
|
||||
}
|
||||
|
||||
// Take the last two messages (in correct chronological order)
|
||||
let rendered: Vec<String> = messages
|
||||
.iter()
|
||||
.rev()
|
||||
.take(2)
|
||||
.map(render_message)
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.rev()
|
||||
.collect();
|
||||
|
||||
let system_prompt = build_system_prompt();
|
||||
|
||||
let user_msg_text = format!(
|
||||
"Here are the last two messages:\n{}\n\nTooltip:",
|
||||
rendered.join("\n")
|
||||
);
|
||||
|
||||
// Instantiate the provider
|
||||
let model_cfg = ModelConfig::new("goose-gpt-4-1".to_string()).with_temperature(Some(0.0));
|
||||
let provider = DatabricksProvider::from_env(model_cfg)?;
|
||||
|
||||
// Schema wrapping our tooltip string
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tooltip": { "type": "string" }
|
||||
},
|
||||
"required": ["tooltip"],
|
||||
"additionalProperties": false
|
||||
});
|
||||
|
||||
// Call extract
|
||||
let user_msg = Message::user().with_text(&user_msg_text);
|
||||
let resp = provider
|
||||
.extract(&system_prompt, &[user_msg], &schema)
|
||||
.await?;
|
||||
|
||||
// Pull out the tooltip field
|
||||
let obj = resp
|
||||
.data
|
||||
.as_object()
|
||||
.ok_or_else(|| ProviderError::ResponseParseError("Expected JSON object".into()))?;
|
||||
|
||||
let tooltip = obj
|
||||
.get("tooltip")
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| {
|
||||
ProviderError::ResponseParseError("Missing or non-string `tooltip` field".into())
|
||||
})?
|
||||
.to_string();
|
||||
|
||||
Ok(tooltip)
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
mod completion;
|
||||
mod message;
|
||||
pub mod extractors;
|
||||
pub mod message;
|
||||
mod model;
|
||||
mod prompt_template;
|
||||
mod providers;
|
||||
pub mod providers;
|
||||
pub mod types;
|
||||
|
||||
pub use completion::completion;
|
||||
|
||||
@@ -43,6 +43,23 @@ impl ProviderCompleteResponse {
|
||||
}
|
||||
}
|
||||
|
||||
/// Response from a structured‐extraction call
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProviderExtractResponse {
|
||||
/// The extracted JSON object
|
||||
pub data: serde_json::Value,
|
||||
/// Which model produced it
|
||||
pub model: String,
|
||||
/// Token usage stats
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
impl ProviderExtractResponse {
|
||||
pub fn new(data: serde_json::Value, model: String, usage: Usage) -> Self {
|
||||
Self { data, model, usage }
|
||||
}
|
||||
}
|
||||
|
||||
/// Base trait for AI providers (OpenAI, Anthropic, etc)
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
@@ -65,6 +82,27 @@ pub trait Provider: Send + Sync {
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<ProviderCompleteResponse, ProviderError>;
|
||||
|
||||
/// Structured extraction: always JSON‐Schema
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `system` – system prompt guiding the extraction task
|
||||
/// * `messages` – conversation history
|
||||
/// * `schema` – a JSON‐Schema for the expected output.
|
||||
/// Will set strict=true for OpenAI & Databricks.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `ProviderExtractResponse` whose `data` is a JSON object matching `schema`.
|
||||
///
|
||||
/// # Errors
|
||||
/// * `ProviderError::ContextLengthExceeded` if the prompt is too large
|
||||
/// * other `ProviderError` variants for API/network failures
|
||||
async fn extract(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
schema: &serde_json::Value,
|
||||
) -> Result<ProviderExtractResponse, ProviderError>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -4,7 +4,7 @@ use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_json::{json, Value};
|
||||
use url::Url;
|
||||
|
||||
use super::{
|
||||
@@ -15,17 +15,15 @@ use super::{
|
||||
use crate::{
|
||||
message::Message,
|
||||
model::ModelConfig,
|
||||
providers::{Provider, ProviderCompleteResponse, Usage},
|
||||
providers::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage},
|
||||
types::core::Tool,
|
||||
};
|
||||
|
||||
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-meta-llama-3-3-70b-instruct";
|
||||
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet";
|
||||
// Databricks can passthrough to a wide range of models, we only provide the default
|
||||
pub const _DATABRICKS_KNOWN_MODELS: &[&str] = &[
|
||||
"databricks-meta-llama-3-3-70b-instruct",
|
||||
"databricks-meta-llama-3-1-405b-instruct",
|
||||
"databricks-dbrx-instruct",
|
||||
"databricks-mixtral-8x7b-instruct",
|
||||
"databricks-claude-3-7-sonnet",
|
||||
];
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -215,4 +213,63 @@ impl Provider for DatabricksProvider {
|
||||
|
||||
Ok(ProviderCompleteResponse::new(message, model, usage))
|
||||
}
|
||||
|
||||
async fn extract(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
schema: &Value,
|
||||
) -> Result<ProviderExtractResponse, ProviderError> {
|
||||
// 1. Build base payload (no tools)
|
||||
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
|
||||
|
||||
// 2. Inject strict JSON‐Schema wrapper
|
||||
payload
|
||||
.as_object_mut()
|
||||
.expect("payload must be an object")
|
||||
.insert(
|
||||
"response_format".to_string(),
|
||||
json!({
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "extraction",
|
||||
"schema": schema,
|
||||
"strict": true
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// 3. Call OpenAI
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
// 4. Extract the assistant’s `content` and parse it into JSON
|
||||
let msg = &response["choices"][0]["message"];
|
||||
let raw = msg.get("content").cloned().ok_or_else(|| {
|
||||
ProviderError::ResponseParseError("Missing content in extract response".into())
|
||||
})?;
|
||||
let data = match raw {
|
||||
Value::String(s) => serde_json::from_str(&s)
|
||||
.map_err(|e| ProviderError::ResponseParseError(format!("Invalid JSON: {}", e)))?,
|
||||
Value::Object(_) | Value::Array(_) => raw,
|
||||
other => {
|
||||
return Err(ProviderError::ResponseParseError(format!(
|
||||
"Unexpected content type: {:?}",
|
||||
other
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// 5. Gather usage & model info
|
||||
let usage = match get_usage(&response) {
|
||||
Ok(u) => u,
|
||||
Err(ProviderError::UsageError(e)) => {
|
||||
tracing::debug!("Failed to get usage in extract: {}", e);
|
||||
Usage::default()
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
let model = get_model(&response);
|
||||
|
||||
Ok(ProviderExtractResponse::new(data, model, usage))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,9 @@ pub enum ProviderError {
|
||||
|
||||
#[error("Usage data error: {0}")]
|
||||
UsageError(String),
|
||||
|
||||
#[error("Invalid response: {0}")]
|
||||
ResponseParseError(String),
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for ProviderError {
|
||||
|
||||
@@ -200,6 +200,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
}
|
||||
|
||||
/// Convert internal Tool format to OpenAI's API tool specification
|
||||
/// https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/api-reference#functionobject
|
||||
pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> {
|
||||
let mut tool_names = std::collections::HashSet::new();
|
||||
let mut result = Vec::new();
|
||||
|
||||
@@ -6,5 +6,5 @@ pub mod formats;
|
||||
pub mod openai;
|
||||
pub mod utils;
|
||||
|
||||
pub use base::{Provider, ProviderCompleteResponse, Usage};
|
||||
pub use base::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage};
|
||||
pub use factory::create;
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::{collections::HashMap, time::Duration};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::{
|
||||
errors::ProviderError,
|
||||
@@ -13,20 +13,12 @@ use super::{
|
||||
use crate::{
|
||||
message::Message,
|
||||
model::ModelConfig,
|
||||
providers::{Provider, ProviderCompleteResponse, Usage},
|
||||
providers::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage},
|
||||
types::core::Tool,
|
||||
};
|
||||
|
||||
pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o";
|
||||
pub const _OPEN_AI_KNOWN_MODELS: &[&str] = &[
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-3.5-turbo",
|
||||
"o1",
|
||||
"o3",
|
||||
"o4-mini",
|
||||
];
|
||||
pub const _OPEN_AI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4.1", "o1", "o3", "o4-mini"];
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OpenAiProvider {
|
||||
@@ -146,6 +138,65 @@ impl Provider for OpenAiProvider {
|
||||
emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||
Ok(ProviderCompleteResponse::new(message, model, usage))
|
||||
}
|
||||
|
||||
async fn extract(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
schema: &Value,
|
||||
) -> Result<ProviderExtractResponse, ProviderError> {
|
||||
// 1. Build base payload (no tools)
|
||||
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
|
||||
|
||||
// 2. Inject strict JSON‐Schema wrapper
|
||||
payload
|
||||
.as_object_mut()
|
||||
.expect("payload must be an object")
|
||||
.insert(
|
||||
"response_format".to_string(),
|
||||
json!({
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "extraction",
|
||||
"schema": schema,
|
||||
"strict": true
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// 3. Call OpenAI
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
// 4. Extract the assistant’s `content` and parse it into JSON
|
||||
let msg = &response["choices"][0]["message"];
|
||||
let raw = msg.get("content").cloned().ok_or_else(|| {
|
||||
ProviderError::ResponseParseError("Missing content in extract response".into())
|
||||
})?;
|
||||
let data = match raw {
|
||||
Value::String(s) => serde_json::from_str(&s)
|
||||
.map_err(|e| ProviderError::ResponseParseError(format!("Invalid JSON: {}", e)))?,
|
||||
Value::Object(_) | Value::Array(_) => raw,
|
||||
other => {
|
||||
return Err(ProviderError::ResponseParseError(format!(
|
||||
"Unexpected content type: {:?}",
|
||||
other
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// 5. Gather usage & model info
|
||||
let usage = match get_usage(&response) {
|
||||
Ok(u) => u,
|
||||
Err(ProviderError::UsageError(e)) => {
|
||||
tracing::debug!("Failed to get usage in extract: {}", e);
|
||||
Usage::default()
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
let model = get_model(&response);
|
||||
|
||||
Ok(ProviderExtractResponse::new(data, model, usage))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_custom_headers(s: String) -> HashMap<String, String> {
|
||||
|
||||
56
crates/goose-llm/tests/extract_session_name.rs
Normal file
56
crates/goose-llm/tests/extract_session_name.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use anyhow::Result;
|
||||
use dotenv::dotenv;
|
||||
use goose_llm::extractors::generate_session_name;
|
||||
use goose_llm::message::Message;
|
||||
use goose_llm::providers::errors::ProviderError;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_session_name_success() -> Result<(), ProviderError> {
|
||||
// Load .env for Databricks credentials
|
||||
dotenv().ok();
|
||||
let has_creds =
|
||||
std::env::var("DATABRICKS_HOST").is_ok() && std::env::var("DATABRICKS_TOKEN").is_ok();
|
||||
if !has_creds {
|
||||
println!("Skipping generate_session_name test – Databricks creds not set");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Build a few messages with at least two user messages
|
||||
let messages = vec![
|
||||
Message::user().with_text("Hello, how are you?"),
|
||||
Message::assistant().with_text("I’m fine, thanks!"),
|
||||
Message::user().with_text("What’s the weather in New York tomorrow?"),
|
||||
];
|
||||
|
||||
let name = generate_session_name(&messages).await?;
|
||||
println!("Generated session name: {:?}", name);
|
||||
|
||||
// Should be non-empty and at most 4 words
|
||||
let name = name.trim();
|
||||
assert!(!name.is_empty(), "Name must not be empty");
|
||||
let word_count = name.split_whitespace().count();
|
||||
assert!(
|
||||
word_count <= 4,
|
||||
"Name must be 4 words or less, got {}: {}",
|
||||
word_count,
|
||||
name
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_session_name_no_user() {
|
||||
// No user messages → expect ExecutionError
|
||||
let messages = vec![
|
||||
Message::assistant().with_text("System starting…"),
|
||||
Message::assistant().with_text("All systems go."),
|
||||
];
|
||||
|
||||
let err = generate_session_name(&messages).await;
|
||||
assert!(
|
||||
matches!(err, Err(ProviderError::ExecutionError(_))),
|
||||
"Expected ExecutionError when there are no user messages, got: {:?}",
|
||||
err
|
||||
);
|
||||
}
|
||||
69
crates/goose-llm/tests/extract_tooltip.rs
Normal file
69
crates/goose-llm/tests/extract_tooltip.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
use anyhow::Result;
|
||||
use dotenv::dotenv;
|
||||
use goose_llm::extractors::generate_tooltip;
|
||||
use goose_llm::message::{Message, MessageContent, ToolRequest};
|
||||
use goose_llm::providers::errors::ProviderError;
|
||||
use goose_llm::types::core::{Content, ToolCall};
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_tooltip_simple() -> Result<(), ProviderError> {
|
||||
// Skip if no Databricks creds
|
||||
dotenv().ok();
|
||||
if std::env::var("DATABRICKS_HOST").is_err() || std::env::var("DATABRICKS_TOKEN").is_err() {
|
||||
println!("Skipping simple tooltip test – Databricks creds not set");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Two plain-text messages
|
||||
let messages = vec![
|
||||
Message::user().with_text("Hello, how are you?"),
|
||||
Message::assistant().with_text("I'm fine, thanks! How can I help?"),
|
||||
];
|
||||
|
||||
let tooltip = generate_tooltip(&messages).await?;
|
||||
println!("Generated tooltip: {:?}", tooltip);
|
||||
|
||||
assert!(!tooltip.trim().is_empty(), "Tooltip must not be empty");
|
||||
assert!(
|
||||
tooltip.len() < 100,
|
||||
"Tooltip should be reasonably short (<100 chars)"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_tooltip_with_tools() -> Result<(), ProviderError> {
|
||||
// Skip if no Databricks creds
|
||||
dotenv().ok();
|
||||
if std::env::var("DATABRICKS_HOST").is_err() || std::env::var("DATABRICKS_TOKEN").is_err() {
|
||||
println!("Skipping tool‐based tooltip test – Databricks creds not set");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 1) Assistant message with a tool request
|
||||
let mut tool_req_msg = Message::assistant();
|
||||
let req = ToolRequest {
|
||||
id: "1".to_string(),
|
||||
tool_call: Ok(ToolCall::new("get_time", json!({"timezone": "UTC"}))),
|
||||
};
|
||||
tool_req_msg.content.push(MessageContent::ToolRequest(req));
|
||||
|
||||
// 2) User message with the tool response
|
||||
let tool_resp_msg = Message::user().with_tool_response(
|
||||
"1",
|
||||
Ok(vec![Content::text("The current time is 12:00 UTC")]),
|
||||
);
|
||||
|
||||
let messages = vec![tool_req_msg, tool_resp_msg];
|
||||
|
||||
let tooltip = generate_tooltip(&messages).await?;
|
||||
println!("Generated tooltip (tools): {:?}", tooltip);
|
||||
|
||||
assert!(!tooltip.trim().is_empty(), "Tooltip must not be empty");
|
||||
assert!(
|
||||
tooltip.len() < 100,
|
||||
"Tooltip should be reasonably short (<100 chars)"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
379
crates/goose-llm/tests/providers_complete.rs
Normal file
379
crates/goose-llm/tests/providers_complete.rs
Normal file
@@ -0,0 +1,379 @@
|
||||
use anyhow::Result;
|
||||
use dotenv::dotenv;
|
||||
use goose_llm::message::{Message, MessageContent};
|
||||
use goose_llm::providers::base::Provider;
|
||||
use goose_llm::providers::errors::ProviderError;
|
||||
use goose_llm::providers::{databricks, openai};
|
||||
use goose_llm::types::core::{Content, Tool};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum TestStatus {
|
||||
Passed,
|
||||
Skipped,
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TestStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TestStatus::Passed => write!(f, "✅"),
|
||||
TestStatus::Skipped => write!(f, "⏭️"),
|
||||
TestStatus::Failed => write!(f, "❌"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TestReport {
|
||||
results: Mutex<HashMap<String, TestStatus>>,
|
||||
}
|
||||
|
||||
impl TestReport {
|
||||
fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
results: Mutex::new(HashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
fn record_status(&self, provider: &str, status: TestStatus) {
|
||||
let mut results = self.results.lock().unwrap();
|
||||
results.insert(provider.to_string(), status);
|
||||
}
|
||||
|
||||
fn record_pass(&self, provider: &str) {
|
||||
self.record_status(provider, TestStatus::Passed);
|
||||
}
|
||||
|
||||
fn record_skip(&self, provider: &str) {
|
||||
self.record_status(provider, TestStatus::Skipped);
|
||||
}
|
||||
|
||||
fn record_fail(&self, provider: &str) {
|
||||
self.record_status(provider, TestStatus::Failed);
|
||||
}
|
||||
|
||||
fn print_summary(&self) {
|
||||
println!("\n============== Providers ==============");
|
||||
let results = self.results.lock().unwrap();
|
||||
let mut providers: Vec<_> = results.iter().collect();
|
||||
providers.sort_by(|a, b| a.0.cmp(b.0));
|
||||
|
||||
for (provider, status) in providers {
|
||||
println!("{} {}", status, provider);
|
||||
}
|
||||
println!("=======================================\n");
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref TEST_REPORT: Arc<TestReport> = TestReport::new();
|
||||
static ref ENV_LOCK: Mutex<()> = Mutex::new(());
|
||||
}
|
||||
|
||||
/// Generic test harness for any Provider implementation
|
||||
struct ProviderTester {
|
||||
provider: Arc<dyn Provider>,
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl ProviderTester {
|
||||
fn new<T: Provider + Send + Sync + 'static>(provider: T, name: String) -> Self {
|
||||
Self {
|
||||
provider: Arc::new(provider),
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
async fn test_basic_response(&self) -> Result<()> {
|
||||
let message = Message::user().with_text("Just say hello!");
|
||||
|
||||
let response = self
|
||||
.provider
|
||||
.complete("You are a helpful assistant.", &[message], &[])
|
||||
.await?;
|
||||
|
||||
// For a basic response, we expect a single text response
|
||||
assert_eq!(
|
||||
response.message.content.len(),
|
||||
1,
|
||||
"Expected single content item in response"
|
||||
);
|
||||
|
||||
// Verify we got a text response
|
||||
assert!(
|
||||
matches!(response.message.content[0], MessageContent::Text(_)),
|
||||
"Expected text response"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn test_tool_usage(&self) -> Result<()> {
|
||||
let weather_tool = Tool::new(
|
||||
"get_weather",
|
||||
"Get the weather for a location",
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let message = Message::user().with_text("What's the weather like in San Francisco?");
|
||||
|
||||
let response1 = self
|
||||
.provider
|
||||
.complete(
|
||||
"You are a helpful weather assistant.",
|
||||
&[message.clone()],
|
||||
&[weather_tool.clone()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
println!("=== {}::reponse1 ===", self.name);
|
||||
dbg!(&response1);
|
||||
println!("===================");
|
||||
|
||||
// Verify we got a tool request
|
||||
assert!(
|
||||
response1
|
||||
.message
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| matches!(content, MessageContent::ToolRequest(_))),
|
||||
"Expected tool request in response"
|
||||
);
|
||||
|
||||
let id = &response1
|
||||
.message
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|message| message.as_tool_request())
|
||||
.last()
|
||||
.expect("got tool request")
|
||||
.id;
|
||||
|
||||
let weather = Message::user().with_tool_response(
|
||||
id,
|
||||
Ok(vec![Content::text(
|
||||
"
|
||||
50°F°C
|
||||
Precipitation: 0%
|
||||
Humidity: 84%
|
||||
Wind: 2 mph
|
||||
Weather
|
||||
Saturday 9:00 PM
|
||||
Clear",
|
||||
)]),
|
||||
);
|
||||
|
||||
// Verify we construct a valid payload including the request/response pair for the next inference
|
||||
let response2 = self
|
||||
.provider
|
||||
.complete(
|
||||
"You are a helpful weather assistant.",
|
||||
&[message, response1.message, weather],
|
||||
&[weather_tool],
|
||||
)
|
||||
.await?;
|
||||
|
||||
println!("=== {}::reponse2 ===", self.name);
|
||||
dbg!(&response2);
|
||||
println!("===================");
|
||||
|
||||
assert!(
|
||||
response2
|
||||
.message
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| matches!(content, MessageContent::Text(_))),
|
||||
"Expected text for final response"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn test_context_length_exceeded_error(&self) -> Result<()> {
|
||||
// Google Gemini has a really long context window
|
||||
let large_message_content = if self.name.to_lowercase() == "google" {
|
||||
"hello ".repeat(1_300_000)
|
||||
} else {
|
||||
"hello ".repeat(300_000)
|
||||
};
|
||||
|
||||
let messages = vec![
|
||||
Message::user().with_text("hi there. what is 2 + 2?"),
|
||||
Message::assistant().with_text("hey! I think it's 4."),
|
||||
Message::user().with_text(&large_message_content),
|
||||
Message::assistant().with_text("heyy!!"),
|
||||
// Messages before this mark should be truncated
|
||||
Message::user().with_text("what's the meaning of life?"),
|
||||
Message::assistant().with_text("the meaning of life is 42"),
|
||||
Message::user().with_text(
|
||||
"did I ask you what's 2+2 in this message history? just respond with 'yes' or 'no'",
|
||||
),
|
||||
];
|
||||
|
||||
// Test that we get ProviderError::ContextLengthExceeded when the context window is exceeded
|
||||
let result = self
|
||||
.provider
|
||||
.complete("You are a helpful assistant.", &messages, &[])
|
||||
.await;
|
||||
|
||||
// Print some debug info
|
||||
println!("=== {}::context_length_exceeded_error ===", self.name);
|
||||
dbg!(&result);
|
||||
println!("===================");
|
||||
|
||||
// Ollama truncates by default even when the context window is exceeded
|
||||
if self.name.to_lowercase() == "ollama" {
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Expected to succeed because of default truncation"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Expected error when context window is exceeded"
|
||||
);
|
||||
assert!(
|
||||
matches!(result.unwrap_err(), ProviderError::ContextLengthExceeded(_)),
|
||||
"Expected error to be ContextLengthExceeded"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run all provider tests
|
||||
async fn run_test_suite(&self) -> Result<()> {
|
||||
self.test_basic_response().await?;
|
||||
self.test_tool_usage().await?;
|
||||
self.test_context_length_exceeded_error().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn load_env() {
|
||||
if let Ok(path) = dotenv() {
|
||||
println!("Loaded environment from {:?}", path);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to run a provider test with proper error handling and reporting
|
||||
async fn test_provider<F, T>(
|
||||
name: &str,
|
||||
required_vars: &[&str],
|
||||
env_modifications: Option<HashMap<&str, Option<String>>>,
|
||||
provider_fn: F,
|
||||
) -> Result<()>
|
||||
where
|
||||
F: FnOnce() -> T,
|
||||
T: Provider + Send + Sync + 'static,
|
||||
{
|
||||
// We start off as failed, so that if the process panics it is seen as a failure
|
||||
TEST_REPORT.record_fail(name);
|
||||
|
||||
// Take exclusive access to environment modifications
|
||||
let lock = ENV_LOCK.lock().unwrap();
|
||||
|
||||
load_env();
|
||||
|
||||
// Save current environment state for required vars and modified vars
|
||||
let mut original_env = HashMap::new();
|
||||
for &var in required_vars {
|
||||
if let Ok(val) = std::env::var(var) {
|
||||
original_env.insert(var, val);
|
||||
}
|
||||
}
|
||||
if let Some(mods) = &env_modifications {
|
||||
for &var in mods.keys() {
|
||||
if let Ok(val) = std::env::var(var) {
|
||||
original_env.insert(var, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply any environment modifications
|
||||
if let Some(mods) = &env_modifications {
|
||||
for (&var, value) in mods.iter() {
|
||||
match value {
|
||||
Some(val) => std::env::set_var(var, val),
|
||||
None => std::env::remove_var(var),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup the provider
|
||||
let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err());
|
||||
if missing_vars {
|
||||
println!("Skipping {} tests - credentials not configured", name);
|
||||
TEST_REPORT.record_skip(name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let provider = provider_fn();
|
||||
|
||||
// Restore original environment
|
||||
for (&var, value) in original_env.iter() {
|
||||
std::env::set_var(var, value);
|
||||
}
|
||||
if let Some(mods) = env_modifications {
|
||||
for &var in mods.keys() {
|
||||
if !original_env.contains_key(var) {
|
||||
std::env::remove_var(var);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::mem::drop(lock);
|
||||
|
||||
let tester = ProviderTester::new(provider, name.to_string());
|
||||
match tester.run_test_suite().await {
|
||||
Ok(_) => {
|
||||
TEST_REPORT.record_pass(name);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
println!("{} test failed: {}", name, e);
|
||||
TEST_REPORT.record_fail(name);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_complete() -> Result<()> {
|
||||
test_provider(
|
||||
"OpenAI",
|
||||
&["OPENAI_API_KEY"],
|
||||
None,
|
||||
openai::OpenAiProvider::default,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn databricks_complete() -> Result<()> {
|
||||
test_provider(
|
||||
"Databricks",
|
||||
&["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
|
||||
None,
|
||||
databricks::DatabricksProvider::default,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// Print the final test report
|
||||
#[ctor::dtor]
|
||||
fn print_test_report() {
|
||||
TEST_REPORT.print_summary();
|
||||
}
|
||||
195
crates/goose-llm/tests/providers_extract.rs
Normal file
195
crates/goose-llm/tests/providers_extract.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
// tests/providers_extract.rs
|
||||
|
||||
use anyhow::Result;
|
||||
use dotenv::dotenv;
|
||||
use goose_llm::message::Message;
|
||||
use goose_llm::providers::base::Provider;
|
||||
use goose_llm::providers::{databricks::DatabricksProvider, openai::OpenAiProvider};
|
||||
use goose_llm::ModelConfig;
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||
enum ProviderType {
|
||||
OpenAi,
|
||||
Databricks,
|
||||
}
|
||||
|
||||
impl ProviderType {
|
||||
fn required_env(&self) -> &'static [&'static str] {
|
||||
match self {
|
||||
ProviderType::OpenAi => &["OPENAI_API_KEY"],
|
||||
ProviderType::Databricks => &["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
|
||||
}
|
||||
}
|
||||
|
||||
fn create_provider(&self, cfg: ModelConfig) -> Result<Arc<dyn Provider>> {
|
||||
Ok(match self {
|
||||
ProviderType::OpenAi => Arc::new(OpenAiProvider::from_env(cfg)?),
|
||||
ProviderType::Databricks => Arc::new(DatabricksProvider::from_env(cfg)?),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn check_required_env_vars(required: &[&str]) -> bool {
|
||||
let missing: Vec<_> = required
|
||||
.iter()
|
||||
.filter(|&&v| std::env::var(v).is_err())
|
||||
.cloned()
|
||||
.collect();
|
||||
if !missing.is_empty() {
|
||||
println!("Skipping test; missing env vars: {:?}", missing);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
// --- Shared inputs for "paper" task ---
|
||||
const PAPER_SYSTEM: &str =
|
||||
"You are an expert at structured data extraction. Extract the metadata of a research paper into JSON.";
|
||||
const PAPER_TEXT: &str =
|
||||
"Application of Quantum Algorithms in Interstellar Navigation: A New Frontier \
|
||||
by Dr. Stella Voyager, Dr. Nova Star, Dr. Lyra Hunter. Abstract: This paper \
|
||||
investigates the utilization of quantum algorithms to improve interstellar \
|
||||
navigation systems. Keywords: Quantum algorithms, interstellar navigation, \
|
||||
space-time anomalies, quantum superposition, quantum entanglement, space travel.";
|
||||
|
||||
fn paper_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": { "type": "string" },
|
||||
"authors": { "type": "array", "items": { "type": "string" } },
|
||||
"abstract": { "type": "string" },
|
||||
"keywords": { "type": "array", "items": { "type": "string" } }
|
||||
},
|
||||
"required": ["title","authors","abstract","keywords"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
|
||||
// --- Shared inputs for "UI" task ---
|
||||
const UI_SYSTEM: &str = "You are a UI generator AI. Convert the user input into a JSON-driven UI.";
|
||||
const UI_TEXT: &str = "Make a User Profile Form";
|
||||
|
||||
fn ui_schema() -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["div","button","header","section","field","form"]
|
||||
},
|
||||
"label": { "type": "string" },
|
||||
"children": {
|
||||
"type": "array",
|
||||
"items": { "$ref": "#" }
|
||||
},
|
||||
"attributes": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"value": { "type": "string" }
|
||||
},
|
||||
"required": ["name","value"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["type","label","children","attributes"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
|
||||
/// Generic runner for any extract task
|
||||
async fn run_extract_test<F>(
|
||||
provider_type: ProviderType,
|
||||
model: &str,
|
||||
system: &'static str,
|
||||
user_text: &'static str,
|
||||
schema: Value,
|
||||
validate: F,
|
||||
) -> Result<()>
|
||||
where
|
||||
F: Fn(&Value) -> bool,
|
||||
{
|
||||
dotenv().ok();
|
||||
if !check_required_env_vars(provider_type.required_env()) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let cfg = ModelConfig::new(model.to_string()).with_temperature(Some(0.0));
|
||||
let provider = provider_type.create_provider(cfg)?;
|
||||
|
||||
let msg = Message::user().with_text(user_text);
|
||||
let resp = provider.extract(system, &[msg], &schema).await?;
|
||||
|
||||
println!("[{:?}] extract => {}", provider_type, resp.data);
|
||||
|
||||
assert!(
|
||||
validate(&resp.data),
|
||||
"{:?} failed validation on {}",
|
||||
provider_type,
|
||||
resp.data
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Helper for the "paper" task
|
||||
async fn run_extract_paper_test(provider: ProviderType, model: &str) -> Result<()> {
|
||||
run_extract_test(
|
||||
provider,
|
||||
model,
|
||||
PAPER_SYSTEM,
|
||||
PAPER_TEXT,
|
||||
paper_schema(),
|
||||
|v| {
|
||||
v.as_object()
|
||||
.map(|o| {
|
||||
["title", "authors", "abstract", "keywords"]
|
||||
.iter()
|
||||
.all(|k| o.contains_key(*k))
|
||||
})
|
||||
.unwrap_or(false)
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Helper for the "UI" task
|
||||
async fn run_extract_ui_test(provider: ProviderType, model: &str) -> Result<()> {
|
||||
run_extract_test(provider, model, UI_SYSTEM, UI_TEXT, ui_schema(), |v| {
|
||||
v.as_object()
|
||||
.and_then(|o| o.get("type").and_then(Value::as_str))
|
||||
== Some("form")
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_extract_paper() -> Result<()> {
|
||||
run_extract_paper_test(ProviderType::OpenAi, "gpt-4o").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_extract_ui() -> Result<()> {
|
||||
run_extract_ui_test(ProviderType::OpenAi, "gpt-4o").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn databricks_extract_paper() -> Result<()> {
|
||||
run_extract_paper_test(ProviderType::Databricks, "goose-gpt-4-1").await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn databricks_extract_ui() -> Result<()> {
|
||||
run_extract_ui_test(ProviderType::Databricks, "goose-gpt-4-1").await
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user