From a5d77950db59ff1167919ebe3f48fca0dc4a15e2 Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Tue, 15 Jul 2025 18:06:37 -0400 Subject: [PATCH] [goose-llm] fix image content bug, add optional request_id field (#3439) --- bindings/kotlin/example/RuntimeStats.kt | 115 +++++ bindings/kotlin/example/Usage.kt | 476 ++++++++---------- bindings/kotlin/uniffi/goose_llm/goose_llm.kt | 20 +- crates/goose-llm/Cargo.toml | 4 + crates/goose-llm/examples/image.rs | 53 ++ crates/goose-llm/examples/simple.rs | 2 +- .../examples/test_assets/test_image.png | Bin 0 -> 4339 bytes crates/goose-llm/src/completion.rs | 7 +- .../goose-llm/src/extractors/session_name.rs | 4 +- crates/goose-llm/src/extractors/tooltip.rs | 4 +- crates/goose-llm/src/providers/base.rs | 4 + crates/goose-llm/src/providers/databricks.rs | 24 + .../src/providers/formats/databricks.rs | 79 +-- .../goose-llm/src/providers/formats/openai.rs | 63 +-- crates/goose-llm/src/providers/openai.rs | 2 + crates/goose-llm/src/providers/utils.rs | 99 ---- crates/goose-llm/src/structured_outputs.rs | 7 +- crates/goose-llm/src/types/completion.rs | 19 +- .../goose-llm/tests/extract_session_name.rs | 2 +- crates/goose-llm/tests/extract_tooltip.rs | 2 +- crates/goose-llm/tests/providers_complete.rs | 6 +- crates/goose-llm/tests/providers_extract.rs | 2 +- 22 files changed, 482 insertions(+), 512 deletions(-) create mode 100644 bindings/kotlin/example/RuntimeStats.kt create mode 100644 crates/goose-llm/examples/image.rs create mode 100644 crates/goose-llm/examples/test_assets/test_image.png diff --git a/bindings/kotlin/example/RuntimeStats.kt b/bindings/kotlin/example/RuntimeStats.kt new file mode 100644 index 00000000..688d382f --- /dev/null +++ b/bindings/kotlin/example/RuntimeStats.kt @@ -0,0 +1,115 @@ +import kotlin.system.measureNanoTime +import kotlinx.coroutines.runBlocking +import uniffi.goose_llm.* + +import java.net.URI +import java.net.http.HttpClient +import java.net.http.HttpRequest +import java.net.http.HttpResponse + +/* ---------- Goose helpers ---------- */ + +fun buildProviderConfig(host: String, token: String): String = + """{ "host": "$host", "token": "$token" }""" + +suspend fun timeGooseCall( + modelCfg: ModelConfig, + providerName: String, + providerCfg: String +): Pair { + + val req = createCompletionRequest( + providerName, + providerCfg, + modelCfg, + systemPreamble = "You are a helpful assistant.", + messages = listOf( + Message( + Role.USER, + System.currentTimeMillis() / 1000, + listOf(MessageContent.Text(TextContent("Write me a 1000 word chapter about learning Go vs Rust in the world of LLMs and AI."))) + ) + ), + extensions = emptyList() + ) + + lateinit var resp: CompletionResponse + val wallMs = measureNanoTime { resp = completion(req) } / 1_000_000.0 + return wallMs to resp +} + +/* ---------- OpenAI helpers ---------- */ + +fun timeOpenAiCall(client: HttpClient, apiKey: String): Double { + val body = """ + { + "model": "gpt-4.1", + "max_tokens": 500, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Write me a 1000 word chapter about learning Go vs Rust in the world of LLMs and AI."} + ] + } + """.trimIndent() + + val request = HttpRequest.newBuilder() + .uri(URI.create("https://api.openai.com/v1/chat/completions")) + .header("Authorization", "Bearer $apiKey") + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(body)) + .build() + + val wallMs = measureNanoTime { + client.send(request, HttpResponse.BodyHandlers.ofString()) + } / 1_000_000.0 + + return wallMs +} + +/* ---------- main ---------- */ + +fun main() = runBlocking { + /* Goose provider setup */ + val providerName = "databricks" + val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set") + val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set") + val providerCfg = buildProviderConfig(host, token) + + /* OpenAI setup */ + val openAiKey = System.getenv("OPENAI_API_KEY") ?: error("OPENAI_API_KEY not set") + val httpClient = HttpClient.newBuilder().build() + + val gooseModels = listOf("goose-claude-4-sonnet", "goose-gpt-4-1") + val runsPerModel = 3 + + /* --- Goose timing --- */ + for (model in gooseModels) { + val maxTokens = 500 + val cfg = ModelConfig(model, 100_000u, 0.0f, maxTokens) + var wallSum = 0.0 + var gooseSum = 0.0 + + println("=== Goose: $model ===") + repeat(runsPerModel) { run -> + val (wall, resp) = timeGooseCall(cfg, providerName, providerCfg) + val gooseMs = resp.runtimeMetrics.totalTimeSec * 1_000 + val overhead = wall - gooseMs + wallSum += wall + gooseSum += gooseMs + println("run ${run + 1}: wall = %.1f ms | goose-llm = %.1f ms | overhead = %.1f ms" + .format(wall, gooseMs, overhead)) + } + println("-- avg wall = %.1f ms | avg overhead = %.1f ms --\n" + .format(wallSum / runsPerModel, (wallSum - gooseSum) / runsPerModel)) + } + + /* --- OpenAI direct timing --- */ + var oaSum = 0.0 + println("=== OpenAI: gpt-4.1 (direct HTTPS) ===") + repeat(runsPerModel) { run -> + val wall = timeOpenAiCall(httpClient, openAiKey) + oaSum += wall + println("run ${run + 1}: wall = %.1f ms".format(wall)) + } + println("-- avg wall = %.1f ms --".format(oaSum / runsPerModel)) +} diff --git a/bindings/kotlin/example/Usage.kt b/bindings/kotlin/example/Usage.kt index cdb06c82..90ee002d 100644 --- a/bindings/kotlin/example/Usage.kt +++ b/bindings/kotlin/example/Usage.kt @@ -1,292 +1,228 @@ +import java.io.File +import java.util.Base64 import kotlinx.coroutines.runBlocking import uniffi.goose_llm.* -fun main() = runBlocking { - val now = System.currentTimeMillis() / 1000 - val msgs = listOf( - // 1) User sends a plain-text prompt - Message( - role = Role.USER, - created = now, - content = listOf( - MessageContent.Text( - TextContent("What is 7 x 6?") - ) - ) - ), +/* ---------- shared helpers ---------- */ - // 2) Assistant makes a tool request (ToolReq) to calculate 7×6 - Message( - role = Role.ASSISTANT, - created = now + 2, - content = listOf( - MessageContent.ToolReq( - ToolRequest( - id = "calc1", - toolCall = """ - { - "status": "success", - "value": { - "name": "calculator_extension__toolname", - "arguments": { - "operation": "doesnotexist", - "numbers": [7, 6] - }, - "needsApproval": false - } - } - """.trimIndent() - ) - ) - ) - ), - - // 3) User (on behalf of the tool) responds with the tool result (ToolResp) - Message( - role = Role.USER, - created = now + 3, - content = listOf( - MessageContent.ToolResp( - ToolResponse( - id = "calc1", - toolResult = """ - { - "status": "error", - "error": "Invalid value for operation: 'doesnotexist'. Valid values are: ['add', 'subtract', 'multiply', 'divide']" - } - """.trimIndent() - ) - ) - ) - ), - - // 4) Assistant makes a tool request (ToolReq) to calculate 7×6 - Message( - role = Role.ASSISTANT, - created = now + 4, - content = listOf( - MessageContent.ToolReq( - ToolRequest( - id = "calc1", - toolCall = """ - { - "status": "success", - "value": { - "name": "calculator_extension__toolname", - "arguments": { - "operation": "multiply", - "numbers": [7, 6] - }, - "needsApproval": false - } - } - """.trimIndent() - ) - ) - ) - ), - - // 5) User (on behalf of the tool) responds with the tool result (ToolResp) - Message( - role = Role.USER, - created = now + 5, - content = listOf( - MessageContent.ToolResp( - ToolResponse( - id = "calc1", - toolResult = """ - { - "status": "success", - "value": [ - {"type": "text", "text": "42"} - ] - } - """.trimIndent() - ) - ) - ) - ), - ) - - printMessages(msgs) - println("---\n") - - // Setup provider - val providerName = "databricks" - val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set") - val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set") - val providerConfig = """{"host": "$host", "token": "$token"}""" - - println("Provider Name: $providerName") - println("Provider Config: $providerConfig") - - - val sessionName = generateSessionName(providerName, providerConfig, msgs) - println("\nSession Name: $sessionName") - - val tooltip = generateTooltip(providerName, providerConfig, msgs) - println("\nTooltip: $tooltip") - - // Completion - val modelName = "goose-gpt-4-1" - val modelConfig = ModelConfig( - modelName, - 100000u, // UInt - 0.1f, // Float - 200 // Int - ) +fun buildProviderConfig(host: String, token: String, imageFormat: String = "OpenAi"): String = """ +{ + "host": "$host", + "token": "$token", + "image_format": "$imageFormat" +} +""".trimIndent() +fun calculatorExtension(): ExtensionConfig { val calculatorTool = createToolConfig( - name = "calculator", + name = "calculator", description = "Perform basic arithmetic operations", inputSchema = """ { - "type": "object", - "required": ["operation", "numbers"], - "properties": { - "operation": { - "type": "string", - "enum": ["add", "subtract", "multiply", "divide"], - "description": "The arithmetic operation to perform" - }, - "numbers": { - "type": "array", - "items": { "type": "number" }, - "description": "List of numbers to operate on in order" - } + "type": "object", + "required": ["operation", "numbers"], + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The arithmetic operation to perform" + }, + "numbers": { + "type": "array", + "items": { "type": "number" }, + "description": "List of numbers to operate on in order" } + } } """.trimIndent(), approvalMode = ToolApprovalMode.AUTO ) - - val calculator_extension = ExtensionConfig( - name = "calculator_extension", + return ExtensionConfig( + name = "calculator_extension", instructions = "This extension provides a calculator tool.", - tools = listOf(calculatorTool) + tools = listOf(calculatorTool) ) - - val extensions = listOf(calculator_extension) - val systemPreamble = "You are a helpful assistant." - - // Testing with tool calls with an error in tool name - val reqToolErr = createCompletionRequest( - providerName, - providerConfig, - modelConfig, - systemPreamble, - messages = listOf( - Message( - role = Role.USER, - created = now, - content = listOf( - MessageContent.Text( - TextContent("What is 7 x 6?") - ) - ) - )), - extensions = extensions - ) - - val respToolErr = completion(reqToolErr) - println("\nCompletion Response (one msg):\n${respToolErr.message}") - println() - - val reqAll = createCompletionRequest( - providerName, - providerConfig, - modelConfig, - systemPreamble, - messages = msgs, - extensions = extensions - ) - - val respAll = completion(reqAll) - println("\nCompletion Response (all msgs):\n${respAll.message}") - println() - - // ---- UI Extraction (custom schema) ---- - runUiExtraction(providerName, providerConfig) - - // --- Prompt Override --- - val prompt_req = createCompletionRequest( - providerName, - providerConfig, - modelConfig, - systemPreamble = null, - systemPromptOverride = "You are a bot named Tile Creator. Your task is to create a tile based on the user's input.", - messages=listOf( - Message( - role = Role.USER, - created = now, - content = listOf( - MessageContent.Text( - TextContent("What's your name?") - ) - ) - ) - ), - extensions=emptyList() - ) - - val prompt_resp = completion(prompt_req) - - println("\nPrompt Override Response:\n${prompt_resp.message}") } +/* ---------- demos ---------- */ + +suspend fun runCalculatorDemo( + modelConfig: ModelConfig, + providerName: String, + providerConfig: String +) { + val now = System.currentTimeMillis() / 1000 + val msgs = listOf( + // same conversation you already had + Message(Role.USER, now, listOf(MessageContent.Text(TextContent("What is 7 x 6?")))), + Message(Role.ASSISTANT, now + 2, listOf(MessageContent.ToolReq( + ToolRequest( + id = "calc1", + toolCall = """ + { + "status": "success", + "value": { + "name": "calculator_extension__toolname", + "arguments": { "operation": "doesnotexist", "numbers": [7,6] }, + "needsApproval": false + } + } + """.trimIndent() + )))), + Message(Role.USER, now + 3, listOf(MessageContent.ToolResp( + ToolResponse( + id = "calc1", + toolResult = """ + { + "status": "error", + "error": "Invalid value for operation: 'doesnotexist'. Valid values are: ['add','subtract','multiply','divide']" + } + """.trimIndent() + )))), + Message(Role.ASSISTANT, now + 4, listOf(MessageContent.ToolReq( + ToolRequest( + id = "calc1", + toolCall = """ + { + "status": "success", + "value": { + "name": "calculator_extension__toolname", + "arguments": { "operation": "multiply", "numbers": [7,6] }, + "needsApproval": false + } + } + """.trimIndent() + )))), + Message(Role.USER, now + 5, listOf(MessageContent.ToolResp( + ToolResponse( + id = "calc1", + toolResult = """ + { + "status": "success", + "value": [ { "type": "text", "text": "42" } ] + } + """.trimIndent() + )))) + ) + + /* one-shot prompt with error */ + val reqErr = createCompletionRequest( + providerName, providerConfig, modelConfig, + "You are a helpful assistant.", + messages = listOf(msgs.first()), + extensions = listOf(calculatorExtension()) + ) + println("\n[${modelConfig.modelName}] Calculator (single-msg) → ${completion(reqErr).message}") + + /* full conversation */ + val reqAll = createCompletionRequest( + providerName, providerConfig, modelConfig, + "You are a helpful assistant.", + messages = msgs, + extensions = listOf(calculatorExtension()) + ) + println("[${modelConfig.modelName}] Calculator (full chat) → ${completion(reqAll).message}") +} + +suspend fun runImageExample( + modelConfig: ModelConfig, + providerName: String, + providerConfig: String +) { + val imagePath = "../../crates/goose/examples/test_assets/test_image.png" + val base64Image = Base64.getEncoder().encodeToString(File(imagePath).readBytes()) + val now = System.currentTimeMillis() / 1000 + + val msgs = listOf( + Message(Role.USER, now, listOf( + MessageContent.Text(TextContent("What is in this image?")), + MessageContent.Image(ImageContent(base64Image, "image/png")) + )), + ) + + val req = createCompletionRequest( + providerName, providerConfig, modelConfig, + "You are a helpful assistant. Please describe any text you see in the image.", + messages = msgs, + extensions = emptyList() + ) + + println("\n[${modelConfig.modelName}] Image example → ${completion(req).message}") +} + +suspend fun runPromptOverride( + modelConfig: ModelConfig, + providerName: String, + providerConfig: String +) { + val now = System.currentTimeMillis() / 1000 + val req = createCompletionRequest( + providerName, providerConfig, modelConfig, + systemPreamble = null, + systemPromptOverride = "You are a bot named Tile Creator. Your task is to create a tile based on the user's input.", + messages = listOf( + Message(Role.USER, now, listOf(MessageContent.Text(TextContent("What's your name?")))) + ), + extensions = emptyList() + ) + println("\n[${modelConfig.modelName}] Prompt override → ${completion(req).message}") +} suspend fun runUiExtraction(providerName: String, providerConfig: String) { - val systemPrompt = "You are a UI generator AI. Convert the user input into a JSON-driven UI." - val messages = listOf( - Message( - role = Role.USER, - created = System.currentTimeMillis() / 1000, - content = listOf( - MessageContent.Text( - TextContent("Make a User Profile Form") - ) - ) - ) - ) - val schema = """{ - "type": "object", - "properties": { - "type": { - "type": "string", - "enum": ["div","button","header","section","field","form"] - }, - "label": { "type": "string" }, - "children": { - "type": "array", - "items": { "${'$'}ref": "#" } - }, - "attributes": { - "type": "array", - "items": { - "type": "object", - "properties": { - "name": { "type": "string" }, - "value": { "type": "string" } - }, - "required": ["name","value"], - "additionalProperties": false - } - } - }, - "required": ["type","label","children","attributes"], - "additionalProperties": false - }""".trimIndent(); + val schema = /* same JSON schema as before */ """ + { + "type":"object", + "properties":{ + "type":{"type":"string","enum":["div","button","header","section","field","form"]}, + "label":{"type":"string"}, + "children":{"type":"array","items":{"${'$'}ref":"#"}}, + "attributes":{"type":"array","items":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"string"}},"required":["name","value"],"additionalProperties":false}} + }, + "required":["type","label","children","attributes"], + "additionalProperties":false + } + """.trimIndent() - try { - val response = generateStructuredOutputs( - providerName = providerName, - providerConfig = providerConfig, - systemPrompt = systemPrompt, - messages = messages, - schema = schema - ) - println("\nUI Extraction Output:\n${response}") - } catch (e: ProviderException) { - println("\nUI Extraction failed:\n${e.message}") - } + val messages = listOf( + Message(Role.USER, System.currentTimeMillis()/1000, + listOf(MessageContent.Text(TextContent("Make a User Profile Form")))) + ) + + val res = generateStructuredOutputs( + providerName, providerConfig, + systemPrompt = "You are a UI generator AI. Convert the user input into a JSON-driven UI.", + messages = messages, + schema = schema + ) + println("\n[UI-Extraction] → $res") +} + +/* ---------- entry-point ---------- */ + +fun main() = runBlocking { + /* --- provider setup --- */ + val providerName = "databricks" + val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set") + val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set") + val providerConfig = buildProviderConfig(host, token) + + println("Provider: $providerName") + println("Config : $providerConfig\n") + + /* --- run demos for each model --- */ + // NOTE: `claude-3-5-haiku` does NOT support images + val modelNames = listOf("kgoose-gpt-4o", "goose-claude-4-sonnet") + + for (name in modelNames) { + val modelConfig = ModelConfig(name, 100000u, 0.1f, 200) + println("\n===== Running demos for model: $name =====") + + runCalculatorDemo(modelConfig, providerName, providerConfig) + runImageExample(modelConfig, providerName, providerConfig) + runPromptOverride(modelConfig, providerName, providerConfig) + println("===== End demos for $name =====\n") + } + + /* UI extraction is model-agnostic, so run it once */ + runUiExtraction(providerName, providerConfig) } diff --git a/bindings/kotlin/uniffi/goose_llm/goose_llm.kt b/bindings/kotlin/uniffi/goose_llm/goose_llm.kt index 76e60aaf..f0195694 100644 --- a/bindings/kotlin/uniffi/goose_llm/goose_llm.kt +++ b/bindings/kotlin/uniffi/goose_llm/goose_llm.kt @@ -833,6 +833,7 @@ internal interface UniffiLib : Library { `systemPromptOverride`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, `extensions`: RustBuffer.ByValue, + `requestId`: RustBuffer.ByValue, uniffi_out_err: UniffiRustCallStatus, ): RustBuffer.ByValue @@ -848,6 +849,7 @@ internal interface UniffiLib : Library { `providerName`: RustBuffer.ByValue, `providerConfig`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, + `requestId`: RustBuffer.ByValue, ): Long fun uniffi_goose_llm_fn_func_generate_structured_outputs( @@ -856,12 +858,14 @@ internal interface UniffiLib : Library { `systemPrompt`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, `schema`: RustBuffer.ByValue, + `requestId`: RustBuffer.ByValue, ): Long fun uniffi_goose_llm_fn_func_generate_tooltip( `providerName`: RustBuffer.ByValue, `providerConfig`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, + `requestId`: RustBuffer.ByValue, ): Long fun uniffi_goose_llm_fn_func_print_messages( @@ -1101,19 +1105,19 @@ private fun uniffiCheckApiChecksums(lib: IntegrityCheckingUniffiLib) { if (lib.uniffi_goose_llm_checksum_func_completion() != 47457.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 50798.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 15391.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 49910.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 64087.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 34350.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 43426.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 4576.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 41121.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 36439.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } if (lib.uniffi_goose_llm_checksum_func_print_messages() != 30278.toShort()) { @@ -2960,6 +2964,7 @@ fun `createCompletionRequest`( `systemPromptOverride`: kotlin.String? = null, `messages`: List, `extensions`: List, + `requestId`: kotlin.String? = null, ): CompletionRequest = FfiConverterTypeCompletionRequest.lift( uniffiRustCall { _status -> @@ -2971,6 +2976,7 @@ fun `createCompletionRequest`( FfiConverterOptionalString.lower(`systemPromptOverride`), FfiConverterSequenceTypeMessage.lower(`messages`), FfiConverterSequenceTypeExtensionConfig.lower(`extensions`), + FfiConverterOptionalString.lower(`requestId`), _status, ) }, @@ -3003,12 +3009,14 @@ suspend fun `generateSessionName`( `providerName`: kotlin.String, `providerConfig`: Value, `messages`: List, + `requestId`: kotlin.String? = null, ): kotlin.String = uniffiRustCallAsync( UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_session_name( FfiConverterString.lower(`providerName`), FfiConverterTypeValue.lower(`providerConfig`), FfiConverterSequenceTypeMessage.lower(`messages`), + FfiConverterOptionalString.lower(`requestId`), ), { future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) }, { future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) }, @@ -3031,6 +3039,7 @@ suspend fun `generateStructuredOutputs`( `systemPrompt`: kotlin.String, `messages`: List, `schema`: Value, + `requestId`: kotlin.String? = null, ): ProviderExtractResponse = uniffiRustCallAsync( UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_structured_outputs( @@ -3039,6 +3048,7 @@ suspend fun `generateStructuredOutputs`( FfiConverterString.lower(`systemPrompt`), FfiConverterSequenceTypeMessage.lower(`messages`), FfiConverterTypeValue.lower(`schema`), + FfiConverterOptionalString.lower(`requestId`), ), { future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) }, { future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) }, @@ -3059,12 +3069,14 @@ suspend fun `generateTooltip`( `providerName`: kotlin.String, `providerConfig`: Value, `messages`: List, + `requestId`: kotlin.String? = null, ): kotlin.String = uniffiRustCallAsync( UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_tooltip( FfiConverterString.lower(`providerName`), FfiConverterTypeValue.lower(`providerConfig`), FfiConverterSequenceTypeMessage.lower(`messages`), + FfiConverterOptionalString.lower(`requestId`), ), { future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) }, { future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) }, diff --git a/crates/goose-llm/Cargo.toml b/crates/goose-llm/Cargo.toml index ef073a37..9f3dd9ed 100644 --- a/crates/goose-llm/Cargo.toml +++ b/crates/goose-llm/Cargo.toml @@ -64,6 +64,10 @@ path = "uniffi-bindgen.rs" name = "simple" path = "examples/simple.rs" +[[example]] +name = "image" +path = "examples/image.rs" + [[example]] name = "prompt_override" path = "examples/prompt_override.rs" diff --git a/crates/goose-llm/examples/image.rs b/crates/goose-llm/examples/image.rs new file mode 100644 index 00000000..7c607713 --- /dev/null +++ b/crates/goose-llm/examples/image.rs @@ -0,0 +1,53 @@ +use anyhow::Result; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; +use goose_llm::{ + completion, + message::MessageContent, + types::completion::{CompletionRequest, CompletionResponse}, + Message, ModelConfig, +}; +use serde_json::json; +use std::{fs, vec}; + +#[tokio::main] +async fn main() -> Result<()> { + let provider = "databricks"; + let provider_config = json!({ + "host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"), + "token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"), + }); + let model_name = "goose-claude-4-sonnet"; // "gpt-4o"; + let model_config = ModelConfig::new(model_name.to_string()); + + let system_preamble = "You are a helpful assistant."; + + // Read and encode test image + let image_data = fs::read("examples/test_assets/test_image.png")?; + let base64_image = BASE64.encode(image_data); + + let user_msg = Message::user() + .with_text("What do you see in this image?") + .with_content(MessageContent::image(base64_image, "image/png")); + + let messages = vec![user_msg]; + + let completion_response: CompletionResponse = completion( + CompletionRequest::new( + provider.to_string(), + provider_config.clone(), + model_config.clone(), + Some(system_preamble.to_string()), + None, + messages, + vec![], + ) + .with_request_id("test-image-1".to_string()), + ) + .await?; + + // Print the response + println!("\nCompletion Response:"); + println!("{}", serde_json::to_string_pretty(&completion_response)?); + + Ok(()) +} diff --git a/crates/goose-llm/examples/simple.rs b/crates/goose-llm/examples/simple.rs index e7d36a78..efab4b0a 100644 --- a/crates/goose-llm/examples/simple.rs +++ b/crates/goose-llm/examples/simple.rs @@ -116,7 +116,7 @@ async fn main() -> Result<()> { println!("\nCompletion Response:"); println!("{}", serde_json::to_string_pretty(&completion_response)?); - let tooltip = generate_tooltip(provider, provider_config.clone(), &messages).await?; + let tooltip = generate_tooltip(provider, provider_config.clone(), &messages, None).await?; println!("\nTooltip: {}", tooltip); } diff --git a/crates/goose-llm/examples/test_assets/test_image.png b/crates/goose-llm/examples/test_assets/test_image.png new file mode 100644 index 0000000000000000000000000000000000000000..f72b65986d199187ebba43d7157e9bcc7d4c3f1a GIT binary patch literal 4339 zcmdT{XHXML*A5+}7a??mC_Z#$Rz+EKp6n|=?DNoUH|~XzIpE~ zuhJQe&Su7jfRnh?9x{E$=5JzQ#J0o%6cXYvhWFJ206e26hI-b)6RY{CVC&(j&$>bU z`s}>#`Lp<)2O=NloV9AdCLPWA9?{@6j5Pu$*<0Sd^mVg{q_Dcx+Pc}~$efM6?JXs! zqKJI$_}2b9U)uRgPIA0Lr_9cZor;q2qw2b?J=CI6zsO{0&1e-yG%ae9{rYIsb&7US z9|f+=ESl-tVIO$`=pFCuVIXKARs@{j?HIsZ#9TI1X#gd@=Cb7cUt(FR$qXqc&rYW~ z0q;Ky3G)Dc8hMpcqqVOObtCR_w|wp&7sMT%a!GdrrW$(w+Fclny6ERS!sZjXx%vJK znd}9HR_>@LCQU!XFD0J$i9l@OGQbaTB8n!@k4idcO&i395xNu5(NtfA#oXY*OI3|K2{r3!Se|V!A3``a|fP0cGe2+4O6*;6+bb zvE4BNYU9*m!p$hr|Acwe#nn|u_Se2y@npzujAlXopxj62>A|G#v^c<5WV3e-L&+LJ zb#AFNQf00$4-By4iG|j6&fw(Nl~RO)n7;bGq2#VIhC3Jk6`vaQ{hl?3#9D&B5X;Mk zWMLiY;(GL*Jz4u=EH8^D;~h)^XGWCKi0ItW{2*M9B8JyzuRVxs)5QgxuH}9qOg9u} z@bibtobBD4+OrlUHrq>@E_}+=Rg;PqmS7E+XAd3Y@T{l^R0Es&cA)k+gJjK!+(*gj z{EHHV7eG}pTki{5_?Q{JYrB{0K+8=pgQP`5;>U(}(Sx$KrK#ZI4)Y z)lEMu2w|Ft?aC1gGVw~n{05Bx^2VYLry%pKXY7iE@v##-KxP)*Qh~OBh1(r0EC+Pp z%+t64ONfu?dI?|Hp)KGFDI;|Mt_iLr55$>oTpwU2^VnwT; zn=VDaw4n?7)>mfdw@xZ3AXvDE=HFNp7&}TE02X<54z*(~3`UsHM>A1vAUhftZ%_(Z z3viLs)R#fHvMigHZJgpsjyd0bW9q7OHzC*wSKoaKV!ZsKV>Wl6#@u4e0PWs~75JcR`nk!0zRqIMOjgGGE~hekAq?19R2n#lPneTtR`=BZ!KSAc!hNu(H! z;|1=kQk4wJfUGBhe%1A>NL{Z*Z6U$%abf6jFVtPIKg|}XDmXTc6hiOL#Ld@{K)#@} z>r&JzhUEA-iClwsI52)yDYdGz_efXG|46n?Q9EDDtT16>FkwO1APHZY^CxZQ;F@ef ztWBo=W`{3VR1(%bqvd-cQ?nFMfM1NQroWDS6>>(A#xuz6X7@8$QN`|Fn0WAL@|sC> z;c1Y{$0;`co-;PT$!&)NcAE^x;({R@L%j45^#=i4T_dHZxyR#2|8rdjLec%S zdj2xxZaJuMaeh9ZhkJg0UV)|oy{gv9f@AI9Jo?rlv0(aesSD40JypPCeo3fyq*gQ3g+<-w)3Vmx%$Wil-+haFr3d>h6muJKijxEI(P>&v83hD#Y#>Q^-ju!HGnp?q~7=K05mT5mI2k9V3m{eQ%wf+ z+H)&Mpi~`C=}${~9^apqgV*lXE}a^%`IAGU5@%4EVnLao%C?Ev`&Xh}=I3TkpL+AB z;?VDYwZycD)6Q!@9Ske1v=f%UKkMOL&8PFh(lzx5GLgVo$kqU7NcY;U`tfMptyb#(XE6i#mTuTITGiMFDhbhRfxgahU@-mYd<(t3qz{rW+#t5OC*>G-j2#hdjM9D zt=7~F#+Ykns6r*zkv@d)o)uor>-vMPG2eMn=|raV`*m)u9J)(>;qELCk5|p~l=3D0 z7{AsNVQXDMTd4_Bwp=sMX{&~sCGb()!axo`MU6F?Cj@W7CXaZ8r)3G5hl}0pPtNQ29Y_VOn~Pudy!`m=;3r;`RR<5-<@UR^P`Mot zk!iQzwXCTw$~=1Q9%?TB3vUhExxB5>y>wy9tD@^=C+AC#@c13JL`Jw4Po=t)LP+Z? z$Cn-KP?VgS`|(H+L6q&AT_PXyeT}6Xc+Y3dmajGK_8RiX9zj@n;;!s4M8s49hPFFR z-b4%Odj2i&+P8?g+)mzhmfX4!qzDvsbo`mT+Nk?t^YNzdlUV1h8axAQu3k3b&>s6! z%rRLh?7lNTh5qc?2Kz6>FyOfsggH~zwj^y+h_K4;v;bS*J9uBlR$V7y>Hw9j;_beF zuBk|6BkH{)j3sT39Fs_EjB`DrPS?jC zJ}j~%Iq=J)qAR_#FU@czTV{vKHxD!aN&AI;z<3rFj{_$CW8c8UAJ|K5;xzpHtT#;S z>oyrUQcZg*Q~l8J+V#H(oGab5ic!P+c8MsoS`Od>GDu%XFz4$*Z-3l<$)=`IbLg!fhf6Ro4Ak4SDyjMUgXxXN#H3*XPk-PD8?m(Fh_GToI_$>A;vVB;D05&OKSZH)oAH_G5B;cL0Ff zWU{6X2ta$M(9>79!}&}xx*2dg{1(q0!L}Z~EP-8Zqv$7Cy$EO|Elx(}jn1PN%OX^l zw^(J5S+$QvRN^sl5QKct7S#X9=r;h#>P}igJtFdM{{6k9@O9^LNl2t23vI z(F=HeU$2xU& zhd%=268(yvsspRzh%LtTmv}HQ3D#!VWz=86SW1EhX`cgWfW_8g3^A2dVmoS9f2D@? z_gSUC<9Ta}=rFTS>G2FP; z+Ewk{@dt*K(#BPmPq`dw0S1ILJf5$3+1NzP;&7*Cwy?mPmaL$RS1DdrEBSq?5a3P3Np~U{`3Auw*D?o{;rS*k6q~m zpsK8@uAri#psH@IstTbG6&0|uGDKM!_8}_oKLFlH7f(0TzXN0nw9e827yn;_zo)mW apTCp0@4sXGl7J_448X()YFMT382x{&;}I_a literal 0 HcmV?d00001 diff --git a/crates/goose-llm/src/completion.rs b/crates/goose-llm/src/completion.rs index d39b1b8d..13f09810 100644 --- a/crates/goose-llm/src/completion.rs +++ b/crates/goose-llm/src/completion.rs @@ -46,7 +46,12 @@ pub async fn completion(req: CompletionRequest) -> Result String { } /// Generates a short (≤4 words) session name -#[uniffi::export(async_runtime = "tokio")] +#[uniffi::export(async_runtime = "tokio", default(request_id = None))] pub async fn generate_session_name( provider_name: &str, provider_config: JsonValueFfi, messages: &[Message], + request_id: Option, ) -> Result { // Collect up to the first 3 user messages (truncated to 300 chars each) let context: Vec = messages @@ -90,6 +91,7 @@ pub async fn generate_session_name( &system_prompt, &[Message::user().with_text(&user_msg_text)], schema, + request_id, ) .await?; diff --git a/crates/goose-llm/src/extractors/tooltip.rs b/crates/goose-llm/src/extractors/tooltip.rs index 37d83ffe..48336a54 100644 --- a/crates/goose-llm/src/extractors/tooltip.rs +++ b/crates/goose-llm/src/extractors/tooltip.rs @@ -52,11 +52,12 @@ fn build_system_prompt() -> String { /// Generates a tooltip summarizing the last two messages in the session, /// including any tool calls or results. -#[uniffi::export(async_runtime = "tokio")] +#[uniffi::export(async_runtime = "tokio", default(request_id = None))] pub async fn generate_tooltip( provider_name: &str, provider_config: JsonValueFfi, messages: &[Message], + request_id: Option, ) -> Result { // Need at least two messages to generate a tooltip if messages.len() < 2 { @@ -148,6 +149,7 @@ pub async fn generate_tooltip( &system_prompt, &[Message::user().with_text(&user_msg_text)], schema, + request_id, ) .await?; diff --git a/crates/goose-llm/src/providers/base.rs b/crates/goose-llm/src/providers/base.rs index dcfecbd1..92a3948d 100644 --- a/crates/goose-llm/src/providers/base.rs +++ b/crates/goose-llm/src/providers/base.rs @@ -69,6 +69,7 @@ pub trait Provider: Send + Sync { /// * `system` - The system prompt that guides the model's behavior /// * `messages` - The conversation history as a sequence of messages /// * `tools` - Optional list of tools the model can use + /// * `request_id` - Optional request ID (only used by some providers like Databricks) /// /// # Returns /// A tuple containing the model's response message and provider usage statistics @@ -81,6 +82,7 @@ pub trait Provider: Send + Sync { system: &str, messages: &[Message], tools: &[Tool], + request_id: Option<&str>, ) -> Result; /// Structured extraction: always JSON‐Schema @@ -90,6 +92,7 @@ pub trait Provider: Send + Sync { /// * `messages` – conversation history /// * `schema` – a JSON‐Schema for the expected output. /// Will set strict=true for OpenAI & Databricks. + /// * `request_id` - Optional request ID (only used by some providers like Databricks) /// /// # Returns /// A `ProviderExtractResponse` whose `data` is a JSON object matching `schema`. @@ -102,6 +105,7 @@ pub trait Provider: Send + Sync { system: &str, messages: &[Message], schema: &serde_json::Value, + request_id: Option<&str>, ) -> Result; } diff --git a/crates/goose-llm/src/providers/databricks.rs b/crates/goose-llm/src/providers/databricks.rs index 3dd31493..0bfe2ffe 100644 --- a/crates/goose-llm/src/providers/databricks.rs +++ b/crates/goose-llm/src/providers/databricks.rs @@ -210,6 +210,7 @@ impl Provider for DatabricksProvider { system: &str, messages: &[Message], tools: &[Tool], + request_id: Option<&str>, ) -> Result { let mut payload = create_request( &self.model, @@ -224,6 +225,17 @@ impl Provider for DatabricksProvider { .expect("payload should have model key") .remove("model"); + // Add client_request_id if provided + if let Some(req_id) = request_id { + payload + .as_object_mut() + .expect("payload should be an object") + .insert( + "client_request_id".to_string(), + serde_json::Value::String(req_id.to_string()), + ); + } + let response = self.post(payload.clone()).await?; // Parse response @@ -247,6 +259,7 @@ impl Provider for DatabricksProvider { system: &str, messages: &[Message], schema: &Value, + request_id: Option<&str>, ) -> Result { // 1. Build base payload (no tools) let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?; @@ -267,6 +280,17 @@ impl Provider for DatabricksProvider { }), ); + // Add client_request_id if provided + if let Some(req_id) = request_id { + payload + .as_object_mut() + .expect("payload should be an object") + .insert( + "client_request_id".to_string(), + serde_json::Value::String(req_id.to_string()), + ); + } + // 3. Call OpenAI let response = self.post(payload.clone()).await?; diff --git a/crates/goose-llm/src/providers/formats/databricks.rs b/crates/goose-llm/src/providers/formats/databricks.rs index d69c31bb..37343f2e 100644 --- a/crates/goose-llm/src/providers/formats/databricks.rs +++ b/crates/goose-llm/src/providers/formats/databricks.rs @@ -7,10 +7,7 @@ use crate::{ providers::{ base::Usage, errors::ProviderError, - utils::{ - convert_image, detect_image_path, is_valid_function_name, load_image_file, - sanitize_function_name, ImageFormat, - }, + utils::{convert_image, is_valid_function_name, sanitize_function_name, ImageFormat}, }, types::core::{Content, Role, Tool, ToolCall, ToolError}, }; @@ -34,30 +31,17 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< match content { MessageContent::Text(text) => { if !text.text.is_empty() { - // Check for image paths in the text - if let Some(image_path) = detect_image_path(&text.text) { - has_multiple_content = true; - // Try to load and convert the image - if let Ok(image) = load_image_file(image_path) { - content_array.push(json!({ - "type": "text", - "text": text.text - })); - content_array.push(convert_image(&image, image_format)); - } else { - content_array.push(json!({ - "type": "text", - "text": text.text - })); - } - } else { - content_array.push(json!({ - "type": "text", - "text": text.text - })); - } + content_array.push(json!({ + "type": "text", + "text": text.text + })); } } + MessageContent::Image(image) => { + // Handle direct image content + let converted_image = convert_image(image, image_format); + content_array.push(converted_image); + } MessageContent::Thinking(content) => { has_multiple_content = true; content_array.push(json!({ @@ -166,15 +150,6 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< } } } - MessageContent::Image(image) => { - // Handle direct image content - content_array.push(json!({ - "type": "image_url", - "image_url": { - "url": convert_image(image, image_format) - } - })); - } } } @@ -791,40 +766,6 @@ mod tests { Ok(()) } - #[test] - fn test_format_messages_with_image_path() -> anyhow::Result<()> { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir()?; - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data)?; - let png_path_str = png_path.to_str().unwrap(); - - // Create message with image path - let message = Message::user().with_text(format!("Here is an image: {}", png_path_str)); - let spec = format_messages(&[message], &ImageFormat::OpenAi); - - assert_eq!(spec.len(), 1); - assert_eq!(spec[0]["role"], "user"); - - // Content should be an array with text and image - let content = spec[0]["content"].as_array().unwrap(); - assert_eq!(content.len(), 2); - assert_eq!(content[0]["type"], "text"); - assert!(content[0]["text"].as_str().unwrap().contains(png_path_str)); - assert_eq!(content[1]["type"], "image_url"); - assert!(content[1]["image_url"]["url"] - .as_str() - .unwrap() - .starts_with("data:image/png;base64,")); - - Ok(()) - } - #[test] fn test_response_to_message_text() -> anyhow::Result<()> { let response = json!({ diff --git a/crates/goose-llm/src/providers/formats/openai.rs b/crates/goose-llm/src/providers/formats/openai.rs index afc48745..a2eb43b4 100644 --- a/crates/goose-llm/src/providers/formats/openai.rs +++ b/crates/goose-llm/src/providers/formats/openai.rs @@ -7,10 +7,7 @@ use crate::{ providers::{ base::Usage, errors::ProviderError, - utils::{ - convert_image, detect_image_path, is_valid_function_name, load_image_file, - sanitize_function_name, ImageFormat, - }, + utils::{convert_image, is_valid_function_name, sanitize_function_name, ImageFormat}, }, types::core::{Content, Role, Tool, ToolCall, ToolError}, }; @@ -31,23 +28,13 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< match content { MessageContent::Text(text) => { if !text.text.is_empty() { - // Check for image paths in the text - if let Some(image_path) = detect_image_path(&text.text) { - // Try to load and convert the image - if let Ok(image) = load_image_file(image_path) { - converted["content"] = json!([ - {"type": "text", "text": text.text}, - convert_image(&image, image_format) - ]); - } else { - // If image loading fails, just use the text - converted["content"] = json!(text.text); - } - } else { - converted["content"] = json!(text.text); - } + converted["content"] = json!(text.text); } } + MessageContent::Image(image) => { + // Handle direct image content + converted["content"] = json!([convert_image(image, image_format)]); + } MessageContent::Thinking(_) => { // Thinking blocks are not directly used in OpenAI format continue; @@ -134,10 +121,6 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< } } } - MessageContent::Image(image) => { - // Handle direct image content - converted["content"] = json!([convert_image(image, image_format)]); - } } } @@ -664,40 +647,6 @@ mod tests { Ok(()) } - #[test] - fn test_format_messages_with_image_path() -> anyhow::Result<()> { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir()?; - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data)?; - let png_path_str = png_path.to_str().unwrap(); - - // Create message with image path - let message = Message::user().with_text(format!("Here is an image: {}", png_path_str)); - let spec = format_messages(&[message], &ImageFormat::OpenAi); - - assert_eq!(spec.len(), 1); - assert_eq!(spec[0]["role"], "user"); - - // Content should be an array with text and image - let content = spec[0]["content"].as_array().unwrap(); - assert_eq!(content.len(), 2); - assert_eq!(content[0]["type"], "text"); - assert!(content[0]["text"].as_str().unwrap().contains(png_path_str)); - assert_eq!(content[1]["type"], "image_url"); - assert!(content[1]["image_url"]["url"] - .as_str() - .unwrap() - .starts_with("data:image/png;base64,")); - - Ok(()) - } - #[test] fn test_response_to_message_text() -> anyhow::Result<()> { let response = json!({ diff --git a/crates/goose-llm/src/providers/openai.rs b/crates/goose-llm/src/providers/openai.rs index bc0dc088..82d736f3 100644 --- a/crates/goose-llm/src/providers/openai.rs +++ b/crates/goose-llm/src/providers/openai.rs @@ -149,6 +149,7 @@ impl Provider for OpenAiProvider { system: &str, messages: &[Message], tools: &[Tool], + _request_id: Option<&str>, // OpenAI doesn't use request_id, so we ignore it ) -> Result { let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; @@ -175,6 +176,7 @@ impl Provider for OpenAiProvider { system: &str, messages: &[Message], schema: &Value, + _request_id: Option<&str>, // OpenAI doesn't use request_id, so we ignore it ) -> Result { // 1. Build base payload (no tools) let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?; diff --git a/crates/goose-llm/src/providers/utils.rs b/crates/goose-llm/src/providers/utils.rs index 1a3945dc..b6c00e7b 100644 --- a/crates/goose-llm/src/providers/utils.rs +++ b/crates/goose-llm/src/providers/utils.rs @@ -181,30 +181,6 @@ fn is_image_file(path: &Path) -> bool { false } -/// Detect if a string contains a path to an image file -pub fn detect_image_path(text: &str) -> Option<&str> { - // Basic image file extension check - let extensions = [".png", ".jpg", ".jpeg"]; - - // Find any word that ends with an image extension - for word in text.split_whitespace() { - if extensions - .iter() - .any(|ext| word.to_lowercase().ends_with(ext)) - { - let path = Path::new(word); - // Check if it's an absolute path and file exists - if path.is_absolute() && path.is_file() { - // Verify it's actually an image file - if is_image_file(path) { - return Some(word); - } - } - } - } - None -} - /// Convert a local image file to base64 encoded ImageContent pub fn load_image_file(path: &str) -> Result { let path = Path::new(path); @@ -267,81 +243,6 @@ pub fn emit_debug_trace( mod tests { use super::*; - #[test] - fn test_detect_image_path() { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir().unwrap(); - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data).unwrap(); - let png_path_str = png_path.to_str().unwrap(); - - // Create a fake PNG (wrong magic numbers) - let fake_png_path = temp_dir.path().join("fake.png"); - std::fs::write(&fake_png_path, b"not a real png").unwrap(); - - // Test with valid PNG file using absolute path - let text = format!("Here is an image {}", png_path_str); - assert_eq!(detect_image_path(&text), Some(png_path_str)); - - // Test with non-image file that has .png extension - let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap()); - assert_eq!(detect_image_path(&text), None); - - // Test with non-existent file - let text = "Here is a fake.png that doesn't exist"; - assert_eq!(detect_image_path(text), None); - - // Test with non-image file - let text = "Here is a file.txt"; - assert_eq!(detect_image_path(text), None); - - // Test with relative path (should not match) - let text = "Here is a relative/path/image.png"; - assert_eq!(detect_image_path(text), None); - } - - #[test] - fn test_load_image_file() { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir().unwrap(); - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data).unwrap(); - let png_path_str = png_path.to_str().unwrap(); - - // Create a fake PNG (wrong magic numbers) - let fake_png_path = temp_dir.path().join("fake.png"); - std::fs::write(&fake_png_path, b"not a real png").unwrap(); - let fake_png_path_str = fake_png_path.to_str().unwrap(); - - // Test loading valid PNG file - let result = load_image_file(png_path_str); - assert!(result.is_ok()); - let image = result.unwrap(); - assert_eq!(image.mime_type, "image/png"); - - // Test loading fake PNG file - let result = load_image_file(fake_png_path_str); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("not a valid image")); - - // Test non-existent file - let result = load_image_file("nonexistent.png"); - assert!(result.is_err()); - } - #[test] fn test_sanitize_function_name() { assert_eq!(sanitize_function_name("hello-world"), "hello-world"); diff --git a/crates/goose-llm/src/structured_outputs.rs b/crates/goose-llm/src/structured_outputs.rs index 8f478d8a..b6690b64 100644 --- a/crates/goose-llm/src/structured_outputs.rs +++ b/crates/goose-llm/src/structured_outputs.rs @@ -6,13 +6,14 @@ use crate::{ /// Generates a structured output based on the provided schema, /// system prompt and user messages. -#[uniffi::export(async_runtime = "tokio")] +#[uniffi::export(async_runtime = "tokio", default(request_id = None))] pub async fn generate_structured_outputs( provider_name: &str, provider_config: JsonValueFfi, system_prompt: &str, messages: &[Message], schema: JsonValueFfi, + request_id: Option, ) -> Result { // Use OpenAI models specifically for this task let model_name = if provider_name == "databricks" { @@ -23,7 +24,9 @@ pub async fn generate_structured_outputs( let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0)); let provider = create(provider_name, provider_config, model_cfg)?; - let resp = provider.extract(system_prompt, messages, &schema).await?; + let resp = provider + .extract(system_prompt, messages, &schema, request_id.as_deref()) + .await?; Ok(resp) } diff --git a/crates/goose-llm/src/types/completion.rs b/crates/goose-llm/src/types/completion.rs index 21e0bcd9..ce54f607 100644 --- a/crates/goose-llm/src/types/completion.rs +++ b/crates/goose-llm/src/types/completion.rs @@ -20,6 +20,7 @@ pub struct CompletionRequest { pub system_prompt_override: Option, pub messages: Vec, pub extensions: Vec, + pub request_id: Option, } impl CompletionRequest { @@ -40,10 +41,17 @@ impl CompletionRequest { system_preamble, messages, extensions, + request_id: None, } } + + pub fn with_request_id(mut self, request_id: String) -> Self { + self.request_id = Some(request_id); + self + } } +#[allow(clippy::too_many_arguments)] #[uniffi::export(default(system_preamble = None, system_prompt_override = None))] pub fn create_completion_request( provider_name: &str, @@ -53,8 +61,9 @@ pub fn create_completion_request( system_prompt_override: Option, messages: Vec, extensions: Vec, + request_id: Option, ) -> CompletionRequest { - CompletionRequest::new( + let mut request = CompletionRequest::new( provider_name.to_string(), provider_config, model_config, @@ -62,7 +71,13 @@ pub fn create_completion_request( system_prompt_override, messages, extensions, - ) + ); + + if let Some(req_id) = request_id { + request = request.with_request_id(req_id); + } + + request } uniffi::custom_type!(CompletionRequest, String, { diff --git a/crates/goose-llm/tests/extract_session_name.rs b/crates/goose-llm/tests/extract_session_name.rs index 5326fdbe..58d0a6b4 100644 --- a/crates/goose-llm/tests/extract_session_name.rs +++ b/crates/goose-llm/tests/extract_session_name.rs @@ -22,7 +22,7 @@ async fn _generate_session_name(messages: &[Message]) -> Result Result {}", provider_type, resp.data);