diff --git a/crates/goose-llm/examples/simple.rs b/crates/goose-llm/examples/simple.rs index 0fbfa568..4d6d7044 100644 --- a/crates/goose-llm/examples/simple.rs +++ b/crates/goose-llm/examples/simple.rs @@ -3,7 +3,9 @@ use std::vec; use anyhow::Result; use goose_llm::{ completion, - types::completion::{CompletionResponse, ExtensionConfig, ToolApprovalMode, ToolConfig}, + types::completion::{ + CompletionRequest, CompletionResponse, ExtensionConfig, ToolApprovalMode, ToolConfig, + }, Message, ModelConfig, }; use serde_json::json; @@ -91,13 +93,13 @@ async fn main() -> Result<()> { println!("\n---------------\n"); println!("User Input: {text}"); let messages = vec![Message::user().with_text(text)]; - let completion_response: CompletionResponse = completion( + let completion_response: CompletionResponse = completion(CompletionRequest::new( provider, model_config.clone(), system_preamble, &messages, &extensions, - ) + )) .await?; // Print the response println!("\nCompletion Response:"); diff --git a/crates/goose-llm/src/completion.rs b/crates/goose-llm/src/completion.rs index 2a010217..cfa693d8 100644 --- a/crates/goose-llm/src/completion.rs +++ b/crates/goose-llm/src/completion.rs @@ -6,93 +6,127 @@ use serde_json::Value; use crate::{ message::{Message, MessageContent}, - model::ModelConfig, prompt_template, - providers::{create, errors::ProviderError}, - types::completion::{ - CompletionResponse, ExtensionConfig, RuntimeMetrics, ToolApprovalMode, ToolConfig, + providers::create, + types::{ + completion::{ + CompletionError, CompletionRequest, CompletionResponse, ExtensionConfig, + RuntimeMetrics, ToolApprovalMode, ToolConfig, + }, + core::ToolCall, }, }; -/// Set `needs_approval` on *every* tool call in the message based on approval mode. -pub fn update_needs_approval_for_tool_calls( - message: &mut Message, - tool_configs: &HashMap, -) { - for content in message.content.iter_mut() { - if let MessageContent::ToolRequest(req) = content { - if let Ok(call) = &mut req.tool_call { - let needs = match tool_configs.get(&call.name) { - Some(cfg) => match cfg.approval_mode { - ToolApprovalMode::Auto => false, - ToolApprovalMode::Manual => true, - ToolApprovalMode::Smart => true, // TODO: implement smart approval later - }, - None => call.needs_approval, // unknown tool: leave flag unchanged - }; - - call.set_needs_approval(needs); - } - } - } -} - /// Public API for the Goose LLM completion function -pub async fn completion( - provider: &str, - model_config: ModelConfig, - system_preamble: &str, - messages: &[Message], - extensions: &[ExtensionConfig], -) -> Result { +pub async fn completion(req: CompletionRequest<'_>) -> Result { let start_total = Instant::now(); - let provider = create(provider, model_config).unwrap(); - let system_prompt = construct_system_prompt(system_preamble, extensions); - // println!("\nSystem prompt: {}\n", system_prompt); - let tools = extensions - .iter() - .flat_map(|ext| ext.get_prefixed_tools()) - .collect::>(); + let provider = create(req.provider_name, req.model_config) + .map_err(|_| CompletionError::UnknownProvider(req.provider_name.to_string()))?; + let system_prompt = construct_system_prompt(req.system_preamble, req.extensions)?; + let tools = collect_prefixed_tools(req.extensions); + + // Call the LLM provider let start_provider = Instant::now(); - let mut response = provider.complete(&system_prompt, messages, &tools).await?; - let total_time_ms_provider = start_provider.elapsed().as_millis(); - let tokens_per_second = response.usage.total_tokens.and_then(|toks| { - if total_time_ms_provider > 0 { - Some(toks as f64 / (total_time_ms_provider as f64 / 1000.0)) - } else { - None - } - }); + let mut response = provider + .complete(&system_prompt, req.messages, &tools) + .await?; + let provider_elapsed_ms = start_provider.elapsed().as_millis(); + let usage_tokens = response.usage.total_tokens; - let tool_configs: HashMap = extensions - .iter() - .flat_map(|ext| ext.get_prefixed_tool_configs().into_iter()) - .collect(); + let tool_configs = collect_prefixed_tool_configs(req.extensions); + update_needs_approval_for_tool_calls(&mut response.message, &tool_configs)?; - update_needs_approval_for_tool_calls(&mut response.message, &tool_configs); - - let total_time_ms = start_total.elapsed().as_millis(); Ok(CompletionResponse::new( response.message, response.model, response.usage, - RuntimeMetrics::new(total_time_ms, total_time_ms_provider, tokens_per_second), + calculate_runtime_metrics(start_total, provider_elapsed_ms, usage_tokens), )) } -fn construct_system_prompt(system_preamble: &str, extensions: &[ExtensionConfig]) -> String { +/// Render the global `system.md` template with the provided context. +fn construct_system_prompt( + system_preamble: &str, + extensions: &[ExtensionConfig], +) -> Result { let mut context: HashMap<&str, Value> = HashMap::new(); - + context.insert("system_preamble", Value::String(system_preamble.to_owned())); + context.insert("extensions", serde_json::to_value(extensions)?); context.insert( - "system_preamble", - Value::String(system_preamble.to_string()), + "current_date", + Value::String(Utc::now().format("%Y-%m-%d").to_string()), ); - context.insert("extensions", serde_json::to_value(extensions).unwrap()); - let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string(); - context.insert("current_date_time", Value::String(current_date_time)); - - prompt_template::render_global_file("system.md", &context).expect("Prompt should render") + Ok(prompt_template::render_global_file("system.md", &context)?) +} + +/// Determine if a tool call requires manual approval. +fn determine_needs_approval(config: &ToolConfig, _call: &ToolCall) -> bool { + match config.approval_mode { + ToolApprovalMode::Auto => false, + ToolApprovalMode::Manual => true, + ToolApprovalMode::Smart => { + // TODO: Implement smart approval logic later + true + } + } +} + +/// Set `needs_approval` on every tool call in the message. +/// Returns a `ToolNotFound` error if the corresponding `ToolConfig` is missing. +pub fn update_needs_approval_for_tool_calls( + message: &mut Message, + tool_configs: &HashMap, +) -> Result<(), CompletionError> { + for content in &mut message.content.iter_mut() { + if let MessageContent::ToolRequest(req) = content { + if let Ok(call) = &mut req.tool_call { + // Provide a clear error message when the tool config is missing + let config = tool_configs.get(&call.name).ok_or_else(|| { + CompletionError::ToolNotFound(format!( + "could not find tool config for '{}'", + call.name + )) + })?; + let needs_approval = determine_needs_approval(config, call); + call.set_needs_approval(needs_approval); + } + } + } + Ok(()) +} + +/// Collect all `Tool` instances from the extensions. +fn collect_prefixed_tools(extensions: &[ExtensionConfig]) -> Vec { + extensions + .iter() + .flat_map(|ext| ext.get_prefixed_tools()) + .collect() +} + +/// Collect all `ToolConfig` entries from the extensions into a map. +fn collect_prefixed_tool_configs(extensions: &[ExtensionConfig]) -> HashMap { + extensions + .iter() + .flat_map(|ext| ext.get_prefixed_tool_configs()) + .collect() +} + +/// Compute runtime metrics for the request. +fn calculate_runtime_metrics( + total_start: Instant, + provider_elapsed_ms: u128, + token_count: Option, +) -> RuntimeMetrics { + let total_ms = total_start.elapsed().as_millis(); + let tokens_per_sec = token_count.and_then(|toks| { + if provider_elapsed_ms > 0 { + Some(toks as f64 / (provider_elapsed_ms as f64 / 1_000.0)) + } else { + None + } + }); + RuntimeMetrics::new(total_ms, provider_elapsed_ms, tokens_per_sec) } diff --git a/crates/goose-llm/src/prompts/system.md b/crates/goose-llm/src/prompts/system.md index e08ce2b3..4a2aacde 100644 --- a/crates/goose-llm/src/prompts/system.md +++ b/crates/goose-llm/src/prompts/system.md @@ -1,6 +1,6 @@ {{system_preamble}} -The current date is {{current_date_time}}. +The current date is {{current_date}}. Goose uses LLM providers with tool calling capability. You can be used with different language models (gpt-4o, claude-3.5-sonnet, o1, llama-3.2, deepseek-r1, etc). These models have varying knowledge cut-off dates depending on when they were trained, but typically it's between 5-10 months prior to the current date. diff --git a/crates/goose-llm/src/types/completion.rs b/crates/goose-llm/src/types/completion.rs index 4f9e33da..039c68d7 100644 --- a/crates/goose-llm/src/types/completion.rs +++ b/crates/goose-llm/src/types/completion.rs @@ -3,10 +3,56 @@ // https://docs.google.com/document/d/1r5vjSK3nBQU1cIRf0WKysDigqMlzzrzl_bxEE4msOiw/edit?tab=t.0 use std::collections::HashMap; +use thiserror::Error; use serde::{Deserialize, Serialize}; use crate::{message::Message, providers::Usage}; +use crate::{model::ModelConfig, providers::errors::ProviderError}; + +pub struct CompletionRequest<'a> { + pub provider_name: &'a str, + pub model_config: ModelConfig, + pub system_preamble: &'a str, + pub messages: &'a [Message], + pub extensions: &'a [ExtensionConfig], +} + +impl<'a> CompletionRequest<'a> { + pub fn new( + provider_name: &'a str, + model_config: ModelConfig, + system_preamble: &'a str, + messages: &'a [Message], + extensions: &'a [ExtensionConfig], + ) -> Self { + Self { + provider_name, + model_config, + system_preamble, + messages, + extensions, + } + } +} + +#[derive(Debug, Error)] +pub enum CompletionError { + #[error("failed to create provider: {0}")] + UnknownProvider(String), + + #[error("provider error: {0}")] + Provider(#[from] ProviderError), + + #[error("template rendering error: {0}")] + Template(#[from] minijinja::Error), + + #[error("json serialization error: {0}")] + Json(#[from] serde_json::Error), + + #[error("tool not found error: {0}")] + ToolNotFound(String), +} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CompletionResponse {