mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 22:24:21 +01:00
feat(cli): add cost estimation per provider for Goose CLI (#3330)
Co-authored-by: Chaitanya Rahalkar <chaitanyarahalkar4@gmail.com>
This commit is contained in:
@@ -23,6 +23,7 @@ use goose::agents::extension::{Envs, ExtensionConfig};
|
||||
use goose::agents::{Agent, SessionConfig};
|
||||
use goose::config::Config;
|
||||
use goose::message::{Message, MessageContent};
|
||||
use goose::providers::pricing::initialize_pricing_cache;
|
||||
use goose::session;
|
||||
use input::InputResult;
|
||||
use mcp_core::handler::ToolError;
|
||||
@@ -1305,11 +1306,40 @@ impl Session {
|
||||
let model_config = provider.get_model_config();
|
||||
let context_limit = model_config.context_limit();
|
||||
|
||||
let config = Config::global();
|
||||
let show_cost = config
|
||||
.get_param::<bool>("GOOSE_CLI_SHOW_COST")
|
||||
.unwrap_or(false);
|
||||
|
||||
let provider_name = config
|
||||
.get_param::<String>("GOOSE_PROVIDER")
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
|
||||
// Initialize pricing cache on startup
|
||||
tracing::info!("Initializing pricing cache...");
|
||||
if let Err(e) = initialize_pricing_cache().await {
|
||||
tracing::warn!(
|
||||
"Failed to initialize pricing cache: {e}. Pricing data may not be available."
|
||||
);
|
||||
}
|
||||
|
||||
match self.get_metadata() {
|
||||
Ok(metadata) => {
|
||||
let total_tokens = metadata.total_tokens.unwrap_or(0) as usize;
|
||||
|
||||
output::display_context_usage(total_tokens, context_limit);
|
||||
|
||||
if show_cost {
|
||||
let input_tokens = metadata.input_tokens.unwrap_or(0) as usize;
|
||||
let output_tokens = metadata.output_tokens.unwrap_or(0) as usize;
|
||||
output::display_cost_usage(
|
||||
&provider_name,
|
||||
&model_config.model_name,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
output::display_context_usage(0, context_limit);
|
||||
|
||||
@@ -2,9 +2,11 @@ use bat::WrappingMode;
|
||||
use console::{style, Color};
|
||||
use goose::config::Config;
|
||||
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
|
||||
use goose::providers::pricing::get_model_pricing;
|
||||
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
|
||||
use mcp_core::prompt::PromptArgument;
|
||||
use mcp_core::tool::ToolCall;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
@@ -668,6 +670,68 @@ pub fn display_context_usage(total_tokens: usize, context_limit: usize) {
|
||||
);
|
||||
}
|
||||
|
||||
fn normalize_model_name(model: &str) -> String {
|
||||
let mut result = model.to_string();
|
||||
|
||||
// Remove "-latest" suffix
|
||||
if result.ends_with("-latest") {
|
||||
result = result.strip_suffix("-latest").unwrap().to_string();
|
||||
}
|
||||
|
||||
// Remove date-like suffixes: -YYYYMMDD
|
||||
let re_date = Regex::new(r"-\d{8}$").unwrap();
|
||||
if re_date.is_match(&result) {
|
||||
result = re_date.replace(&result, "").to_string();
|
||||
}
|
||||
|
||||
// Convert version numbers like -3-5- to -3.5- (e.g., claude-3-5-haiku -> claude-3.5-haiku)
|
||||
let re_version = Regex::new(r"-(\d+)-(\d+)-").unwrap();
|
||||
if re_version.is_match(&result) {
|
||||
result = re_version.replace(&result, "-$1.$2-").to_string();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
async fn estimate_cost_usd(
|
||||
provider: &str,
|
||||
model: &str,
|
||||
input_tokens: usize,
|
||||
output_tokens: usize,
|
||||
) -> Option<f64> {
|
||||
// Use the pricing module's get_model_pricing which handles model name mapping internally
|
||||
let cleaned_model = normalize_model_name(model);
|
||||
let pricing_info = get_model_pricing(provider, &cleaned_model).await;
|
||||
|
||||
match pricing_info {
|
||||
Some(pricing) => {
|
||||
let input_cost = pricing.input_cost * input_tokens as f64;
|
||||
let output_cost = pricing.output_cost * output_tokens as f64;
|
||||
Some(input_cost + output_cost)
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Display cost information, if price data is available.
|
||||
pub async fn display_cost_usage(
|
||||
provider: &str,
|
||||
model: &str,
|
||||
input_tokens: usize,
|
||||
output_tokens: usize,
|
||||
) {
|
||||
if let Some(cost) = estimate_cost_usd(provider, model, input_tokens, output_tokens).await {
|
||||
use console::style;
|
||||
println!(
|
||||
"Cost: {} USD ({} tokens: in {}, out {})",
|
||||
style(format!("${:.4}", cost)).cyan(),
|
||||
input_tokens + output_tokens,
|
||||
input_tokens,
|
||||
output_tokens
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct McpSpinners {
|
||||
bars: HashMap<String, ProgressBar>,
|
||||
log_spinner: Option<ProgressBar>,
|
||||
|
||||
@@ -153,6 +153,7 @@ These variables control how Goose handles [tool permissions](/docs/guides/managi
|
||||
| `GOOSE_TOOLSHIM_OLLAMA_MODEL` | Specifies the model for [tool call interpretation](/docs/experimental/ollama) | Model name (e.g. llama3.2, qwen2.5) | System default |
|
||||
| `GOOSE_CLI_MIN_PRIORITY` | Controls verbosity of [tool output](/docs/guides/managing-tools/adjust-tool-output) | Float between 0.0 and 1.0 | 0.0 |
|
||||
| `GOOSE_CLI_TOOL_PARAMS_TRUNCATION_MAX_LENGTH` | Maximum length for tool parameter values before truncation in CLI output (not in debug mode) | Integer | 40 |
|
||||
| `GOOSE_CLI_SHOW_COST` | Toggles display of model cost estimates in CLI output | "true", "1" (case insensitive) to enable | false |
|
||||
|
||||
**Examples**
|
||||
|
||||
@@ -163,6 +164,9 @@ export GOOSE_TOOLSHIM_OLLAMA_MODEL=llama3.2
|
||||
export GOOSE_MODE="auto"
|
||||
export GOOSE_CLI_MIN_PRIORITY=0.2 # Show only medium and high importance output
|
||||
export GOOSE_CLI_TOOL_PARAMS_MAX_LENGTH=100 # Show up to 100 characters for tool parameters in CLI output
|
||||
|
||||
# Enable model cost display in CLI
|
||||
export GOOSE_CLI_SHOW_COST=true
|
||||
```
|
||||
|
||||
### Enhanced Code Editing
|
||||
|
||||
Reference in New Issue
Block a user