[goose-llm] add completion request & error (#2451)

This commit is contained in:
Salman Mohammed
2025-05-07 14:50:58 -04:00
committed by GitHub
parent 97ff0cdd03
commit c37ea4e99e
4 changed files with 153 additions and 71 deletions

View File

@@ -3,7 +3,9 @@ use std::vec;
use anyhow::Result;
use goose_llm::{
completion,
types::completion::{CompletionResponse, ExtensionConfig, ToolApprovalMode, ToolConfig},
types::completion::{
CompletionRequest, CompletionResponse, ExtensionConfig, ToolApprovalMode, ToolConfig,
},
Message, ModelConfig,
};
use serde_json::json;
@@ -91,13 +93,13 @@ async fn main() -> Result<()> {
println!("\n---------------\n");
println!("User Input: {text}");
let messages = vec![Message::user().with_text(text)];
let completion_response: CompletionResponse = completion(
let completion_response: CompletionResponse = completion(CompletionRequest::new(
provider,
model_config.clone(),
system_preamble,
&messages,
&extensions,
)
))
.await?;
// Print the response
println!("\nCompletion Response:");

View File

@@ -6,93 +6,127 @@ use serde_json::Value;
use crate::{
message::{Message, MessageContent},
model::ModelConfig,
prompt_template,
providers::{create, errors::ProviderError},
types::completion::{
CompletionResponse, ExtensionConfig, RuntimeMetrics, ToolApprovalMode, ToolConfig,
providers::create,
types::{
completion::{
CompletionError, CompletionRequest, CompletionResponse, ExtensionConfig,
RuntimeMetrics, ToolApprovalMode, ToolConfig,
},
core::ToolCall,
},
};
/// Set `needs_approval` on *every* tool call in the message based on approval mode.
pub fn update_needs_approval_for_tool_calls(
message: &mut Message,
tool_configs: &HashMap<String, ToolConfig>,
) {
for content in message.content.iter_mut() {
if let MessageContent::ToolRequest(req) = content {
if let Ok(call) = &mut req.tool_call {
let needs = match tool_configs.get(&call.name) {
Some(cfg) => match cfg.approval_mode {
ToolApprovalMode::Auto => false,
ToolApprovalMode::Manual => true,
ToolApprovalMode::Smart => true, // TODO: implement smart approval later
},
None => call.needs_approval, // unknown tool: leave flag unchanged
};
call.set_needs_approval(needs);
}
}
}
}
/// Public API for the Goose LLM completion function
pub async fn completion(
provider: &str,
model_config: ModelConfig,
system_preamble: &str,
messages: &[Message],
extensions: &[ExtensionConfig],
) -> Result<CompletionResponse, ProviderError> {
pub async fn completion(req: CompletionRequest<'_>) -> Result<CompletionResponse, CompletionError> {
let start_total = Instant::now();
let provider = create(provider, model_config).unwrap();
let system_prompt = construct_system_prompt(system_preamble, extensions);
// println!("\nSystem prompt: {}\n", system_prompt);
let tools = extensions
.iter()
.flat_map(|ext| ext.get_prefixed_tools())
.collect::<Vec<_>>();
let provider = create(req.provider_name, req.model_config)
.map_err(|_| CompletionError::UnknownProvider(req.provider_name.to_string()))?;
let system_prompt = construct_system_prompt(req.system_preamble, req.extensions)?;
let tools = collect_prefixed_tools(req.extensions);
// Call the LLM provider
let start_provider = Instant::now();
let mut response = provider.complete(&system_prompt, messages, &tools).await?;
let total_time_ms_provider = start_provider.elapsed().as_millis();
let tokens_per_second = response.usage.total_tokens.and_then(|toks| {
if total_time_ms_provider > 0 {
Some(toks as f64 / (total_time_ms_provider as f64 / 1000.0))
} else {
None
}
});
let mut response = provider
.complete(&system_prompt, req.messages, &tools)
.await?;
let provider_elapsed_ms = start_provider.elapsed().as_millis();
let usage_tokens = response.usage.total_tokens;
let tool_configs: HashMap<String, ToolConfig> = extensions
.iter()
.flat_map(|ext| ext.get_prefixed_tool_configs().into_iter())
.collect();
let tool_configs = collect_prefixed_tool_configs(req.extensions);
update_needs_approval_for_tool_calls(&mut response.message, &tool_configs)?;
update_needs_approval_for_tool_calls(&mut response.message, &tool_configs);
let total_time_ms = start_total.elapsed().as_millis();
Ok(CompletionResponse::new(
response.message,
response.model,
response.usage,
RuntimeMetrics::new(total_time_ms, total_time_ms_provider, tokens_per_second),
calculate_runtime_metrics(start_total, provider_elapsed_ms, usage_tokens),
))
}
fn construct_system_prompt(system_preamble: &str, extensions: &[ExtensionConfig]) -> String {
/// Render the global `system.md` template with the provided context.
fn construct_system_prompt(
system_preamble: &str,
extensions: &[ExtensionConfig],
) -> Result<String, CompletionError> {
let mut context: HashMap<&str, Value> = HashMap::new();
context.insert("system_preamble", Value::String(system_preamble.to_owned()));
context.insert("extensions", serde_json::to_value(extensions)?);
context.insert(
"system_preamble",
Value::String(system_preamble.to_string()),
"current_date",
Value::String(Utc::now().format("%Y-%m-%d").to_string()),
);
context.insert("extensions", serde_json::to_value(extensions).unwrap());
let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
context.insert("current_date_time", Value::String(current_date_time));
prompt_template::render_global_file("system.md", &context).expect("Prompt should render")
Ok(prompt_template::render_global_file("system.md", &context)?)
}
/// Determine if a tool call requires manual approval.
fn determine_needs_approval(config: &ToolConfig, _call: &ToolCall) -> bool {
match config.approval_mode {
ToolApprovalMode::Auto => false,
ToolApprovalMode::Manual => true,
ToolApprovalMode::Smart => {
// TODO: Implement smart approval logic later
true
}
}
}
/// Set `needs_approval` on every tool call in the message.
/// Returns a `ToolNotFound` error if the corresponding `ToolConfig` is missing.
pub fn update_needs_approval_for_tool_calls(
message: &mut Message,
tool_configs: &HashMap<String, ToolConfig>,
) -> Result<(), CompletionError> {
for content in &mut message.content.iter_mut() {
if let MessageContent::ToolRequest(req) = content {
if let Ok(call) = &mut req.tool_call {
// Provide a clear error message when the tool config is missing
let config = tool_configs.get(&call.name).ok_or_else(|| {
CompletionError::ToolNotFound(format!(
"could not find tool config for '{}'",
call.name
))
})?;
let needs_approval = determine_needs_approval(config, call);
call.set_needs_approval(needs_approval);
}
}
}
Ok(())
}
/// Collect all `Tool` instances from the extensions.
fn collect_prefixed_tools(extensions: &[ExtensionConfig]) -> Vec<crate::types::core::Tool> {
extensions
.iter()
.flat_map(|ext| ext.get_prefixed_tools())
.collect()
}
/// Collect all `ToolConfig` entries from the extensions into a map.
fn collect_prefixed_tool_configs(extensions: &[ExtensionConfig]) -> HashMap<String, ToolConfig> {
extensions
.iter()
.flat_map(|ext| ext.get_prefixed_tool_configs())
.collect()
}
/// Compute runtime metrics for the request.
fn calculate_runtime_metrics(
total_start: Instant,
provider_elapsed_ms: u128,
token_count: Option<i32>,
) -> RuntimeMetrics {
let total_ms = total_start.elapsed().as_millis();
let tokens_per_sec = token_count.and_then(|toks| {
if provider_elapsed_ms > 0 {
Some(toks as f64 / (provider_elapsed_ms as f64 / 1_000.0))
} else {
None
}
});
RuntimeMetrics::new(total_ms, provider_elapsed_ms, tokens_per_sec)
}

View File

@@ -1,6 +1,6 @@
{{system_preamble}}
The current date is {{current_date_time}}.
The current date is {{current_date}}.
Goose uses LLM providers with tool calling capability. You can be used with different language models (gpt-4o, claude-3.5-sonnet, o1, llama-3.2, deepseek-r1, etc).
These models have varying knowledge cut-off dates depending on when they were trained, but typically it's between 5-10 months prior to the current date.

View File

@@ -3,10 +3,56 @@
// https://docs.google.com/document/d/1r5vjSK3nBQU1cIRf0WKysDigqMlzzrzl_bxEE4msOiw/edit?tab=t.0
use std::collections::HashMap;
use thiserror::Error;
use serde::{Deserialize, Serialize};
use crate::{message::Message, providers::Usage};
use crate::{model::ModelConfig, providers::errors::ProviderError};
pub struct CompletionRequest<'a> {
pub provider_name: &'a str,
pub model_config: ModelConfig,
pub system_preamble: &'a str,
pub messages: &'a [Message],
pub extensions: &'a [ExtensionConfig],
}
impl<'a> CompletionRequest<'a> {
pub fn new(
provider_name: &'a str,
model_config: ModelConfig,
system_preamble: &'a str,
messages: &'a [Message],
extensions: &'a [ExtensionConfig],
) -> Self {
Self {
provider_name,
model_config,
system_preamble,
messages,
extensions,
}
}
}
#[derive(Debug, Error)]
pub enum CompletionError {
#[error("failed to create provider: {0}")]
UnknownProvider(String),
#[error("provider error: {0}")]
Provider(#[from] ProviderError),
#[error("template rendering error: {0}")]
Template(#[from] minijinja::Error),
#[error("json serialization error: {0}")]
Json(#[from] serde_json::Error),
#[error("tool not found error: {0}")]
ToolNotFound(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {