mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 22:54:24 +01:00
[goose-llm] add completion request & error (#2451)
This commit is contained in:
@@ -3,7 +3,9 @@ use std::vec;
|
||||
use anyhow::Result;
|
||||
use goose_llm::{
|
||||
completion,
|
||||
types::completion::{CompletionResponse, ExtensionConfig, ToolApprovalMode, ToolConfig},
|
||||
types::completion::{
|
||||
CompletionRequest, CompletionResponse, ExtensionConfig, ToolApprovalMode, ToolConfig,
|
||||
},
|
||||
Message, ModelConfig,
|
||||
};
|
||||
use serde_json::json;
|
||||
@@ -91,13 +93,13 @@ async fn main() -> Result<()> {
|
||||
println!("\n---------------\n");
|
||||
println!("User Input: {text}");
|
||||
let messages = vec![Message::user().with_text(text)];
|
||||
let completion_response: CompletionResponse = completion(
|
||||
let completion_response: CompletionResponse = completion(CompletionRequest::new(
|
||||
provider,
|
||||
model_config.clone(),
|
||||
system_preamble,
|
||||
&messages,
|
||||
&extensions,
|
||||
)
|
||||
))
|
||||
.await?;
|
||||
// Print the response
|
||||
println!("\nCompletion Response:");
|
||||
|
||||
@@ -6,93 +6,127 @@ use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
message::{Message, MessageContent},
|
||||
model::ModelConfig,
|
||||
prompt_template,
|
||||
providers::{create, errors::ProviderError},
|
||||
types::completion::{
|
||||
CompletionResponse, ExtensionConfig, RuntimeMetrics, ToolApprovalMode, ToolConfig,
|
||||
providers::create,
|
||||
types::{
|
||||
completion::{
|
||||
CompletionError, CompletionRequest, CompletionResponse, ExtensionConfig,
|
||||
RuntimeMetrics, ToolApprovalMode, ToolConfig,
|
||||
},
|
||||
core::ToolCall,
|
||||
},
|
||||
};
|
||||
|
||||
/// Set `needs_approval` on *every* tool call in the message based on approval mode.
|
||||
pub fn update_needs_approval_for_tool_calls(
|
||||
message: &mut Message,
|
||||
tool_configs: &HashMap<String, ToolConfig>,
|
||||
) {
|
||||
for content in message.content.iter_mut() {
|
||||
if let MessageContent::ToolRequest(req) = content {
|
||||
if let Ok(call) = &mut req.tool_call {
|
||||
let needs = match tool_configs.get(&call.name) {
|
||||
Some(cfg) => match cfg.approval_mode {
|
||||
ToolApprovalMode::Auto => false,
|
||||
ToolApprovalMode::Manual => true,
|
||||
ToolApprovalMode::Smart => true, // TODO: implement smart approval later
|
||||
},
|
||||
None => call.needs_approval, // unknown tool: leave flag unchanged
|
||||
};
|
||||
|
||||
call.set_needs_approval(needs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Public API for the Goose LLM completion function
|
||||
pub async fn completion(
|
||||
provider: &str,
|
||||
model_config: ModelConfig,
|
||||
system_preamble: &str,
|
||||
messages: &[Message],
|
||||
extensions: &[ExtensionConfig],
|
||||
) -> Result<CompletionResponse, ProviderError> {
|
||||
pub async fn completion(req: CompletionRequest<'_>) -> Result<CompletionResponse, CompletionError> {
|
||||
let start_total = Instant::now();
|
||||
let provider = create(provider, model_config).unwrap();
|
||||
let system_prompt = construct_system_prompt(system_preamble, extensions);
|
||||
// println!("\nSystem prompt: {}\n", system_prompt);
|
||||
|
||||
let tools = extensions
|
||||
.iter()
|
||||
.flat_map(|ext| ext.get_prefixed_tools())
|
||||
.collect::<Vec<_>>();
|
||||
let provider = create(req.provider_name, req.model_config)
|
||||
.map_err(|_| CompletionError::UnknownProvider(req.provider_name.to_string()))?;
|
||||
|
||||
let system_prompt = construct_system_prompt(req.system_preamble, req.extensions)?;
|
||||
let tools = collect_prefixed_tools(req.extensions);
|
||||
|
||||
// Call the LLM provider
|
||||
let start_provider = Instant::now();
|
||||
let mut response = provider.complete(&system_prompt, messages, &tools).await?;
|
||||
let total_time_ms_provider = start_provider.elapsed().as_millis();
|
||||
let tokens_per_second = response.usage.total_tokens.and_then(|toks| {
|
||||
if total_time_ms_provider > 0 {
|
||||
Some(toks as f64 / (total_time_ms_provider as f64 / 1000.0))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
let mut response = provider
|
||||
.complete(&system_prompt, req.messages, &tools)
|
||||
.await?;
|
||||
let provider_elapsed_ms = start_provider.elapsed().as_millis();
|
||||
let usage_tokens = response.usage.total_tokens;
|
||||
|
||||
let tool_configs: HashMap<String, ToolConfig> = extensions
|
||||
.iter()
|
||||
.flat_map(|ext| ext.get_prefixed_tool_configs().into_iter())
|
||||
.collect();
|
||||
let tool_configs = collect_prefixed_tool_configs(req.extensions);
|
||||
update_needs_approval_for_tool_calls(&mut response.message, &tool_configs)?;
|
||||
|
||||
update_needs_approval_for_tool_calls(&mut response.message, &tool_configs);
|
||||
|
||||
let total_time_ms = start_total.elapsed().as_millis();
|
||||
Ok(CompletionResponse::new(
|
||||
response.message,
|
||||
response.model,
|
||||
response.usage,
|
||||
RuntimeMetrics::new(total_time_ms, total_time_ms_provider, tokens_per_second),
|
||||
calculate_runtime_metrics(start_total, provider_elapsed_ms, usage_tokens),
|
||||
))
|
||||
}
|
||||
|
||||
fn construct_system_prompt(system_preamble: &str, extensions: &[ExtensionConfig]) -> String {
|
||||
/// Render the global `system.md` template with the provided context.
|
||||
fn construct_system_prompt(
|
||||
system_preamble: &str,
|
||||
extensions: &[ExtensionConfig],
|
||||
) -> Result<String, CompletionError> {
|
||||
let mut context: HashMap<&str, Value> = HashMap::new();
|
||||
|
||||
context.insert("system_preamble", Value::String(system_preamble.to_owned()));
|
||||
context.insert("extensions", serde_json::to_value(extensions)?);
|
||||
context.insert(
|
||||
"system_preamble",
|
||||
Value::String(system_preamble.to_string()),
|
||||
"current_date",
|
||||
Value::String(Utc::now().format("%Y-%m-%d").to_string()),
|
||||
);
|
||||
context.insert("extensions", serde_json::to_value(extensions).unwrap());
|
||||
|
||||
let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
|
||||
context.insert("current_date_time", Value::String(current_date_time));
|
||||
|
||||
prompt_template::render_global_file("system.md", &context).expect("Prompt should render")
|
||||
Ok(prompt_template::render_global_file("system.md", &context)?)
|
||||
}
|
||||
|
||||
/// Determine if a tool call requires manual approval.
|
||||
fn determine_needs_approval(config: &ToolConfig, _call: &ToolCall) -> bool {
|
||||
match config.approval_mode {
|
||||
ToolApprovalMode::Auto => false,
|
||||
ToolApprovalMode::Manual => true,
|
||||
ToolApprovalMode::Smart => {
|
||||
// TODO: Implement smart approval logic later
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set `needs_approval` on every tool call in the message.
|
||||
/// Returns a `ToolNotFound` error if the corresponding `ToolConfig` is missing.
|
||||
pub fn update_needs_approval_for_tool_calls(
|
||||
message: &mut Message,
|
||||
tool_configs: &HashMap<String, ToolConfig>,
|
||||
) -> Result<(), CompletionError> {
|
||||
for content in &mut message.content.iter_mut() {
|
||||
if let MessageContent::ToolRequest(req) = content {
|
||||
if let Ok(call) = &mut req.tool_call {
|
||||
// Provide a clear error message when the tool config is missing
|
||||
let config = tool_configs.get(&call.name).ok_or_else(|| {
|
||||
CompletionError::ToolNotFound(format!(
|
||||
"could not find tool config for '{}'",
|
||||
call.name
|
||||
))
|
||||
})?;
|
||||
let needs_approval = determine_needs_approval(config, call);
|
||||
call.set_needs_approval(needs_approval);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Collect all `Tool` instances from the extensions.
|
||||
fn collect_prefixed_tools(extensions: &[ExtensionConfig]) -> Vec<crate::types::core::Tool> {
|
||||
extensions
|
||||
.iter()
|
||||
.flat_map(|ext| ext.get_prefixed_tools())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Collect all `ToolConfig` entries from the extensions into a map.
|
||||
fn collect_prefixed_tool_configs(extensions: &[ExtensionConfig]) -> HashMap<String, ToolConfig> {
|
||||
extensions
|
||||
.iter()
|
||||
.flat_map(|ext| ext.get_prefixed_tool_configs())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute runtime metrics for the request.
|
||||
fn calculate_runtime_metrics(
|
||||
total_start: Instant,
|
||||
provider_elapsed_ms: u128,
|
||||
token_count: Option<i32>,
|
||||
) -> RuntimeMetrics {
|
||||
let total_ms = total_start.elapsed().as_millis();
|
||||
let tokens_per_sec = token_count.and_then(|toks| {
|
||||
if provider_elapsed_ms > 0 {
|
||||
Some(toks as f64 / (provider_elapsed_ms as f64 / 1_000.0))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
RuntimeMetrics::new(total_ms, provider_elapsed_ms, tokens_per_sec)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{{system_preamble}}
|
||||
|
||||
The current date is {{current_date_time}}.
|
||||
The current date is {{current_date}}.
|
||||
|
||||
Goose uses LLM providers with tool calling capability. You can be used with different language models (gpt-4o, claude-3.5-sonnet, o1, llama-3.2, deepseek-r1, etc).
|
||||
These models have varying knowledge cut-off dates depending on when they were trained, but typically it's between 5-10 months prior to the current date.
|
||||
|
||||
@@ -3,10 +3,56 @@
|
||||
// https://docs.google.com/document/d/1r5vjSK3nBQU1cIRf0WKysDigqMlzzrzl_bxEE4msOiw/edit?tab=t.0
|
||||
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{message::Message, providers::Usage};
|
||||
use crate::{model::ModelConfig, providers::errors::ProviderError};
|
||||
|
||||
pub struct CompletionRequest<'a> {
|
||||
pub provider_name: &'a str,
|
||||
pub model_config: ModelConfig,
|
||||
pub system_preamble: &'a str,
|
||||
pub messages: &'a [Message],
|
||||
pub extensions: &'a [ExtensionConfig],
|
||||
}
|
||||
|
||||
impl<'a> CompletionRequest<'a> {
|
||||
pub fn new(
|
||||
provider_name: &'a str,
|
||||
model_config: ModelConfig,
|
||||
system_preamble: &'a str,
|
||||
messages: &'a [Message],
|
||||
extensions: &'a [ExtensionConfig],
|
||||
) -> Self {
|
||||
Self {
|
||||
provider_name,
|
||||
model_config,
|
||||
system_preamble,
|
||||
messages,
|
||||
extensions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CompletionError {
|
||||
#[error("failed to create provider: {0}")]
|
||||
UnknownProvider(String),
|
||||
|
||||
#[error("provider error: {0}")]
|
||||
Provider(#[from] ProviderError),
|
||||
|
||||
#[error("template rendering error: {0}")]
|
||||
Template(#[from] minijinja::Error),
|
||||
|
||||
#[error("json serialization error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[error("tool not found error: {0}")]
|
||||
ToolNotFound(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CompletionResponse {
|
||||
|
||||
Reference in New Issue
Block a user