mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-08 16:14:24 +01:00
Fix cost tracking accuracy and OpenRouter model pricing (#3189)
Co-authored-by: jack <> Co-authored-by: angiejones <jones.angie@gmail.com>
This commit is contained in:
@@ -13,7 +13,9 @@ use goose::config::{extensions::name_to_key, PermissionManager};
|
||||
use goose::config::{ExtensionConfigManager, ExtensionEntry};
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::ProviderMetadata;
|
||||
use goose::providers::pricing::{get_all_pricing, get_model_pricing, refresh_pricing};
|
||||
use goose::providers::pricing::{
|
||||
get_all_pricing, get_model_pricing, parse_model_id, refresh_pricing,
|
||||
};
|
||||
use goose::providers::providers as get_providers;
|
||||
use goose::{agents::ExtensionConfig, config::permission::PermissionLevel};
|
||||
use http::{HeaderMap, StatusCode};
|
||||
@@ -390,8 +392,22 @@ pub async fn get_pricing(
|
||||
}
|
||||
|
||||
for model_info in &metadata.known_models {
|
||||
// Try to get pricing from cache
|
||||
if let Some(pricing) = get_model_pricing(&metadata.name, &model_info.name).await {
|
||||
// Handle OpenRouter models specially - they store full provider/model names
|
||||
let (lookup_provider, lookup_model) = if metadata.name == "openrouter" {
|
||||
// For OpenRouter, parse the model name to extract real provider/model
|
||||
if let Some((provider, model)) = parse_model_id(&model_info.name) {
|
||||
(provider, model)
|
||||
} else {
|
||||
// Fallback if parsing fails
|
||||
(metadata.name.clone(), model_info.name.clone())
|
||||
}
|
||||
} else {
|
||||
// For other providers, use names as-is
|
||||
(metadata.name.clone(), model_info.name.clone())
|
||||
};
|
||||
|
||||
// Only get pricing from OpenRouter cache
|
||||
if let Some(pricing) = get_model_pricing(&lookup_provider, &lookup_model).await {
|
||||
pricing_data.push(PricingData {
|
||||
provider: metadata.name.clone(),
|
||||
model: model_info.name.clone(),
|
||||
@@ -401,27 +417,12 @@ pub async fn get_pricing(
|
||||
context_length: pricing.context_length,
|
||||
});
|
||||
}
|
||||
// Check if the model has embedded pricing data
|
||||
else if let (Some(input_cost), Some(output_cost)) =
|
||||
(model_info.input_token_cost, model_info.output_token_cost)
|
||||
{
|
||||
pricing_data.push(PricingData {
|
||||
provider: metadata.name.clone(),
|
||||
model: model_info.name.clone(),
|
||||
input_token_cost: input_cost,
|
||||
output_token_cost: output_cost,
|
||||
currency: model_info
|
||||
.currency
|
||||
.clone()
|
||||
.unwrap_or_else(|| "$".to_string()),
|
||||
context_length: Some(model_info.context_limit as u32),
|
||||
});
|
||||
}
|
||||
// No fallback to hardcoded prices
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
tracing::debug!(
|
||||
"Returning pricing for {} models{}",
|
||||
pricing_data.len(),
|
||||
if configured_only {
|
||||
|
||||
41
crates/goose-server/tests/pricing_api_test.rs
Normal file
41
crates/goose-server/tests/pricing_api_test.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use axum::http::StatusCode;
|
||||
use axum::Router;
|
||||
use axum::{body::Body, http::Request};
|
||||
use etcetera::AppStrategy;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use tower::ServiceExt;
|
||||
|
||||
async fn create_test_app() -> Router {
|
||||
let agent = Arc::new(goose::agents::Agent::default());
|
||||
let state = goose_server::AppState::new(agent, "test".to_string()).await;
|
||||
|
||||
// Add scheduler setup like in the existing tests
|
||||
let sched_storage_path = etcetera::choose_app_strategy(goose::config::APP_STRATEGY.clone())
|
||||
.unwrap()
|
||||
.data_dir()
|
||||
.join("schedules.json");
|
||||
let sched = goose::scheduler_factory::SchedulerFactory::create_legacy(sched_storage_path)
|
||||
.await
|
||||
.unwrap();
|
||||
state.set_scheduler(sched).await;
|
||||
|
||||
goose_server::routes::config_management::routes(state)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pricing_endpoint_basic() {
|
||||
// Basic test to ensure pricing endpoint responds correctly
|
||||
let app = create_test_app().await;
|
||||
|
||||
let request = Request::builder()
|
||||
.uri("/config/pricing")
|
||||
.method("POST")
|
||||
.header("content-type", "application/json")
|
||||
.header("x-secret-key", "test")
|
||||
.body(Body::from(json!({"configured_only": true}).to_string()))
|
||||
.unwrap();
|
||||
|
||||
let response = app.oneshot(request).await.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
@@ -130,17 +130,17 @@ impl Provider for AnthropicProvider {
|
||||
"Claude and other models from Anthropic",
|
||||
ANTHROPIC_DEFAULT_MODEL,
|
||||
vec![
|
||||
ModelInfo::with_cost("claude-sonnet-4-latest", 200000, 0.000015, 0.000075),
|
||||
ModelInfo::with_cost("claude-sonnet-4-20250514", 200000, 0.000015, 0.000075),
|
||||
ModelInfo::with_cost("claude-opus-4-latest", 200000, 0.000025, 0.000125),
|
||||
ModelInfo::with_cost("claude-opus-4-20250514", 200000, 0.000025, 0.000125),
|
||||
ModelInfo::with_cost("claude-3-7-sonnet-latest", 200000, 0.000008, 0.000024),
|
||||
ModelInfo::with_cost("claude-3-7-sonnet-20250219", 200000, 0.000008, 0.000024),
|
||||
ModelInfo::with_cost("claude-3-5-sonnet-20241022", 200000, 0.000003, 0.000015),
|
||||
ModelInfo::with_cost("claude-3-5-haiku-20241022", 200000, 0.000001, 0.000005),
|
||||
ModelInfo::with_cost("claude-3-opus-20240229", 200000, 0.000015, 0.000075),
|
||||
ModelInfo::with_cost("claude-3-sonnet-20240229", 200000, 0.000003, 0.000015),
|
||||
ModelInfo::with_cost("claude-3-haiku-20240307", 200000, 0.00000025, 0.00000125),
|
||||
ModelInfo::new("claude-sonnet-4-latest", 200000),
|
||||
ModelInfo::new("claude-sonnet-4-20250514", 200000),
|
||||
ModelInfo::new("claude-opus-4-latest", 200000),
|
||||
ModelInfo::new("claude-opus-4-20250514", 200000),
|
||||
ModelInfo::new("claude-3-7-sonnet-latest", 200000),
|
||||
ModelInfo::new("claude-3-7-sonnet-20250219", 200000),
|
||||
ModelInfo::new("claude-3-5-sonnet-20241022", 200000),
|
||||
ModelInfo::new("claude-3-5-haiku-20241022", 200000),
|
||||
ModelInfo::new("claude-3-opus-20240229", 200000),
|
||||
ModelInfo::new("claude-3-sonnet-20240229", 200000),
|
||||
ModelInfo::new("claude-3-haiku-20240307", 200000),
|
||||
],
|
||||
ANTHROPIC_DOC_URL,
|
||||
vec![
|
||||
|
||||
@@ -237,33 +237,61 @@ pub fn response_to_message(response: Value) -> Result<Message> {
|
||||
pub fn get_usage(data: &Value) -> Result<Usage> {
|
||||
// Extract usage data if available
|
||||
if let Some(usage) = data.get("usage") {
|
||||
// Sum up all input token types:
|
||||
// - input_tokens (fresh/uncached)
|
||||
// - cache_creation_input_tokens (being written to cache)
|
||||
// - cache_read_input_tokens (read from cache)
|
||||
let total_input_tokens = usage
|
||||
// Get all token fields for analysis
|
||||
let input_tokens = usage
|
||||
.get("input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0)
|
||||
+ usage
|
||||
.get("cache_creation_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0)
|
||||
+ usage
|
||||
.get("cache_read_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
.unwrap_or(0);
|
||||
|
||||
let input_tokens = Some(total_input_tokens as i32);
|
||||
let cache_creation_tokens = usage
|
||||
.get("cache_creation_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let cache_read_tokens = usage
|
||||
.get("cache_read_input_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
|
||||
let output_tokens = usage
|
||||
.get("output_tokens")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|v| v as i32);
|
||||
.unwrap_or(0);
|
||||
|
||||
let total_tokens = output_tokens.map(|o| total_input_tokens as i32 + o);
|
||||
// IMPORTANT: Based on the API responses, when caching is used:
|
||||
// - input_tokens is ONLY the new/fresh tokens (can be very small, like 7)
|
||||
// - cache_creation_input_tokens and cache_read_input_tokens are the cached content
|
||||
// - These cached tokens are charged at different rates:
|
||||
// * Fresh input tokens: 100% of regular price
|
||||
// * Cache creation tokens: 125% of regular price
|
||||
// * Cache read tokens: 10% of regular price
|
||||
//
|
||||
// Calculate effective input tokens for cost calculation based on Anthropic's pricing:
|
||||
// - Fresh input tokens: 100% of regular price (1.0x)
|
||||
// - Cache creation tokens: 125% of regular price (1.25x)
|
||||
// - Cache read tokens: 10% of regular price (0.10x)
|
||||
//
|
||||
// The effective input tokens represent the cost-equivalent tokens when multiplied
|
||||
// by the regular input price, ensuring accurate cost calculations in the frontend.
|
||||
let effective_input_tokens = input_tokens as f64 * 1.0
|
||||
+ cache_creation_tokens as f64 * 1.25
|
||||
+ cache_read_tokens as f64 * 0.10;
|
||||
|
||||
Ok(Usage::new(input_tokens, output_tokens, total_tokens))
|
||||
// For token counting purposes, we still want to show the actual total count
|
||||
let _total_actual_tokens = input_tokens + cache_creation_tokens + cache_read_tokens;
|
||||
|
||||
// Return the effective input tokens for cost calculation
|
||||
// This ensures the frontend cost calculation is accurate when multiplying by regular prices
|
||||
let effective_input_i32 = effective_input_tokens.round().clamp(0.0, i32::MAX as f64) as i32;
|
||||
let output_tokens_i32 = output_tokens.min(i32::MAX as u64) as i32;
|
||||
let total_tokens_i32 =
|
||||
(effective_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32;
|
||||
|
||||
Ok(Usage::new(
|
||||
Some(effective_input_i32),
|
||||
Some(output_tokens_i32),
|
||||
Some(total_tokens_i32),
|
||||
))
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"Failed to get usage data: {}",
|
||||
@@ -387,9 +415,9 @@ mod tests {
|
||||
panic!("Expected Text content");
|
||||
}
|
||||
|
||||
assert_eq!(usage.input_tokens, Some(24)); // 12 + 12 + 0
|
||||
assert_eq!(usage.input_tokens, Some(27)); // 12 * 1.0 + 12 * 1.25 = 27 effective tokens
|
||||
assert_eq!(usage.output_tokens, Some(15));
|
||||
assert_eq!(usage.total_tokens, Some(39)); // 24 + 15
|
||||
assert_eq!(usage.total_tokens, Some(42)); // 27 + 15
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -430,9 +458,9 @@ mod tests {
|
||||
panic!("Expected ToolRequest content");
|
||||
}
|
||||
|
||||
assert_eq!(usage.input_tokens, Some(30)); // 15 + 15 + 0
|
||||
assert_eq!(usage.input_tokens, Some(34)); // 15 * 1.0 + 15 * 1.25 = 33.75 → 34 effective tokens
|
||||
assert_eq!(usage.output_tokens, Some(20));
|
||||
assert_eq!(usage.total_tokens, Some(50)); // 30 + 20
|
||||
assert_eq!(usage.total_tokens, Some(54)); // 34 + 20
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -631,4 +659,37 @@ mod tests {
|
||||
// Return the test result
|
||||
result
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_pricing_calculation() -> Result<()> {
|
||||
// Test realistic cache scenario: small fresh input, large cached content
|
||||
let response = json!({
|
||||
"id": "msg_cache_test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "Based on the cached context, here's my response."
|
||||
}],
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": null,
|
||||
"usage": {
|
||||
"input_tokens": 7, // Small fresh input
|
||||
"output_tokens": 50, // Output tokens
|
||||
"cache_creation_input_tokens": 10000, // Large cache creation
|
||||
"cache_read_input_tokens": 5000 // Large cache read
|
||||
}
|
||||
});
|
||||
|
||||
let usage = get_usage(&response)?;
|
||||
|
||||
// Effective input tokens should be:
|
||||
// 7 * 1.0 + 10000 * 1.25 + 5000 * 0.10 = 7 + 12500 + 500 = 13007
|
||||
assert_eq!(usage.input_tokens, Some(13007));
|
||||
assert_eq!(usage.output_tokens, Some(50));
|
||||
assert_eq!(usage.total_tokens, Some(13057)); // 13007 + 50
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,13 +132,13 @@ impl Provider for OpenAiProvider {
|
||||
"GPT-4 and other OpenAI models, including OpenAI compatible ones",
|
||||
OPEN_AI_DEFAULT_MODEL,
|
||||
vec![
|
||||
ModelInfo::with_cost("gpt-4o", 128000, 0.0000025, 0.00001),
|
||||
ModelInfo::with_cost("gpt-4o-mini", 128000, 0.00000015, 0.0000006),
|
||||
ModelInfo::with_cost("gpt-4-turbo", 128000, 0.00001, 0.00003),
|
||||
ModelInfo::with_cost("gpt-3.5-turbo", 16385, 0.0000005, 0.0000015),
|
||||
ModelInfo::with_cost("o1", 200000, 0.000015, 0.00006),
|
||||
ModelInfo::with_cost("o3", 200000, 0.000015, 0.00006), // Using o1 pricing as placeholder
|
||||
ModelInfo::with_cost("o4-mini", 128000, 0.000003, 0.000012), // Using o1-mini pricing as placeholder
|
||||
ModelInfo::new("gpt-4o", 128000),
|
||||
ModelInfo::new("gpt-4o-mini", 128000),
|
||||
ModelInfo::new("gpt-4-turbo", 128000),
|
||||
ModelInfo::new("gpt-3.5-turbo", 16385),
|
||||
ModelInfo::new("o1", 200000),
|
||||
ModelInfo::new("o3", 200000),
|
||||
ModelInfo::new("o4-mini", 128000),
|
||||
],
|
||||
OPEN_AI_DOC_URL,
|
||||
vec![
|
||||
|
||||
@@ -70,13 +70,13 @@ impl PricingCache {
|
||||
let age_days = (now - cached.fetched_at) / (24 * 60 * 60);
|
||||
|
||||
if age_days < CACHE_TTL_DAYS {
|
||||
tracing::info!(
|
||||
tracing::debug!(
|
||||
"Loaded pricing data from disk cache (age: {} days)",
|
||||
age_days
|
||||
);
|
||||
Ok(Some(cached))
|
||||
} else {
|
||||
tracing::info!("Disk cache expired (age: {} days)", age_days);
|
||||
tracing::debug!("Disk cache expired (age: {} days)", age_days);
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
@@ -102,7 +102,7 @@ impl PricingCache {
|
||||
let json_data = serde_json::to_vec_pretty(data)?;
|
||||
tokio::fs::write(&cache_path, json_data).await?;
|
||||
|
||||
tracing::info!("Saved pricing data to disk cache");
|
||||
tracing::debug!("Saved pricing data to disk cache");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -177,7 +177,7 @@ impl PricingCache {
|
||||
.values()
|
||||
.map(|models| models.len())
|
||||
.sum();
|
||||
tracing::info!(
|
||||
tracing::debug!(
|
||||
"Fetched pricing for {} providers with {} total models from OpenRouter",
|
||||
cached_data.pricing.len(),
|
||||
total_models
|
||||
@@ -201,7 +201,7 @@ impl PricingCache {
|
||||
if let Ok(Some(cached)) = self.load_from_disk().await {
|
||||
// Log how many models we have cached
|
||||
let total_models: usize = cached.pricing.values().map(|models| models.len()).sum();
|
||||
tracing::info!(
|
||||
tracing::debug!(
|
||||
"Loaded {} providers with {} total models from disk cache",
|
||||
cached.pricing.len(),
|
||||
total_models
|
||||
@@ -217,7 +217,7 @@ impl PricingCache {
|
||||
}
|
||||
|
||||
// If no disk cache, fetch from OpenRouter
|
||||
tracing::info!("No valid disk cache found, fetching from OpenRouter");
|
||||
tracing::info!("Fetching pricing data from OpenRouter API");
|
||||
self.refresh().await
|
||||
}
|
||||
}
|
||||
@@ -376,6 +376,12 @@ mod tests {
|
||||
Some(("openai".to_string(), "gpt-4".to_string()))
|
||||
);
|
||||
assert_eq!(parse_model_id("invalid-format"), None);
|
||||
|
||||
// Test the specific model causing issues
|
||||
assert_eq!(
|
||||
parse_model_id("anthropic/claude-sonnet-4"),
|
||||
Some(("anthropic".to_string(), "claude-sonnet-4".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -384,4 +390,41 @@ mod tests {
|
||||
assert_eq!(convert_pricing("0.015"), Some(0.015));
|
||||
assert_eq!(convert_pricing("invalid"), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_claude_sonnet_4_pricing_lookup() {
|
||||
// Initialize the cache to load from disk
|
||||
if let Err(e) = initialize_pricing_cache().await {
|
||||
println!("Failed to initialize pricing cache: {}", e);
|
||||
return;
|
||||
}
|
||||
|
||||
// Test lookup for the specific model
|
||||
let pricing = get_model_pricing("anthropic", "claude-sonnet-4").await;
|
||||
|
||||
println!(
|
||||
"Pricing lookup result for anthropic/claude-sonnet-4: {:?}",
|
||||
pricing
|
||||
);
|
||||
|
||||
// Should find pricing data
|
||||
if let Some(pricing_info) = pricing {
|
||||
assert!(pricing_info.input_cost > 0.0);
|
||||
assert!(pricing_info.output_cost > 0.0);
|
||||
println!(
|
||||
"Found pricing: input={}, output={}",
|
||||
pricing_info.input_cost, pricing_info.output_cost
|
||||
);
|
||||
} else {
|
||||
// Print debug info
|
||||
let all_pricing = get_all_pricing().await;
|
||||
if let Some(anthropic_models) = all_pricing.get("anthropic") {
|
||||
println!("Available anthropic models in cache:");
|
||||
for model_name in anthropic_models.keys() {
|
||||
println!(" {}", model_name);
|
||||
}
|
||||
}
|
||||
panic!("Expected to find pricing for anthropic/claude-sonnet-4");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
136
crates/goose/tests/pricing_integration_test.rs
Normal file
136
crates/goose/tests/pricing_integration_test.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
use goose::providers::pricing::{get_model_pricing, initialize_pricing_cache, refresh_pricing};
|
||||
use std::time::Instant;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pricing_cache_performance() {
|
||||
// Initialize the cache
|
||||
let start = Instant::now();
|
||||
initialize_pricing_cache()
|
||||
.await
|
||||
.expect("Failed to initialize pricing cache");
|
||||
let init_duration = start.elapsed();
|
||||
println!("Cache initialization took: {:?}", init_duration);
|
||||
|
||||
// Test fetching pricing for common models (using actual model names from OpenRouter)
|
||||
let models = vec![
|
||||
("anthropic", "claude-3.5-sonnet"),
|
||||
("openai", "gpt-4o"),
|
||||
("openai", "gpt-4o-mini"),
|
||||
("google", "gemini-flash-1.5"),
|
||||
("anthropic", "claude-sonnet-4"),
|
||||
];
|
||||
|
||||
// First fetch (should hit cache)
|
||||
let start = Instant::now();
|
||||
for (provider, model) in &models {
|
||||
let pricing = get_model_pricing(provider, model).await;
|
||||
assert!(
|
||||
pricing.is_some(),
|
||||
"Expected pricing for {}/{}",
|
||||
provider,
|
||||
model
|
||||
);
|
||||
}
|
||||
let first_fetch_duration = start.elapsed();
|
||||
println!(
|
||||
"First fetch of {} models took: {:?}",
|
||||
models.len(),
|
||||
first_fetch_duration
|
||||
);
|
||||
|
||||
// Second fetch (definitely from cache)
|
||||
let start = Instant::now();
|
||||
for (provider, model) in &models {
|
||||
let pricing = get_model_pricing(provider, model).await;
|
||||
assert!(
|
||||
pricing.is_some(),
|
||||
"Expected pricing for {}/{}",
|
||||
provider,
|
||||
model
|
||||
);
|
||||
}
|
||||
let second_fetch_duration = start.elapsed();
|
||||
println!(
|
||||
"Second fetch of {} models took: {:?}",
|
||||
models.len(),
|
||||
second_fetch_duration
|
||||
);
|
||||
|
||||
// Cache fetch should be significantly faster
|
||||
// Note: Both fetches are already very fast (microseconds), so we just ensure
|
||||
// the second fetch is not slower than the first (allowing for some variance)
|
||||
assert!(
|
||||
second_fetch_duration <= first_fetch_duration * 2,
|
||||
"Cache fetch should not be significantly slower than initial fetch. First: {:?}, Second: {:?}",
|
||||
first_fetch_duration,
|
||||
second_fetch_duration
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pricing_refresh() {
|
||||
// Initialize first
|
||||
initialize_pricing_cache()
|
||||
.await
|
||||
.expect("Failed to initialize pricing cache");
|
||||
|
||||
// Get initial pricing (using a model that actually exists)
|
||||
let initial_pricing = get_model_pricing("anthropic", "claude-3.5-sonnet").await;
|
||||
assert!(initial_pricing.is_some(), "Expected initial pricing");
|
||||
|
||||
// Force refresh
|
||||
let start = Instant::now();
|
||||
refresh_pricing().await.expect("Failed to refresh pricing");
|
||||
let refresh_duration = start.elapsed();
|
||||
println!("Pricing refresh took: {:?}", refresh_duration);
|
||||
|
||||
// Get pricing after refresh
|
||||
let refreshed_pricing = get_model_pricing("anthropic", "claude-3.5-sonnet").await;
|
||||
assert!(
|
||||
refreshed_pricing.is_some(),
|
||||
"Expected pricing after refresh"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_not_in_openrouter() {
|
||||
initialize_pricing_cache()
|
||||
.await
|
||||
.expect("Failed to initialize pricing cache");
|
||||
|
||||
// Test a model that likely doesn't exist
|
||||
let pricing = get_model_pricing("fake-provider", "fake-model").await;
|
||||
assert!(
|
||||
pricing.is_none(),
|
||||
"Should return None for non-existent model"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_access() {
|
||||
use tokio::task;
|
||||
|
||||
initialize_pricing_cache()
|
||||
.await
|
||||
.expect("Failed to initialize pricing cache");
|
||||
|
||||
// Spawn multiple tasks to access pricing concurrently
|
||||
let mut handles = vec![];
|
||||
|
||||
for i in 0..10 {
|
||||
let handle = task::spawn(async move {
|
||||
let start = Instant::now();
|
||||
let pricing = get_model_pricing("openai", "gpt-4o").await;
|
||||
let duration = start.elapsed();
|
||||
(i, pricing.is_some(), duration)
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
for handle in handles {
|
||||
let (task_id, has_pricing, duration) = handle.await.unwrap();
|
||||
assert!(has_pricing, "Task {} should have gotten pricing", task_id);
|
||||
println!("Task {} took: {:?}", task_id, duration);
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@
|
||||
"license": {
|
||||
"name": "Apache-2.0"
|
||||
},
|
||||
"version": "1.0.30"
|
||||
"version": "1.0.31"
|
||||
},
|
||||
"paths": {
|
||||
"/agent/tools": {
|
||||
|
||||
@@ -619,10 +619,22 @@ function ChatContent({
|
||||
}));
|
||||
}
|
||||
|
||||
// Reset token counters for the new model
|
||||
setSessionTokenCount(0);
|
||||
setSessionInputTokens(0);
|
||||
setSessionOutputTokens(0);
|
||||
// Restore token counters from session metadata instead of resetting to 0
|
||||
// This preserves the accumulated session tokens when switching models
|
||||
// and ensures cost tracking remains accurate across model changes
|
||||
if (sessionMetadata) {
|
||||
// Use Math.max to ensure non-negative values and handle potential data issues
|
||||
setSessionTokenCount(Math.max(0, sessionMetadata.totalTokens || 0));
|
||||
setSessionInputTokens(Math.max(0, sessionMetadata.accumulatedInputTokens || 0));
|
||||
setSessionOutputTokens(Math.max(0, sessionMetadata.accumulatedOutputTokens || 0));
|
||||
} else {
|
||||
// Fallback: if no session metadata, preserve current session tokens instead of resetting
|
||||
// This handles edge cases where metadata might not be available yet
|
||||
console.warn(
|
||||
'No session metadata available during model change, preserving current tokens'
|
||||
);
|
||||
}
|
||||
// Only reset local token estimation counters since they're model-specific
|
||||
setLocalInputTokens(0);
|
||||
setLocalOutputTokens(0);
|
||||
|
||||
@@ -631,7 +643,7 @@ function ChatContent({
|
||||
`${prevProviderRef.current}/${prevModelRef.current}`,
|
||||
'to',
|
||||
`${currentProvider}/${currentModel}`,
|
||||
'- saved costs and reset token counters'
|
||||
'- saved costs and restored session token counters'
|
||||
);
|
||||
}
|
||||
|
||||
@@ -644,6 +656,7 @@ function ChatContent({
|
||||
sessionOutputTokens,
|
||||
localInputTokens,
|
||||
localOutputTokens,
|
||||
sessionMetadata,
|
||||
]);
|
||||
|
||||
const handleDrop = (e: React.DragEvent<HTMLDivElement>) => {
|
||||
|
||||
@@ -66,9 +66,7 @@ export function CostTracker({ inputTokens = 0, outputTokens = 0, sessionCosts }:
|
||||
initializeCostDatabase();
|
||||
|
||||
// Update costs for all models in background
|
||||
updateAllModelCosts().catch((error) => {
|
||||
console.error('Failed to update model costs:', error);
|
||||
});
|
||||
updateAllModelCosts().catch(() => {});
|
||||
}, [getProviders]);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -78,18 +76,12 @@ export function CostTracker({ inputTokens = 0, outputTokens = 0, sessionCosts }:
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`CostTracker: Loading cost info for ${currentProvider}/${currentModel}`);
|
||||
|
||||
try {
|
||||
// First check sync cache
|
||||
let costData = getCostForModel(currentProvider, currentModel);
|
||||
|
||||
if (costData) {
|
||||
// We have cached data
|
||||
console.log(
|
||||
`CostTracker: Found cached data for ${currentProvider}/${currentModel}:`,
|
||||
costData
|
||||
);
|
||||
setCostInfo(costData);
|
||||
setPricingFailed(false);
|
||||
setModelNotFound(false);
|
||||
@@ -97,30 +89,19 @@ export function CostTracker({ inputTokens = 0, outputTokens = 0, sessionCosts }:
|
||||
setHasAttemptedFetch(true);
|
||||
} else {
|
||||
// Need to fetch from backend
|
||||
console.log(
|
||||
`CostTracker: No cached data, fetching from backend for ${currentProvider}/${currentModel}`
|
||||
);
|
||||
setIsLoading(true);
|
||||
const result = await fetchAndCachePricing(currentProvider, currentModel);
|
||||
setHasAttemptedFetch(true);
|
||||
|
||||
if (result && result.costInfo) {
|
||||
console.log(
|
||||
`CostTracker: Fetched data for ${currentProvider}/${currentModel}:`,
|
||||
result.costInfo
|
||||
);
|
||||
setCostInfo(result.costInfo);
|
||||
setPricingFailed(false);
|
||||
setModelNotFound(false);
|
||||
} else if (result && result.error === 'model_not_found') {
|
||||
console.log(
|
||||
`CostTracker: Model not found in pricing data for ${currentProvider}/${currentModel}`
|
||||
);
|
||||
// Model not found in pricing database, but API call succeeded
|
||||
setModelNotFound(true);
|
||||
setPricingFailed(false);
|
||||
} else {
|
||||
console.log(`CostTracker: API failed for ${currentProvider}/${currentModel}`);
|
||||
// API call failed or other error
|
||||
const freeProviders = ['ollama', 'local', 'localhost'];
|
||||
if (!freeProviders.includes(currentProvider.toLowerCase())) {
|
||||
@@ -131,7 +112,6 @@ export function CostTracker({ inputTokens = 0, outputTokens = 0, sessionCosts }:
|
||||
setIsLoading(false);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error loading cost info:', error);
|
||||
setHasAttemptedFetch(true);
|
||||
// Only set pricing failed if we're not dealing with a known free provider
|
||||
const freeProviders = ['ollama', 'local', 'localhost'];
|
||||
@@ -194,8 +174,6 @@ export function CostTracker({ inputTokens = 0, outputTokens = 0, sessionCosts }:
|
||||
return cost.toFixed(6);
|
||||
};
|
||||
|
||||
// Debug logging removed
|
||||
|
||||
// Show loading state or when we don't have model/provider info
|
||||
if (!currentModel || !currentProvider) {
|
||||
return null;
|
||||
|
||||
@@ -7,148 +7,8 @@ export interface ModelCostInfo {
|
||||
currency: string; // Currency symbol
|
||||
}
|
||||
|
||||
// In-memory cache for current model pricing only
|
||||
let currentModelPricing: {
|
||||
provider: string;
|
||||
model: string;
|
||||
costInfo: ModelCostInfo | null;
|
||||
} | null = null;
|
||||
|
||||
// LocalStorage keys
|
||||
const PRICING_CACHE_KEY = 'goose_pricing_cache';
|
||||
const PRICING_CACHE_TIMESTAMP_KEY = 'goose_pricing_cache_timestamp';
|
||||
const RECENTLY_USED_MODELS_KEY = 'goose_recently_used_models';
|
||||
const CACHE_TTL_MS = 7 * 24 * 60 * 60 * 1000; // 7 days in milliseconds
|
||||
const MAX_RECENTLY_USED_MODELS = 20; // Keep only the last 20 used models in cache
|
||||
|
||||
interface PricingItem {
|
||||
provider: string;
|
||||
model: string;
|
||||
input_token_cost: number;
|
||||
output_token_cost: number;
|
||||
currency: string;
|
||||
}
|
||||
|
||||
interface PricingCacheData {
|
||||
pricing: PricingItem[];
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
interface RecentlyUsedModel {
|
||||
provider: string;
|
||||
model: string;
|
||||
lastUsed: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get recently used models from localStorage
|
||||
*/
|
||||
function getRecentlyUsedModels(): RecentlyUsedModel[] {
|
||||
try {
|
||||
const stored = localStorage.getItem(RECENTLY_USED_MODELS_KEY);
|
||||
return stored ? JSON.parse(stored) : [];
|
||||
} catch (error) {
|
||||
console.error('Error loading recently used models:', error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a model to the recently used list
|
||||
*/
|
||||
function addToRecentlyUsed(provider: string, model: string): void {
|
||||
try {
|
||||
let recentModels = getRecentlyUsedModels();
|
||||
|
||||
// Remove existing entry if present
|
||||
recentModels = recentModels.filter((m) => !(m.provider === provider && m.model === model));
|
||||
|
||||
// Add to front
|
||||
recentModels.unshift({ provider, model, lastUsed: Date.now() });
|
||||
|
||||
// Keep only the most recent models
|
||||
recentModels = recentModels.slice(0, MAX_RECENTLY_USED_MODELS);
|
||||
|
||||
localStorage.setItem(RECENTLY_USED_MODELS_KEY, JSON.stringify(recentModels));
|
||||
} catch (error) {
|
||||
console.error('Error saving recently used models:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load pricing data from localStorage cache - only for recently used models
|
||||
*/
|
||||
function loadPricingFromLocalStorage(): PricingCacheData | null {
|
||||
try {
|
||||
const cached = localStorage.getItem(PRICING_CACHE_KEY);
|
||||
const timestamp = localStorage.getItem(PRICING_CACHE_TIMESTAMP_KEY);
|
||||
|
||||
if (cached && timestamp) {
|
||||
const cacheAge = Date.now() - parseInt(timestamp, 10);
|
||||
if (cacheAge < CACHE_TTL_MS) {
|
||||
const fullCache = JSON.parse(cached) as PricingCacheData;
|
||||
const recentModels = getRecentlyUsedModels();
|
||||
|
||||
// Filter to only include recently used models
|
||||
const filteredPricing = fullCache.pricing.filter((p) =>
|
||||
recentModels.some((r) => r.provider === p.provider && r.model === p.model)
|
||||
);
|
||||
|
||||
console.log(
|
||||
`Loading ${filteredPricing.length} recently used models from cache (out of ${fullCache.pricing.length} total)`
|
||||
);
|
||||
|
||||
return {
|
||||
pricing: filteredPricing,
|
||||
timestamp: fullCache.timestamp,
|
||||
};
|
||||
} else {
|
||||
console.log('LocalStorage pricing cache expired');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error loading pricing from localStorage:', error);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save pricing data to localStorage - merge with existing data
|
||||
*/
|
||||
function savePricingToLocalStorage(data: PricingCacheData, mergeWithExisting = true): void {
|
||||
try {
|
||||
if (mergeWithExisting) {
|
||||
// Load existing full cache
|
||||
const existingCached = localStorage.getItem(PRICING_CACHE_KEY);
|
||||
if (existingCached) {
|
||||
const existingData = JSON.parse(existingCached) as PricingCacheData;
|
||||
|
||||
// Create a map of existing pricing for quick lookup
|
||||
const pricingMap = new Map<string, (typeof data.pricing)[0]>();
|
||||
existingData.pricing.forEach((p) => {
|
||||
pricingMap.set(`${p.provider}/${p.model}`, p);
|
||||
});
|
||||
|
||||
// Update with new data
|
||||
data.pricing.forEach((p) => {
|
||||
pricingMap.set(`${p.provider}/${p.model}`, p);
|
||||
});
|
||||
|
||||
// Convert back to array
|
||||
data = {
|
||||
pricing: Array.from(pricingMap.values()),
|
||||
timestamp: data.timestamp,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
localStorage.setItem(PRICING_CACHE_KEY, JSON.stringify(data));
|
||||
localStorage.setItem(PRICING_CACHE_TIMESTAMP_KEY, data.timestamp.toString());
|
||||
console.log(`Saved ${data.pricing.length} models to localStorage cache`);
|
||||
} catch (error) {
|
||||
console.error('Error saving pricing to localStorage:', error);
|
||||
}
|
||||
}
|
||||
// In-memory cache for current session only
|
||||
const sessionPricingCache = new Map<string, ModelCostInfo | null>();
|
||||
|
||||
/**
|
||||
* Fetch pricing data from backend for specific provider/model
|
||||
@@ -157,12 +17,199 @@ async function fetchPricingForModel(
|
||||
provider: string,
|
||||
model: string
|
||||
): Promise<ModelCostInfo | null> {
|
||||
// For OpenRouter models, we need to use the parsed provider and model for the API lookup
|
||||
let lookupProvider = provider;
|
||||
let lookupModel = model;
|
||||
|
||||
if (provider.toLowerCase() === 'openrouter') {
|
||||
const parsed = parseOpenRouterModel(model);
|
||||
if (parsed) {
|
||||
lookupProvider = parsed[0];
|
||||
lookupModel = parsed[1];
|
||||
}
|
||||
}
|
||||
|
||||
const apiUrl = getApiUrl('/config/pricing');
|
||||
const secretKey = getSecretKey();
|
||||
|
||||
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||
if (secretKey) {
|
||||
headers['X-Secret-Key'] = secretKey;
|
||||
}
|
||||
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({ configured_only: false }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`API request failed with status ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
// Find the specific model pricing using the lookup provider/model
|
||||
const pricing = data.pricing?.find(
|
||||
(p: {
|
||||
provider: string;
|
||||
model: string;
|
||||
input_token_cost: number;
|
||||
output_token_cost: number;
|
||||
currency: string;
|
||||
}) => {
|
||||
const providerMatch = p.provider.toLowerCase() === lookupProvider.toLowerCase();
|
||||
|
||||
// More flexible model matching - handle versioned models
|
||||
let modelMatch = p.model === lookupModel;
|
||||
|
||||
// If exact match fails, try matching without version suffix
|
||||
if (!modelMatch && lookupModel.includes('-20')) {
|
||||
// Remove date suffix like -20241022
|
||||
const modelWithoutDate = lookupModel.replace(/-20\d{6}$/, '');
|
||||
modelMatch = p.model === modelWithoutDate;
|
||||
|
||||
// Also try with dots instead of dashes (claude-3-5-sonnet vs claude-3.5-sonnet)
|
||||
if (!modelMatch) {
|
||||
const modelWithDots = modelWithoutDate.replace(/-(\d)-/g, '.$1.');
|
||||
modelMatch = p.model === modelWithDots;
|
||||
}
|
||||
}
|
||||
|
||||
return providerMatch && modelMatch;
|
||||
}
|
||||
);
|
||||
|
||||
if (pricing) {
|
||||
return {
|
||||
input_token_cost: pricing.input_token_cost,
|
||||
output_token_cost: pricing.output_token_cost,
|
||||
currency: pricing.currency || '$',
|
||||
};
|
||||
}
|
||||
|
||||
// API call succeeded but model not found in pricing data
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the cost database - no-op since we fetch on demand now
|
||||
*/
|
||||
export async function initializeCostDatabase(): Promise<void> {
|
||||
// Clear session cache on init
|
||||
sessionPricingCache.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Update model costs from providers - no-op since we fetch on demand
|
||||
*/
|
||||
export async function updateAllModelCosts(): Promise<void> {
|
||||
// No-op - we fetch on demand now
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse OpenRouter model ID to extract provider and model
|
||||
* e.g., "anthropic/claude-sonnet-4" -> ["anthropic", "claude-sonnet-4"]
|
||||
*/
|
||||
function parseOpenRouterModel(modelId: string): [string, string] | null {
|
||||
const parts = modelId.split('/');
|
||||
if (parts.length === 2) {
|
||||
return [parts[0], parts[1]];
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get cost information for a specific model with session caching
|
||||
*/
|
||||
export function getCostForModel(provider: string, model: string): ModelCostInfo | null {
|
||||
const cacheKey = `${provider}/${model}`;
|
||||
|
||||
// Check session cache first
|
||||
if (sessionPricingCache.has(cacheKey)) {
|
||||
return sessionPricingCache.get(cacheKey) || null;
|
||||
}
|
||||
|
||||
// For OpenRouter models, also check if we have cached data under the parsed provider/model
|
||||
if (provider.toLowerCase() === 'openrouter') {
|
||||
const parsed = parseOpenRouterModel(model);
|
||||
if (parsed) {
|
||||
const [parsedProvider, parsedModel] = parsed;
|
||||
const parsedCacheKey = `${parsedProvider}/${parsedModel}`;
|
||||
if (sessionPricingCache.has(parsedCacheKey)) {
|
||||
const cachedData = sessionPricingCache.get(parsedCacheKey) || null;
|
||||
// Also cache it under the original OpenRouter key for future lookups
|
||||
sessionPricingCache.set(cacheKey, cachedData);
|
||||
return cachedData;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For local/free providers, return zero cost immediately
|
||||
const freeProviders = ['ollama', 'local', 'localhost'];
|
||||
if (freeProviders.includes(provider.toLowerCase())) {
|
||||
const zeroCost = {
|
||||
input_token_cost: 0,
|
||||
output_token_cost: 0,
|
||||
currency: '$',
|
||||
};
|
||||
sessionPricingCache.set(cacheKey, zeroCost);
|
||||
return zeroCost;
|
||||
}
|
||||
|
||||
// Need to fetch - return null and let component handle async fetch
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch and cache pricing for a model
|
||||
*/
|
||||
export async function fetchAndCachePricing(
|
||||
provider: string,
|
||||
model: string
|
||||
): Promise<{ costInfo: ModelCostInfo | null; error?: string } | null> {
|
||||
try {
|
||||
const cacheKey = `${provider}/${model}`;
|
||||
const costInfo = await fetchPricingForModel(provider, model);
|
||||
|
||||
// Cache the result in session cache under the original key
|
||||
sessionPricingCache.set(cacheKey, costInfo);
|
||||
|
||||
// For OpenRouter models, also cache under the parsed provider/model key
|
||||
// This helps with cross-referencing between frontend requests and backend responses
|
||||
if (provider.toLowerCase() === 'openrouter') {
|
||||
const parsed = parseOpenRouterModel(model);
|
||||
if (parsed) {
|
||||
const [parsedProvider, parsedModel] = parsed;
|
||||
const parsedCacheKey = `${parsedProvider}/${parsedModel}`;
|
||||
sessionPricingCache.set(parsedCacheKey, costInfo);
|
||||
}
|
||||
}
|
||||
|
||||
if (costInfo) {
|
||||
return { costInfo };
|
||||
} else {
|
||||
// Model not found in pricing data
|
||||
return { costInfo: null, error: 'model_not_found' };
|
||||
}
|
||||
} catch (error) {
|
||||
// This is a real API/network error
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh pricing data from backend
|
||||
*/
|
||||
export async function refreshPricing(): Promise<boolean> {
|
||||
try {
|
||||
// Clear session cache to force re-fetch
|
||||
sessionPricingCache.clear();
|
||||
|
||||
// The actual refresh happens on the backend when we call with configured_only: false
|
||||
const apiUrl = getApiUrl('/config/pricing');
|
||||
const secretKey = getSecretKey();
|
||||
|
||||
console.log(`Fetching pricing for ${provider}/${model} from ${apiUrl}`);
|
||||
|
||||
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||
if (secretKey) {
|
||||
headers['X-Secret-Key'] = secretKey;
|
||||
@@ -174,420 +221,25 @@ async function fetchPricingForModel(
|
||||
body: JSON.stringify({ configured_only: false }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Failed to fetch pricing data:', response.status);
|
||||
throw new Error(`API request failed with status ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log('Pricing response:', data);
|
||||
|
||||
// Find the specific model pricing
|
||||
const pricing = data.pricing?.find(
|
||||
(p: {
|
||||
provider: string;
|
||||
model: string;
|
||||
input_token_cost: number;
|
||||
output_token_cost: number;
|
||||
currency: string;
|
||||
}) => {
|
||||
const providerMatch = p.provider.toLowerCase() === provider.toLowerCase();
|
||||
|
||||
// More flexible model matching - handle versioned models
|
||||
let modelMatch = p.model === model;
|
||||
|
||||
// If exact match fails, try matching without version suffix
|
||||
if (!modelMatch && model.includes('-20')) {
|
||||
// Remove date suffix like -20241022
|
||||
const modelWithoutDate = model.replace(/-20\d{6}$/, '');
|
||||
modelMatch = p.model === modelWithoutDate;
|
||||
|
||||
// Also try with dots instead of dashes (claude-3-5-sonnet vs claude-3.5-sonnet)
|
||||
if (!modelMatch) {
|
||||
const modelWithDots = modelWithoutDate.replace(/-(\d)-/g, '.$1.');
|
||||
modelMatch = p.model === modelWithDots;
|
||||
}
|
||||
}
|
||||
|
||||
console.log(
|
||||
`Comparing: ${p.provider}/${p.model} with ${provider}/${model} - Provider match: ${providerMatch}, Model match: ${modelMatch}`
|
||||
);
|
||||
return providerMatch && modelMatch;
|
||||
}
|
||||
);
|
||||
|
||||
console.log(`Found pricing for ${provider}/${model}:`, pricing);
|
||||
|
||||
if (pricing) {
|
||||
return {
|
||||
input_token_cost: pricing.input_token_cost,
|
||||
output_token_cost: pricing.output_token_cost,
|
||||
currency: pricing.currency || '$',
|
||||
};
|
||||
}
|
||||
|
||||
console.log(
|
||||
`No pricing found for ${provider}/${model} in:`,
|
||||
data.pricing?.map((p: { provider: string; model: string }) => `${p.provider}/${p.model}`)
|
||||
);
|
||||
|
||||
// API call succeeded but model not found in pricing data
|
||||
return null;
|
||||
return response.ok;
|
||||
} catch (error) {
|
||||
console.error('Error fetching pricing data:', error);
|
||||
// Re-throw the error so the caller can distinguish between API failure and model not found
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the cost database - only load commonly used models on startup
|
||||
*/
|
||||
export async function initializeCostDatabase(): Promise<void> {
|
||||
try {
|
||||
// Clean up any existing large caches first
|
||||
cleanupPricingCache();
|
||||
|
||||
// First check if we have valid cached data
|
||||
const cachedData = loadPricingFromLocalStorage();
|
||||
if (cachedData && cachedData.pricing.length > 0) {
|
||||
console.log('Using cached pricing data from localStorage');
|
||||
return;
|
||||
}
|
||||
|
||||
// List of commonly used models to pre-fetch
|
||||
const commonModels = [
|
||||
{ provider: 'openai', model: 'gpt-4o' },
|
||||
{ provider: 'openai', model: 'gpt-4o-mini' },
|
||||
{ provider: 'openai', model: 'gpt-4-turbo' },
|
||||
{ provider: 'openai', model: 'gpt-4' },
|
||||
{ provider: 'openai', model: 'gpt-3.5-turbo' },
|
||||
{ provider: 'anthropic', model: 'claude-3-5-sonnet' },
|
||||
{ provider: 'anthropic', model: 'claude-3-5-sonnet-20241022' },
|
||||
{ provider: 'anthropic', model: 'claude-3-opus' },
|
||||
{ provider: 'anthropic', model: 'claude-3-sonnet' },
|
||||
{ provider: 'anthropic', model: 'claude-3-haiku' },
|
||||
{ provider: 'google', model: 'gemini-1.5-pro' },
|
||||
{ provider: 'google', model: 'gemini-1.5-flash' },
|
||||
{ provider: 'deepseek', model: 'deepseek-chat' },
|
||||
{ provider: 'deepseek', model: 'deepseek-reasoner' },
|
||||
{ provider: 'meta-llama', model: 'llama-3.2-90b-text-preview' },
|
||||
{ provider: 'meta-llama', model: 'llama-3.1-405b-instruct' },
|
||||
];
|
||||
|
||||
// Get recently used models
|
||||
const recentModels = getRecentlyUsedModels();
|
||||
|
||||
// Combine common and recent models (deduplicated)
|
||||
const modelsToFetch = new Map<string, { provider: string; model: string }>();
|
||||
|
||||
// Add common models
|
||||
commonModels.forEach((m) => {
|
||||
modelsToFetch.set(`${m.provider}/${m.model}`, m);
|
||||
});
|
||||
|
||||
// Add recent models
|
||||
recentModels.forEach((m) => {
|
||||
modelsToFetch.set(`${m.provider}/${m.model}`, { provider: m.provider, model: m.model });
|
||||
});
|
||||
|
||||
console.log(`Initializing cost database with ${modelsToFetch.size} models...`);
|
||||
|
||||
// Fetch only the pricing we need
|
||||
const apiUrl = getApiUrl('/config/pricing');
|
||||
const secretKey = getSecretKey();
|
||||
|
||||
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||
if (secretKey) {
|
||||
headers['X-Secret-Key'] = secretKey;
|
||||
}
|
||||
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
configured_only: false,
|
||||
models: Array.from(modelsToFetch.values()), // Send specific models if API supports it
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Failed to fetch initial pricing data:', response.status);
|
||||
return;
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log(`Fetched pricing for ${data.pricing?.length || 0} models`);
|
||||
|
||||
if (data.pricing && data.pricing.length > 0) {
|
||||
// Filter to only the models we requested (in case API returns all)
|
||||
const filteredPricing = data.pricing.filter((p: PricingItem) =>
|
||||
modelsToFetch.has(`${p.provider}/${p.model}`)
|
||||
);
|
||||
|
||||
// Save to localStorage
|
||||
const cacheData: PricingCacheData = {
|
||||
pricing: filteredPricing.length > 0 ? filteredPricing : data.pricing.slice(0, 50), // Fallback to first 50 if filtering didn't work
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
savePricingToLocalStorage(cacheData, false); // Don't merge on initial load
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error initializing cost database:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update model costs from providers - no longer needed
|
||||
*/
|
||||
export async function updateAllModelCosts(): Promise<void> {
|
||||
// No-op - we fetch on demand now
|
||||
}
|
||||
|
||||
/**
|
||||
* Get cost information for a specific model with caching
|
||||
*/
|
||||
export function getCostForModel(provider: string, model: string): ModelCostInfo | null {
|
||||
// Track this model as recently used
|
||||
addToRecentlyUsed(provider, model);
|
||||
|
||||
// Check if it's the same model we already have cached in memory
|
||||
if (
|
||||
currentModelPricing &&
|
||||
currentModelPricing.provider === provider &&
|
||||
currentModelPricing.model === model
|
||||
) {
|
||||
return currentModelPricing.costInfo;
|
||||
}
|
||||
|
||||
// For local/free providers, return zero cost immediately
|
||||
const freeProviders = ['ollama', 'local', 'localhost'];
|
||||
if (freeProviders.includes(provider.toLowerCase())) {
|
||||
const zeroCost = {
|
||||
input_token_cost: 0,
|
||||
output_token_cost: 0,
|
||||
currency: '$',
|
||||
};
|
||||
currentModelPricing = { provider, model, costInfo: zeroCost };
|
||||
return zeroCost;
|
||||
}
|
||||
|
||||
// Check localStorage cache (which now only contains recently used models)
|
||||
const cachedData = loadPricingFromLocalStorage();
|
||||
if (cachedData) {
|
||||
const pricing = cachedData.pricing.find((p) => {
|
||||
const providerMatch = p.provider.toLowerCase() === provider.toLowerCase();
|
||||
|
||||
// More flexible model matching - handle versioned models
|
||||
let modelMatch = p.model === model;
|
||||
|
||||
// If exact match fails, try matching without version suffix
|
||||
if (!modelMatch && model.includes('-20')) {
|
||||
// Remove date suffix like -20241022
|
||||
const modelWithoutDate = model.replace(/-20\d{6}$/, '');
|
||||
modelMatch = p.model === modelWithoutDate;
|
||||
|
||||
// Also try with dots instead of dashes (claude-3-5-sonnet vs claude-3.5-sonnet)
|
||||
if (!modelMatch) {
|
||||
const modelWithDots = modelWithoutDate.replace(/-(\d)-/g, '.$1.');
|
||||
modelMatch = p.model === modelWithDots;
|
||||
}
|
||||
}
|
||||
|
||||
return providerMatch && modelMatch;
|
||||
});
|
||||
|
||||
if (pricing) {
|
||||
const costInfo = {
|
||||
input_token_cost: pricing.input_token_cost,
|
||||
output_token_cost: pricing.output_token_cost,
|
||||
currency: pricing.currency || '$',
|
||||
};
|
||||
currentModelPricing = { provider, model, costInfo };
|
||||
return costInfo;
|
||||
}
|
||||
}
|
||||
|
||||
// Need to fetch new pricing - return null for now
|
||||
// The component will handle the async fetch
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch and cache pricing for a model
|
||||
*/
|
||||
export async function fetchAndCachePricing(
|
||||
provider: string,
|
||||
model: string
|
||||
): Promise<{ costInfo: ModelCostInfo | null; error?: string } | null> {
|
||||
try {
|
||||
const costInfo = await fetchPricingForModel(provider, model);
|
||||
|
||||
if (costInfo) {
|
||||
// Cache the result in memory
|
||||
currentModelPricing = { provider, model, costInfo };
|
||||
|
||||
// Update localStorage cache with this new data
|
||||
const cachedData = loadPricingFromLocalStorage();
|
||||
if (cachedData) {
|
||||
// Check if this model already exists in cache
|
||||
const existingIndex = cachedData.pricing.findIndex(
|
||||
(p) => p.provider.toLowerCase() === provider.toLowerCase() && p.model === model
|
||||
);
|
||||
|
||||
const newPricing = {
|
||||
provider,
|
||||
model,
|
||||
input_token_cost: costInfo.input_token_cost,
|
||||
output_token_cost: costInfo.output_token_cost,
|
||||
currency: costInfo.currency,
|
||||
};
|
||||
|
||||
if (existingIndex >= 0) {
|
||||
// Update existing
|
||||
cachedData.pricing[existingIndex] = newPricing;
|
||||
} else {
|
||||
// Add new
|
||||
cachedData.pricing.push(newPricing);
|
||||
}
|
||||
|
||||
// Save updated cache
|
||||
savePricingToLocalStorage(cachedData);
|
||||
}
|
||||
|
||||
return { costInfo };
|
||||
} else {
|
||||
// Cache the null result in memory
|
||||
currentModelPricing = { provider, model, costInfo: null };
|
||||
|
||||
// Check if the API call succeeded but model wasn't found
|
||||
// We can determine this by checking if we got a response but no matching model
|
||||
return { costInfo: null, error: 'model_not_found' };
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in fetchAndCachePricing:', error);
|
||||
// This is a real API/network error
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh pricing data from backend - only refresh recently used models
|
||||
*/
|
||||
export async function refreshPricing(): Promise<boolean> {
|
||||
try {
|
||||
const apiUrl = getApiUrl('/config/pricing');
|
||||
const secretKey = getSecretKey();
|
||||
|
||||
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||
if (secretKey) {
|
||||
headers['X-Secret-Key'] = secretKey;
|
||||
}
|
||||
|
||||
// Get recently used models to refresh
|
||||
const recentModels = getRecentlyUsedModels();
|
||||
|
||||
// Add some common models as well
|
||||
const commonModels = [
|
||||
{ provider: 'openai', model: 'gpt-4o' },
|
||||
{ provider: 'openai', model: 'gpt-4o-mini' },
|
||||
{ provider: 'anthropic', model: 'claude-3-5-sonnet-20241022' },
|
||||
{ provider: 'google', model: 'gemini-1.5-pro' },
|
||||
];
|
||||
|
||||
// Combine and deduplicate
|
||||
const modelsToRefresh = new Map<string, { provider: string; model: string }>();
|
||||
|
||||
commonModels.forEach((m) => {
|
||||
modelsToRefresh.set(`${m.provider}/${m.model}`, m);
|
||||
});
|
||||
|
||||
recentModels.forEach((m) => {
|
||||
modelsToRefresh.set(`${m.provider}/${m.model}`, { provider: m.provider, model: m.model });
|
||||
});
|
||||
|
||||
console.log(`Refreshing pricing for ${modelsToRefresh.size} models...`);
|
||||
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
configured_only: false,
|
||||
models: Array.from(modelsToRefresh.values()), // Send specific models if API supports it
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
|
||||
if (data.pricing && data.pricing.length > 0) {
|
||||
// Filter to only the models we requested (in case API returns all)
|
||||
const filteredPricing = data.pricing.filter((p: PricingItem) =>
|
||||
modelsToRefresh.has(`${p.provider}/${p.model}`)
|
||||
);
|
||||
|
||||
// Save fresh data to localStorage (merge with existing)
|
||||
const cacheData: PricingCacheData = {
|
||||
pricing: filteredPricing.length > 0 ? filteredPricing : data.pricing.slice(0, 50),
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
savePricingToLocalStorage(cacheData, true); // Merge with existing
|
||||
}
|
||||
|
||||
// Clear current memory cache to force re-fetch
|
||||
currentModelPricing = null;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
} catch (error) {
|
||||
console.error('Error refreshing pricing data:', error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up old/unused models from the cache
|
||||
*/
|
||||
export function cleanupPricingCache(): void {
|
||||
try {
|
||||
const recentModels = getRecentlyUsedModels();
|
||||
const cachedData = localStorage.getItem(PRICING_CACHE_KEY);
|
||||
|
||||
if (!cachedData) return;
|
||||
|
||||
const fullCache = JSON.parse(cachedData) as PricingCacheData;
|
||||
const recentModelKeys = new Set(recentModels.map((m) => `${m.provider}/${m.model}`));
|
||||
|
||||
// Keep only recently used models and common models
|
||||
const commonModelKeys = new Set([
|
||||
'openai/gpt-4o',
|
||||
'openai/gpt-4o-mini',
|
||||
'openai/gpt-4-turbo',
|
||||
'anthropic/claude-3-5-sonnet',
|
||||
'anthropic/claude-3-5-sonnet-20241022',
|
||||
'google/gemini-1.5-pro',
|
||||
'google/gemini-1.5-flash',
|
||||
]);
|
||||
|
||||
const filteredPricing = fullCache.pricing.filter((p) => {
|
||||
const key = `${p.provider}/${p.model}`;
|
||||
return recentModelKeys.has(key) || commonModelKeys.has(key);
|
||||
});
|
||||
|
||||
if (filteredPricing.length < fullCache.pricing.length) {
|
||||
console.log(
|
||||
`Cleaned up pricing cache: reduced from ${fullCache.pricing.length} to ${filteredPricing.length} models`
|
||||
);
|
||||
|
||||
const cleanedCache: PricingCacheData = {
|
||||
pricing: filteredPricing,
|
||||
timestamp: fullCache.timestamp,
|
||||
};
|
||||
|
||||
localStorage.setItem(PRICING_CACHE_KEY, JSON.stringify(cleanedCache));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error cleaning up pricing cache:', error);
|
||||
// Expose functions for testing in development mode
|
||||
declare global {
|
||||
interface Window {
|
||||
getCostForModel?: typeof getCostForModel;
|
||||
fetchAndCachePricing?: typeof fetchAndCachePricing;
|
||||
refreshPricing?: typeof refreshPricing;
|
||||
sessionPricingCache?: typeof sessionPricingCache;
|
||||
}
|
||||
}
|
||||
|
||||
if (process.env.NODE_ENV === 'development' || typeof window !== 'undefined') {
|
||||
window.getCostForModel = getCostForModel;
|
||||
window.fetchAndCachePricing = fetchAndCachePricing;
|
||||
window.refreshPricing = refreshPricing;
|
||||
window.sessionPricingCache = sessionPricingCache;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user