diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index bec2a2db..825dd17c 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -23,6 +23,7 @@ use goose::agents::extension::{Envs, ExtensionConfig}; use goose::agents::{Agent, SessionConfig}; use goose::config::Config; use goose::message::{Message, MessageContent}; +use goose::providers::pricing::initialize_pricing_cache; use goose::session; use input::InputResult; use mcp_core::handler::ToolError; @@ -1305,11 +1306,40 @@ impl Session { let model_config = provider.get_model_config(); let context_limit = model_config.context_limit(); + let config = Config::global(); + let show_cost = config + .get_param::("GOOSE_CLI_SHOW_COST") + .unwrap_or(false); + + let provider_name = config + .get_param::("GOOSE_PROVIDER") + .unwrap_or_else(|_| "unknown".to_string()); + + // Initialize pricing cache on startup + tracing::info!("Initializing pricing cache..."); + if let Err(e) = initialize_pricing_cache().await { + tracing::warn!( + "Failed to initialize pricing cache: {e}. Pricing data may not be available." + ); + } + match self.get_metadata() { Ok(metadata) => { let total_tokens = metadata.total_tokens.unwrap_or(0) as usize; output::display_context_usage(total_tokens, context_limit); + + if show_cost { + let input_tokens = metadata.input_tokens.unwrap_or(0) as usize; + let output_tokens = metadata.output_tokens.unwrap_or(0) as usize; + output::display_cost_usage( + &provider_name, + &model_config.model_name, + input_tokens, + output_tokens, + ) + .await; + } } Err(_) => { output::display_context_usage(0, context_limit); diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index deeea706..fc8d81aa 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -2,9 +2,11 @@ use bat::WrappingMode; use console::{style, Color}; use goose::config::Config; use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; +use goose::providers::pricing::get_model_pricing; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use mcp_core::prompt::PromptArgument; use mcp_core::tool::ToolCall; +use regex::Regex; use serde_json::Value; use std::cell::RefCell; use std::collections::HashMap; @@ -668,6 +670,68 @@ pub fn display_context_usage(total_tokens: usize, context_limit: usize) { ); } +fn normalize_model_name(model: &str) -> String { + let mut result = model.to_string(); + + // Remove "-latest" suffix + if result.ends_with("-latest") { + result = result.strip_suffix("-latest").unwrap().to_string(); + } + + // Remove date-like suffixes: -YYYYMMDD + let re_date = Regex::new(r"-\d{8}$").unwrap(); + if re_date.is_match(&result) { + result = re_date.replace(&result, "").to_string(); + } + + // Convert version numbers like -3-5- to -3.5- (e.g., claude-3-5-haiku -> claude-3.5-haiku) + let re_version = Regex::new(r"-(\d+)-(\d+)-").unwrap(); + if re_version.is_match(&result) { + result = re_version.replace(&result, "-$1.$2-").to_string(); + } + + result +} + +async fn estimate_cost_usd( + provider: &str, + model: &str, + input_tokens: usize, + output_tokens: usize, +) -> Option { + // Use the pricing module's get_model_pricing which handles model name mapping internally + let cleaned_model = normalize_model_name(model); + let pricing_info = get_model_pricing(provider, &cleaned_model).await; + + match pricing_info { + Some(pricing) => { + let input_cost = pricing.input_cost * input_tokens as f64; + let output_cost = pricing.output_cost * output_tokens as f64; + Some(input_cost + output_cost) + } + None => None, + } +} + +/// Display cost information, if price data is available. +pub async fn display_cost_usage( + provider: &str, + model: &str, + input_tokens: usize, + output_tokens: usize, +) { + if let Some(cost) = estimate_cost_usd(provider, model, input_tokens, output_tokens).await { + use console::style; + println!( + "Cost: {} USD ({} tokens: in {}, out {})", + style(format!("${:.4}", cost)).cyan(), + input_tokens + output_tokens, + input_tokens, + output_tokens + ); + } +} + pub struct McpSpinners { bars: HashMap, log_spinner: Option, diff --git a/documentation/docs/guides/environment-variables.md b/documentation/docs/guides/environment-variables.md index 5313d4cd..4f0a2f6b 100644 --- a/documentation/docs/guides/environment-variables.md +++ b/documentation/docs/guides/environment-variables.md @@ -153,6 +153,7 @@ These variables control how Goose handles [tool permissions](/docs/guides/managi | `GOOSE_TOOLSHIM_OLLAMA_MODEL` | Specifies the model for [tool call interpretation](/docs/experimental/ollama) | Model name (e.g. llama3.2, qwen2.5) | System default | | `GOOSE_CLI_MIN_PRIORITY` | Controls verbosity of [tool output](/docs/guides/managing-tools/adjust-tool-output) | Float between 0.0 and 1.0 | 0.0 | | `GOOSE_CLI_TOOL_PARAMS_TRUNCATION_MAX_LENGTH` | Maximum length for tool parameter values before truncation in CLI output (not in debug mode) | Integer | 40 | +| `GOOSE_CLI_SHOW_COST` | Toggles display of model cost estimates in CLI output | "true", "1" (case insensitive) to enable | false | **Examples** @@ -163,6 +164,9 @@ export GOOSE_TOOLSHIM_OLLAMA_MODEL=llama3.2 export GOOSE_MODE="auto" export GOOSE_CLI_MIN_PRIORITY=0.2 # Show only medium and high importance output export GOOSE_CLI_TOOL_PARAMS_MAX_LENGTH=100 # Show up to 100 characters for tool parameters in CLI output + +# Enable model cost display in CLI +export GOOSE_CLI_SHOW_COST=true ``` ### Enhanced Code Editing