mirror of
https://github.com/aljazceru/goose.git
synced 2026-01-31 12:14:32 +01:00
[goose-llm] kotlin fn for getting structured outputs (#2547)
This commit is contained in:
21
Justfile
21
Justfile
@@ -324,3 +324,24 @@ win-total-dbg *allparam:
|
||||
win-total-rls *allparam:
|
||||
just win-bld-rls{{allparam}}
|
||||
just win-run-rls
|
||||
|
||||
### Build and run the Kotlin example with
|
||||
### auto-generated bindings for goose-llm
|
||||
kotlin-example:
|
||||
# Build Rust dylib and generate Kotlin bindings
|
||||
cargo build -p goose-llm
|
||||
cargo run --features=uniffi/cli --bin uniffi-bindgen generate \
|
||||
--library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
|
||||
|
||||
# Compile and run the Kotlin example
|
||||
cd bindings/kotlin/ && kotlinc \
|
||||
example/Usage.kt \
|
||||
uniffi/goose_llm/goose_llm.kt \
|
||||
-classpath "libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
|
||||
-include-runtime \
|
||||
-d example.jar
|
||||
|
||||
cd bindings/kotlin/ && java \
|
||||
-Djna.library.path=$HOME/Development/goose/target/debug \
|
||||
-classpath "example.jar:libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
|
||||
UsageKt
|
||||
@@ -135,6 +135,66 @@ fun main() = runBlocking {
|
||||
)
|
||||
|
||||
val response = completion(req)
|
||||
println("\nCompletion Response:")
|
||||
println(response.message)
|
||||
println("\nCompletion Response:\n${response.message}")
|
||||
println()
|
||||
|
||||
// ---- UI Extraction (custom schema) ----
|
||||
runUiExtraction(providerName, providerConfig)
|
||||
}
|
||||
|
||||
|
||||
suspend fun runUiExtraction(providerName: String, providerConfig: String) {
|
||||
val systemPrompt = "You are a UI generator AI. Convert the user input into a JSON-driven UI."
|
||||
val messages = listOf(
|
||||
Message(
|
||||
role = Role.USER,
|
||||
created = System.currentTimeMillis() / 1000,
|
||||
content = listOf(
|
||||
MessageContent.Text(
|
||||
TextContent("Make a User Profile Form")
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
val schema = """{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["div","button","header","section","field","form"]
|
||||
},
|
||||
"label": { "type": "string" },
|
||||
"children": {
|
||||
"type": "array",
|
||||
"items": { "${'$'}ref": "#" }
|
||||
},
|
||||
"attributes": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"value": { "type": "string" }
|
||||
},
|
||||
"required": ["name","value"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["type","label","children","attributes"],
|
||||
"additionalProperties": false
|
||||
}""".trimIndent();
|
||||
|
||||
try {
|
||||
val response = generateStructuredOutputs(
|
||||
providerName = providerName,
|
||||
providerConfig = providerConfig,
|
||||
systemPrompt = systemPrompt,
|
||||
messages = messages,
|
||||
schema = schema
|
||||
)
|
||||
println("\nUI Extraction Output:\n${response}")
|
||||
} catch (e: ProviderException) {
|
||||
println("\nUI Extraction failed:\n${e.message}")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -775,6 +775,8 @@ internal interface IntegrityCheckingUniffiLib : Library {
|
||||
|
||||
fun uniffi_goose_llm_checksum_func_generate_session_name(): Short
|
||||
|
||||
fun uniffi_goose_llm_checksum_func_generate_structured_outputs(): Short
|
||||
|
||||
fun uniffi_goose_llm_checksum_func_generate_tooltip(): Short
|
||||
|
||||
fun uniffi_goose_llm_checksum_func_print_messages(): Short
|
||||
@@ -847,6 +849,14 @@ internal interface UniffiLib : Library {
|
||||
`messages`: RustBuffer.ByValue,
|
||||
): Long
|
||||
|
||||
fun uniffi_goose_llm_fn_func_generate_structured_outputs(
|
||||
`providerName`: RustBuffer.ByValue,
|
||||
`providerConfig`: RustBuffer.ByValue,
|
||||
`systemPrompt`: RustBuffer.ByValue,
|
||||
`messages`: RustBuffer.ByValue,
|
||||
`schema`: RustBuffer.ByValue,
|
||||
): Long
|
||||
|
||||
fun uniffi_goose_llm_fn_func_generate_tooltip(
|
||||
`providerName`: RustBuffer.ByValue,
|
||||
`providerConfig`: RustBuffer.ByValue,
|
||||
@@ -1090,16 +1100,19 @@ private fun uniffiCheckApiChecksums(lib: IntegrityCheckingUniffiLib) {
|
||||
if (lib.uniffi_goose_llm_checksum_func_completion() != 47457.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 51008.toShort()) {
|
||||
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 39068.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 22809.toShort()) {
|
||||
if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 49910.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 9810.toShort()) {
|
||||
if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 64087.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 15466.toShort()) {
|
||||
if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 43426.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 41121.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_print_messages() != 30278.toShort()) {
|
||||
@@ -1626,6 +1639,54 @@ public object FfiConverterTypeProviderCompleteResponse : FfiConverterRustBuffer<
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Response from a structured‐extraction call
|
||||
*/
|
||||
data class ProviderExtractResponse(
|
||||
/**
|
||||
* The extracted JSON object
|
||||
*/
|
||||
var `data`: Value,
|
||||
/**
|
||||
* Which model produced it
|
||||
*/
|
||||
var `model`: kotlin.String,
|
||||
/**
|
||||
* Token usage stats
|
||||
*/
|
||||
var `usage`: Usage,
|
||||
) {
|
||||
companion object
|
||||
}
|
||||
|
||||
/**
|
||||
* @suppress
|
||||
*/
|
||||
public object FfiConverterTypeProviderExtractResponse : FfiConverterRustBuffer<ProviderExtractResponse> {
|
||||
override fun read(buf: ByteBuffer): ProviderExtractResponse =
|
||||
ProviderExtractResponse(
|
||||
FfiConverterTypeValue.read(buf),
|
||||
FfiConverterString.read(buf),
|
||||
FfiConverterTypeUsage.read(buf),
|
||||
)
|
||||
|
||||
override fun allocationSize(value: ProviderExtractResponse) =
|
||||
(
|
||||
FfiConverterTypeValue.allocationSize(value.`data`) +
|
||||
FfiConverterString.allocationSize(value.`model`) +
|
||||
FfiConverterTypeUsage.allocationSize(value.`usage`)
|
||||
)
|
||||
|
||||
override fun write(
|
||||
value: ProviderExtractResponse,
|
||||
buf: ByteBuffer,
|
||||
) {
|
||||
FfiConverterTypeValue.write(value.`data`, buf)
|
||||
FfiConverterString.write(value.`model`, buf)
|
||||
FfiConverterTypeUsage.write(value.`usage`, buf)
|
||||
}
|
||||
}
|
||||
|
||||
data class RedactedThinkingContent(
|
||||
var `data`: kotlin.String,
|
||||
) {
|
||||
@@ -1750,6 +1811,46 @@ public object FfiConverterTypeThinkingContent : FfiConverterRustBuffer<ThinkingC
|
||||
}
|
||||
}
|
||||
|
||||
data class ToolConfig(
|
||||
var `name`: kotlin.String,
|
||||
var `description`: kotlin.String,
|
||||
var `inputSchema`: Value,
|
||||
var `approvalMode`: ToolApprovalMode,
|
||||
) {
|
||||
companion object
|
||||
}
|
||||
|
||||
/**
|
||||
* @suppress
|
||||
*/
|
||||
public object FfiConverterTypeToolConfig : FfiConverterRustBuffer<ToolConfig> {
|
||||
override fun read(buf: ByteBuffer): ToolConfig =
|
||||
ToolConfig(
|
||||
FfiConverterString.read(buf),
|
||||
FfiConverterString.read(buf),
|
||||
FfiConverterTypeValue.read(buf),
|
||||
FfiConverterTypeToolApprovalMode.read(buf),
|
||||
)
|
||||
|
||||
override fun allocationSize(value: ToolConfig) =
|
||||
(
|
||||
FfiConverterString.allocationSize(value.`name`) +
|
||||
FfiConverterString.allocationSize(value.`description`) +
|
||||
FfiConverterTypeValue.allocationSize(value.`inputSchema`) +
|
||||
FfiConverterTypeToolApprovalMode.allocationSize(value.`approvalMode`)
|
||||
)
|
||||
|
||||
override fun write(
|
||||
value: ToolConfig,
|
||||
buf: ByteBuffer,
|
||||
) {
|
||||
FfiConverterString.write(value.`name`, buf)
|
||||
FfiConverterString.write(value.`description`, buf)
|
||||
FfiConverterTypeValue.write(value.`inputSchema`, buf)
|
||||
FfiConverterTypeToolApprovalMode.write(value.`approvalMode`, buf)
|
||||
}
|
||||
}
|
||||
|
||||
data class ToolRequest(
|
||||
var `id`: kotlin.String,
|
||||
var `toolCall`: ToolRequestToolCall,
|
||||
@@ -2737,34 +2838,6 @@ public object FfiConverterSequenceTypeMessage : FfiConverterRustBuffer<List<Mess
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @suppress
|
||||
*/
|
||||
public object FfiConverterSequenceTypeMessageContent : FfiConverterRustBuffer<List<MessageContent>> {
|
||||
override fun read(buf: ByteBuffer): List<MessageContent> {
|
||||
val len = buf.getInt()
|
||||
return List<MessageContent>(len) {
|
||||
FfiConverterTypeMessageContent.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
override fun allocationSize(value: List<MessageContent>): ULong {
|
||||
val sizeForLength = 4UL
|
||||
val sizeForItems = value.map { FfiConverterTypeMessageContent.allocationSize(it) }.sum()
|
||||
return sizeForLength + sizeForItems
|
||||
}
|
||||
|
||||
override fun write(
|
||||
value: List<MessageContent>,
|
||||
buf: ByteBuffer,
|
||||
) {
|
||||
buf.putInt(value.size)
|
||||
value.iterator().forEach {
|
||||
FfiConverterTypeMessageContent.write(it, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @suppress
|
||||
*/
|
||||
@@ -2793,6 +2866,34 @@ public object FfiConverterSequenceTypeToolConfig : FfiConverterRustBuffer<List<T
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @suppress
|
||||
*/
|
||||
public object FfiConverterSequenceTypeMessageContent : FfiConverterRustBuffer<List<MessageContent>> {
|
||||
override fun read(buf: ByteBuffer): List<MessageContent> {
|
||||
val len = buf.getInt()
|
||||
return List<MessageContent>(len) {
|
||||
FfiConverterTypeMessageContent.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
override fun allocationSize(value: List<MessageContent>): ULong {
|
||||
val sizeForLength = 4UL
|
||||
val sizeForItems = value.map { FfiConverterTypeMessageContent.allocationSize(it) }.sum()
|
||||
return sizeForLength + sizeForItems
|
||||
}
|
||||
|
||||
override fun write(
|
||||
value: List<MessageContent>,
|
||||
buf: ByteBuffer,
|
||||
) {
|
||||
buf.putInt(value.size)
|
||||
value.iterator().forEach {
|
||||
FfiConverterTypeMessageContent.write(it, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Typealias from the type name used in the UDL file to the builtin type. This
|
||||
* is needed because the UDL type name is used in function/method signatures.
|
||||
@@ -2809,22 +2910,6 @@ public typealias FfiConverterTypeCompletionRequest = FfiConverterString
|
||||
public typealias Contents = List<MessageContent>
|
||||
public typealias FfiConverterTypeContents = FfiConverterSequenceTypeMessageContent
|
||||
|
||||
/**
|
||||
* Typealias from the type name used in the UDL file to the builtin type. This
|
||||
* is needed because the UDL type name is used in function/method signatures.
|
||||
* It's also what we have an external type that references a custom type.
|
||||
*/
|
||||
public typealias JsonValueFfi = kotlin.String
|
||||
public typealias FfiConverterTypeJsonValueFfi = FfiConverterString
|
||||
|
||||
/**
|
||||
* Typealias from the type name used in the UDL file to the builtin type. This
|
||||
* is needed because the UDL type name is used in function/method signatures.
|
||||
* It's also what we have an external type that references a custom type.
|
||||
*/
|
||||
public typealias ToolConfig = kotlin.String
|
||||
public typealias FfiConverterTypeToolConfig = FfiConverterString
|
||||
|
||||
/**
|
||||
* Typealias from the type name used in the UDL file to the builtin type. This
|
||||
* is needed because the UDL type name is used in function/method signatures.
|
||||
@@ -2841,6 +2926,14 @@ public typealias FfiConverterTypeToolRequestToolCall = FfiConverterString
|
||||
public typealias ToolResponseToolResult = kotlin.String
|
||||
public typealias FfiConverterTypeToolResponseToolResult = FfiConverterString
|
||||
|
||||
/**
|
||||
* Typealias from the type name used in the UDL file to the builtin type. This
|
||||
* is needed because the UDL type name is used in function/method signatures.
|
||||
* It's also what we have an external type that references a custom type.
|
||||
*/
|
||||
public typealias Value = kotlin.String
|
||||
public typealias FfiConverterTypeValue = FfiConverterString
|
||||
|
||||
/**
|
||||
* Public API for the Goose LLM completion function
|
||||
*/
|
||||
@@ -2860,7 +2953,7 @@ suspend fun `completion`(`req`: CompletionRequest): CompletionResponse =
|
||||
|
||||
fun `createCompletionRequest`(
|
||||
`providerName`: kotlin.String,
|
||||
`providerConfig`: JsonValueFfi,
|
||||
`providerConfig`: Value,
|
||||
`modelConfig`: ModelConfig,
|
||||
`systemPreamble`: kotlin.String,
|
||||
`messages`: List<Message>,
|
||||
@@ -2870,7 +2963,7 @@ fun `createCompletionRequest`(
|
||||
uniffiRustCall { _status ->
|
||||
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_create_completion_request(
|
||||
FfiConverterString.lower(`providerName`),
|
||||
FfiConverterTypeJsonValueFfi.lower(`providerConfig`),
|
||||
FfiConverterTypeValue.lower(`providerConfig`),
|
||||
FfiConverterTypeModelConfig.lower(`modelConfig`),
|
||||
FfiConverterString.lower(`systemPreamble`),
|
||||
FfiConverterSequenceTypeMessage.lower(`messages`),
|
||||
@@ -2883,7 +2976,7 @@ fun `createCompletionRequest`(
|
||||
fun `createToolConfig`(
|
||||
`name`: kotlin.String,
|
||||
`description`: kotlin.String,
|
||||
`inputSchema`: JsonValueFfi,
|
||||
`inputSchema`: Value,
|
||||
`approvalMode`: ToolApprovalMode,
|
||||
): ToolConfig =
|
||||
FfiConverterTypeToolConfig.lift(
|
||||
@@ -2891,7 +2984,7 @@ fun `createToolConfig`(
|
||||
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_create_tool_config(
|
||||
FfiConverterString.lower(`name`),
|
||||
FfiConverterString.lower(`description`),
|
||||
FfiConverterTypeJsonValueFfi.lower(`inputSchema`),
|
||||
FfiConverterTypeValue.lower(`inputSchema`),
|
||||
FfiConverterTypeToolApprovalMode.lower(`approvalMode`),
|
||||
_status,
|
||||
)
|
||||
@@ -2905,13 +2998,13 @@ fun `createToolConfig`(
|
||||
@Suppress("ASSIGNED_BUT_NEVER_ACCESSED_VARIABLE")
|
||||
suspend fun `generateSessionName`(
|
||||
`providerName`: kotlin.String,
|
||||
`providerConfig`: JsonValueFfi,
|
||||
`providerConfig`: Value,
|
||||
`messages`: List<Message>,
|
||||
): kotlin.String =
|
||||
uniffiRustCallAsync(
|
||||
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_session_name(
|
||||
FfiConverterString.lower(`providerName`),
|
||||
FfiConverterTypeJsonValueFfi.lower(`providerConfig`),
|
||||
FfiConverterTypeValue.lower(`providerConfig`),
|
||||
FfiConverterSequenceTypeMessage.lower(`messages`),
|
||||
),
|
||||
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
|
||||
@@ -2923,6 +3016,36 @@ suspend fun `generateSessionName`(
|
||||
ProviderException.ErrorHandler,
|
||||
)
|
||||
|
||||
/**
|
||||
* Generates a structured output based on the provided schema,
|
||||
* system prompt and user messages.
|
||||
*/
|
||||
@Throws(ProviderException::class)
|
||||
@Suppress("ASSIGNED_BUT_NEVER_ACCESSED_VARIABLE")
|
||||
suspend fun `generateStructuredOutputs`(
|
||||
`providerName`: kotlin.String,
|
||||
`providerConfig`: Value,
|
||||
`systemPrompt`: kotlin.String,
|
||||
`messages`: List<Message>,
|
||||
`schema`: Value,
|
||||
): ProviderExtractResponse =
|
||||
uniffiRustCallAsync(
|
||||
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_structured_outputs(
|
||||
FfiConverterString.lower(`providerName`),
|
||||
FfiConverterTypeValue.lower(`providerConfig`),
|
||||
FfiConverterString.lower(`systemPrompt`),
|
||||
FfiConverterSequenceTypeMessage.lower(`messages`),
|
||||
FfiConverterTypeValue.lower(`schema`),
|
||||
),
|
||||
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
|
||||
{ future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) },
|
||||
{ future -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_free_rust_buffer(future) },
|
||||
// lift function
|
||||
{ FfiConverterTypeProviderExtractResponse.lift(it) },
|
||||
// Error FFI converter
|
||||
ProviderException.ErrorHandler,
|
||||
)
|
||||
|
||||
/**
|
||||
* Generates a tooltip summarizing the last two messages in the session,
|
||||
* including any tool calls or results.
|
||||
@@ -2931,13 +3054,13 @@ suspend fun `generateSessionName`(
|
||||
@Suppress("ASSIGNED_BUT_NEVER_ACCESSED_VARIABLE")
|
||||
suspend fun `generateTooltip`(
|
||||
`providerName`: kotlin.String,
|
||||
`providerConfig`: JsonValueFfi,
|
||||
`providerConfig`: Value,
|
||||
`messages`: List<Message>,
|
||||
): kotlin.String =
|
||||
uniffiRustCallAsync(
|
||||
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_tooltip(
|
||||
FfiConverterString.lower(`providerName`),
|
||||
FfiConverterTypeJsonValueFfi.lower(`providerConfig`),
|
||||
FfiConverterTypeValue.lower(`providerConfig`),
|
||||
FfiConverterSequenceTypeMessage.lower(`messages`),
|
||||
),
|
||||
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
|
||||
|
||||
@@ -31,34 +31,13 @@ Structure:
|
||||
│ └── goose_llm.kt ← auto-generated bindings
|
||||
```
|
||||
|
||||
#### Create Kotlin bindings:
|
||||
|
||||
```bash
|
||||
# run from project root directory
|
||||
cargo build -p goose-llm
|
||||
|
||||
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
|
||||
```
|
||||
|
||||
|
||||
#### Kotlin -> Rust: run example
|
||||
|
||||
The following `just` command creates kotlin bindings, then compiles and runs an example.
|
||||
|
||||
```bash
|
||||
pushd bindings/kotlin/
|
||||
|
||||
kotlinc \
|
||||
example/Usage.kt \
|
||||
uniffi/goose_llm/goose_llm.kt \
|
||||
-classpath "libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
|
||||
-include-runtime \
|
||||
-d example.jar
|
||||
|
||||
java \
|
||||
-Djna.library.path=$HOME/Development/goose/target/debug \
|
||||
-classpath "example.jar:libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
|
||||
UsageKt
|
||||
|
||||
popd
|
||||
just kotlin-example
|
||||
```
|
||||
|
||||
You will have to download jars in `bindings/kotlin/libs` directory (only the first time):
|
||||
@@ -70,6 +49,16 @@ curl -O https://repo1.maven.org/maven2/net/java/dev/jna/jna/5.13.0/jna-5.13.0.ja
|
||||
popd
|
||||
```
|
||||
|
||||
To just create the Kotlin bindings:
|
||||
|
||||
```bash
|
||||
# run from project root directory
|
||||
cargo build -p goose-llm
|
||||
|
||||
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
|
||||
```
|
||||
|
||||
|
||||
#### Python -> Rust: generate bindings, run example
|
||||
|
||||
```bash
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::create;
|
||||
use crate::generate_structured_outputs;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::types::core::Role;
|
||||
use crate::{message::Message, types::json_value_ffi::JsonValueFfi};
|
||||
@@ -54,15 +53,6 @@ pub async fn generate_session_name(
|
||||
provider_config: JsonValueFfi,
|
||||
messages: &[Message],
|
||||
) -> Result<String, ProviderError> {
|
||||
// Use OpenAI models specifically for this task
|
||||
let model_name = if provider_name == "databricks" {
|
||||
"goose-gpt-4-1"
|
||||
} else {
|
||||
"gpt-4.1"
|
||||
};
|
||||
let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0));
|
||||
let provider = create(provider_name, provider_config.into(), model_cfg)?;
|
||||
|
||||
// Collect up to the first 3 user messages (truncated to 300 chars each)
|
||||
let context: Vec<String> = messages
|
||||
.iter()
|
||||
@@ -96,10 +86,15 @@ pub async fn generate_session_name(
|
||||
"required": ["name"],
|
||||
"additionalProperties": false
|
||||
});
|
||||
let user_msg = Message::user().with_text(&user_msg_text);
|
||||
let resp = provider
|
||||
.extract(&system_prompt, &[user_msg], &schema)
|
||||
.await?;
|
||||
|
||||
let resp = generate_structured_outputs(
|
||||
provider_name,
|
||||
provider_config,
|
||||
&system_prompt,
|
||||
&[Message::user().with_text(&user_msg_text)],
|
||||
schema,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let obj = resp
|
||||
.data
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::generate_structured_outputs;
|
||||
use crate::message::{Message, MessageContent};
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::create;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::types::core::{Content, Role};
|
||||
use crate::types::json_value_ffi::JsonValueFfi;
|
||||
@@ -59,16 +58,7 @@ pub async fn generate_tooltip(
|
||||
provider_config: JsonValueFfi,
|
||||
messages: &[Message],
|
||||
) -> Result<String, ProviderError> {
|
||||
// Use OpenAI models specifically for this task
|
||||
let model_name = if provider_name == "databricks" {
|
||||
"goose-gpt-4-1"
|
||||
} else {
|
||||
"gpt-4.1"
|
||||
};
|
||||
let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0));
|
||||
let provider = create(provider_name, provider_config.into(), model_cfg)?;
|
||||
|
||||
// Need at least two messages to summarize
|
||||
// Need at least two messages to generate a tooltip
|
||||
if messages.len() < 2 {
|
||||
return Err(ProviderError::ExecutionError(
|
||||
"Need at least two messages to generate a tooltip".to_string(),
|
||||
@@ -151,11 +141,15 @@ pub async fn generate_tooltip(
|
||||
"additionalProperties": false
|
||||
});
|
||||
|
||||
// Call extract
|
||||
let user_msg = Message::user().with_text(&user_msg_text);
|
||||
let resp = provider
|
||||
.extract(&system_prompt, &[user_msg], &schema)
|
||||
.await?;
|
||||
// Get the structured outputs
|
||||
let resp = generate_structured_outputs(
|
||||
provider_name,
|
||||
provider_config,
|
||||
&system_prompt,
|
||||
&[Message::user().with_text(&user_msg_text)],
|
||||
schema,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Pull out the tooltip field
|
||||
let obj = resp
|
||||
|
||||
@@ -6,8 +6,10 @@ pub mod message;
|
||||
mod model;
|
||||
mod prompt_template;
|
||||
pub mod providers;
|
||||
mod structured_outputs;
|
||||
pub mod types;
|
||||
|
||||
pub use completion::completion;
|
||||
pub use message::Message;
|
||||
pub use model::ModelConfig;
|
||||
pub use structured_outputs::generate_structured_outputs;
|
||||
|
||||
@@ -44,7 +44,7 @@ impl ProviderCompleteResponse {
|
||||
}
|
||||
|
||||
/// Response from a structured‐extraction call
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, uniffi::Record)]
|
||||
pub struct ProviderExtractResponse {
|
||||
/// The extracted JSON object
|
||||
pub data: serde_json::Value,
|
||||
|
||||
29
crates/goose-llm/src/structured_outputs.rs
Normal file
29
crates/goose-llm/src/structured_outputs.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use crate::{
|
||||
providers::{create, errors::ProviderError, ProviderExtractResponse},
|
||||
types::json_value_ffi::JsonValueFfi,
|
||||
Message, ModelConfig,
|
||||
};
|
||||
|
||||
/// Generates a structured output based on the provided schema,
|
||||
/// system prompt and user messages.
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
pub async fn generate_structured_outputs(
|
||||
provider_name: &str,
|
||||
provider_config: JsonValueFfi,
|
||||
system_prompt: &str,
|
||||
messages: &[Message],
|
||||
schema: JsonValueFfi,
|
||||
) -> Result<ProviderExtractResponse, ProviderError> {
|
||||
// Use OpenAI models specifically for this task
|
||||
let model_name = if provider_name == "databricks" {
|
||||
"goose-gpt-4-1"
|
||||
} else {
|
||||
"gpt-4.1"
|
||||
};
|
||||
let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0));
|
||||
let provider = create(provider_name, provider_config, model_cfg)?;
|
||||
|
||||
let resp = provider.extract(system_prompt, messages, &schema).await?;
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
@@ -52,7 +52,7 @@ pub fn create_completion_request(
|
||||
) -> CompletionRequest {
|
||||
CompletionRequest::new(
|
||||
provider_name.to_string(),
|
||||
provider_config.into(),
|
||||
provider_config,
|
||||
model_config,
|
||||
system_preamble.to_string(),
|
||||
messages,
|
||||
@@ -141,11 +141,11 @@ pub enum ToolApprovalMode {
|
||||
Smart,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct ToolConfig {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: serde_json::Value,
|
||||
pub input_schema: JsonValueFfi,
|
||||
pub approval_mode: ToolApprovalMode,
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ impl ToolConfig {
|
||||
pub fn new(
|
||||
name: &str,
|
||||
description: &str,
|
||||
input_schema: serde_json::Value,
|
||||
input_schema: JsonValueFfi,
|
||||
approval_mode: ToolApprovalMode,
|
||||
) -> Self {
|
||||
Self {
|
||||
@@ -182,18 +182,9 @@ pub fn create_tool_config(
|
||||
input_schema: JsonValueFfi,
|
||||
approval_mode: ToolApprovalMode,
|
||||
) -> ToolConfig {
|
||||
ToolConfig::new(name, description, input_schema.into(), approval_mode)
|
||||
ToolConfig::new(name, description, input_schema, approval_mode)
|
||||
}
|
||||
|
||||
uniffi::custom_type!(ToolConfig, String, {
|
||||
lower: |tc: &ToolConfig| {
|
||||
serde_json::to_string(&tc).unwrap()
|
||||
},
|
||||
try_lift: |s: String| {
|
||||
Ok(serde_json::from_str(&s).unwrap())
|
||||
},
|
||||
});
|
||||
|
||||
// — Register the newtypes with UniFFI, converting via JSON strings —
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
|
||||
|
||||
@@ -1,84 +1,18 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// `serde_json::Value` gets converted to a `String` to pass across the FFI.
|
||||
// https://github.com/mozilla/uniffi-rs/blob/main/docs/manual/src/types/custom_types.md?plain=1
|
||||
// https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/examples/custom-types/src/lib.rs#L63-L69
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct JsonValueFfi(Value);
|
||||
|
||||
impl From<JsonValueFfi> for Value {
|
||||
fn from(val: JsonValueFfi) -> Self {
|
||||
val.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Value> for JsonValueFfi {
|
||||
fn from(val: Value) -> Self {
|
||||
JsonValueFfi(val)
|
||||
}
|
||||
}
|
||||
|
||||
uniffi::custom_type!(JsonValueFfi, String, {
|
||||
uniffi::custom_type!(Value, String, {
|
||||
// Remote is required since 'Value' is from a different crate
|
||||
remote,
|
||||
lower: |obj| {
|
||||
serde_json::to_string(&obj.0).unwrap()
|
||||
serde_json::to_string(&obj).unwrap()
|
||||
},
|
||||
try_lift: |val| {
|
||||
Ok(serde_json::from_str(&val).unwrap() )
|
||||
},
|
||||
});
|
||||
|
||||
// Write some tests to ensure that the conversion works as expected
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_json_value_ffi_conversion() {
|
||||
let original = JsonValueFfi(json!({"key": "value"}));
|
||||
let serialized = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: JsonValueFfi = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(original.0, deserialized.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_value_ffi_to_serde() {
|
||||
let original = JsonValueFfi(json!({"key": "value"}));
|
||||
let value: Value = original.into();
|
||||
assert_eq!(value, json!({"key": "value"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_value_ffi_from_serde() {
|
||||
let value = json!({"key": "value"});
|
||||
let original: JsonValueFfi = value.into();
|
||||
assert_eq!(original.0, json!({"key": "value"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_value_ffi_lower() {
|
||||
let original = JsonValueFfi(json!({"key": "value"}));
|
||||
let serialized = serde_json::to_string(&original).unwrap();
|
||||
|
||||
assert_eq!(serialized, "{\"key\":\"value\"}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_value_ffi_try_lift() {
|
||||
let json_str = "{\"key\":\"value\"}";
|
||||
let deserialized: JsonValueFfi = serde_json::from_str(json_str).unwrap();
|
||||
let expected = JsonValueFfi(json!({"key": "value"}));
|
||||
assert_eq!(deserialized.0, expected.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_value_ffi_custom_type() {
|
||||
let json_str = "{\"key\":\"value\"}";
|
||||
let deserialized: JsonValueFfi = serde_json::from_str(json_str).unwrap();
|
||||
let serialized = serde_json::to_string(&deserialized).unwrap();
|
||||
assert_eq!(serialized, json_str);
|
||||
}
|
||||
}
|
||||
pub type JsonValueFfi = Value;
|
||||
|
||||
@@ -22,7 +22,7 @@ async fn _generate_session_name(messages: &[Message]) -> Result<String, Provider
|
||||
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
|
||||
});
|
||||
|
||||
generate_session_name(provider_name, provider_config.into(), messages).await
|
||||
generate_session_name(provider_name, provider_config, messages).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -24,7 +24,7 @@ async fn _generate_tooltip(messages: &[Message]) -> Result<String, ProviderError
|
||||
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
|
||||
});
|
||||
|
||||
generate_tooltip(provider_name, provider_config.into(), messages).await
|
||||
generate_tooltip(provider_name, provider_config, messages).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Reference in New Issue
Block a user