feat: add runtime metrics to completion response (#2404)

This commit is contained in:
Salman Mohammed
2025-05-02 13:02:26 -03:00
committed by GitHub
parent a9ebf3a9c4
commit fa93f5fbec
3 changed files with 95 additions and 45 deletions

View File

@@ -5,53 +5,13 @@ use std::collections::HashMap;
use goose::message::Message;
use goose::model::ModelConfig;
use goose::providers::base::ProviderUsage;
use goose::providers::create;
use goose::providers::errors::ProviderError;
use mcp_core::tool::Tool;
use std::time::Instant;
use crate::prompt_template;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
message: Message,
usage: ProviderUsage,
}
impl CompletionResponse {
pub fn new(message: Message, usage: ProviderUsage) -> Self {
Self { message, usage }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Extension {
name: String,
instructions: Option<String>,
tools: Vec<Tool>,
}
impl Extension {
pub fn new(name: String, instructions: Option<String>, tools: Vec<Tool>) -> Self {
Self {
name,
instructions,
tools,
}
}
pub fn get_prefixed_tools(&self) -> Vec<Tool> {
self.tools
.iter()
.map(|tool| {
let mut prefixed_tool = tool.clone();
prefixed_tool.name = format!("{}__{}", self.name, tool.name);
prefixed_tool
})
.collect()
}
}
use crate::{CompletionResponse, Extension, RuntimeMetrics};
/// Public API for the Goose LLM completion function
pub async fn completion(
@@ -61,6 +21,7 @@ pub async fn completion(
messages: &[Message],
extensions: &[Extension],
) -> Result<CompletionResponse, ProviderError> {
let start_total = Instant::now();
let provider = create(provider, model_config).unwrap();
let system_prompt = construct_system_prompt(system_preamble, extensions);
// println!("\nSystem prompt: {}\n", system_prompt);
@@ -69,8 +30,24 @@ pub async fn completion(
.iter()
.flat_map(|ext| ext.get_prefixed_tools())
.collect::<Vec<_>>();
let start_provider = Instant::now();
let (response, usage) = provider.complete(&system_prompt, messages, &tools).await?;
let result = CompletionResponse::new(response.clone(), usage.clone());
let total_time_ms_provider = start_provider.elapsed().as_millis();
let total_time_ms = start_total.elapsed().as_millis();
let tokens_per_second = usage.usage.total_tokens.and_then(|toks| {
if total_time_ms_provider > 0 {
Some(toks as f64 / (total_time_ms_provider as f64 / 1000.0))
} else {
None
}
});
let runtime_metrics =
RuntimeMetrics::new(total_time_ms, total_time_ms_provider, tokens_per_second);
let result = CompletionResponse::new(response.clone(), usage.clone(), runtime_metrics);
Ok(result)
}

View File

@@ -1,3 +1,6 @@
mod completion;
mod prompt_template;
pub use completion::{completion, CompletionResponse, Extension};
mod types;
pub use completion::completion;
pub use types::{CompletionResponse, Extension, RuntimeMetrics};

View File

@@ -0,0 +1,70 @@
use goose::message::Message;
use goose::providers::base::ProviderUsage;
use mcp_core::tool::Tool;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
message: Message,
usage: ProviderUsage,
runtime_metrics: RuntimeMetrics,
}
impl CompletionResponse {
pub fn new(message: Message, usage: ProviderUsage, runtime_metrics: RuntimeMetrics) -> Self {
Self {
message,
usage,
runtime_metrics,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuntimeMetrics {
pub total_time_ms: u128,
pub total_time_ms_provider: u128,
pub tokens_per_second: Option<f64>,
}
impl RuntimeMetrics {
pub fn new(
total_time_ms: u128,
total_time_ms_provider: u128,
tokens_per_second: Option<f64>,
) -> Self {
Self {
total_time_ms,
total_time_ms_provider,
tokens_per_second,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Extension {
name: String,
instructions: Option<String>,
tools: Vec<Tool>,
}
impl Extension {
pub fn new(name: String, instructions: Option<String>, tools: Vec<Tool>) -> Self {
Self {
name,
instructions,
tools,
}
}
pub fn get_prefixed_tools(&self) -> Vec<Tool> {
self.tools
.iter()
.map(|tool| {
let mut prefixed_tool = tool.clone();
prefixed_tool.name = format!("{}__{}", self.name, tool.name);
prefixed_tool
})
.collect()
}
}