feat(cli): add cost estimation per provider for Goose CLI (#3330)

Co-authored-by: Chaitanya Rahalkar <chaitanyarahalkar4@gmail.com>
This commit is contained in:
Gary Zhou
2025-07-11 14:54:45 -04:00
committed by GitHub
parent 4ba991e52e
commit c45e0ef62d
3 changed files with 98 additions and 0 deletions

View File

@@ -23,6 +23,7 @@ use goose::agents::extension::{Envs, ExtensionConfig};
use goose::agents::{Agent, SessionConfig};
use goose::config::Config;
use goose::message::{Message, MessageContent};
use goose::providers::pricing::initialize_pricing_cache;
use goose::session;
use input::InputResult;
use mcp_core::handler::ToolError;
@@ -1305,11 +1306,40 @@ impl Session {
let model_config = provider.get_model_config();
let context_limit = model_config.context_limit();
let config = Config::global();
let show_cost = config
.get_param::<bool>("GOOSE_CLI_SHOW_COST")
.unwrap_or(false);
let provider_name = config
.get_param::<String>("GOOSE_PROVIDER")
.unwrap_or_else(|_| "unknown".to_string());
// Initialize pricing cache on startup
tracing::info!("Initializing pricing cache...");
if let Err(e) = initialize_pricing_cache().await {
tracing::warn!(
"Failed to initialize pricing cache: {e}. Pricing data may not be available."
);
}
match self.get_metadata() {
Ok(metadata) => {
let total_tokens = metadata.total_tokens.unwrap_or(0) as usize;
output::display_context_usage(total_tokens, context_limit);
if show_cost {
let input_tokens = metadata.input_tokens.unwrap_or(0) as usize;
let output_tokens = metadata.output_tokens.unwrap_or(0) as usize;
output::display_cost_usage(
&provider_name,
&model_config.model_name,
input_tokens,
output_tokens,
)
.await;
}
}
Err(_) => {
output::display_context_usage(0, context_limit);

View File

@@ -2,9 +2,11 @@ use bat::WrappingMode;
use console::{style, Color};
use goose::config::Config;
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
use goose::providers::pricing::get_model_pricing;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use mcp_core::prompt::PromptArgument;
use mcp_core::tool::ToolCall;
use regex::Regex;
use serde_json::Value;
use std::cell::RefCell;
use std::collections::HashMap;
@@ -668,6 +670,68 @@ pub fn display_context_usage(total_tokens: usize, context_limit: usize) {
);
}
fn normalize_model_name(model: &str) -> String {
let mut result = model.to_string();
// Remove "-latest" suffix
if result.ends_with("-latest") {
result = result.strip_suffix("-latest").unwrap().to_string();
}
// Remove date-like suffixes: -YYYYMMDD
let re_date = Regex::new(r"-\d{8}$").unwrap();
if re_date.is_match(&result) {
result = re_date.replace(&result, "").to_string();
}
// Convert version numbers like -3-5- to -3.5- (e.g., claude-3-5-haiku -> claude-3.5-haiku)
let re_version = Regex::new(r"-(\d+)-(\d+)-").unwrap();
if re_version.is_match(&result) {
result = re_version.replace(&result, "-$1.$2-").to_string();
}
result
}
async fn estimate_cost_usd(
provider: &str,
model: &str,
input_tokens: usize,
output_tokens: usize,
) -> Option<f64> {
// Use the pricing module's get_model_pricing which handles model name mapping internally
let cleaned_model = normalize_model_name(model);
let pricing_info = get_model_pricing(provider, &cleaned_model).await;
match pricing_info {
Some(pricing) => {
let input_cost = pricing.input_cost * input_tokens as f64;
let output_cost = pricing.output_cost * output_tokens as f64;
Some(input_cost + output_cost)
}
None => None,
}
}
/// Display cost information, if price data is available.
pub async fn display_cost_usage(
provider: &str,
model: &str,
input_tokens: usize,
output_tokens: usize,
) {
if let Some(cost) = estimate_cost_usd(provider, model, input_tokens, output_tokens).await {
use console::style;
println!(
"Cost: {} USD ({} tokens: in {}, out {})",
style(format!("${:.4}", cost)).cyan(),
input_tokens + output_tokens,
input_tokens,
output_tokens
);
}
}
pub struct McpSpinners {
bars: HashMap<String, ProgressBar>,
log_spinner: Option<ProgressBar>,

View File

@@ -153,6 +153,7 @@ These variables control how Goose handles [tool permissions](/docs/guides/managi
| `GOOSE_TOOLSHIM_OLLAMA_MODEL` | Specifies the model for [tool call interpretation](/docs/experimental/ollama) | Model name (e.g. llama3.2, qwen2.5) | System default |
| `GOOSE_CLI_MIN_PRIORITY` | Controls verbosity of [tool output](/docs/guides/managing-tools/adjust-tool-output) | Float between 0.0 and 1.0 | 0.0 |
| `GOOSE_CLI_TOOL_PARAMS_TRUNCATION_MAX_LENGTH` | Maximum length for tool parameter values before truncation in CLI output (not in debug mode) | Integer | 40 |
| `GOOSE_CLI_SHOW_COST` | Toggles display of model cost estimates in CLI output | "true", "1" (case insensitive) to enable | false |
**Examples**
@@ -163,6 +164,9 @@ export GOOSE_TOOLSHIM_OLLAMA_MODEL=llama3.2
export GOOSE_MODE="auto"
export GOOSE_CLI_MIN_PRIORITY=0.2 # Show only medium and high importance output
export GOOSE_CLI_TOOL_PARAMS_MAX_LENGTH=100 # Show up to 100 characters for tool parameters in CLI output
# Enable model cost display in CLI
export GOOSE_CLI_SHOW_COST=true
```
### Enhanced Code Editing