mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 22:54:24 +01:00
[goose-llm] add completion request & error (#2451)
This commit is contained in:
@@ -3,7 +3,9 @@ use std::vec;
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use goose_llm::{
|
use goose_llm::{
|
||||||
completion,
|
completion,
|
||||||
types::completion::{CompletionResponse, ExtensionConfig, ToolApprovalMode, ToolConfig},
|
types::completion::{
|
||||||
|
CompletionRequest, CompletionResponse, ExtensionConfig, ToolApprovalMode, ToolConfig,
|
||||||
|
},
|
||||||
Message, ModelConfig,
|
Message, ModelConfig,
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
@@ -91,13 +93,13 @@ async fn main() -> Result<()> {
|
|||||||
println!("\n---------------\n");
|
println!("\n---------------\n");
|
||||||
println!("User Input: {text}");
|
println!("User Input: {text}");
|
||||||
let messages = vec![Message::user().with_text(text)];
|
let messages = vec![Message::user().with_text(text)];
|
||||||
let completion_response: CompletionResponse = completion(
|
let completion_response: CompletionResponse = completion(CompletionRequest::new(
|
||||||
provider,
|
provider,
|
||||||
model_config.clone(),
|
model_config.clone(),
|
||||||
system_preamble,
|
system_preamble,
|
||||||
&messages,
|
&messages,
|
||||||
&extensions,
|
&extensions,
|
||||||
)
|
))
|
||||||
.await?;
|
.await?;
|
||||||
// Print the response
|
// Print the response
|
||||||
println!("\nCompletion Response:");
|
println!("\nCompletion Response:");
|
||||||
|
|||||||
@@ -6,93 +6,127 @@ use serde_json::Value;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
message::{Message, MessageContent},
|
message::{Message, MessageContent},
|
||||||
model::ModelConfig,
|
|
||||||
prompt_template,
|
prompt_template,
|
||||||
providers::{create, errors::ProviderError},
|
providers::create,
|
||||||
types::completion::{
|
types::{
|
||||||
CompletionResponse, ExtensionConfig, RuntimeMetrics, ToolApprovalMode, ToolConfig,
|
completion::{
|
||||||
|
CompletionError, CompletionRequest, CompletionResponse, ExtensionConfig,
|
||||||
|
RuntimeMetrics, ToolApprovalMode, ToolConfig,
|
||||||
|
},
|
||||||
|
core::ToolCall,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Set `needs_approval` on *every* tool call in the message based on approval mode.
|
|
||||||
pub fn update_needs_approval_for_tool_calls(
|
|
||||||
message: &mut Message,
|
|
||||||
tool_configs: &HashMap<String, ToolConfig>,
|
|
||||||
) {
|
|
||||||
for content in message.content.iter_mut() {
|
|
||||||
if let MessageContent::ToolRequest(req) = content {
|
|
||||||
if let Ok(call) = &mut req.tool_call {
|
|
||||||
let needs = match tool_configs.get(&call.name) {
|
|
||||||
Some(cfg) => match cfg.approval_mode {
|
|
||||||
ToolApprovalMode::Auto => false,
|
|
||||||
ToolApprovalMode::Manual => true,
|
|
||||||
ToolApprovalMode::Smart => true, // TODO: implement smart approval later
|
|
||||||
},
|
|
||||||
None => call.needs_approval, // unknown tool: leave flag unchanged
|
|
||||||
};
|
|
||||||
|
|
||||||
call.set_needs_approval(needs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Public API for the Goose LLM completion function
|
/// Public API for the Goose LLM completion function
|
||||||
pub async fn completion(
|
pub async fn completion(req: CompletionRequest<'_>) -> Result<CompletionResponse, CompletionError> {
|
||||||
provider: &str,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
system_preamble: &str,
|
|
||||||
messages: &[Message],
|
|
||||||
extensions: &[ExtensionConfig],
|
|
||||||
) -> Result<CompletionResponse, ProviderError> {
|
|
||||||
let start_total = Instant::now();
|
let start_total = Instant::now();
|
||||||
let provider = create(provider, model_config).unwrap();
|
|
||||||
let system_prompt = construct_system_prompt(system_preamble, extensions);
|
|
||||||
// println!("\nSystem prompt: {}\n", system_prompt);
|
|
||||||
|
|
||||||
let tools = extensions
|
let provider = create(req.provider_name, req.model_config)
|
||||||
.iter()
|
.map_err(|_| CompletionError::UnknownProvider(req.provider_name.to_string()))?;
|
||||||
.flat_map(|ext| ext.get_prefixed_tools())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
|
let system_prompt = construct_system_prompt(req.system_preamble, req.extensions)?;
|
||||||
|
let tools = collect_prefixed_tools(req.extensions);
|
||||||
|
|
||||||
|
// Call the LLM provider
|
||||||
let start_provider = Instant::now();
|
let start_provider = Instant::now();
|
||||||
let mut response = provider.complete(&system_prompt, messages, &tools).await?;
|
let mut response = provider
|
||||||
let total_time_ms_provider = start_provider.elapsed().as_millis();
|
.complete(&system_prompt, req.messages, &tools)
|
||||||
let tokens_per_second = response.usage.total_tokens.and_then(|toks| {
|
.await?;
|
||||||
if total_time_ms_provider > 0 {
|
let provider_elapsed_ms = start_provider.elapsed().as_millis();
|
||||||
Some(toks as f64 / (total_time_ms_provider as f64 / 1000.0))
|
let usage_tokens = response.usage.total_tokens;
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let tool_configs: HashMap<String, ToolConfig> = extensions
|
let tool_configs = collect_prefixed_tool_configs(req.extensions);
|
||||||
.iter()
|
update_needs_approval_for_tool_calls(&mut response.message, &tool_configs)?;
|
||||||
.flat_map(|ext| ext.get_prefixed_tool_configs().into_iter())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
update_needs_approval_for_tool_calls(&mut response.message, &tool_configs);
|
|
||||||
|
|
||||||
let total_time_ms = start_total.elapsed().as_millis();
|
|
||||||
Ok(CompletionResponse::new(
|
Ok(CompletionResponse::new(
|
||||||
response.message,
|
response.message,
|
||||||
response.model,
|
response.model,
|
||||||
response.usage,
|
response.usage,
|
||||||
RuntimeMetrics::new(total_time_ms, total_time_ms_provider, tokens_per_second),
|
calculate_runtime_metrics(start_total, provider_elapsed_ms, usage_tokens),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn construct_system_prompt(system_preamble: &str, extensions: &[ExtensionConfig]) -> String {
|
/// Render the global `system.md` template with the provided context.
|
||||||
|
fn construct_system_prompt(
|
||||||
|
system_preamble: &str,
|
||||||
|
extensions: &[ExtensionConfig],
|
||||||
|
) -> Result<String, CompletionError> {
|
||||||
let mut context: HashMap<&str, Value> = HashMap::new();
|
let mut context: HashMap<&str, Value> = HashMap::new();
|
||||||
|
context.insert("system_preamble", Value::String(system_preamble.to_owned()));
|
||||||
|
context.insert("extensions", serde_json::to_value(extensions)?);
|
||||||
context.insert(
|
context.insert(
|
||||||
"system_preamble",
|
"current_date",
|
||||||
Value::String(system_preamble.to_string()),
|
Value::String(Utc::now().format("%Y-%m-%d").to_string()),
|
||||||
);
|
);
|
||||||
context.insert("extensions", serde_json::to_value(extensions).unwrap());
|
|
||||||
|
|
||||||
let current_date_time = Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
|
Ok(prompt_template::render_global_file("system.md", &context)?)
|
||||||
context.insert("current_date_time", Value::String(current_date_time));
|
}
|
||||||
|
|
||||||
prompt_template::render_global_file("system.md", &context).expect("Prompt should render")
|
/// Determine if a tool call requires manual approval.
|
||||||
|
fn determine_needs_approval(config: &ToolConfig, _call: &ToolCall) -> bool {
|
||||||
|
match config.approval_mode {
|
||||||
|
ToolApprovalMode::Auto => false,
|
||||||
|
ToolApprovalMode::Manual => true,
|
||||||
|
ToolApprovalMode::Smart => {
|
||||||
|
// TODO: Implement smart approval logic later
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set `needs_approval` on every tool call in the message.
|
||||||
|
/// Returns a `ToolNotFound` error if the corresponding `ToolConfig` is missing.
|
||||||
|
pub fn update_needs_approval_for_tool_calls(
|
||||||
|
message: &mut Message,
|
||||||
|
tool_configs: &HashMap<String, ToolConfig>,
|
||||||
|
) -> Result<(), CompletionError> {
|
||||||
|
for content in &mut message.content.iter_mut() {
|
||||||
|
if let MessageContent::ToolRequest(req) = content {
|
||||||
|
if let Ok(call) = &mut req.tool_call {
|
||||||
|
// Provide a clear error message when the tool config is missing
|
||||||
|
let config = tool_configs.get(&call.name).ok_or_else(|| {
|
||||||
|
CompletionError::ToolNotFound(format!(
|
||||||
|
"could not find tool config for '{}'",
|
||||||
|
call.name
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
let needs_approval = determine_needs_approval(config, call);
|
||||||
|
call.set_needs_approval(needs_approval);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Collect all `Tool` instances from the extensions.
|
||||||
|
fn collect_prefixed_tools(extensions: &[ExtensionConfig]) -> Vec<crate::types::core::Tool> {
|
||||||
|
extensions
|
||||||
|
.iter()
|
||||||
|
.flat_map(|ext| ext.get_prefixed_tools())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Collect all `ToolConfig` entries from the extensions into a map.
|
||||||
|
fn collect_prefixed_tool_configs(extensions: &[ExtensionConfig]) -> HashMap<String, ToolConfig> {
|
||||||
|
extensions
|
||||||
|
.iter()
|
||||||
|
.flat_map(|ext| ext.get_prefixed_tool_configs())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute runtime metrics for the request.
|
||||||
|
fn calculate_runtime_metrics(
|
||||||
|
total_start: Instant,
|
||||||
|
provider_elapsed_ms: u128,
|
||||||
|
token_count: Option<i32>,
|
||||||
|
) -> RuntimeMetrics {
|
||||||
|
let total_ms = total_start.elapsed().as_millis();
|
||||||
|
let tokens_per_sec = token_count.and_then(|toks| {
|
||||||
|
if provider_elapsed_ms > 0 {
|
||||||
|
Some(toks as f64 / (provider_elapsed_ms as f64 / 1_000.0))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
});
|
||||||
|
RuntimeMetrics::new(total_ms, provider_elapsed_ms, tokens_per_sec)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{{system_preamble}}
|
{{system_preamble}}
|
||||||
|
|
||||||
The current date is {{current_date_time}}.
|
The current date is {{current_date}}.
|
||||||
|
|
||||||
Goose uses LLM providers with tool calling capability. You can be used with different language models (gpt-4o, claude-3.5-sonnet, o1, llama-3.2, deepseek-r1, etc).
|
Goose uses LLM providers with tool calling capability. You can be used with different language models (gpt-4o, claude-3.5-sonnet, o1, llama-3.2, deepseek-r1, etc).
|
||||||
These models have varying knowledge cut-off dates depending on when they were trained, but typically it's between 5-10 months prior to the current date.
|
These models have varying knowledge cut-off dates depending on when they were trained, but typically it's between 5-10 months prior to the current date.
|
||||||
|
|||||||
@@ -3,10 +3,56 @@
|
|||||||
// https://docs.google.com/document/d/1r5vjSK3nBQU1cIRf0WKysDigqMlzzrzl_bxEE4msOiw/edit?tab=t.0
|
// https://docs.google.com/document/d/1r5vjSK3nBQU1cIRf0WKysDigqMlzzrzl_bxEE4msOiw/edit?tab=t.0
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{message::Message, providers::Usage};
|
use crate::{message::Message, providers::Usage};
|
||||||
|
use crate::{model::ModelConfig, providers::errors::ProviderError};
|
||||||
|
|
||||||
|
pub struct CompletionRequest<'a> {
|
||||||
|
pub provider_name: &'a str,
|
||||||
|
pub model_config: ModelConfig,
|
||||||
|
pub system_preamble: &'a str,
|
||||||
|
pub messages: &'a [Message],
|
||||||
|
pub extensions: &'a [ExtensionConfig],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> CompletionRequest<'a> {
|
||||||
|
pub fn new(
|
||||||
|
provider_name: &'a str,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
system_preamble: &'a str,
|
||||||
|
messages: &'a [Message],
|
||||||
|
extensions: &'a [ExtensionConfig],
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
provider_name,
|
||||||
|
model_config,
|
||||||
|
system_preamble,
|
||||||
|
messages,
|
||||||
|
extensions,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum CompletionError {
|
||||||
|
#[error("failed to create provider: {0}")]
|
||||||
|
UnknownProvider(String),
|
||||||
|
|
||||||
|
#[error("provider error: {0}")]
|
||||||
|
Provider(#[from] ProviderError),
|
||||||
|
|
||||||
|
#[error("template rendering error: {0}")]
|
||||||
|
Template(#[from] minijinja::Error),
|
||||||
|
|
||||||
|
#[error("json serialization error: {0}")]
|
||||||
|
Json(#[from] serde_json::Error),
|
||||||
|
|
||||||
|
#[error("tool not found error: {0}")]
|
||||||
|
ToolNotFound(String),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct CompletionResponse {
|
pub struct CompletionResponse {
|
||||||
|
|||||||
Reference in New Issue
Block a user