[goose-llm] autogenate kotlin bindings using uniffi-rs proc macros (#2478)

This commit is contained in:
Salman Mohammed
2025-05-09 12:15:38 -04:00
committed by GitHub
parent 77146e5035
commit b7dd3aba73
27 changed files with 4262 additions and 651 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
*.jar
run_cli.sh run_cli.sh
tokenizer_files/ tokenizer_files/
.DS_Store .DS_Store

289
Cargo.lock generated
View File

@@ -201,6 +201,48 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "askama"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d4744ed2eef2645831b441d8f5459689ade2ab27c854488fbab1fbe94fce1a7"
dependencies = [
"askama_derive",
"itoa",
"percent-encoding",
"serde",
"serde_json",
]
[[package]]
name = "askama_derive"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d661e0f57be36a5c14c48f78d09011e67e0cb618f269cca9f2fd8d15b68c46ac"
dependencies = [
"askama_parser",
"basic-toml",
"memchr",
"proc-macro2",
"quote",
"rustc-hash 2.1.1",
"serde",
"serde_derive",
"syn 2.0.99",
]
[[package]]
name = "askama_parser"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf315ce6524c857bb129ff794935cf6d42c82a6cff60526fe2a63593de4d0d4f"
dependencies = [
"memchr",
"serde",
"serde_derive",
"winnow",
]
[[package]] [[package]]
name = "assert-json-diff" name = "assert-json-diff"
version = "2.0.2" version = "2.0.2"
@@ -211,6 +253,19 @@ dependencies = [
"serde_json", "serde_json",
] ]
[[package]]
name = "async-compat"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bab94bde396a3f7b4962e396fdad640e241ed797d4d8d77fc8c237d14c58fc0"
dependencies = [
"futures-core",
"futures-io",
"once_cell",
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "async-compression" name = "async-compression"
version = "0.4.20" version = "0.4.20"
@@ -843,6 +898,15 @@ dependencies = [
"vsimd", "vsimd",
] ]
[[package]]
name = "basic-toml"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba62675e8242a4c4e806d12f11d136e626e6c8361d6b829310732241652a178a"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "bat" name = "bat"
version = "0.24.0" version = "0.24.0"
@@ -1086,6 +1150,38 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d2c12f985c78475a6b8d629afd0c360260ef34cfef52efccdcfd31972f81c2e" checksum = "2d2c12f985c78475a6b8d629afd0c360260ef34cfef52efccdcfd31972f81c2e"
[[package]]
name = "camino"
version = "1.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3"
dependencies = [
"serde",
]
[[package]]
name = "cargo-platform"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e35af189006b9c0f00a064685c727031e3ed2d8020f7ba284d78cc2671bd36ea"
dependencies = [
"serde",
]
[[package]]
name = "cargo_metadata"
version = "0.19.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd5eb614ed4c27c5d706420e4320fbe3216ab31fa1c33cd8246ac36dae4479ba"
dependencies = [
"camino",
"cargo-platform",
"semver",
"serde",
"serde_json",
"thiserror 2.0.12",
]
[[package]] [[package]]
name = "cast" name = "cast"
version = "0.3.0" version = "0.3.0"
@@ -2100,6 +2196,15 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa"
[[package]]
name = "fs-err"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88a41f105fe1d5b6b34b2055e3dc59bb79b46b48b2040b9e6c7b4b5de097aa41"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "fs2" name = "fs2"
version = "0.4.3" version = "0.4.3"
@@ -2302,6 +2407,17 @@ dependencies = [
"regex-syntax 0.8.5", "regex-syntax 0.8.5",
] ]
[[package]]
name = "goblin"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b363a30c165f666402fe6a3024d3bec7ebc898f96a4a23bd1c99f8dbf3f4f47"
dependencies = [
"log",
"plain",
"scroll",
]
[[package]] [[package]]
name = "google-apis-common" name = "google-apis-common"
version = "7.0.0" version = "7.0.0"
@@ -2544,6 +2660,7 @@ dependencies = [
"thiserror 1.0.69", "thiserror 1.0.69",
"tokio", "tokio",
"tracing", "tracing",
"uniffi",
"url", "url",
] ]
@@ -4413,6 +4530,12 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plain"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6"
[[package]] [[package]]
name = "plist" name = "plist"
version = "1.7.0" version = "1.7.0"
@@ -5279,6 +5402,26 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "scroll"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ab8598aa408498679922eff7fa985c25d58a90771bd6be794434c5277eab1a6"
dependencies = [
"scroll_derive",
]
[[package]]
name = "scroll_derive"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1783eabc414609e28a5ba76aee5ddd52199f7107a0b24c2e9746a1ecc34a683d"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.99",
]
[[package]] [[package]]
name = "sct" name = "sct"
version = "0.7.1" version = "0.7.1"
@@ -5342,6 +5485,9 @@ name = "semver"
version = "1.0.26" version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde" name = "serde"
@@ -5579,6 +5725,12 @@ dependencies = [
"time", "time",
] ]
[[package]]
name = "siphasher"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d"
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.9" version = "0.4.9"
@@ -5631,6 +5783,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]] [[package]]
name = "std_prelude" name = "std_prelude"
version = "0.2.12" version = "0.2.12"
@@ -6464,6 +6622,128 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "uniffi"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dcd1d240101ba3b9d7532ae86d9cb64d9a7ff63e13a2b7b9e94a32a601d8233"
dependencies = [
"anyhow",
"camino",
"cargo_metadata",
"clap 4.5.31",
"uniffi_bindgen",
"uniffi_core",
"uniffi_macros",
"uniffi_pipeline",
]
[[package]]
name = "uniffi_bindgen"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d0525f06d749ea80d8049dc0bb038bb87941e3d909eefa76b6f0a5589b59ac5"
dependencies = [
"anyhow",
"askama",
"camino",
"cargo_metadata",
"fs-err",
"glob",
"goblin",
"heck 0.5.0",
"indexmap 2.7.1",
"once_cell",
"serde",
"tempfile",
"textwrap",
"toml 0.5.11",
"uniffi_internal_macros",
"uniffi_meta",
"uniffi_pipeline",
"uniffi_udl",
]
[[package]]
name = "uniffi_core"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3fa8eb4d825b4ed095cb13483cba6927c3002b9eb603cef9b7688758cc3772e"
dependencies = [
"anyhow",
"async-compat",
"bytes",
"once_cell",
"static_assertions",
]
[[package]]
name = "uniffi_internal_macros"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83b547d69d699e52f2129fde4b57ae0d00b5216e59ed5b56097c95c86ba06095"
dependencies = [
"anyhow",
"indexmap 2.7.1",
"proc-macro2",
"quote",
"syn 2.0.99",
]
[[package]]
name = "uniffi_macros"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00f1de72edc8cb9201c7d650e3678840d143e4499004571aac49e6cb1b17da43"
dependencies = [
"camino",
"fs-err",
"once_cell",
"proc-macro2",
"quote",
"serde",
"syn 2.0.99",
"toml 0.5.11",
"uniffi_meta",
]
[[package]]
name = "uniffi_meta"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acc9204632f6a555b2cba7c8852c5523bc1aa5f3eff605c64af5054ea28b72e"
dependencies = [
"anyhow",
"siphasher",
"uniffi_internal_macros",
"uniffi_pipeline",
]
[[package]]
name = "uniffi_pipeline"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54b5336a9a925b358183837d31541d12590b7fcec373256d3770de02dff24c69"
dependencies = [
"anyhow",
"heck 0.5.0",
"indexmap 2.7.1",
"tempfile",
"uniffi_internal_macros",
]
[[package]]
name = "uniffi_udl"
version = "0.29.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f95e73373d85f04736bc51997d3e6855721144ec4384cae9ca8513c80615e129"
dependencies = [
"anyhow",
"textwrap",
"uniffi_meta",
"weedle2",
]
[[package]] [[package]]
name = "unsafe-libyaml" name = "unsafe-libyaml"
version = "0.2.11" version = "0.2.11"
@@ -6757,6 +7037,15 @@ dependencies = [
"rustls-pki-types", "rustls-pki-types",
] ]
[[package]]
name = "weedle2"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "998d2c24ec099a87daf9467808859f9d82b61f1d9c9701251aea037f514eae0e"
dependencies = [
"nom",
]
[[package]] [[package]]
name = "weezl" name = "weezl"
version = "0.1.8" version = "0.1.8"

View File

@@ -0,0 +1,130 @@
import kotlinx.coroutines.runBlocking
import uniffi.goose_llm.*
fun main() = runBlocking {
val now = System.currentTimeMillis() / 1000
val msgs = listOf(
// 1) User sends a plain-text prompt
Message(
role = Role.USER,
created = now,
content = listOf(
MessageContent.Text(
TextContent("What is 7 x 6?")
)
)
),
// 2) Assistant makes a tool request (ToolReq) to calculate 7×6
Message(
role = Role.ASSISTANT,
created = now + 2,
content = listOf(
MessageContent.ToolReq(
ToolRequest(
id = "calc1",
toolCall = """
{
"status": "success",
"value": {
"name": "calculator_extension__toolname",
"arguments": {
"operation": "multiply",
"numbers": [7, 6]
},
"needsApproval": false
}
}
""".trimIndent()
)
)
)
),
// 3) User (on behalf of the tool) responds with the tool result (ToolResp)
Message(
role = Role.USER,
created = now + 3,
content = listOf(
MessageContent.ToolResp(
ToolResponse(
id = "calc1",
toolResult = """
{
"status": "success",
"value": [
{"type": "text", "text": "42"}
]
}
""".trimIndent()
)
)
)
),
)
printMessages(msgs)
println("---\n")
val sessionName = generateSessionName(msgs)
println("Session Name: $sessionName")
val tooltip = generateTooltip(msgs)
println("Tooltip: $tooltip")
// Completion
val provider = "databricks"
val modelName = "goose-gpt-4-1"
val modelConfig = ModelConfig(
modelName,
100000u, // UInt
0.1f, // Float
200 // Int
)
val calculatorTool = createToolConfig(
name = "calculator",
description = "Perform basic arithmetic operations",
inputSchema = """
{
"type": "object",
"required": ["operation", "numbers"],
"properties": {
"operation": {
"type": "string",
"enum": ["add", "subtract", "multiply", "divide"],
"description": "The arithmetic operation to perform"
},
"numbers": {
"type": "array",
"items": { "type": "number" },
"description": "List of numbers to operate on in order"
}
}
}
""".trimIndent(),
approvalMode = ToolApprovalMode.AUTO
)
val calculator_extension = ExtensionConfig(
name = "calculator_extension",
instructions = "This extension provides a calculator tool.",
tools = listOf(calculatorTool)
)
val extensions = listOf(calculator_extension)
val systemPreamble = "You are a helpful assistant."
val req = CompletionRequest(
provider,
modelConfig,
systemPreamble,
msgs,
extensions
)
val response = completion(req)
println("\nCompletion Response:")
println(response.message)
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,8 +7,11 @@ license.workspace = true
repository.workspace = true repository.workspace = true
description.workspace = true description.workspace = true
[lib]
crate-type = ["lib", "cdylib"]
name = "goose_llm"
[dependencies] [dependencies]
tokio = { version = "1.43", features = ["full"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
anyhow = "1.0" anyhow = "1.0"
@@ -36,6 +39,9 @@ regex = "1.11.1"
tracing = "0.1" tracing = "0.1"
smallvec = { version = "1.13", features = ["serde"] } smallvec = { version = "1.13", features = ["serde"] }
indoc = "1.0" indoc = "1.0"
# https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/fixtures/futures/Cargo.toml#L22
uniffi = { version = "0.29", features = ["tokio", "cli", "scaffolding-ffi-buffer-fns"] }
tokio = { version = "1.43", features = ["time", "sync"] }
[dev-dependencies] [dev-dependencies]
criterion = "0.5" criterion = "0.5"
@@ -43,7 +49,12 @@ tempfile = "3.15.0"
dotenv = "0.15" dotenv = "0.15"
lazy_static = "1.5" lazy_static = "1.5"
ctor = "0.2.7" ctor = "0.2.7"
tokio = { version = "1.43", features = ["full"] }
[[bin]]
# https://mozilla.github.io/uniffi-rs/latest/tutorial/foreign_language_bindings.html
name = "uniffi-bindgen"
path = "uniffi-bindgen.rs"
[[example]] [[example]]
name = "simple" name = "simple"

View File

@@ -1,4 +1,4 @@
### goose-llm ## goose-llm
This crate is meant to be used for foreign function interface (FFI). It's meant to be This crate is meant to be used for foreign function interface (FFI). It's meant to be
stateless and contain logic related to providers and prompts: stateless and contain logic related to providers and prompts:
@@ -12,3 +12,59 @@ Run:
cargo run -p goose-llm --example simple cargo run -p goose-llm --example simple
``` ```
## Kotlin bindings
Structure:
```
.
└── crates
└── goose-llm/...
└── target
└── debug/libgoose_llm.dylib
├── bindings
│ └── kotlin
│ ├── example
│ │ └── Usage.kt ← your demo app
│ └── uniffi
│ └── goose_llm
│ └── goose_llm.kt ← auto-generated bindings
```
Create Kotlin bindings:
```
# run from project root directory
cargo build -p goose-llm
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
```
Download jars in `bindings/kotlin/libs` directory (only need to do this once):
```
pushd bindings/kotlin/libs/
curl -O https://repo1.maven.org/maven2/org/jetbrains/kotlin/kotlin-stdlib/1.9.0/kotlin-stdlib-1.9.0.jar
curl -O https://repo1.maven.org/maven2/org/jetbrains/kotlinx/kotlinx-coroutines-core-jvm/1.7.3/kotlinx-coroutines-core-jvm-1.7.3.jar
curl -O https://repo1.maven.org/maven2/net/java/dev/jna/jna/5.13.0/jna-5.13.0.jar
popd
```
Compile & Run usage example from Kotlin -> Rust:
```
pushd bindings/kotlin/
kotlinc \
example/Usage.kt \
uniffi/goose_llm/goose_llm.kt \
-classpath "libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
-include-runtime \
-d example.jar
java \
-Djna.library.path=$HOME/Development/goose/target/debug \
-classpath "example.jar:libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
UsageKt
popd
```

View File

@@ -93,13 +93,13 @@ async fn main() -> Result<()> {
println!("\n---------------\n"); println!("\n---------------\n");
println!("User Input: {text}"); println!("User Input: {text}");
let messages = vec![Message::user().with_text(text)]; let messages = vec![Message::user().with_text(text)];
let completion_response: CompletionResponse = completion(CompletionRequest::new( let completion_response: CompletionResponse = completion(CompletionRequest {
provider, provider_name: provider.to_string(),
model_config.clone(), model_config: model_config.clone(),
system_preamble, system_preamble: system_preamble.to_string(),
&messages, messages: messages,
&extensions, extensions: extensions.clone(),
)) })
.await?; .await?;
// Print the response // Print the response
println!("\nCompletion Response:"); println!("\nCompletion Response:");

View File

@@ -17,32 +17,40 @@ use crate::{
}, },
}; };
#[uniffi::export]
pub fn print_messages(messages: Vec<Message>) {
for msg in messages {
println!("[{:?} @ {}] {:?}", msg.role, msg.created, msg.content);
}
}
/// Public API for the Goose LLM completion function /// Public API for the Goose LLM completion function
pub async fn completion(req: CompletionRequest<'_>) -> Result<CompletionResponse, CompletionError> { #[uniffi::export(async_runtime = "tokio")]
pub async fn completion(req: CompletionRequest) -> Result<CompletionResponse, CompletionError> {
let start_total = Instant::now(); let start_total = Instant::now();
let provider = create(req.provider_name, req.model_config) let provider = create(&req.provider_name, req.model_config)
.map_err(|_| CompletionError::UnknownProvider(req.provider_name.to_string()))?; .map_err(|_| CompletionError::UnknownProvider(req.provider_name.to_string()))?;
let system_prompt = construct_system_prompt(req.system_preamble, req.extensions)?; let system_prompt = construct_system_prompt(&req.system_preamble, &req.extensions)?;
let tools = collect_prefixed_tools(req.extensions); let tools = collect_prefixed_tools(&req.extensions);
// Call the LLM provider // Call the LLM provider
let start_provider = Instant::now(); let start_provider = Instant::now();
let mut response = provider let mut response = provider
.complete(&system_prompt, req.messages, &tools) .complete(&system_prompt, &req.messages, &tools)
.await?; .await?;
let provider_elapsed_ms = start_provider.elapsed().as_millis(); let provider_elapsed_sec = start_provider.elapsed().as_secs_f32();
let usage_tokens = response.usage.total_tokens; let usage_tokens = response.usage.total_tokens;
let tool_configs = collect_prefixed_tool_configs(req.extensions); let tool_configs = collect_prefixed_tool_configs(&req.extensions);
update_needs_approval_for_tool_calls(&mut response.message, &tool_configs)?; update_needs_approval_for_tool_calls(&mut response.message, &tool_configs)?;
Ok(CompletionResponse::new( Ok(CompletionResponse::new(
response.message, response.message,
response.model, response.model,
response.usage, response.usage,
calculate_runtime_metrics(start_total, provider_elapsed_ms, usage_tokens), calculate_runtime_metrics(start_total, provider_elapsed_sec, usage_tokens),
)) ))
} }
@@ -81,8 +89,8 @@ pub fn update_needs_approval_for_tool_calls(
tool_configs: &HashMap<String, ToolConfig>, tool_configs: &HashMap<String, ToolConfig>,
) -> Result<(), CompletionError> { ) -> Result<(), CompletionError> {
for content in &mut message.content.iter_mut() { for content in &mut message.content.iter_mut() {
if let MessageContent::ToolRequest(req) = content { if let MessageContent::ToolReq(req) = content {
if let Ok(call) = &mut req.tool_call { if let Ok(call) = &mut req.tool_call.0 {
// Provide a clear error message when the tool config is missing // Provide a clear error message when the tool config is missing
let config = tool_configs.get(&call.name).ok_or_else(|| { let config = tool_configs.get(&call.name).ok_or_else(|| {
CompletionError::ToolNotFound(format!( CompletionError::ToolNotFound(format!(
@@ -117,16 +125,16 @@ fn collect_prefixed_tool_configs(extensions: &[ExtensionConfig]) -> HashMap<Stri
/// Compute runtime metrics for the request. /// Compute runtime metrics for the request.
fn calculate_runtime_metrics( fn calculate_runtime_metrics(
total_start: Instant, total_start: Instant,
provider_elapsed_ms: u128, provider_elapsed_sec: f32,
token_count: Option<i32>, token_count: Option<i32>,
) -> RuntimeMetrics { ) -> RuntimeMetrics {
let total_ms = total_start.elapsed().as_millis(); let total_ms = total_start.elapsed().as_secs_f32();
let tokens_per_sec = token_count.and_then(|toks| { let tokens_per_sec = token_count.and_then(|toks| {
if provider_elapsed_ms > 0 { if provider_elapsed_sec > 0.0 {
Some(toks as f64 / (provider_elapsed_ms as f64 / 1_000.0)) Some(toks as f64 / (provider_elapsed_sec as f64))
} else { } else {
None None
} }
}); });
RuntimeMetrics::new(total_ms, provider_elapsed_ms, tokens_per_sec) RuntimeMetrics::new(total_ms, provider_elapsed_sec, tokens_per_sec)
} }

View File

@@ -49,6 +49,7 @@ fn build_system_prompt() -> String {
} }
/// Generates a short (≤4 words) session name /// Generates a short (≤4 words) session name
#[uniffi::export(async_runtime = "tokio")]
pub async fn generate_session_name(messages: &[Message]) -> Result<String, ProviderError> { pub async fn generate_session_name(messages: &[Message]) -> Result<String, ProviderError> {
// Collect up to the first 3 user messages (truncated to 300 chars each) // Collect up to the first 3 user messages (truncated to 300 chars each)
let context: Vec<String> = messages let context: Vec<String> = messages

View File

@@ -53,6 +53,7 @@ fn build_system_prompt() -> String {
/// Generates a tooltip summarizing the last two messages in the session, /// Generates a tooltip summarizing the last two messages in the session,
/// including any tool calls or results. /// including any tool calls or results.
#[uniffi::export(async_runtime = "tokio")]
pub async fn generate_tooltip(messages: &[Message]) -> Result<String, ProviderError> { pub async fn generate_tooltip(messages: &[Message]) -> Result<String, ProviderError> {
// Need at least two messages to summarize // Need at least two messages to summarize
if messages.len() < 2 { if messages.len() < 2 {
@@ -72,17 +73,17 @@ pub async fn generate_tooltip(messages: &[Message]) -> Result<String, ProviderEr
parts.push(txt.to_string()); parts.push(txt.to_string());
} }
} }
MessageContent::ToolRequest(req) => { MessageContent::ToolReq(req) => {
if let Ok(tool_call) = &req.tool_call { if let Ok(tool_call) = &req.tool_call.0 {
parts.push(format!( parts.push(format!(
"called tool '{}' with args {}", "called tool '{}' with args {}",
tool_call.name, tool_call.arguments tool_call.name, tool_call.arguments
)); ));
} else if let Err(e) = &req.tool_call { } else if let Err(e) = &req.tool_call.0 {
parts.push(format!("tool request error: {}", e)); parts.push(format!("tool request error: {}", e));
} }
} }
MessageContent::ToolResponse(resp) => match &resp.tool_result { MessageContent::ToolResp(resp) => match &resp.tool_result.0 {
Ok(contents) => { Ok(contents) => {
let results: Vec<String> = contents let results: Vec<String> = contents
.iter() .iter()

View File

@@ -1,3 +1,5 @@
uniffi::setup_scaffolding!();
mod completion; mod completion;
pub mod extractors; pub mod extractors;
pub mod message; pub mod message;

View File

@@ -1,531 +0,0 @@
use std::{collections::HashSet, iter::FromIterator, ops::Deref};
/// Messages which represent the content sent back and forth to LLM provider
///
/// We use these messages in the agent code, and interfaces which interact with
/// the agent. That let's us reuse message histories across different interfaces.
///
/// The content of the messages uses MCP types to avoid additional conversions
/// when interacting with MCP servers.
use chrono::Utc;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use crate::types::core::{Content, ImageContent, Role, TextContent, ToolCall, ToolResult};
mod tool_result_serde;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolRequest {
pub id: String,
#[serde(with = "tool_result_serde")]
pub tool_call: ToolResult<ToolCall>,
}
impl ToolRequest {
pub fn to_readable_string(&self) -> String {
match &self.tool_call {
Ok(tool_call) => {
format!(
"Tool: {}, Args: {}",
tool_call.name,
serde_json::to_string_pretty(&tool_call.arguments)
.unwrap_or_else(|_| "<<invalid json>>".to_string())
)
}
Err(e) => format!("Invalid tool call: {}", e),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolResponse {
pub id: String,
#[serde(with = "tool_result_serde")]
pub tool_result: ToolResult<Vec<Content>>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ThinkingContent {
pub thinking: String,
pub signature: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RedactedThinkingContent {
pub data: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
/// Content passed inside a message, which can be both simple content and tool content
#[serde(tag = "type", rename_all = "camelCase")]
pub enum MessageContent {
Text(TextContent),
Image(ImageContent),
ToolRequest(ToolRequest),
ToolResponse(ToolResponse),
Thinking(ThinkingContent),
RedactedThinking(RedactedThinkingContent),
}
impl MessageContent {
pub fn text<S: Into<String>>(text: S) -> Self {
MessageContent::Text(TextContent { text: text.into() })
}
pub fn image<S: Into<String>, T: Into<String>>(data: S, mime_type: T) -> Self {
MessageContent::Image(ImageContent {
data: data.into(),
mime_type: mime_type.into(),
})
}
pub fn tool_request<S: Into<String>>(id: S, tool_call: ToolResult<ToolCall>) -> Self {
MessageContent::ToolRequest(ToolRequest {
id: id.into(),
tool_call,
})
}
pub fn tool_response<S: Into<String>>(id: S, tool_result: ToolResult<Vec<Content>>) -> Self {
MessageContent::ToolResponse(ToolResponse {
id: id.into(),
tool_result,
})
}
pub fn thinking<S1: Into<String>, S2: Into<String>>(thinking: S1, signature: S2) -> Self {
MessageContent::Thinking(ThinkingContent {
thinking: thinking.into(),
signature: signature.into(),
})
}
pub fn redacted_thinking<S: Into<String>>(data: S) -> Self {
MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() })
}
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
if let MessageContent::ToolRequest(ref tool_request) = self {
Some(tool_request)
} else {
None
}
}
pub fn as_tool_response(&self) -> Option<&ToolResponse> {
if let MessageContent::ToolResponse(ref tool_response) = self {
Some(tool_response)
} else {
None
}
}
pub fn as_tool_response_text(&self) -> Option<String> {
if let Some(tool_response) = self.as_tool_response() {
if let Ok(contents) = &tool_response.tool_result {
let texts: Vec<String> = contents
.iter()
.filter_map(|content| content.as_text().map(String::from))
.collect();
if !texts.is_empty() {
return Some(texts.join("\n"));
}
}
}
None
}
pub fn as_tool_request_id(&self) -> Option<&str> {
if let Self::ToolRequest(r) = self {
Some(&r.id)
} else {
None
}
}
pub fn as_tool_response_id(&self) -> Option<&str> {
if let Self::ToolResponse(r) = self {
Some(&r.id)
} else {
None
}
}
/// Get the text content if this is a TextContent variant
pub fn as_text(&self) -> Option<&str> {
match self {
MessageContent::Text(text) => Some(&text.text),
_ => None,
}
}
/// Get the thinking content if this is a ThinkingContent variant
pub fn as_thinking(&self) -> Option<&ThinkingContent> {
match self {
MessageContent::Thinking(thinking) => Some(thinking),
_ => None,
}
}
/// Get the redacted thinking content if this is a RedactedThinkingContent variant
pub fn as_redacted_thinking(&self) -> Option<&RedactedThinkingContent> {
match self {
MessageContent::RedactedThinking(redacted) => Some(redacted),
_ => None,
}
}
pub fn is_text(&self) -> bool {
matches!(self, Self::Text(_))
}
pub fn is_image(&self) -> bool {
matches!(self, Self::Image(_))
}
pub fn is_tool_request(&self) -> bool {
matches!(self, Self::ToolRequest(_))
}
pub fn is_tool_response(&self) -> bool {
matches!(self, Self::ToolResponse(_))
}
}
impl From<Content> for MessageContent {
fn from(content: Content) -> Self {
match content {
Content::Text(text) => MessageContent::Text(text),
Content::Image(image) => MessageContent::Image(image),
}
}
}
// ────────────────────────────────────────────────────────────────────────────
// 2. Contents a new-type wrapper around SmallVec
// ────────────────────────────────────────────────────────────────────────────
/// Holds the heterogeneous fragments that make up one chat message.
///
/// * Up to two items are stored inline on the stack.
/// * Falls back to a heap allocation only when necessary.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(transparent)]
pub struct Contents(SmallVec<[MessageContent; 2]>);
impl Contents {
/*----------------------------------------------------------
* 1-line ergonomic helpers
*---------------------------------------------------------*/
pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, MessageContent> {
self.0.iter_mut()
}
pub fn push(&mut self, item: impl Into<MessageContent>) {
self.0.push(item.into());
}
pub fn texts(&self) -> impl Iterator<Item = &str> {
self.0.iter().filter_map(|c| c.as_text())
}
pub fn concat_text_str(&self) -> String {
self.texts().collect::<Vec<_>>().join("\n")
}
/// Returns `true` if *any* item satisfies the predicate.
pub fn any_is<P>(&self, pred: P) -> bool
where
P: FnMut(&MessageContent) -> bool,
{
self.iter().any(pred)
}
/// Returns `true` if *every* item satisfies the predicate.
pub fn all_are<P>(&self, pred: P) -> bool
where
P: FnMut(&MessageContent) -> bool,
{
self.iter().all(pred)
}
}
impl From<Vec<MessageContent>> for Contents {
fn from(v: Vec<MessageContent>) -> Self {
Contents(SmallVec::from_vec(v))
}
}
impl FromIterator<MessageContent> for Contents {
fn from_iter<I: IntoIterator<Item = MessageContent>>(iter: I) -> Self {
Contents(SmallVec::from_iter(iter))
}
}
/*--------------------------------------------------------------
* Allow &message.content to behave like a slice of fragments.
*-------------------------------------------------------------*/
impl Deref for Contents {
type Target = [MessageContent];
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
/// A message to or from an LLM
#[serde(rename_all = "camelCase")]
pub struct Message {
pub role: Role,
pub created: i64,
pub content: Contents,
}
impl Message {
pub fn new(role: Role) -> Self {
Self {
role,
created: Utc::now().timestamp_millis(),
content: Contents::default(),
}
}
/// Create a new user message with the current timestamp
pub fn user() -> Self {
Self::new(Role::User)
}
/// Create a new assistant message with the current timestamp
pub fn assistant() -> Self {
Self::new(Role::Assistant)
}
/// Add any item that implements Into<MessageContent> to the message
pub fn with_content(mut self, item: impl Into<MessageContent>) -> Self {
self.content.push(item);
self
}
/// Add text content to the message
pub fn with_text<S: Into<String>>(self, text: S) -> Self {
self.with_content(MessageContent::text(text))
}
/// Add image content to the message
pub fn with_image<S: Into<String>, T: Into<String>>(self, data: S, mime_type: T) -> Self {
self.with_content(MessageContent::image(data, mime_type))
}
/// Add a tool request to the message
pub fn with_tool_request<S: Into<String>>(
self,
id: S,
tool_call: ToolResult<ToolCall>,
) -> Self {
self.with_content(MessageContent::tool_request(id, tool_call))
}
/// Add a tool response to the message
pub fn with_tool_response<S: Into<String>>(
self,
id: S,
result: ToolResult<Vec<Content>>,
) -> Self {
self.with_content(MessageContent::tool_response(id, result))
}
/// Add thinking content to the message
pub fn with_thinking<S1: Into<String>, S2: Into<String>>(
self,
thinking: S1,
signature: S2,
) -> Self {
self.with_content(MessageContent::thinking(thinking, signature))
}
/// Add redacted thinking content to the message
pub fn with_redacted_thinking<S: Into<String>>(self, data: S) -> Self {
self.with_content(MessageContent::redacted_thinking(data))
}
/// Check if the message is a tool call
pub fn contains_tool_call(&self) -> bool {
self.content.any_is(MessageContent::is_tool_request)
}
/// Check if the message is a tool response
pub fn contains_tool_response(&self) -> bool {
self.content.any_is(MessageContent::is_tool_response)
}
/// Check if the message contains only text content
pub fn has_only_text_content(&self) -> bool {
self.content.all_are(MessageContent::is_text)
}
/// Retrieves all tool `id` from ToolRequest messages
pub fn tool_request_ids(&self) -> HashSet<&str> {
self.content
.iter()
.filter_map(MessageContent::as_tool_request_id)
.collect()
}
/// Retrieves all tool `id` from ToolResponse messages
pub fn tool_response_ids(&self) -> HashSet<&str> {
self.content
.iter()
.filter_map(MessageContent::as_tool_response_id)
.collect()
}
/// Retrieves all tool `id` from the message
pub fn tool_ids(&self) -> HashSet<&str> {
self.tool_request_ids()
.into_iter()
.chain(self.tool_response_ids())
.collect()
}
}
#[cfg(test)]
mod tests {
use serde_json::{json, Value};
use super::*;
use crate::types::core::ToolError;
#[test]
fn test_message_serialization() {
let message = Message::assistant()
.with_text("Hello, I'll help you with that.")
.with_tool_request(
"tool123",
Ok(ToolCall::new("test_tool", json!({"param": "value"}))),
);
let json_str = serde_json::to_string_pretty(&message).unwrap();
println!("Serialized message: {}", json_str);
// Parse back to Value to check structure
let value: Value = serde_json::from_str(&json_str).unwrap();
// Check top-level fields
assert_eq!(value["role"], "assistant");
assert!(value["created"].is_i64());
assert!(value["content"].is_array());
// Check content items
let content = &value["content"];
// First item should be text
assert_eq!(content[0]["type"], "text");
assert_eq!(content[0]["text"], "Hello, I'll help you with that.");
// Second item should be toolRequest
assert_eq!(content[1]["type"], "toolRequest");
assert_eq!(content[1]["id"], "tool123");
// Check tool_call serialization
assert_eq!(content[1]["toolCall"]["status"], "success");
assert_eq!(content[1]["toolCall"]["value"]["name"], "test_tool");
assert_eq!(
content[1]["toolCall"]["value"]["arguments"]["param"],
"value"
);
}
#[test]
fn test_error_serialization() {
let message = Message::assistant().with_tool_request(
"tool123",
Err(ToolError::ExecutionError(
"Something went wrong".to_string(),
)),
);
let json_str = serde_json::to_string_pretty(&message).unwrap();
println!("Serialized error: {}", json_str);
// Parse back to Value to check structure
let value: Value = serde_json::from_str(&json_str).unwrap();
// Check tool_call serialization with error
let tool_call = &value["content"][0]["toolCall"];
assert_eq!(tool_call["status"], "error");
assert_eq!(tool_call["error"], "Execution failed: Something went wrong");
}
#[test]
fn test_deserialization() {
// Create a JSON string with our new format
let json_str = r#"{
"role": "assistant",
"created": 1740171566,
"content": [
{
"type": "text",
"text": "I'll help you with that."
},
{
"type": "toolRequest",
"id": "tool123",
"toolCall": {
"status": "success",
"value": {
"name": "test_tool",
"arguments": {"param": "value"},
"needsApproval": false
}
}
}
]
}"#;
let message: Message = serde_json::from_str(json_str).unwrap();
assert_eq!(message.role, Role::Assistant);
assert_eq!(message.created, 1740171566);
assert_eq!(message.content.len(), 2);
// Check first content item
if let MessageContent::Text(text) = &message.content[0] {
assert_eq!(text.text, "I'll help you with that.");
} else {
panic!("Expected Text content");
}
// Check second content item
if let MessageContent::ToolRequest(req) = &message.content[1] {
assert_eq!(req.id, "tool123");
if let Ok(tool_call) = &req.tool_call {
assert_eq!(tool_call.name, "test_tool");
assert_eq!(tool_call.arguments, json!({"param": "value"}));
} else {
panic!("Expected successful tool call");
}
} else {
panic!("Expected ToolRequest content");
}
}
#[test]
fn test_message_with_text() {
let message = Message::user().with_text("Hello");
assert_eq!(message.content.concat_text_str(), "Hello");
}
#[test]
fn test_message_with_tool_request() {
let tool_call = Ok(ToolCall::new("test_tool", json!({})));
let message = Message::assistant().with_tool_request("req1", tool_call);
assert!(message.contains_tool_call());
assert!(!message.contains_tool_response());
let ids = message.tool_ids();
assert_eq!(ids.len(), 1);
assert!(ids.contains("req1"));
}
}

View File

@@ -0,0 +1,84 @@
use std::{iter::FromIterator, ops::Deref};
use crate::message::MessageContent;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
/// Holds the heterogeneous fragments that make up one chat message.
///
/// * Up to two items are stored inline on the stack.
/// * Falls back to a heap allocation only when necessary.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(transparent)]
pub struct Contents(SmallVec<[MessageContent; 2]>);
impl Contents {
/*----------------------------------------------------------
* 1-line ergonomic helpers
*---------------------------------------------------------*/
pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, MessageContent> {
self.0.iter_mut()
}
pub fn push(&mut self, item: impl Into<MessageContent>) {
self.0.push(item.into());
}
pub fn texts(&self) -> impl Iterator<Item = &str> {
self.0.iter().filter_map(|c| c.as_text())
}
pub fn concat_text_str(&self) -> String {
self.texts().collect::<Vec<_>>().join("\n")
}
/// Returns `true` if *any* item satisfies the predicate.
pub fn any_is<P>(&self, pred: P) -> bool
where
P: FnMut(&MessageContent) -> bool,
{
self.iter().any(pred)
}
/// Returns `true` if *every* item satisfies the predicate.
pub fn all_are<P>(&self, pred: P) -> bool
where
P: FnMut(&MessageContent) -> bool,
{
self.iter().all(pred)
}
}
impl From<Vec<MessageContent>> for Contents {
fn from(v: Vec<MessageContent>) -> Self {
Contents(SmallVec::from_vec(v))
}
}
impl FromIterator<MessageContent> for Contents {
fn from_iter<I: IntoIterator<Item = MessageContent>>(iter: I) -> Self {
Contents(SmallVec::from_iter(iter))
}
}
/*--------------------------------------------------------------
* Allow &message.content to behave like a slice of fragments.
*-------------------------------------------------------------*/
impl Deref for Contents {
type Target = [MessageContent];
fn deref(&self) -> &Self::Target {
&self.0
}
}
// — Register the contents type with UniFFI, converting to/from Vec<MessageContent> —
// We need to do this because UniFFIs FFI layer supports only primitive buffers (here Vec<u8>),
uniffi::custom_type!(Contents, Vec<MessageContent>, {
lower: |contents: &Contents| {
contents.0.to_vec()
},
try_lift: |contents: Vec<MessageContent>| {
Ok(Contents::from(contents))
},
});

View File

@@ -0,0 +1,240 @@
use serde::{Deserialize, Serialize};
use serde_json;
use crate::message::tool_result_serde;
use crate::types::core::{Content, ImageContent, TextContent, ToolCall, ToolResult};
// — Newtype wrappers (local structs) so we satisfy Rusts orphan rules —
// We need these because we cant implement UniFFIs FfiConverter directly on a type alias.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolRequestToolCall(#[serde(with = "tool_result_serde")] pub ToolResult<ToolCall>);
impl ToolRequestToolCall {
pub fn as_result(&self) -> &ToolResult<ToolCall> {
&self.0
}
}
impl std::ops::Deref for ToolRequestToolCall {
type Target = ToolResult<ToolCall>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Result<ToolCall, crate::types::core::ToolError>> for ToolRequestToolCall {
fn from(res: Result<ToolCall, crate::types::core::ToolError>) -> Self {
ToolRequestToolCall(res)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolResponseToolResult(
#[serde(with = "tool_result_serde")] pub ToolResult<Vec<Content>>,
);
impl ToolResponseToolResult {
pub fn as_result(&self) -> &ToolResult<Vec<Content>> {
&self.0
}
}
impl std::ops::Deref for ToolResponseToolResult {
type Target = ToolResult<Vec<Content>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Result<Vec<Content>, crate::types::core::ToolError>> for ToolResponseToolResult {
fn from(res: Result<Vec<Content>, crate::types::core::ToolError>) -> Self {
ToolResponseToolResult(res)
}
}
// — Register the newtypes with UniFFI, converting via JSON strings —
// UniFFIs FFI layer supports only primitive buffers (here String), so we JSON-serialize
// through our `tool_result_serde` to preserve the same success/error schema on both sides.
uniffi::custom_type!(ToolRequestToolCall, String, {
lower: |obj| {
serde_json::to_string(&obj.0).unwrap()
},
try_lift: |val| {
Ok(serde_json::from_str(&val).unwrap() )
},
});
uniffi::custom_type!(ToolResponseToolResult, String, {
lower: |obj| {
serde_json::to_string(&obj.0).unwrap()
},
try_lift: |val| {
Ok(serde_json::from_str(&val).unwrap() )
},
});
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
#[serde(rename_all = "camelCase")]
pub struct ToolRequest {
pub id: String,
pub tool_call: ToolRequestToolCall,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
#[serde(rename_all = "camelCase")]
pub struct ToolResponse {
pub id: String,
pub tool_result: ToolResponseToolResult,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
pub struct ThinkingContent {
pub thinking: String,
pub signature: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
pub struct RedactedThinkingContent {
pub data: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
/// Content passed inside a message, which can be both simple content and tool content
#[serde(tag = "type", rename_all = "camelCase")]
pub enum MessageContent {
Text(TextContent),
Image(ImageContent),
ToolReq(ToolRequest),
ToolResp(ToolResponse),
Thinking(ThinkingContent),
RedactedThinking(RedactedThinkingContent),
}
impl MessageContent {
pub fn text<S: Into<String>>(text: S) -> Self {
MessageContent::Text(TextContent { text: text.into() })
}
pub fn image<S: Into<String>, T: Into<String>>(data: S, mime_type: T) -> Self {
MessageContent::Image(ImageContent {
data: data.into(),
mime_type: mime_type.into(),
})
}
pub fn tool_request<S: Into<String>>(id: S, tool_call: ToolRequestToolCall) -> Self {
MessageContent::ToolReq(ToolRequest {
id: id.into(),
tool_call,
})
}
pub fn tool_response<S: Into<String>>(id: S, tool_result: ToolResponseToolResult) -> Self {
MessageContent::ToolResp(ToolResponse {
id: id.into(),
tool_result,
})
}
pub fn thinking<S1: Into<String>, S2: Into<String>>(thinking: S1, signature: S2) -> Self {
MessageContent::Thinking(ThinkingContent {
thinking: thinking.into(),
signature: signature.into(),
})
}
pub fn redacted_thinking<S: Into<String>>(data: S) -> Self {
MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() })
}
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
if let MessageContent::ToolReq(ref tool_request) = self {
Some(tool_request)
} else {
None
}
}
pub fn as_tool_response(&self) -> Option<&ToolResponse> {
if let MessageContent::ToolResp(ref tool_response) = self {
Some(tool_response)
} else {
None
}
}
pub fn as_tool_response_text(&self) -> Option<String> {
if let Some(tool_response) = self.as_tool_response() {
if let Ok(contents) = &tool_response.tool_result.0 {
let texts: Vec<String> = contents
.iter()
.filter_map(|content| content.as_text().map(String::from))
.collect();
if !texts.is_empty() {
return Some(texts.join("\n"));
}
}
}
None
}
pub fn as_tool_request_id(&self) -> Option<&str> {
if let Self::ToolReq(r) = self {
Some(&r.id)
} else {
None
}
}
pub fn as_tool_response_id(&self) -> Option<&str> {
if let Self::ToolResp(r) = self {
Some(&r.id)
} else {
None
}
}
/// Get the text content if this is a TextContent variant
pub fn as_text(&self) -> Option<&str> {
match self {
MessageContent::Text(text) => Some(&text.text),
_ => None,
}
}
/// Get the thinking content if this is a ThinkingContent variant
pub fn as_thinking(&self) -> Option<&ThinkingContent> {
match self {
MessageContent::Thinking(thinking) => Some(thinking),
_ => None,
}
}
/// Get the redacted thinking content if this is a RedactedThinkingContent variant
pub fn as_redacted_thinking(&self) -> Option<&RedactedThinkingContent> {
match self {
MessageContent::RedactedThinking(redacted) => Some(redacted),
_ => None,
}
}
pub fn is_text(&self) -> bool {
matches!(self, Self::Text(_))
}
pub fn is_image(&self) -> bool {
matches!(self, Self::Image(_))
}
pub fn is_tool_request(&self) -> bool {
matches!(self, Self::ToolReq(_))
}
pub fn is_tool_response(&self) -> bool {
matches!(self, Self::ToolResp(_))
}
}
impl From<Content> for MessageContent {
fn from(content: Content) -> Self {
match content {
Content::Text(text) => MessageContent::Text(text),
Content::Image(image) => MessageContent::Image(image),
}
}
}

View File

@@ -0,0 +1,284 @@
//! Messages which represent the content sent back and forth to LLM provider
//!
//! We use these messages in the agent code, and interfaces which interact with
//! the agent. That let's us reuse message histories across different interfaces.
//!
//! The content of the messages uses MCP types to avoid additional conversions
//! when interacting with MCP servers.
mod contents;
mod message_content;
mod tool_result_serde;
pub use contents::Contents;
pub use message_content::{
MessageContent, RedactedThinkingContent, ThinkingContent, ToolRequest, ToolRequestToolCall,
ToolResponse, ToolResponseToolResult,
};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use crate::types::core::Role;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
/// A message to or from an LLM
#[serde(rename_all = "camelCase")]
pub struct Message {
pub role: Role,
pub created: i64,
pub content: Contents,
}
impl Message {
pub fn new(role: Role) -> Self {
Self {
role,
created: Utc::now().timestamp_millis(),
content: Contents::default(),
}
}
/// Create a new user message with the current timestamp
pub fn user() -> Self {
Self::new(Role::User)
}
/// Create a new assistant message with the current timestamp
pub fn assistant() -> Self {
Self::new(Role::Assistant)
}
/// Add any item that implements Into<MessageContent> to the message
pub fn with_content(mut self, item: impl Into<MessageContent>) -> Self {
self.content.push(item);
self
}
/// Add text content to the message
pub fn with_text<S: Into<String>>(self, text: S) -> Self {
self.with_content(MessageContent::text(text))
}
/// Add image content to the message
pub fn with_image<S: Into<String>, T: Into<String>>(self, data: S, mime_type: T) -> Self {
self.with_content(MessageContent::image(data, mime_type))
}
/// Add a tool request to the message
pub fn with_tool_request<S: Into<String>, T: Into<ToolRequestToolCall>>(
self,
id: S,
tool_call: T,
) -> Self {
self.with_content(MessageContent::tool_request(id, tool_call.into()))
}
/// Add a tool response to the message
pub fn with_tool_response<S: Into<String>>(
self,
id: S,
result: ToolResponseToolResult,
) -> Self {
self.with_content(MessageContent::tool_response(id, result))
}
/// Add thinking content to the message
pub fn with_thinking<S1: Into<String>, S2: Into<String>>(
self,
thinking: S1,
signature: S2,
) -> Self {
self.with_content(MessageContent::thinking(thinking, signature))
}
/// Add redacted thinking content to the message
pub fn with_redacted_thinking<S: Into<String>>(self, data: S) -> Self {
self.with_content(MessageContent::redacted_thinking(data))
}
/// Check if the message is a tool call
pub fn contains_tool_call(&self) -> bool {
self.content.any_is(MessageContent::is_tool_request)
}
/// Check if the message is a tool response
pub fn contains_tool_response(&self) -> bool {
self.content.any_is(MessageContent::is_tool_response)
}
/// Check if the message contains only text content
pub fn has_only_text_content(&self) -> bool {
self.content.all_are(MessageContent::is_text)
}
/// Retrieves all tool `id` from ToolRequest messages
pub fn tool_request_ids(&self) -> HashSet<&str> {
self.content
.iter()
.filter_map(MessageContent::as_tool_request_id)
.collect()
}
/// Retrieves all tool `id` from ToolResponse messages
pub fn tool_response_ids(&self) -> HashSet<&str> {
self.content
.iter()
.filter_map(MessageContent::as_tool_response_id)
.collect()
}
/// Retrieves all tool `id` from the message
pub fn tool_ids(&self) -> HashSet<&str> {
self.tool_request_ids()
.into_iter()
.chain(self.tool_response_ids())
.collect()
}
}
#[cfg(test)]
mod tests {
use serde_json::{json, Value};
use super::*;
use crate::types::core::{ToolCall, ToolError};
#[test]
fn test_message_serialization() {
let message = Message::assistant()
.with_text("Hello, I'll help you with that.")
.with_tool_request(
"tool123",
Ok(ToolCall::new("test_tool", json!({"param": "value"})).into()),
);
let json_str = serde_json::to_string_pretty(&message).unwrap();
println!("Serialized message: {}", json_str);
// Parse back to Value to check structure
let value: Value = serde_json::from_str(&json_str).unwrap();
println!(
"Read back serialized message: {}",
serde_json::to_string_pretty(&value).unwrap()
);
// Check top-level fields
assert_eq!(value["role"], "assistant");
assert!(value["created"].is_i64());
assert!(value["content"].is_array());
// Check content items
let content = &value["content"];
// First item should be text
assert_eq!(content[0]["type"], "text");
assert_eq!(content[0]["text"], "Hello, I'll help you with that.");
// Second item should be toolRequest
assert_eq!(content[1]["type"], "toolReq");
assert_eq!(content[1]["id"], "tool123");
// Check tool_call serialization
assert_eq!(content[1]["toolCall"]["status"], "success");
assert_eq!(content[1]["toolCall"]["value"]["name"], "test_tool");
assert_eq!(
content[1]["toolCall"]["value"]["arguments"]["param"],
"value"
);
}
#[test]
fn test_error_serialization() {
let message = Message::assistant().with_tool_request(
"tool123",
Err(ToolError::ExecutionError(
"Something went wrong".to_string(),
)),
);
let json_str = serde_json::to_string_pretty(&message).unwrap();
println!("Serialized error: {}", json_str);
// Parse back to Value to check structure
let value: Value = serde_json::from_str(&json_str).unwrap();
// Check tool_call serialization with error
let tool_call = &value["content"][0]["toolCall"];
assert_eq!(tool_call["status"], "error");
assert_eq!(tool_call["error"], "Execution failed: Something went wrong");
}
#[test]
fn test_deserialization() {
// Create a JSON string with our new format
let json_str = r#"{
"role": "assistant",
"created": 1740171566,
"content": [
{
"type": "text",
"text": "I'll help you with that."
},
{
"type": "toolReq",
"id": "tool123",
"toolCall": {
"status": "success",
"value": {
"name": "test_tool",
"arguments": {"param": "value"},
"needsApproval": false
}
}
}
]
}"#;
let message: Message = serde_json::from_str(json_str).unwrap();
assert_eq!(message.role, Role::Assistant);
assert_eq!(message.created, 1740171566);
assert_eq!(message.content.len(), 2);
// Check first content item
if let MessageContent::Text(text) = &message.content[0] {
assert_eq!(text.text, "I'll help you with that.");
} else {
panic!("Expected Text content");
}
// Check second content item
if let MessageContent::ToolReq(req) = &message.content[1] {
assert_eq!(req.id, "tool123");
if let Ok(tool_call) = req.tool_call.as_result() {
assert_eq!(tool_call.name, "test_tool");
assert_eq!(tool_call.arguments, json!({"param": "value"}));
} else {
panic!("Expected successful tool call");
}
} else {
panic!("Expected ToolRequest content");
}
}
#[test]
fn test_message_with_text() {
let message = Message::user().with_text("Hello");
assert_eq!(message.content.concat_text_str(), "Hello");
}
#[test]
fn test_message_with_tool_request() {
let tool_call = Ok(ToolCall::new("test_tool", json!({})));
let message = Message::assistant().with_tool_request("req1", tool_call);
assert!(message.contains_tool_call());
assert!(!message.contains_tool_response());
let ids = message.tool_ids();
assert_eq!(ids.len(), 1);
assert!(ids.contains("req1"));
}
}

View File

@@ -1,14 +1,14 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
const DEFAULT_CONTEXT_LIMIT: usize = 128_000; const DEFAULT_CONTEXT_LIMIT: u32 = 128_000;
/// Configuration for model-specific settings and limits /// Configuration for model-specific settings and limits
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
pub struct ModelConfig { pub struct ModelConfig {
/// The name of the model to use /// The name of the model to use
pub model_name: String, pub model_name: String,
/// Optional explicit context limit that overrides any defaults /// Optional explicit context limit that overrides any defaults
pub context_limit: Option<usize>, pub context_limit: Option<u32>,
/// Optional temperature setting (0.0 - 1.0) /// Optional temperature setting (0.0 - 1.0)
pub temperature: Option<f32>, pub temperature: Option<f32>,
/// Optional maximum tokens to generate /// Optional maximum tokens to generate
@@ -34,7 +34,7 @@ impl ModelConfig {
} }
/// Get model-specific context limit based on model name /// Get model-specific context limit based on model name
fn get_model_specific_limit(model_name: &str) -> Option<usize> { fn get_model_specific_limit(model_name: &str) -> Option<u32> {
// Implement some sensible defaults // Implement some sensible defaults
match model_name { match model_name {
// OpenAI models, https://platform.openai.com/docs/models#models-overview // OpenAI models, https://platform.openai.com/docs/models#models-overview
@@ -52,7 +52,7 @@ impl ModelConfig {
} }
/// Set an explicit context limit /// Set an explicit context limit
pub fn with_context_limit(mut self, limit: Option<usize>) -> Self { pub fn with_context_limit(mut self, limit: Option<u32>) -> Self {
// Default is None and therefore DEFAULT_CONTEXT_LIMIT, only set // Default is None and therefore DEFAULT_CONTEXT_LIMIT, only set
// if input is Some to allow passing through with_context_limit in // if input is Some to allow passing through with_context_limit in
// configuration cases // configuration cases
@@ -76,7 +76,7 @@ impl ModelConfig {
/// Get the context_limit for the current model /// Get the context_limit for the current model
/// If none are defined, use the DEFAULT_CONTEXT_LIMIT /// If none are defined, use the DEFAULT_CONTEXT_LIMIT
pub fn context_limit(&self) -> usize { pub fn context_limit(&self) -> u32 {
self.context_limit.unwrap_or(DEFAULT_CONTEXT_LIMIT) self.context_limit.unwrap_or(DEFAULT_CONTEXT_LIMIT)
} }
} }

View File

@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use super::errors::ProviderError; use super::errors::ProviderError;
use crate::{message::Message, types::core::Tool}; use crate::{message::Message, types::core::Tool};
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize, uniffi::Record)]
pub struct Usage { pub struct Usage {
pub input_tokens: Option<i32>, pub input_tokens: Option<i32>,
pub output_tokens: Option<i32>, pub output_tokens: Option<i32>,
@@ -26,7 +26,7 @@ impl Usage {
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone, uniffi::Record)]
pub struct ProviderCompleteResponse { pub struct ProviderCompleteResponse {
pub message: Message, pub message: Message,
pub model: String, pub model: String,

View File

@@ -1,6 +1,6 @@
use thiserror::Error; use thiserror::Error;
#[derive(Error, Debug)] #[derive(Error, Debug, uniffi::Error)]
pub enum ProviderError { pub enum ProviderError {
#[error("Authentication error: {0}")] #[error("Authentication error: {0}")]
Authentication(String), Authentication(String),

View File

@@ -83,9 +83,9 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
] ]
})); }));
} }
MessageContent::ToolRequest(request) => { MessageContent::ToolReq(request) => {
has_tool_calls = true; has_tool_calls = true;
match &request.tool_call { match &request.tool_call.as_result() {
Ok(tool_call) => { Ok(tool_call) => {
let sanitized_name = sanitize_function_name(&tool_call.name); let sanitized_name = sanitize_function_name(&tool_call.name);
@@ -114,8 +114,8 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
} }
} }
} }
MessageContent::ToolResponse(response) => { MessageContent::ToolResp(response) => {
match &response.tool_result { match &response.tool_result.0 {
Ok(contents) => { Ok(contents) => {
// Process all content, replacing images with placeholder text // Process all content, replacing images with placeholder text
let mut tool_content = Vec::new(); let mut tool_content = Vec::new();
@@ -300,13 +300,13 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
"The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", "The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+",
function_name function_name
)); ));
content.push(MessageContent::tool_request(id, Err(error))); content.push(MessageContent::tool_request(id, Err(error).into()));
} else { } else {
match serde_json::from_str::<Value>(&arguments) { match serde_json::from_str::<Value>(&arguments) {
Ok(params) => { Ok(params) => {
content.push(MessageContent::tool_request( content.push(MessageContent::tool_request(
id, id,
Ok(ToolCall::new(&function_name, params)), Ok(ToolCall::new(&function_name, params)).into(),
)); ));
} }
Err(e) => { Err(e) => {
@@ -314,7 +314,7 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
"Could not interpret tool use parameters for id {}: {}", "Could not interpret tool use parameters for id {}: {}",
id, e id, e
)); ));
content.push(MessageContent::tool_request(id, Err(error))); content.push(MessageContent::tool_request(id, Err(error).into()));
} }
} }
} }
@@ -681,19 +681,20 @@ mod tests {
Message::user().with_text("How are you?"), Message::user().with_text("How are you?"),
Message::assistant().with_tool_request( Message::assistant().with_tool_request(
"tool1", "tool1",
Ok(ToolCall::new("example", json!({"param1": "value1"}))), Ok(ToolCall::new("example", json!({"param1": "value1"})).into()),
), ),
]; ];
// Get the ID from the tool request to use in the response // Get the ID from the tool request to use in the response
let tool_id = if let MessageContent::ToolRequest(request) = &messages[2].content[0] { let tool_id = if let MessageContent::ToolReq(request) = &messages[2].content[0] {
request.id.clone() request.id.clone()
} else { } else {
panic!("should be tool request"); panic!("should be tool request");
}; };
messages messages.push(
.push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]))); Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
);
let spec = format_messages(&messages, &ImageFormat::OpenAi); let spec = format_messages(&messages, &ImageFormat::OpenAi);
@@ -719,14 +720,15 @@ mod tests {
)]; )];
// Get the ID from the tool request to use in the response // Get the ID from the tool request to use in the response
let tool_id = if let MessageContent::ToolRequest(request) = &messages[0].content[0] { let tool_id = if let MessageContent::ToolReq(request) = &messages[0].content[0] {
request.id.clone() request.id.clone()
} else { } else {
panic!("should be tool request"); panic!("should be tool request");
}; };
messages messages.push(
.push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]))); Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
);
let spec = format_messages(&messages, &ImageFormat::OpenAi); let spec = format_messages(&messages, &ImageFormat::OpenAi);
@@ -857,7 +859,7 @@ mod tests {
let message = response_to_message(response)?; let message = response_to_message(response)?;
assert_eq!(message.content.len(), 1); assert_eq!(message.content.len(), 1);
if let MessageContent::ToolRequest(request) = &message.content[0] { if let MessageContent::ToolReq(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap(); let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "example_fn"); assert_eq!(tool_call.name, "example_fn");
assert_eq!(tool_call.arguments, json!({"param": "value"})); assert_eq!(tool_call.arguments, json!({"param": "value"}));
@@ -876,8 +878,8 @@ mod tests {
let message = response_to_message(response)?; let message = response_to_message(response)?;
if let MessageContent::ToolRequest(request) = &message.content[0] { if let MessageContent::ToolReq(request) = &message.content[0] {
match &request.tool_call { match &request.tool_call.as_result() {
Err(ToolError::NotFound(msg)) => { Err(ToolError::NotFound(msg)) => {
assert!(msg.starts_with("The provided function name")); assert!(msg.starts_with("The provided function name"));
} }
@@ -898,8 +900,8 @@ mod tests {
let message = response_to_message(response)?; let message = response_to_message(response)?;
if let MessageContent::ToolRequest(request) = &message.content[0] { if let MessageContent::ToolReq(request) = &message.content[0] {
match &request.tool_call { match &request.tool_call.as_result() {
Err(ToolError::InvalidParameters(msg)) => { Err(ToolError::InvalidParameters(msg)) => {
assert!(msg.starts_with("Could not interpret tool use parameters")); assert!(msg.starts_with("Could not interpret tool use parameters"));
} }
@@ -920,7 +922,7 @@ mod tests {
let message = response_to_message(response)?; let message = response_to_message(response)?;
if let MessageContent::ToolRequest(request) = &message.content[0] { if let MessageContent::ToolReq(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap(); let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "example_fn"); assert_eq!(tool_call.name, "example_fn");
assert_eq!(tool_call.arguments, json!({})); assert_eq!(tool_call.arguments, json!({}));

View File

@@ -56,7 +56,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
// Redacted thinking blocks are not directly used in OpenAI format // Redacted thinking blocks are not directly used in OpenAI format
continue; continue;
} }
MessageContent::ToolRequest(request) => match &request.tool_call { MessageContent::ToolReq(request) => match &request.tool_call.as_result() {
Ok(tool_call) => { Ok(tool_call) => {
let sanitized_name = sanitize_function_name(&tool_call.name); let sanitized_name = sanitize_function_name(&tool_call.name);
let tool_calls = converted let tool_calls = converted
@@ -82,8 +82,8 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
})); }));
} }
}, },
MessageContent::ToolResponse(response) => { MessageContent::ToolResp(response) => {
match &response.tool_result { match &response.tool_result.0 {
Ok(contents) => { Ok(contents) => {
// Process all content, replacing images with placeholder text // Process all content, replacing images with placeholder text
let mut tool_content = Vec::new(); let mut tool_content = Vec::new();
@@ -210,13 +210,13 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
"The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", "The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+",
function_name function_name
)); ));
content.push(MessageContent::tool_request(id, Err(error))); content.push(MessageContent::tool_request(id, Err(error).into()));
} else { } else {
match serde_json::from_str::<Value>(&arguments) { match serde_json::from_str::<Value>(&arguments) {
Ok(params) => { Ok(params) => {
content.push(MessageContent::tool_request( content.push(MessageContent::tool_request(
id, id,
Ok(ToolCall::new(&function_name, params)), Ok(ToolCall::new(&function_name, params)).into(),
)); ));
} }
Err(e) => { Err(e) => {
@@ -224,7 +224,7 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
"Could not interpret tool use parameters for id {}: {}", "Could not interpret tool use parameters for id {}: {}",
id, e id, e
)); ));
content.push(MessageContent::tool_request(id, Err(error))); content.push(MessageContent::tool_request(id, Err(error).into()));
} }
} }
} }
@@ -559,14 +559,15 @@ mod tests {
]; ];
// Get the ID from the tool request to use in the response // Get the ID from the tool request to use in the response
let tool_id = if let MessageContent::ToolRequest(request) = &messages[2].content[0] { let tool_id = if let MessageContent::ToolReq(request) = &messages[2].content[0] {
request.id.clone() request.id.clone()
} else { } else {
panic!("should be tool request"); panic!("should be tool request");
}; };
messages messages.push(
.push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]))); Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
);
let spec = format_messages(&messages, &ImageFormat::OpenAi); let spec = format_messages(&messages, &ImageFormat::OpenAi);
@@ -592,14 +593,15 @@ mod tests {
)]; )];
// Get the ID from the tool request to use in the response // Get the ID from the tool request to use in the response
let tool_id = if let MessageContent::ToolRequest(request) = &messages[0].content[0] { let tool_id = if let MessageContent::ToolReq(request) = &messages[0].content[0] {
request.id.clone() request.id.clone()
} else { } else {
panic!("should be tool request"); panic!("should be tool request");
}; };
messages messages.push(
.push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]))); Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
);
let spec = format_messages(&messages, &ImageFormat::OpenAi); let spec = format_messages(&messages, &ImageFormat::OpenAi);
@@ -730,7 +732,7 @@ mod tests {
let message = response_to_message(response)?; let message = response_to_message(response)?;
assert_eq!(message.content.len(), 1); assert_eq!(message.content.len(), 1);
if let MessageContent::ToolRequest(request) = &message.content[0] { if let MessageContent::ToolReq(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap(); let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "example_fn"); assert_eq!(tool_call.name, "example_fn");
assert_eq!(tool_call.arguments, json!({"param": "value"})); assert_eq!(tool_call.arguments, json!({"param": "value"}));
@@ -749,8 +751,8 @@ mod tests {
let message = response_to_message(response)?; let message = response_to_message(response)?;
if let MessageContent::ToolRequest(request) = &message.content[0] { if let MessageContent::ToolReq(request) = &message.content[0] {
match &request.tool_call { match &request.tool_call.as_result() {
Err(ToolError::NotFound(msg)) => { Err(ToolError::NotFound(msg)) => {
assert!(msg.starts_with("The provided function name")); assert!(msg.starts_with("The provided function name"));
} }
@@ -771,8 +773,8 @@ mod tests {
let message = response_to_message(response)?; let message = response_to_message(response)?;
if let MessageContent::ToolRequest(request) = &message.content[0] { if let MessageContent::ToolReq(request) = &message.content[0] {
match &request.tool_call { match &request.tool_call.as_result() {
Err(ToolError::InvalidParameters(msg)) => { Err(ToolError::InvalidParameters(msg)) => {
assert!(msg.starts_with("Could not interpret tool use parameters")); assert!(msg.starts_with("Could not interpret tool use parameters"));
} }
@@ -793,7 +795,7 @@ mod tests {
let message = response_to_message(response)?; let message = response_to_message(response)?;
if let MessageContent::ToolRequest(request) = &message.content[0] { if let MessageContent::ToolReq(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap(); let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "example_fn"); assert_eq!(tool_call.name, "example_fn");
assert_eq!(tool_call.arguments, json!({})); assert_eq!(tool_call.arguments, json!({}));

View File

@@ -7,36 +7,24 @@ use thiserror::Error;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::types::json_value_ffi::JsonValueFfi;
use crate::{message::Message, providers::Usage}; use crate::{message::Message, providers::Usage};
use crate::{model::ModelConfig, providers::errors::ProviderError}; use crate::{model::ModelConfig, providers::errors::ProviderError};
pub struct CompletionRequest<'a> { // Lifetimes are not supported in Uniffi, cause other languages don't have them
pub provider_name: &'a str, // https://github.com/mozilla/uniffi-rs/issues/1526#issuecomment-1528851837
#[derive(uniffi::Record)]
pub struct CompletionRequest {
pub provider_name: String,
pub model_config: ModelConfig, pub model_config: ModelConfig,
pub system_preamble: &'a str, pub system_preamble: String,
pub messages: &'a [Message], pub messages: Vec<Message>,
pub extensions: &'a [ExtensionConfig], pub extensions: Vec<ExtensionConfig>,
} }
impl<'a> CompletionRequest<'a> { // https://mozilla.github.io/uniffi-rs/latest/proc_macro/errors.html
pub fn new( #[derive(Debug, Error, uniffi::Error)]
provider_name: &'a str, #[uniffi(flat_error)]
model_config: ModelConfig,
system_preamble: &'a str,
messages: &'a [Message],
extensions: &'a [ExtensionConfig],
) -> Self {
Self {
provider_name,
model_config,
system_preamble,
messages,
extensions,
}
}
}
#[derive(Debug, Error)]
pub enum CompletionError { pub enum CompletionError {
#[error("failed to create provider: {0}")] #[error("failed to create provider: {0}")]
UnknownProvider(String), UnknownProvider(String),
@@ -54,7 +42,7 @@ pub enum CompletionError {
ToolNotFound(String), ToolNotFound(String),
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
pub struct CompletionResponse { pub struct CompletionResponse {
pub message: Message, pub message: Message,
pub model: String, pub model: String,
@@ -78,35 +66,35 @@ impl CompletionResponse {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
pub struct RuntimeMetrics { pub struct RuntimeMetrics {
pub total_time_ms: u128, pub total_time_sec: f32,
pub total_time_ms_provider: u128, pub total_time_sec_provider: f32,
pub tokens_per_second: Option<f64>, pub tokens_per_second: Option<f64>,
} }
impl RuntimeMetrics { impl RuntimeMetrics {
pub fn new( pub fn new(
total_time_ms: u128, total_time_sec: f32,
total_time_ms_provider: u128, total_time_sec_provider: f32,
tokens_per_second: Option<f64>, tokens_per_second: Option<f64>,
) -> Self { ) -> Self {
Self { Self {
total_time_ms, total_time_sec,
total_time_ms_provider, total_time_sec_provider,
tokens_per_second, tokens_per_second,
} }
} }
} }
#[derive(Debug, Clone, PartialEq, Serialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
pub enum ToolApprovalMode { pub enum ToolApprovalMode {
Auto, Auto,
Manual, Manual,
Smart, Smart,
} }
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolConfig { pub struct ToolConfig {
pub name: String, pub name: String,
pub description: String, pub description: String,
@@ -140,7 +128,28 @@ impl ToolConfig {
} }
} }
#[derive(Debug, Clone, Serialize)] #[uniffi::export]
pub fn create_tool_config(
name: &str,
description: &str,
input_schema: JsonValueFfi,
approval_mode: ToolApprovalMode,
) -> ToolConfig {
ToolConfig::new(name, description, input_schema.into(), approval_mode)
}
uniffi::custom_type!(ToolConfig, String, {
lower: |tc: &ToolConfig| {
serde_json::to_string(&tc).unwrap()
},
try_lift: |s: String| {
Ok(serde_json::from_str(&s).unwrap())
},
});
// — Register the newtypes with UniFFI, converting via JSON strings —
#[derive(Debug, Clone, Serialize, uniffi::Record)]
pub struct ExtensionConfig { pub struct ExtensionConfig {
name: String, name: String,
instructions: Option<String>, instructions: Option<String>,

View File

@@ -4,14 +4,14 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum Role { pub enum Role {
User, User,
Assistant, Assistant,
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
#[serde(tag = "type", rename_all = "camelCase")] #[serde(tag = "type", rename_all = "camelCase")]
pub enum Content { pub enum Content {
Text(TextContent), Text(TextContent),
@@ -47,13 +47,13 @@ impl Content {
} }
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct TextContent { pub struct TextContent {
pub text: String, pub text: String,
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ImageContent { pub struct ImageContent {
pub data: String, pub data: String,
@@ -116,7 +116,7 @@ impl ToolCall {
} }
#[non_exhaustive] #[non_exhaustive]
#[derive(Error, Debug, Clone, Deserialize, Serialize, PartialEq)] #[derive(Error, Debug, Clone, Deserialize, Serialize, PartialEq, uniffi::Error)]
pub enum ToolError { pub enum ToolError {
#[error("Invalid parameters: {0}")] #[error("Invalid parameters: {0}")]
InvalidParameters(String), InvalidParameters(String),

View File

@@ -0,0 +1,84 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
// `serde_json::Value` gets converted to a `String` to pass across the FFI.
// https://github.com/mozilla/uniffi-rs/blob/main/docs/manual/src/types/custom_types.md?plain=1
// https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/examples/custom-types/src/lib.rs#L63-L69
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct JsonValueFfi(Value);
impl From<JsonValueFfi> for Value {
fn from(val: JsonValueFfi) -> Self {
val.0
}
}
impl From<Value> for JsonValueFfi {
fn from(val: Value) -> Self {
JsonValueFfi(val)
}
}
uniffi::custom_type!(JsonValueFfi, String, {
lower: |obj| {
serde_json::to_string(&obj.0).unwrap()
},
try_lift: |val| {
Ok(serde_json::from_str(&val).unwrap() )
},
});
// Write some tests to ensure that the conversion works as expected
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_json_value_ffi_conversion() {
let original = JsonValueFfi(json!({"key": "value"}));
let serialized = serde_json::to_string(&original).unwrap();
let deserialized: JsonValueFfi = serde_json::from_str(&serialized).unwrap();
assert_eq!(original.0, deserialized.0);
}
#[test]
fn test_json_value_ffi_to_serde() {
let original = JsonValueFfi(json!({"key": "value"}));
let value: Value = original.into();
assert_eq!(value, json!({"key": "value"}));
}
#[test]
fn test_json_value_ffi_from_serde() {
let value = json!({"key": "value"});
let original: JsonValueFfi = value.into();
assert_eq!(original.0, json!({"key": "value"}));
}
#[test]
fn test_json_value_ffi_lower() {
let original = JsonValueFfi(json!({"key": "value"}));
let serialized = serde_json::to_string(&original).unwrap();
assert_eq!(serialized, "{\"key\":\"value\"}");
}
#[test]
fn test_json_value_ffi_try_lift() {
let json_str = "{\"key\":\"value\"}";
let deserialized: JsonValueFfi = serde_json::from_str(json_str).unwrap();
let expected = JsonValueFfi(json!({"key": "value"}));
assert_eq!(deserialized.0, expected.0);
}
#[test]
fn test_json_value_ffi_custom_type() {
let json_str = "{\"key\":\"value\"}";
let deserialized: JsonValueFfi = serde_json::from_str(json_str).unwrap();
let serialized = serde_json::to_string(&deserialized).unwrap();
assert_eq!(serialized, json_str);
}
}

View File

@@ -1,2 +1,3 @@
pub mod completion; pub mod completion;
pub mod core; pub mod core;
pub mod json_value_ffi;

View File

@@ -45,14 +45,14 @@ async fn test_generate_tooltip_with_tools() -> Result<(), ProviderError> {
let mut tool_req_msg = Message::assistant(); let mut tool_req_msg = Message::assistant();
let req = ToolRequest { let req = ToolRequest {
id: "1".to_string(), id: "1".to_string(),
tool_call: Ok(ToolCall::new("get_time", json!({"timezone": "UTC"}))), tool_call: Ok(ToolCall::new("get_time", json!({"timezone": "UTC"}))).into(),
}; };
tool_req_msg.content.push(MessageContent::ToolRequest(req)); tool_req_msg.content.push(MessageContent::ToolReq(req));
// 2) User message with the tool response // 2) User message with the tool response
let tool_resp_msg = Message::user().with_tool_response( let tool_resp_msg = Message::user().with_tool_response(
"1", "1",
Ok(vec![Content::text("The current time is 12:00 UTC")]), Ok(vec![Content::text("The current time is 12:00 UTC")]).into(),
); );
let messages = vec![tool_req_msg, tool_resp_msg]; let messages = vec![tool_req_msg, tool_resp_msg];

View File

@@ -147,7 +147,7 @@ impl ProviderTester {
.message .message
.content .content
.iter() .iter()
.any(|content| matches!(content, MessageContent::ToolRequest(_))), .any(|content| matches!(content, MessageContent::ToolReq(_))),
"Expected tool request in response" "Expected tool request in response"
); );
@@ -171,7 +171,8 @@ impl ProviderTester {
Weather Weather
Saturday 9:00 PM Saturday 9:00 PM
Clear", Clear",
)]), )])
.into(),
); );
// Verify we construct a valid payload including the request/response pair for the next inference // Verify we construct a valid payload including the request/response pair for the next inference

View File

@@ -0,0 +1,3 @@
fn main() {
uniffi::uniffi_bindgen_main()
}