diff --git a/Cargo.lock b/Cargo.lock index 693e68a5..fd8b7f4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2517,6 +2517,22 @@ dependencies = [ "tokio", ] +[[package]] +name = "goose-llm" +version = "1.0.20" +dependencies = [ + "anyhow", + "chrono", + "goose", + "include_dir", + "mcp-core", + "minijinja", + "once_cell", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "goose-mcp" version = "1.0.20" diff --git a/crates/goose-llm/Cargo.toml b/crates/goose-llm/Cargo.toml new file mode 100644 index 00000000..1d73bff6 --- /dev/null +++ b/crates/goose-llm/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "goose-llm" +edition.workspace = true +version.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +description.workspace = true + +[dependencies] +goose = { path = "../goose" } +mcp-core = { path = "../mcp-core" } +tokio = { version = "1.43", features = ["full"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +anyhow = "1.0" +minijinja = "2.8.0" +include_dir = "0.7.4" +once_cell = "1.20.2" +chrono = { version = "0.4.38", features = ["serde"] } + +[[example]] +name = "simple" +path = "examples/simple.rs" diff --git a/crates/goose-llm/README.md b/crates/goose-llm/README.md new file mode 100644 index 00000000..09e971d3 --- /dev/null +++ b/crates/goose-llm/README.md @@ -0,0 +1,14 @@ +### goose-llm + +This crate is meant to be used for foreign function interface (FFI). It's meant to be +stateless and contain logic related to providers and prompts: +- chat completion with model providers +- detecting read-only tools for smart approval +- methods for summarization / truncation + + +Run: +``` +cargo run -p goose-llm --example simple +``` + diff --git a/crates/goose-llm/examples/simple.rs b/crates/goose-llm/examples/simple.rs new file mode 100644 index 00000000..97e2d830 --- /dev/null +++ b/crates/goose-llm/examples/simple.rs @@ -0,0 +1,91 @@ +use std::vec; + +use anyhow::Result; +use goose::message::Message; +use goose::model::ModelConfig; +use goose_llm::{completion, CompletionResponse, Extension}; +use mcp_core::tool::Tool; +use serde_json::json; + +#[tokio::main] +async fn main() -> Result<()> { + let provider = "databricks"; + let model_name = "goose-claude-3-5-sonnet"; + let model_config = ModelConfig::new(model_name.to_string()); + + let calculator_tool = Tool::new( + "calculator", + "Perform basic arithmetic operations", + json!({ + "type": "object", + "required": ["operation", "numbers"], + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The arithmetic operation to perform", + }, + "numbers": { + "type": "array", + "items": {"type": "number"}, + "description": "List of numbers to operate on in order", + } + } + }), + None, + ); + + let bash_tool = Tool::new( + "bash_shell", + "Run a shell command", + json!({ + "type": "object", + "required": ["command"], + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute" + } + } + }), + None, + ); + + let extensions = vec![ + Extension::new( + "calculator_extension".to_string(), + Some("This extension provides a calculator tool.".to_string()), + vec![calculator_tool], + ), + Extension::new( + "bash_extension".to_string(), + Some("This extension provides a bash shell tool.".to_string()), + vec![bash_tool], + ), + ]; + + let system_preamble = "You are a helpful assistant."; + + for text in [ + "Add 10037 + 23123", + // "Write some random bad words to end of words.txt", + // "List all json files in the current directory and then multiply the count of the files by 7", + ] { + println!("\n---------------\n"); + println!("User Input: {text}"); + let messages = vec![Message::user().with_text(text)]; + let completion_response: CompletionResponse = completion( + provider, + model_config.clone(), + system_preamble, + &messages, + &extensions, + ) + .await?; + // Print the response + println!("\nCompletion Response:"); + println!("{}", serde_json::to_string_pretty(&completion_response)?); + } + + Ok(()) +} diff --git a/crates/goose-llm/src/completion.rs b/crates/goose-llm/src/completion.rs new file mode 100644 index 00000000..649111b0 --- /dev/null +++ b/crates/goose-llm/src/completion.rs @@ -0,0 +1,91 @@ +use anyhow::Result; +use chrono::Utc; +use serde_json::Value; +use std::collections::HashMap; + +use goose::message::Message; +use goose::model::ModelConfig; +use goose::providers::base::ProviderUsage; +use goose::providers::create; +use goose::providers::errors::ProviderError; +use mcp_core::tool::Tool; + +use crate::prompt_template; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletionResponse { + message: Message, + usage: ProviderUsage, +} + +impl CompletionResponse { + pub fn new(message: Message, usage: ProviderUsage) -> Self { + Self { message, usage } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Extension { + name: String, + instructions: Option, + tools: Vec, +} + +impl Extension { + pub fn new(name: String, instructions: Option, tools: Vec) -> Self { + Self { + name, + instructions, + tools, + } + } + + pub fn get_prefixed_tools(&self) -> Vec { + self.tools + .iter() + .map(|tool| { + let mut prefixed_tool = tool.clone(); + prefixed_tool.name = format!("{}__{}", self.name, tool.name); + prefixed_tool + }) + .collect() + } +} + +/// Public API for the Goose LLM completion function +pub async fn completion( + provider: &str, + model_config: ModelConfig, + system_preamble: &str, + messages: &[Message], + extensions: &[Extension], +) -> Result { + let provider = create(provider, model_config).unwrap(); + let system_prompt = construct_system_prompt(system_preamble, extensions); + // println!("\nSystem prompt: {}\n", system_prompt); + + let tools = extensions + .iter() + .flat_map(|ext| ext.get_prefixed_tools()) + .collect::>(); + let (response, usage) = provider.complete(&system_prompt, messages, &tools).await?; + let result = CompletionResponse::new(response.clone(), usage.clone()); + + Ok(result) +} + +fn construct_system_prompt(system_preamble: &str, extensions: &[Extension]) -> String { + let mut context: HashMap<&str, Value> = HashMap::new(); + + context.insert( + "system_preamble", + Value::String(system_preamble.to_string()), + ); + context.insert("extensions", serde_json::to_value(extensions).unwrap()); + + let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string(); + context.insert("current_date_time", Value::String(current_date_time)); + + prompt_template::render_global_file("system.md", &context).expect("Prompt should render") +} diff --git a/crates/goose-llm/src/lib.rs b/crates/goose-llm/src/lib.rs new file mode 100644 index 00000000..ef3b28bc --- /dev/null +++ b/crates/goose-llm/src/lib.rs @@ -0,0 +1,3 @@ +mod completion; +mod prompt_template; +pub use completion::{completion, CompletionResponse, Extension}; diff --git a/crates/goose-llm/src/prompt_template.rs b/crates/goose-llm/src/prompt_template.rs new file mode 100644 index 00000000..8ed1c3da --- /dev/null +++ b/crates/goose-llm/src/prompt_template.rs @@ -0,0 +1,112 @@ +use include_dir::{include_dir, Dir}; +use minijinja::{Environment, Error as MiniJinjaError, Value as MJValue}; +use once_cell::sync::Lazy; +use serde::Serialize; +use std::path::PathBuf; +use std::sync::{Arc, RwLock}; + +/// This directory will be embedded into the final binary. +/// Typically used to store "core" or "system" prompts. +static CORE_PROMPTS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/prompts"); + +/// A global MiniJinja environment storing the "core" prompts. +/// +/// - Loaded at startup from the `CORE_PROMPTS_DIR`. +/// - Ideal for "system" templates that don't change often. +/// - *Not* used for extension prompts (which are ephemeral). +static GLOBAL_ENV: Lazy>>> = Lazy::new(|| { + let mut env = Environment::new(); + + // Pre-load all core templates from the embedded dir. + for file in CORE_PROMPTS_DIR.files() { + let name = file.path().to_string_lossy().to_string(); + let source = String::from_utf8_lossy(file.contents()).to_string(); + + // Since we're using 'static lifetime for the Environment, we need to ensure + // the strings we add as templates live for the entire program duration. + // We can achieve this by leaking the strings (acceptable for initialization). + let static_name: &'static str = Box::leak(name.into_boxed_str()); + let static_source: &'static str = Box::leak(source.into_boxed_str()); + + if let Err(e) = env.add_template(static_name, static_source) { + println!("Failed to add template {}: {}", static_name, e); + } + } + + Arc::new(RwLock::new(env)) +}); + +/// Renders a prompt from the global environment by name. +/// +/// # Arguments +/// * `template_name` - The name of the template (usually the file path or a custom ID). +/// * `context_data` - Data to be inserted into the template (must be `Serialize`). +pub fn render_global_template( + template_name: &str, + context_data: &T, +) -> Result { + let env = GLOBAL_ENV.read().expect("GLOBAL_ENV lock poisoned"); + let tmpl = env.get_template(template_name)?; + let ctx = MJValue::from_serialize(context_data); + let rendered = tmpl.render(ctx)?; + Ok(rendered.trim().to_string()) +} + +/// Renders a file from `CORE_PROMPTS_DIR` within the global environment. +/// +/// # Arguments +/// * `template_file` - The file path within the embedded directory (e.g. "system.md"). +/// * `context_data` - Data to be inserted into the template (must be `Serialize`). +/// +/// This function **assumes** the file is already in `CORE_PROMPTS_DIR`. If it wasn't +/// added to the global environment at startup (due to parse errors, etc.), this will error out. +pub fn render_global_file( + template_file: impl Into, + context_data: &T, +) -> Result { + let file_path = template_file.into(); + let template_name = file_path.to_string_lossy().to_string(); + + render_global_template(&template_name, context_data) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// For convenience in tests, define a small struct or use a HashMap to provide context. + #[derive(Serialize)] + struct TestContext { + name: String, + age: u32, + } + + #[test] + fn test_global_file_render() { + // "mock.md" should exist in the embedded CORE_PROMPTS_DIR + // and have placeholders for `name` and `age`. + let context = TestContext { + name: "Alice".to_string(), + age: 30, + }; + + let result = render_global_file("mock.md", &context).unwrap(); + // Assume mock.md content is something like: + // "This prompt is only used for testing.\n\nHello, {{ name }}! You are {{ age }} years old." + assert_eq!( + result, + "This prompt is only used for testing.\n\nHello, Alice! You are 30 years old." + ); + } + + #[test] + fn test_global_file_not_found() { + let context = TestContext { + name: "Unused".to_string(), + age: 99, + }; + + let result = render_global_file("non_existent.md", &context); + assert!(result.is_err(), "Should fail because file is missing"); + } +} diff --git a/crates/goose-llm/src/prompts/mock.md b/crates/goose-llm/src/prompts/mock.md new file mode 100644 index 00000000..46c1e708 --- /dev/null +++ b/crates/goose-llm/src/prompts/mock.md @@ -0,0 +1,3 @@ +This prompt is only used for testing. + +Hello, {{ name }}! You are {{ age }} years old. \ No newline at end of file diff --git a/crates/goose-llm/src/prompts/system.md b/crates/goose-llm/src/prompts/system.md new file mode 100644 index 00000000..e08ce2b3 --- /dev/null +++ b/crates/goose-llm/src/prompts/system.md @@ -0,0 +1,34 @@ +{{system_preamble}} + +The current date is {{current_date_time}}. + +Goose uses LLM providers with tool calling capability. You can be used with different language models (gpt-4o, claude-3.5-sonnet, o1, llama-3.2, deepseek-r1, etc). +These models have varying knowledge cut-off dates depending on when they were trained, but typically it's between 5-10 months prior to the current date. + +# Extensions + +Extensions allow other applications to provide context to Goose. Extensions connect Goose to different data sources and tools. + +{% if (extensions is defined) and extensions %} +Because you dynamically load extensions, your conversation history may refer +to interactions with extensions that are not currently active. The currently +active extensions are below. Each of these extensions provides tools that are +in your tool specification. + +{% for extension in extensions %} +## {{extension.name}} +{% if extension.instructions %}### Instructions +{{extension.instructions}}{% endif %} +{% endfor %} +{% else %} +No extensions are defined. You should let the user know that they should add extensions.{% endif %} + +# Response Guidelines + +- Use Markdown formatting for all responses. +- Follow best practices for Markdown, including: + - Using headers for organization. + - Bullet points for lists. + - Links formatted correctly, either as linked text (e.g., [this is linked text](https://example.com)) or automatic links using angle brackets (e.g., ). +- For code examples, use fenced code blocks by placing triple backticks (` ``` `) before and after the code. Include the language identifier after the opening backticks (e.g., ` ```python `) to enable syntax highlighting. +- Ensure clarity, conciseness, and proper formatting to enhance readability and usability. diff --git a/crates/goose/src/permission/permission_judge.rs b/crates/goose/src/permission/permission_judge.rs index 204dcb56..368ce82f 100644 --- a/crates/goose/src/permission/permission_judge.rs +++ b/crates/goose/src/permission/permission_judge.rs @@ -7,6 +7,7 @@ use chrono::Utc; use indoc::indoc; use mcp_core::tool::ToolAnnotations; use mcp_core::{tool::Tool, TextContent}; +use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::collections::HashSet; use std::sync::Arc; @@ -152,6 +153,7 @@ pub async fn detect_read_only_tools( } // Define return structure +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct PermissionCheckResult { pub approved: Vec, pub needs_approval: Vec,