[goose-llm] system prompt override (#2791)

This commit is contained in:
Salman Mohammed
2025-06-05 12:51:51 -04:00
committed by GitHub
parent c04373fdb2
commit 0bcea3d130
8 changed files with 132 additions and 14 deletions

View File

@@ -130,8 +130,8 @@ fun main() = runBlocking {
providerConfig,
modelConfig,
systemPreamble,
msgs,
extensions
messages = msgs,
extensions = extensions
)
val response = completion(req)
@@ -140,6 +140,32 @@ fun main() = runBlocking {
// ---- UI Extraction (custom schema) ----
runUiExtraction(providerName, providerConfig)
// --- Prompt Override ---
val prompt_req = createCompletionRequest(
providerName,
providerConfig,
modelConfig,
systemPreamble = null,
systemPromptOverride = "You are a bot named Tile Creator. Your task is to create a tile based on the user's input.",
messages=listOf(
Message(
role = Role.USER,
created = now,
content = listOf(
MessageContent.Text(
TextContent("What's your name?")
)
)
)
),
extensions=emptyList()
)
val prompt_resp = completion(prompt_req)
println("\nPrompt Override Response:\n${prompt_resp.message}")
}

View File

@@ -830,6 +830,7 @@ internal interface UniffiLib : Library {
`providerConfig`: RustBuffer.ByValue,
`modelConfig`: RustBuffer.ByValue,
`systemPreamble`: RustBuffer.ByValue,
`systemPromptOverride`: RustBuffer.ByValue,
`messages`: RustBuffer.ByValue,
`extensions`: RustBuffer.ByValue,
uniffi_out_err: UniffiRustCallStatus,
@@ -1100,7 +1101,7 @@ private fun uniffiCheckApiChecksums(lib: IntegrityCheckingUniffiLib) {
if (lib.uniffi_goose_llm_checksum_func_completion() != 47457.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 39068.toShort()) {
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 50798.toShort()) {
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
}
if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 49910.toShort()) {
@@ -2955,7 +2956,8 @@ fun `createCompletionRequest`(
`providerName`: kotlin.String,
`providerConfig`: Value,
`modelConfig`: ModelConfig,
`systemPreamble`: kotlin.String,
`systemPreamble`: kotlin.String? = null,
`systemPromptOverride`: kotlin.String? = null,
`messages`: List<Message>,
`extensions`: List<ExtensionConfig>,
): CompletionRequest =
@@ -2965,7 +2967,8 @@ fun `createCompletionRequest`(
FfiConverterString.lower(`providerName`),
FfiConverterTypeValue.lower(`providerConfig`),
FfiConverterTypeModelConfig.lower(`modelConfig`),
FfiConverterString.lower(`systemPreamble`),
FfiConverterOptionalString.lower(`systemPreamble`),
FfiConverterOptionalString.lower(`systemPromptOverride`),
FfiConverterSequenceTypeMessage.lower(`messages`),
FfiConverterSequenceTypeExtensionConfig.lower(`extensions`),
_status,

View File

@@ -59,3 +59,7 @@ path = "uniffi-bindgen.rs"
[[example]]
name = "simple"
path = "examples/simple.rs"
[[example]]
name = "prompt_override"
path = "examples/prompt_override.rs"

View File

@@ -49,7 +49,7 @@ curl -O https://repo1.maven.org/maven2/net/java/dev/jna/jna/5.13.0/jna-5.13.0.ja
popd
```
To just create the Kotlin bindings:
To just create the Kotlin bindings (for MacOS):
```bash
# run from project root directory
@@ -58,6 +58,18 @@ cargo build -p goose-llm
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
```
Creating `libgoose_llm.so` for Linux distribution:
Use `cross` to build for the specific target and then create the bindings:
```
# x86-64 GNU/Linux (kGoose uses this)
rustup target add x86_64-unknown-linux-gnu
cross build --release --target x86_64-unknown-linux-gnu -p goose-llm
# The goose_llm.kt bindings produced should be the same whether we use 'libgoose_llm.dylib' or 'libgoose_llm.so'
cross run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/x86_64-unknown-linux-gnu/release/libgoose_llm.so --language kotlin --out-dir bindings/kotlin
```
#### Python -> Rust: generate bindings, run example

View File

@@ -0,0 +1,48 @@
use std::vec;
use anyhow::Result;
use goose_llm::{
completion,
types::completion::{CompletionRequest, CompletionResponse},
Message, ModelConfig,
};
use serde_json::json;
#[tokio::main]
async fn main() -> Result<()> {
let provider = "databricks";
let provider_config = json!({
"host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"),
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
});
// let model_name = "goose-gpt-4-1"; // parallel tool calls
let model_name = "claude-3-5-haiku";
let model_config = ModelConfig::new(model_name.to_string());
let system_prompt_override = "You are a helpful assistant. Talk in the style of pirates.";
for text in ["How was your day?"] {
println!("\n---------------\n");
println!("User Input: {text}");
let messages = vec![
Message::user().with_text("Hi there!"),
Message::assistant().with_text("How can I help?"),
Message::user().with_text(text),
];
let completion_response: CompletionResponse = completion(CompletionRequest::new(
provider.to_string(),
provider_config.clone(),
model_config.clone(),
None,
Some(system_prompt_override.to_string()),
messages.clone(),
vec![],
))
.await?;
// Print the response
println!("\nCompletion Response:");
println!("{}", serde_json::to_string_pretty(&completion_response)?);
}
Ok(())
}

View File

@@ -106,7 +106,8 @@ async fn main() -> Result<()> {
provider.to_string(),
provider_config.clone(),
model_config.clone(),
system_preamble.to_string(),
Some(system_preamble.to_string()),
None,
messages.clone(),
extensions.clone(),
))

View File

@@ -36,7 +36,11 @@ pub async fn completion(req: CompletionRequest) -> Result<CompletionResponse, Co
)
.map_err(|_| CompletionError::UnknownProvider(req.provider_name.to_string()))?;
let system_prompt = construct_system_prompt(&req.system_preamble, &req.extensions)?;
let system_prompt = construct_system_prompt(
&req.system_preamble,
&req.system_prompt_override,
&req.extensions,
)?;
let tools = collect_prefixed_tools(&req.extensions);
// Call the LLM provider
@@ -60,9 +64,24 @@ pub async fn completion(req: CompletionRequest) -> Result<CompletionResponse, Co
/// Render the global `system.md` template with the provided context.
fn construct_system_prompt(
system_preamble: &str,
preamble: &Option<String>,
prompt_override: &Option<String>,
extensions: &[ExtensionConfig],
) -> Result<String, CompletionError> {
// If both system_preamble and system_prompt_override are provided, then prompt_override takes precedence
// and we don't render the template using preamble and extensions. Just return the prompt_override as is.
if prompt_override.is_some() {
return Ok(prompt_override.clone().unwrap());
}
let system_preamble = {
if let Some(p) = preamble {
p
} else {
"You are a helpful assistant."
}
};
let mut context: HashMap<&str, Value> = HashMap::new();
context.insert("system_preamble", Value::String(system_preamble.to_owned()));
context.insert("extensions", serde_json::to_value(extensions)?);

View File

@@ -16,7 +16,8 @@ pub struct CompletionRequest {
pub provider_name: String,
pub provider_config: serde_json::Value,
pub model_config: ModelConfig,
pub system_preamble: String,
pub system_preamble: Option<String>,
pub system_prompt_override: Option<String>,
pub messages: Vec<Message>,
pub extensions: Vec<ExtensionConfig>,
}
@@ -26,7 +27,8 @@ impl CompletionRequest {
provider_name: String,
provider_config: serde_json::Value,
model_config: ModelConfig,
system_preamble: String,
system_preamble: Option<String>,
system_prompt_override: Option<String>,
messages: Vec<Message>,
extensions: Vec<ExtensionConfig>,
) -> Self {
@@ -34,6 +36,7 @@ impl CompletionRequest {
provider_name,
provider_config,
model_config,
system_prompt_override,
system_preamble,
messages,
extensions,
@@ -41,12 +44,13 @@ impl CompletionRequest {
}
}
#[uniffi::export]
#[uniffi::export(default(system_preamble = None, system_prompt_override = None))]
pub fn create_completion_request(
provider_name: &str,
provider_config: JsonValueFfi,
model_config: ModelConfig,
system_preamble: &str,
system_preamble: Option<String>,
system_prompt_override: Option<String>,
messages: Vec<Message>,
extensions: Vec<ExtensionConfig>,
) -> CompletionRequest {
@@ -54,7 +58,8 @@ pub fn create_completion_request(
provider_name.to_string(),
provider_config,
model_config,
system_preamble.to_string(),
system_preamble,
system_prompt_override,
messages,
extensions,
)