[goose-llm] autogenate kotlin bindings using uniffi-rs proc macros (#2478)

This commit is contained in:
Salman Mohammed
2025-05-09 12:15:38 -04:00
committed by GitHub
parent 77146e5035
commit b7dd3aba73
27 changed files with 4262 additions and 651 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
*.jar
run_cli.sh
tokenizer_files/
.DS_Store

289
Cargo.lock generated
View File

@@ -201,6 +201,48 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "askama"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d4744ed2eef2645831b441d8f5459689ade2ab27c854488fbab1fbe94fce1a7"
dependencies = [
"askama_derive",
"itoa",
"percent-encoding",
"serde",
"serde_json",
]
[[package]]
name = "askama_derive"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d661e0f57be36a5c14c48f78d09011e67e0cb618f269cca9f2fd8d15b68c46ac"
dependencies = [
"askama_parser",
"basic-toml",
"memchr",
"proc-macro2",
"quote",
"rustc-hash 2.1.1",
"serde",
"serde_derive",
"syn 2.0.99",
]
[[package]]
name = "askama_parser"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf315ce6524c857bb129ff794935cf6d42c82a6cff60526fe2a63593de4d0d4f"
dependencies = [
"memchr",
"serde",
"serde_derive",
"winnow",
]
[[package]]
name = "assert-json-diff"
version = "2.0.2"
@@ -211,6 +253,19 @@ dependencies = [
"serde_json",
]
[[package]]
name = "async-compat"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bab94bde396a3f7b4962e396fdad640e241ed797d4d8d77fc8c237d14c58fc0"
dependencies = [
"futures-core",
"futures-io",
"once_cell",
"pin-project-lite",
"tokio",
]
[[package]]
name = "async-compression"
version = "0.4.20"
@@ -843,6 +898,15 @@ dependencies = [
"vsimd",
]
[[package]]
name = "basic-toml"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba62675e8242a4c4e806d12f11d136e626e6c8361d6b829310732241652a178a"
dependencies = [
"serde",
]
[[package]]
name = "bat"
version = "0.24.0"
@@ -1086,6 +1150,38 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d2c12f985c78475a6b8d629afd0c360260ef34cfef52efccdcfd31972f81c2e"
[[package]]
name = "camino"
version = "1.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3"
dependencies = [
"serde",
]
[[package]]
name = "cargo-platform"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e35af189006b9c0f00a064685c727031e3ed2d8020f7ba284d78cc2671bd36ea"
dependencies = [
"serde",
]
[[package]]
name = "cargo_metadata"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd5eb614ed4c27c5d706420e4320fbe3216ab31fa1c33cd8246ac36dae4479ba"
dependencies = [
"camino",
"cargo-platform",
"semver",
"serde",
"serde_json",
"thiserror 2.0.12",
]
[[package]]
name = "cast"
version = "0.3.0"
@@ -2100,6 +2196,15 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa"
[[package]]
name = "fs-err"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88a41f105fe1d5b6b34b2055e3dc59bb79b46b48b2040b9e6c7b4b5de097aa41"
dependencies = [
"autocfg",
]
[[package]]
name = "fs2"
version = "0.4.3"
@@ -2302,6 +2407,17 @@ dependencies = [
"regex-syntax 0.8.5",
]
[[package]]
name = "goblin"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b363a30c165f666402fe6a3024d3bec7ebc898f96a4a23bd1c99f8dbf3f4f47"
dependencies = [
"log",
"plain",
"scroll",
]
[[package]]
name = "google-apis-common"
version = "7.0.0"
@@ -2544,6 +2660,7 @@ dependencies = [
"thiserror 1.0.69",
"tokio",
"tracing",
"uniffi",
"url",
]
@@ -4413,6 +4530,12 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plain"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6"
[[package]]
name = "plist"
version = "1.7.0"
@@ -5279,6 +5402,26 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "scroll"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ab8598aa408498679922eff7fa985c25d58a90771bd6be794434c5277eab1a6"
dependencies = [
"scroll_derive",
]
[[package]]
name = "scroll_derive"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1783eabc414609e28a5ba76aee5ddd52199f7107a0b24c2e9746a1ecc34a683d"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.99",
]
[[package]]
name = "sct"
version = "0.7.1"
@@ -5342,6 +5485,9 @@ name = "semver"
version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
dependencies = [
"serde",
]
[[package]]
name = "serde"
@@ -5579,6 +5725,12 @@ dependencies = [
"time",
]
[[package]]
name = "siphasher"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d"
[[package]]
name = "slab"
version = "0.4.9"
@@ -5631,6 +5783,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "std_prelude"
version = "0.2.12"
@@ -6464,6 +6622,128 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "uniffi"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dcd1d240101ba3b9d7532ae86d9cb64d9a7ff63e13a2b7b9e94a32a601d8233"
dependencies = [
"anyhow",
"camino",
"cargo_metadata",
"clap 4.5.31",
"uniffi_bindgen",
"uniffi_core",
"uniffi_macros",
"uniffi_pipeline",
]
[[package]]
name = "uniffi_bindgen"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d0525f06d749ea80d8049dc0bb038bb87941e3d909eefa76b6f0a5589b59ac5"
dependencies = [
"anyhow",
"askama",
"camino",
"cargo_metadata",
"fs-err",
"glob",
"goblin",
"heck 0.5.0",
"indexmap 2.7.1",
"once_cell",
"serde",
"tempfile",
"textwrap",
"toml 0.5.11",
"uniffi_internal_macros",
"uniffi_meta",
"uniffi_pipeline",
"uniffi_udl",
]
[[package]]
name = "uniffi_core"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3fa8eb4d825b4ed095cb13483cba6927c3002b9eb603cef9b7688758cc3772e"
dependencies = [
"anyhow",
"async-compat",
"bytes",
"once_cell",
"static_assertions",
]
[[package]]
name = "uniffi_internal_macros"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83b547d69d699e52f2129fde4b57ae0d00b5216e59ed5b56097c95c86ba06095"
dependencies = [
"anyhow",
"indexmap 2.7.1",
"proc-macro2",
"quote",
"syn 2.0.99",
]
[[package]]
name = "uniffi_macros"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00f1de72edc8cb9201c7d650e3678840d143e4499004571aac49e6cb1b17da43"
dependencies = [
"camino",
"fs-err",
"once_cell",
"proc-macro2",
"quote",
"serde",
"syn 2.0.99",
"toml 0.5.11",
"uniffi_meta",
]
[[package]]
name = "uniffi_meta"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acc9204632f6a555b2cba7c8852c5523bc1aa5f3eff605c64af5054ea28b72e"
dependencies = [
"anyhow",
"siphasher",
"uniffi_internal_macros",
"uniffi_pipeline",
]
[[package]]
name = "uniffi_pipeline"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54b5336a9a925b358183837d31541d12590b7fcec373256d3770de02dff24c69"
dependencies = [
"anyhow",
"heck 0.5.0",
"indexmap 2.7.1",
"tempfile",
"uniffi_internal_macros",
]
[[package]]
name = "uniffi_udl"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f95e73373d85f04736bc51997d3e6855721144ec4384cae9ca8513c80615e129"
dependencies = [
"anyhow",
"textwrap",
"uniffi_meta",
"weedle2",
]
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
@@ -6757,6 +7037,15 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "weedle2"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "998d2c24ec099a87daf9467808859f9d82b61f1d9c9701251aea037f514eae0e"
dependencies = [
"nom",
]
[[package]]
name = "weezl"
version = "0.1.8"

View File

@@ -0,0 +1,130 @@
import kotlinx.coroutines.runBlocking
import uniffi.goose_llm.*
fun main() = runBlocking {
val now = System.currentTimeMillis() / 1000
val msgs = listOf(
// 1) User sends a plain-text prompt
Message(
role = Role.USER,
created = now,
content = listOf(
MessageContent.Text(
TextContent("What is 7 x 6?")
)
)
),
// 2) Assistant makes a tool request (ToolReq) to calculate 7×6
Message(
role = Role.ASSISTANT,
created = now + 2,
content = listOf(
MessageContent.ToolReq(
ToolRequest(
id = "calc1",
toolCall = """
{
"status": "success",
"value": {
"name": "calculator_extension__toolname",
"arguments": {
"operation": "multiply",
"numbers": [7, 6]
},
"needsApproval": false
}
}
""".trimIndent()
)
)
)
),
// 3) User (on behalf of the tool) responds with the tool result (ToolResp)
Message(
role = Role.USER,
created = now + 3,
content = listOf(
MessageContent.ToolResp(
ToolResponse(
id = "calc1",
toolResult = """
{
"status": "success",
"value": [
{"type": "text", "text": "42"}
]
}
""".trimIndent()
)
)
)
),
)
printMessages(msgs)
println("---\n")
val sessionName = generateSessionName(msgs)
println("Session Name: $sessionName")
val tooltip = generateTooltip(msgs)
println("Tooltip: $tooltip")
// Completion
val provider = "databricks"
val modelName = "goose-gpt-4-1"
val modelConfig = ModelConfig(
modelName,
100000u, // UInt
0.1f, // Float
200 // Int
)
val calculatorTool = createToolConfig(
name = "calculator",
description = "Perform basic arithmetic operations",
inputSchema = """
{
"type": "object",
"required": ["operation", "numbers"],
"properties": {
"operation": {
"type": "string",
"enum": ["add", "subtract", "multiply", "divide"],
"description": "The arithmetic operation to perform"
},
"numbers": {
"type": "array",
"items": { "type": "number" },
"description": "List of numbers to operate on in order"
}
}
}
""".trimIndent(),
approvalMode = ToolApprovalMode.AUTO
)
val calculator_extension = ExtensionConfig(
name = "calculator_extension",
instructions = "This extension provides a calculator tool.",
tools = listOf(calculatorTool)
)
val extensions = listOf(calculator_extension)
val systemPreamble = "You are a helpful assistant."
val req = CompletionRequest(
provider,
modelConfig,
systemPreamble,
msgs,
extensions
)
val response = completion(req)
println("\nCompletion Response:")
println(response.message)
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,8 +7,11 @@ license.workspace = true
repository.workspace = true
description.workspace = true
[lib]
crate-type = ["lib", "cdylib"]
name = "goose_llm"
[dependencies]
tokio = { version = "1.43", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
anyhow = "1.0"
@@ -36,6 +39,9 @@ regex = "1.11.1"
tracing = "0.1"
smallvec = { version = "1.13", features = ["serde"] }
indoc = "1.0"
# https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/fixtures/futures/Cargo.toml#L22
uniffi = { version = "0.29", features = ["tokio", "cli", "scaffolding-ffi-buffer-fns"] }
tokio = { version = "1.43", features = ["time", "sync"] }
[dev-dependencies]
criterion = "0.5"
@@ -43,7 +49,12 @@ tempfile = "3.15.0"
dotenv = "0.15"
lazy_static = "1.5"
ctor = "0.2.7"
tokio = { version = "1.43", features = ["full"] }
[[bin]]
# https://mozilla.github.io/uniffi-rs/latest/tutorial/foreign_language_bindings.html
name = "uniffi-bindgen"
path = "uniffi-bindgen.rs"
[[example]]
name = "simple"

View File

@@ -1,4 +1,4 @@
### goose-llm
## goose-llm
This crate is meant to be used for foreign function interface (FFI). It's meant to be
stateless and contain logic related to providers and prompts:
@@ -12,3 +12,59 @@ Run:
cargo run -p goose-llm --example simple
```
## Kotlin bindings
Structure:
```
.
└── crates
└── goose-llm/...
└── target
└── debug/libgoose_llm.dylib
├── bindings
│ └── kotlin
│ ├── example
│ │ └── Usage.kt ← your demo app
│ └── uniffi
│ └── goose_llm
│ └── goose_llm.kt ← auto-generated bindings
```
Create Kotlin bindings:
```
# run from project root directory
cargo build -p goose-llm
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
```
Download jars in `bindings/kotlin/libs` directory (only need to do this once):
```
pushd bindings/kotlin/libs/
curl -O https://repo1.maven.org/maven2/org/jetbrains/kotlin/kotlin-stdlib/1.9.0/kotlin-stdlib-1.9.0.jar
curl -O https://repo1.maven.org/maven2/org/jetbrains/kotlinx/kotlinx-coroutines-core-jvm/1.7.3/kotlinx-coroutines-core-jvm-1.7.3.jar
curl -O https://repo1.maven.org/maven2/net/java/dev/jna/jna/5.13.0/jna-5.13.0.jar
popd
```
Compile & Run usage example from Kotlin -> Rust:
```
pushd bindings/kotlin/
kotlinc \
example/Usage.kt \
uniffi/goose_llm/goose_llm.kt \
-classpath "libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
-include-runtime \
-d example.jar
java \
-Djna.library.path=$HOME/Development/goose/target/debug \
-classpath "example.jar:libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
UsageKt
popd
```

View File

@@ -93,13 +93,13 @@ async fn main() -> Result<()> {
println!("\n---------------\n");
println!("User Input: {text}");
let messages = vec![Message::user().with_text(text)];
let completion_response: CompletionResponse = completion(CompletionRequest::new(
provider,
model_config.clone(),
system_preamble,
&messages,
&extensions,
))
let completion_response: CompletionResponse = completion(CompletionRequest {
provider_name: provider.to_string(),
model_config: model_config.clone(),
system_preamble: system_preamble.to_string(),
messages: messages,
extensions: extensions.clone(),
})
.await?;
// Print the response
println!("\nCompletion Response:");

View File

@@ -17,32 +17,40 @@ use crate::{
},
};
#[uniffi::export]
pub fn print_messages(messages: Vec<Message>) {
for msg in messages {
println!("[{:?} @ {}] {:?}", msg.role, msg.created, msg.content);
}
}
/// Public API for the Goose LLM completion function
pub async fn completion(req: CompletionRequest<'_>) -> Result<CompletionResponse, CompletionError> {
#[uniffi::export(async_runtime = "tokio")]
pub async fn completion(req: CompletionRequest) -> Result<CompletionResponse, CompletionError> {
let start_total = Instant::now();
let provider = create(req.provider_name, req.model_config)
let provider = create(&req.provider_name, req.model_config)
.map_err(|_| CompletionError::UnknownProvider(req.provider_name.to_string()))?;
let system_prompt = construct_system_prompt(req.system_preamble, req.extensions)?;
let tools = collect_prefixed_tools(req.extensions);
let system_prompt = construct_system_prompt(&req.system_preamble, &req.extensions)?;
let tools = collect_prefixed_tools(&req.extensions);
// Call the LLM provider
let start_provider = Instant::now();
let mut response = provider
.complete(&system_prompt, req.messages, &tools)
.complete(&system_prompt, &req.messages, &tools)
.await?;
let provider_elapsed_ms = start_provider.elapsed().as_millis();
let provider_elapsed_sec = start_provider.elapsed().as_secs_f32();
let usage_tokens = response.usage.total_tokens;
let tool_configs = collect_prefixed_tool_configs(req.extensions);
let tool_configs = collect_prefixed_tool_configs(&req.extensions);
update_needs_approval_for_tool_calls(&mut response.message, &tool_configs)?;
Ok(CompletionResponse::new(
response.message,
response.model,
response.usage,
calculate_runtime_metrics(start_total, provider_elapsed_ms, usage_tokens),
calculate_runtime_metrics(start_total, provider_elapsed_sec, usage_tokens),
))
}
@@ -81,8 +89,8 @@ pub fn update_needs_approval_for_tool_calls(
tool_configs: &HashMap<String, ToolConfig>,
) -> Result<(), CompletionError> {
for content in &mut message.content.iter_mut() {
if let MessageContent::ToolRequest(req) = content {
if let Ok(call) = &mut req.tool_call {
if let MessageContent::ToolReq(req) = content {
if let Ok(call) = &mut req.tool_call.0 {
// Provide a clear error message when the tool config is missing
let config = tool_configs.get(&call.name).ok_or_else(|| {
CompletionError::ToolNotFound(format!(
@@ -117,16 +125,16 @@ fn collect_prefixed_tool_configs(extensions: &[ExtensionConfig]) -> HashMap<Stri
/// Compute runtime metrics for the request.
fn calculate_runtime_metrics(
total_start: Instant,
provider_elapsed_ms: u128,
provider_elapsed_sec: f32,
token_count: Option<i32>,
) -> RuntimeMetrics {
let total_ms = total_start.elapsed().as_millis();
let total_ms = total_start.elapsed().as_secs_f32();
let tokens_per_sec = token_count.and_then(|toks| {
if provider_elapsed_ms > 0 {
Some(toks as f64 / (provider_elapsed_ms as f64 / 1_000.0))
if provider_elapsed_sec > 0.0 {
Some(toks as f64 / (provider_elapsed_sec as f64))
} else {
None
}
});
RuntimeMetrics::new(total_ms, provider_elapsed_ms, tokens_per_sec)
RuntimeMetrics::new(total_ms, provider_elapsed_sec, tokens_per_sec)
}

View File

@@ -49,6 +49,7 @@ fn build_system_prompt() -> String {
}
/// Generates a short (≤4 words) session name
#[uniffi::export(async_runtime = "tokio")]
pub async fn generate_session_name(messages: &[Message]) -> Result<String, ProviderError> {
// Collect up to the first 3 user messages (truncated to 300 chars each)
let context: Vec<String> = messages

View File

@@ -53,6 +53,7 @@ fn build_system_prompt() -> String {
/// Generates a tooltip summarizing the last two messages in the session,
/// including any tool calls or results.
#[uniffi::export(async_runtime = "tokio")]
pub async fn generate_tooltip(messages: &[Message]) -> Result<String, ProviderError> {
// Need at least two messages to summarize
if messages.len() < 2 {
@@ -72,17 +73,17 @@ pub async fn generate_tooltip(messages: &[Message]) -> Result<String, ProviderEr
parts.push(txt.to_string());
}
}
MessageContent::ToolRequest(req) => {
if let Ok(tool_call) = &req.tool_call {
MessageContent::ToolReq(req) => {
if let Ok(tool_call) = &req.tool_call.0 {
parts.push(format!(
"called tool '{}' with args {}",
tool_call.name, tool_call.arguments
));
} else if let Err(e) = &req.tool_call {
} else if let Err(e) = &req.tool_call.0 {
parts.push(format!("tool request error: {}", e));
}
}
MessageContent::ToolResponse(resp) => match &resp.tool_result {
MessageContent::ToolResp(resp) => match &resp.tool_result.0 {
Ok(contents) => {
let results: Vec<String> = contents
.iter()

View File

@@ -1,3 +1,5 @@
uniffi::setup_scaffolding!();
mod completion;
pub mod extractors;
pub mod message;

View File

@@ -1,531 +0,0 @@
use std::{collections::HashSet, iter::FromIterator, ops::Deref};
/// Messages which represent the content sent back and forth to LLM provider
///
/// We use these messages in the agent code, and interfaces which interact with
/// the agent. That let's us reuse message histories across different interfaces.
///
/// The content of the messages uses MCP types to avoid additional conversions
/// when interacting with MCP servers.
use chrono::Utc;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use crate::types::core::{Content, ImageContent, Role, TextContent, ToolCall, ToolResult};
mod tool_result_serde;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolRequest {
pub id: String,
#[serde(with = "tool_result_serde")]
pub tool_call: ToolResult<ToolCall>,
}
impl ToolRequest {
pub fn to_readable_string(&self) -> String {
match &self.tool_call {
Ok(tool_call) => {
format!(
"Tool: {}, Args: {}",
tool_call.name,
serde_json::to_string_pretty(&tool_call.arguments)
.unwrap_or_else(|_| "<<invalid json>>".to_string())
)
}
Err(e) => format!("Invalid tool call: {}", e),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolResponse {
pub id: String,
#[serde(with = "tool_result_serde")]
pub tool_result: ToolResult<Vec<Content>>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ThinkingContent {
pub thinking: String,
pub signature: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RedactedThinkingContent {
pub data: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
/// Content passed inside a message, which can be both simple content and tool content
#[serde(tag = "type", rename_all = "camelCase")]
pub enum MessageContent {
Text(TextContent),
Image(ImageContent),
ToolRequest(ToolRequest),
ToolResponse(ToolResponse),
Thinking(ThinkingContent),
RedactedThinking(RedactedThinkingContent),
}
impl MessageContent {
pub fn text<S: Into<String>>(text: S) -> Self {
MessageContent::Text(TextContent { text: text.into() })
}
pub fn image<S: Into<String>, T: Into<String>>(data: S, mime_type: T) -> Self {
MessageContent::Image(ImageContent {
data: data.into(),
mime_type: mime_type.into(),
})
}
pub fn tool_request<S: Into<String>>(id: S, tool_call: ToolResult<ToolCall>) -> Self {
MessageContent::ToolRequest(ToolRequest {
id: id.into(),
tool_call,
})
}
pub fn tool_response<S: Into<String>>(id: S, tool_result: ToolResult<Vec<Content>>) -> Self {
MessageContent::ToolResponse(ToolResponse {
id: id.into(),
tool_result,
})
}
pub fn thinking<S1: Into<String>, S2: Into<String>>(thinking: S1, signature: S2) -> Self {
MessageContent::Thinking(ThinkingContent {
thinking: thinking.into(),
signature: signature.into(),
})
}
pub fn redacted_thinking<S: Into<String>>(data: S) -> Self {
MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() })
}
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
if let MessageContent::ToolRequest(ref tool_request) = self {
Some(tool_request)
} else {
None
}
}
pub fn as_tool_response(&self) -> Option<&ToolResponse> {
if let MessageContent::ToolResponse(ref tool_response) = self {
Some(tool_response)
} else {
None
}
}
pub fn as_tool_response_text(&self) -> Option<String> {
if let Some(tool_response) = self.as_tool_response() {
if let Ok(contents) = &tool_response.tool_result {
let texts: Vec<String> = contents
.iter()
.filter_map(|content| content.as_text().map(String::from))
.collect();
if !texts.is_empty() {
return Some(texts.join("\n"));
}
}
}
None
}
pub fn as_tool_request_id(&self) -> Option<&str> {
if let Self::ToolRequest(r) = self {
Some(&r.id)
} else {
None
}
}
pub fn as_tool_response_id(&self) -> Option<&str> {
if let Self::ToolResponse(r) = self {
Some(&r.id)
} else {
None
}
}
/// Get the text content if this is a TextContent variant
pub fn as_text(&self) -> Option<&str> {
match self {
MessageContent::Text(text) => Some(&text.text),
_ => None,
}
}
/// Get the thinking content if this is a ThinkingContent variant
pub fn as_thinking(&self) -> Option<&ThinkingContent> {
match self {
MessageContent::Thinking(thinking) => Some(thinking),
_ => None,
}
}
/// Get the redacted thinking content if this is a RedactedThinkingContent variant
pub fn as_redacted_thinking(&self) -> Option<&RedactedThinkingContent> {
match self {
MessageContent::RedactedThinking(redacted) => Some(redacted),
_ => None,
}
}
pub fn is_text(&self) -> bool {
matches!(self, Self::Text(_))
}
pub fn is_image(&self) -> bool {
matches!(self, Self::Image(_))
}
pub fn is_tool_request(&self) -> bool {
matches!(self, Self::ToolRequest(_))
}
pub fn is_tool_response(&self) -> bool {
matches!(self, Self::ToolResponse(_))
}
}
impl From<Content> for MessageContent {
fn from(content: Content) -> Self {
match content {
Content::Text(text) => MessageContent::Text(text),
Content::Image(image) => MessageContent::Image(image),
}
}
}
// ────────────────────────────────────────────────────────────────────────────
// 2. Contents a new-type wrapper around SmallVec
// ────────────────────────────────────────────────────────────────────────────
/// Holds the heterogeneous fragments that make up one chat message.
///
/// * Up to two items are stored inline on the stack.
/// * Falls back to a heap allocation only when necessary.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(transparent)]
pub struct Contents(SmallVec<[MessageContent; 2]>);
impl Contents {
/*----------------------------------------------------------
* 1-line ergonomic helpers
*---------------------------------------------------------*/
pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, MessageContent> {
self.0.iter_mut()
}
pub fn push(&mut self, item: impl Into<MessageContent>) {
self.0.push(item.into());
}
pub fn texts(&self) -> impl Iterator<Item = &str> {
self.0.iter().filter_map(|c| c.as_text())
}
pub fn concat_text_str(&self) -> String {
self.texts().collect::<Vec<_>>().join("\n")
}
/// Returns `true` if *any* item satisfies the predicate.
pub fn any_is<P>(&self, pred: P) -> bool
where
P: FnMut(&MessageContent) -> bool,
{
self.iter().any(pred)
}
/// Returns `true` if *every* item satisfies the predicate.
pub fn all_are<P>(&self, pred: P) -> bool
where
P: FnMut(&MessageContent) -> bool,
{
self.iter().all(pred)
}
}
impl From<Vec<MessageContent>> for Contents {
fn from(v: Vec<MessageContent>) -> Self {
Contents(SmallVec::from_vec(v))
}
}
impl FromIterator<MessageContent> for Contents {
fn from_iter<I: IntoIterator<Item = MessageContent>>(iter: I) -> Self {
Contents(SmallVec::from_iter(iter))
}
}
/*--------------------------------------------------------------
* Allow &message.content to behave like a slice of fragments.
*-------------------------------------------------------------*/
impl Deref for Contents {
type Target = [MessageContent];
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
/// A message to or from an LLM
#[serde(rename_all = "camelCase")]
pub struct Message {
pub role: Role,
pub created: i64,
pub content: Contents,
}
impl Message {
pub fn new(role: Role) -> Self {
Self {
role,
created: Utc::now().timestamp_millis(),
content: Contents::default(),
}
}
/// Create a new user message with the current timestamp
pub fn user() -> Self {
Self::new(Role::User)
}
/// Create a new assistant message with the current timestamp
pub fn assistant() -> Self {
Self::new(Role::Assistant)
}
/// Add any item that implements Into<MessageContent> to the message
pub fn with_content(mut self, item: impl Into<MessageContent>) -> Self {
self.content.push(item);
self
}
/// Add text content to the message
pub fn with_text<S: Into<String>>(self, text: S) -> Self {
self.with_content(MessageContent::text(text))
}
/// Add image content to the message
pub fn with_image<S: Into<String>, T: Into<String>>(self, data: S, mime_type: T) -> Self {
self.with_content(MessageContent::image(data, mime_type))
}
/// Add a tool request to the message
pub fn with_tool_request<S: Into<String>>(
self,
id: S,
tool_call: ToolResult<ToolCall>,
) -> Self {
self.with_content(MessageContent::tool_request(id, tool_call))
}
/// Add a tool response to the message
pub fn with_tool_response<S: Into<String>>(
self,
id: S,
result: ToolResult<Vec<Content>>,
) -> Self {
self.with_content(MessageContent::tool_response(id, result))
}
/// Add thinking content to the message
pub fn with_thinking<S1: Into<String>, S2: Into<String>>(
self,
thinking: S1,
signature: S2,
) -> Self {
self.with_content(MessageContent::thinking(thinking, signature))
}
/// Add redacted thinking content to the message
pub fn with_redacted_thinking<S: Into<String>>(self, data: S) -> Self {
self.with_content(MessageContent::redacted_thinking(data))
}
/// Check if the message is a tool call
pub fn contains_tool_call(&self) -> bool {
self.content.any_is(MessageContent::is_tool_request)
}
/// Check if the message is a tool response
pub fn contains_tool_response(&self) -> bool {
self.content.any_is(MessageContent::is_tool_response)
}
/// Check if the message contains only text content
pub fn has_only_text_content(&self) -> bool {
self.content.all_are(MessageContent::is_text)
}
/// Retrieves all tool `id` from ToolRequest messages
pub fn tool_request_ids(&self) -> HashSet<&str> {
self.content
.iter()
.filter_map(MessageContent::as_tool_request_id)
.collect()
}
/// Retrieves all tool `id` from ToolResponse messages
pub fn tool_response_ids(&self) -> HashSet<&str> {
self.content
.iter()
.filter_map(MessageContent::as_tool_response_id)
.collect()
}
/// Retrieves all tool `id` from the message
pub fn tool_ids(&self) -> HashSet<&str> {
self.tool_request_ids()
.into_iter()
.chain(self.tool_response_ids())
.collect()
}
}
#[cfg(test)]
mod tests {
use serde_json::{json, Value};
use super::*;
use crate::types::core::ToolError;
#[test]
fn test_message_serialization() {
let message = Message::assistant()
.with_text("Hello, I'll help you with that.")
.with_tool_request(
"tool123",
Ok(ToolCall::new("test_tool", json!({"param": "value"}))),
);
let json_str = serde_json::to_string_pretty(&message).unwrap();
println!("Serialized message: {}", json_str);
// Parse back to Value to check structure
let value: Value = serde_json::from_str(&json_str).unwrap();
// Check top-level fields
assert_eq!(value["role"], "assistant");
assert!(value["created"].is_i64());
assert!(value["content"].is_array());
// Check content items
let content = &value["content"];
// First item should be text
assert_eq!(content[0]["type"], "text");
assert_eq!(content[0]["text"], "Hello, I'll help you with that.");
// Second item should be toolRequest
assert_eq!(content[1]["type"], "toolRequest");
assert_eq!(content[1]["id"], "tool123");
// Check tool_call serialization
assert_eq!(content[1]["toolCall"]["status"], "success");
assert_eq!(content[1]["toolCall"]["value"]["name"], "test_tool");
assert_eq!(
content[1]["toolCall"]["value"]["arguments"]["param"],
"value"
);
}
#[test]
fn test_error_serialization() {
let message = Message::assistant().with_tool_request(
"tool123",
Err(ToolError::ExecutionError(
"Something went wrong".to_string(),
)),
);
let json_str = serde_json::to_string_pretty(&message).unwrap();
println!("Serialized error: {}", json_str);
// Parse back to Value to check structure
let value: Value = serde_json::from_str(&json_str).unwrap();
// Check tool_call serialization with error
let tool_call = &value["content"][0]["toolCall"];
assert_eq!(tool_call["status"], "error");
assert_eq!(tool_call["error"], "Execution failed: Something went wrong");
}
#[test]
fn test_deserialization() {
// Create a JSON string with our new format
let json_str = r#"{
"role": "assistant",
"created": 1740171566,
"content": [
{
"type": "text",
"text": "I'll help you with that."
},
{
"type": "toolRequest",
"id": "tool123",
"toolCall": {
"status": "success",
"value": {
"name": "test_tool",
"arguments": {"param": "value"},
"needsApproval": false
}
}
}
]
}"#;
let message: Message = serde_json::from_str(json_str).unwrap();
assert_eq!(message.role, Role::Assistant);
assert_eq!(message.created, 1740171566);
assert_eq!(message.content.len(), 2);
// Check first content item
if let MessageContent::Text(text) = &message.content[0] {
assert_eq!(text.text, "I'll help you with that.");
} else {
panic!("Expected Text content");
}
// Check second content item
if let MessageContent::ToolRequest(req) = &message.content[1] {
assert_eq!(req.id, "tool123");
if let Ok(tool_call) = &req.tool_call {
assert_eq!(tool_call.name, "test_tool");
assert_eq!(tool_call.arguments, json!({"param": "value"}));
} else {
panic!("Expected successful tool call");
}
} else {
panic!("Expected ToolRequest content");
}
}
#[test]
fn test_message_with_text() {
let message = Message::user().with_text("Hello");
assert_eq!(message.content.concat_text_str(), "Hello");
}
#[test]
fn test_message_with_tool_request() {
let tool_call = Ok(ToolCall::new("test_tool", json!({})));
let message = Message::assistant().with_tool_request("req1", tool_call);
assert!(message.contains_tool_call());
assert!(!message.contains_tool_response());
let ids = message.tool_ids();
assert_eq!(ids.len(), 1);
assert!(ids.contains("req1"));
}
}

View File

@@ -0,0 +1,84 @@
use std::{iter::FromIterator, ops::Deref};
use crate::message::MessageContent;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
/// Holds the heterogeneous fragments that make up one chat message.
///
/// * Up to two items are stored inline on the stack.
/// * Falls back to a heap allocation only when necessary.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(transparent)]
pub struct Contents(SmallVec<[MessageContent; 2]>);
impl Contents {
/*----------------------------------------------------------
* 1-line ergonomic helpers
*---------------------------------------------------------*/
pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, MessageContent> {
self.0.iter_mut()
}
pub fn push(&mut self, item: impl Into<MessageContent>) {
self.0.push(item.into());
}
pub fn texts(&self) -> impl Iterator<Item = &str> {
self.0.iter().filter_map(|c| c.as_text())
}
pub fn concat_text_str(&self) -> String {
self.texts().collect::<Vec<_>>().join("\n")
}
/// Returns `true` if *any* item satisfies the predicate.
pub fn any_is<P>(&self, pred: P) -> bool
where
P: FnMut(&MessageContent) -> bool,
{
self.iter().any(pred)
}
/// Returns `true` if *every* item satisfies the predicate.
pub fn all_are<P>(&self, pred: P) -> bool
where
P: FnMut(&MessageContent) -> bool,
{
self.iter().all(pred)
}
}
impl From<Vec<MessageContent>> for Contents {
fn from(v: Vec<MessageContent>) -> Self {
Contents(SmallVec::from_vec(v))
}
}
impl FromIterator<MessageContent> for Contents {
fn from_iter<I: IntoIterator<Item = MessageContent>>(iter: I) -> Self {
Contents(SmallVec::from_iter(iter))
}
}
/*--------------------------------------------------------------
* Allow &message.content to behave like a slice of fragments.
*-------------------------------------------------------------*/
impl Deref for Contents {
type Target = [MessageContent];
fn deref(&self) -> &Self::Target {
&self.0
}
}
// — Register the contents type with UniFFI, converting to/from Vec<MessageContent> —
// We need to do this because UniFFIs FFI layer supports only primitive buffers (here Vec<u8>),
uniffi::custom_type!(Contents, Vec<MessageContent>, {
lower: |contents: &Contents| {
contents.0.to_vec()
},
try_lift: |contents: Vec<MessageContent>| {
Ok(Contents::from(contents))
},
});

View File

@@ -0,0 +1,240 @@
use serde::{Deserialize, Serialize};
use serde_json;
use crate::message::tool_result_serde;
use crate::types::core::{Content, ImageContent, TextContent, ToolCall, ToolResult};
// — Newtype wrappers (local structs) so we satisfy Rusts orphan rules —
// We need these because we cant implement UniFFIs FfiConverter directly on a type alias.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolRequestToolCall(#[serde(with = "tool_result_serde")] pub ToolResult<ToolCall>);
impl ToolRequestToolCall {
pub fn as_result(&self) -> &ToolResult<ToolCall> {
&self.0
}
}
impl std::ops::Deref for ToolRequestToolCall {
type Target = ToolResult<ToolCall>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Result<ToolCall, crate::types::core::ToolError>> for ToolRequestToolCall {
fn from(res: Result<ToolCall, crate::types::core::ToolError>) -> Self {
ToolRequestToolCall(res)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolResponseToolResult(
#[serde(with = "tool_result_serde")] pub ToolResult<Vec<Content>>,
);
impl ToolResponseToolResult {
pub fn as_result(&self) -> &ToolResult<Vec<Content>> {
&self.0
}
}
impl std::ops::Deref for ToolResponseToolResult {
type Target = ToolResult<Vec<Content>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Result<Vec<Content>, crate::types::core::ToolError>> for ToolResponseToolResult {
fn from(res: Result<Vec<Content>, crate::types::core::ToolError>) -> Self {
ToolResponseToolResult(res)
}
}
// — Register the newtypes with UniFFI, converting via JSON strings —
// UniFFIs FFI layer supports only primitive buffers (here String), so we JSON-serialize
// through our `tool_result_serde` to preserve the same success/error schema on both sides.
uniffi::custom_type!(ToolRequestToolCall, String, {
lower: |obj| {
serde_json::to_string(&obj.0).unwrap()
},
try_lift: |val| {
Ok(serde_json::from_str(&val).unwrap() )
},
});
uniffi::custom_type!(ToolResponseToolResult, String, {
lower: |obj| {
serde_json::to_string(&obj.0).unwrap()
},
try_lift: |val| {
Ok(serde_json::from_str(&val).unwrap() )
},
});
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
#[serde(rename_all = "camelCase")]
pub struct ToolRequest {
pub id: String,
pub tool_call: ToolRequestToolCall,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
#[serde(rename_all = "camelCase")]
pub struct ToolResponse {
pub id: String,
pub tool_result: ToolResponseToolResult,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
pub struct ThinkingContent {
pub thinking: String,
pub signature: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
pub struct RedactedThinkingContent {
pub data: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
/// Content passed inside a message, which can be both simple content and tool content
#[serde(tag = "type", rename_all = "camelCase")]
pub enum MessageContent {
Text(TextContent),
Image(ImageContent),
ToolReq(ToolRequest),
ToolResp(ToolResponse),
Thinking(ThinkingContent),
RedactedThinking(RedactedThinkingContent),
}
impl MessageContent {
pub fn text<S: Into<String>>(text: S) -> Self {
MessageContent::Text(TextContent { text: text.into() })
}
pub fn image<S: Into<String>, T: Into<String>>(data: S, mime_type: T) -> Self {
MessageContent::Image(ImageContent {
data: data.into(),
mime_type: mime_type.into(),
})
}
pub fn tool_request<S: Into<String>>(id: S, tool_call: ToolRequestToolCall) -> Self {
MessageContent::ToolReq(ToolRequest {
id: id.into(),
tool_call,
})
}
pub fn tool_response<S: Into<String>>(id: S, tool_result: ToolResponseToolResult) -> Self {
MessageContent::ToolResp(ToolResponse {
id: id.into(),
tool_result,
})
}
pub fn thinking<S1: Into<String>, S2: Into<String>>(thinking: S1, signature: S2) -> Self {
MessageContent::Thinking(ThinkingContent {
thinking: thinking.into(),
signature: signature.into(),
})
}
pub fn redacted_thinking<S: Into<String>>(data: S) -> Self {
MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() })
}
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
if let MessageContent::ToolReq(ref tool_request) = self {
Some(tool_request)
} else {
None
}
}
pub fn as_tool_response(&self) -> Option<&ToolResponse> {
if let MessageContent::ToolResp(ref tool_response) = self {
Some(tool_response)
} else {
None
}
}
pub fn as_tool_response_text(&self) -> Option<String> {
if let Some(tool_response) = self.as_tool_response() {
if let Ok(contents) = &tool_response.tool_result.0 {
let texts: Vec<String> = contents
.iter()
.filter_map(|content| content.as_text().map(String::from))
.collect();
if !texts.is_empty() {
return Some(texts.join("\n"));
}
}
}
None
}
pub fn as_tool_request_id(&self) -> Option<&str> {
if let Self::ToolReq(r) = self {
Some(&r.id)
} else {
None
}
}
pub fn as_tool_response_id(&self) -> Option<&str> {
if let Self::ToolResp(r) = self {
Some(&r.id)
} else {
None
}
}
/// Get the text content if this is a TextContent variant
pub fn as_text(&self) -> Option<&str> {
match self {
MessageContent::Text(text) => Some(&text.text),
_ => None,
}
}
/// Get the thinking content if this is a ThinkingContent variant
pub fn as_thinking(&self) -> Option<&ThinkingContent> {
match self {
MessageContent::Thinking(thinking) => Some(thinking),
_ => None,
}
}
/// Get the redacted thinking content if this is a RedactedThinkingContent variant
pub fn as_redacted_thinking(&self) -> Option<&RedactedThinkingContent> {
match self {
MessageContent::RedactedThinking(redacted) => Some(redacted),
_ => None,
}
}
pub fn is_text(&self) -> bool {
matches!(self, Self::Text(_))
}
pub fn is_image(&self) -> bool {
matches!(self, Self::Image(_))
}
pub fn is_tool_request(&self) -> bool {
matches!(self, Self::ToolReq(_))
}
pub fn is_tool_response(&self) -> bool {
matches!(self, Self::ToolResp(_))
}
}
impl From<Content> for MessageContent {
fn from(content: Content) -> Self {
match content {
Content::Text(text) => MessageContent::Text(text),
Content::Image(image) => MessageContent::Image(image),
}
}
}

View File

@@ -0,0 +1,284 @@
//! Messages which represent the content sent back and forth to LLM provider
//!
//! We use these messages in the agent code, and interfaces which interact with
//! the agent. That let's us reuse message histories across different interfaces.
//!
//! The content of the messages uses MCP types to avoid additional conversions
//! when interacting with MCP servers.
mod contents;
mod message_content;
mod tool_result_serde;
pub use contents::Contents;
pub use message_content::{
MessageContent, RedactedThinkingContent, ThinkingContent, ToolRequest, ToolRequestToolCall,
ToolResponse, ToolResponseToolResult,
};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use crate::types::core::Role;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
/// A message to or from an LLM
#[serde(rename_all = "camelCase")]
pub struct Message {
pub role: Role,
pub created: i64,
pub content: Contents,
}
impl Message {
pub fn new(role: Role) -> Self {
Self {
role,
created: Utc::now().timestamp_millis(),
content: Contents::default(),
}
}
/// Create a new user message with the current timestamp
pub fn user() -> Self {
Self::new(Role::User)
}
/// Create a new assistant message with the current timestamp
pub fn assistant() -> Self {
Self::new(Role::Assistant)
}
/// Add any item that implements Into<MessageContent> to the message
pub fn with_content(mut self, item: impl Into<MessageContent>) -> Self {
self.content.push(item);
self
}
/// Add text content to the message
pub fn with_text<S: Into<String>>(self, text: S) -> Self {
self.with_content(MessageContent::text(text))
}
/// Add image content to the message
pub fn with_image<S: Into<String>, T: Into<String>>(self, data: S, mime_type: T) -> Self {
self.with_content(MessageContent::image(data, mime_type))
}
/// Add a tool request to the message
pub fn with_tool_request<S: Into<String>, T: Into<ToolRequestToolCall>>(
self,
id: S,
tool_call: T,
) -> Self {
self.with_content(MessageContent::tool_request(id, tool_call.into()))
}
/// Add a tool response to the message
pub fn with_tool_response<S: Into<String>>(
self,
id: S,
result: ToolResponseToolResult,
) -> Self {
self.with_content(MessageContent::tool_response(id, result))
}
/// Add thinking content to the message
pub fn with_thinking<S1: Into<String>, S2: Into<String>>(
self,
thinking: S1,
signature: S2,
) -> Self {
self.with_content(MessageContent::thinking(thinking, signature))
}
/// Add redacted thinking content to the message
pub fn with_redacted_thinking<S: Into<String>>(self, data: S) -> Self {
self.with_content(MessageContent::redacted_thinking(data))
}
/// Check if the message is a tool call
pub fn contains_tool_call(&self) -> bool {
self.content.any_is(MessageContent::is_tool_request)
}
/// Check if the message is a tool response
pub fn contains_tool_response(&self) -> bool {
self.content.any_is(MessageContent::is_tool_response)
}
/// Check if the message contains only text content
pub fn has_only_text_content(&self) -> bool {
self.content.all_are(MessageContent::is_text)
}
/// Retrieves all tool `id` from ToolRequest messages
pub fn tool_request_ids(&self) -> HashSet<&str> {
self.content
.iter()
.filter_map(MessageContent::as_tool_request_id)
.collect()
}
/// Retrieves all tool `id` from ToolResponse messages
pub fn tool_response_ids(&self) -> HashSet<&str> {
self.content
.iter()
.filter_map(MessageContent::as_tool_response_id)
.collect()
}
/// Retrieves all tool `id` from the message
pub fn tool_ids(&self) -> HashSet<&str> {
self.tool_request_ids()
.into_iter()
.chain(self.tool_response_ids())
.collect()
}
}
#[cfg(test)]
mod tests {
use serde_json::{json, Value};
use super::*;
use crate::types::core::{ToolCall, ToolError};
#[test]
fn test_message_serialization() {
let message = Message::assistant()
.with_text("Hello, I'll help you with that.")
.with_tool_request(
"tool123",
Ok(ToolCall::new("test_tool", json!({"param": "value"})).into()),
);
let json_str = serde_json::to_string_pretty(&message).unwrap();
println!("Serialized message: {}", json_str);
// Parse back to Value to check structure
let value: Value = serde_json::from_str(&json_str).unwrap();
println!(
"Read back serialized message: {}",
serde_json::to_string_pretty(&value).unwrap()
);
// Check top-level fields
assert_eq!(value["role"], "assistant");
assert!(value["created"].is_i64());
assert!(value["content"].is_array());
// Check content items
let content = &value["content"];
// First item should be text
assert_eq!(content[0]["type"], "text");
assert_eq!(content[0]["text"], "Hello, I'll help you with that.");
// Second item should be toolRequest
assert_eq!(content[1]["type"], "toolReq");
assert_eq!(content[1]["id"], "tool123");
// Check tool_call serialization
assert_eq!(content[1]["toolCall"]["status"], "success");
assert_eq!(content[1]["toolCall"]["value"]["name"], "test_tool");
assert_eq!(
content[1]["toolCall"]["value"]["arguments"]["param"],
"value"
);
}
#[test]
fn test_error_serialization() {
let message = Message::assistant().with_tool_request(
"tool123",
Err(ToolError::ExecutionError(
"Something went wrong".to_string(),
)),
);
let json_str = serde_json::to_string_pretty(&message).unwrap();
println!("Serialized error: {}", json_str);
// Parse back to Value to check structure
let value: Value = serde_json::from_str(&json_str).unwrap();
// Check tool_call serialization with error
let tool_call = &value["content"][0]["toolCall"];
assert_eq!(tool_call["status"], "error");
assert_eq!(tool_call["error"], "Execution failed: Something went wrong");
}
#[test]
fn test_deserialization() {
// Create a JSON string with our new format
let json_str = r#"{
"role": "assistant",
"created": 1740171566,
"content": [
{
"type": "text",
"text": "I'll help you with that."
},
{
"type": "toolReq",
"id": "tool123",
"toolCall": {
"status": "success",
"value": {
"name": "test_tool",
"arguments": {"param": "value"},
"needsApproval": false
}
}
}
]
}"#;
let message: Message = serde_json::from_str(json_str).unwrap();
assert_eq!(message.role, Role::Assistant);
assert_eq!(message.created, 1740171566);
assert_eq!(message.content.len(), 2);
// Check first content item
if let MessageContent::Text(text) = &message.content[0] {
assert_eq!(text.text, "I'll help you with that.");
} else {
panic!("Expected Text content");
}
// Check second content item
if let MessageContent::ToolReq(req) = &message.content[1] {
assert_eq!(req.id, "tool123");
if let Ok(tool_call) = req.tool_call.as_result() {
assert_eq!(tool_call.name, "test_tool");
assert_eq!(tool_call.arguments, json!({"param": "value"}));
} else {
panic!("Expected successful tool call");
}
} else {
panic!("Expected ToolRequest content");
}
}
#[test]
fn test_message_with_text() {
let message = Message::user().with_text("Hello");
assert_eq!(message.content.concat_text_str(), "Hello");
}
#[test]
fn test_message_with_tool_request() {
let tool_call = Ok(ToolCall::new("test_tool", json!({})));
let message = Message::assistant().with_tool_request("req1", tool_call);
assert!(message.contains_tool_call());
assert!(!message.contains_tool_response());
let ids = message.tool_ids();
assert_eq!(ids.len(), 1);
assert!(ids.contains("req1"));
}
}

View File

@@ -1,14 +1,14 @@
use serde::{Deserialize, Serialize};
const DEFAULT_CONTEXT_LIMIT: usize = 128_000;
const DEFAULT_CONTEXT_LIMIT: u32 = 128_000;
/// Configuration for model-specific settings and limits
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
pub struct ModelConfig {
/// The name of the model to use
pub model_name: String,
/// Optional explicit context limit that overrides any defaults
pub context_limit: Option<usize>,
pub context_limit: Option<u32>,
/// Optional temperature setting (0.0 - 1.0)
pub temperature: Option<f32>,
/// Optional maximum tokens to generate
@@ -34,7 +34,7 @@ impl ModelConfig {
}
/// Get model-specific context limit based on model name
fn get_model_specific_limit(model_name: &str) -> Option<usize> {
fn get_model_specific_limit(model_name: &str) -> Option<u32> {
// Implement some sensible defaults
match model_name {
// OpenAI models, https://platform.openai.com/docs/models#models-overview
@@ -52,7 +52,7 @@ impl ModelConfig {
}
/// Set an explicit context limit
pub fn with_context_limit(mut self, limit: Option<usize>) -> Self {
pub fn with_context_limit(mut self, limit: Option<u32>) -> Self {
// Default is None and therefore DEFAULT_CONTEXT_LIMIT, only set
// if input is Some to allow passing through with_context_limit in
// configuration cases
@@ -76,7 +76,7 @@ impl ModelConfig {
/// Get the context_limit for the current model
/// If none are defined, use the DEFAULT_CONTEXT_LIMIT
pub fn context_limit(&self) -> usize {
pub fn context_limit(&self) -> u32 {
self.context_limit.unwrap_or(DEFAULT_CONTEXT_LIMIT)
}
}

View File

@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use super::errors::ProviderError;
use crate::{message::Message, types::core::Tool};
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize, uniffi::Record)]
pub struct Usage {
pub input_tokens: Option<i32>,
pub output_tokens: Option<i32>,
@@ -26,7 +26,7 @@ impl Usage {
}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, uniffi::Record)]
pub struct ProviderCompleteResponse {
pub message: Message,
pub model: String,

View File

@@ -1,6 +1,6 @@
use thiserror::Error;
#[derive(Error, Debug)]
#[derive(Error, Debug, uniffi::Error)]
pub enum ProviderError {
#[error("Authentication error: {0}")]
Authentication(String),

View File

@@ -83,9 +83,9 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
]
}));
}
MessageContent::ToolRequest(request) => {
MessageContent::ToolReq(request) => {
has_tool_calls = true;
match &request.tool_call {
match &request.tool_call.as_result() {
Ok(tool_call) => {
let sanitized_name = sanitize_function_name(&tool_call.name);
@@ -114,8 +114,8 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
}
}
}
MessageContent::ToolResponse(response) => {
match &response.tool_result {
MessageContent::ToolResp(response) => {
match &response.tool_result.0 {
Ok(contents) => {
// Process all content, replacing images with placeholder text
let mut tool_content = Vec::new();
@@ -300,13 +300,13 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
"The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+",
function_name
));
content.push(MessageContent::tool_request(id, Err(error)));
content.push(MessageContent::tool_request(id, Err(error).into()));
} else {
match serde_json::from_str::<Value>(&arguments) {
Ok(params) => {
content.push(MessageContent::tool_request(
id,
Ok(ToolCall::new(&function_name, params)),
Ok(ToolCall::new(&function_name, params)).into(),
));
}
Err(e) => {
@@ -314,7 +314,7 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
"Could not interpret tool use parameters for id {}: {}",
id, e
));
content.push(MessageContent::tool_request(id, Err(error)));
content.push(MessageContent::tool_request(id, Err(error).into()));
}
}
}
@@ -681,19 +681,20 @@ mod tests {
Message::user().with_text("How are you?"),
Message::assistant().with_tool_request(
"tool1",
Ok(ToolCall::new("example", json!({"param1": "value1"}))),
Ok(ToolCall::new("example", json!({"param1": "value1"})).into()),
),
];
// Get the ID from the tool request to use in the response
let tool_id = if let MessageContent::ToolRequest(request) = &messages[2].content[0] {
let tool_id = if let MessageContent::ToolReq(request) = &messages[2].content[0] {
request.id.clone()
} else {
panic!("should be tool request");
};
messages
.push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")])));
messages.push(
Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
);
let spec = format_messages(&messages, &ImageFormat::OpenAi);
@@ -719,14 +720,15 @@ mod tests {
)];
// Get the ID from the tool request to use in the response
let tool_id = if let MessageContent::ToolRequest(request) = &messages[0].content[0] {
let tool_id = if let MessageContent::ToolReq(request) = &messages[0].content[0] {
request.id.clone()
} else {
panic!("should be tool request");
};
messages
.push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")])));
messages.push(
Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
);
let spec = format_messages(&messages, &ImageFormat::OpenAi);
@@ -857,7 +859,7 @@ mod tests {
let message = response_to_message(response)?;
assert_eq!(message.content.len(), 1);
if let MessageContent::ToolRequest(request) = &message.content[0] {
if let MessageContent::ToolReq(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "example_fn");
assert_eq!(tool_call.arguments, json!({"param": "value"}));
@@ -876,8 +878,8 @@ mod tests {
let message = response_to_message(response)?;
if let MessageContent::ToolRequest(request) = &message.content[0] {
match &request.tool_call {
if let MessageContent::ToolReq(request) = &message.content[0] {
match &request.tool_call.as_result() {
Err(ToolError::NotFound(msg)) => {
assert!(msg.starts_with("The provided function name"));
}
@@ -898,8 +900,8 @@ mod tests {
let message = response_to_message(response)?;
if let MessageContent::ToolRequest(request) = &message.content[0] {
match &request.tool_call {
if let MessageContent::ToolReq(request) = &message.content[0] {
match &request.tool_call.as_result() {
Err(ToolError::InvalidParameters(msg)) => {
assert!(msg.starts_with("Could not interpret tool use parameters"));
}
@@ -920,7 +922,7 @@ mod tests {
let message = response_to_message(response)?;
if let MessageContent::ToolRequest(request) = &message.content[0] {
if let MessageContent::ToolReq(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "example_fn");
assert_eq!(tool_call.arguments, json!({}));

View File

@@ -56,7 +56,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
// Redacted thinking blocks are not directly used in OpenAI format
continue;
}
MessageContent::ToolRequest(request) => match &request.tool_call {
MessageContent::ToolReq(request) => match &request.tool_call.as_result() {
Ok(tool_call) => {
let sanitized_name = sanitize_function_name(&tool_call.name);
let tool_calls = converted
@@ -82,8 +82,8 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
}));
}
},
MessageContent::ToolResponse(response) => {
match &response.tool_result {
MessageContent::ToolResp(response) => {
match &response.tool_result.0 {
Ok(contents) => {
// Process all content, replacing images with placeholder text
let mut tool_content = Vec::new();
@@ -210,13 +210,13 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
"The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+",
function_name
));
content.push(MessageContent::tool_request(id, Err(error)));
content.push(MessageContent::tool_request(id, Err(error).into()));
} else {
match serde_json::from_str::<Value>(&arguments) {
Ok(params) => {
content.push(MessageContent::tool_request(
id,
Ok(ToolCall::new(&function_name, params)),
Ok(ToolCall::new(&function_name, params)).into(),
));
}
Err(e) => {
@@ -224,7 +224,7 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
"Could not interpret tool use parameters for id {}: {}",
id, e
));
content.push(MessageContent::tool_request(id, Err(error)));
content.push(MessageContent::tool_request(id, Err(error).into()));
}
}
}
@@ -559,14 +559,15 @@ mod tests {
];
// Get the ID from the tool request to use in the response
let tool_id = if let MessageContent::ToolRequest(request) = &messages[2].content[0] {
let tool_id = if let MessageContent::ToolReq(request) = &messages[2].content[0] {
request.id.clone()
} else {
panic!("should be tool request");
};
messages
.push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")])));
messages.push(
Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
);
let spec = format_messages(&messages, &ImageFormat::OpenAi);
@@ -592,14 +593,15 @@ mod tests {
)];
// Get the ID from the tool request to use in the response
let tool_id = if let MessageContent::ToolRequest(request) = &messages[0].content[0] {
let tool_id = if let MessageContent::ToolReq(request) = &messages[0].content[0] {
request.id.clone()
} else {
panic!("should be tool request");
};
messages
.push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")])));
messages.push(
Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
);
let spec = format_messages(&messages, &ImageFormat::OpenAi);
@@ -730,7 +732,7 @@ mod tests {
let message = response_to_message(response)?;
assert_eq!(message.content.len(), 1);
if let MessageContent::ToolRequest(request) = &message.content[0] {
if let MessageContent::ToolReq(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "example_fn");
assert_eq!(tool_call.arguments, json!({"param": "value"}));
@@ -749,8 +751,8 @@ mod tests {
let message = response_to_message(response)?;
if let MessageContent::ToolRequest(request) = &message.content[0] {
match &request.tool_call {
if let MessageContent::ToolReq(request) = &message.content[0] {
match &request.tool_call.as_result() {
Err(ToolError::NotFound(msg)) => {
assert!(msg.starts_with("The provided function name"));
}
@@ -771,8 +773,8 @@ mod tests {
let message = response_to_message(response)?;
if let MessageContent::ToolRequest(request) = &message.content[0] {
match &request.tool_call {
if let MessageContent::ToolReq(request) = &message.content[0] {
match &request.tool_call.as_result() {
Err(ToolError::InvalidParameters(msg)) => {
assert!(msg.starts_with("Could not interpret tool use parameters"));
}
@@ -793,7 +795,7 @@ mod tests {
let message = response_to_message(response)?;
if let MessageContent::ToolRequest(request) = &message.content[0] {
if let MessageContent::ToolReq(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "example_fn");
assert_eq!(tool_call.arguments, json!({}));

View File

@@ -7,36 +7,24 @@ use thiserror::Error;
use serde::{Deserialize, Serialize};
use crate::types::json_value_ffi::JsonValueFfi;
use crate::{message::Message, providers::Usage};
use crate::{model::ModelConfig, providers::errors::ProviderError};
pub struct CompletionRequest<'a> {
pub provider_name: &'a str,
// Lifetimes are not supported in Uniffi, cause other languages don't have them
// https://github.com/mozilla/uniffi-rs/issues/1526#issuecomment-1528851837
#[derive(uniffi::Record)]
pub struct CompletionRequest {
pub provider_name: String,
pub model_config: ModelConfig,
pub system_preamble: &'a str,
pub messages: &'a [Message],
pub extensions: &'a [ExtensionConfig],
pub system_preamble: String,
pub messages: Vec<Message>,
pub extensions: Vec<ExtensionConfig>,
}
impl<'a> CompletionRequest<'a> {
pub fn new(
provider_name: &'a str,
model_config: ModelConfig,
system_preamble: &'a str,
messages: &'a [Message],
extensions: &'a [ExtensionConfig],
) -> Self {
Self {
provider_name,
model_config,
system_preamble,
messages,
extensions,
}
}
}
#[derive(Debug, Error)]
// https://mozilla.github.io/uniffi-rs/latest/proc_macro/errors.html
#[derive(Debug, Error, uniffi::Error)]
#[uniffi(flat_error)]
pub enum CompletionError {
#[error("failed to create provider: {0}")]
UnknownProvider(String),
@@ -54,7 +42,7 @@ pub enum CompletionError {
ToolNotFound(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
pub struct CompletionResponse {
pub message: Message,
pub model: String,
@@ -78,35 +66,35 @@ impl CompletionResponse {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
pub struct RuntimeMetrics {
pub total_time_ms: u128,
pub total_time_ms_provider: u128,
pub total_time_sec: f32,
pub total_time_sec_provider: f32,
pub tokens_per_second: Option<f64>,
}
impl RuntimeMetrics {
pub fn new(
total_time_ms: u128,
total_time_ms_provider: u128,
total_time_sec: f32,
total_time_sec_provider: f32,
tokens_per_second: Option<f64>,
) -> Self {
Self {
total_time_ms,
total_time_ms_provider,
total_time_sec,
total_time_sec_provider,
tokens_per_second,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
pub enum ToolApprovalMode {
Auto,
Manual,
Smart,
}
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolConfig {
pub name: String,
pub description: String,
@@ -140,7 +128,28 @@ impl ToolConfig {
}
}
#[derive(Debug, Clone, Serialize)]
#[uniffi::export]
pub fn create_tool_config(
name: &str,
description: &str,
input_schema: JsonValueFfi,
approval_mode: ToolApprovalMode,
) -> ToolConfig {
ToolConfig::new(name, description, input_schema.into(), approval_mode)
}
uniffi::custom_type!(ToolConfig, String, {
lower: |tc: &ToolConfig| {
serde_json::to_string(&tc).unwrap()
},
try_lift: |s: String| {
Ok(serde_json::from_str(&s).unwrap())
},
});
// — Register the newtypes with UniFFI, converting via JSON strings —
#[derive(Debug, Clone, Serialize, uniffi::Record)]
pub struct ExtensionConfig {
name: String,
instructions: Option<String>,

View File

@@ -4,14 +4,14 @@
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum Content {
Text(TextContent),
@@ -47,13 +47,13 @@ impl Content {
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
#[serde(rename_all = "camelCase")]
pub struct TextContent {
pub text: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
#[serde(rename_all = "camelCase")]
pub struct ImageContent {
pub data: String,
@@ -116,7 +116,7 @@ impl ToolCall {
}
#[non_exhaustive]
#[derive(Error, Debug, Clone, Deserialize, Serialize, PartialEq)]
#[derive(Error, Debug, Clone, Deserialize, Serialize, PartialEq, uniffi::Error)]
pub enum ToolError {
#[error("Invalid parameters: {0}")]
InvalidParameters(String),

View File

@@ -0,0 +1,84 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
// `serde_json::Value` gets converted to a `String` to pass across the FFI.
// https://github.com/mozilla/uniffi-rs/blob/main/docs/manual/src/types/custom_types.md?plain=1
// https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/examples/custom-types/src/lib.rs#L63-L69
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct JsonValueFfi(Value);
impl From<JsonValueFfi> for Value {
fn from(val: JsonValueFfi) -> Self {
val.0
}
}
impl From<Value> for JsonValueFfi {
fn from(val: Value) -> Self {
JsonValueFfi(val)
}
}
uniffi::custom_type!(JsonValueFfi, String, {
lower: |obj| {
serde_json::to_string(&obj.0).unwrap()
},
try_lift: |val| {
Ok(serde_json::from_str(&val).unwrap() )
},
});
// Write some tests to ensure that the conversion works as expected
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_json_value_ffi_conversion() {
let original = JsonValueFfi(json!({"key": "value"}));
let serialized = serde_json::to_string(&original).unwrap();
let deserialized: JsonValueFfi = serde_json::from_str(&serialized).unwrap();
assert_eq!(original.0, deserialized.0);
}
#[test]
fn test_json_value_ffi_to_serde() {
let original = JsonValueFfi(json!({"key": "value"}));
let value: Value = original.into();
assert_eq!(value, json!({"key": "value"}));
}
#[test]
fn test_json_value_ffi_from_serde() {
let value = json!({"key": "value"});
let original: JsonValueFfi = value.into();
assert_eq!(original.0, json!({"key": "value"}));
}
#[test]
fn test_json_value_ffi_lower() {
let original = JsonValueFfi(json!({"key": "value"}));
let serialized = serde_json::to_string(&original).unwrap();
assert_eq!(serialized, "{\"key\":\"value\"}");
}
#[test]
fn test_json_value_ffi_try_lift() {
let json_str = "{\"key\":\"value\"}";
let deserialized: JsonValueFfi = serde_json::from_str(json_str).unwrap();
let expected = JsonValueFfi(json!({"key": "value"}));
assert_eq!(deserialized.0, expected.0);
}
#[test]
fn test_json_value_ffi_custom_type() {
let json_str = "{\"key\":\"value\"}";
let deserialized: JsonValueFfi = serde_json::from_str(json_str).unwrap();
let serialized = serde_json::to_string(&deserialized).unwrap();
assert_eq!(serialized, json_str);
}
}

View File

@@ -1,2 +1,3 @@
pub mod completion;
pub mod core;
pub mod json_value_ffi;

View File

@@ -45,14 +45,14 @@ async fn test_generate_tooltip_with_tools() -> Result<(), ProviderError> {
let mut tool_req_msg = Message::assistant();
let req = ToolRequest {
id: "1".to_string(),
tool_call: Ok(ToolCall::new("get_time", json!({"timezone": "UTC"}))),
tool_call: Ok(ToolCall::new("get_time", json!({"timezone": "UTC"}))).into(),
};
tool_req_msg.content.push(MessageContent::ToolRequest(req));
tool_req_msg.content.push(MessageContent::ToolReq(req));
// 2) User message with the tool response
let tool_resp_msg = Message::user().with_tool_response(
"1",
Ok(vec![Content::text("The current time is 12:00 UTC")]),
Ok(vec![Content::text("The current time is 12:00 UTC")]).into(),
);
let messages = vec![tool_req_msg, tool_resp_msg];

View File

@@ -147,7 +147,7 @@ impl ProviderTester {
.message
.content
.iter()
.any(|content| matches!(content, MessageContent::ToolRequest(_))),
.any(|content| matches!(content, MessageContent::ToolReq(_))),
"Expected tool request in response"
);
@@ -171,7 +171,8 @@ impl ProviderTester {
Weather
Saturday 9:00 PM
Clear",
)]),
)])
.into(),
);
// Verify we construct a valid payload including the request/response pair for the next inference

View File

@@ -0,0 +1,3 @@
fn main() {
uniffi::uniffi_bindgen_main()
}