mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 06:04:23 +01:00
[goose-llm] fix image content bug, add optional request_id field (#3439)
This commit is contained in:
115
bindings/kotlin/example/RuntimeStats.kt
Normal file
115
bindings/kotlin/example/RuntimeStats.kt
Normal file
@@ -0,0 +1,115 @@
|
||||
import kotlin.system.measureNanoTime
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import uniffi.goose_llm.*
|
||||
|
||||
import java.net.URI
|
||||
import java.net.http.HttpClient
|
||||
import java.net.http.HttpRequest
|
||||
import java.net.http.HttpResponse
|
||||
|
||||
/* ---------- Goose helpers ---------- */
|
||||
|
||||
fun buildProviderConfig(host: String, token: String): String =
|
||||
"""{ "host": "$host", "token": "$token" }"""
|
||||
|
||||
suspend fun timeGooseCall(
|
||||
modelCfg: ModelConfig,
|
||||
providerName: String,
|
||||
providerCfg: String
|
||||
): Pair<Double, CompletionResponse> {
|
||||
|
||||
val req = createCompletionRequest(
|
||||
providerName,
|
||||
providerCfg,
|
||||
modelCfg,
|
||||
systemPreamble = "You are a helpful assistant.",
|
||||
messages = listOf(
|
||||
Message(
|
||||
Role.USER,
|
||||
System.currentTimeMillis() / 1000,
|
||||
listOf(MessageContent.Text(TextContent("Write me a 1000 word chapter about learning Go vs Rust in the world of LLMs and AI.")))
|
||||
)
|
||||
),
|
||||
extensions = emptyList()
|
||||
)
|
||||
|
||||
lateinit var resp: CompletionResponse
|
||||
val wallMs = measureNanoTime { resp = completion(req) } / 1_000_000.0
|
||||
return wallMs to resp
|
||||
}
|
||||
|
||||
/* ---------- OpenAI helpers ---------- */
|
||||
|
||||
fun timeOpenAiCall(client: HttpClient, apiKey: String): Double {
|
||||
val body = """
|
||||
{
|
||||
"model": "gpt-4.1",
|
||||
"max_tokens": 500,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write me a 1000 word chapter about learning Go vs Rust in the world of LLMs and AI."}
|
||||
]
|
||||
}
|
||||
""".trimIndent()
|
||||
|
||||
val request = HttpRequest.newBuilder()
|
||||
.uri(URI.create("https://api.openai.com/v1/chat/completions"))
|
||||
.header("Authorization", "Bearer $apiKey")
|
||||
.header("Content-Type", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(body))
|
||||
.build()
|
||||
|
||||
val wallMs = measureNanoTime {
|
||||
client.send(request, HttpResponse.BodyHandlers.ofString())
|
||||
} / 1_000_000.0
|
||||
|
||||
return wallMs
|
||||
}
|
||||
|
||||
/* ---------- main ---------- */
|
||||
|
||||
fun main() = runBlocking {
|
||||
/* Goose provider setup */
|
||||
val providerName = "databricks"
|
||||
val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set")
|
||||
val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set")
|
||||
val providerCfg = buildProviderConfig(host, token)
|
||||
|
||||
/* OpenAI setup */
|
||||
val openAiKey = System.getenv("OPENAI_API_KEY") ?: error("OPENAI_API_KEY not set")
|
||||
val httpClient = HttpClient.newBuilder().build()
|
||||
|
||||
val gooseModels = listOf("goose-claude-4-sonnet", "goose-gpt-4-1")
|
||||
val runsPerModel = 3
|
||||
|
||||
/* --- Goose timing --- */
|
||||
for (model in gooseModels) {
|
||||
val maxTokens = 500
|
||||
val cfg = ModelConfig(model, 100_000u, 0.0f, maxTokens)
|
||||
var wallSum = 0.0
|
||||
var gooseSum = 0.0
|
||||
|
||||
println("=== Goose: $model ===")
|
||||
repeat(runsPerModel) { run ->
|
||||
val (wall, resp) = timeGooseCall(cfg, providerName, providerCfg)
|
||||
val gooseMs = resp.runtimeMetrics.totalTimeSec * 1_000
|
||||
val overhead = wall - gooseMs
|
||||
wallSum += wall
|
||||
gooseSum += gooseMs
|
||||
println("run ${run + 1}: wall = %.1f ms | goose-llm = %.1f ms | overhead = %.1f ms"
|
||||
.format(wall, gooseMs, overhead))
|
||||
}
|
||||
println("-- avg wall = %.1f ms | avg overhead = %.1f ms --\n"
|
||||
.format(wallSum / runsPerModel, (wallSum - gooseSum) / runsPerModel))
|
||||
}
|
||||
|
||||
/* --- OpenAI direct timing --- */
|
||||
var oaSum = 0.0
|
||||
println("=== OpenAI: gpt-4.1 (direct HTTPS) ===")
|
||||
repeat(runsPerModel) { run ->
|
||||
val wall = timeOpenAiCall(httpClient, openAiKey)
|
||||
oaSum += wall
|
||||
println("run ${run + 1}: wall = %.1f ms".format(wall))
|
||||
}
|
||||
println("-- avg wall = %.1f ms --".format(oaSum / runsPerModel))
|
||||
}
|
||||
@@ -1,292 +1,228 @@
|
||||
import java.io.File
|
||||
import java.util.Base64
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import uniffi.goose_llm.*
|
||||
|
||||
fun main() = runBlocking {
|
||||
val now = System.currentTimeMillis() / 1000
|
||||
val msgs = listOf(
|
||||
// 1) User sends a plain-text prompt
|
||||
Message(
|
||||
role = Role.USER,
|
||||
created = now,
|
||||
content = listOf(
|
||||
MessageContent.Text(
|
||||
TextContent("What is 7 x 6?")
|
||||
)
|
||||
)
|
||||
),
|
||||
/* ---------- shared helpers ---------- */
|
||||
|
||||
// 2) Assistant makes a tool request (ToolReq) to calculate 7×6
|
||||
Message(
|
||||
role = Role.ASSISTANT,
|
||||
created = now + 2,
|
||||
content = listOf(
|
||||
MessageContent.ToolReq(
|
||||
ToolRequest(
|
||||
id = "calc1",
|
||||
toolCall = """
|
||||
{
|
||||
"status": "success",
|
||||
"value": {
|
||||
"name": "calculator_extension__toolname",
|
||||
"arguments": {
|
||||
"operation": "doesnotexist",
|
||||
"numbers": [7, 6]
|
||||
},
|
||||
"needsApproval": false
|
||||
}
|
||||
}
|
||||
""".trimIndent()
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
|
||||
// 3) User (on behalf of the tool) responds with the tool result (ToolResp)
|
||||
Message(
|
||||
role = Role.USER,
|
||||
created = now + 3,
|
||||
content = listOf(
|
||||
MessageContent.ToolResp(
|
||||
ToolResponse(
|
||||
id = "calc1",
|
||||
toolResult = """
|
||||
{
|
||||
"status": "error",
|
||||
"error": "Invalid value for operation: 'doesnotexist'. Valid values are: ['add', 'subtract', 'multiply', 'divide']"
|
||||
}
|
||||
""".trimIndent()
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
|
||||
// 4) Assistant makes a tool request (ToolReq) to calculate 7×6
|
||||
Message(
|
||||
role = Role.ASSISTANT,
|
||||
created = now + 4,
|
||||
content = listOf(
|
||||
MessageContent.ToolReq(
|
||||
ToolRequest(
|
||||
id = "calc1",
|
||||
toolCall = """
|
||||
{
|
||||
"status": "success",
|
||||
"value": {
|
||||
"name": "calculator_extension__toolname",
|
||||
"arguments": {
|
||||
"operation": "multiply",
|
||||
"numbers": [7, 6]
|
||||
},
|
||||
"needsApproval": false
|
||||
}
|
||||
}
|
||||
""".trimIndent()
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
|
||||
// 5) User (on behalf of the tool) responds with the tool result (ToolResp)
|
||||
Message(
|
||||
role = Role.USER,
|
||||
created = now + 5,
|
||||
content = listOf(
|
||||
MessageContent.ToolResp(
|
||||
ToolResponse(
|
||||
id = "calc1",
|
||||
toolResult = """
|
||||
{
|
||||
"status": "success",
|
||||
"value": [
|
||||
{"type": "text", "text": "42"}
|
||||
]
|
||||
}
|
||||
""".trimIndent()
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
printMessages(msgs)
|
||||
println("---\n")
|
||||
|
||||
// Setup provider
|
||||
val providerName = "databricks"
|
||||
val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set")
|
||||
val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set")
|
||||
val providerConfig = """{"host": "$host", "token": "$token"}"""
|
||||
|
||||
println("Provider Name: $providerName")
|
||||
println("Provider Config: $providerConfig")
|
||||
|
||||
|
||||
val sessionName = generateSessionName(providerName, providerConfig, msgs)
|
||||
println("\nSession Name: $sessionName")
|
||||
|
||||
val tooltip = generateTooltip(providerName, providerConfig, msgs)
|
||||
println("\nTooltip: $tooltip")
|
||||
|
||||
// Completion
|
||||
val modelName = "goose-gpt-4-1"
|
||||
val modelConfig = ModelConfig(
|
||||
modelName,
|
||||
100000u, // UInt
|
||||
0.1f, // Float
|
||||
200 // Int
|
||||
)
|
||||
fun buildProviderConfig(host: String, token: String, imageFormat: String = "OpenAi"): String = """
|
||||
{
|
||||
"host": "$host",
|
||||
"token": "$token",
|
||||
"image_format": "$imageFormat"
|
||||
}
|
||||
""".trimIndent()
|
||||
|
||||
fun calculatorExtension(): ExtensionConfig {
|
||||
val calculatorTool = createToolConfig(
|
||||
name = "calculator",
|
||||
name = "calculator",
|
||||
description = "Perform basic arithmetic operations",
|
||||
inputSchema = """
|
||||
{
|
||||
"type": "object",
|
||||
"required": ["operation", "numbers"],
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["add", "subtract", "multiply", "divide"],
|
||||
"description": "The arithmetic operation to perform"
|
||||
},
|
||||
"numbers": {
|
||||
"type": "array",
|
||||
"items": { "type": "number" },
|
||||
"description": "List of numbers to operate on in order"
|
||||
}
|
||||
"type": "object",
|
||||
"required": ["operation", "numbers"],
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["add", "subtract", "multiply", "divide"],
|
||||
"description": "The arithmetic operation to perform"
|
||||
},
|
||||
"numbers": {
|
||||
"type": "array",
|
||||
"items": { "type": "number" },
|
||||
"description": "List of numbers to operate on in order"
|
||||
}
|
||||
}
|
||||
}
|
||||
""".trimIndent(),
|
||||
approvalMode = ToolApprovalMode.AUTO
|
||||
)
|
||||
|
||||
val calculator_extension = ExtensionConfig(
|
||||
name = "calculator_extension",
|
||||
return ExtensionConfig(
|
||||
name = "calculator_extension",
|
||||
instructions = "This extension provides a calculator tool.",
|
||||
tools = listOf(calculatorTool)
|
||||
tools = listOf(calculatorTool)
|
||||
)
|
||||
|
||||
val extensions = listOf(calculator_extension)
|
||||
val systemPreamble = "You are a helpful assistant."
|
||||
|
||||
// Testing with tool calls with an error in tool name
|
||||
val reqToolErr = createCompletionRequest(
|
||||
providerName,
|
||||
providerConfig,
|
||||
modelConfig,
|
||||
systemPreamble,
|
||||
messages = listOf(
|
||||
Message(
|
||||
role = Role.USER,
|
||||
created = now,
|
||||
content = listOf(
|
||||
MessageContent.Text(
|
||||
TextContent("What is 7 x 6?")
|
||||
)
|
||||
)
|
||||
)),
|
||||
extensions = extensions
|
||||
)
|
||||
|
||||
val respToolErr = completion(reqToolErr)
|
||||
println("\nCompletion Response (one msg):\n${respToolErr.message}")
|
||||
println()
|
||||
|
||||
val reqAll = createCompletionRequest(
|
||||
providerName,
|
||||
providerConfig,
|
||||
modelConfig,
|
||||
systemPreamble,
|
||||
messages = msgs,
|
||||
extensions = extensions
|
||||
)
|
||||
|
||||
val respAll = completion(reqAll)
|
||||
println("\nCompletion Response (all msgs):\n${respAll.message}")
|
||||
println()
|
||||
|
||||
// ---- UI Extraction (custom schema) ----
|
||||
runUiExtraction(providerName, providerConfig)
|
||||
|
||||
// --- Prompt Override ---
|
||||
val prompt_req = createCompletionRequest(
|
||||
providerName,
|
||||
providerConfig,
|
||||
modelConfig,
|
||||
systemPreamble = null,
|
||||
systemPromptOverride = "You are a bot named Tile Creator. Your task is to create a tile based on the user's input.",
|
||||
messages=listOf(
|
||||
Message(
|
||||
role = Role.USER,
|
||||
created = now,
|
||||
content = listOf(
|
||||
MessageContent.Text(
|
||||
TextContent("What's your name?")
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
extensions=emptyList()
|
||||
)
|
||||
|
||||
val prompt_resp = completion(prompt_req)
|
||||
|
||||
println("\nPrompt Override Response:\n${prompt_resp.message}")
|
||||
}
|
||||
|
||||
/* ---------- demos ---------- */
|
||||
|
||||
suspend fun runCalculatorDemo(
|
||||
modelConfig: ModelConfig,
|
||||
providerName: String,
|
||||
providerConfig: String
|
||||
) {
|
||||
val now = System.currentTimeMillis() / 1000
|
||||
val msgs = listOf(
|
||||
// same conversation you already had
|
||||
Message(Role.USER, now, listOf(MessageContent.Text(TextContent("What is 7 x 6?")))),
|
||||
Message(Role.ASSISTANT, now + 2, listOf(MessageContent.ToolReq(
|
||||
ToolRequest(
|
||||
id = "calc1",
|
||||
toolCall = """
|
||||
{
|
||||
"status": "success",
|
||||
"value": {
|
||||
"name": "calculator_extension__toolname",
|
||||
"arguments": { "operation": "doesnotexist", "numbers": [7,6] },
|
||||
"needsApproval": false
|
||||
}
|
||||
}
|
||||
""".trimIndent()
|
||||
)))),
|
||||
Message(Role.USER, now + 3, listOf(MessageContent.ToolResp(
|
||||
ToolResponse(
|
||||
id = "calc1",
|
||||
toolResult = """
|
||||
{
|
||||
"status": "error",
|
||||
"error": "Invalid value for operation: 'doesnotexist'. Valid values are: ['add','subtract','multiply','divide']"
|
||||
}
|
||||
""".trimIndent()
|
||||
)))),
|
||||
Message(Role.ASSISTANT, now + 4, listOf(MessageContent.ToolReq(
|
||||
ToolRequest(
|
||||
id = "calc1",
|
||||
toolCall = """
|
||||
{
|
||||
"status": "success",
|
||||
"value": {
|
||||
"name": "calculator_extension__toolname",
|
||||
"arguments": { "operation": "multiply", "numbers": [7,6] },
|
||||
"needsApproval": false
|
||||
}
|
||||
}
|
||||
""".trimIndent()
|
||||
)))),
|
||||
Message(Role.USER, now + 5, listOf(MessageContent.ToolResp(
|
||||
ToolResponse(
|
||||
id = "calc1",
|
||||
toolResult = """
|
||||
{
|
||||
"status": "success",
|
||||
"value": [ { "type": "text", "text": "42" } ]
|
||||
}
|
||||
""".trimIndent()
|
||||
))))
|
||||
)
|
||||
|
||||
/* one-shot prompt with error */
|
||||
val reqErr = createCompletionRequest(
|
||||
providerName, providerConfig, modelConfig,
|
||||
"You are a helpful assistant.",
|
||||
messages = listOf(msgs.first()),
|
||||
extensions = listOf(calculatorExtension())
|
||||
)
|
||||
println("\n[${modelConfig.modelName}] Calculator (single-msg) → ${completion(reqErr).message}")
|
||||
|
||||
/* full conversation */
|
||||
val reqAll = createCompletionRequest(
|
||||
providerName, providerConfig, modelConfig,
|
||||
"You are a helpful assistant.",
|
||||
messages = msgs,
|
||||
extensions = listOf(calculatorExtension())
|
||||
)
|
||||
println("[${modelConfig.modelName}] Calculator (full chat) → ${completion(reqAll).message}")
|
||||
}
|
||||
|
||||
suspend fun runImageExample(
|
||||
modelConfig: ModelConfig,
|
||||
providerName: String,
|
||||
providerConfig: String
|
||||
) {
|
||||
val imagePath = "../../crates/goose/examples/test_assets/test_image.png"
|
||||
val base64Image = Base64.getEncoder().encodeToString(File(imagePath).readBytes())
|
||||
val now = System.currentTimeMillis() / 1000
|
||||
|
||||
val msgs = listOf(
|
||||
Message(Role.USER, now, listOf(
|
||||
MessageContent.Text(TextContent("What is in this image?")),
|
||||
MessageContent.Image(ImageContent(base64Image, "image/png"))
|
||||
)),
|
||||
)
|
||||
|
||||
val req = createCompletionRequest(
|
||||
providerName, providerConfig, modelConfig,
|
||||
"You are a helpful assistant. Please describe any text you see in the image.",
|
||||
messages = msgs,
|
||||
extensions = emptyList()
|
||||
)
|
||||
|
||||
println("\n[${modelConfig.modelName}] Image example → ${completion(req).message}")
|
||||
}
|
||||
|
||||
suspend fun runPromptOverride(
|
||||
modelConfig: ModelConfig,
|
||||
providerName: String,
|
||||
providerConfig: String
|
||||
) {
|
||||
val now = System.currentTimeMillis() / 1000
|
||||
val req = createCompletionRequest(
|
||||
providerName, providerConfig, modelConfig,
|
||||
systemPreamble = null,
|
||||
systemPromptOverride = "You are a bot named Tile Creator. Your task is to create a tile based on the user's input.",
|
||||
messages = listOf(
|
||||
Message(Role.USER, now, listOf(MessageContent.Text(TextContent("What's your name?"))))
|
||||
),
|
||||
extensions = emptyList()
|
||||
)
|
||||
println("\n[${modelConfig.modelName}] Prompt override → ${completion(req).message}")
|
||||
}
|
||||
|
||||
suspend fun runUiExtraction(providerName: String, providerConfig: String) {
|
||||
val systemPrompt = "You are a UI generator AI. Convert the user input into a JSON-driven UI."
|
||||
val messages = listOf(
|
||||
Message(
|
||||
role = Role.USER,
|
||||
created = System.currentTimeMillis() / 1000,
|
||||
content = listOf(
|
||||
MessageContent.Text(
|
||||
TextContent("Make a User Profile Form")
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
val schema = """{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["div","button","header","section","field","form"]
|
||||
},
|
||||
"label": { "type": "string" },
|
||||
"children": {
|
||||
"type": "array",
|
||||
"items": { "${'$'}ref": "#" }
|
||||
},
|
||||
"attributes": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"value": { "type": "string" }
|
||||
},
|
||||
"required": ["name","value"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["type","label","children","attributes"],
|
||||
"additionalProperties": false
|
||||
}""".trimIndent();
|
||||
val schema = /* same JSON schema as before */ """
|
||||
{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"type":{"type":"string","enum":["div","button","header","section","field","form"]},
|
||||
"label":{"type":"string"},
|
||||
"children":{"type":"array","items":{"${'$'}ref":"#"}},
|
||||
"attributes":{"type":"array","items":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"string"}},"required":["name","value"],"additionalProperties":false}}
|
||||
},
|
||||
"required":["type","label","children","attributes"],
|
||||
"additionalProperties":false
|
||||
}
|
||||
""".trimIndent()
|
||||
|
||||
try {
|
||||
val response = generateStructuredOutputs(
|
||||
providerName = providerName,
|
||||
providerConfig = providerConfig,
|
||||
systemPrompt = systemPrompt,
|
||||
messages = messages,
|
||||
schema = schema
|
||||
)
|
||||
println("\nUI Extraction Output:\n${response}")
|
||||
} catch (e: ProviderException) {
|
||||
println("\nUI Extraction failed:\n${e.message}")
|
||||
}
|
||||
val messages = listOf(
|
||||
Message(Role.USER, System.currentTimeMillis()/1000,
|
||||
listOf(MessageContent.Text(TextContent("Make a User Profile Form"))))
|
||||
)
|
||||
|
||||
val res = generateStructuredOutputs(
|
||||
providerName, providerConfig,
|
||||
systemPrompt = "You are a UI generator AI. Convert the user input into a JSON-driven UI.",
|
||||
messages = messages,
|
||||
schema = schema
|
||||
)
|
||||
println("\n[UI-Extraction] → $res")
|
||||
}
|
||||
|
||||
/* ---------- entry-point ---------- */
|
||||
|
||||
fun main() = runBlocking {
|
||||
/* --- provider setup --- */
|
||||
val providerName = "databricks"
|
||||
val host = System.getenv("DATABRICKS_HOST") ?: error("DATABRICKS_HOST not set")
|
||||
val token = System.getenv("DATABRICKS_TOKEN") ?: error("DATABRICKS_TOKEN not set")
|
||||
val providerConfig = buildProviderConfig(host, token)
|
||||
|
||||
println("Provider: $providerName")
|
||||
println("Config : $providerConfig\n")
|
||||
|
||||
/* --- run demos for each model --- */
|
||||
// NOTE: `claude-3-5-haiku` does NOT support images
|
||||
val modelNames = listOf("kgoose-gpt-4o", "goose-claude-4-sonnet")
|
||||
|
||||
for (name in modelNames) {
|
||||
val modelConfig = ModelConfig(name, 100000u, 0.1f, 200)
|
||||
println("\n===== Running demos for model: $name =====")
|
||||
|
||||
runCalculatorDemo(modelConfig, providerName, providerConfig)
|
||||
runImageExample(modelConfig, providerName, providerConfig)
|
||||
runPromptOverride(modelConfig, providerName, providerConfig)
|
||||
println("===== End demos for $name =====\n")
|
||||
}
|
||||
|
||||
/* UI extraction is model-agnostic, so run it once */
|
||||
runUiExtraction(providerName, providerConfig)
|
||||
}
|
||||
|
||||
@@ -833,6 +833,7 @@ internal interface UniffiLib : Library {
|
||||
`systemPromptOverride`: RustBuffer.ByValue,
|
||||
`messages`: RustBuffer.ByValue,
|
||||
`extensions`: RustBuffer.ByValue,
|
||||
`requestId`: RustBuffer.ByValue,
|
||||
uniffi_out_err: UniffiRustCallStatus,
|
||||
): RustBuffer.ByValue
|
||||
|
||||
@@ -848,6 +849,7 @@ internal interface UniffiLib : Library {
|
||||
`providerName`: RustBuffer.ByValue,
|
||||
`providerConfig`: RustBuffer.ByValue,
|
||||
`messages`: RustBuffer.ByValue,
|
||||
`requestId`: RustBuffer.ByValue,
|
||||
): Long
|
||||
|
||||
fun uniffi_goose_llm_fn_func_generate_structured_outputs(
|
||||
@@ -856,12 +858,14 @@ internal interface UniffiLib : Library {
|
||||
`systemPrompt`: RustBuffer.ByValue,
|
||||
`messages`: RustBuffer.ByValue,
|
||||
`schema`: RustBuffer.ByValue,
|
||||
`requestId`: RustBuffer.ByValue,
|
||||
): Long
|
||||
|
||||
fun uniffi_goose_llm_fn_func_generate_tooltip(
|
||||
`providerName`: RustBuffer.ByValue,
|
||||
`providerConfig`: RustBuffer.ByValue,
|
||||
`messages`: RustBuffer.ByValue,
|
||||
`requestId`: RustBuffer.ByValue,
|
||||
): Long
|
||||
|
||||
fun uniffi_goose_llm_fn_func_print_messages(
|
||||
@@ -1101,19 +1105,19 @@ private fun uniffiCheckApiChecksums(lib: IntegrityCheckingUniffiLib) {
|
||||
if (lib.uniffi_goose_llm_checksum_func_completion() != 47457.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 50798.toShort()) {
|
||||
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 15391.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 49910.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 64087.toShort()) {
|
||||
if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 34350.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 43426.toShort()) {
|
||||
if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 4576.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 41121.toShort()) {
|
||||
if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 36439.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_print_messages() != 30278.toShort()) {
|
||||
@@ -2960,6 +2964,7 @@ fun `createCompletionRequest`(
|
||||
`systemPromptOverride`: kotlin.String? = null,
|
||||
`messages`: List<Message>,
|
||||
`extensions`: List<ExtensionConfig>,
|
||||
`requestId`: kotlin.String? = null,
|
||||
): CompletionRequest =
|
||||
FfiConverterTypeCompletionRequest.lift(
|
||||
uniffiRustCall { _status ->
|
||||
@@ -2971,6 +2976,7 @@ fun `createCompletionRequest`(
|
||||
FfiConverterOptionalString.lower(`systemPromptOverride`),
|
||||
FfiConverterSequenceTypeMessage.lower(`messages`),
|
||||
FfiConverterSequenceTypeExtensionConfig.lower(`extensions`),
|
||||
FfiConverterOptionalString.lower(`requestId`),
|
||||
_status,
|
||||
)
|
||||
},
|
||||
@@ -3003,12 +3009,14 @@ suspend fun `generateSessionName`(
|
||||
`providerName`: kotlin.String,
|
||||
`providerConfig`: Value,
|
||||
`messages`: List<Message>,
|
||||
`requestId`: kotlin.String? = null,
|
||||
): kotlin.String =
|
||||
uniffiRustCallAsync(
|
||||
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_session_name(
|
||||
FfiConverterString.lower(`providerName`),
|
||||
FfiConverterTypeValue.lower(`providerConfig`),
|
||||
FfiConverterSequenceTypeMessage.lower(`messages`),
|
||||
FfiConverterOptionalString.lower(`requestId`),
|
||||
),
|
||||
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
|
||||
{ future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) },
|
||||
@@ -3031,6 +3039,7 @@ suspend fun `generateStructuredOutputs`(
|
||||
`systemPrompt`: kotlin.String,
|
||||
`messages`: List<Message>,
|
||||
`schema`: Value,
|
||||
`requestId`: kotlin.String? = null,
|
||||
): ProviderExtractResponse =
|
||||
uniffiRustCallAsync(
|
||||
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_structured_outputs(
|
||||
@@ -3039,6 +3048,7 @@ suspend fun `generateStructuredOutputs`(
|
||||
FfiConverterString.lower(`systemPrompt`),
|
||||
FfiConverterSequenceTypeMessage.lower(`messages`),
|
||||
FfiConverterTypeValue.lower(`schema`),
|
||||
FfiConverterOptionalString.lower(`requestId`),
|
||||
),
|
||||
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
|
||||
{ future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) },
|
||||
@@ -3059,12 +3069,14 @@ suspend fun `generateTooltip`(
|
||||
`providerName`: kotlin.String,
|
||||
`providerConfig`: Value,
|
||||
`messages`: List<Message>,
|
||||
`requestId`: kotlin.String? = null,
|
||||
): kotlin.String =
|
||||
uniffiRustCallAsync(
|
||||
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_tooltip(
|
||||
FfiConverterString.lower(`providerName`),
|
||||
FfiConverterTypeValue.lower(`providerConfig`),
|
||||
FfiConverterSequenceTypeMessage.lower(`messages`),
|
||||
FfiConverterOptionalString.lower(`requestId`),
|
||||
),
|
||||
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
|
||||
{ future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) },
|
||||
|
||||
@@ -64,6 +64,10 @@ path = "uniffi-bindgen.rs"
|
||||
name = "simple"
|
||||
path = "examples/simple.rs"
|
||||
|
||||
[[example]]
|
||||
name = "image"
|
||||
path = "examples/image.rs"
|
||||
|
||||
[[example]]
|
||||
name = "prompt_override"
|
||||
path = "examples/prompt_override.rs"
|
||||
|
||||
53
crates/goose-llm/examples/image.rs
Normal file
53
crates/goose-llm/examples/image.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use anyhow::Result;
|
||||
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
|
||||
use goose_llm::{
|
||||
completion,
|
||||
message::MessageContent,
|
||||
types::completion::{CompletionRequest, CompletionResponse},
|
||||
Message, ModelConfig,
|
||||
};
|
||||
use serde_json::json;
|
||||
use std::{fs, vec};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let provider = "databricks";
|
||||
let provider_config = json!({
|
||||
"host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"),
|
||||
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
|
||||
});
|
||||
let model_name = "goose-claude-4-sonnet"; // "gpt-4o";
|
||||
let model_config = ModelConfig::new(model_name.to_string());
|
||||
|
||||
let system_preamble = "You are a helpful assistant.";
|
||||
|
||||
// Read and encode test image
|
||||
let image_data = fs::read("examples/test_assets/test_image.png")?;
|
||||
let base64_image = BASE64.encode(image_data);
|
||||
|
||||
let user_msg = Message::user()
|
||||
.with_text("What do you see in this image?")
|
||||
.with_content(MessageContent::image(base64_image, "image/png"));
|
||||
|
||||
let messages = vec![user_msg];
|
||||
|
||||
let completion_response: CompletionResponse = completion(
|
||||
CompletionRequest::new(
|
||||
provider.to_string(),
|
||||
provider_config.clone(),
|
||||
model_config.clone(),
|
||||
Some(system_preamble.to_string()),
|
||||
None,
|
||||
messages,
|
||||
vec![],
|
||||
)
|
||||
.with_request_id("test-image-1".to_string()),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Print the response
|
||||
println!("\nCompletion Response:");
|
||||
println!("{}", serde_json::to_string_pretty(&completion_response)?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -116,7 +116,7 @@ async fn main() -> Result<()> {
|
||||
println!("\nCompletion Response:");
|
||||
println!("{}", serde_json::to_string_pretty(&completion_response)?);
|
||||
|
||||
let tooltip = generate_tooltip(provider, provider_config.clone(), &messages).await?;
|
||||
let tooltip = generate_tooltip(provider, provider_config.clone(), &messages, None).await?;
|
||||
println!("\nTooltip: {}", tooltip);
|
||||
}
|
||||
|
||||
|
||||
BIN
crates/goose-llm/examples/test_assets/test_image.png
Normal file
BIN
crates/goose-llm/examples/test_assets/test_image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.2 KiB |
@@ -46,7 +46,12 @@ pub async fn completion(req: CompletionRequest) -> Result<CompletionResponse, Co
|
||||
// Call the LLM provider
|
||||
let start_provider = Instant::now();
|
||||
let mut response = provider
|
||||
.complete(&system_prompt, &req.messages, &tools)
|
||||
.complete(
|
||||
&system_prompt,
|
||||
&req.messages,
|
||||
&tools,
|
||||
req.request_id.as_deref(),
|
||||
)
|
||||
.await?;
|
||||
let provider_elapsed_sec = start_provider.elapsed().as_secs_f32();
|
||||
let usage_tokens = response.usage.total_tokens;
|
||||
|
||||
@@ -48,11 +48,12 @@ fn build_system_prompt() -> String {
|
||||
}
|
||||
|
||||
/// Generates a short (≤4 words) session name
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
#[uniffi::export(async_runtime = "tokio", default(request_id = None))]
|
||||
pub async fn generate_session_name(
|
||||
provider_name: &str,
|
||||
provider_config: JsonValueFfi,
|
||||
messages: &[Message],
|
||||
request_id: Option<String>,
|
||||
) -> Result<String, ProviderError> {
|
||||
// Collect up to the first 3 user messages (truncated to 300 chars each)
|
||||
let context: Vec<String> = messages
|
||||
@@ -90,6 +91,7 @@ pub async fn generate_session_name(
|
||||
&system_prompt,
|
||||
&[Message::user().with_text(&user_msg_text)],
|
||||
schema,
|
||||
request_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -52,11 +52,12 @@ fn build_system_prompt() -> String {
|
||||
|
||||
/// Generates a tooltip summarizing the last two messages in the session,
|
||||
/// including any tool calls or results.
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
#[uniffi::export(async_runtime = "tokio", default(request_id = None))]
|
||||
pub async fn generate_tooltip(
|
||||
provider_name: &str,
|
||||
provider_config: JsonValueFfi,
|
||||
messages: &[Message],
|
||||
request_id: Option<String>,
|
||||
) -> Result<String, ProviderError> {
|
||||
// Need at least two messages to generate a tooltip
|
||||
if messages.len() < 2 {
|
||||
@@ -148,6 +149,7 @@ pub async fn generate_tooltip(
|
||||
&system_prompt,
|
||||
&[Message::user().with_text(&user_msg_text)],
|
||||
schema,
|
||||
request_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -69,6 +69,7 @@ pub trait Provider: Send + Sync {
|
||||
/// * `system` - The system prompt that guides the model's behavior
|
||||
/// * `messages` - The conversation history as a sequence of messages
|
||||
/// * `tools` - Optional list of tools the model can use
|
||||
/// * `request_id` - Optional request ID (only used by some providers like Databricks)
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple containing the model's response message and provider usage statistics
|
||||
@@ -81,6 +82,7 @@ pub trait Provider: Send + Sync {
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
request_id: Option<&str>,
|
||||
) -> Result<ProviderCompleteResponse, ProviderError>;
|
||||
|
||||
/// Structured extraction: always JSON‐Schema
|
||||
@@ -90,6 +92,7 @@ pub trait Provider: Send + Sync {
|
||||
/// * `messages` – conversation history
|
||||
/// * `schema` – a JSON‐Schema for the expected output.
|
||||
/// Will set strict=true for OpenAI & Databricks.
|
||||
/// * `request_id` - Optional request ID (only used by some providers like Databricks)
|
||||
///
|
||||
/// # Returns
|
||||
/// A `ProviderExtractResponse` whose `data` is a JSON object matching `schema`.
|
||||
@@ -102,6 +105,7 @@ pub trait Provider: Send + Sync {
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
schema: &serde_json::Value,
|
||||
request_id: Option<&str>,
|
||||
) -> Result<ProviderExtractResponse, ProviderError>;
|
||||
}
|
||||
|
||||
|
||||
@@ -210,6 +210,7 @@ impl Provider for DatabricksProvider {
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
request_id: Option<&str>,
|
||||
) -> Result<ProviderCompleteResponse, ProviderError> {
|
||||
let mut payload = create_request(
|
||||
&self.model,
|
||||
@@ -224,6 +225,17 @@ impl Provider for DatabricksProvider {
|
||||
.expect("payload should have model key")
|
||||
.remove("model");
|
||||
|
||||
// Add client_request_id if provided
|
||||
if let Some(req_id) = request_id {
|
||||
payload
|
||||
.as_object_mut()
|
||||
.expect("payload should be an object")
|
||||
.insert(
|
||||
"client_request_id".to_string(),
|
||||
serde_json::Value::String(req_id.to_string()),
|
||||
);
|
||||
}
|
||||
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
// Parse response
|
||||
@@ -247,6 +259,7 @@ impl Provider for DatabricksProvider {
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
schema: &Value,
|
||||
request_id: Option<&str>,
|
||||
) -> Result<ProviderExtractResponse, ProviderError> {
|
||||
// 1. Build base payload (no tools)
|
||||
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
|
||||
@@ -267,6 +280,17 @@ impl Provider for DatabricksProvider {
|
||||
}),
|
||||
);
|
||||
|
||||
// Add client_request_id if provided
|
||||
if let Some(req_id) = request_id {
|
||||
payload
|
||||
.as_object_mut()
|
||||
.expect("payload should be an object")
|
||||
.insert(
|
||||
"client_request_id".to_string(),
|
||||
serde_json::Value::String(req_id.to_string()),
|
||||
);
|
||||
}
|
||||
|
||||
// 3. Call OpenAI
|
||||
let response = self.post(payload.clone()).await?;
|
||||
|
||||
|
||||
@@ -7,10 +7,7 @@ use crate::{
|
||||
providers::{
|
||||
base::Usage,
|
||||
errors::ProviderError,
|
||||
utils::{
|
||||
convert_image, detect_image_path, is_valid_function_name, load_image_file,
|
||||
sanitize_function_name, ImageFormat,
|
||||
},
|
||||
utils::{convert_image, is_valid_function_name, sanitize_function_name, ImageFormat},
|
||||
},
|
||||
types::core::{Content, Role, Tool, ToolCall, ToolError},
|
||||
};
|
||||
@@ -34,30 +31,17 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
if !text.text.is_empty() {
|
||||
// Check for image paths in the text
|
||||
if let Some(image_path) = detect_image_path(&text.text) {
|
||||
has_multiple_content = true;
|
||||
// Try to load and convert the image
|
||||
if let Ok(image) = load_image_file(image_path) {
|
||||
content_array.push(json!({
|
||||
"type": "text",
|
||||
"text": text.text
|
||||
}));
|
||||
content_array.push(convert_image(&image, image_format));
|
||||
} else {
|
||||
content_array.push(json!({
|
||||
"type": "text",
|
||||
"text": text.text
|
||||
}));
|
||||
}
|
||||
} else {
|
||||
content_array.push(json!({
|
||||
"type": "text",
|
||||
"text": text.text
|
||||
}));
|
||||
}
|
||||
content_array.push(json!({
|
||||
"type": "text",
|
||||
"text": text.text
|
||||
}));
|
||||
}
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
// Handle direct image content
|
||||
let converted_image = convert_image(image, image_format);
|
||||
content_array.push(converted_image);
|
||||
}
|
||||
MessageContent::Thinking(content) => {
|
||||
has_multiple_content = true;
|
||||
content_array.push(json!({
|
||||
@@ -166,15 +150,6 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
// Handle direct image content
|
||||
content_array.push(json!({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": convert_image(image, image_format)
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -791,40 +766,6 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_messages_with_image_path() -> anyhow::Result<()> {
|
||||
// Create a temporary PNG file with valid PNG magic numbers
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let png_path = temp_dir.path().join("test.png");
|
||||
let png_data = [
|
||||
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||
];
|
||||
std::fs::write(&png_path, png_data)?;
|
||||
let png_path_str = png_path.to_str().unwrap();
|
||||
|
||||
// Create message with image path
|
||||
let message = Message::user().with_text(format!("Here is an image: {}", png_path_str));
|
||||
let spec = format_messages(&[message], &ImageFormat::OpenAi);
|
||||
|
||||
assert_eq!(spec.len(), 1);
|
||||
assert_eq!(spec[0]["role"], "user");
|
||||
|
||||
// Content should be an array with text and image
|
||||
let content = spec[0]["content"].as_array().unwrap();
|
||||
assert_eq!(content.len(), 2);
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
assert!(content[0]["text"].as_str().unwrap().contains(png_path_str));
|
||||
assert_eq!(content[1]["type"], "image_url");
|
||||
assert!(content[1]["image_url"]["url"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.starts_with("data:image/png;base64,"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_to_message_text() -> anyhow::Result<()> {
|
||||
let response = json!({
|
||||
|
||||
@@ -7,10 +7,7 @@ use crate::{
|
||||
providers::{
|
||||
base::Usage,
|
||||
errors::ProviderError,
|
||||
utils::{
|
||||
convert_image, detect_image_path, is_valid_function_name, load_image_file,
|
||||
sanitize_function_name, ImageFormat,
|
||||
},
|
||||
utils::{convert_image, is_valid_function_name, sanitize_function_name, ImageFormat},
|
||||
},
|
||||
types::core::{Content, Role, Tool, ToolCall, ToolError},
|
||||
};
|
||||
@@ -31,23 +28,13 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
if !text.text.is_empty() {
|
||||
// Check for image paths in the text
|
||||
if let Some(image_path) = detect_image_path(&text.text) {
|
||||
// Try to load and convert the image
|
||||
if let Ok(image) = load_image_file(image_path) {
|
||||
converted["content"] = json!([
|
||||
{"type": "text", "text": text.text},
|
||||
convert_image(&image, image_format)
|
||||
]);
|
||||
} else {
|
||||
// If image loading fails, just use the text
|
||||
converted["content"] = json!(text.text);
|
||||
}
|
||||
} else {
|
||||
converted["content"] = json!(text.text);
|
||||
}
|
||||
converted["content"] = json!(text.text);
|
||||
}
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
// Handle direct image content
|
||||
converted["content"] = json!([convert_image(image, image_format)]);
|
||||
}
|
||||
MessageContent::Thinking(_) => {
|
||||
// Thinking blocks are not directly used in OpenAI format
|
||||
continue;
|
||||
@@ -134,10 +121,6 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageContent::Image(image) => {
|
||||
// Handle direct image content
|
||||
converted["content"] = json!([convert_image(image, image_format)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -664,40 +647,6 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_messages_with_image_path() -> anyhow::Result<()> {
|
||||
// Create a temporary PNG file with valid PNG magic numbers
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let png_path = temp_dir.path().join("test.png");
|
||||
let png_data = [
|
||||
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||
];
|
||||
std::fs::write(&png_path, png_data)?;
|
||||
let png_path_str = png_path.to_str().unwrap();
|
||||
|
||||
// Create message with image path
|
||||
let message = Message::user().with_text(format!("Here is an image: {}", png_path_str));
|
||||
let spec = format_messages(&[message], &ImageFormat::OpenAi);
|
||||
|
||||
assert_eq!(spec.len(), 1);
|
||||
assert_eq!(spec[0]["role"], "user");
|
||||
|
||||
// Content should be an array with text and image
|
||||
let content = spec[0]["content"].as_array().unwrap();
|
||||
assert_eq!(content.len(), 2);
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
assert!(content[0]["text"].as_str().unwrap().contains(png_path_str));
|
||||
assert_eq!(content[1]["type"], "image_url");
|
||||
assert!(content[1]["image_url"]["url"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.starts_with("data:image/png;base64,"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_to_message_text() -> anyhow::Result<()> {
|
||||
let response = json!({
|
||||
|
||||
@@ -149,6 +149,7 @@ impl Provider for OpenAiProvider {
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
_request_id: Option<&str>, // OpenAI doesn't use request_id, so we ignore it
|
||||
) -> Result<ProviderCompleteResponse, ProviderError> {
|
||||
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
|
||||
|
||||
@@ -175,6 +176,7 @@ impl Provider for OpenAiProvider {
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
schema: &Value,
|
||||
_request_id: Option<&str>, // OpenAI doesn't use request_id, so we ignore it
|
||||
) -> Result<ProviderExtractResponse, ProviderError> {
|
||||
// 1. Build base payload (no tools)
|
||||
let mut payload = create_request(&self.model, system, messages, &[], &ImageFormat::OpenAi)?;
|
||||
|
||||
@@ -181,30 +181,6 @@ fn is_image_file(path: &Path) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Detect if a string contains a path to an image file
|
||||
pub fn detect_image_path(text: &str) -> Option<&str> {
|
||||
// Basic image file extension check
|
||||
let extensions = [".png", ".jpg", ".jpeg"];
|
||||
|
||||
// Find any word that ends with an image extension
|
||||
for word in text.split_whitespace() {
|
||||
if extensions
|
||||
.iter()
|
||||
.any(|ext| word.to_lowercase().ends_with(ext))
|
||||
{
|
||||
let path = Path::new(word);
|
||||
// Check if it's an absolute path and file exists
|
||||
if path.is_absolute() && path.is_file() {
|
||||
// Verify it's actually an image file
|
||||
if is_image_file(path) {
|
||||
return Some(word);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Convert a local image file to base64 encoded ImageContent
|
||||
pub fn load_image_file(path: &str) -> Result<ImageContent, ProviderError> {
|
||||
let path = Path::new(path);
|
||||
@@ -267,81 +243,6 @@ pub fn emit_debug_trace(
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_detect_image_path() {
|
||||
// Create a temporary PNG file with valid PNG magic numbers
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let png_path = temp_dir.path().join("test.png");
|
||||
let png_data = [
|
||||
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||
];
|
||||
std::fs::write(&png_path, png_data).unwrap();
|
||||
let png_path_str = png_path.to_str().unwrap();
|
||||
|
||||
// Create a fake PNG (wrong magic numbers)
|
||||
let fake_png_path = temp_dir.path().join("fake.png");
|
||||
std::fs::write(&fake_png_path, b"not a real png").unwrap();
|
||||
|
||||
// Test with valid PNG file using absolute path
|
||||
let text = format!("Here is an image {}", png_path_str);
|
||||
assert_eq!(detect_image_path(&text), Some(png_path_str));
|
||||
|
||||
// Test with non-image file that has .png extension
|
||||
let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap());
|
||||
assert_eq!(detect_image_path(&text), None);
|
||||
|
||||
// Test with non-existent file
|
||||
let text = "Here is a fake.png that doesn't exist";
|
||||
assert_eq!(detect_image_path(text), None);
|
||||
|
||||
// Test with non-image file
|
||||
let text = "Here is a file.txt";
|
||||
assert_eq!(detect_image_path(text), None);
|
||||
|
||||
// Test with relative path (should not match)
|
||||
let text = "Here is a relative/path/image.png";
|
||||
assert_eq!(detect_image_path(text), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_image_file() {
|
||||
// Create a temporary PNG file with valid PNG magic numbers
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let png_path = temp_dir.path().join("test.png");
|
||||
let png_data = [
|
||||
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||
];
|
||||
std::fs::write(&png_path, png_data).unwrap();
|
||||
let png_path_str = png_path.to_str().unwrap();
|
||||
|
||||
// Create a fake PNG (wrong magic numbers)
|
||||
let fake_png_path = temp_dir.path().join("fake.png");
|
||||
std::fs::write(&fake_png_path, b"not a real png").unwrap();
|
||||
let fake_png_path_str = fake_png_path.to_str().unwrap();
|
||||
|
||||
// Test loading valid PNG file
|
||||
let result = load_image_file(png_path_str);
|
||||
assert!(result.is_ok());
|
||||
let image = result.unwrap();
|
||||
assert_eq!(image.mime_type, "image/png");
|
||||
|
||||
// Test loading fake PNG file
|
||||
let result = load_image_file(fake_png_path_str);
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("not a valid image"));
|
||||
|
||||
// Test non-existent file
|
||||
let result = load_image_file("nonexistent.png");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_function_name() {
|
||||
assert_eq!(sanitize_function_name("hello-world"), "hello-world");
|
||||
|
||||
@@ -6,13 +6,14 @@ use crate::{
|
||||
|
||||
/// Generates a structured output based on the provided schema,
|
||||
/// system prompt and user messages.
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
#[uniffi::export(async_runtime = "tokio", default(request_id = None))]
|
||||
pub async fn generate_structured_outputs(
|
||||
provider_name: &str,
|
||||
provider_config: JsonValueFfi,
|
||||
system_prompt: &str,
|
||||
messages: &[Message],
|
||||
schema: JsonValueFfi,
|
||||
request_id: Option<String>,
|
||||
) -> Result<ProviderExtractResponse, ProviderError> {
|
||||
// Use OpenAI models specifically for this task
|
||||
let model_name = if provider_name == "databricks" {
|
||||
@@ -23,7 +24,9 @@ pub async fn generate_structured_outputs(
|
||||
let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0));
|
||||
let provider = create(provider_name, provider_config, model_cfg)?;
|
||||
|
||||
let resp = provider.extract(system_prompt, messages, &schema).await?;
|
||||
let resp = provider
|
||||
.extract(system_prompt, messages, &schema, request_id.as_deref())
|
||||
.await?;
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ pub struct CompletionRequest {
|
||||
pub system_prompt_override: Option<String>,
|
||||
pub messages: Vec<Message>,
|
||||
pub extensions: Vec<ExtensionConfig>,
|
||||
pub request_id: Option<String>,
|
||||
}
|
||||
|
||||
impl CompletionRequest {
|
||||
@@ -40,10 +41,17 @@ impl CompletionRequest {
|
||||
system_preamble,
|
||||
messages,
|
||||
extensions,
|
||||
request_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_request_id(mut self, request_id: String) -> Self {
|
||||
self.request_id = Some(request_id);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[uniffi::export(default(system_preamble = None, system_prompt_override = None))]
|
||||
pub fn create_completion_request(
|
||||
provider_name: &str,
|
||||
@@ -53,8 +61,9 @@ pub fn create_completion_request(
|
||||
system_prompt_override: Option<String>,
|
||||
messages: Vec<Message>,
|
||||
extensions: Vec<ExtensionConfig>,
|
||||
request_id: Option<String>,
|
||||
) -> CompletionRequest {
|
||||
CompletionRequest::new(
|
||||
let mut request = CompletionRequest::new(
|
||||
provider_name.to_string(),
|
||||
provider_config,
|
||||
model_config,
|
||||
@@ -62,7 +71,13 @@ pub fn create_completion_request(
|
||||
system_prompt_override,
|
||||
messages,
|
||||
extensions,
|
||||
)
|
||||
);
|
||||
|
||||
if let Some(req_id) = request_id {
|
||||
request = request.with_request_id(req_id);
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
uniffi::custom_type!(CompletionRequest, String, {
|
||||
|
||||
@@ -22,7 +22,7 @@ async fn _generate_session_name(messages: &[Message]) -> Result<String, Provider
|
||||
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
|
||||
});
|
||||
|
||||
generate_session_name(provider_name, provider_config, messages).await
|
||||
generate_session_name(provider_name, provider_config, messages, None).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -24,7 +24,7 @@ async fn _generate_tooltip(messages: &[Message]) -> Result<String, ProviderError
|
||||
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
|
||||
});
|
||||
|
||||
generate_tooltip(provider_name, provider_config, messages).await
|
||||
generate_tooltip(provider_name, provider_config, messages, None).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -91,7 +91,7 @@ impl ProviderTester {
|
||||
|
||||
let response = self
|
||||
.provider
|
||||
.complete("You are a helpful assistant.", &[message], &[])
|
||||
.complete("You are a helpful assistant.", &[message], &[], None)
|
||||
.await?;
|
||||
|
||||
// For a basic response, we expect a single text response
|
||||
@@ -134,6 +134,7 @@ impl ProviderTester {
|
||||
"You are a helpful weather assistant.",
|
||||
&[message.clone()],
|
||||
&[weather_tool.clone()],
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -182,6 +183,7 @@ impl ProviderTester {
|
||||
"You are a helpful weather assistant.",
|
||||
&[message, response1.message, weather],
|
||||
&[weather_tool],
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -225,7 +227,7 @@ impl ProviderTester {
|
||||
// Test that we get ProviderError::ContextLengthExceeded when the context window is exceeded
|
||||
let result = self
|
||||
.provider
|
||||
.complete("You are a helpful assistant.", &messages, &[])
|
||||
.complete("You are a helpful assistant.", &messages, &[], None)
|
||||
.await;
|
||||
|
||||
// Print some debug info
|
||||
|
||||
@@ -125,7 +125,7 @@ where
|
||||
let provider = provider_type.create_provider(cfg)?;
|
||||
|
||||
let msg = Message::user().with_text(user_text);
|
||||
let resp = provider.extract(system, &[msg], &schema).await?;
|
||||
let resp = provider.extract(system, &[msg], &schema, None).await?;
|
||||
|
||||
println!("[{:?}] extract => {}", provider_type, resp.data);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user