mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 22:24:21 +01:00
196 lines
5.7 KiB
Rust
196 lines
5.7 KiB
Rust
// tests/providers_extract.rs
|
|
|
|
use anyhow::Result;
|
|
use dotenv::dotenv;
|
|
use goose_llm::message::Message;
|
|
use goose_llm::providers::base::Provider;
|
|
use goose_llm::providers::{databricks::DatabricksProvider, openai::OpenAiProvider};
|
|
use goose_llm::ModelConfig;
|
|
use serde_json::{json, Value};
|
|
use std::sync::Arc;
|
|
|
|
#[derive(Debug, PartialEq, Copy, Clone)]
|
|
enum ProviderType {
|
|
OpenAi,
|
|
Databricks,
|
|
}
|
|
|
|
impl ProviderType {
|
|
fn required_env(&self) -> &'static [&'static str] {
|
|
match self {
|
|
ProviderType::OpenAi => &["OPENAI_API_KEY"],
|
|
ProviderType::Databricks => &["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
|
|
}
|
|
}
|
|
|
|
fn create_provider(&self, cfg: ModelConfig) -> Result<Arc<dyn Provider>> {
|
|
Ok(match self {
|
|
ProviderType::OpenAi => Arc::new(OpenAiProvider::from_env(cfg)),
|
|
ProviderType::Databricks => Arc::new(DatabricksProvider::from_env(cfg)),
|
|
})
|
|
}
|
|
}
|
|
|
|
fn check_required_env_vars(required: &[&str]) -> bool {
|
|
let missing: Vec<_> = required
|
|
.iter()
|
|
.filter(|&&v| std::env::var(v).is_err())
|
|
.cloned()
|
|
.collect();
|
|
if !missing.is_empty() {
|
|
println!("Skipping test; missing env vars: {:?}", missing);
|
|
false
|
|
} else {
|
|
true
|
|
}
|
|
}
|
|
|
|
// --- Shared inputs for "paper" task ---
|
|
const PAPER_SYSTEM: &str =
|
|
"You are an expert at structured data extraction. Extract the metadata of a research paper into JSON.";
|
|
const PAPER_TEXT: &str =
|
|
"Application of Quantum Algorithms in Interstellar Navigation: A New Frontier \
|
|
by Dr. Stella Voyager, Dr. Nova Star, Dr. Lyra Hunter. Abstract: This paper \
|
|
investigates the utilization of quantum algorithms to improve interstellar \
|
|
navigation systems. Keywords: Quantum algorithms, interstellar navigation, \
|
|
space-time anomalies, quantum superposition, quantum entanglement, space travel.";
|
|
|
|
fn paper_schema() -> Value {
|
|
json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"title": { "type": "string" },
|
|
"authors": { "type": "array", "items": { "type": "string" } },
|
|
"abstract": { "type": "string" },
|
|
"keywords": { "type": "array", "items": { "type": "string" } }
|
|
},
|
|
"required": ["title","authors","abstract","keywords"],
|
|
"additionalProperties": false
|
|
})
|
|
}
|
|
|
|
// --- Shared inputs for "UI" task ---
|
|
const UI_SYSTEM: &str = "You are a UI generator AI. Convert the user input into a JSON-driven UI.";
|
|
const UI_TEXT: &str = "Make a User Profile Form";
|
|
|
|
fn ui_schema() -> Value {
|
|
json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"type": {
|
|
"type": "string",
|
|
"enum": ["div","button","header","section","field","form"]
|
|
},
|
|
"label": { "type": "string" },
|
|
"children": {
|
|
"type": "array",
|
|
"items": { "$ref": "#" }
|
|
},
|
|
"attributes": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": { "type": "string" },
|
|
"value": { "type": "string" }
|
|
},
|
|
"required": ["name","value"],
|
|
"additionalProperties": false
|
|
}
|
|
}
|
|
},
|
|
"required": ["type","label","children","attributes"],
|
|
"additionalProperties": false
|
|
})
|
|
}
|
|
|
|
/// Generic runner for any extract task
|
|
async fn run_extract_test<F>(
|
|
provider_type: ProviderType,
|
|
model: &str,
|
|
system: &'static str,
|
|
user_text: &'static str,
|
|
schema: Value,
|
|
validate: F,
|
|
) -> Result<()>
|
|
where
|
|
F: Fn(&Value) -> bool,
|
|
{
|
|
dotenv().ok();
|
|
if !check_required_env_vars(provider_type.required_env()) {
|
|
return Ok(());
|
|
}
|
|
|
|
let cfg = ModelConfig::new(model.to_string()).with_temperature(Some(0.0));
|
|
let provider = provider_type.create_provider(cfg)?;
|
|
|
|
let msg = Message::user().with_text(user_text);
|
|
let resp = provider.extract(system, &[msg], &schema).await?;
|
|
|
|
println!("[{:?}] extract => {}", provider_type, resp.data);
|
|
|
|
assert!(
|
|
validate(&resp.data),
|
|
"{:?} failed validation on {}",
|
|
provider_type,
|
|
resp.data
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
/// Helper for the "paper" task
|
|
async fn run_extract_paper_test(provider: ProviderType, model: &str) -> Result<()> {
|
|
run_extract_test(
|
|
provider,
|
|
model,
|
|
PAPER_SYSTEM,
|
|
PAPER_TEXT,
|
|
paper_schema(),
|
|
|v| {
|
|
v.as_object()
|
|
.map(|o| {
|
|
["title", "authors", "abstract", "keywords"]
|
|
.iter()
|
|
.all(|k| o.contains_key(*k))
|
|
})
|
|
.unwrap_or(false)
|
|
},
|
|
)
|
|
.await
|
|
}
|
|
|
|
/// Helper for the "UI" task
|
|
async fn run_extract_ui_test(provider: ProviderType, model: &str) -> Result<()> {
|
|
run_extract_test(provider, model, UI_SYSTEM, UI_TEXT, ui_schema(), |v| {
|
|
v.as_object()
|
|
.and_then(|o| o.get("type").and_then(Value::as_str))
|
|
== Some("form")
|
|
})
|
|
.await
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn openai_extract_paper() -> Result<()> {
|
|
run_extract_paper_test(ProviderType::OpenAi, "gpt-4o").await
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn openai_extract_ui() -> Result<()> {
|
|
run_extract_ui_test(ProviderType::OpenAi, "gpt-4o").await
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn databricks_extract_paper() -> Result<()> {
|
|
run_extract_paper_test(ProviderType::Databricks, "goose-gpt-4-1").await
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn databricks_extract_ui() -> Result<()> {
|
|
run_extract_ui_test(ProviderType::Databricks, "goose-gpt-4-1").await
|
|
}
|
|
}
|