diff --git a/Cargo.lock b/Cargo.lock index 66eecd72..57eebcc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3427,6 +3427,7 @@ dependencies = [ "chrono", "criterion", "ctor", + "dirs 5.0.1", "dotenv", "etcetera", "fs2", diff --git a/crates/goose-server/src/commands/agent.rs b/crates/goose-server/src/commands/agent.rs index 104e3b81..5fdfa89a 100644 --- a/crates/goose-server/src/commands/agent.rs +++ b/crates/goose-server/src/commands/agent.rs @@ -10,12 +10,23 @@ use goose::scheduler_factory::SchedulerFactory; use tower_http::cors::{Any, CorsLayer}; use tracing::info; +use goose::providers::pricing::initialize_pricing_cache; + pub async fn run() -> Result<()> { // Initialize logging crate::logging::setup_logging(Some("goosed"))?; let settings = configuration::Settings::new()?; + // Initialize pricing cache on startup + tracing::info!("Initializing pricing cache..."); + if let Err(e) = initialize_pricing_cache().await { + tracing::warn!( + "Failed to initialize pricing cache: {}. Pricing data may not be available.", + e + ); + } + let secret_key = std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string()); diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 06e15198..e2e919b9 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -13,6 +13,7 @@ use goose::config::{extensions::name_to_key, PermissionManager}; use goose::config::{ExtensionConfigManager, ExtensionEntry}; use goose::model::ModelConfig; use goose::providers::base::ProviderMetadata; +use goose::providers::pricing::{get_all_pricing, get_model_pricing, refresh_pricing}; use goose::providers::providers as get_providers; use goose::{agents::ExtensionConfig, config::permission::PermissionLevel}; use http::{HeaderMap, StatusCode}; @@ -314,6 +315,128 @@ pub async fn providers( Ok(Json(providers_response)) } +#[derive(Serialize, ToSchema)] +pub struct PricingData { + pub provider: String, + pub model: String, + pub input_token_cost: f64, + pub output_token_cost: f64, + pub currency: String, + pub context_length: Option, +} + +#[derive(Serialize, ToSchema)] +pub struct PricingResponse { + pub pricing: Vec, + pub source: String, +} + +#[derive(Deserialize, ToSchema)] +pub struct PricingQuery { + /// If true, only return pricing for configured providers. If false, return all. + pub configured_only: Option, +} + +#[utoipa::path( + post, + path = "/config/pricing", + request_body = PricingQuery, + responses( + (status = 200, description = "Model pricing data retrieved successfully", body = PricingResponse) + ) +)] +pub async fn get_pricing( + State(state): State>, + headers: HeaderMap, + Json(query): Json, +) -> Result, StatusCode> { + verify_secret_key(&headers, &state)?; + + let configured_only = query.configured_only.unwrap_or(true); + + // If refresh requested (configured_only = false), refresh the cache + if !configured_only { + if let Err(e) = refresh_pricing().await { + tracing::error!("Failed to refresh pricing data: {}", e); + } + } + + let mut pricing_data = Vec::new(); + + if !configured_only { + // Get ALL pricing data from the cache + let all_pricing = get_all_pricing().await; + + for (provider, models) in all_pricing { + for (model, pricing) in models { + pricing_data.push(PricingData { + provider: provider.clone(), + model: model.clone(), + input_token_cost: pricing.input_cost, + output_token_cost: pricing.output_cost, + currency: "$".to_string(), + context_length: pricing.context_length, + }); + } + } + } else { + // Get only configured providers' pricing + let providers_metadata = get_providers(); + + for metadata in providers_metadata { + // Skip unconfigured providers if filtering + if !check_provider_configured(&metadata) { + continue; + } + + for model_info in &metadata.known_models { + // Try to get pricing from cache + if let Some(pricing) = get_model_pricing(&metadata.name, &model_info.name).await { + pricing_data.push(PricingData { + provider: metadata.name.clone(), + model: model_info.name.clone(), + input_token_cost: pricing.input_cost, + output_token_cost: pricing.output_cost, + currency: "$".to_string(), + context_length: pricing.context_length, + }); + } + // Check if the model has embedded pricing data + else if let (Some(input_cost), Some(output_cost)) = + (model_info.input_token_cost, model_info.output_token_cost) + { + pricing_data.push(PricingData { + provider: metadata.name.clone(), + model: model_info.name.clone(), + input_token_cost: input_cost, + output_token_cost: output_cost, + currency: model_info + .currency + .clone() + .unwrap_or_else(|| "$".to_string()), + context_length: Some(model_info.context_limit as u32), + }); + } + } + } + } + + tracing::info!( + "Returning pricing for {} models{}", + pricing_data.len(), + if configured_only { + " (configured providers only)" + } else { + " (all cached models)" + } + ); + + Ok(Json(PricingResponse { + pricing: pricing_data, + source: "openrouter".to_string(), + })) +} + #[utoipa::path( post, path = "/config/init", @@ -471,6 +594,7 @@ pub fn routes(state: Arc) -> Router { .route("/config/extensions", post(add_extension)) .route("/config/extensions/{name}", delete(remove_extension)) .route("/config/providers", get(providers)) + .route("/config/pricing", post(get_pricing)) .route("/config/init", post(init_config)) .route("/config/backup", post(backup_config)) .route("/config/permissions", post(upsert_permissions)) diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index bf2a23d9..c435449a 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -17,6 +17,7 @@ mcp-core = { path = "../mcp-core" } anyhow = "1.0" thiserror = "1.0" futures = "0.3" +dirs = "5.0" reqwest = { version = "0.12.9", features = [ "rustls-tls-native-roots", "json", diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 8a06372e..89b7223b 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -5,7 +5,7 @@ use reqwest::{Client, StatusCode}; use serde_json::Value; use std::time::Duration; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; +use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; use super::formats::anthropic::{create_request, get_usage, response_to_message}; use super::utils::{emit_debug_trace, get_model}; @@ -122,12 +122,18 @@ impl AnthropicProvider { #[async_trait] impl Provider for AnthropicProvider { fn metadata() -> ProviderMetadata { - ProviderMetadata::new( + ProviderMetadata::with_models( "anthropic", "Anthropic", "Claude and other models from Anthropic", ANTHROPIC_DEFAULT_MODEL, - ANTHROPIC_KNOWN_MODELS.to_vec(), + vec![ + ModelInfo::with_cost("claude-3-5-sonnet-20241022", 200000, 0.000003, 0.000015), + ModelInfo::with_cost("claude-3-5-haiku-20241022", 200000, 0.000001, 0.000005), + ModelInfo::with_cost("claude-3-opus-20240229", 200000, 0.000015, 0.000075), + ModelInfo::with_cost("claude-3-sonnet-20240229", 200000, 0.000003, 0.000015), + ModelInfo::with_cost("claude-3-haiku-20240307", 200000, 0.00000025, 0.00000125), + ], ANTHROPIC_DOC_URL, vec![ ConfigKey::new("ANTHROPIC_API_KEY", true, true, None), diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index d0e8cc43..09a5ef08 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -32,6 +32,41 @@ pub struct ModelInfo { pub name: String, /// The maximum context length this model supports pub context_limit: usize, + /// Cost per token for input (optional) + pub input_token_cost: Option, + /// Cost per token for output (optional) + pub output_token_cost: Option, + /// Currency for the costs (default: "$") + pub currency: Option, +} + +impl ModelInfo { + /// Create a new ModelInfo with just name and context limit + pub fn new(name: impl Into, context_limit: usize) -> Self { + Self { + name: name.into(), + context_limit, + input_token_cost: None, + output_token_cost: None, + currency: None, + } + } + + /// Create a new ModelInfo with cost information (per token) + pub fn with_cost( + name: impl Into, + context_limit: usize, + input_cost: f64, + output_cost: f64, + ) -> Self { + Self { + name: name.into(), + context_limit, + input_token_cost: Some(input_cost), + output_token_cost: Some(output_cost), + currency: Some("$".to_string()), + } + } } /// Metadata about a provider's configuration requirements and capabilities @@ -74,6 +109,9 @@ impl ProviderMetadata { .map(|&name| ModelInfo { name: name.to_string(), context_limit: ModelConfig::new(name.to_string()).context_limit(), + input_token_cost: None, + output_token_cost: None, + currency: None, }) .collect(), model_doc_link: model_doc_link.to_string(), @@ -81,6 +119,27 @@ impl ProviderMetadata { } } + /// Create a new ProviderMetadata with ModelInfo objects that include cost data + pub fn with_models( + name: &str, + display_name: &str, + description: &str, + default_model: &str, + models: Vec, + model_doc_link: &str, + config_keys: Vec, + ) -> Self { + Self { + name: name.to_string(), + display_name: display_name.to_string(), + description: description.to_string(), + default_model: default_model.to_string(), + known_models: models, + model_doc_link: model_doc_link.to_string(), + config_keys, + } + } + pub fn empty() -> Self { Self { name: "".to_string(), @@ -313,6 +372,9 @@ mod tests { let info = ModelInfo { name: "test-model".to_string(), context_limit: 1000, + input_token_cost: None, + output_token_cost: None, + currency: None, }; assert_eq!(info.context_limit, 1000); @@ -320,6 +382,9 @@ mod tests { let info2 = ModelInfo { name: "test-model".to_string(), context_limit: 1000, + input_token_cost: None, + output_token_cost: None, + currency: None, }; assert_eq!(info, info2); @@ -327,7 +392,20 @@ mod tests { let info3 = ModelInfo { name: "test-model".to_string(), context_limit: 2000, + input_token_cost: None, + output_token_cost: None, + currency: None, }; assert_ne!(info, info3); } + + #[test] + fn test_model_info_with_cost() { + let info = ModelInfo::with_cost("gpt-4o", 128000, 0.0000025, 0.00001); + assert_eq!(info.name, "gpt-4o"); + assert_eq!(info.context_limit, 128000); + assert_eq!(info.input_token_cost, Some(0.0000025)); + assert_eq!(info.output_token_cost, Some(0.00001)); + assert_eq!(info.currency, Some("$".to_string())); + } } diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index decd346a..a9f02b1c 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -18,6 +18,7 @@ pub mod oauth; pub mod ollama; pub mod openai; pub mod openrouter; +pub mod pricing; pub mod sagemaker_tgi; pub mod snowflake; pub mod toolshim; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index fd46fbcb..a3bbeef2 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -5,7 +5,7 @@ use serde_json::Value; use std::collections::HashMap; use std::time::Duration; -use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; @@ -126,12 +126,20 @@ impl OpenAiProvider { #[async_trait] impl Provider for OpenAiProvider { fn metadata() -> ProviderMetadata { - ProviderMetadata::new( + ProviderMetadata::with_models( "openai", "OpenAI", "GPT-4 and other OpenAI models, including OpenAI compatible ones", OPEN_AI_DEFAULT_MODEL, - OPEN_AI_KNOWN_MODELS.to_vec(), + vec![ + ModelInfo::with_cost("gpt-4o", 128000, 0.0000025, 0.00001), + ModelInfo::with_cost("gpt-4o-mini", 128000, 0.00000015, 0.0000006), + ModelInfo::with_cost("gpt-4-turbo", 128000, 0.00001, 0.00003), + ModelInfo::with_cost("gpt-3.5-turbo", 16385, 0.0000005, 0.0000015), + ModelInfo::with_cost("o1", 200000, 0.000015, 0.00006), + ModelInfo::with_cost("o3", 200000, 0.000015, 0.00006), // Using o1 pricing as placeholder + ModelInfo::with_cost("o4-mini", 128000, 0.000003, 0.000012), // Using o1-mini pricing as placeholder + ], OPEN_AI_DOC_URL, vec![ ConfigKey::new("OPENAI_API_KEY", true, true, None), diff --git a/crates/goose/src/providers/pricing.rs b/crates/goose/src/providers/pricing.rs new file mode 100644 index 00000000..ed24983d --- /dev/null +++ b/crates/goose/src/providers/pricing.rs @@ -0,0 +1,387 @@ +use anyhow::Result; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::sync::RwLock; + +/// Disk cache configuration +const CACHE_FILE_NAME: &str = "pricing_cache.json"; +const CACHE_TTL_DAYS: u64 = 7; // Cache for 7 days + +/// Get the cache directory path +fn get_cache_dir() -> Result { + let cache_dir = if let Ok(goose_dir) = std::env::var("GOOSE_CACHE_DIR") { + PathBuf::from(goose_dir) + } else { + dirs::cache_dir() + .ok_or_else(|| anyhow::anyhow!("Could not determine cache directory"))? + .join("goose") + }; + Ok(cache_dir) +} + +/// Cached pricing data structure for disk storage +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CachedPricingData { + /// Nested HashMap: provider -> model -> pricing info + pub pricing: HashMap>, + /// Unix timestamp when data was fetched + pub fetched_at: u64, +} + +/// Simplified pricing info for efficient storage +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PricingInfo { + pub input_cost: f64, // Cost per token + pub output_cost: f64, // Cost per token + pub context_length: Option, +} + +/// Cache for OpenRouter pricing data with disk persistence +pub struct PricingCache { + /// In-memory cache + memory_cache: Arc>>, +} + +impl PricingCache { + pub fn new() -> Self { + Self { + memory_cache: Arc::new(RwLock::new(None)), + } + } + + /// Load pricing from disk cache + async fn load_from_disk(&self) -> Result> { + let cache_path = get_cache_dir()?.join(CACHE_FILE_NAME); + + if !cache_path.exists() { + return Ok(None); + } + + match tokio::fs::read(&cache_path).await { + Ok(data) => { + match serde_json::from_slice::(&data) { + Ok(cached) => { + // Check if cache is still valid + let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); + let age_days = (now - cached.fetched_at) / (24 * 60 * 60); + + if age_days < CACHE_TTL_DAYS { + tracing::info!( + "Loaded pricing data from disk cache (age: {} days)", + age_days + ); + Ok(Some(cached)) + } else { + tracing::info!("Disk cache expired (age: {} days)", age_days); + Ok(None) + } + } + Err(e) => { + tracing::warn!("Failed to parse pricing cache: {}", e); + Ok(None) + } + } + } + Err(e) => { + tracing::warn!("Failed to read pricing cache: {}", e); + Ok(None) + } + } + } + + /// Save pricing data to disk + async fn save_to_disk(&self, data: &CachedPricingData) -> Result<()> { + let cache_dir = get_cache_dir()?; + tokio::fs::create_dir_all(&cache_dir).await?; + + let cache_path = cache_dir.join(CACHE_FILE_NAME); + let json_data = serde_json::to_vec_pretty(data)?; + tokio::fs::write(&cache_path, json_data).await?; + + tracing::info!("Saved pricing data to disk cache"); + Ok(()) + } + + /// Get pricing for a specific model + pub async fn get_model_pricing(&self, provider: &str, model: &str) -> Option { + // Try memory cache first + { + let cache = self.memory_cache.read().await; + if let Some(cached) = &*cache { + return cached + .pricing + .get(&provider.to_lowercase()) + .and_then(|models| models.get(model)) + .cloned(); + } + } + + // Try loading from disk + if let Ok(Some(disk_cache)) = self.load_from_disk().await { + // Update memory cache + { + let mut cache = self.memory_cache.write().await; + *cache = Some(disk_cache.clone()); + } + + return disk_cache + .pricing + .get(&provider.to_lowercase()) + .and_then(|models| models.get(model)) + .cloned(); + } + + None + } + + /// Force refresh pricing data from OpenRouter + pub async fn refresh(&self) -> Result<()> { + let pricing = fetch_openrouter_pricing_internal().await?; + + // Convert to our efficient structure + let mut structured_pricing: HashMap> = HashMap::new(); + + for (model_id, model) in pricing { + if let Some((provider, model_name)) = parse_model_id(&model_id) { + if let (Some(input_cost), Some(output_cost)) = ( + convert_pricing(&model.pricing.prompt), + convert_pricing(&model.pricing.completion), + ) { + let provider_lower = provider.to_lowercase(); + let provider_models = structured_pricing.entry(provider_lower).or_default(); + + provider_models.insert( + model_name, + PricingInfo { + input_cost, + output_cost, + context_length: model.context_length, + }, + ); + } + } + } + + let cached_data = CachedPricingData { + pricing: structured_pricing, + fetched_at: SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(), + }; + + // Log how many models we fetched + let total_models: usize = cached_data + .pricing + .values() + .map(|models| models.len()) + .sum(); + tracing::info!( + "Fetched pricing for {} providers with {} total models from OpenRouter", + cached_data.pricing.len(), + total_models + ); + + // Save to disk + self.save_to_disk(&cached_data).await?; + + // Update memory cache + { + let mut cache = self.memory_cache.write().await; + *cache = Some(cached_data); + } + + Ok(()) + } + + /// Initialize cache (load from disk or fetch if needed) + pub async fn initialize(&self) -> Result<()> { + // Try loading from disk first + if let Ok(Some(cached)) = self.load_from_disk().await { + // Log how many models we have cached + let total_models: usize = cached.pricing.values().map(|models| models.len()).sum(); + tracing::info!( + "Loaded {} providers with {} total models from disk cache", + cached.pricing.len(), + total_models + ); + + // Update memory cache + { + let mut cache = self.memory_cache.write().await; + *cache = Some(cached); + } + + return Ok(()); + } + + // If no disk cache, fetch from OpenRouter + tracing::info!("No valid disk cache found, fetching from OpenRouter"); + self.refresh().await + } +} + +impl Default for PricingCache { + fn default() -> Self { + Self::new() + } +} + +// Global cache instance +lazy_static::lazy_static! { + static ref PRICING_CACHE: PricingCache = PricingCache::new(); + static ref HTTP_CLIENT: Client = Client::builder() + .timeout(Duration::from_secs(30)) + .pool_idle_timeout(Duration::from_secs(90)) + .pool_max_idle_per_host(10) + .build() + .unwrap(); +} + +/// OpenRouter model pricing information +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenRouterModel { + pub id: String, + pub name: String, + pub pricing: OpenRouterPricing, + pub context_length: Option, + pub architecture: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenRouterPricing { + pub prompt: String, // Cost per token for input (in USD) + pub completion: String, // Cost per token for output (in USD) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Architecture { + pub modality: String, + pub tokenizer: String, + pub instruct_type: Option, +} + +/// Response from OpenRouter models endpoint +#[derive(Debug, Deserialize)] +pub struct OpenRouterModelsResponse { + pub data: Vec, +} + +/// Internal function to fetch pricing data +async fn fetch_openrouter_pricing_internal() -> Result> { + let response = HTTP_CLIENT + .get("https://openrouter.ai/api/v1/models") + .send() + .await?; + + if !response.status().is_success() { + anyhow::bail!( + "Failed to fetch OpenRouter models: HTTP {}", + response.status() + ); + } + + let models_response: OpenRouterModelsResponse = response.json().await?; + + // Create a map for easy lookup + let mut pricing_map = HashMap::new(); + for model in models_response.data { + pricing_map.insert(model.id.clone(), model); + } + + Ok(pricing_map) +} + +/// Initialize pricing cache on startup +pub async fn initialize_pricing_cache() -> Result<()> { + PRICING_CACHE.initialize().await +} + +/// Get pricing for a specific model +pub async fn get_model_pricing(provider: &str, model: &str) -> Option { + PRICING_CACHE.get_model_pricing(provider, model).await +} + +/// Force refresh pricing data +pub async fn refresh_pricing() -> Result<()> { + PRICING_CACHE.refresh().await +} + +/// Get all cached pricing data +pub async fn get_all_pricing() -> HashMap> { + let cache = PRICING_CACHE.memory_cache.read().await; + if let Some(cached) = &*cache { + cached.pricing.clone() + } else { + // Try loading from disk + if let Ok(Some(disk_cache)) = PRICING_CACHE.load_from_disk().await { + // Update memory cache + drop(cache); + let mut write_cache = PRICING_CACHE.memory_cache.write().await; + *write_cache = Some(disk_cache.clone()); + disk_cache.pricing + } else { + HashMap::new() + } + } +} + +/// Convert OpenRouter model ID to provider/model format +/// e.g., "anthropic/claude-3.5-sonnet" -> ("anthropic", "claude-3.5-sonnet") +pub fn parse_model_id(model_id: &str) -> Option<(String, String)> { + let parts: Vec<&str> = model_id.splitn(2, '/').collect(); + if parts.len() == 2 { + // Normalize provider names to match our internal naming + let provider = match parts[0] { + "openai" => "openai", + "anthropic" => "anthropic", + "google" => "google", + "meta-llama" => "ollama", // Meta models often run via Ollama + "mistralai" => "mistral", + "cohere" => "cohere", + "perplexity" => "perplexity", + "deepseek" => "deepseek", + "groq" => "groq", + "nvidia" => "nvidia", + "microsoft" => "azure", + "replicate" => "replicate", + "huggingface" => "huggingface", + _ => parts[0], + }; + Some((provider.to_string(), parts[1].to_string())) + } else { + None + } +} + +/// Convert OpenRouter pricing to cost per token (already in that format) +pub fn convert_pricing(price_str: &str) -> Option { + // OpenRouter prices are already in USD per token + price_str.parse::().ok() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_model_id() { + assert_eq!( + parse_model_id("anthropic/claude-3.5-sonnet"), + Some(("anthropic".to_string(), "claude-3.5-sonnet".to_string())) + ); + assert_eq!( + parse_model_id("openai/gpt-4"), + Some(("openai".to_string(), "gpt-4".to_string())) + ); + assert_eq!(parse_model_id("invalid-format"), None); + } + + #[test] + fn test_convert_pricing() { + assert_eq!(convert_pricing("0.000003"), Some(0.000003)); + assert_eq!(convert_pricing("0.015"), Some(0.015)); + assert_eq!(convert_pricing("invalid"), None); + } +} diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index a791f817..bd5092dd 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -1700,9 +1700,26 @@ "description": "The maximum context length this model supports", "minimum": 0 }, + "currency": { + "type": "string", + "description": "Currency for the costs (default: \"$\")", + "nullable": true + }, + "input_token_cost": { + "type": "number", + "format": "double", + "description": "Cost per token for input (optional)", + "nullable": true + }, "name": { "type": "string", "description": "The name of the model" + }, + "output_token_cost": { + "type": "number", + "format": "double", + "description": "Cost per token for output (optional)", + "nullable": true } } }, diff --git a/ui/desktop/src/App.tsx b/ui/desktop/src/App.tsx index 67e4d687..5b2e5c38 100644 --- a/ui/desktop/src/App.tsx +++ b/ui/desktop/src/App.tsx @@ -3,6 +3,7 @@ import { IpcRendererEvent } from 'electron'; import { openSharedSessionFromDeepLink, type SessionLinksViewOptions } from './sessionLinks'; import { type SharedSessionDetails } from './sharedSessions'; import { initializeSystem } from './utils/providerUtils'; +import { initializeCostDatabase } from './utils/costDatabase'; import { ErrorUI } from './components/ErrorBoundary'; import { ConfirmationModal } from './components/ui/ConfirmationModal'; import { ToastContainer } from 'react-toastify'; @@ -158,6 +159,11 @@ export default function App() { const initializeApp = async () => { try { + // Initialize cost database early to pre-load pricing data + initializeCostDatabase().catch((error) => { + console.error('Failed to initialize cost database:', error); + }); + await initConfig(); try { await readAllConfig({ throwOnError: true }); diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index c5ed03a6..56eb1f27 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -229,10 +229,22 @@ export type ModelInfo = { * The maximum context length this model supports */ context_limit: number; + /** + * Currency for the costs (default: "$") + */ + currency?: string | null; + /** + * Cost per token for input (optional) + */ + input_token_cost?: number | null; /** * The name of the model */ name: string; + /** + * Cost per token for output (optional) + */ + output_token_cost?: number | null; }; export type PermissionConfirmationRequest = { diff --git a/ui/desktop/src/components/ChatInput.tsx b/ui/desktop/src/components/ChatInput.tsx index 09cb2ade..5e1c3daa 100644 --- a/ui/desktop/src/components/ChatInput.tsx +++ b/ui/desktop/src/components/ChatInput.tsx @@ -29,9 +29,18 @@ interface ChatInputProps { droppedFiles?: string[]; setView: (view: View) => void; numTokens?: number; + inputTokens?: number; + outputTokens?: number; hasMessages?: boolean; messages?: Message[]; setMessages: (messages: Message[]) => void; + sessionCosts?: { + [key: string]: { + inputTokens: number; + outputTokens: number; + totalCost: number; + }; + }; } export default function ChatInput({ @@ -42,9 +51,12 @@ export default function ChatInput({ initialValue = '', setView, numTokens, + inputTokens, + outputTokens, droppedFiles = [], messages = [], setMessages, + sessionCosts, }: ChatInputProps) { const [_value, setValue] = useState(initialValue); const [displayValue, setDisplayValue] = useState(initialValue); // For immediate visual feedback @@ -557,9 +569,12 @@ export default function ChatInput({ diff --git a/ui/desktop/src/components/ChatView.tsx b/ui/desktop/src/components/ChatView.tsx index 6d674ed4..23407e8b 100644 --- a/ui/desktop/src/components/ChatView.tsx +++ b/ui/desktop/src/components/ChatView.tsx @@ -33,6 +33,8 @@ import { } from './context_management/ChatContextManager'; import { ContextHandler } from './context_management/ContextHandler'; import { LocalMessageStorage } from '../utils/localMessageStorage'; +import { useModelAndProvider } from './ModelAndProviderContext'; +import { getCostForModel } from '../utils/costDatabase'; import { Message, createUserMessage, @@ -106,11 +108,25 @@ function ChatContent({ const [showGame, setShowGame] = useState(false); const [isGeneratingRecipe, setIsGeneratingRecipe] = useState(false); const [sessionTokenCount, setSessionTokenCount] = useState(0); + const [sessionInputTokens, setSessionInputTokens] = useState(0); + const [sessionOutputTokens, setSessionOutputTokens] = useState(0); + const [localInputTokens, setLocalInputTokens] = useState(0); + const [localOutputTokens, setLocalOutputTokens] = useState(0); const [ancestorMessages, setAncestorMessages] = useState([]); const [droppedFiles, setDroppedFiles] = useState([]); + const [sessionCosts, setSessionCosts] = useState<{ + [key: string]: { + inputTokens: number; + outputTokens: number; + totalCost: number; + }; + }>({}); const [readyForAutoUserPrompt, setReadyForAutoUserPrompt] = useState(false); const scrollRef = useRef(null); + const { currentModel, currentProvider } = useModelAndProvider(); + const prevModelRef = useRef(); + const prevProviderRef = useRef(); const { summaryContent, @@ -160,6 +176,7 @@ function ChatContent({ updateMessageStreamBody, notifications, currentModelInfo, + sessionMetadata, } = useMessageStream({ api: getApiUrl('/reply'), initialMessages: chat.messages, @@ -518,12 +535,40 @@ function ChatContent({ .reverse(); }, [filteredMessages]); + // Simple token estimation function (roughly 4 characters per token) + const estimateTokens = (text: string): number => { + return Math.ceil(text.length / 4); + }; + + // Calculate token counts from messages + useEffect(() => { + let inputTokens = 0; + let outputTokens = 0; + + messages.forEach((message) => { + const textContent = getTextContent(message); + if (textContent) { + const tokens = estimateTokens(textContent); + if (message.role === 'user') { + inputTokens += tokens; + } else if (message.role === 'assistant') { + outputTokens += tokens; + } + } + }); + + setLocalInputTokens(inputTokens); + setLocalOutputTokens(outputTokens); + }, [messages]); + // Fetch session metadata to get token count useEffect(() => { const fetchSessionTokens = async () => { try { const sessionDetails = await fetchSessionDetails(chat.id); setSessionTokenCount(sessionDetails.metadata.total_tokens || 0); + setSessionInputTokens(sessionDetails.metadata.accumulated_input_tokens || 0); + setSessionOutputTokens(sessionDetails.metadata.accumulated_output_tokens || 0); } catch (err) { console.error('Error fetching session token count:', err); } @@ -533,6 +578,74 @@ function ChatContent({ } }, [chat.id, messages]); + // Update token counts when sessionMetadata changes from the message stream + useEffect(() => { + console.log('Session metadata received:', sessionMetadata); + if (sessionMetadata) { + setSessionTokenCount(sessionMetadata.totalTokens || 0); + setSessionInputTokens(sessionMetadata.accumulatedInputTokens || 0); + setSessionOutputTokens(sessionMetadata.accumulatedOutputTokens || 0); + } + }, [sessionMetadata]); + + // Handle model changes and accumulate costs + useEffect(() => { + if ( + prevModelRef.current !== undefined && + prevProviderRef.current !== undefined && + (prevModelRef.current !== currentModel || prevProviderRef.current !== currentProvider) + ) { + // Model/provider has changed, save the costs for the previous model + const prevKey = `${prevProviderRef.current}/${prevModelRef.current}`; + + // Get pricing info for the previous model + const prevCostInfo = getCostForModel(prevProviderRef.current, prevModelRef.current); + + if (prevCostInfo) { + const prevInputCost = + (sessionInputTokens || localInputTokens) * (prevCostInfo.input_token_cost || 0); + const prevOutputCost = + (sessionOutputTokens || localOutputTokens) * (prevCostInfo.output_token_cost || 0); + const prevTotalCost = prevInputCost + prevOutputCost; + + // Save the accumulated costs for this model + setSessionCosts((prev) => ({ + ...prev, + [prevKey]: { + inputTokens: sessionInputTokens || localInputTokens, + outputTokens: sessionOutputTokens || localOutputTokens, + totalCost: prevTotalCost, + }, + })); + } + + // Reset token counters for the new model + setSessionTokenCount(0); + setSessionInputTokens(0); + setSessionOutputTokens(0); + setLocalInputTokens(0); + setLocalOutputTokens(0); + + console.log( + 'Model changed from', + `${prevProviderRef.current}/${prevModelRef.current}`, + 'to', + `${currentProvider}/${currentModel}`, + '- saved costs and reset token counters' + ); + } + + prevModelRef.current = currentModel || undefined; + prevProviderRef.current = currentProvider || undefined; + }, [ + currentModel, + currentProvider, + sessionInputTokens, + sessionOutputTokens, + localInputTokens, + localOutputTokens, + ]); + const handleDrop = (e: React.DragEvent) => { e.preventDefault(); const files = e.dataTransfer.files; @@ -684,9 +797,12 @@ function ChatContent({ setView={setView} hasMessages={hasMessages} numTokens={sessionTokenCount} + inputTokens={sessionInputTokens || localInputTokens} + outputTokens={sessionOutputTokens || localOutputTokens} droppedFiles={droppedFiles} messages={messages} setMessages={setMessages} + sessionCosts={sessionCosts} /> diff --git a/ui/desktop/src/components/bottom_menu/BottomMenu.tsx b/ui/desktop/src/components/bottom_menu/BottomMenu.tsx index 587670ec..bd3a2d4a 100644 --- a/ui/desktop/src/components/bottom_menu/BottomMenu.tsx +++ b/ui/desktop/src/components/bottom_menu/BottomMenu.tsx @@ -9,6 +9,7 @@ import { useConfig } from '../ConfigContext'; import { useModelAndProvider } from '../ModelAndProviderContext'; import { Message } from '../../types/message'; import { ManualSummarizeButton } from '../context_management/ManualSummaryButton'; +import { CostTracker } from './CostTracker'; const TOKEN_LIMIT_DEFAULT = 128000; // fallback for custom models that the backend doesn't know about const TOKEN_WARNING_THRESHOLD = 0.8; // warning shows at 80% of the token limit @@ -22,15 +23,27 @@ interface ModelLimit { export default function BottomMenu({ setView, numTokens = 0, + inputTokens = 0, + outputTokens = 0, messages = [], isLoading = false, setMessages, + sessionCosts, }: { setView: (view: View, viewOptions?: ViewOptions) => void; numTokens?: number; + inputTokens?: number; + outputTokens?: number; messages?: Message[]; isLoading?: boolean; setMessages: (messages: Message[]) => void; + sessionCosts?: { + [key: string]: { + inputTokens: number; + outputTokens: number; + totalCost: number; + }; + }; }) { const [isModelMenuOpen, setIsModelMenuOpen] = useState(false); const { alerts, addAlert, clearAlerts } = useAlerts(); @@ -202,29 +215,45 @@ export default function BottomMenu({ }, [isModelMenuOpen]); return ( -
-
+
+
{/* Tool and Token count */} - {} +
+ {} +
+ + {/* Cost Tracker - no separator before it */} +
+ +
+ + {/* Separator between cost and model */} +
{/* Model Selector Dropdown */} - +
+ +
{/* Separator */} -
+
{/* Goose Mode Selector Dropdown */} - +
+ +
- {/* Summarize Context Button - ADD THIS */} + {/* Summarize Context Button */} {messages.length > 0 && ( <> -
- +
+
+ +
)}
diff --git a/ui/desktop/src/components/bottom_menu/CostTracker.tsx b/ui/desktop/src/components/bottom_menu/CostTracker.tsx new file mode 100644 index 00000000..2f57e94f --- /dev/null +++ b/ui/desktop/src/components/bottom_menu/CostTracker.tsx @@ -0,0 +1,310 @@ +import { useState, useEffect } from 'react'; +import { useModelAndProvider } from '../ModelAndProviderContext'; +import { useConfig } from '../ConfigContext'; +import { + getCostForModel, + initializeCostDatabase, + updateAllModelCosts, + fetchAndCachePricing, +} from '../../utils/costDatabase'; + +interface CostTrackerProps { + inputTokens?: number; + outputTokens?: number; + sessionCosts?: { + [key: string]: { + inputTokens: number; + outputTokens: number; + totalCost: number; + }; + }; +} + +export function CostTracker({ inputTokens = 0, outputTokens = 0, sessionCosts }: CostTrackerProps) { + const { currentModel, currentProvider } = useModelAndProvider(); + const { getProviders } = useConfig(); + const [costInfo, setCostInfo] = useState<{ + input_token_cost?: number; + output_token_cost?: number; + currency?: string; + } | null>(null); + const [isLoading, setIsLoading] = useState(true); + const [showPricing, setShowPricing] = useState(true); + const [pricingFailed, setPricingFailed] = useState(false); + const [modelNotFound, setModelNotFound] = useState(false); + const [hasAttemptedFetch, setHasAttemptedFetch] = useState(false); + const [initialLoadComplete, setInitialLoadComplete] = useState(false); + + // Check if pricing is enabled + useEffect(() => { + const checkPricingSetting = () => { + const stored = localStorage.getItem('show_pricing'); + setShowPricing(stored !== 'false'); + }; + + // Check on mount + checkPricingSetting(); + + // Listen for storage changes + window.addEventListener('storage', checkPricingSetting); + return () => window.removeEventListener('storage', checkPricingSetting); + }, []); + + // Set initial load complete after a short delay + useEffect(() => { + const timer = setTimeout(() => { + setInitialLoadComplete(true); + }, 3000); // Give 3 seconds for initial load + + return () => window.clearTimeout(timer); + }, []); + + // Debug log props removed + + // Initialize cost database on mount + useEffect(() => { + initializeCostDatabase(); + + // Update costs for all models in background + updateAllModelCosts().catch((error) => { + console.error('Failed to update model costs:', error); + }); + }, [getProviders]); + + useEffect(() => { + const loadCostInfo = async () => { + if (!currentModel || !currentProvider) { + setIsLoading(false); + return; + } + + console.log(`CostTracker: Loading cost info for ${currentProvider}/${currentModel}`); + + try { + // First check sync cache + let costData = getCostForModel(currentProvider, currentModel); + + if (costData) { + // We have cached data + console.log( + `CostTracker: Found cached data for ${currentProvider}/${currentModel}:`, + costData + ); + setCostInfo(costData); + setPricingFailed(false); + setModelNotFound(false); + setIsLoading(false); + setHasAttemptedFetch(true); + } else { + // Need to fetch from backend + console.log( + `CostTracker: No cached data, fetching from backend for ${currentProvider}/${currentModel}` + ); + setIsLoading(true); + const result = await fetchAndCachePricing(currentProvider, currentModel); + setHasAttemptedFetch(true); + + if (result && result.costInfo) { + console.log( + `CostTracker: Fetched data for ${currentProvider}/${currentModel}:`, + result.costInfo + ); + setCostInfo(result.costInfo); + setPricingFailed(false); + setModelNotFound(false); + } else if (result && result.error === 'model_not_found') { + console.log( + `CostTracker: Model not found in pricing data for ${currentProvider}/${currentModel}` + ); + // Model not found in pricing database, but API call succeeded + setModelNotFound(true); + setPricingFailed(false); + } else { + console.log(`CostTracker: API failed for ${currentProvider}/${currentModel}`); + // API call failed or other error + const freeProviders = ['ollama', 'local', 'localhost']; + if (!freeProviders.includes(currentProvider.toLowerCase())) { + setPricingFailed(true); + setModelNotFound(false); + } + } + setIsLoading(false); + } + } catch (error) { + console.error('Error loading cost info:', error); + setHasAttemptedFetch(true); + // Only set pricing failed if we're not dealing with a known free provider + const freeProviders = ['ollama', 'local', 'localhost']; + if (!freeProviders.includes(currentProvider.toLowerCase())) { + setPricingFailed(true); + setModelNotFound(false); + } + setIsLoading(false); + } + }; + + loadCostInfo(); + }, [currentModel, currentProvider]); + + // Return null early if pricing is disabled + if (!showPricing) { + return null; + } + + const calculateCost = (): number => { + // If we have session costs, calculate the total across all models + if (sessionCosts) { + let totalCost = 0; + + // Add up all historical costs from different models + Object.values(sessionCosts).forEach((modelCost) => { + totalCost += modelCost.totalCost; + }); + + // Add current model cost if we have pricing info + if ( + costInfo && + (costInfo.input_token_cost !== undefined || costInfo.output_token_cost !== undefined) + ) { + const currentInputCost = inputTokens * (costInfo.input_token_cost || 0); + const currentOutputCost = outputTokens * (costInfo.output_token_cost || 0); + totalCost += currentInputCost + currentOutputCost; + } + + return totalCost; + } + + // Fallback to simple calculation for current model only + if ( + !costInfo || + (costInfo.input_token_cost === undefined && costInfo.output_token_cost === undefined) + ) { + return 0; + } + + const inputCost = inputTokens * (costInfo.input_token_cost || 0); + const outputCost = outputTokens * (costInfo.output_token_cost || 0); + const total = inputCost + outputCost; + + return total; + }; + + const formatCost = (cost: number): string => { + // Always show 6 decimal places for consistency + return cost.toFixed(6); + }; + + // Debug logging removed + + // Show loading state or when we don't have model/provider info + if (!currentModel || !currentProvider) { + return null; + } + + // If still loading, show a placeholder + if (isLoading) { + return ( +
+ ... +
+ ); + } + + // If no cost info found, try to return a default + if ( + !costInfo || + (costInfo.input_token_cost === undefined && costInfo.output_token_cost === undefined) + ) { + // If it's a known free/local provider, show $0.000000 without "not available" message + const freeProviders = ['ollama', 'local', 'localhost']; + if (freeProviders.includes(currentProvider.toLowerCase())) { + return ( +
+ $0.000000 +
+ ); + } + + // Otherwise show as unavailable + const getUnavailableTooltip = () => { + if (pricingFailed && hasAttemptedFetch && initialLoadComplete) { + return `Pricing data unavailable - OpenRouter connection failed. Click refresh in settings to retry.`; + } + // If we reach here, it must be modelNotFound (since we only get here after attempting fetch) + return `Cost data not available for ${currentModel} (${inputTokens.toLocaleString()} input, ${outputTokens.toLocaleString()} output tokens)`; + }; + + return ( +
+ $0.000000 +
+ ); + } + + const totalCost = calculateCost(); + + // Build tooltip content + const getTooltipContent = (): string => { + // Handle error states first + if (pricingFailed && hasAttemptedFetch && initialLoadComplete) { + return `Pricing data unavailable - OpenRouter connection failed. Click refresh in settings to retry.`; + } + + if (modelNotFound && hasAttemptedFetch && initialLoadComplete) { + return `Pricing not available for ${currentProvider}/${currentModel}. This model may not be supported by the pricing service.`; + } + + // Handle session costs + if (sessionCosts && Object.keys(sessionCosts).length > 0) { + // Show session breakdown + let tooltip = 'Session cost breakdown:\n'; + + Object.entries(sessionCosts).forEach(([modelKey, cost]) => { + const costStr = `${costInfo?.currency || '$'}${cost.totalCost.toFixed(6)}`; + tooltip += `${modelKey}: ${costStr} (${cost.inputTokens.toLocaleString()} in, ${cost.outputTokens.toLocaleString()} out)\n`; + }); + + // Add current model if it has costs + if (costInfo && (inputTokens > 0 || outputTokens > 0)) { + const currentCost = + inputTokens * (costInfo.input_token_cost || 0) + + outputTokens * (costInfo.output_token_cost || 0); + if (currentCost > 0) { + tooltip += `${currentProvider}/${currentModel} (current): ${costInfo.currency || '$'}${currentCost.toFixed(6)} (${inputTokens.toLocaleString()} in, ${outputTokens.toLocaleString()} out)\n`; + } + } + + tooltip += `\nTotal session cost: ${costInfo?.currency || '$'}${totalCost.toFixed(6)}`; + return tooltip; + } + + // Default tooltip for single model + return `Input: ${inputTokens.toLocaleString()} tokens (${costInfo?.currency || '$'}${(inputTokens * (costInfo?.input_token_cost || 0)).toFixed(6)}) | Output: ${outputTokens.toLocaleString()} tokens (${costInfo?.currency || '$'}${(outputTokens * (costInfo?.output_token_cost || 0)).toFixed(6)})`; + }; + + return ( +
+ + {costInfo.currency || '$'} + {formatCost(totalCost)} + +
+ ); +} diff --git a/ui/desktop/src/components/settings/app/AppSettingsSection.tsx b/ui/desktop/src/components/settings/app/AppSettingsSection.tsx index 1a69d965..c19558c0 100644 --- a/ui/desktop/src/components/settings/app/AppSettingsSection.tsx +++ b/ui/desktop/src/components/settings/app/AppSettingsSection.tsx @@ -1,10 +1,11 @@ import { useState, useEffect, useRef } from 'react'; import { Switch } from '../../ui/switch'; import { Button } from '../../ui/button'; -import { Settings } from 'lucide-react'; +import { Settings, RefreshCw, ExternalLink } from 'lucide-react'; import Modal from '../../Modal'; import UpdateSection from './UpdateSection'; import { UPDATES_ENABLED } from '../../../updates'; +import { getApiUrl, getSecretKey } from '../../../config'; interface AppSettingsSectionProps { scrollToSection?: string; @@ -17,6 +18,10 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti const [isMacOS, setIsMacOS] = useState(false); const [isDockSwitchDisabled, setIsDockSwitchDisabled] = useState(false); const [showNotificationModal, setShowNotificationModal] = useState(false); + const [pricingStatus, setPricingStatus] = useState<'loading' | 'success' | 'error'>('loading'); + const [lastFetchTime, setLastFetchTime] = useState(null); + const [isRefreshing, setIsRefreshing] = useState(false); + const [showPricing, setShowPricing] = useState(true); const updateSectionRef = useRef(null); // Check if running on macOS @@ -24,6 +29,77 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti setIsMacOS(window.electron.platform === 'darwin'); }, []); + // Load show pricing setting + useEffect(() => { + const stored = localStorage.getItem('show_pricing'); + setShowPricing(stored !== 'false'); + }, []); + + // Check pricing status on mount + useEffect(() => { + checkPricingStatus(); + }, []); + + const checkPricingStatus = async () => { + try { + const apiUrl = getApiUrl('/config/pricing'); + const secretKey = getSecretKey(); + + const headers: HeadersInit = { 'Content-Type': 'application/json' }; + if (secretKey) { + headers['X-Secret-Key'] = secretKey; + } + + const response = await fetch(apiUrl, { + method: 'POST', + headers, + body: JSON.stringify({ configured_only: true }), + }); + + if (response.ok) { + await response.json(); // Consume the response + setPricingStatus('success'); + setLastFetchTime(new Date()); + } else { + setPricingStatus('error'); + } + } catch (error) { + setPricingStatus('error'); + } + }; + + const handleRefreshPricing = async () => { + setIsRefreshing(true); + try { + const apiUrl = getApiUrl('/config/pricing'); + const secretKey = getSecretKey(); + + const headers: HeadersInit = { 'Content-Type': 'application/json' }; + if (secretKey) { + headers['X-Secret-Key'] = secretKey; + } + + const response = await fetch(apiUrl, { + method: 'POST', + headers, + body: JSON.stringify({ configured_only: false }), + }); + + if (response.ok) { + setPricingStatus('success'); + setLastFetchTime(new Date()); + // Trigger a reload of the cost database + window.dispatchEvent(new CustomEvent('pricing-updated')); + } else { + setPricingStatus('error'); + } + } catch (error) { + setPricingStatus('error'); + } finally { + setIsRefreshing(false); + } + }; + // Handle scrolling to update section useEffect(() => { if (scrollToSection === 'update' && updateSectionRef.current) { @@ -99,6 +175,13 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti } }; + const handleShowPricingToggle = (checked: boolean) => { + setShowPricing(checked); + localStorage.setItem('show_pricing', String(checked)); + // Trigger storage event for other components + window.dispatchEvent(new CustomEvent('storage')); + }; + return (
@@ -173,6 +256,7 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
)} + {/* Quit Confirmation */}

Quit Confirmation

@@ -188,6 +272,87 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti />
+ + {/* Cost Tracking */} +
+
+

Cost Tracking

+

+ Show model pricing and usage costs +

+
+
+ +
+
+ + {/* Pricing Status - only show if cost tracking is enabled */} + {showPricing && ( + <> +
+ Pricing Source: + + OpenRouter Docs + + +
+ +
+ Status: +
+ + {pricingStatus === 'success' + ? '✓ Connected' + : pricingStatus === 'error' + ? '✗ Failed' + : '... Checking'} + + +
+
+ + {lastFetchTime && ( +
+ Last updated: + {lastFetchTime.toLocaleTimeString()} +
+ )} + + {pricingStatus === 'error' && ( +

+ Unable to fetch pricing data. Costs will not be displayed. +

+ )} + + )}
{/* Help & Feedback Section */} diff --git a/ui/desktop/src/components/settings/models/bottom_bar/ModelsBottomBar.tsx b/ui/desktop/src/components/settings/models/bottom_bar/ModelsBottomBar.tsx index 12911d80..48c55bb1 100644 --- a/ui/desktop/src/components/settings/models/bottom_bar/ModelsBottomBar.tsx +++ b/ui/desktop/src/components/settings/models/bottom_bar/ModelsBottomBar.tsx @@ -43,7 +43,10 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa }, [read]); // Determine which model to display - activeModel takes priority when lead/worker is active - const displayModel = (isLeadWorkerActive && currentModelInfo?.model) ? currentModelInfo.model : (currentModel || 'Select Model'); + const displayModel = + isLeadWorkerActive && currentModelInfo?.model + ? currentModelInfo.model + : currentModel || 'Select Model'; const modelMode = currentModelInfo?.mode; // Update display provider when current provider changes @@ -106,9 +109,7 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa > {displayModel} {isLeadWorkerActive && modelMode && ( - - ({modelMode}) - + ({modelMode}) )} @@ -116,9 +117,7 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa {displayModel} {isLeadWorkerActive && modelMode && ( - - ({modelMode}) - + ({modelMode}) )} )} @@ -164,7 +163,7 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa {isAddModelModalOpen ? ( setIsAddModelModalOpen(false)} /> ) : null} - + {isLeadWorkerModalOpen ? ( setIsLeadWorkerModalOpen(false)}> setIsLeadWorkerModalOpen(false)} /> diff --git a/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.tsx b/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.tsx index c87a0624..ec5c5717 100644 --- a/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.tsx +++ b/ui/desktop/src/components/settings/models/subcomponents/LeadWorkerSettings.tsx @@ -21,7 +21,9 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) { const [failureThreshold, setFailureThreshold] = useState(2); const [fallbackTurns, setFallbackTurns] = useState(2); const [isEnabled, setIsEnabled] = useState(false); - const [modelOptions, setModelOptions] = useState<{ value: string; label: string; provider: string }[]>([]); + const [modelOptions, setModelOptions] = useState< + { value: string; label: string; provider: string }[] + >([]); const [isLoading, setIsLoading] = useState(true); // Load current configuration @@ -51,7 +53,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) { if (leadTurnsConfig) setLeadTurns(Number(leadTurnsConfig)); if (failureThresholdConfig) setFailureThreshold(Number(failureThresholdConfig)); if (fallbackTurnsConfig) setFallbackTurns(Number(fallbackTurnsConfig)); - + // Set worker model to current model or from config const workerModelConfig = await read('GOOSE_MODEL', false); if (workerModelConfig) { @@ -59,7 +61,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) { } else if (currentModel) { setWorkerModel(currentModel as string); } - + const workerProviderConfig = await read('GOOSE_PROVIDER', false); if (workerProviderConfig) { setWorkerProvider(workerProviderConfig as string); @@ -69,7 +71,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) { const providers = await getProviders(false); const activeProviders = providers.filter((p) => p.is_configured); const options: { value: string; label: string; provider: string }[] = []; - + activeProviders.forEach(({ metadata, name }) => { if (metadata.known_models) { metadata.known_models.forEach((model) => { @@ -81,7 +83,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) { }); } }); - + setModelOptions(options); } catch (error) { console.error('Error loading configuration:', error); @@ -184,9 +186,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) { placeholder="Select worker model..." isDisabled={!isEnabled} /> -

- Fast model for routine execution tasks -

+

Fast model for routine execution tasks

@@ -242,9 +242,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) { className="w-20" disabled={!isEnabled} /> -

- Turns to use lead model during fallback -

+

Turns to use lead model during fallback

@@ -259,4 +257,4 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
); -} \ No newline at end of file +} diff --git a/ui/desktop/src/hooks/useCurrentModel.ts b/ui/desktop/src/hooks/useCurrentModel.ts index 4de74a2b..57f7f3d2 100644 --- a/ui/desktop/src/hooks/useCurrentModel.ts +++ b/ui/desktop/src/hooks/useCurrentModel.ts @@ -2,9 +2,9 @@ import { useCurrentModelInfo } from '../components/ChatView'; export function useCurrentModel() { const modelInfo = useCurrentModelInfo(); - - return { + + return { currentModel: modelInfo?.model || null, - isLoading: false + isLoading: false, }; -} \ No newline at end of file +} diff --git a/ui/desktop/src/hooks/useMessageStream.ts b/ui/desktop/src/hooks/useMessageStream.ts index f0177c17..2195035b 100644 --- a/ui/desktop/src/hooks/useMessageStream.ts +++ b/ui/desktop/src/hooks/useMessageStream.ts @@ -2,12 +2,26 @@ import { useState, useCallback, useEffect, useRef, useId } from 'react'; import useSWR from 'swr'; import { getSecretKey } from '../config'; import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message'; +import { getSessionHistory } from '../api'; // Ensure TextDecoder is available in the global scope const TextDecoder = globalThis.TextDecoder; type JsonValue = string | number | boolean | null | JsonValue[] | { [key: string]: JsonValue }; +export interface SessionMetadata { + workingDir: string; + description: string; + scheduleId: string | null; + messageCount: number; + totalTokens: number | null; + inputTokens: number | null; + outputTokens: number | null; + accumulatedTotalTokens: number | null; + accumulatedInputTokens: number | null; + accumulatedOutputTokens: number | null; +} + export interface NotificationEvent { type: 'Notification'; request_id: string; @@ -141,9 +155,12 @@ export interface UseMessageStreamHelpers { updateMessageStreamBody?: (newBody: object) => void; notifications: NotificationEvent[]; - + /** Current model info from the backend */ currentModelInfo: { model: string; mode: string } | null; + + /** Session metadata including token counts */ + sessionMetadata: SessionMetadata | null; } /** @@ -172,7 +189,10 @@ export function useMessageStream({ }); const [notifications, setNotifications] = useState([]); - const [currentModelInfo, setCurrentModelInfo] = useState<{ model: string; mode: string } | null>(null); + const [currentModelInfo, setCurrentModelInfo] = useState<{ model: string; mode: string } | null>( + null + ); + const [sessionMetadata, setSessionMetadata] = useState(null); // expose a way to update the body so we can update the session id when CLE occurs const updateMessageStreamBody = useCallback((newBody: object) => { @@ -291,13 +311,41 @@ export function useMessageStream({ case 'Error': throw new Error(parsedEvent.error); - case 'Finish': + case 'Finish': { // Call onFinish with the last message if available if (onFinish && currentMessages.length > 0) { const lastMessage = currentMessages[currentMessages.length - 1]; onFinish(lastMessage, parsedEvent.reason); } + + // Fetch updated session metadata with token counts + const sessionId = (extraMetadataRef.current.body as Record)?.session_id as string; + if (sessionId) { + try { + const sessionResponse = await getSessionHistory({ + path: { session_id: sessionId }, + }); + + if (sessionResponse.data?.metadata) { + setSessionMetadata({ + workingDir: sessionResponse.data.metadata.working_dir, + description: sessionResponse.data.metadata.description, + scheduleId: sessionResponse.data.metadata.schedule_id || null, + messageCount: sessionResponse.data.metadata.message_count, + totalTokens: sessionResponse.data.metadata.total_tokens || null, + inputTokens: sessionResponse.data.metadata.input_tokens || null, + outputTokens: sessionResponse.data.metadata.output_tokens || null, + accumulatedTotalTokens: sessionResponse.data.metadata.accumulated_total_tokens || null, + accumulatedInputTokens: sessionResponse.data.metadata.accumulated_input_tokens || null, + accumulatedOutputTokens: sessionResponse.data.metadata.accumulated_output_tokens || null, + }); + } + } catch (error) { + console.error('Failed to fetch session metadata:', error); + } + } break; + } } } catch (e) { console.error('Error parsing SSE event:', e); @@ -559,5 +607,6 @@ export function useMessageStream({ updateMessageStreamBody, notifications, currentModelInfo, + sessionMetadata, }; } diff --git a/ui/desktop/src/sessions.ts b/ui/desktop/src/sessions.ts index b4ba4374..2ebb1e5d 100644 --- a/ui/desktop/src/sessions.ts +++ b/ui/desktop/src/sessions.ts @@ -7,6 +7,10 @@ export interface SessionMetadata { message_count: number; total_tokens: number | null; working_dir: string; // Required in type, but may be missing in old sessions + // Add the accumulated token fields from the API + accumulated_input_tokens?: number | null; + accumulated_output_tokens?: number | null; + accumulated_total_tokens?: number | null; } // Helper function to ensure working directory is set @@ -16,6 +20,9 @@ export function ensureWorkingDir(metadata: Partial): SessionMet message_count: metadata.message_count || 0, total_tokens: metadata.total_tokens || null, working_dir: metadata.working_dir || process.env.HOME || '', + accumulated_input_tokens: metadata.accumulated_input_tokens || null, + accumulated_output_tokens: metadata.accumulated_output_tokens || null, + accumulated_total_tokens: metadata.accumulated_total_tokens || null, }; } diff --git a/ui/desktop/src/utils/costDatabase.ts b/ui/desktop/src/utils/costDatabase.ts new file mode 100644 index 00000000..fcbb603e --- /dev/null +++ b/ui/desktop/src/utils/costDatabase.ts @@ -0,0 +1,593 @@ +// Import the proper type from ConfigContext +import { getApiUrl, getSecretKey } from '../config'; + +export interface ModelCostInfo { + input_token_cost: number; // Cost per token for input (in USD) + output_token_cost: number; // Cost per token for output (in USD) + currency: string; // Currency symbol +} + +// In-memory cache for current model pricing only +let currentModelPricing: { + provider: string; + model: string; + costInfo: ModelCostInfo | null; +} | null = null; + +// LocalStorage keys +const PRICING_CACHE_KEY = 'goose_pricing_cache'; +const PRICING_CACHE_TIMESTAMP_KEY = 'goose_pricing_cache_timestamp'; +const RECENTLY_USED_MODELS_KEY = 'goose_recently_used_models'; +const CACHE_TTL_MS = 7 * 24 * 60 * 60 * 1000; // 7 days in milliseconds +const MAX_RECENTLY_USED_MODELS = 20; // Keep only the last 20 used models in cache + +interface PricingItem { + provider: string; + model: string; + input_token_cost: number; + output_token_cost: number; + currency: string; +} + +interface PricingCacheData { + pricing: PricingItem[]; + timestamp: number; +} + +interface RecentlyUsedModel { + provider: string; + model: string; + lastUsed: number; +} + +/** + * Get recently used models from localStorage + */ +function getRecentlyUsedModels(): RecentlyUsedModel[] { + try { + const stored = localStorage.getItem(RECENTLY_USED_MODELS_KEY); + return stored ? JSON.parse(stored) : []; + } catch (error) { + console.error('Error loading recently used models:', error); + return []; + } +} + +/** + * Add a model to the recently used list + */ +function addToRecentlyUsed(provider: string, model: string): void { + try { + let recentModels = getRecentlyUsedModels(); + + // Remove existing entry if present + recentModels = recentModels.filter((m) => !(m.provider === provider && m.model === model)); + + // Add to front + recentModels.unshift({ provider, model, lastUsed: Date.now() }); + + // Keep only the most recent models + recentModels = recentModels.slice(0, MAX_RECENTLY_USED_MODELS); + + localStorage.setItem(RECENTLY_USED_MODELS_KEY, JSON.stringify(recentModels)); + } catch (error) { + console.error('Error saving recently used models:', error); + } +} + +/** + * Load pricing data from localStorage cache - only for recently used models + */ +function loadPricingFromLocalStorage(): PricingCacheData | null { + try { + const cached = localStorage.getItem(PRICING_CACHE_KEY); + const timestamp = localStorage.getItem(PRICING_CACHE_TIMESTAMP_KEY); + + if (cached && timestamp) { + const cacheAge = Date.now() - parseInt(timestamp, 10); + if (cacheAge < CACHE_TTL_MS) { + const fullCache = JSON.parse(cached) as PricingCacheData; + const recentModels = getRecentlyUsedModels(); + + // Filter to only include recently used models + const filteredPricing = fullCache.pricing.filter((p) => + recentModels.some((r) => r.provider === p.provider && r.model === p.model) + ); + + console.log( + `Loading ${filteredPricing.length} recently used models from cache (out of ${fullCache.pricing.length} total)` + ); + + return { + pricing: filteredPricing, + timestamp: fullCache.timestamp, + }; + } else { + console.log('LocalStorage pricing cache expired'); + } + } + } catch (error) { + console.error('Error loading pricing from localStorage:', error); + } + return null; +} + +/** + * Save pricing data to localStorage - merge with existing data + */ +function savePricingToLocalStorage(data: PricingCacheData, mergeWithExisting = true): void { + try { + if (mergeWithExisting) { + // Load existing full cache + const existingCached = localStorage.getItem(PRICING_CACHE_KEY); + if (existingCached) { + const existingData = JSON.parse(existingCached) as PricingCacheData; + + // Create a map of existing pricing for quick lookup + const pricingMap = new Map(); + existingData.pricing.forEach((p) => { + pricingMap.set(`${p.provider}/${p.model}`, p); + }); + + // Update with new data + data.pricing.forEach((p) => { + pricingMap.set(`${p.provider}/${p.model}`, p); + }); + + // Convert back to array + data = { + pricing: Array.from(pricingMap.values()), + timestamp: data.timestamp, + }; + } + } + + localStorage.setItem(PRICING_CACHE_KEY, JSON.stringify(data)); + localStorage.setItem(PRICING_CACHE_TIMESTAMP_KEY, data.timestamp.toString()); + console.log(`Saved ${data.pricing.length} models to localStorage cache`); + } catch (error) { + console.error('Error saving pricing to localStorage:', error); + } +} + +/** + * Fetch pricing data from backend for specific provider/model + */ +async function fetchPricingForModel( + provider: string, + model: string +): Promise { + try { + const apiUrl = getApiUrl('/config/pricing'); + const secretKey = getSecretKey(); + + console.log(`Fetching pricing for ${provider}/${model} from ${apiUrl}`); + + const headers: HeadersInit = { 'Content-Type': 'application/json' }; + if (secretKey) { + headers['X-Secret-Key'] = secretKey; + } + + const response = await fetch(apiUrl, { + method: 'POST', + headers, + body: JSON.stringify({ configured_only: false }), + }); + + if (!response.ok) { + console.error('Failed to fetch pricing data:', response.status); + throw new Error(`API request failed with status ${response.status}`); + } + + const data = await response.json(); + console.log('Pricing response:', data); + + // Find the specific model pricing + const pricing = data.pricing?.find( + (p: { + provider: string; + model: string; + input_token_cost: number; + output_token_cost: number; + currency: string; + }) => { + const providerMatch = p.provider.toLowerCase() === provider.toLowerCase(); + + // More flexible model matching - handle versioned models + let modelMatch = p.model === model; + + // If exact match fails, try matching without version suffix + if (!modelMatch && model.includes('-20')) { + // Remove date suffix like -20241022 + const modelWithoutDate = model.replace(/-20\d{6}$/, ''); + modelMatch = p.model === modelWithoutDate; + + // Also try with dots instead of dashes (claude-3-5-sonnet vs claude-3.5-sonnet) + if (!modelMatch) { + const modelWithDots = modelWithoutDate.replace(/-(\d)-/g, '.$1.'); + modelMatch = p.model === modelWithDots; + } + } + + console.log( + `Comparing: ${p.provider}/${p.model} with ${provider}/${model} - Provider match: ${providerMatch}, Model match: ${modelMatch}` + ); + return providerMatch && modelMatch; + } + ); + + console.log(`Found pricing for ${provider}/${model}:`, pricing); + + if (pricing) { + return { + input_token_cost: pricing.input_token_cost, + output_token_cost: pricing.output_token_cost, + currency: pricing.currency || '$', + }; + } + + console.log( + `No pricing found for ${provider}/${model} in:`, + data.pricing?.map((p: { provider: string; model: string }) => `${p.provider}/${p.model}`) + ); + + // API call succeeded but model not found in pricing data + return null; + } catch (error) { + console.error('Error fetching pricing data:', error); + // Re-throw the error so the caller can distinguish between API failure and model not found + throw error; + } +} + +/** + * Initialize the cost database - only load commonly used models on startup + */ +export async function initializeCostDatabase(): Promise { + try { + // Clean up any existing large caches first + cleanupPricingCache(); + + // First check if we have valid cached data + const cachedData = loadPricingFromLocalStorage(); + if (cachedData && cachedData.pricing.length > 0) { + console.log('Using cached pricing data from localStorage'); + return; + } + + // List of commonly used models to pre-fetch + const commonModels = [ + { provider: 'openai', model: 'gpt-4o' }, + { provider: 'openai', model: 'gpt-4o-mini' }, + { provider: 'openai', model: 'gpt-4-turbo' }, + { provider: 'openai', model: 'gpt-4' }, + { provider: 'openai', model: 'gpt-3.5-turbo' }, + { provider: 'anthropic', model: 'claude-3-5-sonnet' }, + { provider: 'anthropic', model: 'claude-3-5-sonnet-20241022' }, + { provider: 'anthropic', model: 'claude-3-opus' }, + { provider: 'anthropic', model: 'claude-3-sonnet' }, + { provider: 'anthropic', model: 'claude-3-haiku' }, + { provider: 'google', model: 'gemini-1.5-pro' }, + { provider: 'google', model: 'gemini-1.5-flash' }, + { provider: 'deepseek', model: 'deepseek-chat' }, + { provider: 'deepseek', model: 'deepseek-reasoner' }, + { provider: 'meta-llama', model: 'llama-3.2-90b-text-preview' }, + { provider: 'meta-llama', model: 'llama-3.1-405b-instruct' }, + ]; + + // Get recently used models + const recentModels = getRecentlyUsedModels(); + + // Combine common and recent models (deduplicated) + const modelsToFetch = new Map(); + + // Add common models + commonModels.forEach((m) => { + modelsToFetch.set(`${m.provider}/${m.model}`, m); + }); + + // Add recent models + recentModels.forEach((m) => { + modelsToFetch.set(`${m.provider}/${m.model}`, { provider: m.provider, model: m.model }); + }); + + console.log(`Initializing cost database with ${modelsToFetch.size} models...`); + + // Fetch only the pricing we need + const apiUrl = getApiUrl('/config/pricing'); + const secretKey = getSecretKey(); + + const headers: HeadersInit = { 'Content-Type': 'application/json' }; + if (secretKey) { + headers['X-Secret-Key'] = secretKey; + } + + const response = await fetch(apiUrl, { + method: 'POST', + headers, + body: JSON.stringify({ + configured_only: false, + models: Array.from(modelsToFetch.values()), // Send specific models if API supports it + }), + }); + + if (!response.ok) { + console.error('Failed to fetch initial pricing data:', response.status); + return; + } + + const data = await response.json(); + console.log(`Fetched pricing for ${data.pricing?.length || 0} models`); + + if (data.pricing && data.pricing.length > 0) { + // Filter to only the models we requested (in case API returns all) + const filteredPricing = data.pricing.filter((p: PricingItem) => + modelsToFetch.has(`${p.provider}/${p.model}`) + ); + + // Save to localStorage + const cacheData: PricingCacheData = { + pricing: filteredPricing.length > 0 ? filteredPricing : data.pricing.slice(0, 50), // Fallback to first 50 if filtering didn't work + timestamp: Date.now(), + }; + savePricingToLocalStorage(cacheData, false); // Don't merge on initial load + } + } catch (error) { + console.error('Error initializing cost database:', error); + } +} + +/** + * Update model costs from providers - no longer needed + */ +export async function updateAllModelCosts(): Promise { + // No-op - we fetch on demand now +} + +/** + * Get cost information for a specific model with caching + */ +export function getCostForModel(provider: string, model: string): ModelCostInfo | null { + // Track this model as recently used + addToRecentlyUsed(provider, model); + + // Check if it's the same model we already have cached in memory + if ( + currentModelPricing && + currentModelPricing.provider === provider && + currentModelPricing.model === model + ) { + return currentModelPricing.costInfo; + } + + // For local/free providers, return zero cost immediately + const freeProviders = ['ollama', 'local', 'localhost']; + if (freeProviders.includes(provider.toLowerCase())) { + const zeroCost = { + input_token_cost: 0, + output_token_cost: 0, + currency: '$', + }; + currentModelPricing = { provider, model, costInfo: zeroCost }; + return zeroCost; + } + + // Check localStorage cache (which now only contains recently used models) + const cachedData = loadPricingFromLocalStorage(); + if (cachedData) { + const pricing = cachedData.pricing.find((p) => { + const providerMatch = p.provider.toLowerCase() === provider.toLowerCase(); + + // More flexible model matching - handle versioned models + let modelMatch = p.model === model; + + // If exact match fails, try matching without version suffix + if (!modelMatch && model.includes('-20')) { + // Remove date suffix like -20241022 + const modelWithoutDate = model.replace(/-20\d{6}$/, ''); + modelMatch = p.model === modelWithoutDate; + + // Also try with dots instead of dashes (claude-3-5-sonnet vs claude-3.5-sonnet) + if (!modelMatch) { + const modelWithDots = modelWithoutDate.replace(/-(\d)-/g, '.$1.'); + modelMatch = p.model === modelWithDots; + } + } + + return providerMatch && modelMatch; + }); + + if (pricing) { + const costInfo = { + input_token_cost: pricing.input_token_cost, + output_token_cost: pricing.output_token_cost, + currency: pricing.currency || '$', + }; + currentModelPricing = { provider, model, costInfo }; + return costInfo; + } + } + + // Need to fetch new pricing - return null for now + // The component will handle the async fetch + return null; +} + +/** + * Fetch and cache pricing for a model + */ +export async function fetchAndCachePricing( + provider: string, + model: string +): Promise<{ costInfo: ModelCostInfo | null; error?: string } | null> { + try { + const costInfo = await fetchPricingForModel(provider, model); + + if (costInfo) { + // Cache the result in memory + currentModelPricing = { provider, model, costInfo }; + + // Update localStorage cache with this new data + const cachedData = loadPricingFromLocalStorage(); + if (cachedData) { + // Check if this model already exists in cache + const existingIndex = cachedData.pricing.findIndex( + (p) => p.provider.toLowerCase() === provider.toLowerCase() && p.model === model + ); + + const newPricing = { + provider, + model, + input_token_cost: costInfo.input_token_cost, + output_token_cost: costInfo.output_token_cost, + currency: costInfo.currency, + }; + + if (existingIndex >= 0) { + // Update existing + cachedData.pricing[existingIndex] = newPricing; + } else { + // Add new + cachedData.pricing.push(newPricing); + } + + // Save updated cache + savePricingToLocalStorage(cachedData); + } + + return { costInfo }; + } else { + // Cache the null result in memory + currentModelPricing = { provider, model, costInfo: null }; + + // Check if the API call succeeded but model wasn't found + // We can determine this by checking if we got a response but no matching model + return { costInfo: null, error: 'model_not_found' }; + } + } catch (error) { + console.error('Error in fetchAndCachePricing:', error); + // This is a real API/network error + return null; + } +} + +/** + * Refresh pricing data from backend - only refresh recently used models + */ +export async function refreshPricing(): Promise { + try { + const apiUrl = getApiUrl('/config/pricing'); + const secretKey = getSecretKey(); + + const headers: HeadersInit = { 'Content-Type': 'application/json' }; + if (secretKey) { + headers['X-Secret-Key'] = secretKey; + } + + // Get recently used models to refresh + const recentModels = getRecentlyUsedModels(); + + // Add some common models as well + const commonModels = [ + { provider: 'openai', model: 'gpt-4o' }, + { provider: 'openai', model: 'gpt-4o-mini' }, + { provider: 'anthropic', model: 'claude-3-5-sonnet-20241022' }, + { provider: 'google', model: 'gemini-1.5-pro' }, + ]; + + // Combine and deduplicate + const modelsToRefresh = new Map(); + + commonModels.forEach((m) => { + modelsToRefresh.set(`${m.provider}/${m.model}`, m); + }); + + recentModels.forEach((m) => { + modelsToRefresh.set(`${m.provider}/${m.model}`, { provider: m.provider, model: m.model }); + }); + + console.log(`Refreshing pricing for ${modelsToRefresh.size} models...`); + + const response = await fetch(apiUrl, { + method: 'POST', + headers, + body: JSON.stringify({ + configured_only: false, + models: Array.from(modelsToRefresh.values()), // Send specific models if API supports it + }), + }); + + if (response.ok) { + const data = await response.json(); + + if (data.pricing && data.pricing.length > 0) { + // Filter to only the models we requested (in case API returns all) + const filteredPricing = data.pricing.filter((p: PricingItem) => + modelsToRefresh.has(`${p.provider}/${p.model}`) + ); + + // Save fresh data to localStorage (merge with existing) + const cacheData: PricingCacheData = { + pricing: filteredPricing.length > 0 ? filteredPricing : data.pricing.slice(0, 50), + timestamp: Date.now(), + }; + savePricingToLocalStorage(cacheData, true); // Merge with existing + } + + // Clear current memory cache to force re-fetch + currentModelPricing = null; + return true; + } + + return false; + } catch (error) { + console.error('Error refreshing pricing data:', error); + return false; + } +} + +/** + * Clean up old/unused models from the cache + */ +export function cleanupPricingCache(): void { + try { + const recentModels = getRecentlyUsedModels(); + const cachedData = localStorage.getItem(PRICING_CACHE_KEY); + + if (!cachedData) return; + + const fullCache = JSON.parse(cachedData) as PricingCacheData; + const recentModelKeys = new Set(recentModels.map((m) => `${m.provider}/${m.model}`)); + + // Keep only recently used models and common models + const commonModelKeys = new Set([ + 'openai/gpt-4o', + 'openai/gpt-4o-mini', + 'openai/gpt-4-turbo', + 'anthropic/claude-3-5-sonnet', + 'anthropic/claude-3-5-sonnet-20241022', + 'google/gemini-1.5-pro', + 'google/gemini-1.5-flash', + ]); + + const filteredPricing = fullCache.pricing.filter((p) => { + const key = `${p.provider}/${p.model}`; + return recentModelKeys.has(key) || commonModelKeys.has(key); + }); + + if (filteredPricing.length < fullCache.pricing.length) { + console.log( + `Cleaned up pricing cache: reduced from ${fullCache.pricing.length} to ${filteredPricing.length} models` + ); + + const cleanedCache: PricingCacheData = { + pricing: filteredPricing, + timestamp: fullCache.timestamp, + }; + + localStorage.setItem(PRICING_CACHE_KEY, JSON.stringify(cleanedCache)); + } + } catch (error) { + console.error('Error cleaning up pricing cache:', error); + } +} diff --git a/ui/desktop/tailwind.config.ts b/ui/desktop/tailwind.config.ts index 1a62cc50..5b0a92d1 100644 --- a/ui/desktop/tailwind.config.ts +++ b/ui/desktop/tailwind.config.ts @@ -44,6 +44,10 @@ export default { '0%': { transform: 'rotate(0deg)' }, '100%': { transform: 'rotate(360deg)' }, }, + 'spin-fast': { + '0%': { transform: 'rotate(0deg)' }, + '100%': { transform: 'rotate(360deg)' }, + }, indeterminate: { '0%': { left: '-40%', width: '40%' }, '50%': { left: '20%', width: '60%' }, @@ -54,6 +58,7 @@ export default { 'shimmer-pulse': 'shimmer 4s ease-in-out infinite', 'gradient-loader': 'loader 750ms ease-in-out infinite', indeterminate: 'indeterminate 1.5s infinite linear', + 'spin-fast': 'spin-fast 0.5s linear infinite', }, colors: { bgApp: 'var(--background-app)',