diff --git a/bindings/kotlin/example/Usage.kt b/bindings/kotlin/example/Usage.kt index 7c3fc4a3..99089515 100644 --- a/bindings/kotlin/example/Usage.kt +++ b/bindings/kotlin/example/Usage.kt @@ -130,8 +130,8 @@ fun main() = runBlocking { providerConfig, modelConfig, systemPreamble, - msgs, - extensions + messages = msgs, + extensions = extensions ) val response = completion(req) @@ -140,6 +140,32 @@ fun main() = runBlocking { // ---- UI Extraction (custom schema) ---- runUiExtraction(providerName, providerConfig) + + + // --- Prompt Override --- + val prompt_req = createCompletionRequest( + providerName, + providerConfig, + modelConfig, + systemPreamble = null, + systemPromptOverride = "You are a bot named Tile Creator. Your task is to create a tile based on the user's input.", + messages=listOf( + Message( + role = Role.USER, + created = now, + content = listOf( + MessageContent.Text( + TextContent("What's your name?") + ) + ) + ) + ), + extensions=emptyList() + ) + + val prompt_resp = completion(prompt_req) + + println("\nPrompt Override Response:\n${prompt_resp.message}") } diff --git a/bindings/kotlin/uniffi/goose_llm/goose_llm.kt b/bindings/kotlin/uniffi/goose_llm/goose_llm.kt index b5b16337..76e60aaf 100644 --- a/bindings/kotlin/uniffi/goose_llm/goose_llm.kt +++ b/bindings/kotlin/uniffi/goose_llm/goose_llm.kt @@ -830,6 +830,7 @@ internal interface UniffiLib : Library { `providerConfig`: RustBuffer.ByValue, `modelConfig`: RustBuffer.ByValue, `systemPreamble`: RustBuffer.ByValue, + `systemPromptOverride`: RustBuffer.ByValue, `messages`: RustBuffer.ByValue, `extensions`: RustBuffer.ByValue, uniffi_out_err: UniffiRustCallStatus, @@ -1100,7 +1101,7 @@ private fun uniffiCheckApiChecksums(lib: IntegrityCheckingUniffiLib) { if (lib.uniffi_goose_llm_checksum_func_completion() != 47457.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } - if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 39068.toShort()) { + if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 50798.toShort()) { throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project") } if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 49910.toShort()) { @@ -2955,7 +2956,8 @@ fun `createCompletionRequest`( `providerName`: kotlin.String, `providerConfig`: Value, `modelConfig`: ModelConfig, - `systemPreamble`: kotlin.String, + `systemPreamble`: kotlin.String? = null, + `systemPromptOverride`: kotlin.String? = null, `messages`: List, `extensions`: List, ): CompletionRequest = @@ -2965,7 +2967,8 @@ fun `createCompletionRequest`( FfiConverterString.lower(`providerName`), FfiConverterTypeValue.lower(`providerConfig`), FfiConverterTypeModelConfig.lower(`modelConfig`), - FfiConverterString.lower(`systemPreamble`), + FfiConverterOptionalString.lower(`systemPreamble`), + FfiConverterOptionalString.lower(`systemPromptOverride`), FfiConverterSequenceTypeMessage.lower(`messages`), FfiConverterSequenceTypeExtensionConfig.lower(`extensions`), _status, diff --git a/crates/goose-llm/Cargo.toml b/crates/goose-llm/Cargo.toml index 2323d29e..be240600 100644 --- a/crates/goose-llm/Cargo.toml +++ b/crates/goose-llm/Cargo.toml @@ -59,3 +59,7 @@ path = "uniffi-bindgen.rs" [[example]] name = "simple" path = "examples/simple.rs" + +[[example]] +name = "prompt_override" +path = "examples/prompt_override.rs" diff --git a/crates/goose-llm/README.md b/crates/goose-llm/README.md index 08d806b8..f32b61c3 100644 --- a/crates/goose-llm/README.md +++ b/crates/goose-llm/README.md @@ -49,7 +49,7 @@ curl -O https://repo1.maven.org/maven2/net/java/dev/jna/jna/5.13.0/jna-5.13.0.ja popd ``` -To just create the Kotlin bindings: +To just create the Kotlin bindings (for MacOS): ```bash # run from project root directory @@ -58,6 +58,18 @@ cargo build -p goose-llm cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin ``` +Creating `libgoose_llm.so` for Linux distribution: + +Use `cross` to build for the specific target and then create the bindings: +``` +# x86-64 GNU/Linux (kGoose uses this) +rustup target add x86_64-unknown-linux-gnu +cross build --release --target x86_64-unknown-linux-gnu -p goose-llm + +# The goose_llm.kt bindings produced should be the same whether we use 'libgoose_llm.dylib' or 'libgoose_llm.so' +cross run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/x86_64-unknown-linux-gnu/release/libgoose_llm.so --language kotlin --out-dir bindings/kotlin +``` + #### Python -> Rust: generate bindings, run example diff --git a/crates/goose-llm/examples/prompt_override.rs b/crates/goose-llm/examples/prompt_override.rs new file mode 100644 index 00000000..3cebffc5 --- /dev/null +++ b/crates/goose-llm/examples/prompt_override.rs @@ -0,0 +1,48 @@ +use std::vec; + +use anyhow::Result; +use goose_llm::{ + completion, + types::completion::{CompletionRequest, CompletionResponse}, + Message, ModelConfig, +}; +use serde_json::json; + +#[tokio::main] +async fn main() -> Result<()> { + let provider = "databricks"; + let provider_config = json!({ + "host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"), + "token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"), + }); + // let model_name = "goose-gpt-4-1"; // parallel tool calls + let model_name = "claude-3-5-haiku"; + let model_config = ModelConfig::new(model_name.to_string()); + + let system_prompt_override = "You are a helpful assistant. Talk in the style of pirates."; + + for text in ["How was your day?"] { + println!("\n---------------\n"); + println!("User Input: {text}"); + let messages = vec![ + Message::user().with_text("Hi there!"), + Message::assistant().with_text("How can I help?"), + Message::user().with_text(text), + ]; + let completion_response: CompletionResponse = completion(CompletionRequest::new( + provider.to_string(), + provider_config.clone(), + model_config.clone(), + None, + Some(system_prompt_override.to_string()), + messages.clone(), + vec![], + )) + .await?; + // Print the response + println!("\nCompletion Response:"); + println!("{}", serde_json::to_string_pretty(&completion_response)?); + } + + Ok(()) +} diff --git a/crates/goose-llm/examples/simple.rs b/crates/goose-llm/examples/simple.rs index 9544f7ce..99ecea31 100644 --- a/crates/goose-llm/examples/simple.rs +++ b/crates/goose-llm/examples/simple.rs @@ -106,7 +106,8 @@ async fn main() -> Result<()> { provider.to_string(), provider_config.clone(), model_config.clone(), - system_preamble.to_string(), + Some(system_preamble.to_string()), + None, messages.clone(), extensions.clone(), )) diff --git a/crates/goose-llm/src/completion.rs b/crates/goose-llm/src/completion.rs index 2095df73..d39b1b8d 100644 --- a/crates/goose-llm/src/completion.rs +++ b/crates/goose-llm/src/completion.rs @@ -36,7 +36,11 @@ pub async fn completion(req: CompletionRequest) -> Result Result, + prompt_override: &Option, extensions: &[ExtensionConfig], ) -> Result { + // If both system_preamble and system_prompt_override are provided, then prompt_override takes precedence + // and we don't render the template using preamble and extensions. Just return the prompt_override as is. + if prompt_override.is_some() { + return Ok(prompt_override.clone().unwrap()); + } + + let system_preamble = { + if let Some(p) = preamble { + p + } else { + "You are a helpful assistant." + } + }; + let mut context: HashMap<&str, Value> = HashMap::new(); context.insert("system_preamble", Value::String(system_preamble.to_owned())); context.insert("extensions", serde_json::to_value(extensions)?); diff --git a/crates/goose-llm/src/types/completion.rs b/crates/goose-llm/src/types/completion.rs index def6a6d2..21e0bcd9 100644 --- a/crates/goose-llm/src/types/completion.rs +++ b/crates/goose-llm/src/types/completion.rs @@ -16,7 +16,8 @@ pub struct CompletionRequest { pub provider_name: String, pub provider_config: serde_json::Value, pub model_config: ModelConfig, - pub system_preamble: String, + pub system_preamble: Option, + pub system_prompt_override: Option, pub messages: Vec, pub extensions: Vec, } @@ -26,7 +27,8 @@ impl CompletionRequest { provider_name: String, provider_config: serde_json::Value, model_config: ModelConfig, - system_preamble: String, + system_preamble: Option, + system_prompt_override: Option, messages: Vec, extensions: Vec, ) -> Self { @@ -34,6 +36,7 @@ impl CompletionRequest { provider_name, provider_config, model_config, + system_prompt_override, system_preamble, messages, extensions, @@ -41,12 +44,13 @@ impl CompletionRequest { } } -#[uniffi::export] +#[uniffi::export(default(system_preamble = None, system_prompt_override = None))] pub fn create_completion_request( provider_name: &str, provider_config: JsonValueFfi, model_config: ModelConfig, - system_preamble: &str, + system_preamble: Option, + system_prompt_override: Option, messages: Vec, extensions: Vec, ) -> CompletionRequest { @@ -54,7 +58,8 @@ pub fn create_completion_request( provider_name.to_string(), provider_config, model_config, - system_preamble.to_string(), + system_preamble, + system_prompt_override, messages, extensions, )