mirror of
https://github.com/aljazceru/goose.git
synced 2026-01-31 20:24:27 +01:00
[goose-llm] system prompt override (#2791)
This commit is contained in:
@@ -130,8 +130,8 @@ fun main() = runBlocking {
|
||||
providerConfig,
|
||||
modelConfig,
|
||||
systemPreamble,
|
||||
msgs,
|
||||
extensions
|
||||
messages = msgs,
|
||||
extensions = extensions
|
||||
)
|
||||
|
||||
val response = completion(req)
|
||||
@@ -140,6 +140,32 @@ fun main() = runBlocking {
|
||||
|
||||
// ---- UI Extraction (custom schema) ----
|
||||
runUiExtraction(providerName, providerConfig)
|
||||
|
||||
|
||||
// --- Prompt Override ---
|
||||
val prompt_req = createCompletionRequest(
|
||||
providerName,
|
||||
providerConfig,
|
||||
modelConfig,
|
||||
systemPreamble = null,
|
||||
systemPromptOverride = "You are a bot named Tile Creator. Your task is to create a tile based on the user's input.",
|
||||
messages=listOf(
|
||||
Message(
|
||||
role = Role.USER,
|
||||
created = now,
|
||||
content = listOf(
|
||||
MessageContent.Text(
|
||||
TextContent("What's your name?")
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
extensions=emptyList()
|
||||
)
|
||||
|
||||
val prompt_resp = completion(prompt_req)
|
||||
|
||||
println("\nPrompt Override Response:\n${prompt_resp.message}")
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -830,6 +830,7 @@ internal interface UniffiLib : Library {
|
||||
`providerConfig`: RustBuffer.ByValue,
|
||||
`modelConfig`: RustBuffer.ByValue,
|
||||
`systemPreamble`: RustBuffer.ByValue,
|
||||
`systemPromptOverride`: RustBuffer.ByValue,
|
||||
`messages`: RustBuffer.ByValue,
|
||||
`extensions`: RustBuffer.ByValue,
|
||||
uniffi_out_err: UniffiRustCallStatus,
|
||||
@@ -1100,7 +1101,7 @@ private fun uniffiCheckApiChecksums(lib: IntegrityCheckingUniffiLib) {
|
||||
if (lib.uniffi_goose_llm_checksum_func_completion() != 47457.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 39068.toShort()) {
|
||||
if (lib.uniffi_goose_llm_checksum_func_create_completion_request() != 50798.toShort()) {
|
||||
throw RuntimeException("UniFFI API checksum mismatch: try cleaning and rebuilding your project")
|
||||
}
|
||||
if (lib.uniffi_goose_llm_checksum_func_create_tool_config() != 49910.toShort()) {
|
||||
@@ -2955,7 +2956,8 @@ fun `createCompletionRequest`(
|
||||
`providerName`: kotlin.String,
|
||||
`providerConfig`: Value,
|
||||
`modelConfig`: ModelConfig,
|
||||
`systemPreamble`: kotlin.String,
|
||||
`systemPreamble`: kotlin.String? = null,
|
||||
`systemPromptOverride`: kotlin.String? = null,
|
||||
`messages`: List<Message>,
|
||||
`extensions`: List<ExtensionConfig>,
|
||||
): CompletionRequest =
|
||||
@@ -2965,7 +2967,8 @@ fun `createCompletionRequest`(
|
||||
FfiConverterString.lower(`providerName`),
|
||||
FfiConverterTypeValue.lower(`providerConfig`),
|
||||
FfiConverterTypeModelConfig.lower(`modelConfig`),
|
||||
FfiConverterString.lower(`systemPreamble`),
|
||||
FfiConverterOptionalString.lower(`systemPreamble`),
|
||||
FfiConverterOptionalString.lower(`systemPromptOverride`),
|
||||
FfiConverterSequenceTypeMessage.lower(`messages`),
|
||||
FfiConverterSequenceTypeExtensionConfig.lower(`extensions`),
|
||||
_status,
|
||||
|
||||
@@ -59,3 +59,7 @@ path = "uniffi-bindgen.rs"
|
||||
[[example]]
|
||||
name = "simple"
|
||||
path = "examples/simple.rs"
|
||||
|
||||
[[example]]
|
||||
name = "prompt_override"
|
||||
path = "examples/prompt_override.rs"
|
||||
|
||||
@@ -49,7 +49,7 @@ curl -O https://repo1.maven.org/maven2/net/java/dev/jna/jna/5.13.0/jna-5.13.0.ja
|
||||
popd
|
||||
```
|
||||
|
||||
To just create the Kotlin bindings:
|
||||
To just create the Kotlin bindings (for MacOS):
|
||||
|
||||
```bash
|
||||
# run from project root directory
|
||||
@@ -58,6 +58,18 @@ cargo build -p goose-llm
|
||||
cargo run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/debug/libgoose_llm.dylib --language kotlin --out-dir bindings/kotlin
|
||||
```
|
||||
|
||||
Creating `libgoose_llm.so` for Linux distribution:
|
||||
|
||||
Use `cross` to build for the specific target and then create the bindings:
|
||||
```
|
||||
# x86-64 GNU/Linux (kGoose uses this)
|
||||
rustup target add x86_64-unknown-linux-gnu
|
||||
cross build --release --target x86_64-unknown-linux-gnu -p goose-llm
|
||||
|
||||
# The goose_llm.kt bindings produced should be the same whether we use 'libgoose_llm.dylib' or 'libgoose_llm.so'
|
||||
cross run --features=uniffi/cli --bin uniffi-bindgen generate --library ./target/x86_64-unknown-linux-gnu/release/libgoose_llm.so --language kotlin --out-dir bindings/kotlin
|
||||
```
|
||||
|
||||
|
||||
#### Python -> Rust: generate bindings, run example
|
||||
|
||||
|
||||
48
crates/goose-llm/examples/prompt_override.rs
Normal file
48
crates/goose-llm/examples/prompt_override.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use std::vec;
|
||||
|
||||
use anyhow::Result;
|
||||
use goose_llm::{
|
||||
completion,
|
||||
types::completion::{CompletionRequest, CompletionResponse},
|
||||
Message, ModelConfig,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let provider = "databricks";
|
||||
let provider_config = json!({
|
||||
"host": std::env::var("DATABRICKS_HOST").expect("Missing DATABRICKS_HOST"),
|
||||
"token": std::env::var("DATABRICKS_TOKEN").expect("Missing DATABRICKS_TOKEN"),
|
||||
});
|
||||
// let model_name = "goose-gpt-4-1"; // parallel tool calls
|
||||
let model_name = "claude-3-5-haiku";
|
||||
let model_config = ModelConfig::new(model_name.to_string());
|
||||
|
||||
let system_prompt_override = "You are a helpful assistant. Talk in the style of pirates.";
|
||||
|
||||
for text in ["How was your day?"] {
|
||||
println!("\n---------------\n");
|
||||
println!("User Input: {text}");
|
||||
let messages = vec![
|
||||
Message::user().with_text("Hi there!"),
|
||||
Message::assistant().with_text("How can I help?"),
|
||||
Message::user().with_text(text),
|
||||
];
|
||||
let completion_response: CompletionResponse = completion(CompletionRequest::new(
|
||||
provider.to_string(),
|
||||
provider_config.clone(),
|
||||
model_config.clone(),
|
||||
None,
|
||||
Some(system_prompt_override.to_string()),
|
||||
messages.clone(),
|
||||
vec![],
|
||||
))
|
||||
.await?;
|
||||
// Print the response
|
||||
println!("\nCompletion Response:");
|
||||
println!("{}", serde_json::to_string_pretty(&completion_response)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -106,7 +106,8 @@ async fn main() -> Result<()> {
|
||||
provider.to_string(),
|
||||
provider_config.clone(),
|
||||
model_config.clone(),
|
||||
system_preamble.to_string(),
|
||||
Some(system_preamble.to_string()),
|
||||
None,
|
||||
messages.clone(),
|
||||
extensions.clone(),
|
||||
))
|
||||
|
||||
@@ -36,7 +36,11 @@ pub async fn completion(req: CompletionRequest) -> Result<CompletionResponse, Co
|
||||
)
|
||||
.map_err(|_| CompletionError::UnknownProvider(req.provider_name.to_string()))?;
|
||||
|
||||
let system_prompt = construct_system_prompt(&req.system_preamble, &req.extensions)?;
|
||||
let system_prompt = construct_system_prompt(
|
||||
&req.system_preamble,
|
||||
&req.system_prompt_override,
|
||||
&req.extensions,
|
||||
)?;
|
||||
let tools = collect_prefixed_tools(&req.extensions);
|
||||
|
||||
// Call the LLM provider
|
||||
@@ -60,9 +64,24 @@ pub async fn completion(req: CompletionRequest) -> Result<CompletionResponse, Co
|
||||
|
||||
/// Render the global `system.md` template with the provided context.
|
||||
fn construct_system_prompt(
|
||||
system_preamble: &str,
|
||||
preamble: &Option<String>,
|
||||
prompt_override: &Option<String>,
|
||||
extensions: &[ExtensionConfig],
|
||||
) -> Result<String, CompletionError> {
|
||||
// If both system_preamble and system_prompt_override are provided, then prompt_override takes precedence
|
||||
// and we don't render the template using preamble and extensions. Just return the prompt_override as is.
|
||||
if prompt_override.is_some() {
|
||||
return Ok(prompt_override.clone().unwrap());
|
||||
}
|
||||
|
||||
let system_preamble = {
|
||||
if let Some(p) = preamble {
|
||||
p
|
||||
} else {
|
||||
"You are a helpful assistant."
|
||||
}
|
||||
};
|
||||
|
||||
let mut context: HashMap<&str, Value> = HashMap::new();
|
||||
context.insert("system_preamble", Value::String(system_preamble.to_owned()));
|
||||
context.insert("extensions", serde_json::to_value(extensions)?);
|
||||
|
||||
@@ -16,7 +16,8 @@ pub struct CompletionRequest {
|
||||
pub provider_name: String,
|
||||
pub provider_config: serde_json::Value,
|
||||
pub model_config: ModelConfig,
|
||||
pub system_preamble: String,
|
||||
pub system_preamble: Option<String>,
|
||||
pub system_prompt_override: Option<String>,
|
||||
pub messages: Vec<Message>,
|
||||
pub extensions: Vec<ExtensionConfig>,
|
||||
}
|
||||
@@ -26,7 +27,8 @@ impl CompletionRequest {
|
||||
provider_name: String,
|
||||
provider_config: serde_json::Value,
|
||||
model_config: ModelConfig,
|
||||
system_preamble: String,
|
||||
system_preamble: Option<String>,
|
||||
system_prompt_override: Option<String>,
|
||||
messages: Vec<Message>,
|
||||
extensions: Vec<ExtensionConfig>,
|
||||
) -> Self {
|
||||
@@ -34,6 +36,7 @@ impl CompletionRequest {
|
||||
provider_name,
|
||||
provider_config,
|
||||
model_config,
|
||||
system_prompt_override,
|
||||
system_preamble,
|
||||
messages,
|
||||
extensions,
|
||||
@@ -41,12 +44,13 @@ impl CompletionRequest {
|
||||
}
|
||||
}
|
||||
|
||||
#[uniffi::export]
|
||||
#[uniffi::export(default(system_preamble = None, system_prompt_override = None))]
|
||||
pub fn create_completion_request(
|
||||
provider_name: &str,
|
||||
provider_config: JsonValueFfi,
|
||||
model_config: ModelConfig,
|
||||
system_preamble: &str,
|
||||
system_preamble: Option<String>,
|
||||
system_prompt_override: Option<String>,
|
||||
messages: Vec<Message>,
|
||||
extensions: Vec<ExtensionConfig>,
|
||||
) -> CompletionRequest {
|
||||
@@ -54,7 +58,8 @@ pub fn create_completion_request(
|
||||
provider_name.to_string(),
|
||||
provider_config,
|
||||
model_config,
|
||||
system_preamble.to_string(),
|
||||
system_preamble,
|
||||
system_prompt_override,
|
||||
messages,
|
||||
extensions,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user