mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 06:34:26 +01:00
feat(cli): add cost estimation per provider for Goose CLI (#3330)
Co-authored-by: Chaitanya Rahalkar <chaitanyarahalkar4@gmail.com>
This commit is contained in:
@@ -23,6 +23,7 @@ use goose::agents::extension::{Envs, ExtensionConfig};
|
|||||||
use goose::agents::{Agent, SessionConfig};
|
use goose::agents::{Agent, SessionConfig};
|
||||||
use goose::config::Config;
|
use goose::config::Config;
|
||||||
use goose::message::{Message, MessageContent};
|
use goose::message::{Message, MessageContent};
|
||||||
|
use goose::providers::pricing::initialize_pricing_cache;
|
||||||
use goose::session;
|
use goose::session;
|
||||||
use input::InputResult;
|
use input::InputResult;
|
||||||
use mcp_core::handler::ToolError;
|
use mcp_core::handler::ToolError;
|
||||||
@@ -1305,11 +1306,40 @@ impl Session {
|
|||||||
let model_config = provider.get_model_config();
|
let model_config = provider.get_model_config();
|
||||||
let context_limit = model_config.context_limit();
|
let context_limit = model_config.context_limit();
|
||||||
|
|
||||||
|
let config = Config::global();
|
||||||
|
let show_cost = config
|
||||||
|
.get_param::<bool>("GOOSE_CLI_SHOW_COST")
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
let provider_name = config
|
||||||
|
.get_param::<String>("GOOSE_PROVIDER")
|
||||||
|
.unwrap_or_else(|_| "unknown".to_string());
|
||||||
|
|
||||||
|
// Initialize pricing cache on startup
|
||||||
|
tracing::info!("Initializing pricing cache...");
|
||||||
|
if let Err(e) = initialize_pricing_cache().await {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to initialize pricing cache: {e}. Pricing data may not be available."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
match self.get_metadata() {
|
match self.get_metadata() {
|
||||||
Ok(metadata) => {
|
Ok(metadata) => {
|
||||||
let total_tokens = metadata.total_tokens.unwrap_or(0) as usize;
|
let total_tokens = metadata.total_tokens.unwrap_or(0) as usize;
|
||||||
|
|
||||||
output::display_context_usage(total_tokens, context_limit);
|
output::display_context_usage(total_tokens, context_limit);
|
||||||
|
|
||||||
|
if show_cost {
|
||||||
|
let input_tokens = metadata.input_tokens.unwrap_or(0) as usize;
|
||||||
|
let output_tokens = metadata.output_tokens.unwrap_or(0) as usize;
|
||||||
|
output::display_cost_usage(
|
||||||
|
&provider_name,
|
||||||
|
&model_config.model_name,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
output::display_context_usage(0, context_limit);
|
output::display_context_usage(0, context_limit);
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ use bat::WrappingMode;
|
|||||||
use console::{style, Color};
|
use console::{style, Color};
|
||||||
use goose::config::Config;
|
use goose::config::Config;
|
||||||
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
|
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
|
||||||
|
use goose::providers::pricing::get_model_pricing;
|
||||||
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
|
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
|
||||||
use mcp_core::prompt::PromptArgument;
|
use mcp_core::prompt::PromptArgument;
|
||||||
use mcp_core::tool::ToolCall;
|
use mcp_core::tool::ToolCall;
|
||||||
|
use regex::Regex;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@@ -668,6 +670,68 @@ pub fn display_context_usage(total_tokens: usize, context_limit: usize) {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn normalize_model_name(model: &str) -> String {
|
||||||
|
let mut result = model.to_string();
|
||||||
|
|
||||||
|
// Remove "-latest" suffix
|
||||||
|
if result.ends_with("-latest") {
|
||||||
|
result = result.strip_suffix("-latest").unwrap().to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove date-like suffixes: -YYYYMMDD
|
||||||
|
let re_date = Regex::new(r"-\d{8}$").unwrap();
|
||||||
|
if re_date.is_match(&result) {
|
||||||
|
result = re_date.replace(&result, "").to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert version numbers like -3-5- to -3.5- (e.g., claude-3-5-haiku -> claude-3.5-haiku)
|
||||||
|
let re_version = Regex::new(r"-(\d+)-(\d+)-").unwrap();
|
||||||
|
if re_version.is_match(&result) {
|
||||||
|
result = re_version.replace(&result, "-$1.$2-").to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn estimate_cost_usd(
|
||||||
|
provider: &str,
|
||||||
|
model: &str,
|
||||||
|
input_tokens: usize,
|
||||||
|
output_tokens: usize,
|
||||||
|
) -> Option<f64> {
|
||||||
|
// Use the pricing module's get_model_pricing which handles model name mapping internally
|
||||||
|
let cleaned_model = normalize_model_name(model);
|
||||||
|
let pricing_info = get_model_pricing(provider, &cleaned_model).await;
|
||||||
|
|
||||||
|
match pricing_info {
|
||||||
|
Some(pricing) => {
|
||||||
|
let input_cost = pricing.input_cost * input_tokens as f64;
|
||||||
|
let output_cost = pricing.output_cost * output_tokens as f64;
|
||||||
|
Some(input_cost + output_cost)
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Display cost information, if price data is available.
|
||||||
|
pub async fn display_cost_usage(
|
||||||
|
provider: &str,
|
||||||
|
model: &str,
|
||||||
|
input_tokens: usize,
|
||||||
|
output_tokens: usize,
|
||||||
|
) {
|
||||||
|
if let Some(cost) = estimate_cost_usd(provider, model, input_tokens, output_tokens).await {
|
||||||
|
use console::style;
|
||||||
|
println!(
|
||||||
|
"Cost: {} USD ({} tokens: in {}, out {})",
|
||||||
|
style(format!("${:.4}", cost)).cyan(),
|
||||||
|
input_tokens + output_tokens,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct McpSpinners {
|
pub struct McpSpinners {
|
||||||
bars: HashMap<String, ProgressBar>,
|
bars: HashMap<String, ProgressBar>,
|
||||||
log_spinner: Option<ProgressBar>,
|
log_spinner: Option<ProgressBar>,
|
||||||
|
|||||||
@@ -153,6 +153,7 @@ These variables control how Goose handles [tool permissions](/docs/guides/managi
|
|||||||
| `GOOSE_TOOLSHIM_OLLAMA_MODEL` | Specifies the model for [tool call interpretation](/docs/experimental/ollama) | Model name (e.g. llama3.2, qwen2.5) | System default |
|
| `GOOSE_TOOLSHIM_OLLAMA_MODEL` | Specifies the model for [tool call interpretation](/docs/experimental/ollama) | Model name (e.g. llama3.2, qwen2.5) | System default |
|
||||||
| `GOOSE_CLI_MIN_PRIORITY` | Controls verbosity of [tool output](/docs/guides/managing-tools/adjust-tool-output) | Float between 0.0 and 1.0 | 0.0 |
|
| `GOOSE_CLI_MIN_PRIORITY` | Controls verbosity of [tool output](/docs/guides/managing-tools/adjust-tool-output) | Float between 0.0 and 1.0 | 0.0 |
|
||||||
| `GOOSE_CLI_TOOL_PARAMS_TRUNCATION_MAX_LENGTH` | Maximum length for tool parameter values before truncation in CLI output (not in debug mode) | Integer | 40 |
|
| `GOOSE_CLI_TOOL_PARAMS_TRUNCATION_MAX_LENGTH` | Maximum length for tool parameter values before truncation in CLI output (not in debug mode) | Integer | 40 |
|
||||||
|
| `GOOSE_CLI_SHOW_COST` | Toggles display of model cost estimates in CLI output | "true", "1" (case insensitive) to enable | false |
|
||||||
|
|
||||||
**Examples**
|
**Examples**
|
||||||
|
|
||||||
@@ -163,6 +164,9 @@ export GOOSE_TOOLSHIM_OLLAMA_MODEL=llama3.2
|
|||||||
export GOOSE_MODE="auto"
|
export GOOSE_MODE="auto"
|
||||||
export GOOSE_CLI_MIN_PRIORITY=0.2 # Show only medium and high importance output
|
export GOOSE_CLI_MIN_PRIORITY=0.2 # Show only medium and high importance output
|
||||||
export GOOSE_CLI_TOOL_PARAMS_MAX_LENGTH=100 # Show up to 100 characters for tool parameters in CLI output
|
export GOOSE_CLI_TOOL_PARAMS_MAX_LENGTH=100 # Show up to 100 characters for tool parameters in CLI output
|
||||||
|
|
||||||
|
# Enable model cost display in CLI
|
||||||
|
export GOOSE_CLI_SHOW_COST=true
|
||||||
```
|
```
|
||||||
|
|
||||||
### Enhanced Code Editing
|
### Enhanced Code Editing
|
||||||
|
|||||||
Reference in New Issue
Block a user