From fa93f5fbec46ef3d144ef183c0aa6932ec6551bc Mon Sep 17 00:00:00 2001 From: Salman Mohammed Date: Fri, 2 May 2025 13:02:26 -0300 Subject: [PATCH] feat: add runtime metrics to completion response (#2404) --- crates/goose-llm/src/completion.rs | 65 +++++++++------------------ crates/goose-llm/src/lib.rs | 5 ++- crates/goose-llm/src/types.rs | 70 ++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 45 deletions(-) create mode 100644 crates/goose-llm/src/types.rs diff --git a/crates/goose-llm/src/completion.rs b/crates/goose-llm/src/completion.rs index 649111b0..999d7748 100644 --- a/crates/goose-llm/src/completion.rs +++ b/crates/goose-llm/src/completion.rs @@ -5,53 +5,13 @@ use std::collections::HashMap; use goose::message::Message; use goose::model::ModelConfig; -use goose::providers::base::ProviderUsage; use goose::providers::create; use goose::providers::errors::ProviderError; -use mcp_core::tool::Tool; + +use std::time::Instant; use crate::prompt_template; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CompletionResponse { - message: Message, - usage: ProviderUsage, -} - -impl CompletionResponse { - pub fn new(message: Message, usage: ProviderUsage) -> Self { - Self { message, usage } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Extension { - name: String, - instructions: Option, - tools: Vec, -} - -impl Extension { - pub fn new(name: String, instructions: Option, tools: Vec) -> Self { - Self { - name, - instructions, - tools, - } - } - - pub fn get_prefixed_tools(&self) -> Vec { - self.tools - .iter() - .map(|tool| { - let mut prefixed_tool = tool.clone(); - prefixed_tool.name = format!("{}__{}", self.name, tool.name); - prefixed_tool - }) - .collect() - } -} +use crate::{CompletionResponse, Extension, RuntimeMetrics}; /// Public API for the Goose LLM completion function pub async fn completion( @@ -61,6 +21,7 @@ pub async fn completion( messages: &[Message], extensions: &[Extension], ) -> Result { + let start_total = Instant::now(); let provider = create(provider, model_config).unwrap(); let system_prompt = construct_system_prompt(system_preamble, extensions); // println!("\nSystem prompt: {}\n", system_prompt); @@ -69,8 +30,24 @@ pub async fn completion( .iter() .flat_map(|ext| ext.get_prefixed_tools()) .collect::>(); + + let start_provider = Instant::now(); let (response, usage) = provider.complete(&system_prompt, messages, &tools).await?; - let result = CompletionResponse::new(response.clone(), usage.clone()); + let total_time_ms_provider = start_provider.elapsed().as_millis(); + let total_time_ms = start_total.elapsed().as_millis(); + + let tokens_per_second = usage.usage.total_tokens.and_then(|toks| { + if total_time_ms_provider > 0 { + Some(toks as f64 / (total_time_ms_provider as f64 / 1000.0)) + } else { + None + } + }); + + let runtime_metrics = + RuntimeMetrics::new(total_time_ms, total_time_ms_provider, tokens_per_second); + + let result = CompletionResponse::new(response.clone(), usage.clone(), runtime_metrics); Ok(result) } diff --git a/crates/goose-llm/src/lib.rs b/crates/goose-llm/src/lib.rs index ef3b28bc..9fbdbc4b 100644 --- a/crates/goose-llm/src/lib.rs +++ b/crates/goose-llm/src/lib.rs @@ -1,3 +1,6 @@ mod completion; mod prompt_template; -pub use completion::{completion, CompletionResponse, Extension}; +mod types; + +pub use completion::completion; +pub use types::{CompletionResponse, Extension, RuntimeMetrics}; diff --git a/crates/goose-llm/src/types.rs b/crates/goose-llm/src/types.rs new file mode 100644 index 00000000..cdc52232 --- /dev/null +++ b/crates/goose-llm/src/types.rs @@ -0,0 +1,70 @@ +use goose::message::Message; +use goose::providers::base::ProviderUsage; +use mcp_core::tool::Tool; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletionResponse { + message: Message, + usage: ProviderUsage, + runtime_metrics: RuntimeMetrics, +} + +impl CompletionResponse { + pub fn new(message: Message, usage: ProviderUsage, runtime_metrics: RuntimeMetrics) -> Self { + Self { + message, + usage, + runtime_metrics, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RuntimeMetrics { + pub total_time_ms: u128, + pub total_time_ms_provider: u128, + pub tokens_per_second: Option, +} + +impl RuntimeMetrics { + pub fn new( + total_time_ms: u128, + total_time_ms_provider: u128, + tokens_per_second: Option, + ) -> Self { + Self { + total_time_ms, + total_time_ms_provider, + tokens_per_second, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Extension { + name: String, + instructions: Option, + tools: Vec, +} + +impl Extension { + pub fn new(name: String, instructions: Option, tools: Vec) -> Self { + Self { + name, + instructions, + tools, + } + } + + pub fn get_prefixed_tools(&self) -> Vec { + self.tools + .iter() + .map(|tool| { + let mut prefixed_tool = tool.clone(); + prefixed_tool.name = format!("{}__{}", self.name, tool.name); + prefixed_tool + }) + .collect() + } +}