[goose-llm] kotlin fn for getting structured outputs (#2547)

This commit is contained in:
Salman Mohammed
2025-05-20 07:08:54 -07:00
committed by GitHub
parent 81332ab914
commit f153204dde
13 changed files with 340 additions and 202 deletions

View File

@@ -324,3 +324,24 @@ win-total-dbg *allparam:
win-total-rls *allparam:
just win-bld-rls{{allparam}}
just win-run-rls
### Build and run the Kotlin example with
### auto-generated bindings for goose-llm
kotlin-example:
# Build Rust dylib and generate Kotlin bindings
cargo build -p goose-llm
cargo run --features=uniffi/cli --bin uniffi-bindgen generate \
--library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
# Compile and run the Kotlin example
cd bindings/kotlin/ && kotlinc \
example/Usage.kt \
uniffi/goose_llm/goose_llm.kt \
-classpath "libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
-include-runtime \
-d example.jar
cd bindings/kotlin/ && java \
-Djna.library.path=$HOME/Development/goose/target/debug \
-classpath "example.jar:libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
UsageKt

View File

@@ -135,6 +135,66 @@ fun main() = runBlocking {
)
val response = completion(req)
println("\nCompletion Response:")
println(response.message)
println("\nCompletion Response:\n${response.message}")
println()
// ---- UI Extraction (custom schema) ----
runUiExtraction(providerName, providerConfig)
}
suspend fun runUiExtraction(providerName: String, providerConfig: String) {
val systemPrompt = "You are a UI generator AI. Convert the user input into a JSON-driven UI."
val messages = listOf(
Message(
role = Role.USER,
created = System.currentTimeMillis() / 1000,
content = listOf(
MessageContent.Text(
TextContent("Make a User Profile Form")
)
)
)
)
val schema = """{
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["div","button","header","section","field","form"]
},
"label": { "type": "string" },
"children": {
"type": "array",
"items": { "${'$'}ref": "#" }
},
"attributes": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"value": { "type": "string" }
},
"required": ["name","value"],
"additionalProperties": false
}
}
},
"required": ["type","label","children","attributes"],
"additionalProperties": false
}""".trimIndent();
try {
val response = generateStructuredOutputs(
providerName = providerName,
providerConfig = providerConfig,
systemPrompt = systemPrompt,
messages = messages,
schema = schema
)
println("\nUI Extraction Output:\n${response}")
} catch (e: ProviderException) {
println("\nUI Extraction failed:\n${e.message}")
}
}

View File

@@ -775,6 +775,8 @@ internal interface IntegrityCheckingUniffiLib : Library {
fun uniffi_goose_llm_checksum_func_generate_session_name(): Short
fun uniffi_goose_llm_checksum_func_generate_structured_outputs(): Short
fun uniffi_goose_llm_checksum_func_generate_tooltip(): Short
fun uniffi_goose_llm_checksum_func_print_messages(): Short
@@ -847,6 +849,14 @@ internal interface UniffiLib : Library {
`messages`: RustBuffer.ByValue,
): Long
fun uniffi_goose_llm_fn_func_generate_structured_outputs(
`providerName`: RustBuffer.ByValue,
`providerConfig`: RustBuffer.ByValue,
`systemPrompt`: RustBuffer.ByValue,
`messages`: RustBuffer.ByValue,
`schema`: RustBuffer.ByValue,
): Long
fun uniffi_goose_llm_fn_func_generate_tooltip(
`providerName`: RustBuffer.ByValue,
`providerConfig`: RustBuffer.ByValue,
@@ -1090,16 +1100,19 @@ private fun uniffiCheckApiChecksums(lib: IntegrityCheckingUniffiLib) {
if (lib.uniffi_goose_llm_checksum_func_completion() != 47457.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 51008.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 39068.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 22809.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 49910.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 9810.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_generate_session_name() != 64087.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 15466.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_generate_structured_outputs() != 43426.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_generate_tooltip() != 41121.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_print_messages() != 30278.toShort()) {
@@ -1626,6 +1639,54 @@ public object FfiConverterTypeProviderCompleteResponse : FfiConverterRustBuffer<
}
}
/**
* Response from a structuredextraction call
*/
data class ProviderExtractResponse(
/**
* The extracted JSON object
*/
var `data`: Value,
/**
* Which model produced it
*/
var `model`: kotlin.String,
/**
* Token usage stats
*/
var `usage`: Usage,
) {
companion object
}
/**
* @suppress
*/
public object FfiConverterTypeProviderExtractResponse : FfiConverterRustBuffer<ProviderExtractResponse> {
override fun read(buf: ByteBuffer): ProviderExtractResponse =
ProviderExtractResponse(
FfiConverterTypeValue.read(buf),
FfiConverterString.read(buf),
FfiConverterTypeUsage.read(buf),
)
override fun allocationSize(value: ProviderExtractResponse) =
(
FfiConverterTypeValue.allocationSize(value.`data`) +
FfiConverterString.allocationSize(value.`model`) +
FfiConverterTypeUsage.allocationSize(value.`usage`)
)
override fun write(
value: ProviderExtractResponse,
buf: ByteBuffer,
) {
FfiConverterTypeValue.write(value.`data`, buf)
FfiConverterString.write(value.`model`, buf)
FfiConverterTypeUsage.write(value.`usage`, buf)
}
}
data class RedactedThinkingContent(
var `data`: kotlin.String,
) {
@@ -1750,6 +1811,46 @@ public object FfiConverterTypeThinkingContent : FfiConverterRustBuffer<ThinkingC
}
}
data class ToolConfig(
var `name`: kotlin.String,
var `description`: kotlin.String,
var `inputSchema`: Value,
var `approvalMode`: ToolApprovalMode,
) {
companion object
}
/**
* @suppress
*/
public object FfiConverterTypeToolConfig : FfiConverterRustBuffer<ToolConfig> {
override fun read(buf: ByteBuffer): ToolConfig =
ToolConfig(
FfiConverterString.read(buf),
FfiConverterString.read(buf),
FfiConverterTypeValue.read(buf),
FfiConverterTypeToolApprovalMode.read(buf),
)
override fun allocationSize(value: ToolConfig) =
(
FfiConverterString.allocationSize(value.`name`) +
FfiConverterString.allocationSize(value.`description`) +
FfiConverterTypeValue.allocationSize(value.`inputSchema`) +
FfiConverterTypeToolApprovalMode.allocationSize(value.`approvalMode`)
)
override fun write(
value: ToolConfig,
buf: ByteBuffer,
) {
FfiConverterString.write(value.`name`, buf)
FfiConverterString.write(value.`description`, buf)
FfiConverterTypeValue.write(value.`inputSchema`, buf)
FfiConverterTypeToolApprovalMode.write(value.`approvalMode`, buf)
}
}
data class ToolRequest(
var `id`: kotlin.String,
var `toolCall`: ToolRequestToolCall,
@@ -2737,34 +2838,6 @@ public object FfiConverterSequenceTypeMessage : FfiConverterRustBuffer<List<Mess
}
}
/**
* @suppress
*/
public object FfiConverterSequenceTypeMessageContent : FfiConverterRustBuffer<List<MessageContent>> {
override fun read(buf: ByteBuffer): List<MessageContent> {
val len = buf.getInt()
return List<MessageContent>(len) {
FfiConverterTypeMessageContent.read(buf)
}
}
override fun allocationSize(value: List<MessageContent>): ULong {
val sizeForLength = 4UL
val sizeForItems = value.map { FfiConverterTypeMessageContent.allocationSize(it) }.sum()
return sizeForLength + sizeForItems
}
override fun write(
value: List<MessageContent>,
buf: ByteBuffer,
) {
buf.putInt(value.size)
value.iterator().forEach {
FfiConverterTypeMessageContent.write(it, buf)
}
}
}
/**
* @suppress
*/
@@ -2793,6 +2866,34 @@ public object FfiConverterSequenceTypeToolConfig : FfiConverterRustBuffer<List<T
}
}
/**
* @suppress
*/
public object FfiConverterSequenceTypeMessageContent : FfiConverterRustBuffer<List<MessageContent>> {
override fun read(buf: ByteBuffer): List<MessageContent> {
val len = buf.getInt()
return List<MessageContent>(len) {
FfiConverterTypeMessageContent.read(buf)
}
}
override fun allocationSize(value: List<MessageContent>): ULong {
val sizeForLength = 4UL
val sizeForItems = value.map { FfiConverterTypeMessageContent.allocationSize(it) }.sum()
return sizeForLength + sizeForItems
}
override fun write(
value: List<MessageContent>,
buf: ByteBuffer,
) {
buf.putInt(value.size)
value.iterator().forEach {
FfiConverterTypeMessageContent.write(it, buf)
}
}
}
/**
* Typealias from the type name used in the UDL file to the builtin type. This
* is needed because the UDL type name is used in function/method signatures.
@@ -2809,22 +2910,6 @@ public typealias FfiConverterTypeCompletionRequest = FfiConverterString
public typealias Contents = List<MessageContent>
public typealias FfiConverterTypeContents = FfiConverterSequenceTypeMessageContent
/**
* Typealias from the type name used in the UDL file to the builtin type. This
* is needed because the UDL type name is used in function/method signatures.
* It's also what we have an external type that references a custom type.
*/
public typealias JsonValueFfi = kotlin.String
public typealias FfiConverterTypeJsonValueFfi = FfiConverterString
/**
* Typealias from the type name used in the UDL file to the builtin type. This
* is needed because the UDL type name is used in function/method signatures.
* It's also what we have an external type that references a custom type.
*/
public typealias ToolConfig = kotlin.String
public typealias FfiConverterTypeToolConfig = FfiConverterString
/**
* Typealias from the type name used in the UDL file to the builtin type. This
* is needed because the UDL type name is used in function/method signatures.
@@ -2841,6 +2926,14 @@ public typealias FfiConverterTypeToolRequestToolCall = FfiConverterString
public typealias ToolResponseToolResult = kotlin.String
public typealias FfiConverterTypeToolResponseToolResult = FfiConverterString
/**
* Typealias from the type name used in the UDL file to the builtin type. This
* is needed because the UDL type name is used in function/method signatures.
* It's also what we have an external type that references a custom type.
*/
public typealias Value = kotlin.String
public typealias FfiConverterTypeValue = FfiConverterString
/**
* Public API for the Goose LLM completion function
*/
@@ -2860,7 +2953,7 @@ suspend fun `completion`(`req`: CompletionRequest): CompletionResponse =
fun `createCompletionRequest`(
`providerName`: kotlin.String,
`providerConfig`: JsonValueFfi,
`providerConfig`: Value,
`modelConfig`: ModelConfig,
`systemPreamble`: kotlin.String,
`messages`: List<Message>,
@@ -2870,7 +2963,7 @@ fun `createCompletionRequest`(
uniffiRustCall { _status ->
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_create_completion_request(
FfiConverterString.lower(`providerName`),
FfiConverterTypeJsonValueFfi.lower(`providerConfig`),
FfiConverterTypeValue.lower(`providerConfig`),
FfiConverterTypeModelConfig.lower(`modelConfig`),
FfiConverterString.lower(`systemPreamble`),
FfiConverterSequenceTypeMessage.lower(`messages`),
@@ -2883,7 +2976,7 @@ fun `createCompletionRequest`(
fun `createToolConfig`(
`name`: kotlin.String,
`description`: kotlin.String,
`inputSchema`: JsonValueFfi,
`inputSchema`: Value,
`approvalMode`: ToolApprovalMode,
): ToolConfig =
FfiConverterTypeToolConfig.lift(
@@ -2891,7 +2984,7 @@ fun `createToolConfig`(
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_create_tool_config(
FfiConverterString.lower(`name`),
FfiConverterString.lower(`description`),
FfiConverterTypeJsonValueFfi.lower(`inputSchema`),
FfiConverterTypeValue.lower(`inputSchema`),
FfiConverterTypeToolApprovalMode.lower(`approvalMode`),
_status,
)
@@ -2905,13 +2998,13 @@ fun `createToolConfig`(
@Suppress("ASSIGNED_BUT_NEVER_ACCESSED_VARIABLE")
suspend fun `generateSessionName`(
`providerName`: kotlin.String,
`providerConfig`: JsonValueFfi,
`providerConfig`: Value,
`messages`: List<Message>,
): kotlin.String =
uniffiRustCallAsync(
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_session_name(
FfiConverterString.lower(`providerName`),
FfiConverterTypeJsonValueFfi.lower(`providerConfig`),
FfiConverterTypeValue.lower(`providerConfig`),
FfiConverterSequenceTypeMessage.lower(`messages`),
),
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
@@ -2923,6 +3016,36 @@ suspend fun `generateSessionName`(
ProviderException.ErrorHandler,
)
/**
* Generates a structured output based on the provided schema,
* system prompt and user messages.
*/
@Throws(ProviderException::class)
@Suppress("ASSIGNED_BUT_NEVER_ACCESSED_VARIABLE")
suspend fun `generateStructuredOutputs`(
`providerName`: kotlin.String,
`providerConfig`: Value,
`systemPrompt`: kotlin.String,
`messages`: List<Message>,
`schema`: Value,
): ProviderExtractResponse =
uniffiRustCallAsync(
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_structured_outputs(
FfiConverterString.lower(`providerName`),
FfiConverterTypeValue.lower(`providerConfig`),
FfiConverterString.lower(`systemPrompt`),
FfiConverterSequenceTypeMessage.lower(`messages`),
FfiConverterTypeValue.lower(`schema`),
),
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },
{ future, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_complete_rust_buffer(future, continuation) },
{ future -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_free_rust_buffer(future) },
// lift function
{ FfiConverterTypeProviderExtractResponse.lift(it) },
// Error FFI converter
ProviderException.ErrorHandler,
)
/**
* Generates a tooltip summarizing the last two messages in the session,
* including any tool calls or results.
@@ -2931,13 +3054,13 @@ suspend fun `generateSessionName`(
@Suppress("ASSIGNED_BUT_NEVER_ACCESSED_VARIABLE")
suspend fun `generateTooltip`(
`providerName`: kotlin.String,
`providerConfig`: JsonValueFfi,
`providerConfig`: Value,
`messages`: List<Message>,
): kotlin.String =
uniffiRustCallAsync(
UniffiLib.INSTANCE.uniffi_goose_llm_fn_func_generate_tooltip(
FfiConverterString.lower(`providerName`),
FfiConverterTypeJsonValueFfi.lower(`providerConfig`),
FfiConverterTypeValue.lower(`providerConfig`),
FfiConverterSequenceTypeMessage.lower(`messages`),
),
{ future, callback, continuation -> UniffiLib.INSTANCE.ffi_goose_llm_rust_future_poll_rust_buffer(future, callback, continuation) },

View File

@@ -31,34 +31,13 @@ Structure:
│ └── goose_llm.kt ← auto-generated bindings
```
#### Create Kotlin bindings:
```bash
# run from project root directory
cargo build -p goose-llm
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
```
#### Kotlin -> Rust: run example
The following `just` command creates kotlin bindings, then compiles and runs an example.
```bash
pushd bindings/kotlin/
kotlinc \
example/Usage.kt \
uniffi/goose_llm/goose_llm.kt \
-classpath "libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
-include-runtime \
-d example.jar
java \
-Djna.library.path=$HOME/Development/goose/target/debug \
-classpath "example.jar:libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
UsageKt
popd
just kotlin-example
```
You will have to download jars in `bindings/kotlin/libs` directory (only the first time):
@@ -70,6 +49,16 @@ curl -O https://repo1.maven.org/maven2/net/java/dev/jna/jna/5.13.0/jna-5.13.0.ja
popd
```
To just create the Kotlin bindings:
```bash
# run from project root directory
cargo build -p goose-llm
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
```
#### Python -> Rust: generate bindings, run example
```bash

View File

@@ -1,5 +1,4 @@
use crate::model::ModelConfig;
use crate::providers::create;
use crate::generate_structured_outputs;
use crate::providers::errors::ProviderError;
use crate::types::core::Role;
use crate::{message::Message, types::json_value_ffi::JsonValueFfi};
@@ -54,15 +53,6 @@ pub async fn generate_session_name(
provider_config: JsonValueFfi,
messages: &[Message],
) -> Result<String, ProviderError> {
// Use OpenAI models specifically for this task
let model_name = if provider_name == "databricks" {
"goose-gpt-4-1"
} else {
"gpt-4.1"
};
let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0));
let provider = create(provider_name, provider_config.into(), model_cfg)?;
// Collect up to the first 3 user messages (truncated to 300 chars each)
let context: Vec<String> = messages
.iter()
@@ -96,10 +86,15 @@ pub async fn generate_session_name(
"required": ["name"],
"additionalProperties": false
});
let user_msg = Message::user().with_text(&user_msg_text);
let resp = provider
.extract(&system_prompt, &[user_msg], &schema)
.await?;
let resp = generate_structured_outputs(
provider_name,
provider_config,
&system_prompt,
&[Message::user().with_text(&user_msg_text)],
schema,
)
.await?;
let obj = resp
.data

View File

@@ -1,6 +1,5 @@
use crate::generate_structured_outputs;
use crate::message::{Message, MessageContent};
use crate::model::ModelConfig;
use crate::providers::create;
use crate::providers::errors::ProviderError;
use crate::types::core::{Content, Role};
use crate::types::json_value_ffi::JsonValueFfi;
@@ -59,16 +58,7 @@ pub async fn generate_tooltip(
provider_config: JsonValueFfi,
messages: &[Message],
) -> Result<String, ProviderError> {
// Use OpenAI models specifically for this task
let model_name = if provider_name == "databricks" {
"goose-gpt-4-1"
} else {
"gpt-4.1"
};
let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0));
let provider = create(provider_name, provider_config.into(), model_cfg)?;
// Need at least two messages to summarize
// Need at least two messages to generate a tooltip
if messages.len() < 2 {
return Err(ProviderError::ExecutionError(
"Need at least two messages to generate a tooltip".to_string(),
@@ -151,11 +141,15 @@ pub async fn generate_tooltip(
"additionalProperties": false
});
// Call extract
let user_msg = Message::user().with_text(&user_msg_text);
let resp = provider
.extract(&system_prompt, &[user_msg], &schema)
.await?;
// Get the structured outputs
let resp = generate_structured_outputs(
provider_name,
provider_config,
&system_prompt,
&[Message::user().with_text(&user_msg_text)],
schema,
)
.await?;
// Pull out the tooltip field
let obj = resp

View File

@@ -6,8 +6,10 @@ pub mod message;
mod model;
mod prompt_template;
pub mod providers;
mod structured_outputs;
pub mod types;
pub use completion::completion;
pub use message::Message;
pub use model::ModelConfig;
pub use structured_outputs::generate_structured_outputs;

View File

@@ -44,7 +44,7 @@ impl ProviderCompleteResponse {
}
/// Response from a structuredextraction call
#[derive(Debug, Clone)]
#[derive(Debug, Clone, uniffi::Record)]
pub struct ProviderExtractResponse {
/// The extracted JSON object
pub data: serde_json::Value,

View File

@@ -0,0 +1,29 @@
use crate::{
providers::{create, errors::ProviderError, ProviderExtractResponse},
types::json_value_ffi::JsonValueFfi,
Message, ModelConfig,
};
/// Generates a structured output based on the provided schema,
/// system prompt and user messages.
#[uniffi::export(async_runtime = "tokio")]
pub async fn generate_structured_outputs(
provider_name: &str,
provider_config: JsonValueFfi,
system_prompt: &str,
messages: &[Message],
schema: JsonValueFfi,
) -> Result<ProviderExtractResponse, ProviderError> {
// Use OpenAI models specifically for this task
let model_name = if provider_name == "databricks" {
"goose-gpt-4-1"
} else {
"gpt-4.1"
};
let model_cfg = ModelConfig::new(model_name.to_string()).with_temperature(Some(0.0));
let provider = create(provider_name, provider_config, model_cfg)?;
let resp = provider.extract(system_prompt, messages, &schema).await?;
Ok(resp)
}

View File

@@ -52,7 +52,7 @@ pub fn create_completion_request(
) -> CompletionRequest {
CompletionRequest::new(
provider_name.to_string(),
provider_config.into(),
provider_config,
model_config,
system_preamble.to_string(),
messages,
@@ -141,11 +141,11 @@ pub enum ToolApprovalMode {
Smart,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
pub struct ToolConfig {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
pub input_schema: JsonValueFfi,
pub approval_mode: ToolApprovalMode,
}
@@ -153,7 +153,7 @@ impl ToolConfig {
pub fn new(
name: &str,
description: &str,
input_schema: serde_json::Value,
input_schema: JsonValueFfi,
approval_mode: ToolApprovalMode,
) -> Self {
Self {
@@ -182,18 +182,9 @@ pub fn create_tool_config(
input_schema: JsonValueFfi,
approval_mode: ToolApprovalMode,
) -> ToolConfig {
ToolConfig::new(name, description, input_schema.into(), approval_mode)
ToolConfig::new(name, description, input_schema, approval_mode)
}
uniffi::custom_type!(ToolConfig, String, {
lower: |tc: &ToolConfig| {
serde_json::to_string(&tc).unwrap()
},
try_lift: |s: String| {
Ok(serde_json::from_str(&s).unwrap())
},
});
// — Register the newtypes with UniFFI, converting via JSON strings —
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]

View File

@@ -1,84 +1,18 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
// `serde_json::Value` gets converted to a `String` to pass across the FFI.
// https://github.com/mozilla/uniffi-rs/blob/main/docs/manual/src/types/custom_types.md?plain=1
// https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/examples/custom-types/src/lib.rs#L63-L69
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct JsonValueFfi(Value);
impl From<JsonValueFfi> for Value {
fn from(val: JsonValueFfi) -> Self {
val.0
}
}
impl From<Value> for JsonValueFfi {
fn from(val: Value) -> Self {
JsonValueFfi(val)
}
}
uniffi::custom_type!(JsonValueFfi, String, {
uniffi::custom_type!(Value, String, {
// Remote is required since 'Value' is from a different crate
remote,
lower: |obj| {
serde_json::to_string(&obj.0).unwrap()
serde_json::to_string(&obj).unwrap()
},
try_lift: |val| {
Ok(serde_json::from_str(&val).unwrap() )
},
});
// Write some tests to ensure that the conversion works as expected
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_json_value_ffi_conversion() {
let original = JsonValueFfi(json!({"key": "value"}));
let serialized = serde_json::to_string(&original).unwrap();
let deserialized: JsonValueFfi = serde_json::from_str(&serialized).unwrap();
assert_eq!(original.0, deserialized.0);
}
#[test]
fn test_json_value_ffi_to_serde() {
let original = JsonValueFfi(json!({"key": "value"}));
let value: Value = original.into();
assert_eq!(value, json!({"key": "value"}));
}
#[test]
fn test_json_value_ffi_from_serde() {
let value = json!({"key": "value"});
let original: JsonValueFfi = value.into();
assert_eq!(original.0, json!({"key": "value"}));
}
#[test]
fn test_json_value_ffi_lower() {
let original = JsonValueFfi(json!({"key": "value"}));
let serialized = serde_json::to_string(&original).unwrap();
assert_eq!(serialized, "{\"key\":\"value\"}");
}
#[test]
fn test_json_value_ffi_try_lift() {
let json_str = "{\"key\":\"value\"}";
let deserialized: JsonValueFfi = serde_json::from_str(json_str).unwrap();
let expected = JsonValueFfi(json!({"key": "value"}));
assert_eq!(deserialized.0, expected.0);
}
#[test]
fn test_json_value_ffi_custom_type() {
let json_str = "{\"key\":\"value\"}";
let deserialized: JsonValueFfi = serde_json::from_str(json_str).unwrap();
let serialized = serde_json::to_string(&deserialized).unwrap();
assert_eq!(serialized, json_str);
}
}
pub type JsonValueFfi = Value;

View File

@@ -22,7 +22,7 @@ async fn _generate_session_name(messages: &[Message]) -> Result<String, Provider
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
});
generate_session_name(provider_name, provider_config.into(), messages).await
generate_session_name(provider_name, provider_config, messages).await
}
#[tokio::test]

View File

@@ -24,7 +24,7 @@ async fn _generate_tooltip(messages: &[Message]) -> Result<String, ProviderError
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
});
generate_tooltip(provider_name, provider_config.into(), messages).await
generate_tooltip(provider_name, provider_config, messages).await
}
#[tokio::test]