mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 07:04:21 +01:00
[goose-llm] autogenate kotlin bindings using uniffi-rs proc macros (#2478)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
*.jar
|
||||
run_cli.sh
|
||||
tokenizer_files/
|
||||
.DS_Store
|
||||
|
||||
289
Cargo.lock
generated
289
Cargo.lock
generated
@@ -201,6 +201,48 @@ version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
|
||||
|
||||
[[package]]
|
||||
name = "askama"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d4744ed2eef2645831b441d8f5459689ade2ab27c854488fbab1fbe94fce1a7"
|
||||
dependencies = [
|
||||
"askama_derive",
|
||||
"itoa",
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "askama_derive"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d661e0f57be36a5c14c48f78d09011e67e0cb618f269cca9f2fd8d15b68c46ac"
|
||||
dependencies = [
|
||||
"askama_parser",
|
||||
"basic-toml",
|
||||
"memchr",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustc-hash 2.1.1",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"syn 2.0.99",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "askama_parser"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf315ce6524c857bb129ff794935cf6d42c82a6cff60526fe2a63593de4d0d4f"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "assert-json-diff"
|
||||
version = "2.0.2"
|
||||
@@ -211,6 +253,19 @@ dependencies = [
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-compat"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7bab94bde396a3f7b4962e396fdad640e241ed797d4d8d77fc8c237d14c58fc0"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-compression"
|
||||
version = "0.4.20"
|
||||
@@ -843,6 +898,15 @@ dependencies = [
|
||||
"vsimd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "basic-toml"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba62675e8242a4c4e806d12f11d136e626e6c8361d6b829310732241652a178a"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bat"
|
||||
version = "0.24.0"
|
||||
@@ -1086,6 +1150,38 @@ version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d2c12f985c78475a6b8d629afd0c360260ef34cfef52efccdcfd31972f81c2e"
|
||||
|
||||
[[package]]
|
||||
name = "camino"
|
||||
version = "1.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cargo-platform"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e35af189006b9c0f00a064685c727031e3ed2d8020f7ba284d78cc2671bd36ea"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cargo_metadata"
|
||||
version = "0.19.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd5eb614ed4c27c5d706420e4320fbe3216ab31fa1c33cd8246ac36dae4479ba"
|
||||
dependencies = [
|
||||
"camino",
|
||||
"cargo-platform",
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cast"
|
||||
version = "0.3.0"
|
||||
@@ -2100,6 +2196,15 @@ version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa"
|
||||
|
||||
[[package]]
|
||||
name = "fs-err"
|
||||
version = "2.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88a41f105fe1d5b6b34b2055e3dc59bb79b46b48b2040b9e6c7b4b5de097aa41"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fs2"
|
||||
version = "0.4.3"
|
||||
@@ -2302,6 +2407,17 @@ dependencies = [
|
||||
"regex-syntax 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "goblin"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b363a30c165f666402fe6a3024d3bec7ebc898f96a4a23bd1c99f8dbf3f4f47"
|
||||
dependencies = [
|
||||
"log",
|
||||
"plain",
|
||||
"scroll",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "google-apis-common"
|
||||
version = "7.0.0"
|
||||
@@ -2544,6 +2660,7 @@ dependencies = [
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"uniffi",
|
||||
"url",
|
||||
]
|
||||
|
||||
@@ -4413,6 +4530,12 @@ version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
|
||||
|
||||
[[package]]
|
||||
name = "plain"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6"
|
||||
|
||||
[[package]]
|
||||
name = "plist"
|
||||
version = "1.7.0"
|
||||
@@ -5279,6 +5402,26 @@ version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "scroll"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ab8598aa408498679922eff7fa985c25d58a90771bd6be794434c5277eab1a6"
|
||||
dependencies = [
|
||||
"scroll_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scroll_derive"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1783eabc414609e28a5ba76aee5ddd52199f7107a0b24c2e9746a1ecc34a683d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.99",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sct"
|
||||
version = "0.7.1"
|
||||
@@ -5342,6 +5485,9 @@ name = "semver"
|
||||
version = "1.0.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
@@ -5579,6 +5725,12 @@ dependencies = [
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "siphasher"
|
||||
version = "0.3.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d"
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.9"
|
||||
@@ -5631,6 +5783,12 @@ version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "std_prelude"
|
||||
version = "0.2.12"
|
||||
@@ -6464,6 +6622,128 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||
|
||||
[[package]]
|
||||
name = "uniffi"
|
||||
version = "0.29.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4dcd1d240101ba3b9d7532ae86d9cb64d9a7ff63e13a2b7b9e94a32a601d8233"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"camino",
|
||||
"cargo_metadata",
|
||||
"clap 4.5.31",
|
||||
"uniffi_bindgen",
|
||||
"uniffi_core",
|
||||
"uniffi_macros",
|
||||
"uniffi_pipeline",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uniffi_bindgen"
|
||||
version = "0.29.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d0525f06d749ea80d8049dc0bb038bb87941e3d909eefa76b6f0a5589b59ac5"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"askama",
|
||||
"camino",
|
||||
"cargo_metadata",
|
||||
"fs-err",
|
||||
"glob",
|
||||
"goblin",
|
||||
"heck 0.5.0",
|
||||
"indexmap 2.7.1",
|
||||
"once_cell",
|
||||
"serde",
|
||||
"tempfile",
|
||||
"textwrap",
|
||||
"toml 0.5.11",
|
||||
"uniffi_internal_macros",
|
||||
"uniffi_meta",
|
||||
"uniffi_pipeline",
|
||||
"uniffi_udl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uniffi_core"
|
||||
version = "0.29.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3fa8eb4d825b4ed095cb13483cba6927c3002b9eb603cef9b7688758cc3772e"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-compat",
|
||||
"bytes",
|
||||
"once_cell",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uniffi_internal_macros"
|
||||
version = "0.29.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83b547d69d699e52f2129fde4b57ae0d00b5216e59ed5b56097c95c86ba06095"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"indexmap 2.7.1",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.99",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uniffi_macros"
|
||||
version = "0.29.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "00f1de72edc8cb9201c7d650e3678840d143e4499004571aac49e6cb1b17da43"
|
||||
dependencies = [
|
||||
"camino",
|
||||
"fs-err",
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde",
|
||||
"syn 2.0.99",
|
||||
"toml 0.5.11",
|
||||
"uniffi_meta",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uniffi_meta"
|
||||
version = "0.29.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3acc9204632f6a555b2cba7c8852c5523bc1aa5f3eff605c64af5054ea28b72e"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"siphasher",
|
||||
"uniffi_internal_macros",
|
||||
"uniffi_pipeline",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uniffi_pipeline"
|
||||
version = "0.29.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "54b5336a9a925b358183837d31541d12590b7fcec373256d3770de02dff24c69"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"heck 0.5.0",
|
||||
"indexmap 2.7.1",
|
||||
"tempfile",
|
||||
"uniffi_internal_macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uniffi_udl"
|
||||
version = "0.29.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f95e73373d85f04736bc51997d3e6855721144ec4384cae9ca8513c80615e129"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"textwrap",
|
||||
"uniffi_meta",
|
||||
"weedle2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unsafe-libyaml"
|
||||
version = "0.2.11"
|
||||
@@ -6757,6 +7037,15 @@ dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "weedle2"
|
||||
version = "5.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "998d2c24ec099a87daf9467808859f9d82b61f1d9c9701251aea037f514eae0e"
|
||||
dependencies = [
|
||||
"nom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "weezl"
|
||||
version = "0.1.8"
|
||||
|
||||
130
bindings/kotlin/example/Usage.kt
Normal file
130
bindings/kotlin/example/Usage.kt
Normal file
@@ -0,0 +1,130 @@
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import uniffi.goose_llm.*
|
||||
|
||||
fun main() = runBlocking {
|
||||
val now = System.currentTimeMillis() / 1000
|
||||
val msgs = listOf(
|
||||
// 1) User sends a plain-text prompt
|
||||
Message(
|
||||
role = Role.USER,
|
||||
created = now,
|
||||
content = listOf(
|
||||
MessageContent.Text(
|
||||
TextContent("What is 7 x 6?")
|
||||
)
|
||||
)
|
||||
),
|
||||
|
||||
// 2) Assistant makes a tool request (ToolReq) to calculate 7×6
|
||||
Message(
|
||||
role = Role.ASSISTANT,
|
||||
created = now + 2,
|
||||
content = listOf(
|
||||
MessageContent.ToolReq(
|
||||
ToolRequest(
|
||||
id = "calc1",
|
||||
toolCall = """
|
||||
{
|
||||
"status": "success",
|
||||
"value": {
|
||||
"name": "calculator_extension__toolname",
|
||||
"arguments": {
|
||||
"operation": "multiply",
|
||||
"numbers": [7, 6]
|
||||
},
|
||||
"needsApproval": false
|
||||
}
|
||||
}
|
||||
""".trimIndent()
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
|
||||
// 3) User (on behalf of the tool) responds with the tool result (ToolResp)
|
||||
Message(
|
||||
role = Role.USER,
|
||||
created = now + 3,
|
||||
content = listOf(
|
||||
MessageContent.ToolResp(
|
||||
ToolResponse(
|
||||
id = "calc1",
|
||||
toolResult = """
|
||||
{
|
||||
"status": "success",
|
||||
"value": [
|
||||
{"type": "text", "text": "42"}
|
||||
]
|
||||
}
|
||||
""".trimIndent()
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
printMessages(msgs)
|
||||
println("---\n")
|
||||
|
||||
val sessionName = generateSessionName(msgs)
|
||||
println("Session Name: $sessionName")
|
||||
|
||||
val tooltip = generateTooltip(msgs)
|
||||
println("Tooltip: $tooltip")
|
||||
|
||||
// Completion
|
||||
val provider = "databricks"
|
||||
val modelName = "goose-gpt-4-1"
|
||||
val modelConfig = ModelConfig(
|
||||
modelName,
|
||||
100000u, // UInt
|
||||
0.1f, // Float
|
||||
200 // Int
|
||||
)
|
||||
|
||||
val calculatorTool = createToolConfig(
|
||||
name = "calculator",
|
||||
description = "Perform basic arithmetic operations",
|
||||
inputSchema = """
|
||||
{
|
||||
"type": "object",
|
||||
"required": ["operation", "numbers"],
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"enum": ["add", "subtract", "multiply", "divide"],
|
||||
"description": "The arithmetic operation to perform"
|
||||
},
|
||||
"numbers": {
|
||||
"type": "array",
|
||||
"items": { "type": "number" },
|
||||
"description": "List of numbers to operate on in order"
|
||||
}
|
||||
}
|
||||
}
|
||||
""".trimIndent(),
|
||||
approvalMode = ToolApprovalMode.AUTO
|
||||
)
|
||||
|
||||
val calculator_extension = ExtensionConfig(
|
||||
name = "calculator_extension",
|
||||
instructions = "This extension provides a calculator tool.",
|
||||
tools = listOf(calculatorTool)
|
||||
)
|
||||
|
||||
val extensions = listOf(calculator_extension)
|
||||
val systemPreamble = "You are a helpful assistant."
|
||||
|
||||
|
||||
val req = CompletionRequest(
|
||||
provider,
|
||||
modelConfig,
|
||||
systemPreamble,
|
||||
msgs,
|
||||
extensions
|
||||
)
|
||||
|
||||
val response = completion(req)
|
||||
println("\nCompletion Response:")
|
||||
println(response.message)
|
||||
}
|
||||
2933
bindings/kotlin/uniffi/goose_llm/goose_llm.kt
Normal file
2933
bindings/kotlin/uniffi/goose_llm/goose_llm.kt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,8 +7,11 @@ license.workspace = true
|
||||
repository.workspace = true
|
||||
description.workspace = true
|
||||
|
||||
[lib]
|
||||
crate-type = ["lib", "cdylib"]
|
||||
name = "goose_llm"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.43", features = ["full"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
anyhow = "1.0"
|
||||
@@ -36,6 +39,9 @@ regex = "1.11.1"
|
||||
tracing = "0.1"
|
||||
smallvec = { version = "1.13", features = ["serde"] }
|
||||
indoc = "1.0"
|
||||
# https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/fixtures/futures/Cargo.toml#L22
|
||||
uniffi = { version = "0.29", features = ["tokio", "cli", "scaffolding-ffi-buffer-fns"] }
|
||||
tokio = { version = "1.43", features = ["time", "sync"] }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5"
|
||||
@@ -43,7 +49,12 @@ tempfile = "3.15.0"
|
||||
dotenv = "0.15"
|
||||
lazy_static = "1.5"
|
||||
ctor = "0.2.7"
|
||||
tokio = { version = "1.43", features = ["full"] }
|
||||
|
||||
[[bin]]
|
||||
# https://mozilla.github.io/uniffi-rs/latest/tutorial/foreign_language_bindings.html
|
||||
name = "uniffi-bindgen"
|
||||
path = "uniffi-bindgen.rs"
|
||||
|
||||
[[example]]
|
||||
name = "simple"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
### goose-llm
|
||||
## goose-llm
|
||||
|
||||
This crate is meant to be used for foreign function interface (FFI). It's meant to be
|
||||
stateless and contain logic related to providers and prompts:
|
||||
@@ -12,3 +12,59 @@ Run:
|
||||
cargo run -p goose-llm --example simple
|
||||
```
|
||||
|
||||
|
||||
## Kotlin bindings
|
||||
|
||||
Structure:
|
||||
```
|
||||
.
|
||||
└── crates
|
||||
└── goose-llm/...
|
||||
└── target
|
||||
└── debug/libgoose_llm.dylib
|
||||
├── bindings
|
||||
│ └── kotlin
|
||||
│ ├── example
|
||||
│ │ └── Usage.kt ← your demo app
|
||||
│ └── uniffi
|
||||
│ └── goose_llm
|
||||
│ └── goose_llm.kt ← auto-generated bindings
|
||||
```
|
||||
|
||||
Create Kotlin bindings:
|
||||
```
|
||||
# run from project root directory
|
||||
cargo build -p goose-llm
|
||||
|
||||
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
|
||||
```
|
||||
|
||||
Download jars in `bindings/kotlin/libs` directory (only need to do this once):
|
||||
```
|
||||
pushd bindings/kotlin/libs/
|
||||
curl -O https://repo1.maven.org/maven2/org/jetbrains/kotlin/kotlin-stdlib/1.9.0/kotlin-stdlib-1.9.0.jar
|
||||
curl -O https://repo1.maven.org/maven2/org/jetbrains/kotlinx/kotlinx-coroutines-core-jvm/1.7.3/kotlinx-coroutines-core-jvm-1.7.3.jar
|
||||
curl -O https://repo1.maven.org/maven2/net/java/dev/jna/jna/5.13.0/jna-5.13.0.jar
|
||||
popd
|
||||
```
|
||||
|
||||
|
||||
Compile & Run usage example from Kotlin -> Rust:
|
||||
```
|
||||
pushd bindings/kotlin/
|
||||
|
||||
kotlinc \
|
||||
example/Usage.kt \
|
||||
uniffi/goose_llm/goose_llm.kt \
|
||||
-classpath "libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
|
||||
-include-runtime \
|
||||
-d example.jar
|
||||
|
||||
java \
|
||||
-Djna.library.path=$HOME/Development/goose/target/debug \
|
||||
-classpath "example.jar:libs/kotlin-stdlib-1.9.0.jar:libs/kotlinx-coroutines-core-jvm-1.7.3.jar:libs/jna-5.13.0.jar" \
|
||||
UsageKt
|
||||
|
||||
popd
|
||||
```
|
||||
|
||||
|
||||
@@ -93,13 +93,13 @@ async fn main() -> Result<()> {
|
||||
println!("\n---------------\n");
|
||||
println!("User Input: {text}");
|
||||
let messages = vec![Message::user().with_text(text)];
|
||||
let completion_response: CompletionResponse = completion(CompletionRequest::new(
|
||||
provider,
|
||||
model_config.clone(),
|
||||
system_preamble,
|
||||
&messages,
|
||||
&extensions,
|
||||
))
|
||||
let completion_response: CompletionResponse = completion(CompletionRequest {
|
||||
provider_name: provider.to_string(),
|
||||
model_config: model_config.clone(),
|
||||
system_preamble: system_preamble.to_string(),
|
||||
messages: messages,
|
||||
extensions: extensions.clone(),
|
||||
})
|
||||
.await?;
|
||||
// Print the response
|
||||
println!("\nCompletion Response:");
|
||||
|
||||
@@ -17,32 +17,40 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
#[uniffi::export]
|
||||
pub fn print_messages(messages: Vec<Message>) {
|
||||
for msg in messages {
|
||||
println!("[{:?} @ {}] {:?}", msg.role, msg.created, msg.content);
|
||||
}
|
||||
}
|
||||
|
||||
/// Public API for the Goose LLM completion function
|
||||
pub async fn completion(req: CompletionRequest<'_>) -> Result<CompletionResponse, CompletionError> {
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
pub async fn completion(req: CompletionRequest) -> Result<CompletionResponse, CompletionError> {
|
||||
let start_total = Instant::now();
|
||||
|
||||
let provider = create(req.provider_name, req.model_config)
|
||||
let provider = create(&req.provider_name, req.model_config)
|
||||
.map_err(|_| CompletionError::UnknownProvider(req.provider_name.to_string()))?;
|
||||
|
||||
let system_prompt = construct_system_prompt(req.system_preamble, req.extensions)?;
|
||||
let tools = collect_prefixed_tools(req.extensions);
|
||||
let system_prompt = construct_system_prompt(&req.system_preamble, &req.extensions)?;
|
||||
let tools = collect_prefixed_tools(&req.extensions);
|
||||
|
||||
// Call the LLM provider
|
||||
let start_provider = Instant::now();
|
||||
let mut response = provider
|
||||
.complete(&system_prompt, req.messages, &tools)
|
||||
.complete(&system_prompt, &req.messages, &tools)
|
||||
.await?;
|
||||
let provider_elapsed_ms = start_provider.elapsed().as_millis();
|
||||
let provider_elapsed_sec = start_provider.elapsed().as_secs_f32();
|
||||
let usage_tokens = response.usage.total_tokens;
|
||||
|
||||
let tool_configs = collect_prefixed_tool_configs(req.extensions);
|
||||
let tool_configs = collect_prefixed_tool_configs(&req.extensions);
|
||||
update_needs_approval_for_tool_calls(&mut response.message, &tool_configs)?;
|
||||
|
||||
Ok(CompletionResponse::new(
|
||||
response.message,
|
||||
response.model,
|
||||
response.usage,
|
||||
calculate_runtime_metrics(start_total, provider_elapsed_ms, usage_tokens),
|
||||
calculate_runtime_metrics(start_total, provider_elapsed_sec, usage_tokens),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -81,8 +89,8 @@ pub fn update_needs_approval_for_tool_calls(
|
||||
tool_configs: &HashMap<String, ToolConfig>,
|
||||
) -> Result<(), CompletionError> {
|
||||
for content in &mut message.content.iter_mut() {
|
||||
if let MessageContent::ToolRequest(req) = content {
|
||||
if let Ok(call) = &mut req.tool_call {
|
||||
if let MessageContent::ToolReq(req) = content {
|
||||
if let Ok(call) = &mut req.tool_call.0 {
|
||||
// Provide a clear error message when the tool config is missing
|
||||
let config = tool_configs.get(&call.name).ok_or_else(|| {
|
||||
CompletionError::ToolNotFound(format!(
|
||||
@@ -117,16 +125,16 @@ fn collect_prefixed_tool_configs(extensions: &[ExtensionConfig]) -> HashMap<Stri
|
||||
/// Compute runtime metrics for the request.
|
||||
fn calculate_runtime_metrics(
|
||||
total_start: Instant,
|
||||
provider_elapsed_ms: u128,
|
||||
provider_elapsed_sec: f32,
|
||||
token_count: Option<i32>,
|
||||
) -> RuntimeMetrics {
|
||||
let total_ms = total_start.elapsed().as_millis();
|
||||
let total_ms = total_start.elapsed().as_secs_f32();
|
||||
let tokens_per_sec = token_count.and_then(|toks| {
|
||||
if provider_elapsed_ms > 0 {
|
||||
Some(toks as f64 / (provider_elapsed_ms as f64 / 1_000.0))
|
||||
if provider_elapsed_sec > 0.0 {
|
||||
Some(toks as f64 / (provider_elapsed_sec as f64))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
RuntimeMetrics::new(total_ms, provider_elapsed_ms, tokens_per_sec)
|
||||
RuntimeMetrics::new(total_ms, provider_elapsed_sec, tokens_per_sec)
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ fn build_system_prompt() -> String {
|
||||
}
|
||||
|
||||
/// Generates a short (≤4 words) session name
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
pub async fn generate_session_name(messages: &[Message]) -> Result<String, ProviderError> {
|
||||
// Collect up to the first 3 user messages (truncated to 300 chars each)
|
||||
let context: Vec<String> = messages
|
||||
|
||||
@@ -53,6 +53,7 @@ fn build_system_prompt() -> String {
|
||||
|
||||
/// Generates a tooltip summarizing the last two messages in the session,
|
||||
/// including any tool calls or results.
|
||||
#[uniffi::export(async_runtime = "tokio")]
|
||||
pub async fn generate_tooltip(messages: &[Message]) -> Result<String, ProviderError> {
|
||||
// Need at least two messages to summarize
|
||||
if messages.len() < 2 {
|
||||
@@ -72,17 +73,17 @@ pub async fn generate_tooltip(messages: &[Message]) -> Result<String, ProviderEr
|
||||
parts.push(txt.to_string());
|
||||
}
|
||||
}
|
||||
MessageContent::ToolRequest(req) => {
|
||||
if let Ok(tool_call) = &req.tool_call {
|
||||
MessageContent::ToolReq(req) => {
|
||||
if let Ok(tool_call) = &req.tool_call.0 {
|
||||
parts.push(format!(
|
||||
"called tool '{}' with args {}",
|
||||
tool_call.name, tool_call.arguments
|
||||
));
|
||||
} else if let Err(e) = &req.tool_call {
|
||||
} else if let Err(e) = &req.tool_call.0 {
|
||||
parts.push(format!("tool request error: {}", e));
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResponse(resp) => match &resp.tool_result {
|
||||
MessageContent::ToolResp(resp) => match &resp.tool_result.0 {
|
||||
Ok(contents) => {
|
||||
let results: Vec<String> = contents
|
||||
.iter()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
uniffi::setup_scaffolding!();
|
||||
|
||||
mod completion;
|
||||
pub mod extractors;
|
||||
pub mod message;
|
||||
|
||||
@@ -1,531 +0,0 @@
|
||||
use std::{collections::HashSet, iter::FromIterator, ops::Deref};
|
||||
|
||||
/// Messages which represent the content sent back and forth to LLM provider
|
||||
///
|
||||
/// We use these messages in the agent code, and interfaces which interact with
|
||||
/// the agent. That let's us reuse message histories across different interfaces.
|
||||
///
|
||||
/// The content of the messages uses MCP types to avoid additional conversions
|
||||
/// when interacting with MCP servers.
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smallvec::SmallVec;
|
||||
|
||||
use crate::types::core::{Content, ImageContent, Role, TextContent, ToolCall, ToolResult};
|
||||
|
||||
mod tool_result_serde;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolRequest {
|
||||
pub id: String,
|
||||
#[serde(with = "tool_result_serde")]
|
||||
pub tool_call: ToolResult<ToolCall>,
|
||||
}
|
||||
|
||||
impl ToolRequest {
|
||||
pub fn to_readable_string(&self) -> String {
|
||||
match &self.tool_call {
|
||||
Ok(tool_call) => {
|
||||
format!(
|
||||
"Tool: {}, Args: {}",
|
||||
tool_call.name,
|
||||
serde_json::to_string_pretty(&tool_call.arguments)
|
||||
.unwrap_or_else(|_| "<<invalid json>>".to_string())
|
||||
)
|
||||
}
|
||||
Err(e) => format!("Invalid tool call: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolResponse {
|
||||
pub id: String,
|
||||
#[serde(with = "tool_result_serde")]
|
||||
pub tool_result: ToolResult<Vec<Content>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ThinkingContent {
|
||||
pub thinking: String,
|
||||
pub signature: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct RedactedThinkingContent {
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
/// Content passed inside a message, which can be both simple content and tool content
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
pub enum MessageContent {
|
||||
Text(TextContent),
|
||||
Image(ImageContent),
|
||||
ToolRequest(ToolRequest),
|
||||
ToolResponse(ToolResponse),
|
||||
Thinking(ThinkingContent),
|
||||
RedactedThinking(RedactedThinkingContent),
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn text<S: Into<String>>(text: S) -> Self {
|
||||
MessageContent::Text(TextContent { text: text.into() })
|
||||
}
|
||||
|
||||
pub fn image<S: Into<String>, T: Into<String>>(data: S, mime_type: T) -> Self {
|
||||
MessageContent::Image(ImageContent {
|
||||
data: data.into(),
|
||||
mime_type: mime_type.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn tool_request<S: Into<String>>(id: S, tool_call: ToolResult<ToolCall>) -> Self {
|
||||
MessageContent::ToolRequest(ToolRequest {
|
||||
id: id.into(),
|
||||
tool_call,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn tool_response<S: Into<String>>(id: S, tool_result: ToolResult<Vec<Content>>) -> Self {
|
||||
MessageContent::ToolResponse(ToolResponse {
|
||||
id: id.into(),
|
||||
tool_result,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn thinking<S1: Into<String>, S2: Into<String>>(thinking: S1, signature: S2) -> Self {
|
||||
MessageContent::Thinking(ThinkingContent {
|
||||
thinking: thinking.into(),
|
||||
signature: signature.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn redacted_thinking<S: Into<String>>(data: S) -> Self {
|
||||
MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() })
|
||||
}
|
||||
|
||||
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
|
||||
if let MessageContent::ToolRequest(ref tool_request) = self {
|
||||
Some(tool_request)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_tool_response(&self) -> Option<&ToolResponse> {
|
||||
if let MessageContent::ToolResponse(ref tool_response) = self {
|
||||
Some(tool_response)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_tool_response_text(&self) -> Option<String> {
|
||||
if let Some(tool_response) = self.as_tool_response() {
|
||||
if let Ok(contents) = &tool_response.tool_result {
|
||||
let texts: Vec<String> = contents
|
||||
.iter()
|
||||
.filter_map(|content| content.as_text().map(String::from))
|
||||
.collect();
|
||||
if !texts.is_empty() {
|
||||
return Some(texts.join("\n"));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn as_tool_request_id(&self) -> Option<&str> {
|
||||
if let Self::ToolRequest(r) = self {
|
||||
Some(&r.id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_tool_response_id(&self) -> Option<&str> {
|
||||
if let Self::ToolResponse(r) = self {
|
||||
Some(&r.id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the text content if this is a TextContent variant
|
||||
pub fn as_text(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessageContent::Text(text) => Some(&text.text),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the thinking content if this is a ThinkingContent variant
|
||||
pub fn as_thinking(&self) -> Option<&ThinkingContent> {
|
||||
match self {
|
||||
MessageContent::Thinking(thinking) => Some(thinking),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the redacted thinking content if this is a RedactedThinkingContent variant
|
||||
pub fn as_redacted_thinking(&self) -> Option<&RedactedThinkingContent> {
|
||||
match self {
|
||||
MessageContent::RedactedThinking(redacted) => Some(redacted),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_text(&self) -> bool {
|
||||
matches!(self, Self::Text(_))
|
||||
}
|
||||
pub fn is_image(&self) -> bool {
|
||||
matches!(self, Self::Image(_))
|
||||
}
|
||||
pub fn is_tool_request(&self) -> bool {
|
||||
matches!(self, Self::ToolRequest(_))
|
||||
}
|
||||
pub fn is_tool_response(&self) -> bool {
|
||||
matches!(self, Self::ToolResponse(_))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Content> for MessageContent {
|
||||
fn from(content: Content) -> Self {
|
||||
match content {
|
||||
Content::Text(text) => MessageContent::Text(text),
|
||||
Content::Image(image) => MessageContent::Image(image),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// 2. Contents – a new-type wrapper around SmallVec
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Holds the heterogeneous fragments that make up one chat message.
|
||||
///
|
||||
/// * Up to two items are stored inline on the stack.
|
||||
/// * Falls back to a heap allocation only when necessary.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||
#[serde(transparent)]
|
||||
pub struct Contents(SmallVec<[MessageContent; 2]>);
|
||||
|
||||
impl Contents {
|
||||
/*----------------------------------------------------------
|
||||
* 1-line ergonomic helpers
|
||||
*---------------------------------------------------------*/
|
||||
|
||||
pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, MessageContent> {
|
||||
self.0.iter_mut()
|
||||
}
|
||||
|
||||
pub fn push(&mut self, item: impl Into<MessageContent>) {
|
||||
self.0.push(item.into());
|
||||
}
|
||||
|
||||
pub fn texts(&self) -> impl Iterator<Item = &str> {
|
||||
self.0.iter().filter_map(|c| c.as_text())
|
||||
}
|
||||
|
||||
pub fn concat_text_str(&self) -> String {
|
||||
self.texts().collect::<Vec<_>>().join("\n")
|
||||
}
|
||||
|
||||
/// Returns `true` if *any* item satisfies the predicate.
|
||||
pub fn any_is<P>(&self, pred: P) -> bool
|
||||
where
|
||||
P: FnMut(&MessageContent) -> bool,
|
||||
{
|
||||
self.iter().any(pred)
|
||||
}
|
||||
|
||||
/// Returns `true` if *every* item satisfies the predicate.
|
||||
pub fn all_are<P>(&self, pred: P) -> bool
|
||||
where
|
||||
P: FnMut(&MessageContent) -> bool,
|
||||
{
|
||||
self.iter().all(pred)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<MessageContent>> for Contents {
|
||||
fn from(v: Vec<MessageContent>) -> Self {
|
||||
Contents(SmallVec::from_vec(v))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<MessageContent> for Contents {
|
||||
fn from_iter<I: IntoIterator<Item = MessageContent>>(iter: I) -> Self {
|
||||
Contents(SmallVec::from_iter(iter))
|
||||
}
|
||||
}
|
||||
|
||||
/*--------------------------------------------------------------
|
||||
* Allow &message.content to behave like a slice of fragments.
|
||||
*-------------------------------------------------------------*/
|
||||
impl Deref for Contents {
|
||||
type Target = [MessageContent];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
/// A message to or from an LLM
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub created: i64,
|
||||
pub content: Contents,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(role: Role) -> Self {
|
||||
Self {
|
||||
role,
|
||||
created: Utc::now().timestamp_millis(),
|
||||
content: Contents::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new user message with the current timestamp
|
||||
pub fn user() -> Self {
|
||||
Self::new(Role::User)
|
||||
}
|
||||
|
||||
/// Create a new assistant message with the current timestamp
|
||||
pub fn assistant() -> Self {
|
||||
Self::new(Role::Assistant)
|
||||
}
|
||||
|
||||
/// Add any item that implements Into<MessageContent> to the message
|
||||
pub fn with_content(mut self, item: impl Into<MessageContent>) -> Self {
|
||||
self.content.push(item);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add text content to the message
|
||||
pub fn with_text<S: Into<String>>(self, text: S) -> Self {
|
||||
self.with_content(MessageContent::text(text))
|
||||
}
|
||||
|
||||
/// Add image content to the message
|
||||
pub fn with_image<S: Into<String>, T: Into<String>>(self, data: S, mime_type: T) -> Self {
|
||||
self.with_content(MessageContent::image(data, mime_type))
|
||||
}
|
||||
|
||||
/// Add a tool request to the message
|
||||
pub fn with_tool_request<S: Into<String>>(
|
||||
self,
|
||||
id: S,
|
||||
tool_call: ToolResult<ToolCall>,
|
||||
) -> Self {
|
||||
self.with_content(MessageContent::tool_request(id, tool_call))
|
||||
}
|
||||
|
||||
/// Add a tool response to the message
|
||||
pub fn with_tool_response<S: Into<String>>(
|
||||
self,
|
||||
id: S,
|
||||
result: ToolResult<Vec<Content>>,
|
||||
) -> Self {
|
||||
self.with_content(MessageContent::tool_response(id, result))
|
||||
}
|
||||
|
||||
/// Add thinking content to the message
|
||||
pub fn with_thinking<S1: Into<String>, S2: Into<String>>(
|
||||
self,
|
||||
thinking: S1,
|
||||
signature: S2,
|
||||
) -> Self {
|
||||
self.with_content(MessageContent::thinking(thinking, signature))
|
||||
}
|
||||
|
||||
/// Add redacted thinking content to the message
|
||||
pub fn with_redacted_thinking<S: Into<String>>(self, data: S) -> Self {
|
||||
self.with_content(MessageContent::redacted_thinking(data))
|
||||
}
|
||||
|
||||
/// Check if the message is a tool call
|
||||
pub fn contains_tool_call(&self) -> bool {
|
||||
self.content.any_is(MessageContent::is_tool_request)
|
||||
}
|
||||
|
||||
/// Check if the message is a tool response
|
||||
pub fn contains_tool_response(&self) -> bool {
|
||||
self.content.any_is(MessageContent::is_tool_response)
|
||||
}
|
||||
|
||||
/// Check if the message contains only text content
|
||||
pub fn has_only_text_content(&self) -> bool {
|
||||
self.content.all_are(MessageContent::is_text)
|
||||
}
|
||||
|
||||
/// Retrieves all tool `id` from ToolRequest messages
|
||||
pub fn tool_request_ids(&self) -> HashSet<&str> {
|
||||
self.content
|
||||
.iter()
|
||||
.filter_map(MessageContent::as_tool_request_id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Retrieves all tool `id` from ToolResponse messages
|
||||
pub fn tool_response_ids(&self) -> HashSet<&str> {
|
||||
self.content
|
||||
.iter()
|
||||
.filter_map(MessageContent::as_tool_response_id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Retrieves all tool `id` from the message
|
||||
pub fn tool_ids(&self) -> HashSet<&str> {
|
||||
self.tool_request_ids()
|
||||
.into_iter()
|
||||
.chain(self.tool_response_ids())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::*;
|
||||
use crate::types::core::ToolError;
|
||||
|
||||
#[test]
|
||||
fn test_message_serialization() {
|
||||
let message = Message::assistant()
|
||||
.with_text("Hello, I'll help you with that.")
|
||||
.with_tool_request(
|
||||
"tool123",
|
||||
Ok(ToolCall::new("test_tool", json!({"param": "value"}))),
|
||||
);
|
||||
|
||||
let json_str = serde_json::to_string_pretty(&message).unwrap();
|
||||
println!("Serialized message: {}", json_str);
|
||||
|
||||
// Parse back to Value to check structure
|
||||
let value: Value = serde_json::from_str(&json_str).unwrap();
|
||||
|
||||
// Check top-level fields
|
||||
assert_eq!(value["role"], "assistant");
|
||||
assert!(value["created"].is_i64());
|
||||
assert!(value["content"].is_array());
|
||||
|
||||
// Check content items
|
||||
let content = &value["content"];
|
||||
|
||||
// First item should be text
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
assert_eq!(content[0]["text"], "Hello, I'll help you with that.");
|
||||
|
||||
// Second item should be toolRequest
|
||||
assert_eq!(content[1]["type"], "toolRequest");
|
||||
assert_eq!(content[1]["id"], "tool123");
|
||||
|
||||
// Check tool_call serialization
|
||||
assert_eq!(content[1]["toolCall"]["status"], "success");
|
||||
assert_eq!(content[1]["toolCall"]["value"]["name"], "test_tool");
|
||||
assert_eq!(
|
||||
content[1]["toolCall"]["value"]["arguments"]["param"],
|
||||
"value"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_serialization() {
|
||||
let message = Message::assistant().with_tool_request(
|
||||
"tool123",
|
||||
Err(ToolError::ExecutionError(
|
||||
"Something went wrong".to_string(),
|
||||
)),
|
||||
);
|
||||
|
||||
let json_str = serde_json::to_string_pretty(&message).unwrap();
|
||||
println!("Serialized error: {}", json_str);
|
||||
|
||||
// Parse back to Value to check structure
|
||||
let value: Value = serde_json::from_str(&json_str).unwrap();
|
||||
|
||||
// Check tool_call serialization with error
|
||||
let tool_call = &value["content"][0]["toolCall"];
|
||||
assert_eq!(tool_call["status"], "error");
|
||||
assert_eq!(tool_call["error"], "Execution failed: Something went wrong");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialization() {
|
||||
// Create a JSON string with our new format
|
||||
let json_str = r#"{
|
||||
"role": "assistant",
|
||||
"created": 1740171566,
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "I'll help you with that."
|
||||
},
|
||||
{
|
||||
"type": "toolRequest",
|
||||
"id": "tool123",
|
||||
"toolCall": {
|
||||
"status": "success",
|
||||
"value": {
|
||||
"name": "test_tool",
|
||||
"arguments": {"param": "value"},
|
||||
"needsApproval": false
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let message: Message = serde_json::from_str(json_str).unwrap();
|
||||
|
||||
assert_eq!(message.role, Role::Assistant);
|
||||
assert_eq!(message.created, 1740171566);
|
||||
assert_eq!(message.content.len(), 2);
|
||||
|
||||
// Check first content item
|
||||
if let MessageContent::Text(text) = &message.content[0] {
|
||||
assert_eq!(text.text, "I'll help you with that.");
|
||||
} else {
|
||||
panic!("Expected Text content");
|
||||
}
|
||||
|
||||
// Check second content item
|
||||
if let MessageContent::ToolRequest(req) = &message.content[1] {
|
||||
assert_eq!(req.id, "tool123");
|
||||
if let Ok(tool_call) = &req.tool_call {
|
||||
assert_eq!(tool_call.name, "test_tool");
|
||||
assert_eq!(tool_call.arguments, json!({"param": "value"}));
|
||||
} else {
|
||||
panic!("Expected successful tool call");
|
||||
}
|
||||
} else {
|
||||
panic!("Expected ToolRequest content");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_with_text() {
|
||||
let message = Message::user().with_text("Hello");
|
||||
assert_eq!(message.content.concat_text_str(), "Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_with_tool_request() {
|
||||
let tool_call = Ok(ToolCall::new("test_tool", json!({})));
|
||||
|
||||
let message = Message::assistant().with_tool_request("req1", tool_call);
|
||||
assert!(message.contains_tool_call());
|
||||
assert!(!message.contains_tool_response());
|
||||
|
||||
let ids = message.tool_ids();
|
||||
assert_eq!(ids.len(), 1);
|
||||
assert!(ids.contains("req1"));
|
||||
}
|
||||
}
|
||||
84
crates/goose-llm/src/message/contents.rs
Normal file
84
crates/goose-llm/src/message/contents.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use std::{iter::FromIterator, ops::Deref};
|
||||
|
||||
use crate::message::MessageContent;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smallvec::SmallVec;
|
||||
|
||||
/// Holds the heterogeneous fragments that make up one chat message.
|
||||
///
|
||||
/// * Up to two items are stored inline on the stack.
|
||||
/// * Falls back to a heap allocation only when necessary.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||
#[serde(transparent)]
|
||||
pub struct Contents(SmallVec<[MessageContent; 2]>);
|
||||
|
||||
impl Contents {
|
||||
/*----------------------------------------------------------
|
||||
* 1-line ergonomic helpers
|
||||
*---------------------------------------------------------*/
|
||||
|
||||
pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, MessageContent> {
|
||||
self.0.iter_mut()
|
||||
}
|
||||
|
||||
pub fn push(&mut self, item: impl Into<MessageContent>) {
|
||||
self.0.push(item.into());
|
||||
}
|
||||
|
||||
pub fn texts(&self) -> impl Iterator<Item = &str> {
|
||||
self.0.iter().filter_map(|c| c.as_text())
|
||||
}
|
||||
|
||||
pub fn concat_text_str(&self) -> String {
|
||||
self.texts().collect::<Vec<_>>().join("\n")
|
||||
}
|
||||
|
||||
/// Returns `true` if *any* item satisfies the predicate.
|
||||
pub fn any_is<P>(&self, pred: P) -> bool
|
||||
where
|
||||
P: FnMut(&MessageContent) -> bool,
|
||||
{
|
||||
self.iter().any(pred)
|
||||
}
|
||||
|
||||
/// Returns `true` if *every* item satisfies the predicate.
|
||||
pub fn all_are<P>(&self, pred: P) -> bool
|
||||
where
|
||||
P: FnMut(&MessageContent) -> bool,
|
||||
{
|
||||
self.iter().all(pred)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<MessageContent>> for Contents {
|
||||
fn from(v: Vec<MessageContent>) -> Self {
|
||||
Contents(SmallVec::from_vec(v))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<MessageContent> for Contents {
|
||||
fn from_iter<I: IntoIterator<Item = MessageContent>>(iter: I) -> Self {
|
||||
Contents(SmallVec::from_iter(iter))
|
||||
}
|
||||
}
|
||||
|
||||
/*--------------------------------------------------------------
|
||||
* Allow &message.content to behave like a slice of fragments.
|
||||
*-------------------------------------------------------------*/
|
||||
impl Deref for Contents {
|
||||
type Target = [MessageContent];
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
// — Register the contents type with UniFFI, converting to/from Vec<MessageContent> —
|
||||
// We need to do this because UniFFI’s FFI layer supports only primitive buffers (here Vec<u8>),
|
||||
uniffi::custom_type!(Contents, Vec<MessageContent>, {
|
||||
lower: |contents: &Contents| {
|
||||
contents.0.to_vec()
|
||||
},
|
||||
try_lift: |contents: Vec<MessageContent>| {
|
||||
Ok(Contents::from(contents))
|
||||
},
|
||||
});
|
||||
240
crates/goose-llm/src/message/message_content.rs
Normal file
240
crates/goose-llm/src/message/message_content.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
|
||||
use crate::message::tool_result_serde;
|
||||
use crate::types::core::{Content, ImageContent, TextContent, ToolCall, ToolResult};
|
||||
|
||||
// — Newtype wrappers (local structs) so we satisfy Rust’s orphan rules —
|
||||
// We need these because we can’t implement UniFFI’s FfiConverter directly on a type alias.
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ToolRequestToolCall(#[serde(with = "tool_result_serde")] pub ToolResult<ToolCall>);
|
||||
|
||||
impl ToolRequestToolCall {
|
||||
pub fn as_result(&self) -> &ToolResult<ToolCall> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl std::ops::Deref for ToolRequestToolCall {
|
||||
type Target = ToolResult<ToolCall>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl From<Result<ToolCall, crate::types::core::ToolError>> for ToolRequestToolCall {
|
||||
fn from(res: Result<ToolCall, crate::types::core::ToolError>) -> Self {
|
||||
ToolRequestToolCall(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ToolResponseToolResult(
|
||||
#[serde(with = "tool_result_serde")] pub ToolResult<Vec<Content>>,
|
||||
);
|
||||
|
||||
impl ToolResponseToolResult {
|
||||
pub fn as_result(&self) -> &ToolResult<Vec<Content>> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl std::ops::Deref for ToolResponseToolResult {
|
||||
type Target = ToolResult<Vec<Content>>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl From<Result<Vec<Content>, crate::types::core::ToolError>> for ToolResponseToolResult {
|
||||
fn from(res: Result<Vec<Content>, crate::types::core::ToolError>) -> Self {
|
||||
ToolResponseToolResult(res)
|
||||
}
|
||||
}
|
||||
|
||||
// — Register the newtypes with UniFFI, converting via JSON strings —
|
||||
// UniFFI’s FFI layer supports only primitive buffers (here String), so we JSON-serialize
|
||||
// through our `tool_result_serde` to preserve the same success/error schema on both sides.
|
||||
|
||||
uniffi::custom_type!(ToolRequestToolCall, String, {
|
||||
lower: |obj| {
|
||||
serde_json::to_string(&obj.0).unwrap()
|
||||
},
|
||||
try_lift: |val| {
|
||||
Ok(serde_json::from_str(&val).unwrap() )
|
||||
},
|
||||
});
|
||||
|
||||
uniffi::custom_type!(ToolResponseToolResult, String, {
|
||||
lower: |obj| {
|
||||
serde_json::to_string(&obj.0).unwrap()
|
||||
},
|
||||
try_lift: |val| {
|
||||
Ok(serde_json::from_str(&val).unwrap() )
|
||||
},
|
||||
});
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolRequest {
|
||||
pub id: String,
|
||||
pub tool_call: ToolRequestToolCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolResponse {
|
||||
pub id: String,
|
||||
pub tool_result: ToolResponseToolResult,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct ThinkingContent {
|
||||
pub thinking: String,
|
||||
pub signature: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct RedactedThinkingContent {
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
|
||||
/// Content passed inside a message, which can be both simple content and tool content
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
pub enum MessageContent {
|
||||
Text(TextContent),
|
||||
Image(ImageContent),
|
||||
ToolReq(ToolRequest),
|
||||
ToolResp(ToolResponse),
|
||||
Thinking(ThinkingContent),
|
||||
RedactedThinking(RedactedThinkingContent),
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn text<S: Into<String>>(text: S) -> Self {
|
||||
MessageContent::Text(TextContent { text: text.into() })
|
||||
}
|
||||
|
||||
pub fn image<S: Into<String>, T: Into<String>>(data: S, mime_type: T) -> Self {
|
||||
MessageContent::Image(ImageContent {
|
||||
data: data.into(),
|
||||
mime_type: mime_type.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn tool_request<S: Into<String>>(id: S, tool_call: ToolRequestToolCall) -> Self {
|
||||
MessageContent::ToolReq(ToolRequest {
|
||||
id: id.into(),
|
||||
tool_call,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn tool_response<S: Into<String>>(id: S, tool_result: ToolResponseToolResult) -> Self {
|
||||
MessageContent::ToolResp(ToolResponse {
|
||||
id: id.into(),
|
||||
tool_result,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn thinking<S1: Into<String>, S2: Into<String>>(thinking: S1, signature: S2) -> Self {
|
||||
MessageContent::Thinking(ThinkingContent {
|
||||
thinking: thinking.into(),
|
||||
signature: signature.into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn redacted_thinking<S: Into<String>>(data: S) -> Self {
|
||||
MessageContent::RedactedThinking(RedactedThinkingContent { data: data.into() })
|
||||
}
|
||||
|
||||
pub fn as_tool_request(&self) -> Option<&ToolRequest> {
|
||||
if let MessageContent::ToolReq(ref tool_request) = self {
|
||||
Some(tool_request)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_tool_response(&self) -> Option<&ToolResponse> {
|
||||
if let MessageContent::ToolResp(ref tool_response) = self {
|
||||
Some(tool_response)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_tool_response_text(&self) -> Option<String> {
|
||||
if let Some(tool_response) = self.as_tool_response() {
|
||||
if let Ok(contents) = &tool_response.tool_result.0 {
|
||||
let texts: Vec<String> = contents
|
||||
.iter()
|
||||
.filter_map(|content| content.as_text().map(String::from))
|
||||
.collect();
|
||||
if !texts.is_empty() {
|
||||
return Some(texts.join("\n"));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn as_tool_request_id(&self) -> Option<&str> {
|
||||
if let Self::ToolReq(r) = self {
|
||||
Some(&r.id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_tool_response_id(&self) -> Option<&str> {
|
||||
if let Self::ToolResp(r) = self {
|
||||
Some(&r.id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the text content if this is a TextContent variant
|
||||
pub fn as_text(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessageContent::Text(text) => Some(&text.text),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the thinking content if this is a ThinkingContent variant
|
||||
pub fn as_thinking(&self) -> Option<&ThinkingContent> {
|
||||
match self {
|
||||
MessageContent::Thinking(thinking) => Some(thinking),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the redacted thinking content if this is a RedactedThinkingContent variant
|
||||
pub fn as_redacted_thinking(&self) -> Option<&RedactedThinkingContent> {
|
||||
match self {
|
||||
MessageContent::RedactedThinking(redacted) => Some(redacted),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_text(&self) -> bool {
|
||||
matches!(self, Self::Text(_))
|
||||
}
|
||||
pub fn is_image(&self) -> bool {
|
||||
matches!(self, Self::Image(_))
|
||||
}
|
||||
pub fn is_tool_request(&self) -> bool {
|
||||
matches!(self, Self::ToolReq(_))
|
||||
}
|
||||
pub fn is_tool_response(&self) -> bool {
|
||||
matches!(self, Self::ToolResp(_))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Content> for MessageContent {
|
||||
fn from(content: Content) -> Self {
|
||||
match content {
|
||||
Content::Text(text) => MessageContent::Text(text),
|
||||
Content::Image(image) => MessageContent::Image(image),
|
||||
}
|
||||
}
|
||||
}
|
||||
284
crates/goose-llm/src/message/mod.rs
Normal file
284
crates/goose-llm/src/message/mod.rs
Normal file
@@ -0,0 +1,284 @@
|
||||
//! Messages which represent the content sent back and forth to LLM provider
|
||||
//!
|
||||
//! We use these messages in the agent code, and interfaces which interact with
|
||||
//! the agent. That let's us reuse message histories across different interfaces.
|
||||
//!
|
||||
//! The content of the messages uses MCP types to avoid additional conversions
|
||||
//! when interacting with MCP servers.
|
||||
|
||||
mod contents;
|
||||
mod message_content;
|
||||
mod tool_result_serde;
|
||||
|
||||
pub use contents::Contents;
|
||||
pub use message_content::{
|
||||
MessageContent, RedactedThinkingContent, ThinkingContent, ToolRequest, ToolRequestToolCall,
|
||||
ToolResponse, ToolResponseToolResult,
|
||||
};
|
||||
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::types::core::Role;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
/// A message to or from an LLM
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub created: i64,
|
||||
pub content: Contents,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(role: Role) -> Self {
|
||||
Self {
|
||||
role,
|
||||
created: Utc::now().timestamp_millis(),
|
||||
content: Contents::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new user message with the current timestamp
|
||||
pub fn user() -> Self {
|
||||
Self::new(Role::User)
|
||||
}
|
||||
|
||||
/// Create a new assistant message with the current timestamp
|
||||
pub fn assistant() -> Self {
|
||||
Self::new(Role::Assistant)
|
||||
}
|
||||
|
||||
/// Add any item that implements Into<MessageContent> to the message
|
||||
pub fn with_content(mut self, item: impl Into<MessageContent>) -> Self {
|
||||
self.content.push(item);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add text content to the message
|
||||
pub fn with_text<S: Into<String>>(self, text: S) -> Self {
|
||||
self.with_content(MessageContent::text(text))
|
||||
}
|
||||
|
||||
/// Add image content to the message
|
||||
pub fn with_image<S: Into<String>, T: Into<String>>(self, data: S, mime_type: T) -> Self {
|
||||
self.with_content(MessageContent::image(data, mime_type))
|
||||
}
|
||||
|
||||
/// Add a tool request to the message
|
||||
pub fn with_tool_request<S: Into<String>, T: Into<ToolRequestToolCall>>(
|
||||
self,
|
||||
id: S,
|
||||
tool_call: T,
|
||||
) -> Self {
|
||||
self.with_content(MessageContent::tool_request(id, tool_call.into()))
|
||||
}
|
||||
|
||||
/// Add a tool response to the message
|
||||
pub fn with_tool_response<S: Into<String>>(
|
||||
self,
|
||||
id: S,
|
||||
result: ToolResponseToolResult,
|
||||
) -> Self {
|
||||
self.with_content(MessageContent::tool_response(id, result))
|
||||
}
|
||||
|
||||
/// Add thinking content to the message
|
||||
pub fn with_thinking<S1: Into<String>, S2: Into<String>>(
|
||||
self,
|
||||
thinking: S1,
|
||||
signature: S2,
|
||||
) -> Self {
|
||||
self.with_content(MessageContent::thinking(thinking, signature))
|
||||
}
|
||||
|
||||
/// Add redacted thinking content to the message
|
||||
pub fn with_redacted_thinking<S: Into<String>>(self, data: S) -> Self {
|
||||
self.with_content(MessageContent::redacted_thinking(data))
|
||||
}
|
||||
|
||||
/// Check if the message is a tool call
|
||||
pub fn contains_tool_call(&self) -> bool {
|
||||
self.content.any_is(MessageContent::is_tool_request)
|
||||
}
|
||||
|
||||
/// Check if the message is a tool response
|
||||
pub fn contains_tool_response(&self) -> bool {
|
||||
self.content.any_is(MessageContent::is_tool_response)
|
||||
}
|
||||
|
||||
/// Check if the message contains only text content
|
||||
pub fn has_only_text_content(&self) -> bool {
|
||||
self.content.all_are(MessageContent::is_text)
|
||||
}
|
||||
|
||||
/// Retrieves all tool `id` from ToolRequest messages
|
||||
pub fn tool_request_ids(&self) -> HashSet<&str> {
|
||||
self.content
|
||||
.iter()
|
||||
.filter_map(MessageContent::as_tool_request_id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Retrieves all tool `id` from ToolResponse messages
|
||||
pub fn tool_response_ids(&self) -> HashSet<&str> {
|
||||
self.content
|
||||
.iter()
|
||||
.filter_map(MessageContent::as_tool_response_id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Retrieves all tool `id` from the message
|
||||
pub fn tool_ids(&self) -> HashSet<&str> {
|
||||
self.tool_request_ids()
|
||||
.into_iter()
|
||||
.chain(self.tool_response_ids())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::*;
|
||||
use crate::types::core::{ToolCall, ToolError};
|
||||
|
||||
#[test]
|
||||
fn test_message_serialization() {
|
||||
let message = Message::assistant()
|
||||
.with_text("Hello, I'll help you with that.")
|
||||
.with_tool_request(
|
||||
"tool123",
|
||||
Ok(ToolCall::new("test_tool", json!({"param": "value"})).into()),
|
||||
);
|
||||
|
||||
let json_str = serde_json::to_string_pretty(&message).unwrap();
|
||||
println!("Serialized message: {}", json_str);
|
||||
|
||||
// Parse back to Value to check structure
|
||||
let value: Value = serde_json::from_str(&json_str).unwrap();
|
||||
println!(
|
||||
"Read back serialized message: {}",
|
||||
serde_json::to_string_pretty(&value).unwrap()
|
||||
);
|
||||
|
||||
// Check top-level fields
|
||||
assert_eq!(value["role"], "assistant");
|
||||
assert!(value["created"].is_i64());
|
||||
assert!(value["content"].is_array());
|
||||
|
||||
// Check content items
|
||||
let content = &value["content"];
|
||||
|
||||
// First item should be text
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
assert_eq!(content[0]["text"], "Hello, I'll help you with that.");
|
||||
|
||||
// Second item should be toolRequest
|
||||
assert_eq!(content[1]["type"], "toolReq");
|
||||
assert_eq!(content[1]["id"], "tool123");
|
||||
|
||||
// Check tool_call serialization
|
||||
assert_eq!(content[1]["toolCall"]["status"], "success");
|
||||
assert_eq!(content[1]["toolCall"]["value"]["name"], "test_tool");
|
||||
assert_eq!(
|
||||
content[1]["toolCall"]["value"]["arguments"]["param"],
|
||||
"value"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_serialization() {
|
||||
let message = Message::assistant().with_tool_request(
|
||||
"tool123",
|
||||
Err(ToolError::ExecutionError(
|
||||
"Something went wrong".to_string(),
|
||||
)),
|
||||
);
|
||||
|
||||
let json_str = serde_json::to_string_pretty(&message).unwrap();
|
||||
println!("Serialized error: {}", json_str);
|
||||
|
||||
// Parse back to Value to check structure
|
||||
let value: Value = serde_json::from_str(&json_str).unwrap();
|
||||
|
||||
// Check tool_call serialization with error
|
||||
let tool_call = &value["content"][0]["toolCall"];
|
||||
assert_eq!(tool_call["status"], "error");
|
||||
assert_eq!(tool_call["error"], "Execution failed: Something went wrong");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialization() {
|
||||
// Create a JSON string with our new format
|
||||
let json_str = r#"{
|
||||
"role": "assistant",
|
||||
"created": 1740171566,
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "I'll help you with that."
|
||||
},
|
||||
{
|
||||
"type": "toolReq",
|
||||
"id": "tool123",
|
||||
"toolCall": {
|
||||
"status": "success",
|
||||
"value": {
|
||||
"name": "test_tool",
|
||||
"arguments": {"param": "value"},
|
||||
"needsApproval": false
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let message: Message = serde_json::from_str(json_str).unwrap();
|
||||
|
||||
assert_eq!(message.role, Role::Assistant);
|
||||
assert_eq!(message.created, 1740171566);
|
||||
assert_eq!(message.content.len(), 2);
|
||||
|
||||
// Check first content item
|
||||
if let MessageContent::Text(text) = &message.content[0] {
|
||||
assert_eq!(text.text, "I'll help you with that.");
|
||||
} else {
|
||||
panic!("Expected Text content");
|
||||
}
|
||||
|
||||
// Check second content item
|
||||
if let MessageContent::ToolReq(req) = &message.content[1] {
|
||||
assert_eq!(req.id, "tool123");
|
||||
if let Ok(tool_call) = req.tool_call.as_result() {
|
||||
assert_eq!(tool_call.name, "test_tool");
|
||||
assert_eq!(tool_call.arguments, json!({"param": "value"}));
|
||||
} else {
|
||||
panic!("Expected successful tool call");
|
||||
}
|
||||
} else {
|
||||
panic!("Expected ToolRequest content");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_with_text() {
|
||||
let message = Message::user().with_text("Hello");
|
||||
assert_eq!(message.content.concat_text_str(), "Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_with_tool_request() {
|
||||
let tool_call = Ok(ToolCall::new("test_tool", json!({})));
|
||||
|
||||
let message = Message::assistant().with_tool_request("req1", tool_call);
|
||||
assert!(message.contains_tool_call());
|
||||
assert!(!message.contains_tool_response());
|
||||
|
||||
let ids = message.tool_ids();
|
||||
assert_eq!(ids.len(), 1);
|
||||
assert!(ids.contains("req1"));
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,14 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DEFAULT_CONTEXT_LIMIT: usize = 128_000;
|
||||
const DEFAULT_CONTEXT_LIMIT: u32 = 128_000;
|
||||
|
||||
/// Configuration for model-specific settings and limits
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct ModelConfig {
|
||||
/// The name of the model to use
|
||||
pub model_name: String,
|
||||
/// Optional explicit context limit that overrides any defaults
|
||||
pub context_limit: Option<usize>,
|
||||
pub context_limit: Option<u32>,
|
||||
/// Optional temperature setting (0.0 - 1.0)
|
||||
pub temperature: Option<f32>,
|
||||
/// Optional maximum tokens to generate
|
||||
@@ -34,7 +34,7 @@ impl ModelConfig {
|
||||
}
|
||||
|
||||
/// Get model-specific context limit based on model name
|
||||
fn get_model_specific_limit(model_name: &str) -> Option<usize> {
|
||||
fn get_model_specific_limit(model_name: &str) -> Option<u32> {
|
||||
// Implement some sensible defaults
|
||||
match model_name {
|
||||
// OpenAI models, https://platform.openai.com/docs/models#models-overview
|
||||
@@ -52,7 +52,7 @@ impl ModelConfig {
|
||||
}
|
||||
|
||||
/// Set an explicit context limit
|
||||
pub fn with_context_limit(mut self, limit: Option<usize>) -> Self {
|
||||
pub fn with_context_limit(mut self, limit: Option<u32>) -> Self {
|
||||
// Default is None and therefore DEFAULT_CONTEXT_LIMIT, only set
|
||||
// if input is Some to allow passing through with_context_limit in
|
||||
// configuration cases
|
||||
@@ -76,7 +76,7 @@ impl ModelConfig {
|
||||
|
||||
/// Get the context_limit for the current model
|
||||
/// If none are defined, use the DEFAULT_CONTEXT_LIMIT
|
||||
pub fn context_limit(&self) -> usize {
|
||||
pub fn context_limit(&self) -> u32 {
|
||||
self.context_limit.unwrap_or(DEFAULT_CONTEXT_LIMIT)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
|
||||
use super::errors::ProviderError;
|
||||
use crate::{message::Message, types::core::Tool};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct Usage {
|
||||
pub input_tokens: Option<i32>,
|
||||
pub output_tokens: Option<i32>,
|
||||
@@ -26,7 +26,7 @@ impl Usage {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, uniffi::Record)]
|
||||
pub struct ProviderCompleteResponse {
|
||||
pub message: Message,
|
||||
pub model: String,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[derive(Error, Debug, uniffi::Error)]
|
||||
pub enum ProviderError {
|
||||
#[error("Authentication error: {0}")]
|
||||
Authentication(String),
|
||||
|
||||
@@ -83,9 +83,9 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
]
|
||||
}));
|
||||
}
|
||||
MessageContent::ToolRequest(request) => {
|
||||
MessageContent::ToolReq(request) => {
|
||||
has_tool_calls = true;
|
||||
match &request.tool_call {
|
||||
match &request.tool_call.as_result() {
|
||||
Ok(tool_call) => {
|
||||
let sanitized_name = sanitize_function_name(&tool_call.name);
|
||||
|
||||
@@ -114,8 +114,8 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResponse(response) => {
|
||||
match &response.tool_result {
|
||||
MessageContent::ToolResp(response) => {
|
||||
match &response.tool_result.0 {
|
||||
Ok(contents) => {
|
||||
// Process all content, replacing images with placeholder text
|
||||
let mut tool_content = Vec::new();
|
||||
@@ -300,13 +300,13 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
|
||||
"The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+",
|
||||
function_name
|
||||
));
|
||||
content.push(MessageContent::tool_request(id, Err(error)));
|
||||
content.push(MessageContent::tool_request(id, Err(error).into()));
|
||||
} else {
|
||||
match serde_json::from_str::<Value>(&arguments) {
|
||||
Ok(params) => {
|
||||
content.push(MessageContent::tool_request(
|
||||
id,
|
||||
Ok(ToolCall::new(&function_name, params)),
|
||||
Ok(ToolCall::new(&function_name, params)).into(),
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -314,7 +314,7 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
|
||||
"Could not interpret tool use parameters for id {}: {}",
|
||||
id, e
|
||||
));
|
||||
content.push(MessageContent::tool_request(id, Err(error)));
|
||||
content.push(MessageContent::tool_request(id, Err(error).into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -681,19 +681,20 @@ mod tests {
|
||||
Message::user().with_text("How are you?"),
|
||||
Message::assistant().with_tool_request(
|
||||
"tool1",
|
||||
Ok(ToolCall::new("example", json!({"param1": "value1"}))),
|
||||
Ok(ToolCall::new("example", json!({"param1": "value1"})).into()),
|
||||
),
|
||||
];
|
||||
|
||||
// Get the ID from the tool request to use in the response
|
||||
let tool_id = if let MessageContent::ToolRequest(request) = &messages[2].content[0] {
|
||||
let tool_id = if let MessageContent::ToolReq(request) = &messages[2].content[0] {
|
||||
request.id.clone()
|
||||
} else {
|
||||
panic!("should be tool request");
|
||||
};
|
||||
|
||||
messages
|
||||
.push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")])));
|
||||
messages.push(
|
||||
Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
|
||||
);
|
||||
|
||||
let spec = format_messages(&messages, &ImageFormat::OpenAi);
|
||||
|
||||
@@ -719,14 +720,15 @@ mod tests {
|
||||
)];
|
||||
|
||||
// Get the ID from the tool request to use in the response
|
||||
let tool_id = if let MessageContent::ToolRequest(request) = &messages[0].content[0] {
|
||||
let tool_id = if let MessageContent::ToolReq(request) = &messages[0].content[0] {
|
||||
request.id.clone()
|
||||
} else {
|
||||
panic!("should be tool request");
|
||||
};
|
||||
|
||||
messages
|
||||
.push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")])));
|
||||
messages.push(
|
||||
Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
|
||||
);
|
||||
|
||||
let spec = format_messages(&messages, &ImageFormat::OpenAi);
|
||||
|
||||
@@ -857,7 +859,7 @@ mod tests {
|
||||
let message = response_to_message(response)?;
|
||||
|
||||
assert_eq!(message.content.len(), 1);
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
if let MessageContent::ToolReq(request) = &message.content[0] {
|
||||
let tool_call = request.tool_call.as_ref().unwrap();
|
||||
assert_eq!(tool_call.name, "example_fn");
|
||||
assert_eq!(tool_call.arguments, json!({"param": "value"}));
|
||||
@@ -876,8 +878,8 @@ mod tests {
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
match &request.tool_call {
|
||||
if let MessageContent::ToolReq(request) = &message.content[0] {
|
||||
match &request.tool_call.as_result() {
|
||||
Err(ToolError::NotFound(msg)) => {
|
||||
assert!(msg.starts_with("The provided function name"));
|
||||
}
|
||||
@@ -898,8 +900,8 @@ mod tests {
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
match &request.tool_call {
|
||||
if let MessageContent::ToolReq(request) = &message.content[0] {
|
||||
match &request.tool_call.as_result() {
|
||||
Err(ToolError::InvalidParameters(msg)) => {
|
||||
assert!(msg.starts_with("Could not interpret tool use parameters"));
|
||||
}
|
||||
@@ -920,7 +922,7 @@ mod tests {
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
if let MessageContent::ToolReq(request) = &message.content[0] {
|
||||
let tool_call = request.tool_call.as_ref().unwrap();
|
||||
assert_eq!(tool_call.name, "example_fn");
|
||||
assert_eq!(tool_call.arguments, json!({}));
|
||||
|
||||
@@ -56,7 +56,7 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
// Redacted thinking blocks are not directly used in OpenAI format
|
||||
continue;
|
||||
}
|
||||
MessageContent::ToolRequest(request) => match &request.tool_call {
|
||||
MessageContent::ToolReq(request) => match &request.tool_call.as_result() {
|
||||
Ok(tool_call) => {
|
||||
let sanitized_name = sanitize_function_name(&tool_call.name);
|
||||
let tool_calls = converted
|
||||
@@ -82,8 +82,8 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
}));
|
||||
}
|
||||
},
|
||||
MessageContent::ToolResponse(response) => {
|
||||
match &response.tool_result {
|
||||
MessageContent::ToolResp(response) => {
|
||||
match &response.tool_result.0 {
|
||||
Ok(contents) => {
|
||||
// Process all content, replacing images with placeholder text
|
||||
let mut tool_content = Vec::new();
|
||||
@@ -210,13 +210,13 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
|
||||
"The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+",
|
||||
function_name
|
||||
));
|
||||
content.push(MessageContent::tool_request(id, Err(error)));
|
||||
content.push(MessageContent::tool_request(id, Err(error).into()));
|
||||
} else {
|
||||
match serde_json::from_str::<Value>(&arguments) {
|
||||
Ok(params) => {
|
||||
content.push(MessageContent::tool_request(
|
||||
id,
|
||||
Ok(ToolCall::new(&function_name, params)),
|
||||
Ok(ToolCall::new(&function_name, params)).into(),
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -224,7 +224,7 @@ pub fn response_to_message(response: Value) -> anyhow::Result<Message> {
|
||||
"Could not interpret tool use parameters for id {}: {}",
|
||||
id, e
|
||||
));
|
||||
content.push(MessageContent::tool_request(id, Err(error)));
|
||||
content.push(MessageContent::tool_request(id, Err(error).into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -559,14 +559,15 @@ mod tests {
|
||||
];
|
||||
|
||||
// Get the ID from the tool request to use in the response
|
||||
let tool_id = if let MessageContent::ToolRequest(request) = &messages[2].content[0] {
|
||||
let tool_id = if let MessageContent::ToolReq(request) = &messages[2].content[0] {
|
||||
request.id.clone()
|
||||
} else {
|
||||
panic!("should be tool request");
|
||||
};
|
||||
|
||||
messages
|
||||
.push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")])));
|
||||
messages.push(
|
||||
Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
|
||||
);
|
||||
|
||||
let spec = format_messages(&messages, &ImageFormat::OpenAi);
|
||||
|
||||
@@ -592,14 +593,15 @@ mod tests {
|
||||
)];
|
||||
|
||||
// Get the ID from the tool request to use in the response
|
||||
let tool_id = if let MessageContent::ToolRequest(request) = &messages[0].content[0] {
|
||||
let tool_id = if let MessageContent::ToolReq(request) = &messages[0].content[0] {
|
||||
request.id.clone()
|
||||
} else {
|
||||
panic!("should be tool request");
|
||||
};
|
||||
|
||||
messages
|
||||
.push(Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")])));
|
||||
messages.push(
|
||||
Message::user().with_tool_response(tool_id, Ok(vec![Content::text("Result")]).into()),
|
||||
);
|
||||
|
||||
let spec = format_messages(&messages, &ImageFormat::OpenAi);
|
||||
|
||||
@@ -730,7 +732,7 @@ mod tests {
|
||||
let message = response_to_message(response)?;
|
||||
|
||||
assert_eq!(message.content.len(), 1);
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
if let MessageContent::ToolReq(request) = &message.content[0] {
|
||||
let tool_call = request.tool_call.as_ref().unwrap();
|
||||
assert_eq!(tool_call.name, "example_fn");
|
||||
assert_eq!(tool_call.arguments, json!({"param": "value"}));
|
||||
@@ -749,8 +751,8 @@ mod tests {
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
match &request.tool_call {
|
||||
if let MessageContent::ToolReq(request) = &message.content[0] {
|
||||
match &request.tool_call.as_result() {
|
||||
Err(ToolError::NotFound(msg)) => {
|
||||
assert!(msg.starts_with("The provided function name"));
|
||||
}
|
||||
@@ -771,8 +773,8 @@ mod tests {
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
match &request.tool_call {
|
||||
if let MessageContent::ToolReq(request) = &message.content[0] {
|
||||
match &request.tool_call.as_result() {
|
||||
Err(ToolError::InvalidParameters(msg)) => {
|
||||
assert!(msg.starts_with("Could not interpret tool use parameters"));
|
||||
}
|
||||
@@ -793,7 +795,7 @@ mod tests {
|
||||
|
||||
let message = response_to_message(response)?;
|
||||
|
||||
if let MessageContent::ToolRequest(request) = &message.content[0] {
|
||||
if let MessageContent::ToolReq(request) = &message.content[0] {
|
||||
let tool_call = request.tool_call.as_ref().unwrap();
|
||||
assert_eq!(tool_call.name, "example_fn");
|
||||
assert_eq!(tool_call.arguments, json!({}));
|
||||
|
||||
@@ -7,36 +7,24 @@ use thiserror::Error;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::types::json_value_ffi::JsonValueFfi;
|
||||
use crate::{message::Message, providers::Usage};
|
||||
use crate::{model::ModelConfig, providers::errors::ProviderError};
|
||||
|
||||
pub struct CompletionRequest<'a> {
|
||||
pub provider_name: &'a str,
|
||||
// Lifetimes are not supported in Uniffi, cause other languages don't have them
|
||||
// https://github.com/mozilla/uniffi-rs/issues/1526#issuecomment-1528851837
|
||||
#[derive(uniffi::Record)]
|
||||
pub struct CompletionRequest {
|
||||
pub provider_name: String,
|
||||
pub model_config: ModelConfig,
|
||||
pub system_preamble: &'a str,
|
||||
pub messages: &'a [Message],
|
||||
pub extensions: &'a [ExtensionConfig],
|
||||
pub system_preamble: String,
|
||||
pub messages: Vec<Message>,
|
||||
pub extensions: Vec<ExtensionConfig>,
|
||||
}
|
||||
|
||||
impl<'a> CompletionRequest<'a> {
|
||||
pub fn new(
|
||||
provider_name: &'a str,
|
||||
model_config: ModelConfig,
|
||||
system_preamble: &'a str,
|
||||
messages: &'a [Message],
|
||||
extensions: &'a [ExtensionConfig],
|
||||
) -> Self {
|
||||
Self {
|
||||
provider_name,
|
||||
model_config,
|
||||
system_preamble,
|
||||
messages,
|
||||
extensions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
// https://mozilla.github.io/uniffi-rs/latest/proc_macro/errors.html
|
||||
#[derive(Debug, Error, uniffi::Error)]
|
||||
#[uniffi(flat_error)]
|
||||
pub enum CompletionError {
|
||||
#[error("failed to create provider: {0}")]
|
||||
UnknownProvider(String),
|
||||
@@ -54,7 +42,7 @@ pub enum CompletionError {
|
||||
ToolNotFound(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct CompletionResponse {
|
||||
pub message: Message,
|
||||
pub model: String,
|
||||
@@ -78,35 +66,35 @@ impl CompletionResponse {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, uniffi::Record)]
|
||||
pub struct RuntimeMetrics {
|
||||
pub total_time_ms: u128,
|
||||
pub total_time_ms_provider: u128,
|
||||
pub total_time_sec: f32,
|
||||
pub total_time_sec_provider: f32,
|
||||
pub tokens_per_second: Option<f64>,
|
||||
}
|
||||
|
||||
impl RuntimeMetrics {
|
||||
pub fn new(
|
||||
total_time_ms: u128,
|
||||
total_time_ms_provider: u128,
|
||||
total_time_sec: f32,
|
||||
total_time_sec_provider: f32,
|
||||
tokens_per_second: Option<f64>,
|
||||
) -> Self {
|
||||
Self {
|
||||
total_time_ms,
|
||||
total_time_ms_provider,
|
||||
total_time_sec,
|
||||
total_time_sec_provider,
|
||||
tokens_per_second,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
|
||||
pub enum ToolApprovalMode {
|
||||
Auto,
|
||||
Manual,
|
||||
Smart,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ToolConfig {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
@@ -140,7 +128,28 @@ impl ToolConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[uniffi::export]
|
||||
pub fn create_tool_config(
|
||||
name: &str,
|
||||
description: &str,
|
||||
input_schema: JsonValueFfi,
|
||||
approval_mode: ToolApprovalMode,
|
||||
) -> ToolConfig {
|
||||
ToolConfig::new(name, description, input_schema.into(), approval_mode)
|
||||
}
|
||||
|
||||
uniffi::custom_type!(ToolConfig, String, {
|
||||
lower: |tc: &ToolConfig| {
|
||||
serde_json::to_string(&tc).unwrap()
|
||||
},
|
||||
try_lift: |s: String| {
|
||||
Ok(serde_json::from_str(&s).unwrap())
|
||||
},
|
||||
});
|
||||
|
||||
// — Register the newtypes with UniFFI, converting via JSON strings —
|
||||
|
||||
#[derive(Debug, Clone, Serialize, uniffi::Record)]
|
||||
pub struct ExtensionConfig {
|
||||
name: String,
|
||||
instructions: Option<String>,
|
||||
|
||||
@@ -4,14 +4,14 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Enum)]
|
||||
#[serde(tag = "type", rename_all = "camelCase")]
|
||||
pub enum Content {
|
||||
Text(TextContent),
|
||||
@@ -47,13 +47,13 @@ impl Content {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TextContent {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, uniffi::Record)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ImageContent {
|
||||
pub data: String,
|
||||
@@ -116,7 +116,7 @@ impl ToolCall {
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Error, Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
#[derive(Error, Debug, Clone, Deserialize, Serialize, PartialEq, uniffi::Error)]
|
||||
pub enum ToolError {
|
||||
#[error("Invalid parameters: {0}")]
|
||||
InvalidParameters(String),
|
||||
|
||||
84
crates/goose-llm/src/types/json_value_ffi.rs
Normal file
84
crates/goose-llm/src/types/json_value_ffi.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// `serde_json::Value` gets converted to a `String` to pass across the FFI.
|
||||
// https://github.com/mozilla/uniffi-rs/blob/main/docs/manual/src/types/custom_types.md?plain=1
|
||||
// https://github.com/mozilla/uniffi-rs/blob/c7f6caa3d1bf20f934346cefd8e82b5093f0dc6f/examples/custom-types/src/lib.rs#L63-L69
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct JsonValueFfi(Value);
|
||||
|
||||
impl From<JsonValueFfi> for Value {
|
||||
fn from(val: JsonValueFfi) -> Self {
|
||||
val.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Value> for JsonValueFfi {
|
||||
fn from(val: Value) -> Self {
|
||||
JsonValueFfi(val)
|
||||
}
|
||||
}
|
||||
|
||||
uniffi::custom_type!(JsonValueFfi, String, {
|
||||
lower: |obj| {
|
||||
serde_json::to_string(&obj.0).unwrap()
|
||||
},
|
||||
try_lift: |val| {
|
||||
Ok(serde_json::from_str(&val).unwrap() )
|
||||
},
|
||||
});
|
||||
|
||||
// Write some tests to ensure that the conversion works as expected
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_json_value_ffi_conversion() {
|
||||
let original = JsonValueFfi(json!({"key": "value"}));
|
||||
let serialized = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: JsonValueFfi = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(original.0, deserialized.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_value_ffi_to_serde() {
|
||||
let original = JsonValueFfi(json!({"key": "value"}));
|
||||
let value: Value = original.into();
|
||||
assert_eq!(value, json!({"key": "value"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_value_ffi_from_serde() {
|
||||
let value = json!({"key": "value"});
|
||||
let original: JsonValueFfi = value.into();
|
||||
assert_eq!(original.0, json!({"key": "value"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_value_ffi_lower() {
|
||||
let original = JsonValueFfi(json!({"key": "value"}));
|
||||
let serialized = serde_json::to_string(&original).unwrap();
|
||||
|
||||
assert_eq!(serialized, "{\"key\":\"value\"}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_value_ffi_try_lift() {
|
||||
let json_str = "{\"key\":\"value\"}";
|
||||
let deserialized: JsonValueFfi = serde_json::from_str(json_str).unwrap();
|
||||
let expected = JsonValueFfi(json!({"key": "value"}));
|
||||
assert_eq!(deserialized.0, expected.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_value_ffi_custom_type() {
|
||||
let json_str = "{\"key\":\"value\"}";
|
||||
let deserialized: JsonValueFfi = serde_json::from_str(json_str).unwrap();
|
||||
let serialized = serde_json::to_string(&deserialized).unwrap();
|
||||
assert_eq!(serialized, json_str);
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,3 @@
|
||||
pub mod completion;
|
||||
pub mod core;
|
||||
pub mod json_value_ffi;
|
||||
|
||||
@@ -45,14 +45,14 @@ async fn test_generate_tooltip_with_tools() -> Result<(), ProviderError> {
|
||||
let mut tool_req_msg = Message::assistant();
|
||||
let req = ToolRequest {
|
||||
id: "1".to_string(),
|
||||
tool_call: Ok(ToolCall::new("get_time", json!({"timezone": "UTC"}))),
|
||||
tool_call: Ok(ToolCall::new("get_time", json!({"timezone": "UTC"}))).into(),
|
||||
};
|
||||
tool_req_msg.content.push(MessageContent::ToolRequest(req));
|
||||
tool_req_msg.content.push(MessageContent::ToolReq(req));
|
||||
|
||||
// 2) User message with the tool response
|
||||
let tool_resp_msg = Message::user().with_tool_response(
|
||||
"1",
|
||||
Ok(vec![Content::text("The current time is 12:00 UTC")]),
|
||||
Ok(vec![Content::text("The current time is 12:00 UTC")]).into(),
|
||||
);
|
||||
|
||||
let messages = vec![tool_req_msg, tool_resp_msg];
|
||||
|
||||
@@ -147,7 +147,7 @@ impl ProviderTester {
|
||||
.message
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| matches!(content, MessageContent::ToolRequest(_))),
|
||||
.any(|content| matches!(content, MessageContent::ToolReq(_))),
|
||||
"Expected tool request in response"
|
||||
);
|
||||
|
||||
@@ -171,7 +171,8 @@ impl ProviderTester {
|
||||
Weather
|
||||
Saturday 9:00 PM
|
||||
Clear",
|
||||
)]),
|
||||
)])
|
||||
.into(),
|
||||
);
|
||||
|
||||
// Verify we construct a valid payload including the request/response pair for the next inference
|
||||
|
||||
3
crates/goose-llm/uniffi-bindgen.rs
Normal file
3
crates/goose-llm/uniffi-bindgen.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
uniffi::uniffi_bindgen_main()
|
||||
}
|
||||
Reference in New Issue
Block a user