[goose-llm] fix image content bug, add optional request_id field (#3439)

This commit is contained in:
Salman Mohammed
2025-07-15 18:06:37 -04:00
committed by GitHub
parent f4e3d06f9e
commit a5d77950db
22 changed files with 482 additions and 512 deletions

View File

@@ -0,0 +1,115 @@
import kotlin.system.measureNanoTime
import kotlinx.coroutines.runBlocking
import uniffi.goose_llm.*
import java.net.URI
import java.net.http.HttpClient
import java.net.http.HttpRequest
import java.net.http.HttpResponse
/* ---------- Goose helpers ---------- */
fun buildProviderConfig(host: String, token: String): String =
"""{ "host": "$host", "token": "$token" }"""
suspend fun timeGooseCall(
modelCfg: ModelConfig,
providerName: String,
providerCfg: String
): Pair<Double, CompletionResponse> {
val req = createCompletionRequest(
providerName,
providerCfg,
modelCfg,
systemPreamble = "You are a helpful assistant.",
messages = listOf(
Message(
Role.USER,
System.currentTimeMillis() / 1000,
listOf(MessageContent.Text(TextContent("Write me a 1000 word chapter about learning Go vs Rust in the world of LLMs and AI.")))
)
),
extensions = emptyList()
)
lateinit var resp: CompletionResponse
val wallMs = measureNanoTime { resp = completion(req) } / 1_000_000.0
return wallMs to resp
}
/* ---------- OpenAI helpers ---------- */
fun timeOpenAiCall(client: HttpClient, apiKey: String): Double {
val body = """
{
"model": "gpt-4.1",
"max_tokens": 500,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write me a 1000 word chapter about learning Go vs Rust in the world of LLMs and AI."}
]
}
""".trimIndent()
val request = HttpRequest.newBuilder()
.uri(URI.create("https://api.openai.com/v1/chat/completions"))
.header("Authorization", "Bearer $apiKey")
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(body))
.build()
val wallMs = measureNanoTime {
client.send(request, HttpResponse.BodyHandlers.ofString())
} / 1_000_000.0
return wallMs
}
/* ---------- main ---------- */
fun main() = runBlocking {
/* Goose provider setup */
val providerName = "databricks"
val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set")
val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set")
val providerCfg = buildProviderConfig(host, token)
/* OpenAI setup */
val openAiKey = System.getenv("OPENAI_API_KEY") ?: error("OPENAI_API_KEY not set")
val httpClient = HttpClient.newBuilder().build()
val gooseModels = listOf("goose-claude-4-sonnet", "goose-gpt-4-1")
val runsPerModel = 3
/* --- Goose timing --- */
for (model in gooseModels) {
val maxTokens = 500
val cfg = ModelConfig(model, 100_000u, 0.0f, maxTokens)
var wallSum = 0.0
var gooseSum = 0.0
println("=== Goose: $model ===")
repeat(runsPerModel) { run ->
val (wall, resp) = timeGooseCall(cfg, providerName, providerCfg)
val gooseMs = resp.runtimeMetrics.totalTimeSec * 1_000
val overhead = wall - gooseMs
wallSum += wall
gooseSum += gooseMs
println("run ${run + 1}: wall = %.1f ms | goose-llm = %.1f ms | overhead = %.1f ms"
.format(wall, gooseMs, overhead))
}
println("-- avg wall = %.1f ms | avg overhead = %.1f ms --\n"
.format(wallSum / runsPerModel, (wallSum - gooseSum) / runsPerModel))
}
/* --- OpenAI direct timing --- */
var oaSum = 0.0
println("=== OpenAI: gpt-4.1 (direct HTTPS) ===")
repeat(runsPerModel) { run ->
val wall = timeOpenAiCall(httpClient, openAiKey)
oaSum += wall
println("run ${run + 1}: wall = %.1f ms".format(wall))
}
println("-- avg wall = %.1f ms --".format(oaSum / runsPerModel))
}

View File

@@ -1,292 +1,228 @@
import java.io.File
import java.util.Base64
import kotlinx.coroutines.runBlocking
import uniffi.goose_llm.*
fun main() = runBlocking {
val now = System.currentTimeMillis() / 1000
val msgs = listOf(
// 1) User sends a plain-text prompt
Message(
role = Role.USER,
created = now,
content = listOf(
MessageContent.Text(
TextContent("What is 7 x 6?")
)
)
),
/* ---------- shared helpers ---------- */
// 2) Assistant makes a tool request (ToolReq) to calculate 7×6
Message(
role = Role.ASSISTANT,
created = now + 2,
content = listOf(
MessageContent.ToolReq(
ToolRequest(
id = "calc1",
toolCall = """
{
"status": "success",
"value": {
"name": "calculator_extension__toolname",
"arguments": {
"operation": "doesnotexist",
"numbers": [7, 6]
},
"needsApproval": false
}
}
""".trimIndent()
)
)
)
),
// 3) User (on behalf of the tool) responds with the tool result (ToolResp)
Message(
role = Role.USER,
created = now + 3,
content = listOf(
MessageContent.ToolResp(
ToolResponse(
id = "calc1",
toolResult = """
{
"status": "error",
"error": "Invalid value for operation: 'doesnotexist'. Valid values are: ['add', 'subtract', 'multiply', 'divide']"
}
""".trimIndent()
)
)
)
),
// 4) Assistant makes a tool request (ToolReq) to calculate 7×6
Message(
role = Role.ASSISTANT,
created = now + 4,
content = listOf(
MessageContent.ToolReq(
ToolRequest(
id = "calc1",
toolCall = """
{
"status": "success",
"value": {
"name": "calculator_extension__toolname",
"arguments": {
"operation": "multiply",
"numbers": [7, 6]
},
"needsApproval": false
}
}
""".trimIndent()
)
)
)
),
// 5) User (on behalf of the tool) responds with the tool result (ToolResp)
Message(
role = Role.USER,
created = now + 5,
content = listOf(
MessageContent.ToolResp(
ToolResponse(
id = "calc1",
toolResult = """
{
"status": "success",
"value": [
{"type": "text", "text": "42"}
]
}
""".trimIndent()
)
)
)
),
)
printMessages(msgs)
println("---\n")
// Setup provider
val providerName = "databricks"
val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set")
val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set")
val providerConfig = """{"host": "$host", "token": "$token"}"""
println("Provider Name: $providerName")
println("Provider Config: $providerConfig")
val sessionName = generateSessionName(providerName, providerConfig, msgs)
println("\nSession Name: $sessionName")
val tooltip = generateTooltip(providerName, providerConfig, msgs)
println("\nTooltip: $tooltip")
// Completion
val modelName = "goose-gpt-4-1"
val modelConfig = ModelConfig(
modelName,
100000u, // UInt
0.1f, // Float
200 // Int
)
fun buildProviderConfig(host: String, token: String, imageFormat: String = "OpenAi"): String = """
{
"host": "$host",
"token": "$token",
"image_format": "$imageFormat"
}
""".trimIndent()
fun calculatorExtension(): ExtensionConfig {
val calculatorTool = createToolConfig(
name = "calculator",
name = "calculator",
description = "Perform basic arithmetic operations",
inputSchema = """
{
"type": "object",
"required": ["operation", "numbers"],
"properties": {
"operation": {
"type": "string",
"enum": ["add", "subtract", "multiply", "divide"],
"description": "The arithmetic operation to perform"
},
"numbers": {
"type": "array",
"items": { "type": "number" },
"description": "List of numbers to operate on in order"
}
"type": "object",
"required": ["operation", "numbers"],
"properties": {
"operation": {
"type": "string",
"enum": ["add", "subtract", "multiply", "divide"],
"description": "The arithmetic operation to perform"
},
"numbers": {
"type": "array",
"items": { "type": "number" },
"description": "List of numbers to operate on in order"
}
}
}
""".trimIndent(),
approvalMode = ToolApprovalMode.AUTO
)
val calculator_extension = ExtensionConfig(
name = "calculator_extension",
return ExtensionConfig(
name = "calculator_extension",
instructions = "This extension provides a calculator tool.",
tools = listOf(calculatorTool)
tools = listOf(calculatorTool)
)
val extensions = listOf(calculator_extension)
val systemPreamble = "You are a helpful assistant."
// Testing with tool calls with an error in tool name
val reqToolErr = createCompletionRequest(
providerName,
providerConfig,
modelConfig,
systemPreamble,
messages = listOf(
Message(
role = Role.USER,
created = now,
content = listOf(
MessageContent.Text(
TextContent("What is 7 x 6?")
)
)
)),
extensions = extensions
)
val respToolErr = completion(reqToolErr)
println("\nCompletion Response (one msg):\n${respToolErr.message}")
println()
val reqAll = createCompletionRequest(
providerName,
providerConfig,
modelConfig,
systemPreamble,
messages = msgs,
extensions = extensions
)
val respAll = completion(reqAll)
println("\nCompletion Response (all msgs):\n${respAll.message}")
println()
// ---- UI Extraction (custom schema) ----
runUiExtraction(providerName, providerConfig)
// --- Prompt Override ---
val prompt_req = createCompletionRequest(
providerName,
providerConfig,
modelConfig,
systemPreamble = null,
systemPromptOverride = "You are a bot named Tile Creator. Your task is to create a tile based on the user's input.",
messages=listOf(
Message(
role = Role.USER,
created = now,
content = listOf(
MessageContent.Text(
TextContent("What's your name?")
)
)
)
),
extensions=emptyList()
)
val prompt_resp = completion(prompt_req)
println("\nPrompt Override Response:\n${prompt_resp.message}")
}
/* ---------- demos ---------- */
suspend fun runCalculatorDemo(
modelConfig: ModelConfig,
providerName: String,
providerConfig: String
) {
val now = System.currentTimeMillis() / 1000
val msgs = listOf(
// same conversation you already had
Message(Role.USER, now, listOf(MessageContent.Text(TextContent("What is 7 x 6?")))),
Message(Role.ASSISTANT, now + 2, listOf(MessageContent.ToolReq(
ToolRequest(
id = "calc1",
toolCall = """
{
"status": "success",
"value": {
"name": "calculator_extension__toolname",
"arguments": { "operation": "doesnotexist", "numbers": [7,6] },
"needsApproval": false
}
}
""".trimIndent()
)))),
Message(Role.USER, now + 3, listOf(MessageContent.ToolResp(
ToolResponse(
id = "calc1",
toolResult = """
{
"status": "error",
"error": "Invalid value for operation: 'doesnotexist'. Valid values are: ['add','subtract','multiply','divide']"
}
""".trimIndent()
)))),
Message(Role.ASSISTANT, now + 4, listOf(MessageContent.ToolReq(
ToolRequest(
id = "calc1",
toolCall = """
{
"status": "success",
"value": {
"name": "calculator_extension__toolname",
"arguments": { "operation": "multiply", "numbers": [7,6] },
"needsApproval": false
}
}
""".trimIndent()
)))),
Message(Role.USER, now + 5, listOf(MessageContent.ToolResp(
ToolResponse(
id = "calc1",
toolResult = """
{
"status": "success",
"value": [ { "type": "text", "text": "42" } ]
}
""".trimIndent()
))))
)
/* one-shot prompt with error */
val reqErr = createCompletionRequest(
providerName, providerConfig, modelConfig,
"You are a helpful assistant.",
messages = listOf(msgs.first()),
extensions = listOf(calculatorExtension())
)
println("\n[${modelConfig.modelName}] Calculator (single-msg) → ${completion(reqErr).message}")
/* full conversation */
val reqAll = createCompletionRequest(
providerName, providerConfig, modelConfig,
"You are a helpful assistant.",
messages = msgs,
extensions = listOf(calculatorExtension())
)
println("[${modelConfig.modelName}] Calculator (full chat) → ${completion(reqAll).message}")
}
suspend fun runImageExample(
modelConfig: ModelConfig,
providerName: String,
providerConfig: String
) {
val imagePath = "../../crates/goose/examples/test_assets/test_image.png"
val base64Image = Base64.getEncoder().encodeToString(File(imagePath).readBytes())
val now = System.currentTimeMillis() / 1000
val msgs = listOf(
Message(Role.USER, now, listOf(
MessageContent.Text(TextContent("What is in this image?")),
MessageContent.Image(ImageContent(base64Image, "image/png"))
)),
)
val req = createCompletionRequest(
providerName, providerConfig, modelConfig,
"You are a helpful assistant. Please describe any text you see in the image.",
messages = msgs,
extensions = emptyList()
)
println("\n[${modelConfig.modelName}] Image example → ${completion(req).message}")
}
suspend fun runPromptOverride(
modelConfig: ModelConfig,
providerName: String,
providerConfig: String
) {
val now = System.currentTimeMillis() / 1000
val req = createCompletionRequest(
providerName, providerConfig, modelConfig,
systemPreamble = null,
systemPromptOverride = "You are a bot named Tile Creator. Your task is to create a tile based on the user's input.",
messages = listOf(
Message(Role.USER, now, listOf(MessageContent.Text(TextContent("What's your name?"))))
),
extensions = emptyList()
)
println("\n[${modelConfig.modelName}] Prompt override → ${completion(req).message}")
}
suspend fun runUiExtraction(providerName: String, providerConfig: String) {
val systemPrompt = "You are a UI generator AI. Convert the user input into a JSON-driven UI."
val messages = listOf(
Message(
role = Role.USER,
created = System.currentTimeMillis() / 1000,
content = listOf(
MessageContent.Text(
TextContent("Make a User Profile Form")
)
)
)
)
val schema = """{
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["div","button","header","section","field","form"]
},
"label": { "type": "string" },
"children": {
"type": "array",
"items": { "${'$'}ref": "#" }
},
"attributes": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"value": { "type": "string" }
},
"required": ["name","value"],
"additionalProperties": false
}
}
},
"required": ["type","label","children","attributes"],
"additionalProperties": false
}""".trimIndent();
val schema = /* same JSON schema as before */ """
{
"type":"object",
"properties":{
"type":{"type":"string","enum":["div","button","header","section","field","form"]},
"label":{"type":"string"},
"children":{"type":"array","items":{"${'$'}ref":"#"}},
"attributes":{"type":"array","items":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"string"}},"required":["name","value"],"additionalProperties":false}}
},
"required":["type","label","children","attributes"],
"additionalProperties":false
}
""".trimIndent()
try {
val response = generateStructuredOutputs(
providerName = providerName,
providerConfig = providerConfig,
systemPrompt = systemPrompt,
messages = messages,
schema = schema
)
println("\nUI Extraction Output:\n${response}")
} catch (e: ProviderException) {
println("\nUI Extraction failed:\n${e.message}")
}
val messages = listOf(
Message(Role.USER, System.currentTimeMillis()/1000,
listOf(MessageContent.Text(TextContent("Make a User Profile Form"))))
)
val res = generateStructuredOutputs(
providerName, providerConfig,
systemPrompt = "You are a UI generator AI. Convert the user input into a JSON-driven UI.",
messages = messages,
schema = schema
)
println("\n[UI-Extraction] → $res")
}
/* ---------- entry-point ---------- */
fun main() = runBlocking {
/* --- provider setup --- */
val providerName = "databricks"
val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set")
val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set")
val providerConfig = buildProviderConfig(host, token)
println("Provider: $providerName")
println("Config : $providerConfig\n")
/* --- run demos for each model --- */
// NOTE: `claude-3-5-haiku` does NOT support images
val modelNames = listOf("kgoose-gpt-4o", "goose-claude-4-sonnet")
for (name in modelNames) {
val modelConfig = ModelConfig(name, 100000u, 0.1f, 200)
println("\n===== Running demos for model: $name =====")
runCalculatorDemo(modelConfig, providerName, providerConfig)
runImageExample(modelConfig, providerName, providerConfig)
runPromptOverride(modelConfig, providerName, providerConfig)
println("===== End demos for $name =====\n")
}
/* UI extraction is model-agnostic, so run it once */
runUiExtraction(providerName, providerConfig)
}

View File

@@ -833,6 +833,7 @@ internal interface UniffiLib : Library {
`systemPromptOverride`: RustBuffer.ByValue,
`messages`: RustBuffer.ByValue,
`extensions`: RustBuffer.ByValue,
`requestId`: RustBuffer.ByValue,
uniffi_out_err: UniffiRustCallStatus,
): RustBuffer.ByValue
@@ -848,6 +849,7 @@ internal interface UniffiLib : Library {
`providerName`: RustBuffer.ByValue,
`providerConfig`: RustBuffer.ByValue,
`messages`: RustBuffer.ByValue,
`requestId`: RustBuffer.ByValue,
): Long
fun uniffi_goose_llm_fn_func_generate_structured_outputs(
@@ -856,12 +858,14 @@ internal interface UniffiLib : Library {
`systemPrompt`: RustBuffer.ByValue,
`messages`: RustBuffer.ByValue,
`schema`: RustBuffer.ByValue,
`requestId`: RustBuffer.ByValue,
): Long
fun uniffi_goose_llm_fn_func_generate_tooltip(
`providerName`: RustBuffer.ByValue,
`providerConfig`: RustBuffer.ByValue,
`messages`: RustBuffer.ByValue,
`requestId`: RustBuffer.ByValue,
): Long
fun uniffi_goose_llm_fn_func_print_messages(
@@ -1101,19 +1105,19 @@ private fun uniffiCheckApiChecksums(lib: IntegrityCheckingUniffiLib) {
if (lib.uniffi_goose_llm_checksum_func_completion() != 47457.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 50798.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 15391.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 49910.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 64087.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 34350.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 43426.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 4576.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 41121.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 36439.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_print_messages() != 30278.toShort()) {
@@ -2960,6 +2964,7 @@ fun `createCompletionRequest`(
`systemPromptOverride`: kotlin.String? = null,
`messages`: List<Message>,
`extensions`: List<ExtensionConfig>,
`requestId`: kotlin.String? = null,
): CompletionRequest =
FfiConverterTypeCompletionRequest.lift(
uniffiRustCall { _status ->
@@ -2971,6 +2976,7 @@ fun `createCompletionRequest`(
FfiConverterOptionalString.lower(`systemPromptOverride`),
FfiConverterSequenceTypeMessage.lower(`messages`),
FfiConverterSequenceTypeExtensionConfig.lower(`extensions`),
FfiConverterOptionalString.lower(`requestId`),
_status,
)
},
@@ -3003,12 +3009,14 @@ suspend fun `generateSessionName`(
`providerName`: kotlin.String,
`providerConfig`: Value,
`messages`: List<Message>,
`requestId`: kotlin.String? = null,
): kotlin.String =
uniffiRustCallAsync(
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_session_name(
FfiConverterString.lower(`providerName`),
FfiConverterTypeValue.lower(`providerConfig`),
FfiConverterSequenceTypeMessage.lower(`messages`),
FfiConverterOptionalString.lower(`requestId`),
),
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
{ future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) },
@@ -3031,6 +3039,7 @@ suspend fun `generateStructuredOutputs`(
`systemPrompt`: kotlin.String,
`messages`: List<Message>,
`schema`: Value,
`requestId`: kotlin.String? = null,
): ProviderExtractResponse =
uniffiRustCallAsync(
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_structured_outputs(
@@ -3039,6 +3048,7 @@ suspend fun `generateStructuredOutputs`(
FfiConverterString.lower(`systemPrompt`),
FfiConverterSequenceTypeMessage.lower(`messages`),
FfiConverterTypeValue.lower(`schema`),
FfiConverterOptionalString.lower(`requestId`),
),
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
{ future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) },
@@ -3059,12 +3069,14 @@ suspend fun `generateTooltip`(
`providerName`: kotlin.String,
`providerConfig`: Value,
`messages`: List<Message>,
`requestId`: kotlin.String? = null,
): kotlin.String =
uniffiRustCallAsync(
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_tooltip(
FfiConverterString.lower(`providerName`),
FfiConverterTypeValue.lower(`providerConfig`),
FfiConverterSequenceTypeMessage.lower(`messages`),
FfiConverterOptionalString.lower(`requestId`),
),
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
{ future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) },

View File

@@ -64,6 +64,10 @@ path = "uniffi-bindgen.rs"
name = "simple"
path = "examples/simple.rs"
[[example]]
name = "image"
path = "examples/image.rs"
[[example]]
name = "prompt_override"
path = "examples/prompt_override.rs"

View File

@@ -0,0 +1,53 @@
use anyhow::Result;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use goose_llm::{
completion,
message::MessageContent,
types::completion::{CompletionRequest, CompletionResponse},
Message, ModelConfig,
};
use serde_json::json;
use std::{fs, vec};
#[tokio::main]
async fn main() -> Result<()> {
let provider = "databricks";
let provider_config = json!({
"host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"),
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
});
let model_name = "goose-claude-4-sonnet"; // "gpt-4o";
let model_config = ModelConfig::new(model_name.to_string());
let system_preamble = "You are a helpful assistant.";
// Read and encode test image
let image_data = fs::read("examples/test_assets/test_image.png")?;
let base64_image = BASE64.encode(image_data);
let user_msg = Message::user()
.with_text("What do you see in this image?")
.with_content(MessageContent::image(base64_image, "image/png"));
let messages = vec![user_msg];
let completion_response: CompletionResponse = completion(
CompletionRequest::new(
provider.to_string(),
provider_config.clone(),
model_config.clone(),
Some(system_preamble.to_string()),
None,
messages,
vec![],
)
.with_request_id("test-image-1".to_string()),
)
.await?;
// Print the response
println!("\nCompletion Response:");
println!("{}", serde_json::to_string_pretty(&completion_response)?);
Ok(())
}

View File

@@ -116,7 +116,7 @@ async fn main() -> Result<()> {
println!("\nCompletion Response:");
println!("{}", serde_json::to_string_pretty(&completion_response)?);
let tooltip = generate_tooltip(provider, provider_config.clone(), &messages).await?;
let tooltip = generate_tooltip(provider, provider_config.clone(), &messages, None).await?;
println!("\nTooltip: {}", tooltip);
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 KiB

View File

@@ -46,7 +46,12 @@ pub async fn completion(req: CompletionRequest) -> Result<CompletionResponse, Co
// Call the LLM provider
let start_provider = Instant::now();
let mut response = provider
.complete(&system_prompt, &req.messages, &tools)
.complete(
&system_prompt,
&req.messages,
&tools,
req.request_id.as_deref(),
)
.await?;
let provider_elapsed_sec = start_provider.elapsed().as_secs_f32();
let usage_tokens = response.usage.total_tokens;

View File

@@ -48,11 +48,12 @@ fn build_system_prompt() -> String {
}
/// Generates a short (≤4 words) session name
#[uniffi::export(async_runtime = "tokio")]
#[uniffi::export(async_runtime = "tokio", default(request_id = None))]
pub async fn generate_session_name(
provider_name: &str,
provider_config: JsonValueFfi,
messages: &[Message],
request_id: Option<String>,
) -> Result<String, ProviderError> {
// Collect up to the first 3 user messages (truncated to 300 chars each)
let context: Vec<String> = messages
@@ -90,6 +91,7 @@ pub async fn generate_session_name(
&system_prompt,
&[Message::user().with_text(&user_msg_text)],
schema,
request_id,
)
.await?;

View File

@@ -52,11 +52,12 @@ fn build_system_prompt() -> String {
/// Generates a tooltip summarizing the last two messages in the session,
/// including any tool calls or results.
#[uniffi::export(async_runtime = "tokio")]
#[uniffi::export(async_runtime = "tokio", default(request_id = None))]
pub async fn generate_tooltip(
provider_name: &str,
provider_config: JsonValueFfi,
messages: &[Message],
request_id: Option<String>,
) -> Result<String, ProviderError> {
// Need at least two messages to generate a tooltip
if messages.len() < 2 {
@@ -148,6 +149,7 @@ pub async fn generate_tooltip(
&system_prompt,
&[Message::user().with_text(&user_msg_text)],
schema,
request_id,
)
.await?;

View File

@@ -69,6 +69,7 @@ pub trait Provider: Send + Sync {
/// * `system` - The system prompt that guides the model's behavior
/// * `messages` - The conversation history as a sequence of messages
/// * `tools` - Optional list of tools the model can use
/// * `request_id` - Optional request ID (only used by some providers like Databricks)
///
/// # Returns
/// A tuple containing the model's response message and provider usage statistics
@@ -81,6 +82,7 @@ pub trait Provider: Send + Sync {
system: &str,
messages: &[Message],
tools: &[Tool],
request_id: Option<&str>,
) -> Result<ProviderCompleteResponse, ProviderError>;
/// Structured extraction: always JSONSchema
@@ -90,6 +92,7 @@ pub trait Provider: Send + Sync {
/// * `messages` conversation history
/// * `schema` a JSONSchema for the expected output.
/// Will set strict=true for OpenAI & Databricks.
/// * `request_id` - Optional request ID (only used by some providers like Databricks)
///
/// # Returns
/// A `ProviderExtractResponse` whose `data` is a JSON object matching `schema`.
@@ -102,6 +105,7 @@ pub trait Provider: Send + Sync {
system: &str,
messages: &[Message],
schema: &serde_json::Value,
request_id: Option<&str>,
) -> Result<ProviderExtractResponse, ProviderError>;
}

View File

@@ -210,6 +210,7 @@ impl Provider for DatabricksProvider {
system: &str,
messages: &[Message],
tools: &[Tool],
request_id: Option<&str>,
) -> Result<ProviderCompleteResponse, ProviderError> {
let mut payload = create_request(
&self.model,
@@ -224,6 +225,17 @@ impl Provider for DatabricksProvider {
.expect("payload should have model key")
.remove("model");
// Add client_request_id if provided
if let Some(req_id) = request_id {
payload
.as_object_mut()
.expect("payload should be an object")
.insert(
"client_request_id".to_string(),
serde_json::Value::String(req_id.to_string()),
);
}
let response = self.post(payload.clone()).await?;
// Parse response
@@ -247,6 +259,7 @@ impl Provider for DatabricksProvider {
system: &str,
messages: &[Message],
schema: &Value,
request_id: Option<&str>,
) -> Result<ProviderExtractResponse, ProviderError> {
// 1. Build base payload (no tools)
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
@@ -267,6 +280,17 @@ impl Provider for DatabricksProvider {
}),
);
// Add client_request_id if provided
if let Some(req_id) = request_id {
payload
.as_object_mut()
.expect("payload should be an object")
.insert(
"client_request_id".to_string(),
serde_json::Value::String(req_id.to_string()),
);
}
// 3. Call OpenAI
let response = self.post(payload.clone()).await?;

View File

@@ -7,10 +7,7 @@ use crate::{
providers::{
base::Usage,
errors::ProviderError,
utils::{
convert_image, detect_image_path, is_valid_function_name, load_image_file,
sanitize_function_name, ImageFormat,
},
utils::{convert_image, is_valid_function_name, sanitize_function_name, ImageFormat},
},
types::core::{Content, Role, Tool, ToolCall, ToolError},
};
@@ -34,30 +31,17 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
match content {
MessageContent::Text(text) => {
if !text.text.is_empty() {
// Check for image paths in the text
if let Some(image_path) = detect_image_path(&text.text) {
has_multiple_content = true;
// Try to load and convert the image
if let Ok(image) = load_image_file(image_path) {
content_array.push(json!({
"type": "text",
"text": text.text
}));
content_array.push(convert_image(&image, image_format));
} else {
content_array.push(json!({
"type": "text",
"text": text.text
}));
}
} else {
content_array.push(json!({
"type": "text",
"text": text.text
}));
}
content_array.push(json!({
"type": "text",
"text": text.text
}));
}
}
MessageContent::Image(image) => {
// Handle direct image content
let converted_image = convert_image(image, image_format);
content_array.push(converted_image);
}
MessageContent::Thinking(content) => {
has_multiple_content = true;
content_array.push(json!({
@@ -166,15 +150,6 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
}
}
}
MessageContent::Image(image) => {
// Handle direct image content
content_array.push(json!({
"type": "image_url",
"image_url": {
"url": convert_image(image, image_format)
}
}));
}
}
}
@@ -791,40 +766,6 @@ mod tests {
Ok(())
}
#[test]
fn test_format_messages_with_image_path() -> anyhow::Result<()> {
// Create a temporary PNG file with valid PNG magic numbers
let temp_dir = tempfile::tempdir()?;
let png_path = temp_dir.path().join("test.png");
let png_data = [
0x89, 0x50, 0x4E, 0x47, // PNG magic number
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
];
std::fs::write(&png_path, png_data)?;
let png_path_str = png_path.to_str().unwrap();
// Create message with image path
let message = Message::user().with_text(format!("Here is an image: {}", png_path_str));
let spec = format_messages(&[message], &ImageFormat::OpenAi);
assert_eq!(spec.len(), 1);
assert_eq!(spec[0]["role"], "user");
// Content should be an array with text and image
let content = spec[0]["content"].as_array().unwrap();
assert_eq!(content.len(), 2);
assert_eq!(content[0]["type"], "text");
assert!(content[0]["text"].as_str().unwrap().contains(png_path_str));
assert_eq!(content[1]["type"], "image_url");
assert!(content[1]["image_url"]["url"]
.as_str()
.unwrap()
.starts_with("data:image/png;base64,"));
Ok(())
}
#[test]
fn test_response_to_message_text() -> anyhow::Result<()> {
let response = json!({

View File

@@ -7,10 +7,7 @@ use crate::{
providers::{
base::Usage,
errors::ProviderError,
utils::{
convert_image, detect_image_path, is_valid_function_name, load_image_file,
sanitize_function_name, ImageFormat,
},
utils::{convert_image, is_valid_function_name, sanitize_function_name, ImageFormat},
},
types::core::{Content, Role, Tool, ToolCall, ToolError},
};
@@ -31,23 +28,13 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
match content {
MessageContent::Text(text) => {
if !text.text.is_empty() {
// Check for image paths in the text
if let Some(image_path) = detect_image_path(&text.text) {
// Try to load and convert the image
if let Ok(image) = load_image_file(image_path) {
converted["content"] = json!([
{"type": "text", "text": text.text},
convert_image(&image, image_format)
]);
} else {
// If image loading fails, just use the text
converted["content"] = json!(text.text);
}
} else {
converted["content"] = json!(text.text);
}
converted["content"] = json!(text.text);
}
}
MessageContent::Image(image) => {
// Handle direct image content
converted["content"] = json!([convert_image(image, image_format)]);
}
MessageContent::Thinking(_) => {
// Thinking blocks are not directly used in OpenAI format
continue;
@@ -134,10 +121,6 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
}
}
}
MessageContent::Image(image) => {
// Handle direct image content
converted["content"] = json!([convert_image(image, image_format)]);
}
}
}
@@ -664,40 +647,6 @@ mod tests {
Ok(())
}
#[test]
fn test_format_messages_with_image_path() -> anyhow::Result<()> {
// Create a temporary PNG file with valid PNG magic numbers
let temp_dir = tempfile::tempdir()?;
let png_path = temp_dir.path().join("test.png");
let png_data = [
0x89, 0x50, 0x4E, 0x47, // PNG magic number
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
];
std::fs::write(&png_path, png_data)?;
let png_path_str = png_path.to_str().unwrap();
// Create message with image path
let message = Message::user().with_text(format!("Here is an image: {}", png_path_str));
let spec = format_messages(&[message], &ImageFormat::OpenAi);
assert_eq!(spec.len(), 1);
assert_eq!(spec[0]["role"], "user");
// Content should be an array with text and image
let content = spec[0]["content"].as_array().unwrap();
assert_eq!(content.len(), 2);
assert_eq!(content[0]["type"], "text");
assert!(content[0]["text"].as_str().unwrap().contains(png_path_str));
assert_eq!(content[1]["type"], "image_url");
assert!(content[1]["image_url"]["url"]
.as_str()
.unwrap()
.starts_with("data:image/png;base64,"));
Ok(())
}
#[test]
fn test_response_to_message_text() -> anyhow::Result<()> {
let response = json!({

View File

@@ -149,6 +149,7 @@ impl Provider for OpenAiProvider {
system: &str,
messages: &[Message],
tools: &[Tool],
_request_id: Option<&str>, // OpenAI doesn't use request_id, so we ignore it
) -> Result<ProviderCompleteResponse, ProviderError> {
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
@@ -175,6 +176,7 @@ impl Provider for OpenAiProvider {
system: &str,
messages: &[Message],
schema: &Value,
_request_id: Option<&str>, // OpenAI doesn't use request_id, so we ignore it
) -> Result<ProviderExtractResponse, ProviderError> {
// 1. Build base payload (no tools)
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;

View File

@@ -181,30 +181,6 @@ fn is_image_file(path: &Path) -> bool {
false
}
/// Detect if a string contains a path to an image file
pub fn detect_image_path(text: &str) -> Option<&str> {
// Basic image file extension check
let extensions = [".png", ".jpg", ".jpeg"];
// Find any word that ends with an image extension
for word in text.split_whitespace() {
if extensions
.iter()
.any(|ext| word.to_lowercase().ends_with(ext))
{
let path = Path::new(word);
// Check if it's an absolute path and file exists
if path.is_absolute() && path.is_file() {
// Verify it's actually an image file
if is_image_file(path) {
return Some(word);
}
}
}
}
None
}
/// Convert a local image file to base64 encoded ImageContent
pub fn load_image_file(path: &str) -> Result<ImageContent, ProviderError> {
let path = Path::new(path);
@@ -267,81 +243,6 @@ pub fn emit_debug_trace(
mod tests {
use super::*;
#[test]
fn test_detect_image_path() {
// Create a temporary PNG file with valid PNG magic numbers
let temp_dir = tempfile::tempdir().unwrap();
let png_path = temp_dir.path().join("test.png");
let png_data = [
0x89, 0x50, 0x4E, 0x47, // PNG magic number
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
];
std::fs::write(&png_path, png_data).unwrap();
let png_path_str = png_path.to_str().unwrap();
// Create a fake PNG (wrong magic numbers)
let fake_png_path = temp_dir.path().join("fake.png");
std::fs::write(&fake_png_path, b"not a real png").unwrap();
// Test with valid PNG file using absolute path
let text = format!("Here is an image {}", png_path_str);
assert_eq!(detect_image_path(&text), Some(png_path_str));
// Test with non-image file that has .png extension
let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap());
assert_eq!(detect_image_path(&text), None);
// Test with non-existent file
let text = "Here is a fake.png that doesn't exist";
assert_eq!(detect_image_path(text), None);
// Test with non-image file
let text = "Here is a file.txt";
assert_eq!(detect_image_path(text), None);
// Test with relative path (should not match)
let text = "Here is a relative/path/image.png";
assert_eq!(detect_image_path(text), None);
}
#[test]
fn test_load_image_file() {
// Create a temporary PNG file with valid PNG magic numbers
let temp_dir = tempfile::tempdir().unwrap();
let png_path = temp_dir.path().join("test.png");
let png_data = [
0x89, 0x50, 0x4E, 0x47, // PNG magic number
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
];
std::fs::write(&png_path, png_data).unwrap();
let png_path_str = png_path.to_str().unwrap();
// Create a fake PNG (wrong magic numbers)
let fake_png_path = temp_dir.path().join("fake.png");
std::fs::write(&fake_png_path, b"not a real png").unwrap();
let fake_png_path_str = fake_png_path.to_str().unwrap();
// Test loading valid PNG file
let result = load_image_file(png_path_str);
assert!(result.is_ok());
let image = result.unwrap();
assert_eq!(image.mime_type, "image/png");
// Test loading fake PNG file
let result = load_image_file(fake_png_path_str);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("not a valid image"));
// Test non-existent file
let result = load_image_file("nonexistent.png");
assert!(result.is_err());
}
#[test]
fn test_sanitize_function_name() {
assert_eq!(sanitize_function_name("hello-world"), "hello-world");

View File

@@ -6,13 +6,14 @@ use crate::{
/// Generates a structured output based on the provided schema,
/// system prompt and user messages.
#[uniffi::export(async_runtime = "tokio")]
#[uniffi::export(async_runtime = "tokio", default(request_id = None))]
pub async fn generate_structured_outputs(
provider_name: &str,
provider_config: JsonValueFfi,
system_prompt: &str,
messages: &[Message],
schema: JsonValueFfi,
request_id: Option<String>,
) -> Result<ProviderExtractResponse, ProviderError> {
// Use OpenAI models specifically for this task
let model_name = if provider_name == "databricks" {
@@ -23,7 +24,9 @@ pub async fn generate_structured_outputs(
let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0));
let provider = create(provider_name, provider_config, model_cfg)?;
let resp = provider.extract(system_prompt, messages, &schema).await?;
let resp = provider
.extract(system_prompt, messages, &schema, request_id.as_deref())
.await?;
Ok(resp)
}

View File

@@ -20,6 +20,7 @@ pub struct CompletionRequest {
pub system_prompt_override: Option<String>,
pub messages: Vec<Message>,
pub extensions: Vec<ExtensionConfig>,
pub request_id: Option<String>,
}
impl CompletionRequest {
@@ -40,10 +41,17 @@ impl CompletionRequest {
system_preamble,
messages,
extensions,
request_id: None,
}
}
pub fn with_request_id(mut self, request_id: String) -> Self {
self.request_id = Some(request_id);
self
}
}
#[allow(clippy::too_many_arguments)]
#[uniffi::export(default(system_preamble = None, system_prompt_override = None))]
pub fn create_completion_request(
provider_name: &str,
@@ -53,8 +61,9 @@ pub fn create_completion_request(
system_prompt_override: Option<String>,
messages: Vec<Message>,
extensions: Vec<ExtensionConfig>,
request_id: Option<String>,
) -> CompletionRequest {
CompletionRequest::new(
let mut request = CompletionRequest::new(
provider_name.to_string(),
provider_config,
model_config,
@@ -62,7 +71,13 @@ pub fn create_completion_request(
system_prompt_override,
messages,
extensions,
)
);
if let Some(req_id) = request_id {
request = request.with_request_id(req_id);
}
request
}
uniffi::custom_type!(CompletionRequest, String, {

View File

@@ -22,7 +22,7 @@ async fn _generate_session_name(messages: &[Message]) -> Result<String, Provider
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
});
generate_session_name(provider_name, provider_config, messages).await
generate_session_name(provider_name, provider_config, messages, None).await
}
#[tokio::test]

View File

@@ -24,7 +24,7 @@ async fn _generate_tooltip(messages: &[Message]) -> Result<String, ProviderError
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
});
generate_tooltip(provider_name, provider_config, messages).await
generate_tooltip(provider_name, provider_config, messages, None).await
}
#[tokio::test]

View File

@@ -91,7 +91,7 @@ impl ProviderTester {
let response = self
.provider
.complete("You are a helpful assistant.", &[message], &[])
.complete("You are a helpful assistant.", &[message], &[], None)
.await?;
// For a basic response, we expect a single text response
@@ -134,6 +134,7 @@ impl ProviderTester {
"You are a helpful weather assistant.",
&[message.clone()],
&[weather_tool.clone()],
None,
)
.await?;
@@ -182,6 +183,7 @@ impl ProviderTester {
"You are a helpful weather assistant.",
&[message, response1.message, weather],
&[weather_tool],
None,
)
.await?;
@@ -225,7 +227,7 @@ impl ProviderTester {
// Test that we get ProviderError::ContextLengthExceeded when the context window is exceeded
let result = self
.provider
.complete("You are a helpful assistant.", &messages, &[])
.complete("You are a helpful assistant.", &messages, &[], None)
.await;
// Print some debug info

View File

@@ -125,7 +125,7 @@ where
let provider = provider_type.create_provider(cfg)?;
let msg = Message::user().with_text(user_text);
let resp = provider.extract(system, &[msg], &schema).await?;
let resp = provider.extract(system, &[msg], &schema, None).await?;
println!("[{:?}] extract => {}", provider_type, resp.data);