mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 07:04:21 +01:00
[goose-llm] add generate tooltip & session name via extract method (#2467)
* extract method in provider to use for structured outputs * generate session name from msgs * generate tooltip from msgs * add provider tests
This commit is contained in:
14
Cargo.lock
generated
14
Cargo.lock
generated
@@ -2406,7 +2406,7 @@ dependencies = [
|
|||||||
"fs2",
|
"fs2",
|
||||||
"futures",
|
"futures",
|
||||||
"include_dir",
|
"include_dir",
|
||||||
"indoc",
|
"indoc 2.0.6",
|
||||||
"jsonwebtoken",
|
"jsonwebtoken",
|
||||||
"keyring",
|
"keyring",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
@@ -2527,7 +2527,11 @@ dependencies = [
|
|||||||
"base64 0.21.7",
|
"base64 0.21.7",
|
||||||
"chrono",
|
"chrono",
|
||||||
"criterion",
|
"criterion",
|
||||||
|
"ctor",
|
||||||
|
"dotenv",
|
||||||
"include_dir",
|
"include_dir",
|
||||||
|
"indoc 1.0.9",
|
||||||
|
"lazy_static",
|
||||||
"minijinja",
|
"minijinja",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"regex",
|
"regex",
|
||||||
@@ -2559,7 +2563,7 @@ dependencies = [
|
|||||||
"ignore",
|
"ignore",
|
||||||
"image 0.24.9",
|
"image 0.24.9",
|
||||||
"include_dir",
|
"include_dir",
|
||||||
"indoc",
|
"indoc 2.0.6",
|
||||||
"keyring",
|
"keyring",
|
||||||
"kill_tree",
|
"kill_tree",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
@@ -3263,6 +3267,12 @@ dependencies = [
|
|||||||
"web-time",
|
"web-time",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "indoc"
|
||||||
|
version = "1.0.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "indoc"
|
name = "indoc"
|
||||||
version = "2.0.6"
|
version = "2.0.6"
|
||||||
|
|||||||
@@ -35,10 +35,14 @@ base64 = "0.21"
|
|||||||
regex = "1.11.1"
|
regex = "1.11.1"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
smallvec = { version = "1.13", features = ["serde"] }
|
smallvec = { version = "1.13", features = ["serde"] }
|
||||||
|
indoc = "1.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = "0.5"
|
criterion = "0.5"
|
||||||
tempfile = "3.15.0"
|
tempfile = "3.15.0"
|
||||||
|
dotenv = "0.15"
|
||||||
|
lazy_static = "1.5"
|
||||||
|
ctor = "0.2.7"
|
||||||
|
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
|
|||||||
5
crates/goose-llm/src/extractors/mod.rs
Normal file
5
crates/goose-llm/src/extractors/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
mod session_name;
|
||||||
|
mod tooltip;
|
||||||
|
|
||||||
|
pub use session_name::generate_session_name;
|
||||||
|
pub use tooltip::generate_tooltip;
|
||||||
107
crates/goose-llm/src/extractors/session_name.rs
Normal file
107
crates/goose-llm/src/extractors/session_name.rs
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
use crate::message::Message;
|
||||||
|
use crate::model::ModelConfig;
|
||||||
|
use crate::providers::base::Provider;
|
||||||
|
use crate::providers::databricks::DatabricksProvider;
|
||||||
|
use crate::providers::errors::ProviderError;
|
||||||
|
use crate::types::core::Role;
|
||||||
|
use anyhow::Result;
|
||||||
|
use indoc::indoc;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
|
const SESSION_NAME_EXAMPLES: &[&str] = &[
|
||||||
|
"Research Synthesis",
|
||||||
|
"Sentiment Analysis",
|
||||||
|
"Performance Report",
|
||||||
|
"Feedback Collector",
|
||||||
|
"Accessibility Check",
|
||||||
|
"Design Reminder",
|
||||||
|
"Project Reminder",
|
||||||
|
"Launch Checklist",
|
||||||
|
"Metrics Monitor",
|
||||||
|
"Incident Response",
|
||||||
|
"Deploy Cabinet App",
|
||||||
|
"Design Reminder Alert",
|
||||||
|
"Generate Monthly Expense Report",
|
||||||
|
"Automate Incident Response Workflow",
|
||||||
|
"Analyze Brand Sentiment Trends",
|
||||||
|
"Monitor Device Health Issues",
|
||||||
|
"Collect UI Feedback Summary",
|
||||||
|
"Schedule Project Deadline Reminders",
|
||||||
|
];
|
||||||
|
|
||||||
|
fn build_system_prompt() -> String {
|
||||||
|
let examples = SESSION_NAME_EXAMPLES
|
||||||
|
.iter()
|
||||||
|
.map(|e| format!("- {}", e))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
indoc! {r#"
|
||||||
|
You are an assistant that crafts a concise session title.
|
||||||
|
Given the first couple user messages in the conversation so far,
|
||||||
|
reply with only a short name (up to 4 words) that best describes
|
||||||
|
this session’s goal.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
"#}
|
||||||
|
.to_string()
|
||||||
|
+ &examples
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a short (≤4 words) session name
|
||||||
|
pub async fn generate_session_name(messages: &[Message]) -> Result<String, ProviderError> {
|
||||||
|
// Collect up to the first 3 user messages (truncated to 300 chars each)
|
||||||
|
let context: Vec<String> = messages
|
||||||
|
.iter()
|
||||||
|
.filter(|m| m.role == Role::User)
|
||||||
|
.take(3)
|
||||||
|
.map(|m| {
|
||||||
|
let text = m.content.concat_text_str();
|
||||||
|
if text.len() > 300 {
|
||||||
|
text.chars().take(300).collect()
|
||||||
|
} else {
|
||||||
|
text
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if context.is_empty() {
|
||||||
|
return Err(ProviderError::ExecutionError(
|
||||||
|
"No user messages found to generate a session name.".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let system_prompt = build_system_prompt();
|
||||||
|
let user_msg_text = format!("Here are the user messages:\n{}", context.join("\n"));
|
||||||
|
|
||||||
|
// Instantiate DatabricksProvider with goose-gpt-4-1
|
||||||
|
let model_cfg = ModelConfig::new("goose-gpt-4-1".to_string()).with_temperature(Some(0.0));
|
||||||
|
let provider = DatabricksProvider::from_env(model_cfg)?;
|
||||||
|
|
||||||
|
// Use `extract` with a simple string schema
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" }
|
||||||
|
},
|
||||||
|
"required": ["name"],
|
||||||
|
"additionalProperties": false
|
||||||
|
});
|
||||||
|
let user_msg = Message::user().with_text(&user_msg_text);
|
||||||
|
let resp = provider
|
||||||
|
.extract(&system_prompt, &[user_msg], &schema)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let obj = resp
|
||||||
|
.data
|
||||||
|
.as_object()
|
||||||
|
.ok_or_else(|| ProviderError::ResponseParseError("Expected object".into()))?;
|
||||||
|
|
||||||
|
let name = obj
|
||||||
|
.get("name")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.ok_or_else(|| ProviderError::ResponseParseError("Missing or non-string name".into()))?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
Ok(name)
|
||||||
|
}
|
||||||
165
crates/goose-llm/src/extractors/tooltip.rs
Normal file
165
crates/goose-llm/src/extractors/tooltip.rs
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
use crate::message::{Message, MessageContent};
|
||||||
|
use crate::model::ModelConfig;
|
||||||
|
use crate::providers::base::Provider;
|
||||||
|
use crate::providers::databricks::DatabricksProvider;
|
||||||
|
use crate::providers::errors::ProviderError;
|
||||||
|
use crate::types::core::{Content, Role};
|
||||||
|
use anyhow::Result;
|
||||||
|
use indoc::indoc;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
|
const TOOLTIP_EXAMPLES: &[&str] = &[
|
||||||
|
"analyzing KPIs",
|
||||||
|
"detecting anomalies",
|
||||||
|
"building artifacts in Buildkite",
|
||||||
|
"categorizing issues",
|
||||||
|
"checking dependencies",
|
||||||
|
"collecting feedback",
|
||||||
|
"deploying changes in AWS",
|
||||||
|
"drafting report in Google Docs",
|
||||||
|
"extracting action items",
|
||||||
|
"generating insights",
|
||||||
|
"logging issues",
|
||||||
|
"monitoring tickets in Zendesk",
|
||||||
|
"notifying design team",
|
||||||
|
"running integration tests",
|
||||||
|
"scanning threads in Figma",
|
||||||
|
"sending reminders in Gmail",
|
||||||
|
"sending surveys",
|
||||||
|
"sharing with stakeholders",
|
||||||
|
"summarizing findings",
|
||||||
|
"transcribing meeting",
|
||||||
|
"tracking resolution",
|
||||||
|
"updating status in Linear",
|
||||||
|
];
|
||||||
|
|
||||||
|
fn build_system_prompt() -> String {
|
||||||
|
let examples = TOOLTIP_EXAMPLES
|
||||||
|
.iter()
|
||||||
|
.map(|e| format!("- {}", e))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
indoc! {r#"
|
||||||
|
You are an assistant that summarizes the recent conversation into a tooltip.
|
||||||
|
Given the last two messages, reply with only a short tooltip (up to 4 words)
|
||||||
|
describing what is happening now.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
"#}
|
||||||
|
.to_string()
|
||||||
|
+ &examples
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a tooltip summarizing the last two messages in the session,
|
||||||
|
/// including any tool calls or results.
|
||||||
|
pub async fn generate_tooltip(messages: &[Message]) -> Result<String, ProviderError> {
|
||||||
|
// Need at least two messages to summarize
|
||||||
|
if messages.len() < 2 {
|
||||||
|
return Err(ProviderError::ExecutionError(
|
||||||
|
"Need at least two messages to generate a tooltip".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to render a single message's content
|
||||||
|
fn render_message(m: &Message) -> String {
|
||||||
|
let mut parts = Vec::new();
|
||||||
|
for content in m.content.iter() {
|
||||||
|
match content {
|
||||||
|
MessageContent::Text(text_block) => {
|
||||||
|
let txt = text_block.text.trim();
|
||||||
|
if !txt.is_empty() {
|
||||||
|
parts.push(txt.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MessageContent::ToolRequest(req) => {
|
||||||
|
if let Ok(tool_call) = &req.tool_call {
|
||||||
|
parts.push(format!(
|
||||||
|
"called tool '{}' with args {}",
|
||||||
|
tool_call.name, tool_call.arguments
|
||||||
|
));
|
||||||
|
} else if let Err(e) = &req.tool_call {
|
||||||
|
parts.push(format!("tool request error: {}", e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MessageContent::ToolResponse(resp) => match &resp.tool_result {
|
||||||
|
Ok(contents) => {
|
||||||
|
let results: Vec<String> = contents
|
||||||
|
.iter()
|
||||||
|
.map(|c| match c {
|
||||||
|
Content::Text(t) => t.text.clone(),
|
||||||
|
Content::Image(_) => "[image]".to_string(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
parts.push(format!("tool responded with: {}", results.join(" ")));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
parts.push(format!("tool error: {}", e));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => {} // ignore other variants
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let role = match m.role {
|
||||||
|
Role::User => "User",
|
||||||
|
Role::Assistant => "Assistant",
|
||||||
|
};
|
||||||
|
|
||||||
|
format!("{}: {}", role, parts.join("; "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take the last two messages (in correct chronological order)
|
||||||
|
let rendered: Vec<String> = messages
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.take(2)
|
||||||
|
.map(render_message)
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.into_iter()
|
||||||
|
.rev()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let system_prompt = build_system_prompt();
|
||||||
|
|
||||||
|
let user_msg_text = format!(
|
||||||
|
"Here are the last two messages:\n{}\n\nTooltip:",
|
||||||
|
rendered.join("\n")
|
||||||
|
);
|
||||||
|
|
||||||
|
// Instantiate the provider
|
||||||
|
let model_cfg = ModelConfig::new("goose-gpt-4-1".to_string()).with_temperature(Some(0.0));
|
||||||
|
let provider = DatabricksProvider::from_env(model_cfg)?;
|
||||||
|
|
||||||
|
// Schema wrapping our tooltip string
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tooltip": { "type": "string" }
|
||||||
|
},
|
||||||
|
"required": ["tooltip"],
|
||||||
|
"additionalProperties": false
|
||||||
|
});
|
||||||
|
|
||||||
|
// Call extract
|
||||||
|
let user_msg = Message::user().with_text(&user_msg_text);
|
||||||
|
let resp = provider
|
||||||
|
.extract(&system_prompt, &[user_msg], &schema)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Pull out the tooltip field
|
||||||
|
let obj = resp
|
||||||
|
.data
|
||||||
|
.as_object()
|
||||||
|
.ok_or_else(|| ProviderError::ResponseParseError("Expected JSON object".into()))?;
|
||||||
|
|
||||||
|
let tooltip = obj
|
||||||
|
.get("tooltip")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
ProviderError::ResponseParseError("Missing or non-string `tooltip` field".into())
|
||||||
|
})?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
Ok(tooltip)
|
||||||
|
}
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
mod completion;
|
mod completion;
|
||||||
mod message;
|
pub mod extractors;
|
||||||
|
pub mod message;
|
||||||
mod model;
|
mod model;
|
||||||
mod prompt_template;
|
mod prompt_template;
|
||||||
mod providers;
|
pub mod providers;
|
||||||
pub mod types;
|
pub mod types;
|
||||||
|
|
||||||
pub use completion::completion;
|
pub use completion::completion;
|
||||||
|
|||||||
@@ -43,6 +43,23 @@ impl ProviderCompleteResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Response from a structured‐extraction call
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ProviderExtractResponse {
|
||||||
|
/// The extracted JSON object
|
||||||
|
pub data: serde_json::Value,
|
||||||
|
/// Which model produced it
|
||||||
|
pub model: String,
|
||||||
|
/// Token usage stats
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderExtractResponse {
|
||||||
|
pub fn new(data: serde_json::Value, model: String, usage: Usage) -> Self {
|
||||||
|
Self { data, model, usage }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Base trait for AI providers (OpenAI, Anthropic, etc)
|
/// Base trait for AI providers (OpenAI, Anthropic, etc)
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Provider: Send + Sync {
|
pub trait Provider: Send + Sync {
|
||||||
@@ -65,6 +82,27 @@ pub trait Provider: Send + Sync {
|
|||||||
messages: &[Message],
|
messages: &[Message],
|
||||||
tools: &[Tool],
|
tools: &[Tool],
|
||||||
) -> Result<ProviderCompleteResponse, ProviderError>;
|
) -> Result<ProviderCompleteResponse, ProviderError>;
|
||||||
|
|
||||||
|
/// Structured extraction: always JSON‐Schema
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `system` – system prompt guiding the extraction task
|
||||||
|
/// * `messages` – conversation history
|
||||||
|
/// * `schema` – a JSON‐Schema for the expected output.
|
||||||
|
/// Will set strict=true for OpenAI & Databricks.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// A `ProviderExtractResponse` whose `data` is a JSON object matching `schema`.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
/// * `ProviderError::ContextLengthExceeded` if the prompt is too large
|
||||||
|
/// * other `ProviderError` variants for API/network failures
|
||||||
|
async fn extract(
|
||||||
|
&self,
|
||||||
|
system: &str,
|
||||||
|
messages: &[Message],
|
||||||
|
schema: &serde_json::Value,
|
||||||
|
) -> Result<ProviderExtractResponse, ProviderError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use anyhow::Result;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::{Client, StatusCode};
|
use reqwest::{Client, StatusCode};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::{json, Value};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
@@ -15,17 +15,15 @@ use super::{
|
|||||||
use crate::{
|
use crate::{
|
||||||
message::Message,
|
message::Message,
|
||||||
model::ModelConfig,
|
model::ModelConfig,
|
||||||
providers::{Provider, ProviderCompleteResponse, Usage},
|
providers::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage},
|
||||||
types::core::Tool,
|
types::core::Tool,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-meta-llama-3-3-70b-instruct";
|
pub const DATABRICKS_DEFAULT_MODEL: &str = "databricks-claude-3-7-sonnet";
|
||||||
// Databricks can passthrough to a wide range of models, we only provide the default
|
// Databricks can passthrough to a wide range of models, we only provide the default
|
||||||
pub const _DATABRICKS_KNOWN_MODELS: &[&str] = &[
|
pub const _DATABRICKS_KNOWN_MODELS: &[&str] = &[
|
||||||
"databricks-meta-llama-3-3-70b-instruct",
|
"databricks-meta-llama-3-3-70b-instruct",
|
||||||
"databricks-meta-llama-3-1-405b-instruct",
|
"databricks-claude-3-7-sonnet",
|
||||||
"databricks-dbrx-instruct",
|
|
||||||
"databricks-mixtral-8x7b-instruct",
|
|
||||||
];
|
];
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -215,4 +213,63 @@ impl Provider for DatabricksProvider {
|
|||||||
|
|
||||||
Ok(ProviderCompleteResponse::new(message, model, usage))
|
Ok(ProviderCompleteResponse::new(message, model, usage))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn extract(
|
||||||
|
&self,
|
||||||
|
system: &str,
|
||||||
|
messages: &[Message],
|
||||||
|
schema: &Value,
|
||||||
|
) -> Result<ProviderExtractResponse, ProviderError> {
|
||||||
|
// 1. Build base payload (no tools)
|
||||||
|
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
|
||||||
|
|
||||||
|
// 2. Inject strict JSON‐Schema wrapper
|
||||||
|
payload
|
||||||
|
.as_object_mut()
|
||||||
|
.expect("payload must be an object")
|
||||||
|
.insert(
|
||||||
|
"response_format".to_string(),
|
||||||
|
json!({
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": "extraction",
|
||||||
|
"schema": schema,
|
||||||
|
"strict": true
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// 3. Call OpenAI
|
||||||
|
let response = self.post(payload.clone()).await?;
|
||||||
|
|
||||||
|
// 4. Extract the assistant’s `content` and parse it into JSON
|
||||||
|
let msg = &response["choices"][0]["message"];
|
||||||
|
let raw = msg.get("content").cloned().ok_or_else(|| {
|
||||||
|
ProviderError::ResponseParseError("Missing content in extract response".into())
|
||||||
|
})?;
|
||||||
|
let data = match raw {
|
||||||
|
Value::String(s) => serde_json::from_str(&s)
|
||||||
|
.map_err(|e| ProviderError::ResponseParseError(format!("Invalid JSON: {}", e)))?,
|
||||||
|
Value::Object(_) | Value::Array(_) => raw,
|
||||||
|
other => {
|
||||||
|
return Err(ProviderError::ResponseParseError(format!(
|
||||||
|
"Unexpected content type: {:?}",
|
||||||
|
other
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// 5. Gather usage & model info
|
||||||
|
let usage = match get_usage(&response) {
|
||||||
|
Ok(u) => u,
|
||||||
|
Err(ProviderError::UsageError(e)) => {
|
||||||
|
tracing::debug!("Failed to get usage in extract: {}", e);
|
||||||
|
Usage::default()
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
};
|
||||||
|
let model = get_model(&response);
|
||||||
|
|
||||||
|
Ok(ProviderExtractResponse::new(data, model, usage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ pub enum ProviderError {
|
|||||||
|
|
||||||
#[error("Usage data error: {0}")]
|
#[error("Usage data error: {0}")]
|
||||||
UsageError(String),
|
UsageError(String),
|
||||||
|
|
||||||
|
#[error("Invalid response: {0}")]
|
||||||
|
ResponseParseError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<anyhow::Error> for ProviderError {
|
impl From<anyhow::Error> for ProviderError {
|
||||||
|
|||||||
@@ -200,6 +200,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Convert internal Tool format to OpenAI's API tool specification
|
/// Convert internal Tool format to OpenAI's API tool specification
|
||||||
|
/// https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/api-reference#functionobject
|
||||||
pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> {
|
pub fn format_tools(tools: &[Tool]) -> anyhow::Result<Vec<Value>> {
|
||||||
let mut tool_names = std::collections::HashSet::new();
|
let mut tool_names = std::collections::HashSet::new();
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
|
|||||||
@@ -6,5 +6,5 @@ pub mod formats;
|
|||||||
pub mod openai;
|
pub mod openai;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
pub use base::{Provider, ProviderCompleteResponse, Usage};
|
pub use base::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage};
|
||||||
pub use factory::create;
|
pub use factory::create;
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use std::{collections::HashMap, time::Duration};
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::Value;
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
errors::ProviderError,
|
errors::ProviderError,
|
||||||
@@ -13,20 +13,12 @@ use super::{
|
|||||||
use crate::{
|
use crate::{
|
||||||
message::Message,
|
message::Message,
|
||||||
model::ModelConfig,
|
model::ModelConfig,
|
||||||
providers::{Provider, ProviderCompleteResponse, Usage},
|
providers::{Provider, ProviderCompleteResponse, ProviderExtractResponse, Usage},
|
||||||
types::core::Tool,
|
types::core::Tool,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o";
|
pub const OPEN_AI_DEFAULT_MODEL: &str = "gpt-4o";
|
||||||
pub const _OPEN_AI_KNOWN_MODELS: &[&str] = &[
|
pub const _OPEN_AI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4.1", "o1", "o3", "o4-mini"];
|
||||||
"gpt-4o",
|
|
||||||
"gpt-4o-mini",
|
|
||||||
"gpt-4-turbo",
|
|
||||||
"gpt-3.5-turbo",
|
|
||||||
"o1",
|
|
||||||
"o3",
|
|
||||||
"o4-mini",
|
|
||||||
];
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct OpenAiProvider {
|
pub struct OpenAiProvider {
|
||||||
@@ -146,6 +138,65 @@ impl Provider for OpenAiProvider {
|
|||||||
emit_debug_trace(&self.model, &payload, &response, &usage);
|
emit_debug_trace(&self.model, &payload, &response, &usage);
|
||||||
Ok(ProviderCompleteResponse::new(message, model, usage))
|
Ok(ProviderCompleteResponse::new(message, model, usage))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn extract(
|
||||||
|
&self,
|
||||||
|
system: &str,
|
||||||
|
messages: &[Message],
|
||||||
|
schema: &Value,
|
||||||
|
) -> Result<ProviderExtractResponse, ProviderError> {
|
||||||
|
// 1. Build base payload (no tools)
|
||||||
|
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
|
||||||
|
|
||||||
|
// 2. Inject strict JSON‐Schema wrapper
|
||||||
|
payload
|
||||||
|
.as_object_mut()
|
||||||
|
.expect("payload must be an object")
|
||||||
|
.insert(
|
||||||
|
"response_format".to_string(),
|
||||||
|
json!({
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": "extraction",
|
||||||
|
"schema": schema,
|
||||||
|
"strict": true
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// 3. Call OpenAI
|
||||||
|
let response = self.post(payload.clone()).await?;
|
||||||
|
|
||||||
|
// 4. Extract the assistant’s `content` and parse it into JSON
|
||||||
|
let msg = &response["choices"][0]["message"];
|
||||||
|
let raw = msg.get("content").cloned().ok_or_else(|| {
|
||||||
|
ProviderError::ResponseParseError("Missing content in extract response".into())
|
||||||
|
})?;
|
||||||
|
let data = match raw {
|
||||||
|
Value::String(s) => serde_json::from_str(&s)
|
||||||
|
.map_err(|e| ProviderError::ResponseParseError(format!("Invalid JSON: {}", e)))?,
|
||||||
|
Value::Object(_) | Value::Array(_) => raw,
|
||||||
|
other => {
|
||||||
|
return Err(ProviderError::ResponseParseError(format!(
|
||||||
|
"Unexpected content type: {:?}",
|
||||||
|
other
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// 5. Gather usage & model info
|
||||||
|
let usage = match get_usage(&response) {
|
||||||
|
Ok(u) => u,
|
||||||
|
Err(ProviderError::UsageError(e)) => {
|
||||||
|
tracing::debug!("Failed to get usage in extract: {}", e);
|
||||||
|
Usage::default()
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
};
|
||||||
|
let model = get_model(&response);
|
||||||
|
|
||||||
|
Ok(ProviderExtractResponse::new(data, model, usage))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_custom_headers(s: String) -> HashMap<String, String> {
|
fn parse_custom_headers(s: String) -> HashMap<String, String> {
|
||||||
|
|||||||
56
crates/goose-llm/tests/extract_session_name.rs
Normal file
56
crates/goose-llm/tests/extract_session_name.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use dotenv::dotenv;
|
||||||
|
use goose_llm::extractors::generate_session_name;
|
||||||
|
use goose_llm::message::Message;
|
||||||
|
use goose_llm::providers::errors::ProviderError;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_generate_session_name_success() -> Result<(), ProviderError> {
|
||||||
|
// Load .env for Databricks credentials
|
||||||
|
dotenv().ok();
|
||||||
|
let has_creds =
|
||||||
|
std::env::var("DATABRICKS_HOST").is_ok() && std::env::var("DATABRICKS_TOKEN").is_ok();
|
||||||
|
if !has_creds {
|
||||||
|
println!("Skipping generate_session_name test – Databricks creds not set");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a few messages with at least two user messages
|
||||||
|
let messages = vec![
|
||||||
|
Message::user().with_text("Hello, how are you?"),
|
||||||
|
Message::assistant().with_text("I’m fine, thanks!"),
|
||||||
|
Message::user().with_text("What’s the weather in New York tomorrow?"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let name = generate_session_name(&messages).await?;
|
||||||
|
println!("Generated session name: {:?}", name);
|
||||||
|
|
||||||
|
// Should be non-empty and at most 4 words
|
||||||
|
let name = name.trim();
|
||||||
|
assert!(!name.is_empty(), "Name must not be empty");
|
||||||
|
let word_count = name.split_whitespace().count();
|
||||||
|
assert!(
|
||||||
|
word_count <= 4,
|
||||||
|
"Name must be 4 words or less, got {}: {}",
|
||||||
|
word_count,
|
||||||
|
name
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_generate_session_name_no_user() {
|
||||||
|
// No user messages → expect ExecutionError
|
||||||
|
let messages = vec![
|
||||||
|
Message::assistant().with_text("System starting…"),
|
||||||
|
Message::assistant().with_text("All systems go."),
|
||||||
|
];
|
||||||
|
|
||||||
|
let err = generate_session_name(&messages).await;
|
||||||
|
assert!(
|
||||||
|
matches!(err, Err(ProviderError::ExecutionError(_))),
|
||||||
|
"Expected ExecutionError when there are no user messages, got: {:?}",
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
69
crates/goose-llm/tests/extract_tooltip.rs
Normal file
69
crates/goose-llm/tests/extract_tooltip.rs
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use dotenv::dotenv;
|
||||||
|
use goose_llm::extractors::generate_tooltip;
|
||||||
|
use goose_llm::message::{Message, MessageContent, ToolRequest};
|
||||||
|
use goose_llm::providers::errors::ProviderError;
|
||||||
|
use goose_llm::types::core::{Content, ToolCall};
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_generate_tooltip_simple() -> Result<(), ProviderError> {
|
||||||
|
// Skip if no Databricks creds
|
||||||
|
dotenv().ok();
|
||||||
|
if std::env::var("DATABRICKS_HOST").is_err() || std::env::var("DATABRICKS_TOKEN").is_err() {
|
||||||
|
println!("Skipping simple tooltip test – Databricks creds not set");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Two plain-text messages
|
||||||
|
let messages = vec![
|
||||||
|
Message::user().with_text("Hello, how are you?"),
|
||||||
|
Message::assistant().with_text("I'm fine, thanks! How can I help?"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let tooltip = generate_tooltip(&messages).await?;
|
||||||
|
println!("Generated tooltip: {:?}", tooltip);
|
||||||
|
|
||||||
|
assert!(!tooltip.trim().is_empty(), "Tooltip must not be empty");
|
||||||
|
assert!(
|
||||||
|
tooltip.len() < 100,
|
||||||
|
"Tooltip should be reasonably short (<100 chars)"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_generate_tooltip_with_tools() -> Result<(), ProviderError> {
|
||||||
|
// Skip if no Databricks creds
|
||||||
|
dotenv().ok();
|
||||||
|
if std::env::var("DATABRICKS_HOST").is_err() || std::env::var("DATABRICKS_TOKEN").is_err() {
|
||||||
|
println!("Skipping tool‐based tooltip test – Databricks creds not set");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1) Assistant message with a tool request
|
||||||
|
let mut tool_req_msg = Message::assistant();
|
||||||
|
let req = ToolRequest {
|
||||||
|
id: "1".to_string(),
|
||||||
|
tool_call: Ok(ToolCall::new("get_time", json!({"timezone": "UTC"}))),
|
||||||
|
};
|
||||||
|
tool_req_msg.content.push(MessageContent::ToolRequest(req));
|
||||||
|
|
||||||
|
// 2) User message with the tool response
|
||||||
|
let tool_resp_msg = Message::user().with_tool_response(
|
||||||
|
"1",
|
||||||
|
Ok(vec![Content::text("The current time is 12:00 UTC")]),
|
||||||
|
);
|
||||||
|
|
||||||
|
let messages = vec![tool_req_msg, tool_resp_msg];
|
||||||
|
|
||||||
|
let tooltip = generate_tooltip(&messages).await?;
|
||||||
|
println!("Generated tooltip (tools): {:?}", tooltip);
|
||||||
|
|
||||||
|
assert!(!tooltip.trim().is_empty(), "Tooltip must not be empty");
|
||||||
|
assert!(
|
||||||
|
tooltip.len() < 100,
|
||||||
|
"Tooltip should be reasonably short (<100 chars)"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
379
crates/goose-llm/tests/providers_complete.rs
Normal file
379
crates/goose-llm/tests/providers_complete.rs
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use dotenv::dotenv;
|
||||||
|
use goose_llm::message::{Message, MessageContent};
|
||||||
|
use goose_llm::providers::base::Provider;
|
||||||
|
use goose_llm::providers::errors::ProviderError;
|
||||||
|
use goose_llm::providers::{databricks, openai};
|
||||||
|
use goose_llm::types::core::{Content, Tool};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
enum TestStatus {
|
||||||
|
Passed,
|
||||||
|
Skipped,
|
||||||
|
Failed,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for TestStatus {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
TestStatus::Passed => write!(f, "✅"),
|
||||||
|
TestStatus::Skipped => write!(f, "⏭️"),
|
||||||
|
TestStatus::Failed => write!(f, "❌"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TestReport {
|
||||||
|
results: Mutex<HashMap<String, TestStatus>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestReport {
|
||||||
|
fn new() -> Arc<Self> {
|
||||||
|
Arc::new(Self {
|
||||||
|
results: Mutex::new(HashMap::new()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_status(&self, provider: &str, status: TestStatus) {
|
||||||
|
let mut results = self.results.lock().unwrap();
|
||||||
|
results.insert(provider.to_string(), status);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_pass(&self, provider: &str) {
|
||||||
|
self.record_status(provider, TestStatus::Passed);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_skip(&self, provider: &str) {
|
||||||
|
self.record_status(provider, TestStatus::Skipped);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_fail(&self, provider: &str) {
|
||||||
|
self.record_status(provider, TestStatus::Failed);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn print_summary(&self) {
|
||||||
|
println!("\n============== Providers ==============");
|
||||||
|
let results = self.results.lock().unwrap();
|
||||||
|
let mut providers: Vec<_> = results.iter().collect();
|
||||||
|
providers.sort_by(|a, b| a.0.cmp(b.0));
|
||||||
|
|
||||||
|
for (provider, status) in providers {
|
||||||
|
println!("{} {}", status, provider);
|
||||||
|
}
|
||||||
|
println!("=======================================\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lazy_static::lazy_static! {
|
||||||
|
static ref TEST_REPORT: Arc<TestReport> = TestReport::new();
|
||||||
|
static ref ENV_LOCK: Mutex<()> = Mutex::new(());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generic test harness for any Provider implementation
|
||||||
|
struct ProviderTester {
|
||||||
|
provider: Arc<dyn Provider>,
|
||||||
|
name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderTester {
|
||||||
|
fn new<T: Provider + Send + Sync + 'static>(provider: T, name: String) -> Self {
|
||||||
|
Self {
|
||||||
|
provider: Arc::new(provider),
|
||||||
|
name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn test_basic_response(&self) -> Result<()> {
|
||||||
|
let message = Message::user().with_text("Just say hello!");
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.provider
|
||||||
|
.complete("You are a helpful assistant.", &[message], &[])
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// For a basic response, we expect a single text response
|
||||||
|
assert_eq!(
|
||||||
|
response.message.content.len(),
|
||||||
|
1,
|
||||||
|
"Expected single content item in response"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify we got a text response
|
||||||
|
assert!(
|
||||||
|
matches!(response.message.content[0], MessageContent::Text(_)),
|
||||||
|
"Expected text response"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn test_tool_usage(&self) -> Result<()> {
|
||||||
|
let weather_tool = Tool::new(
|
||||||
|
"get_weather",
|
||||||
|
"Get the weather for a location",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"required": ["location"],
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let message = Message::user().with_text("What's the weather like in San Francisco?");
|
||||||
|
|
||||||
|
let response1 = self
|
||||||
|
.provider
|
||||||
|
.complete(
|
||||||
|
"You are a helpful weather assistant.",
|
||||||
|
&[message.clone()],
|
||||||
|
&[weather_tool.clone()],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
println!("=== {}::reponse1 ===", self.name);
|
||||||
|
dbg!(&response1);
|
||||||
|
println!("===================");
|
||||||
|
|
||||||
|
// Verify we got a tool request
|
||||||
|
assert!(
|
||||||
|
response1
|
||||||
|
.message
|
||||||
|
.content
|
||||||
|
.iter()
|
||||||
|
.any(|content| matches!(content, MessageContent::ToolRequest(_))),
|
||||||
|
"Expected tool request in response"
|
||||||
|
);
|
||||||
|
|
||||||
|
let id = &response1
|
||||||
|
.message
|
||||||
|
.content
|
||||||
|
.iter()
|
||||||
|
.filter_map(|message| message.as_tool_request())
|
||||||
|
.last()
|
||||||
|
.expect("got tool request")
|
||||||
|
.id;
|
||||||
|
|
||||||
|
let weather = Message::user().with_tool_response(
|
||||||
|
id,
|
||||||
|
Ok(vec![Content::text(
|
||||||
|
"
|
||||||
|
50°F°C
|
||||||
|
Precipitation: 0%
|
||||||
|
Humidity: 84%
|
||||||
|
Wind: 2 mph
|
||||||
|
Weather
|
||||||
|
Saturday 9:00 PM
|
||||||
|
Clear",
|
||||||
|
)]),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify we construct a valid payload including the request/response pair for the next inference
|
||||||
|
let response2 = self
|
||||||
|
.provider
|
||||||
|
.complete(
|
||||||
|
"You are a helpful weather assistant.",
|
||||||
|
&[message, response1.message, weather],
|
||||||
|
&[weather_tool],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
println!("=== {}::reponse2 ===", self.name);
|
||||||
|
dbg!(&response2);
|
||||||
|
println!("===================");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
response2
|
||||||
|
.message
|
||||||
|
.content
|
||||||
|
.iter()
|
||||||
|
.any(|content| matches!(content, MessageContent::Text(_))),
|
||||||
|
"Expected text for final response"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn test_context_length_exceeded_error(&self) -> Result<()> {
|
||||||
|
// Google Gemini has a really long context window
|
||||||
|
let large_message_content = if self.name.to_lowercase() == "google" {
|
||||||
|
"hello ".repeat(1_300_000)
|
||||||
|
} else {
|
||||||
|
"hello ".repeat(300_000)
|
||||||
|
};
|
||||||
|
|
||||||
|
let messages = vec![
|
||||||
|
Message::user().with_text("hi there. what is 2 + 2?"),
|
||||||
|
Message::assistant().with_text("hey! I think it's 4."),
|
||||||
|
Message::user().with_text(&large_message_content),
|
||||||
|
Message::assistant().with_text("heyy!!"),
|
||||||
|
// Messages before this mark should be truncated
|
||||||
|
Message::user().with_text("what's the meaning of life?"),
|
||||||
|
Message::assistant().with_text("the meaning of life is 42"),
|
||||||
|
Message::user().with_text(
|
||||||
|
"did I ask you what's 2+2 in this message history? just respond with 'yes' or 'no'",
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Test that we get ProviderError::ContextLengthExceeded when the context window is exceeded
|
||||||
|
let result = self
|
||||||
|
.provider
|
||||||
|
.complete("You are a helpful assistant.", &messages, &[])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Print some debug info
|
||||||
|
println!("=== {}::context_length_exceeded_error ===", self.name);
|
||||||
|
dbg!(&result);
|
||||||
|
println!("===================");
|
||||||
|
|
||||||
|
// Ollama truncates by default even when the context window is exceeded
|
||||||
|
if self.name.to_lowercase() == "ollama" {
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"Expected to succeed because of default truncation"
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
result.is_err(),
|
||||||
|
"Expected error when context window is exceeded"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
matches!(result.unwrap_err(), ProviderError::ContextLengthExceeded(_)),
|
||||||
|
"Expected error to be ContextLengthExceeded"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run all provider tests
|
||||||
|
async fn run_test_suite(&self) -> Result<()> {
|
||||||
|
self.test_basic_response().await?;
|
||||||
|
self.test_tool_usage().await?;
|
||||||
|
self.test_context_length_exceeded_error().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_env() {
|
||||||
|
if let Ok(path) = dotenv() {
|
||||||
|
println!("Loaded environment from {:?}", path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to run a provider test with proper error handling and reporting
|
||||||
|
async fn test_provider<F, T>(
|
||||||
|
name: &str,
|
||||||
|
required_vars: &[&str],
|
||||||
|
env_modifications: Option<HashMap<&str, Option<String>>>,
|
||||||
|
provider_fn: F,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
F: FnOnce() -> T,
|
||||||
|
T: Provider + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
// We start off as failed, so that if the process panics it is seen as a failure
|
||||||
|
TEST_REPORT.record_fail(name);
|
||||||
|
|
||||||
|
// Take exclusive access to environment modifications
|
||||||
|
let lock = ENV_LOCK.lock().unwrap();
|
||||||
|
|
||||||
|
load_env();
|
||||||
|
|
||||||
|
// Save current environment state for required vars and modified vars
|
||||||
|
let mut original_env = HashMap::new();
|
||||||
|
for &var in required_vars {
|
||||||
|
if let Ok(val) = std::env::var(var) {
|
||||||
|
original_env.insert(var, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(mods) = &env_modifications {
|
||||||
|
for &var in mods.keys() {
|
||||||
|
if let Ok(val) = std::env::var(var) {
|
||||||
|
original_env.insert(var, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply any environment modifications
|
||||||
|
if let Some(mods) = &env_modifications {
|
||||||
|
for (&var, value) in mods.iter() {
|
||||||
|
match value {
|
||||||
|
Some(val) => std::env::set_var(var, val),
|
||||||
|
None => std::env::remove_var(var),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the provider
|
||||||
|
let missing_vars = required_vars.iter().any(|var| std::env::var(var).is_err());
|
||||||
|
if missing_vars {
|
||||||
|
println!("Skipping {} tests - credentials not configured", name);
|
||||||
|
TEST_REPORT.record_skip(name);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let provider = provider_fn();
|
||||||
|
|
||||||
|
// Restore original environment
|
||||||
|
for (&var, value) in original_env.iter() {
|
||||||
|
std::env::set_var(var, value);
|
||||||
|
}
|
||||||
|
if let Some(mods) = env_modifications {
|
||||||
|
for &var in mods.keys() {
|
||||||
|
if !original_env.contains_key(var) {
|
||||||
|
std::env::remove_var(var);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::mem::drop(lock);
|
||||||
|
|
||||||
|
let tester = ProviderTester::new(provider, name.to_string());
|
||||||
|
match tester.run_test_suite().await {
|
||||||
|
Ok(_) => {
|
||||||
|
TEST_REPORT.record_pass(name);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
println!("{} test failed: {}", name, e);
|
||||||
|
TEST_REPORT.record_fail(name);
|
||||||
|
Err(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn openai_complete() -> Result<()> {
|
||||||
|
test_provider(
|
||||||
|
"OpenAI",
|
||||||
|
&["OPENAI_API_KEY"],
|
||||||
|
None,
|
||||||
|
openai::OpenAiProvider::default,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn databricks_complete() -> Result<()> {
|
||||||
|
test_provider(
|
||||||
|
"Databricks",
|
||||||
|
&["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
|
||||||
|
None,
|
||||||
|
databricks::DatabricksProvider::default,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the final test report
|
||||||
|
#[ctor::dtor]
|
||||||
|
fn print_test_report() {
|
||||||
|
TEST_REPORT.print_summary();
|
||||||
|
}
|
||||||
195
crates/goose-llm/tests/providers_extract.rs
Normal file
195
crates/goose-llm/tests/providers_extract.rs
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
// tests/providers_extract.rs
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use dotenv::dotenv;
|
||||||
|
use goose_llm::message::Message;
|
||||||
|
use goose_llm::providers::base::Provider;
|
||||||
|
use goose_llm::providers::{databricks::DatabricksProvider, openai::OpenAiProvider};
|
||||||
|
use goose_llm::ModelConfig;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Copy, Clone)]
|
||||||
|
enum ProviderType {
|
||||||
|
OpenAi,
|
||||||
|
Databricks,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderType {
|
||||||
|
fn required_env(&self) -> &'static [&'static str] {
|
||||||
|
match self {
|
||||||
|
ProviderType::OpenAi => &["OPENAI_API_KEY"],
|
||||||
|
ProviderType::Databricks => &["DATABRICKS_HOST", "DATABRICKS_TOKEN"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_provider(&self, cfg: ModelConfig) -> Result<Arc<dyn Provider>> {
|
||||||
|
Ok(match self {
|
||||||
|
ProviderType::OpenAi => Arc::new(OpenAiProvider::from_env(cfg)?),
|
||||||
|
ProviderType::Databricks => Arc::new(DatabricksProvider::from_env(cfg)?),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_required_env_vars(required: &[&str]) -> bool {
|
||||||
|
let missing: Vec<_> = required
|
||||||
|
.iter()
|
||||||
|
.filter(|&&v| std::env::var(v).is_err())
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
if !missing.is_empty() {
|
||||||
|
println!("Skipping test; missing env vars: {:?}", missing);
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Shared inputs for "paper" task ---
|
||||||
|
const PAPER_SYSTEM: &str =
|
||||||
|
"You are an expert at structured data extraction. Extract the metadata of a research paper into JSON.";
|
||||||
|
const PAPER_TEXT: &str =
|
||||||
|
"Application of Quantum Algorithms in Interstellar Navigation: A New Frontier \
|
||||||
|
by Dr. Stella Voyager, Dr. Nova Star, Dr. Lyra Hunter. Abstract: This paper \
|
||||||
|
investigates the utilization of quantum algorithms to improve interstellar \
|
||||||
|
navigation systems. Keywords: Quantum algorithms, interstellar navigation, \
|
||||||
|
space-time anomalies, quantum superposition, quantum entanglement, space travel.";
|
||||||
|
|
||||||
|
fn paper_schema() -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"title": { "type": "string" },
|
||||||
|
"authors": { "type": "array", "items": { "type": "string" } },
|
||||||
|
"abstract": { "type": "string" },
|
||||||
|
"keywords": { "type": "array", "items": { "type": "string" } }
|
||||||
|
},
|
||||||
|
"required": ["title","authors","abstract","keywords"],
|
||||||
|
"additionalProperties": false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Shared inputs for "UI" task ---
|
||||||
|
const UI_SYSTEM: &str = "You are a UI generator AI. Convert the user input into a JSON-driven UI.";
|
||||||
|
const UI_TEXT: &str = "Make a User Profile Form";
|
||||||
|
|
||||||
|
fn ui_schema() -> Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["div","button","header","section","field","form"]
|
||||||
|
},
|
||||||
|
"label": { "type": "string" },
|
||||||
|
"children": {
|
||||||
|
"type": "array",
|
||||||
|
"items": { "$ref": "#" }
|
||||||
|
},
|
||||||
|
"attributes": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" },
|
||||||
|
"value": { "type": "string" }
|
||||||
|
},
|
||||||
|
"required": ["name","value"],
|
||||||
|
"additionalProperties": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["type","label","children","attributes"],
|
||||||
|
"additionalProperties": false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generic runner for any extract task
|
||||||
|
async fn run_extract_test<F>(
|
||||||
|
provider_type: ProviderType,
|
||||||
|
model: &str,
|
||||||
|
system: &'static str,
|
||||||
|
user_text: &'static str,
|
||||||
|
schema: Value,
|
||||||
|
validate: F,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
F: Fn(&Value) -> bool,
|
||||||
|
{
|
||||||
|
dotenv().ok();
|
||||||
|
if !check_required_env_vars(provider_type.required_env()) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let cfg = ModelConfig::new(model.to_string()).with_temperature(Some(0.0));
|
||||||
|
let provider = provider_type.create_provider(cfg)?;
|
||||||
|
|
||||||
|
let msg = Message::user().with_text(user_text);
|
||||||
|
let resp = provider.extract(system, &[msg], &schema).await?;
|
||||||
|
|
||||||
|
println!("[{:?}] extract => {}", provider_type, resp.data);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
validate(&resp.data),
|
||||||
|
"{:?} failed validation on {}",
|
||||||
|
provider_type,
|
||||||
|
resp.data
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper for the "paper" task
|
||||||
|
async fn run_extract_paper_test(provider: ProviderType, model: &str) -> Result<()> {
|
||||||
|
run_extract_test(
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
PAPER_SYSTEM,
|
||||||
|
PAPER_TEXT,
|
||||||
|
paper_schema(),
|
||||||
|
|v| {
|
||||||
|
v.as_object()
|
||||||
|
.map(|o| {
|
||||||
|
["title", "authors", "abstract", "keywords"]
|
||||||
|
.iter()
|
||||||
|
.all(|k| o.contains_key(*k))
|
||||||
|
})
|
||||||
|
.unwrap_or(false)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper for the "UI" task
|
||||||
|
async fn run_extract_ui_test(provider: ProviderType, model: &str) -> Result<()> {
|
||||||
|
run_extract_test(provider, model, UI_SYSTEM, UI_TEXT, ui_schema(), |v| {
|
||||||
|
v.as_object()
|
||||||
|
.and_then(|o| o.get("type").and_then(Value::as_str))
|
||||||
|
== Some("form")
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn openai_extract_paper() -> Result<()> {
|
||||||
|
run_extract_paper_test(ProviderType::OpenAi, "gpt-4o").await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn openai_extract_ui() -> Result<()> {
|
||||||
|
run_extract_ui_test(ProviderType::OpenAi, "gpt-4o").await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn databricks_extract_paper() -> Result<()> {
|
||||||
|
run_extract_paper_test(ProviderType::Databricks, "goose-gpt-4-1").await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn databricks_extract_ui() -> Result<()> {
|
||||||
|
run_extract_ui_test(ProviderType::Databricks, "goose-gpt-4-1").await
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user