mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-21 16:14:21 +01:00
feat: Add comprehensive cost tracking display for LLM usage (#2992)
Co-authored-by: jack <jack@deck.local> Co-authored-by: Bradley Axen <baxen@squareup.com>
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -3427,6 +3427,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"criterion",
|
||||
"ctor",
|
||||
"dirs 5.0.1",
|
||||
"dotenv",
|
||||
"etcetera",
|
||||
"fs2",
|
||||
|
||||
@@ -10,12 +10,23 @@ use goose::scheduler_factory::SchedulerFactory;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tracing::info;
|
||||
|
||||
use goose::providers::pricing::initialize_pricing_cache;
|
||||
|
||||
pub async fn run() -> Result<()> {
|
||||
// Initialize logging
|
||||
crate::logging::setup_logging(Some("goosed"))?;
|
||||
|
||||
let settings = configuration::Settings::new()?;
|
||||
|
||||
// Initialize pricing cache on startup
|
||||
tracing::info!("Initializing pricing cache...");
|
||||
if let Err(e) = initialize_pricing_cache().await {
|
||||
tracing::warn!(
|
||||
"Failed to initialize pricing cache: {}. Pricing data may not be available.",
|
||||
e
|
||||
);
|
||||
}
|
||||
|
||||
let secret_key =
|
||||
std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string());
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ use goose::config::{extensions::name_to_key, PermissionManager};
|
||||
use goose::config::{ExtensionConfigManager, ExtensionEntry};
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::ProviderMetadata;
|
||||
use goose::providers::pricing::{get_all_pricing, get_model_pricing, refresh_pricing};
|
||||
use goose::providers::providers as get_providers;
|
||||
use goose::{agents::ExtensionConfig, config::permission::PermissionLevel};
|
||||
use http::{HeaderMap, StatusCode};
|
||||
@@ -314,6 +315,128 @@ pub async fn providers(
|
||||
Ok(Json(providers_response))
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub struct PricingData {
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
pub input_token_cost: f64,
|
||||
pub output_token_cost: f64,
|
||||
pub currency: String,
|
||||
pub context_length: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
pub struct PricingResponse {
|
||||
pub pricing: Vec<PricingData>,
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, ToSchema)]
|
||||
pub struct PricingQuery {
|
||||
/// If true, only return pricing for configured providers. If false, return all.
|
||||
pub configured_only: Option<bool>,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/config/pricing",
|
||||
request_body = PricingQuery,
|
||||
responses(
|
||||
(status = 200, description = "Model pricing data retrieved successfully", body = PricingResponse)
|
||||
)
|
||||
)]
|
||||
pub async fn get_pricing(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Json(query): Json<PricingQuery>,
|
||||
) -> Result<Json<PricingResponse>, StatusCode> {
|
||||
verify_secret_key(&headers, &state)?;
|
||||
|
||||
let configured_only = query.configured_only.unwrap_or(true);
|
||||
|
||||
// If refresh requested (configured_only = false), refresh the cache
|
||||
if !configured_only {
|
||||
if let Err(e) = refresh_pricing().await {
|
||||
tracing::error!("Failed to refresh pricing data: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
let mut pricing_data = Vec::new();
|
||||
|
||||
if !configured_only {
|
||||
// Get ALL pricing data from the cache
|
||||
let all_pricing = get_all_pricing().await;
|
||||
|
||||
for (provider, models) in all_pricing {
|
||||
for (model, pricing) in models {
|
||||
pricing_data.push(PricingData {
|
||||
provider: provider.clone(),
|
||||
model: model.clone(),
|
||||
input_token_cost: pricing.input_cost,
|
||||
output_token_cost: pricing.output_cost,
|
||||
currency: "$".to_string(),
|
||||
context_length: pricing.context_length,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Get only configured providers' pricing
|
||||
let providers_metadata = get_providers();
|
||||
|
||||
for metadata in providers_metadata {
|
||||
// Skip unconfigured providers if filtering
|
||||
if !check_provider_configured(&metadata) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for model_info in &metadata.known_models {
|
||||
// Try to get pricing from cache
|
||||
if let Some(pricing) = get_model_pricing(&metadata.name, &model_info.name).await {
|
||||
pricing_data.push(PricingData {
|
||||
provider: metadata.name.clone(),
|
||||
model: model_info.name.clone(),
|
||||
input_token_cost: pricing.input_cost,
|
||||
output_token_cost: pricing.output_cost,
|
||||
currency: "$".to_string(),
|
||||
context_length: pricing.context_length,
|
||||
});
|
||||
}
|
||||
// Check if the model has embedded pricing data
|
||||
else if let (Some(input_cost), Some(output_cost)) =
|
||||
(model_info.input_token_cost, model_info.output_token_cost)
|
||||
{
|
||||
pricing_data.push(PricingData {
|
||||
provider: metadata.name.clone(),
|
||||
model: model_info.name.clone(),
|
||||
input_token_cost: input_cost,
|
||||
output_token_cost: output_cost,
|
||||
currency: model_info
|
||||
.currency
|
||||
.clone()
|
||||
.unwrap_or_else(|| "$".to_string()),
|
||||
context_length: Some(model_info.context_limit as u32),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"Returning pricing for {} models{}",
|
||||
pricing_data.len(),
|
||||
if configured_only {
|
||||
" (configured providers only)"
|
||||
} else {
|
||||
" (all cached models)"
|
||||
}
|
||||
);
|
||||
|
||||
Ok(Json(PricingResponse {
|
||||
pricing: pricing_data,
|
||||
source: "openrouter".to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/config/init",
|
||||
@@ -471,6 +594,7 @@ pub fn routes(state: Arc<AppState>) -> Router {
|
||||
.route("/config/extensions", post(add_extension))
|
||||
.route("/config/extensions/{name}", delete(remove_extension))
|
||||
.route("/config/providers", get(providers))
|
||||
.route("/config/pricing", post(get_pricing))
|
||||
.route("/config/init", post(init_config))
|
||||
.route("/config/backup", post(backup_config))
|
||||
.route("/config/permissions", post(upsert_permissions))
|
||||
|
||||
@@ -17,6 +17,7 @@ mcp-core = { path = "../mcp-core" }
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
futures = "0.3"
|
||||
dirs = "5.0"
|
||||
reqwest = { version = "0.12.9", features = [
|
||||
"rustls-tls-native-roots",
|
||||
"json",
|
||||
|
||||
@@ -5,7 +5,7 @@ use reqwest::{Client, StatusCode};
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
|
||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
|
||||
use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
|
||||
use super::errors::ProviderError;
|
||||
use super::formats::anthropic::{create_request, get_usage, response_to_message};
|
||||
use super::utils::{emit_debug_trace, get_model};
|
||||
@@ -122,12 +122,18 @@ impl AnthropicProvider {
|
||||
#[async_trait]
|
||||
impl Provider for AnthropicProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::new(
|
||||
ProviderMetadata::with_models(
|
||||
"anthropic",
|
||||
"Anthropic",
|
||||
"Claude and other models from Anthropic",
|
||||
ANTHROPIC_DEFAULT_MODEL,
|
||||
ANTHROPIC_KNOWN_MODELS.to_vec(),
|
||||
vec![
|
||||
ModelInfo::with_cost("claude-3-5-sonnet-20241022", 200000, 0.000003, 0.000015),
|
||||
ModelInfo::with_cost("claude-3-5-haiku-20241022", 200000, 0.000001, 0.000005),
|
||||
ModelInfo::with_cost("claude-3-opus-20240229", 200000, 0.000015, 0.000075),
|
||||
ModelInfo::with_cost("claude-3-sonnet-20240229", 200000, 0.000003, 0.000015),
|
||||
ModelInfo::with_cost("claude-3-haiku-20240307", 200000, 0.00000025, 0.00000125),
|
||||
],
|
||||
ANTHROPIC_DOC_URL,
|
||||
vec![
|
||||
ConfigKey::new("ANTHROPIC_API_KEY", true, true, None),
|
||||
|
||||
@@ -32,6 +32,41 @@ pub struct ModelInfo {
|
||||
pub name: String,
|
||||
/// The maximum context length this model supports
|
||||
pub context_limit: usize,
|
||||
/// Cost per token for input (optional)
|
||||
pub input_token_cost: Option<f64>,
|
||||
/// Cost per token for output (optional)
|
||||
pub output_token_cost: Option<f64>,
|
||||
/// Currency for the costs (default: "$")
|
||||
pub currency: Option<String>,
|
||||
}
|
||||
|
||||
impl ModelInfo {
|
||||
/// Create a new ModelInfo with just name and context limit
|
||||
pub fn new(name: impl Into<String>, context_limit: usize) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
context_limit,
|
||||
input_token_cost: None,
|
||||
output_token_cost: None,
|
||||
currency: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new ModelInfo with cost information (per token)
|
||||
pub fn with_cost(
|
||||
name: impl Into<String>,
|
||||
context_limit: usize,
|
||||
input_cost: f64,
|
||||
output_cost: f64,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
context_limit,
|
||||
input_token_cost: Some(input_cost),
|
||||
output_token_cost: Some(output_cost),
|
||||
currency: Some("$".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Metadata about a provider's configuration requirements and capabilities
|
||||
@@ -74,6 +109,9 @@ impl ProviderMetadata {
|
||||
.map(|&name| ModelInfo {
|
||||
name: name.to_string(),
|
||||
context_limit: ModelConfig::new(name.to_string()).context_limit(),
|
||||
input_token_cost: None,
|
||||
output_token_cost: None,
|
||||
currency: None,
|
||||
})
|
||||
.collect(),
|
||||
model_doc_link: model_doc_link.to_string(),
|
||||
@@ -81,6 +119,27 @@ impl ProviderMetadata {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new ProviderMetadata with ModelInfo objects that include cost data
|
||||
pub fn with_models(
|
||||
name: &str,
|
||||
display_name: &str,
|
||||
description: &str,
|
||||
default_model: &str,
|
||||
models: Vec<ModelInfo>,
|
||||
model_doc_link: &str,
|
||||
config_keys: Vec<ConfigKey>,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
display_name: display_name.to_string(),
|
||||
description: description.to_string(),
|
||||
default_model: default_model.to_string(),
|
||||
known_models: models,
|
||||
model_doc_link: model_doc_link.to_string(),
|
||||
config_keys,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
name: "".to_string(),
|
||||
@@ -313,6 +372,9 @@ mod tests {
|
||||
let info = ModelInfo {
|
||||
name: "test-model".to_string(),
|
||||
context_limit: 1000,
|
||||
input_token_cost: None,
|
||||
output_token_cost: None,
|
||||
currency: None,
|
||||
};
|
||||
assert_eq!(info.context_limit, 1000);
|
||||
|
||||
@@ -320,6 +382,9 @@ mod tests {
|
||||
let info2 = ModelInfo {
|
||||
name: "test-model".to_string(),
|
||||
context_limit: 1000,
|
||||
input_token_cost: None,
|
||||
output_token_cost: None,
|
||||
currency: None,
|
||||
};
|
||||
assert_eq!(info, info2);
|
||||
|
||||
@@ -327,7 +392,20 @@ mod tests {
|
||||
let info3 = ModelInfo {
|
||||
name: "test-model".to_string(),
|
||||
context_limit: 2000,
|
||||
input_token_cost: None,
|
||||
output_token_cost: None,
|
||||
currency: None,
|
||||
};
|
||||
assert_ne!(info, info3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_info_with_cost() {
|
||||
let info = ModelInfo::with_cost("gpt-4o", 128000, 0.0000025, 0.00001);
|
||||
assert_eq!(info.name, "gpt-4o");
|
||||
assert_eq!(info.context_limit, 128000);
|
||||
assert_eq!(info.input_token_cost, Some(0.0000025));
|
||||
assert_eq!(info.output_token_cost, Some(0.00001));
|
||||
assert_eq!(info.currency, Some("$".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ pub mod oauth;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
pub mod openrouter;
|
||||
pub mod pricing;
|
||||
pub mod sagemaker_tgi;
|
||||
pub mod snowflake;
|
||||
pub mod toolshim;
|
||||
|
||||
@@ -5,7 +5,7 @@ use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||
use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||
use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse};
|
||||
use super::errors::ProviderError;
|
||||
use super::formats::openai::{create_request, get_usage, response_to_message};
|
||||
@@ -126,12 +126,20 @@ impl OpenAiProvider {
|
||||
#[async_trait]
|
||||
impl Provider for OpenAiProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::new(
|
||||
ProviderMetadata::with_models(
|
||||
"openai",
|
||||
"OpenAI",
|
||||
"GPT-4 and other OpenAI models, including OpenAI compatible ones",
|
||||
OPEN_AI_DEFAULT_MODEL,
|
||||
OPEN_AI_KNOWN_MODELS.to_vec(),
|
||||
vec![
|
||||
ModelInfo::with_cost("gpt-4o", 128000, 0.0000025, 0.00001),
|
||||
ModelInfo::with_cost("gpt-4o-mini", 128000, 0.00000015, 0.0000006),
|
||||
ModelInfo::with_cost("gpt-4-turbo", 128000, 0.00001, 0.00003),
|
||||
ModelInfo::with_cost("gpt-3.5-turbo", 16385, 0.0000005, 0.0000015),
|
||||
ModelInfo::with_cost("o1", 200000, 0.000015, 0.00006),
|
||||
ModelInfo::with_cost("o3", 200000, 0.000015, 0.00006), // Using o1 pricing as placeholder
|
||||
ModelInfo::with_cost("o4-mini", 128000, 0.000003, 0.000012), // Using o1-mini pricing as placeholder
|
||||
],
|
||||
OPEN_AI_DOC_URL,
|
||||
vec![
|
||||
ConfigKey::new("OPENAI_API_KEY", true, true, None),
|
||||
|
||||
387
crates/goose/src/providers/pricing.rs
Normal file
387
crates/goose/src/providers/pricing.rs
Normal file
@@ -0,0 +1,387 @@
|
||||
use anyhow::Result;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Disk cache configuration
|
||||
const CACHE_FILE_NAME: &str = "pricing_cache.json";
|
||||
const CACHE_TTL_DAYS: u64 = 7; // Cache for 7 days
|
||||
|
||||
/// Get the cache directory path
|
||||
fn get_cache_dir() -> Result<PathBuf> {
|
||||
let cache_dir = if let Ok(goose_dir) = std::env::var("GOOSE_CACHE_DIR") {
|
||||
PathBuf::from(goose_dir)
|
||||
} else {
|
||||
dirs::cache_dir()
|
||||
.ok_or_else(|| anyhow::anyhow!("Could not determine cache directory"))?
|
||||
.join("goose")
|
||||
};
|
||||
Ok(cache_dir)
|
||||
}
|
||||
|
||||
/// Cached pricing data structure for disk storage
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CachedPricingData {
|
||||
/// Nested HashMap: provider -> model -> pricing info
|
||||
pub pricing: HashMap<String, HashMap<String, PricingInfo>>,
|
||||
/// Unix timestamp when data was fetched
|
||||
pub fetched_at: u64,
|
||||
}
|
||||
|
||||
/// Simplified pricing info for efficient storage
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PricingInfo {
|
||||
pub input_cost: f64, // Cost per token
|
||||
pub output_cost: f64, // Cost per token
|
||||
pub context_length: Option<u32>,
|
||||
}
|
||||
|
||||
/// Cache for OpenRouter pricing data with disk persistence
|
||||
pub struct PricingCache {
|
||||
/// In-memory cache
|
||||
memory_cache: Arc<RwLock<Option<CachedPricingData>>>,
|
||||
}
|
||||
|
||||
impl PricingCache {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
memory_cache: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load pricing from disk cache
|
||||
async fn load_from_disk(&self) -> Result<Option<CachedPricingData>> {
|
||||
let cache_path = get_cache_dir()?.join(CACHE_FILE_NAME);
|
||||
|
||||
if !cache_path.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
match tokio::fs::read(&cache_path).await {
|
||||
Ok(data) => {
|
||||
match serde_json::from_slice::<CachedPricingData>(&data) {
|
||||
Ok(cached) => {
|
||||
// Check if cache is still valid
|
||||
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
|
||||
let age_days = (now - cached.fetched_at) / (24 * 60 * 60);
|
||||
|
||||
if age_days < CACHE_TTL_DAYS {
|
||||
tracing::info!(
|
||||
"Loaded pricing data from disk cache (age: {} days)",
|
||||
age_days
|
||||
);
|
||||
Ok(Some(cached))
|
||||
} else {
|
||||
tracing::info!("Disk cache expired (age: {} days)", age_days);
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse pricing cache: {}", e);
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to read pricing cache: {}", e);
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Save pricing data to disk
|
||||
async fn save_to_disk(&self, data: &CachedPricingData) -> Result<()> {
|
||||
let cache_dir = get_cache_dir()?;
|
||||
tokio::fs::create_dir_all(&cache_dir).await?;
|
||||
|
||||
let cache_path = cache_dir.join(CACHE_FILE_NAME);
|
||||
let json_data = serde_json::to_vec_pretty(data)?;
|
||||
tokio::fs::write(&cache_path, json_data).await?;
|
||||
|
||||
tracing::info!("Saved pricing data to disk cache");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get pricing for a specific model
|
||||
pub async fn get_model_pricing(&self, provider: &str, model: &str) -> Option<PricingInfo> {
|
||||
// Try memory cache first
|
||||
{
|
||||
let cache = self.memory_cache.read().await;
|
||||
if let Some(cached) = &*cache {
|
||||
return cached
|
||||
.pricing
|
||||
.get(&provider.to_lowercase())
|
||||
.and_then(|models| models.get(model))
|
||||
.cloned();
|
||||
}
|
||||
}
|
||||
|
||||
// Try loading from disk
|
||||
if let Ok(Some(disk_cache)) = self.load_from_disk().await {
|
||||
// Update memory cache
|
||||
{
|
||||
let mut cache = self.memory_cache.write().await;
|
||||
*cache = Some(disk_cache.clone());
|
||||
}
|
||||
|
||||
return disk_cache
|
||||
.pricing
|
||||
.get(&provider.to_lowercase())
|
||||
.and_then(|models| models.get(model))
|
||||
.cloned();
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Force refresh pricing data from OpenRouter
|
||||
pub async fn refresh(&self) -> Result<()> {
|
||||
let pricing = fetch_openrouter_pricing_internal().await?;
|
||||
|
||||
// Convert to our efficient structure
|
||||
let mut structured_pricing: HashMap<String, HashMap<String, PricingInfo>> = HashMap::new();
|
||||
|
||||
for (model_id, model) in pricing {
|
||||
if let Some((provider, model_name)) = parse_model_id(&model_id) {
|
||||
if let (Some(input_cost), Some(output_cost)) = (
|
||||
convert_pricing(&model.pricing.prompt),
|
||||
convert_pricing(&model.pricing.completion),
|
||||
) {
|
||||
let provider_lower = provider.to_lowercase();
|
||||
let provider_models = structured_pricing.entry(provider_lower).or_default();
|
||||
|
||||
provider_models.insert(
|
||||
model_name,
|
||||
PricingInfo {
|
||||
input_cost,
|
||||
output_cost,
|
||||
context_length: model.context_length,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let cached_data = CachedPricingData {
|
||||
pricing: structured_pricing,
|
||||
fetched_at: SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(),
|
||||
};
|
||||
|
||||
// Log how many models we fetched
|
||||
let total_models: usize = cached_data
|
||||
.pricing
|
||||
.values()
|
||||
.map(|models| models.len())
|
||||
.sum();
|
||||
tracing::info!(
|
||||
"Fetched pricing for {} providers with {} total models from OpenRouter",
|
||||
cached_data.pricing.len(),
|
||||
total_models
|
||||
);
|
||||
|
||||
// Save to disk
|
||||
self.save_to_disk(&cached_data).await?;
|
||||
|
||||
// Update memory cache
|
||||
{
|
||||
let mut cache = self.memory_cache.write().await;
|
||||
*cache = Some(cached_data);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initialize cache (load from disk or fetch if needed)
|
||||
pub async fn initialize(&self) -> Result<()> {
|
||||
// Try loading from disk first
|
||||
if let Ok(Some(cached)) = self.load_from_disk().await {
|
||||
// Log how many models we have cached
|
||||
let total_models: usize = cached.pricing.values().map(|models| models.len()).sum();
|
||||
tracing::info!(
|
||||
"Loaded {} providers with {} total models from disk cache",
|
||||
cached.pricing.len(),
|
||||
total_models
|
||||
);
|
||||
|
||||
// Update memory cache
|
||||
{
|
||||
let mut cache = self.memory_cache.write().await;
|
||||
*cache = Some(cached);
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// If no disk cache, fetch from OpenRouter
|
||||
tracing::info!("No valid disk cache found, fetching from OpenRouter");
|
||||
self.refresh().await
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PricingCache {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// Global cache instance
|
||||
lazy_static::lazy_static! {
|
||||
static ref PRICING_CACHE: PricingCache = PricingCache::new();
|
||||
static ref HTTP_CLIENT: Client = Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.pool_idle_timeout(Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(10)
|
||||
.build()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// OpenRouter model pricing information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenRouterModel {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub pricing: OpenRouterPricing,
|
||||
pub context_length: Option<u32>,
|
||||
pub architecture: Option<Architecture>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenRouterPricing {
|
||||
pub prompt: String, // Cost per token for input (in USD)
|
||||
pub completion: String, // Cost per token for output (in USD)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Architecture {
|
||||
pub modality: String,
|
||||
pub tokenizer: String,
|
||||
pub instruct_type: Option<String>,
|
||||
}
|
||||
|
||||
/// Response from OpenRouter models endpoint
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OpenRouterModelsResponse {
|
||||
pub data: Vec<OpenRouterModel>,
|
||||
}
|
||||
|
||||
/// Internal function to fetch pricing data
|
||||
async fn fetch_openrouter_pricing_internal() -> Result<HashMap<String, OpenRouterModel>> {
|
||||
let response = HTTP_CLIENT
|
||||
.get("https://openrouter.ai/api/v1/models")
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!(
|
||||
"Failed to fetch OpenRouter models: HTTP {}",
|
||||
response.status()
|
||||
);
|
||||
}
|
||||
|
||||
let models_response: OpenRouterModelsResponse = response.json().await?;
|
||||
|
||||
// Create a map for easy lookup
|
||||
let mut pricing_map = HashMap::new();
|
||||
for model in models_response.data {
|
||||
pricing_map.insert(model.id.clone(), model);
|
||||
}
|
||||
|
||||
Ok(pricing_map)
|
||||
}
|
||||
|
||||
/// Initialize pricing cache on startup
|
||||
pub async fn initialize_pricing_cache() -> Result<()> {
|
||||
PRICING_CACHE.initialize().await
|
||||
}
|
||||
|
||||
/// Get pricing for a specific model
|
||||
pub async fn get_model_pricing(provider: &str, model: &str) -> Option<PricingInfo> {
|
||||
PRICING_CACHE.get_model_pricing(provider, model).await
|
||||
}
|
||||
|
||||
/// Force refresh pricing data
|
||||
pub async fn refresh_pricing() -> Result<()> {
|
||||
PRICING_CACHE.refresh().await
|
||||
}
|
||||
|
||||
/// Get all cached pricing data
|
||||
pub async fn get_all_pricing() -> HashMap<String, HashMap<String, PricingInfo>> {
|
||||
let cache = PRICING_CACHE.memory_cache.read().await;
|
||||
if let Some(cached) = &*cache {
|
||||
cached.pricing.clone()
|
||||
} else {
|
||||
// Try loading from disk
|
||||
if let Ok(Some(disk_cache)) = PRICING_CACHE.load_from_disk().await {
|
||||
// Update memory cache
|
||||
drop(cache);
|
||||
let mut write_cache = PRICING_CACHE.memory_cache.write().await;
|
||||
*write_cache = Some(disk_cache.clone());
|
||||
disk_cache.pricing
|
||||
} else {
|
||||
HashMap::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert OpenRouter model ID to provider/model format
|
||||
/// e.g., "anthropic/claude-3.5-sonnet" -> ("anthropic", "claude-3.5-sonnet")
|
||||
pub fn parse_model_id(model_id: &str) -> Option<(String, String)> {
|
||||
let parts: Vec<&str> = model_id.splitn(2, '/').collect();
|
||||
if parts.len() == 2 {
|
||||
// Normalize provider names to match our internal naming
|
||||
let provider = match parts[0] {
|
||||
"openai" => "openai",
|
||||
"anthropic" => "anthropic",
|
||||
"google" => "google",
|
||||
"meta-llama" => "ollama", // Meta models often run via Ollama
|
||||
"mistralai" => "mistral",
|
||||
"cohere" => "cohere",
|
||||
"perplexity" => "perplexity",
|
||||
"deepseek" => "deepseek",
|
||||
"groq" => "groq",
|
||||
"nvidia" => "nvidia",
|
||||
"microsoft" => "azure",
|
||||
"replicate" => "replicate",
|
||||
"huggingface" => "huggingface",
|
||||
_ => parts[0],
|
||||
};
|
||||
Some((provider.to_string(), parts[1].to_string()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert OpenRouter pricing to cost per token (already in that format)
|
||||
pub fn convert_pricing(price_str: &str) -> Option<f64> {
|
||||
// OpenRouter prices are already in USD per token
|
||||
price_str.parse::<f64>().ok()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_model_id() {
|
||||
assert_eq!(
|
||||
parse_model_id("anthropic/claude-3.5-sonnet"),
|
||||
Some(("anthropic".to_string(), "claude-3.5-sonnet".to_string()))
|
||||
);
|
||||
assert_eq!(
|
||||
parse_model_id("openai/gpt-4"),
|
||||
Some(("openai".to_string(), "gpt-4".to_string()))
|
||||
);
|
||||
assert_eq!(parse_model_id("invalid-format"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_pricing() {
|
||||
assert_eq!(convert_pricing("0.000003"), Some(0.000003));
|
||||
assert_eq!(convert_pricing("0.015"), Some(0.015));
|
||||
assert_eq!(convert_pricing("invalid"), None);
|
||||
}
|
||||
}
|
||||
@@ -1700,9 +1700,26 @@
|
||||
"description": "The maximum context length this model supports",
|
||||
"minimum": 0
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "Currency for the costs (default: \"$\")",
|
||||
"nullable": true
|
||||
},
|
||||
"input_token_cost": {
|
||||
"type": "number",
|
||||
"format": "double",
|
||||
"description": "Cost per token for input (optional)",
|
||||
"nullable": true
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The name of the model"
|
||||
},
|
||||
"output_token_cost": {
|
||||
"type": "number",
|
||||
"format": "double",
|
||||
"description": "Cost per token for output (optional)",
|
||||
"nullable": true
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -3,6 +3,7 @@ import { IpcRendererEvent } from 'electron';
|
||||
import { openSharedSessionFromDeepLink, type SessionLinksViewOptions } from './sessionLinks';
|
||||
import { type SharedSessionDetails } from './sharedSessions';
|
||||
import { initializeSystem } from './utils/providerUtils';
|
||||
import { initializeCostDatabase } from './utils/costDatabase';
|
||||
import { ErrorUI } from './components/ErrorBoundary';
|
||||
import { ConfirmationModal } from './components/ui/ConfirmationModal';
|
||||
import { ToastContainer } from 'react-toastify';
|
||||
@@ -158,6 +159,11 @@ export default function App() {
|
||||
|
||||
const initializeApp = async () => {
|
||||
try {
|
||||
// Initialize cost database early to pre-load pricing data
|
||||
initializeCostDatabase().catch((error) => {
|
||||
console.error('Failed to initialize cost database:', error);
|
||||
});
|
||||
|
||||
await initConfig();
|
||||
try {
|
||||
await readAllConfig({ throwOnError: true });
|
||||
|
||||
@@ -229,10 +229,22 @@ export type ModelInfo = {
|
||||
* The maximum context length this model supports
|
||||
*/
|
||||
context_limit: number;
|
||||
/**
|
||||
* Currency for the costs (default: "$")
|
||||
*/
|
||||
currency?: string | null;
|
||||
/**
|
||||
* Cost per token for input (optional)
|
||||
*/
|
||||
input_token_cost?: number | null;
|
||||
/**
|
||||
* The name of the model
|
||||
*/
|
||||
name: string;
|
||||
/**
|
||||
* Cost per token for output (optional)
|
||||
*/
|
||||
output_token_cost?: number | null;
|
||||
};
|
||||
|
||||
export type PermissionConfirmationRequest = {
|
||||
|
||||
@@ -29,9 +29,18 @@ interface ChatInputProps {
|
||||
droppedFiles?: string[];
|
||||
setView: (view: View) => void;
|
||||
numTokens?: number;
|
||||
inputTokens?: number;
|
||||
outputTokens?: number;
|
||||
hasMessages?: boolean;
|
||||
messages?: Message[];
|
||||
setMessages: (messages: Message[]) => void;
|
||||
sessionCosts?: {
|
||||
[key: string]: {
|
||||
inputTokens: number;
|
||||
outputTokens: number;
|
||||
totalCost: number;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
export default function ChatInput({
|
||||
@@ -42,9 +51,12 @@ export default function ChatInput({
|
||||
initialValue = '',
|
||||
setView,
|
||||
numTokens,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
droppedFiles = [],
|
||||
messages = [],
|
||||
setMessages,
|
||||
sessionCosts,
|
||||
}: ChatInputProps) {
|
||||
const [_value, setValue] = useState(initialValue);
|
||||
const [displayValue, setDisplayValue] = useState(initialValue); // For immediate visual feedback
|
||||
@@ -557,9 +569,12 @@ export default function ChatInput({
|
||||
<BottomMenu
|
||||
setView={setView}
|
||||
numTokens={numTokens}
|
||||
inputTokens={inputTokens}
|
||||
outputTokens={outputTokens}
|
||||
messages={messages}
|
||||
isLoading={isLoading}
|
||||
setMessages={setMessages}
|
||||
sessionCosts={sessionCosts}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -33,6 +33,8 @@ import {
|
||||
} from './context_management/ChatContextManager';
|
||||
import { ContextHandler } from './context_management/ContextHandler';
|
||||
import { LocalMessageStorage } from '../utils/localMessageStorage';
|
||||
import { useModelAndProvider } from './ModelAndProviderContext';
|
||||
import { getCostForModel } from '../utils/costDatabase';
|
||||
import {
|
||||
Message,
|
||||
createUserMessage,
|
||||
@@ -106,11 +108,25 @@ function ChatContent({
|
||||
const [showGame, setShowGame] = useState(false);
|
||||
const [isGeneratingRecipe, setIsGeneratingRecipe] = useState(false);
|
||||
const [sessionTokenCount, setSessionTokenCount] = useState<number>(0);
|
||||
const [sessionInputTokens, setSessionInputTokens] = useState<number>(0);
|
||||
const [sessionOutputTokens, setSessionOutputTokens] = useState<number>(0);
|
||||
const [localInputTokens, setLocalInputTokens] = useState<number>(0);
|
||||
const [localOutputTokens, setLocalOutputTokens] = useState<number>(0);
|
||||
const [ancestorMessages, setAncestorMessages] = useState<Message[]>([]);
|
||||
const [droppedFiles, setDroppedFiles] = useState<string[]>([]);
|
||||
const [sessionCosts, setSessionCosts] = useState<{
|
||||
[key: string]: {
|
||||
inputTokens: number;
|
||||
outputTokens: number;
|
||||
totalCost: number;
|
||||
};
|
||||
}>({});
|
||||
const [readyForAutoUserPrompt, setReadyForAutoUserPrompt] = useState(false);
|
||||
|
||||
const scrollRef = useRef<ScrollAreaHandle>(null);
|
||||
const { currentModel, currentProvider } = useModelAndProvider();
|
||||
const prevModelRef = useRef<string | undefined>();
|
||||
const prevProviderRef = useRef<string | undefined>();
|
||||
|
||||
const {
|
||||
summaryContent,
|
||||
@@ -160,6 +176,7 @@ function ChatContent({
|
||||
updateMessageStreamBody,
|
||||
notifications,
|
||||
currentModelInfo,
|
||||
sessionMetadata,
|
||||
} = useMessageStream({
|
||||
api: getApiUrl('/reply'),
|
||||
initialMessages: chat.messages,
|
||||
@@ -518,12 +535,40 @@ function ChatContent({
|
||||
.reverse();
|
||||
}, [filteredMessages]);
|
||||
|
||||
// Simple token estimation function (roughly 4 characters per token)
|
||||
const estimateTokens = (text: string): number => {
|
||||
return Math.ceil(text.length / 4);
|
||||
};
|
||||
|
||||
// Calculate token counts from messages
|
||||
useEffect(() => {
|
||||
let inputTokens = 0;
|
||||
let outputTokens = 0;
|
||||
|
||||
messages.forEach((message) => {
|
||||
const textContent = getTextContent(message);
|
||||
if (textContent) {
|
||||
const tokens = estimateTokens(textContent);
|
||||
if (message.role === 'user') {
|
||||
inputTokens += tokens;
|
||||
} else if (message.role === 'assistant') {
|
||||
outputTokens += tokens;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
setLocalInputTokens(inputTokens);
|
||||
setLocalOutputTokens(outputTokens);
|
||||
}, [messages]);
|
||||
|
||||
// Fetch session metadata to get token count
|
||||
useEffect(() => {
|
||||
const fetchSessionTokens = async () => {
|
||||
try {
|
||||
const sessionDetails = await fetchSessionDetails(chat.id);
|
||||
setSessionTokenCount(sessionDetails.metadata.total_tokens || 0);
|
||||
setSessionInputTokens(sessionDetails.metadata.accumulated_input_tokens || 0);
|
||||
setSessionOutputTokens(sessionDetails.metadata.accumulated_output_tokens || 0);
|
||||
} catch (err) {
|
||||
console.error('Error fetching session token count:', err);
|
||||
}
|
||||
@@ -533,6 +578,74 @@ function ChatContent({
|
||||
}
|
||||
}, [chat.id, messages]);
|
||||
|
||||
// Update token counts when sessionMetadata changes from the message stream
|
||||
useEffect(() => {
|
||||
console.log('Session metadata received:', sessionMetadata);
|
||||
if (sessionMetadata) {
|
||||
setSessionTokenCount(sessionMetadata.totalTokens || 0);
|
||||
setSessionInputTokens(sessionMetadata.accumulatedInputTokens || 0);
|
||||
setSessionOutputTokens(sessionMetadata.accumulatedOutputTokens || 0);
|
||||
}
|
||||
}, [sessionMetadata]);
|
||||
|
||||
// Handle model changes and accumulate costs
|
||||
useEffect(() => {
|
||||
if (
|
||||
prevModelRef.current !== undefined &&
|
||||
prevProviderRef.current !== undefined &&
|
||||
(prevModelRef.current !== currentModel || prevProviderRef.current !== currentProvider)
|
||||
) {
|
||||
// Model/provider has changed, save the costs for the previous model
|
||||
const prevKey = `${prevProviderRef.current}/${prevModelRef.current}`;
|
||||
|
||||
// Get pricing info for the previous model
|
||||
const prevCostInfo = getCostForModel(prevProviderRef.current, prevModelRef.current);
|
||||
|
||||
if (prevCostInfo) {
|
||||
const prevInputCost =
|
||||
(sessionInputTokens || localInputTokens) * (prevCostInfo.input_token_cost || 0);
|
||||
const prevOutputCost =
|
||||
(sessionOutputTokens || localOutputTokens) * (prevCostInfo.output_token_cost || 0);
|
||||
const prevTotalCost = prevInputCost + prevOutputCost;
|
||||
|
||||
// Save the accumulated costs for this model
|
||||
setSessionCosts((prev) => ({
|
||||
...prev,
|
||||
[prevKey]: {
|
||||
inputTokens: sessionInputTokens || localInputTokens,
|
||||
outputTokens: sessionOutputTokens || localOutputTokens,
|
||||
totalCost: prevTotalCost,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
// Reset token counters for the new model
|
||||
setSessionTokenCount(0);
|
||||
setSessionInputTokens(0);
|
||||
setSessionOutputTokens(0);
|
||||
setLocalInputTokens(0);
|
||||
setLocalOutputTokens(0);
|
||||
|
||||
console.log(
|
||||
'Model changed from',
|
||||
`${prevProviderRef.current}/${prevModelRef.current}`,
|
||||
'to',
|
||||
`${currentProvider}/${currentModel}`,
|
||||
'- saved costs and reset token counters'
|
||||
);
|
||||
}
|
||||
|
||||
prevModelRef.current = currentModel || undefined;
|
||||
prevProviderRef.current = currentProvider || undefined;
|
||||
}, [
|
||||
currentModel,
|
||||
currentProvider,
|
||||
sessionInputTokens,
|
||||
sessionOutputTokens,
|
||||
localInputTokens,
|
||||
localOutputTokens,
|
||||
]);
|
||||
|
||||
const handleDrop = (e: React.DragEvent<HTMLDivElement>) => {
|
||||
e.preventDefault();
|
||||
const files = e.dataTransfer.files;
|
||||
@@ -684,9 +797,12 @@ function ChatContent({
|
||||
setView={setView}
|
||||
hasMessages={hasMessages}
|
||||
numTokens={sessionTokenCount}
|
||||
inputTokens={sessionInputTokens || localInputTokens}
|
||||
outputTokens={sessionOutputTokens || localOutputTokens}
|
||||
droppedFiles={droppedFiles}
|
||||
messages={messages}
|
||||
setMessages={setMessages}
|
||||
sessionCosts={sessionCosts}
|
||||
/>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
@@ -9,6 +9,7 @@ import { useConfig } from '../ConfigContext';
|
||||
import { useModelAndProvider } from '../ModelAndProviderContext';
|
||||
import { Message } from '../../types/message';
|
||||
import { ManualSummarizeButton } from '../context_management/ManualSummaryButton';
|
||||
import { CostTracker } from './CostTracker';
|
||||
|
||||
const TOKEN_LIMIT_DEFAULT = 128000; // fallback for custom models that the backend doesn't know about
|
||||
const TOKEN_WARNING_THRESHOLD = 0.8; // warning shows at 80% of the token limit
|
||||
@@ -22,15 +23,27 @@ interface ModelLimit {
|
||||
export default function BottomMenu({
|
||||
setView,
|
||||
numTokens = 0,
|
||||
inputTokens = 0,
|
||||
outputTokens = 0,
|
||||
messages = [],
|
||||
isLoading = false,
|
||||
setMessages,
|
||||
sessionCosts,
|
||||
}: {
|
||||
setView: (view: View, viewOptions?: ViewOptions) => void;
|
||||
numTokens?: number;
|
||||
inputTokens?: number;
|
||||
outputTokens?: number;
|
||||
messages?: Message[];
|
||||
isLoading?: boolean;
|
||||
setMessages: (messages: Message[]) => void;
|
||||
sessionCosts?: {
|
||||
[key: string]: {
|
||||
inputTokens: number;
|
||||
outputTokens: number;
|
||||
totalCost: number;
|
||||
};
|
||||
};
|
||||
}) {
|
||||
const [isModelMenuOpen, setIsModelMenuOpen] = useState(false);
|
||||
const { alerts, addAlert, clearAlerts } = useAlerts();
|
||||
@@ -202,29 +215,45 @@ export default function BottomMenu({
|
||||
}, [isModelMenuOpen]);
|
||||
|
||||
return (
|
||||
<div className="flex justify-between items-center transition-colors text-textSubtle relative text-xs align-middle">
|
||||
<div className="flex items-center pl-2">
|
||||
<div className="flex justify-between items-center transition-colors text-textSubtle relative text-xs h-6">
|
||||
<div className="flex items-center h-full">
|
||||
{/* Tool and Token count */}
|
||||
{<BottomMenuAlertPopover alerts={alerts} />}
|
||||
<div className="flex items-center h-full pl-2">
|
||||
{<BottomMenuAlertPopover alerts={alerts} />}
|
||||
</div>
|
||||
|
||||
{/* Cost Tracker - no separator before it */}
|
||||
<div className="flex items-center h-full ml-1">
|
||||
<CostTracker inputTokens={inputTokens} outputTokens={outputTokens} sessionCosts={sessionCosts} />
|
||||
</div>
|
||||
|
||||
{/* Separator between cost and model */}
|
||||
<div className="w-[1px] h-4 bg-borderSubtle mx-1.5" />
|
||||
|
||||
{/* Model Selector Dropdown */}
|
||||
<ModelsBottomBar dropdownRef={dropdownRef} setView={setView} />
|
||||
<div className="flex items-center h-full">
|
||||
<ModelsBottomBar dropdownRef={dropdownRef} setView={setView} />
|
||||
</div>
|
||||
|
||||
{/* Separator */}
|
||||
<div className="w-[1px] h-4 bg-borderSubtle mx-2" />
|
||||
<div className="w-[1px] h-4 bg-borderSubtle mx-1.5" />
|
||||
|
||||
{/* Goose Mode Selector Dropdown */}
|
||||
<BottomMenuModeSelection setView={setView} />
|
||||
<div className="flex items-center h-full">
|
||||
<BottomMenuModeSelection setView={setView} />
|
||||
</div>
|
||||
|
||||
{/* Summarize Context Button - ADD THIS */}
|
||||
{/* Summarize Context Button */}
|
||||
{messages.length > 0 && (
|
||||
<>
|
||||
<div className="w-[1px] h-4 bg-borderSubtle mx-2" />
|
||||
<ManualSummarizeButton
|
||||
messages={messages}
|
||||
isLoading={isLoading}
|
||||
setMessages={setMessages}
|
||||
/>
|
||||
<div className="w-[1px] h-4 bg-borderSubtle mx-1.5" />
|
||||
<div className="flex items-center h-full">
|
||||
<ManualSummarizeButton
|
||||
messages={messages}
|
||||
isLoading={isLoading}
|
||||
setMessages={setMessages}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
310
ui/desktop/src/components/bottom_menu/CostTracker.tsx
Normal file
310
ui/desktop/src/components/bottom_menu/CostTracker.tsx
Normal file
@@ -0,0 +1,310 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import { useModelAndProvider } from '../ModelAndProviderContext';
|
||||
import { useConfig } from '../ConfigContext';
|
||||
import {
|
||||
getCostForModel,
|
||||
initializeCostDatabase,
|
||||
updateAllModelCosts,
|
||||
fetchAndCachePricing,
|
||||
} from '../../utils/costDatabase';
|
||||
|
||||
interface CostTrackerProps {
|
||||
inputTokens?: number;
|
||||
outputTokens?: number;
|
||||
sessionCosts?: {
|
||||
[key: string]: {
|
||||
inputTokens: number;
|
||||
outputTokens: number;
|
||||
totalCost: number;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
export function CostTracker({ inputTokens = 0, outputTokens = 0, sessionCosts }: CostTrackerProps) {
|
||||
const { currentModel, currentProvider } = useModelAndProvider();
|
||||
const { getProviders } = useConfig();
|
||||
const [costInfo, setCostInfo] = useState<{
|
||||
input_token_cost?: number;
|
||||
output_token_cost?: number;
|
||||
currency?: string;
|
||||
} | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [showPricing, setShowPricing] = useState(true);
|
||||
const [pricingFailed, setPricingFailed] = useState(false);
|
||||
const [modelNotFound, setModelNotFound] = useState(false);
|
||||
const [hasAttemptedFetch, setHasAttemptedFetch] = useState(false);
|
||||
const [initialLoadComplete, setInitialLoadComplete] = useState(false);
|
||||
|
||||
// Check if pricing is enabled
|
||||
useEffect(() => {
|
||||
const checkPricingSetting = () => {
|
||||
const stored = localStorage.getItem('show_pricing');
|
||||
setShowPricing(stored !== 'false');
|
||||
};
|
||||
|
||||
// Check on mount
|
||||
checkPricingSetting();
|
||||
|
||||
// Listen for storage changes
|
||||
window.addEventListener('storage', checkPricingSetting);
|
||||
return () => window.removeEventListener('storage', checkPricingSetting);
|
||||
}, []);
|
||||
|
||||
// Set initial load complete after a short delay
|
||||
useEffect(() => {
|
||||
const timer = setTimeout(() => {
|
||||
setInitialLoadComplete(true);
|
||||
}, 3000); // Give 3 seconds for initial load
|
||||
|
||||
return () => window.clearTimeout(timer);
|
||||
}, []);
|
||||
|
||||
// Debug log props removed
|
||||
|
||||
// Initialize cost database on mount
|
||||
useEffect(() => {
|
||||
initializeCostDatabase();
|
||||
|
||||
// Update costs for all models in background
|
||||
updateAllModelCosts().catch((error) => {
|
||||
console.error('Failed to update model costs:', error);
|
||||
});
|
||||
}, [getProviders]);
|
||||
|
||||
useEffect(() => {
|
||||
const loadCostInfo = async () => {
|
||||
if (!currentModel || !currentProvider) {
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`CostTracker: Loading cost info for ${currentProvider}/${currentModel}`);
|
||||
|
||||
try {
|
||||
// First check sync cache
|
||||
let costData = getCostForModel(currentProvider, currentModel);
|
||||
|
||||
if (costData) {
|
||||
// We have cached data
|
||||
console.log(
|
||||
`CostTracker: Found cached data for ${currentProvider}/${currentModel}:`,
|
||||
costData
|
||||
);
|
||||
setCostInfo(costData);
|
||||
setPricingFailed(false);
|
||||
setModelNotFound(false);
|
||||
setIsLoading(false);
|
||||
setHasAttemptedFetch(true);
|
||||
} else {
|
||||
// Need to fetch from backend
|
||||
console.log(
|
||||
`CostTracker: No cached data, fetching from backend for ${currentProvider}/${currentModel}`
|
||||
);
|
||||
setIsLoading(true);
|
||||
const result = await fetchAndCachePricing(currentProvider, currentModel);
|
||||
setHasAttemptedFetch(true);
|
||||
|
||||
if (result && result.costInfo) {
|
||||
console.log(
|
||||
`CostTracker: Fetched data for ${currentProvider}/${currentModel}:`,
|
||||
result.costInfo
|
||||
);
|
||||
setCostInfo(result.costInfo);
|
||||
setPricingFailed(false);
|
||||
setModelNotFound(false);
|
||||
} else if (result && result.error === 'model_not_found') {
|
||||
console.log(
|
||||
`CostTracker: Model not found in pricing data for ${currentProvider}/${currentModel}`
|
||||
);
|
||||
// Model not found in pricing database, but API call succeeded
|
||||
setModelNotFound(true);
|
||||
setPricingFailed(false);
|
||||
} else {
|
||||
console.log(`CostTracker: API failed for ${currentProvider}/${currentModel}`);
|
||||
// API call failed or other error
|
||||
const freeProviders = ['ollama', 'local', 'localhost'];
|
||||
if (!freeProviders.includes(currentProvider.toLowerCase())) {
|
||||
setPricingFailed(true);
|
||||
setModelNotFound(false);
|
||||
}
|
||||
}
|
||||
setIsLoading(false);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error loading cost info:', error);
|
||||
setHasAttemptedFetch(true);
|
||||
// Only set pricing failed if we're not dealing with a known free provider
|
||||
const freeProviders = ['ollama', 'local', 'localhost'];
|
||||
if (!freeProviders.includes(currentProvider.toLowerCase())) {
|
||||
setPricingFailed(true);
|
||||
setModelNotFound(false);
|
||||
}
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
loadCostInfo();
|
||||
}, [currentModel, currentProvider]);
|
||||
|
||||
// Return null early if pricing is disabled
|
||||
if (!showPricing) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const calculateCost = (): number => {
|
||||
// If we have session costs, calculate the total across all models
|
||||
if (sessionCosts) {
|
||||
let totalCost = 0;
|
||||
|
||||
// Add up all historical costs from different models
|
||||
Object.values(sessionCosts).forEach((modelCost) => {
|
||||
totalCost += modelCost.totalCost;
|
||||
});
|
||||
|
||||
// Add current model cost if we have pricing info
|
||||
if (
|
||||
costInfo &&
|
||||
(costInfo.input_token_cost !== undefined || costInfo.output_token_cost !== undefined)
|
||||
) {
|
||||
const currentInputCost = inputTokens * (costInfo.input_token_cost || 0);
|
||||
const currentOutputCost = outputTokens * (costInfo.output_token_cost || 0);
|
||||
totalCost += currentInputCost + currentOutputCost;
|
||||
}
|
||||
|
||||
return totalCost;
|
||||
}
|
||||
|
||||
// Fallback to simple calculation for current model only
|
||||
if (
|
||||
!costInfo ||
|
||||
(costInfo.input_token_cost === undefined && costInfo.output_token_cost === undefined)
|
||||
) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const inputCost = inputTokens * (costInfo.input_token_cost || 0);
|
||||
const outputCost = outputTokens * (costInfo.output_token_cost || 0);
|
||||
const total = inputCost + outputCost;
|
||||
|
||||
return total;
|
||||
};
|
||||
|
||||
const formatCost = (cost: number): string => {
|
||||
// Always show 6 decimal places for consistency
|
||||
return cost.toFixed(6);
|
||||
};
|
||||
|
||||
// Debug logging removed
|
||||
|
||||
// Show loading state or when we don't have model/provider info
|
||||
if (!currentModel || !currentProvider) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// If still loading, show a placeholder
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center h-full text-textSubtle translate-y-[1px]">
|
||||
<span className="text-xs font-mono">...</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// If no cost info found, try to return a default
|
||||
if (
|
||||
!costInfo ||
|
||||
(costInfo.input_token_cost === undefined && costInfo.output_token_cost === undefined)
|
||||
) {
|
||||
// If it's a known free/local provider, show $0.000000 without "not available" message
|
||||
const freeProviders = ['ollama', 'local', 'localhost'];
|
||||
if (freeProviders.includes(currentProvider.toLowerCase())) {
|
||||
return (
|
||||
<div
|
||||
className="flex items-center justify-center h-full text-textSubtle hover:text-textStandard transition-colors cursor-default translate-y-[1px]"
|
||||
title={`Local model (${inputTokens.toLocaleString()} input, ${outputTokens.toLocaleString()} output tokens)`}
|
||||
>
|
||||
<span className="text-xs font-mono">$0.000000</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Otherwise show as unavailable
|
||||
const getUnavailableTooltip = () => {
|
||||
if (pricingFailed && hasAttemptedFetch && initialLoadComplete) {
|
||||
return `Pricing data unavailable - OpenRouter connection failed. Click refresh in settings to retry.`;
|
||||
}
|
||||
// If we reach here, it must be modelNotFound (since we only get here after attempting fetch)
|
||||
return `Cost data not available for ${currentModel} (${inputTokens.toLocaleString()} input, ${outputTokens.toLocaleString()} output tokens)`;
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`flex items-center justify-center h-full transition-colors cursor-default translate-y-[1px] ${
|
||||
(pricingFailed || modelNotFound) && hasAttemptedFetch && initialLoadComplete
|
||||
? 'text-red-500 hover:text-red-400'
|
||||
: 'text-textSubtle hover:text-textStandard'
|
||||
}`}
|
||||
title={getUnavailableTooltip()}
|
||||
>
|
||||
<span className="text-xs font-mono">$0.000000</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const totalCost = calculateCost();
|
||||
|
||||
// Build tooltip content
|
||||
const getTooltipContent = (): string => {
|
||||
// Handle error states first
|
||||
if (pricingFailed && hasAttemptedFetch && initialLoadComplete) {
|
||||
return `Pricing data unavailable - OpenRouter connection failed. Click refresh in settings to retry.`;
|
||||
}
|
||||
|
||||
if (modelNotFound && hasAttemptedFetch && initialLoadComplete) {
|
||||
return `Pricing not available for ${currentProvider}/${currentModel}. This model may not be supported by the pricing service.`;
|
||||
}
|
||||
|
||||
// Handle session costs
|
||||
if (sessionCosts && Object.keys(sessionCosts).length > 0) {
|
||||
// Show session breakdown
|
||||
let tooltip = 'Session cost breakdown:\n';
|
||||
|
||||
Object.entries(sessionCosts).forEach(([modelKey, cost]) => {
|
||||
const costStr = `${costInfo?.currency || '$'}${cost.totalCost.toFixed(6)}`;
|
||||
tooltip += `${modelKey}: ${costStr} (${cost.inputTokens.toLocaleString()} in, ${cost.outputTokens.toLocaleString()} out)\n`;
|
||||
});
|
||||
|
||||
// Add current model if it has costs
|
||||
if (costInfo && (inputTokens > 0 || outputTokens > 0)) {
|
||||
const currentCost =
|
||||
inputTokens * (costInfo.input_token_cost || 0) +
|
||||
outputTokens * (costInfo.output_token_cost || 0);
|
||||
if (currentCost > 0) {
|
||||
tooltip += `${currentProvider}/${currentModel} (current): ${costInfo.currency || '$'}${currentCost.toFixed(6)} (${inputTokens.toLocaleString()} in, ${outputTokens.toLocaleString()} out)\n`;
|
||||
}
|
||||
}
|
||||
|
||||
tooltip += `\nTotal session cost: ${costInfo?.currency || '$'}${totalCost.toFixed(6)}`;
|
||||
return tooltip;
|
||||
}
|
||||
|
||||
// Default tooltip for single model
|
||||
return `Input: ${inputTokens.toLocaleString()} tokens (${costInfo?.currency || '$'}${(inputTokens * (costInfo?.input_token_cost || 0)).toFixed(6)}) | Output: ${outputTokens.toLocaleString()} tokens (${costInfo?.currency || '$'}${(outputTokens * (costInfo?.output_token_cost || 0)).toFixed(6)})`;
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`flex items-center justify-center h-full transition-colors cursor-default translate-y-[1px] ${
|
||||
(pricingFailed || modelNotFound) && hasAttemptedFetch && initialLoadComplete
|
||||
? 'text-red-500 hover:text-red-400'
|
||||
: 'text-textSubtle hover:text-textStandard'
|
||||
}`}
|
||||
title={getTooltipContent()}
|
||||
>
|
||||
<span className="text-xs font-mono">
|
||||
{costInfo.currency || '$'}
|
||||
{formatCost(totalCost)}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
import { useState, useEffect, useRef } from 'react';
|
||||
import { Switch } from '../../ui/switch';
|
||||
import { Button } from '../../ui/button';
|
||||
import { Settings } from 'lucide-react';
|
||||
import { Settings, RefreshCw, ExternalLink } from 'lucide-react';
|
||||
import Modal from '../../Modal';
|
||||
import UpdateSection from './UpdateSection';
|
||||
import { UPDATES_ENABLED } from '../../../updates';
|
||||
import { getApiUrl, getSecretKey } from '../../../config';
|
||||
|
||||
interface AppSettingsSectionProps {
|
||||
scrollToSection?: string;
|
||||
@@ -17,6 +18,10 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
|
||||
const [isMacOS, setIsMacOS] = useState(false);
|
||||
const [isDockSwitchDisabled, setIsDockSwitchDisabled] = useState(false);
|
||||
const [showNotificationModal, setShowNotificationModal] = useState(false);
|
||||
const [pricingStatus, setPricingStatus] = useState<'loading' | 'success' | 'error'>('loading');
|
||||
const [lastFetchTime, setLastFetchTime] = useState<Date | null>(null);
|
||||
const [isRefreshing, setIsRefreshing] = useState(false);
|
||||
const [showPricing, setShowPricing] = useState(true);
|
||||
const updateSectionRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Check if running on macOS
|
||||
@@ -24,6 +29,77 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
|
||||
setIsMacOS(window.electron.platform === 'darwin');
|
||||
}, []);
|
||||
|
||||
// Load show pricing setting
|
||||
useEffect(() => {
|
||||
const stored = localStorage.getItem('show_pricing');
|
||||
setShowPricing(stored !== 'false');
|
||||
}, []);
|
||||
|
||||
// Check pricing status on mount
|
||||
useEffect(() => {
|
||||
checkPricingStatus();
|
||||
}, []);
|
||||
|
||||
const checkPricingStatus = async () => {
|
||||
try {
|
||||
const apiUrl = getApiUrl('/config/pricing');
|
||||
const secretKey = getSecretKey();
|
||||
|
||||
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||
if (secretKey) {
|
||||
headers['X-Secret-Key'] = secretKey;
|
||||
}
|
||||
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({ configured_only: true }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
await response.json(); // Consume the response
|
||||
setPricingStatus('success');
|
||||
setLastFetchTime(new Date());
|
||||
} else {
|
||||
setPricingStatus('error');
|
||||
}
|
||||
} catch (error) {
|
||||
setPricingStatus('error');
|
||||
}
|
||||
};
|
||||
|
||||
const handleRefreshPricing = async () => {
|
||||
setIsRefreshing(true);
|
||||
try {
|
||||
const apiUrl = getApiUrl('/config/pricing');
|
||||
const secretKey = getSecretKey();
|
||||
|
||||
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||
if (secretKey) {
|
||||
headers['X-Secret-Key'] = secretKey;
|
||||
}
|
||||
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({ configured_only: false }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
setPricingStatus('success');
|
||||
setLastFetchTime(new Date());
|
||||
// Trigger a reload of the cost database
|
||||
window.dispatchEvent(new CustomEvent('pricing-updated'));
|
||||
} else {
|
||||
setPricingStatus('error');
|
||||
}
|
||||
} catch (error) {
|
||||
setPricingStatus('error');
|
||||
} finally {
|
||||
setIsRefreshing(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Handle scrolling to update section
|
||||
useEffect(() => {
|
||||
if (scrollToSection === 'update' && updateSectionRef.current) {
|
||||
@@ -99,6 +175,13 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
|
||||
}
|
||||
};
|
||||
|
||||
const handleShowPricingToggle = (checked: boolean) => {
|
||||
setShowPricing(checked);
|
||||
localStorage.setItem('show_pricing', String(checked));
|
||||
// Trigger storage event for other components
|
||||
window.dispatchEvent(new CustomEvent('storage'));
|
||||
};
|
||||
|
||||
return (
|
||||
<section id="appSettings" className="px-8">
|
||||
<div className="flex justify-between items-center mb-2">
|
||||
@@ -173,6 +256,7 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Quit Confirmation */}
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<div>
|
||||
<h3 className="text-textStandard">Quit Confirmation</h3>
|
||||
@@ -188,6 +272,87 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Cost Tracking */}
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<div>
|
||||
<h3 className="text-textStandard">Cost Tracking</h3>
|
||||
<p className="text-xs text-textSubtle max-w-md mt-[2px]">
|
||||
Show model pricing and usage costs
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
<Switch
|
||||
checked={showPricing}
|
||||
onCheckedChange={handleShowPricingToggle}
|
||||
variant="mono"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Pricing Status - only show if cost tracking is enabled */}
|
||||
{showPricing && (
|
||||
<>
|
||||
<div className="flex items-center justify-between text-xs mb-2 px-4">
|
||||
<span className="text-textSubtle">Pricing Source:</span>
|
||||
<a
|
||||
href="https://openrouter.ai/docs#models"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-blue-600 dark:text-blue-400 hover:underline flex items-center gap-1"
|
||||
>
|
||||
OpenRouter Docs
|
||||
<ExternalLink size={10} />
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between text-xs mb-2 px-4">
|
||||
<span className="text-textSubtle">Status:</span>
|
||||
<div className="flex items-center gap-2">
|
||||
<span
|
||||
className={`font-medium ${
|
||||
pricingStatus === 'success'
|
||||
? 'text-green-600 dark:text-green-400'
|
||||
: pricingStatus === 'error'
|
||||
? 'text-red-600 dark:text-red-400'
|
||||
: 'text-textSubtle'
|
||||
}`}
|
||||
>
|
||||
{pricingStatus === 'success'
|
||||
? '✓ Connected'
|
||||
: pricingStatus === 'error'
|
||||
? '✗ Failed'
|
||||
: '... Checking'}
|
||||
</span>
|
||||
<button
|
||||
className="p-0.5 hover:bg-gray-200 dark:hover:bg-gray-700 rounded transition-colors disabled:opacity-50"
|
||||
onClick={handleRefreshPricing}
|
||||
disabled={isRefreshing}
|
||||
title="Refresh pricing data"
|
||||
type="button"
|
||||
>
|
||||
<RefreshCw
|
||||
size={8}
|
||||
className={`text-textSubtle hover:text-textStandard ${isRefreshing ? 'animate-spin-fast' : ''}`}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{lastFetchTime && (
|
||||
<div className="flex items-center justify-between text-xs mb-2 px-4">
|
||||
<span className="text-textSubtle">Last updated:</span>
|
||||
<span className="text-textSubtle">{lastFetchTime.toLocaleTimeString()}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{pricingStatus === 'error' && (
|
||||
<p className="text-xs text-red-600 dark:text-red-400 px-4">
|
||||
Unable to fetch pricing data. Costs will not be displayed.
|
||||
</p>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Help & Feedback Section */}
|
||||
|
||||
@@ -43,7 +43,10 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa
|
||||
}, [read]);
|
||||
|
||||
// Determine which model to display - activeModel takes priority when lead/worker is active
|
||||
const displayModel = (isLeadWorkerActive && currentModelInfo?.model) ? currentModelInfo.model : (currentModel || 'Select Model');
|
||||
const displayModel =
|
||||
isLeadWorkerActive && currentModelInfo?.model
|
||||
? currentModelInfo.model
|
||||
: currentModel || 'Select Model';
|
||||
const modelMode = currentModelInfo?.mode;
|
||||
|
||||
// Update display provider when current provider changes
|
||||
@@ -106,9 +109,7 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa
|
||||
>
|
||||
{displayModel}
|
||||
{isLeadWorkerActive && modelMode && (
|
||||
<span className="ml-1 text-[10px] opacity-60">
|
||||
({modelMode})
|
||||
</span>
|
||||
<span className="ml-1 text-[10px] opacity-60">({modelMode})</span>
|
||||
)}
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
@@ -116,9 +117,7 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa
|
||||
<TooltipContent className="max-w-96 overflow-auto scrollbar-thin" side="top">
|
||||
{displayModel}
|
||||
{isLeadWorkerActive && modelMode && (
|
||||
<span className="ml-1 text-[10px] opacity-60">
|
||||
({modelMode})
|
||||
</span>
|
||||
<span className="ml-1 text-[10px] opacity-60">({modelMode})</span>
|
||||
)}
|
||||
</TooltipContent>
|
||||
)}
|
||||
@@ -164,7 +163,7 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa
|
||||
{isAddModelModalOpen ? (
|
||||
<AddModelModal setView={setView} onClose={() => setIsAddModelModalOpen(false)} />
|
||||
) : null}
|
||||
|
||||
|
||||
{isLeadWorkerModalOpen ? (
|
||||
<Modal onClose={() => setIsLeadWorkerModalOpen(false)}>
|
||||
<LeadWorkerSettings onClose={() => setIsLeadWorkerModalOpen(false)} />
|
||||
|
||||
@@ -21,7 +21,9 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
||||
const [failureThreshold, setFailureThreshold] = useState<number>(2);
|
||||
const [fallbackTurns, setFallbackTurns] = useState<number>(2);
|
||||
const [isEnabled, setIsEnabled] = useState(false);
|
||||
const [modelOptions, setModelOptions] = useState<{ value: string; label: string; provider: string }[]>([]);
|
||||
const [modelOptions, setModelOptions] = useState<
|
||||
{ value: string; label: string; provider: string }[]
|
||||
>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
|
||||
// Load current configuration
|
||||
@@ -51,7 +53,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
||||
if (leadTurnsConfig) setLeadTurns(Number(leadTurnsConfig));
|
||||
if (failureThresholdConfig) setFailureThreshold(Number(failureThresholdConfig));
|
||||
if (fallbackTurnsConfig) setFallbackTurns(Number(fallbackTurnsConfig));
|
||||
|
||||
|
||||
// Set worker model to current model or from config
|
||||
const workerModelConfig = await read('GOOSE_MODEL', false);
|
||||
if (workerModelConfig) {
|
||||
@@ -59,7 +61,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
||||
} else if (currentModel) {
|
||||
setWorkerModel(currentModel as string);
|
||||
}
|
||||
|
||||
|
||||
const workerProviderConfig = await read('GOOSE_PROVIDER', false);
|
||||
if (workerProviderConfig) {
|
||||
setWorkerProvider(workerProviderConfig as string);
|
||||
@@ -69,7 +71,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
||||
const providers = await getProviders(false);
|
||||
const activeProviders = providers.filter((p) => p.is_configured);
|
||||
const options: { value: string; label: string; provider: string }[] = [];
|
||||
|
||||
|
||||
activeProviders.forEach(({ metadata, name }) => {
|
||||
if (metadata.known_models) {
|
||||
metadata.known_models.forEach((model) => {
|
||||
@@ -81,7 +83,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
setModelOptions(options);
|
||||
} catch (error) {
|
||||
console.error('Error loading configuration:', error);
|
||||
@@ -184,9 +186,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
||||
placeholder="Select worker model..."
|
||||
isDisabled={!isEnabled}
|
||||
/>
|
||||
<p className="text-xs text-textSubtle">
|
||||
Fast model for routine execution tasks
|
||||
</p>
|
||||
<p className="text-xs text-textSubtle">Fast model for routine execution tasks</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4 pt-4 border-t border-borderSubtle">
|
||||
@@ -242,9 +242,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
||||
className="w-20"
|
||||
disabled={!isEnabled}
|
||||
/>
|
||||
<p className="text-xs text-textSubtle">
|
||||
Turns to use lead model during fallback
|
||||
</p>
|
||||
<p className="text-xs text-textSubtle">Turns to use lead model during fallback</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -259,4 +257,4 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@ import { useCurrentModelInfo } from '../components/ChatView';
|
||||
|
||||
export function useCurrentModel() {
|
||||
const modelInfo = useCurrentModelInfo();
|
||||
|
||||
return {
|
||||
|
||||
return {
|
||||
currentModel: modelInfo?.model || null,
|
||||
isLoading: false
|
||||
isLoading: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,12 +2,26 @@ import { useState, useCallback, useEffect, useRef, useId } from 'react';
|
||||
import useSWR from 'swr';
|
||||
import { getSecretKey } from '../config';
|
||||
import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message';
|
||||
import { getSessionHistory } from '../api';
|
||||
|
||||
// Ensure TextDecoder is available in the global scope
|
||||
const TextDecoder = globalThis.TextDecoder;
|
||||
|
||||
type JsonValue = string | number | boolean | null | JsonValue[] | { [key: string]: JsonValue };
|
||||
|
||||
export interface SessionMetadata {
|
||||
workingDir: string;
|
||||
description: string;
|
||||
scheduleId: string | null;
|
||||
messageCount: number;
|
||||
totalTokens: number | null;
|
||||
inputTokens: number | null;
|
||||
outputTokens: number | null;
|
||||
accumulatedTotalTokens: number | null;
|
||||
accumulatedInputTokens: number | null;
|
||||
accumulatedOutputTokens: number | null;
|
||||
}
|
||||
|
||||
export interface NotificationEvent {
|
||||
type: 'Notification';
|
||||
request_id: string;
|
||||
@@ -141,9 +155,12 @@ export interface UseMessageStreamHelpers {
|
||||
updateMessageStreamBody?: (newBody: object) => void;
|
||||
|
||||
notifications: NotificationEvent[];
|
||||
|
||||
|
||||
/** Current model info from the backend */
|
||||
currentModelInfo: { model: string; mode: string } | null;
|
||||
|
||||
/** Session metadata including token counts */
|
||||
sessionMetadata: SessionMetadata | null;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -172,7 +189,10 @@ export function useMessageStream({
|
||||
});
|
||||
|
||||
const [notifications, setNotifications] = useState<NotificationEvent[]>([]);
|
||||
const [currentModelInfo, setCurrentModelInfo] = useState<{ model: string; mode: string } | null>(null);
|
||||
const [currentModelInfo, setCurrentModelInfo] = useState<{ model: string; mode: string } | null>(
|
||||
null
|
||||
);
|
||||
const [sessionMetadata, setSessionMetadata] = useState<SessionMetadata | null>(null);
|
||||
|
||||
// expose a way to update the body so we can update the session id when CLE occurs
|
||||
const updateMessageStreamBody = useCallback((newBody: object) => {
|
||||
@@ -291,13 +311,41 @@ export function useMessageStream({
|
||||
case 'Error':
|
||||
throw new Error(parsedEvent.error);
|
||||
|
||||
case 'Finish':
|
||||
case 'Finish': {
|
||||
// Call onFinish with the last message if available
|
||||
if (onFinish && currentMessages.length > 0) {
|
||||
const lastMessage = currentMessages[currentMessages.length - 1];
|
||||
onFinish(lastMessage, parsedEvent.reason);
|
||||
}
|
||||
|
||||
// Fetch updated session metadata with token counts
|
||||
const sessionId = (extraMetadataRef.current.body as Record<string, unknown>)?.session_id as string;
|
||||
if (sessionId) {
|
||||
try {
|
||||
const sessionResponse = await getSessionHistory({
|
||||
path: { session_id: sessionId },
|
||||
});
|
||||
|
||||
if (sessionResponse.data?.metadata) {
|
||||
setSessionMetadata({
|
||||
workingDir: sessionResponse.data.metadata.working_dir,
|
||||
description: sessionResponse.data.metadata.description,
|
||||
scheduleId: sessionResponse.data.metadata.schedule_id || null,
|
||||
messageCount: sessionResponse.data.metadata.message_count,
|
||||
totalTokens: sessionResponse.data.metadata.total_tokens || null,
|
||||
inputTokens: sessionResponse.data.metadata.input_tokens || null,
|
||||
outputTokens: sessionResponse.data.metadata.output_tokens || null,
|
||||
accumulatedTotalTokens: sessionResponse.data.metadata.accumulated_total_tokens || null,
|
||||
accumulatedInputTokens: sessionResponse.data.metadata.accumulated_input_tokens || null,
|
||||
accumulatedOutputTokens: sessionResponse.data.metadata.accumulated_output_tokens || null,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch session metadata:', error);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error parsing SSE event:', e);
|
||||
@@ -559,5 +607,6 @@ export function useMessageStream({
|
||||
updateMessageStreamBody,
|
||||
notifications,
|
||||
currentModelInfo,
|
||||
sessionMetadata,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -7,6 +7,10 @@ export interface SessionMetadata {
|
||||
message_count: number;
|
||||
total_tokens: number | null;
|
||||
working_dir: string; // Required in type, but may be missing in old sessions
|
||||
// Add the accumulated token fields from the API
|
||||
accumulated_input_tokens?: number | null;
|
||||
accumulated_output_tokens?: number | null;
|
||||
accumulated_total_tokens?: number | null;
|
||||
}
|
||||
|
||||
// Helper function to ensure working directory is set
|
||||
@@ -16,6 +20,9 @@ export function ensureWorkingDir(metadata: Partial<SessionMetadata>): SessionMet
|
||||
message_count: metadata.message_count || 0,
|
||||
total_tokens: metadata.total_tokens || null,
|
||||
working_dir: metadata.working_dir || process.env.HOME || '',
|
||||
accumulated_input_tokens: metadata.accumulated_input_tokens || null,
|
||||
accumulated_output_tokens: metadata.accumulated_output_tokens || null,
|
||||
accumulated_total_tokens: metadata.accumulated_total_tokens || null,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
593
ui/desktop/src/utils/costDatabase.ts
Normal file
593
ui/desktop/src/utils/costDatabase.ts
Normal file
@@ -0,0 +1,593 @@
|
||||
// Import the proper type from ConfigContext
|
||||
import { getApiUrl, getSecretKey } from '../config';
|
||||
|
||||
export interface ModelCostInfo {
|
||||
input_token_cost: number; // Cost per token for input (in USD)
|
||||
output_token_cost: number; // Cost per token for output (in USD)
|
||||
currency: string; // Currency symbol
|
||||
}
|
||||
|
||||
// In-memory cache for current model pricing only
|
||||
let currentModelPricing: {
|
||||
provider: string;
|
||||
model: string;
|
||||
costInfo: ModelCostInfo | null;
|
||||
} | null = null;
|
||||
|
||||
// LocalStorage keys
|
||||
const PRICING_CACHE_KEY = 'goose_pricing_cache';
|
||||
const PRICING_CACHE_TIMESTAMP_KEY = 'goose_pricing_cache_timestamp';
|
||||
const RECENTLY_USED_MODELS_KEY = 'goose_recently_used_models';
|
||||
const CACHE_TTL_MS = 7 * 24 * 60 * 60 * 1000; // 7 days in milliseconds
|
||||
const MAX_RECENTLY_USED_MODELS = 20; // Keep only the last 20 used models in cache
|
||||
|
||||
interface PricingItem {
|
||||
provider: string;
|
||||
model: string;
|
||||
input_token_cost: number;
|
||||
output_token_cost: number;
|
||||
currency: string;
|
||||
}
|
||||
|
||||
interface PricingCacheData {
|
||||
pricing: PricingItem[];
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
interface RecentlyUsedModel {
|
||||
provider: string;
|
||||
model: string;
|
||||
lastUsed: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get recently used models from localStorage
|
||||
*/
|
||||
function getRecentlyUsedModels(): RecentlyUsedModel[] {
|
||||
try {
|
||||
const stored = localStorage.getItem(RECENTLY_USED_MODELS_KEY);
|
||||
return stored ? JSON.parse(stored) : [];
|
||||
} catch (error) {
|
||||
console.error('Error loading recently used models:', error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a model to the recently used list
|
||||
*/
|
||||
function addToRecentlyUsed(provider: string, model: string): void {
|
||||
try {
|
||||
let recentModels = getRecentlyUsedModels();
|
||||
|
||||
// Remove existing entry if present
|
||||
recentModels = recentModels.filter((m) => !(m.provider === provider && m.model === model));
|
||||
|
||||
// Add to front
|
||||
recentModels.unshift({ provider, model, lastUsed: Date.now() });
|
||||
|
||||
// Keep only the most recent models
|
||||
recentModels = recentModels.slice(0, MAX_RECENTLY_USED_MODELS);
|
||||
|
||||
localStorage.setItem(RECENTLY_USED_MODELS_KEY, JSON.stringify(recentModels));
|
||||
} catch (error) {
|
||||
console.error('Error saving recently used models:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load pricing data from localStorage cache - only for recently used models
|
||||
*/
|
||||
function loadPricingFromLocalStorage(): PricingCacheData | null {
|
||||
try {
|
||||
const cached = localStorage.getItem(PRICING_CACHE_KEY);
|
||||
const timestamp = localStorage.getItem(PRICING_CACHE_TIMESTAMP_KEY);
|
||||
|
||||
if (cached && timestamp) {
|
||||
const cacheAge = Date.now() - parseInt(timestamp, 10);
|
||||
if (cacheAge < CACHE_TTL_MS) {
|
||||
const fullCache = JSON.parse(cached) as PricingCacheData;
|
||||
const recentModels = getRecentlyUsedModels();
|
||||
|
||||
// Filter to only include recently used models
|
||||
const filteredPricing = fullCache.pricing.filter((p) =>
|
||||
recentModels.some((r) => r.provider === p.provider && r.model === p.model)
|
||||
);
|
||||
|
||||
console.log(
|
||||
`Loading ${filteredPricing.length} recently used models from cache (out of ${fullCache.pricing.length} total)`
|
||||
);
|
||||
|
||||
return {
|
||||
pricing: filteredPricing,
|
||||
timestamp: fullCache.timestamp,
|
||||
};
|
||||
} else {
|
||||
console.log('LocalStorage pricing cache expired');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error loading pricing from localStorage:', error);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save pricing data to localStorage - merge with existing data
|
||||
*/
|
||||
function savePricingToLocalStorage(data: PricingCacheData, mergeWithExisting = true): void {
|
||||
try {
|
||||
if (mergeWithExisting) {
|
||||
// Load existing full cache
|
||||
const existingCached = localStorage.getItem(PRICING_CACHE_KEY);
|
||||
if (existingCached) {
|
||||
const existingData = JSON.parse(existingCached) as PricingCacheData;
|
||||
|
||||
// Create a map of existing pricing for quick lookup
|
||||
const pricingMap = new Map<string, (typeof data.pricing)[0]>();
|
||||
existingData.pricing.forEach((p) => {
|
||||
pricingMap.set(`${p.provider}/${p.model}`, p);
|
||||
});
|
||||
|
||||
// Update with new data
|
||||
data.pricing.forEach((p) => {
|
||||
pricingMap.set(`${p.provider}/${p.model}`, p);
|
||||
});
|
||||
|
||||
// Convert back to array
|
||||
data = {
|
||||
pricing: Array.from(pricingMap.values()),
|
||||
timestamp: data.timestamp,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
localStorage.setItem(PRICING_CACHE_KEY, JSON.stringify(data));
|
||||
localStorage.setItem(PRICING_CACHE_TIMESTAMP_KEY, data.timestamp.toString());
|
||||
console.log(`Saved ${data.pricing.length} models to localStorage cache`);
|
||||
} catch (error) {
|
||||
console.error('Error saving pricing to localStorage:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch pricing data from backend for specific provider/model
|
||||
*/
|
||||
async function fetchPricingForModel(
|
||||
provider: string,
|
||||
model: string
|
||||
): Promise<ModelCostInfo | null> {
|
||||
try {
|
||||
const apiUrl = getApiUrl('/config/pricing');
|
||||
const secretKey = getSecretKey();
|
||||
|
||||
console.log(`Fetching pricing for ${provider}/${model} from ${apiUrl}`);
|
||||
|
||||
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||
if (secretKey) {
|
||||
headers['X-Secret-Key'] = secretKey;
|
||||
}
|
||||
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({ configured_only: false }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Failed to fetch pricing data:', response.status);
|
||||
throw new Error(`API request failed with status ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log('Pricing response:', data);
|
||||
|
||||
// Find the specific model pricing
|
||||
const pricing = data.pricing?.find(
|
||||
(p: {
|
||||
provider: string;
|
||||
model: string;
|
||||
input_token_cost: number;
|
||||
output_token_cost: number;
|
||||
currency: string;
|
||||
}) => {
|
||||
const providerMatch = p.provider.toLowerCase() === provider.toLowerCase();
|
||||
|
||||
// More flexible model matching - handle versioned models
|
||||
let modelMatch = p.model === model;
|
||||
|
||||
// If exact match fails, try matching without version suffix
|
||||
if (!modelMatch && model.includes('-20')) {
|
||||
// Remove date suffix like -20241022
|
||||
const modelWithoutDate = model.replace(/-20\d{6}$/, '');
|
||||
modelMatch = p.model === modelWithoutDate;
|
||||
|
||||
// Also try with dots instead of dashes (claude-3-5-sonnet vs claude-3.5-sonnet)
|
||||
if (!modelMatch) {
|
||||
const modelWithDots = modelWithoutDate.replace(/-(\d)-/g, '.$1.');
|
||||
modelMatch = p.model === modelWithDots;
|
||||
}
|
||||
}
|
||||
|
||||
console.log(
|
||||
`Comparing: ${p.provider}/${p.model} with ${provider}/${model} - Provider match: ${providerMatch}, Model match: ${modelMatch}`
|
||||
);
|
||||
return providerMatch && modelMatch;
|
||||
}
|
||||
);
|
||||
|
||||
console.log(`Found pricing for ${provider}/${model}:`, pricing);
|
||||
|
||||
if (pricing) {
|
||||
return {
|
||||
input_token_cost: pricing.input_token_cost,
|
||||
output_token_cost: pricing.output_token_cost,
|
||||
currency: pricing.currency || '$',
|
||||
};
|
||||
}
|
||||
|
||||
console.log(
|
||||
`No pricing found for ${provider}/${model} in:`,
|
||||
data.pricing?.map((p: { provider: string; model: string }) => `${p.provider}/${p.model}`)
|
||||
);
|
||||
|
||||
// API call succeeded but model not found in pricing data
|
||||
return null;
|
||||
} catch (error) {
|
||||
console.error('Error fetching pricing data:', error);
|
||||
// Re-throw the error so the caller can distinguish between API failure and model not found
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the cost database - only load commonly used models on startup
|
||||
*/
|
||||
export async function initializeCostDatabase(): Promise<void> {
|
||||
try {
|
||||
// Clean up any existing large caches first
|
||||
cleanupPricingCache();
|
||||
|
||||
// First check if we have valid cached data
|
||||
const cachedData = loadPricingFromLocalStorage();
|
||||
if (cachedData && cachedData.pricing.length > 0) {
|
||||
console.log('Using cached pricing data from localStorage');
|
||||
return;
|
||||
}
|
||||
|
||||
// List of commonly used models to pre-fetch
|
||||
const commonModels = [
|
||||
{ provider: 'openai', model: 'gpt-4o' },
|
||||
{ provider: 'openai', model: 'gpt-4o-mini' },
|
||||
{ provider: 'openai', model: 'gpt-4-turbo' },
|
||||
{ provider: 'openai', model: 'gpt-4' },
|
||||
{ provider: 'openai', model: 'gpt-3.5-turbo' },
|
||||
{ provider: 'anthropic', model: 'claude-3-5-sonnet' },
|
||||
{ provider: 'anthropic', model: 'claude-3-5-sonnet-20241022' },
|
||||
{ provider: 'anthropic', model: 'claude-3-opus' },
|
||||
{ provider: 'anthropic', model: 'claude-3-sonnet' },
|
||||
{ provider: 'anthropic', model: 'claude-3-haiku' },
|
||||
{ provider: 'google', model: 'gemini-1.5-pro' },
|
||||
{ provider: 'google', model: 'gemini-1.5-flash' },
|
||||
{ provider: 'deepseek', model: 'deepseek-chat' },
|
||||
{ provider: 'deepseek', model: 'deepseek-reasoner' },
|
||||
{ provider: 'meta-llama', model: 'llama-3.2-90b-text-preview' },
|
||||
{ provider: 'meta-llama', model: 'llama-3.1-405b-instruct' },
|
||||
];
|
||||
|
||||
// Get recently used models
|
||||
const recentModels = getRecentlyUsedModels();
|
||||
|
||||
// Combine common and recent models (deduplicated)
|
||||
const modelsToFetch = new Map<string, { provider: string; model: string }>();
|
||||
|
||||
// Add common models
|
||||
commonModels.forEach((m) => {
|
||||
modelsToFetch.set(`${m.provider}/${m.model}`, m);
|
||||
});
|
||||
|
||||
// Add recent models
|
||||
recentModels.forEach((m) => {
|
||||
modelsToFetch.set(`${m.provider}/${m.model}`, { provider: m.provider, model: m.model });
|
||||
});
|
||||
|
||||
console.log(`Initializing cost database with ${modelsToFetch.size} models...`);
|
||||
|
||||
// Fetch only the pricing we need
|
||||
const apiUrl = getApiUrl('/config/pricing');
|
||||
const secretKey = getSecretKey();
|
||||
|
||||
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||
if (secretKey) {
|
||||
headers['X-Secret-Key'] = secretKey;
|
||||
}
|
||||
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
configured_only: false,
|
||||
models: Array.from(modelsToFetch.values()), // Send specific models if API supports it
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Failed to fetch initial pricing data:', response.status);
|
||||
return;
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log(`Fetched pricing for ${data.pricing?.length || 0} models`);
|
||||
|
||||
if (data.pricing && data.pricing.length > 0) {
|
||||
// Filter to only the models we requested (in case API returns all)
|
||||
const filteredPricing = data.pricing.filter((p: PricingItem) =>
|
||||
modelsToFetch.has(`${p.provider}/${p.model}`)
|
||||
);
|
||||
|
||||
// Save to localStorage
|
||||
const cacheData: PricingCacheData = {
|
||||
pricing: filteredPricing.length > 0 ? filteredPricing : data.pricing.slice(0, 50), // Fallback to first 50 if filtering didn't work
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
savePricingToLocalStorage(cacheData, false); // Don't merge on initial load
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error initializing cost database:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update model costs from providers - no longer needed
|
||||
*/
|
||||
export async function updateAllModelCosts(): Promise<void> {
|
||||
// No-op - we fetch on demand now
|
||||
}
|
||||
|
||||
/**
|
||||
* Get cost information for a specific model with caching
|
||||
*/
|
||||
export function getCostForModel(provider: string, model: string): ModelCostInfo | null {
|
||||
// Track this model as recently used
|
||||
addToRecentlyUsed(provider, model);
|
||||
|
||||
// Check if it's the same model we already have cached in memory
|
||||
if (
|
||||
currentModelPricing &&
|
||||
currentModelPricing.provider === provider &&
|
||||
currentModelPricing.model === model
|
||||
) {
|
||||
return currentModelPricing.costInfo;
|
||||
}
|
||||
|
||||
// For local/free providers, return zero cost immediately
|
||||
const freeProviders = ['ollama', 'local', 'localhost'];
|
||||
if (freeProviders.includes(provider.toLowerCase())) {
|
||||
const zeroCost = {
|
||||
input_token_cost: 0,
|
||||
output_token_cost: 0,
|
||||
currency: '$',
|
||||
};
|
||||
currentModelPricing = { provider, model, costInfo: zeroCost };
|
||||
return zeroCost;
|
||||
}
|
||||
|
||||
// Check localStorage cache (which now only contains recently used models)
|
||||
const cachedData = loadPricingFromLocalStorage();
|
||||
if (cachedData) {
|
||||
const pricing = cachedData.pricing.find((p) => {
|
||||
const providerMatch = p.provider.toLowerCase() === provider.toLowerCase();
|
||||
|
||||
// More flexible model matching - handle versioned models
|
||||
let modelMatch = p.model === model;
|
||||
|
||||
// If exact match fails, try matching without version suffix
|
||||
if (!modelMatch && model.includes('-20')) {
|
||||
// Remove date suffix like -20241022
|
||||
const modelWithoutDate = model.replace(/-20\d{6}$/, '');
|
||||
modelMatch = p.model === modelWithoutDate;
|
||||
|
||||
// Also try with dots instead of dashes (claude-3-5-sonnet vs claude-3.5-sonnet)
|
||||
if (!modelMatch) {
|
||||
const modelWithDots = modelWithoutDate.replace(/-(\d)-/g, '.$1.');
|
||||
modelMatch = p.model === modelWithDots;
|
||||
}
|
||||
}
|
||||
|
||||
return providerMatch && modelMatch;
|
||||
});
|
||||
|
||||
if (pricing) {
|
||||
const costInfo = {
|
||||
input_token_cost: pricing.input_token_cost,
|
||||
output_token_cost: pricing.output_token_cost,
|
||||
currency: pricing.currency || '$',
|
||||
};
|
||||
currentModelPricing = { provider, model, costInfo };
|
||||
return costInfo;
|
||||
}
|
||||
}
|
||||
|
||||
// Need to fetch new pricing - return null for now
|
||||
// The component will handle the async fetch
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch and cache pricing for a model
|
||||
*/
|
||||
export async function fetchAndCachePricing(
|
||||
provider: string,
|
||||
model: string
|
||||
): Promise<{ costInfo: ModelCostInfo | null; error?: string } | null> {
|
||||
try {
|
||||
const costInfo = await fetchPricingForModel(provider, model);
|
||||
|
||||
if (costInfo) {
|
||||
// Cache the result in memory
|
||||
currentModelPricing = { provider, model, costInfo };
|
||||
|
||||
// Update localStorage cache with this new data
|
||||
const cachedData = loadPricingFromLocalStorage();
|
||||
if (cachedData) {
|
||||
// Check if this model already exists in cache
|
||||
const existingIndex = cachedData.pricing.findIndex(
|
||||
(p) => p.provider.toLowerCase() === provider.toLowerCase() && p.model === model
|
||||
);
|
||||
|
||||
const newPricing = {
|
||||
provider,
|
||||
model,
|
||||
input_token_cost: costInfo.input_token_cost,
|
||||
output_token_cost: costInfo.output_token_cost,
|
||||
currency: costInfo.currency,
|
||||
};
|
||||
|
||||
if (existingIndex >= 0) {
|
||||
// Update existing
|
||||
cachedData.pricing[existingIndex] = newPricing;
|
||||
} else {
|
||||
// Add new
|
||||
cachedData.pricing.push(newPricing);
|
||||
}
|
||||
|
||||
// Save updated cache
|
||||
savePricingToLocalStorage(cachedData);
|
||||
}
|
||||
|
||||
return { costInfo };
|
||||
} else {
|
||||
// Cache the null result in memory
|
||||
currentModelPricing = { provider, model, costInfo: null };
|
||||
|
||||
// Check if the API call succeeded but model wasn't found
|
||||
// We can determine this by checking if we got a response but no matching model
|
||||
return { costInfo: null, error: 'model_not_found' };
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in fetchAndCachePricing:', error);
|
||||
// This is a real API/network error
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh pricing data from backend - only refresh recently used models
|
||||
*/
|
||||
export async function refreshPricing(): Promise<boolean> {
|
||||
try {
|
||||
const apiUrl = getApiUrl('/config/pricing');
|
||||
const secretKey = getSecretKey();
|
||||
|
||||
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||
if (secretKey) {
|
||||
headers['X-Secret-Key'] = secretKey;
|
||||
}
|
||||
|
||||
// Get recently used models to refresh
|
||||
const recentModels = getRecentlyUsedModels();
|
||||
|
||||
// Add some common models as well
|
||||
const commonModels = [
|
||||
{ provider: 'openai', model: 'gpt-4o' },
|
||||
{ provider: 'openai', model: 'gpt-4o-mini' },
|
||||
{ provider: 'anthropic', model: 'claude-3-5-sonnet-20241022' },
|
||||
{ provider: 'google', model: 'gemini-1.5-pro' },
|
||||
];
|
||||
|
||||
// Combine and deduplicate
|
||||
const modelsToRefresh = new Map<string, { provider: string; model: string }>();
|
||||
|
||||
commonModels.forEach((m) => {
|
||||
modelsToRefresh.set(`${m.provider}/${m.model}`, m);
|
||||
});
|
||||
|
||||
recentModels.forEach((m) => {
|
||||
modelsToRefresh.set(`${m.provider}/${m.model}`, { provider: m.provider, model: m.model });
|
||||
});
|
||||
|
||||
console.log(`Refreshing pricing for ${modelsToRefresh.size} models...`);
|
||||
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
configured_only: false,
|
||||
models: Array.from(modelsToRefresh.values()), // Send specific models if API supports it
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
|
||||
if (data.pricing && data.pricing.length > 0) {
|
||||
// Filter to only the models we requested (in case API returns all)
|
||||
const filteredPricing = data.pricing.filter((p: PricingItem) =>
|
||||
modelsToRefresh.has(`${p.provider}/${p.model}`)
|
||||
);
|
||||
|
||||
// Save fresh data to localStorage (merge with existing)
|
||||
const cacheData: PricingCacheData = {
|
||||
pricing: filteredPricing.length > 0 ? filteredPricing : data.pricing.slice(0, 50),
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
savePricingToLocalStorage(cacheData, true); // Merge with existing
|
||||
}
|
||||
|
||||
// Clear current memory cache to force re-fetch
|
||||
currentModelPricing = null;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
} catch (error) {
|
||||
console.error('Error refreshing pricing data:', error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up old/unused models from the cache
|
||||
*/
|
||||
export function cleanupPricingCache(): void {
|
||||
try {
|
||||
const recentModels = getRecentlyUsedModels();
|
||||
const cachedData = localStorage.getItem(PRICING_CACHE_KEY);
|
||||
|
||||
if (!cachedData) return;
|
||||
|
||||
const fullCache = JSON.parse(cachedData) as PricingCacheData;
|
||||
const recentModelKeys = new Set(recentModels.map((m) => `${m.provider}/${m.model}`));
|
||||
|
||||
// Keep only recently used models and common models
|
||||
const commonModelKeys = new Set([
|
||||
'openai/gpt-4o',
|
||||
'openai/gpt-4o-mini',
|
||||
'openai/gpt-4-turbo',
|
||||
'anthropic/claude-3-5-sonnet',
|
||||
'anthropic/claude-3-5-sonnet-20241022',
|
||||
'google/gemini-1.5-pro',
|
||||
'google/gemini-1.5-flash',
|
||||
]);
|
||||
|
||||
const filteredPricing = fullCache.pricing.filter((p) => {
|
||||
const key = `${p.provider}/${p.model}`;
|
||||
return recentModelKeys.has(key) || commonModelKeys.has(key);
|
||||
});
|
||||
|
||||
if (filteredPricing.length < fullCache.pricing.length) {
|
||||
console.log(
|
||||
`Cleaned up pricing cache: reduced from ${fullCache.pricing.length} to ${filteredPricing.length} models`
|
||||
);
|
||||
|
||||
const cleanedCache: PricingCacheData = {
|
||||
pricing: filteredPricing,
|
||||
timestamp: fullCache.timestamp,
|
||||
};
|
||||
|
||||
localStorage.setItem(PRICING_CACHE_KEY, JSON.stringify(cleanedCache));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error cleaning up pricing cache:', error);
|
||||
}
|
||||
}
|
||||
@@ -44,6 +44,10 @@ export default {
|
||||
'0%': { transform: 'rotate(0deg)' },
|
||||
'100%': { transform: 'rotate(360deg)' },
|
||||
},
|
||||
'spin-fast': {
|
||||
'0%': { transform: 'rotate(0deg)' },
|
||||
'100%': { transform: 'rotate(360deg)' },
|
||||
},
|
||||
indeterminate: {
|
||||
'0%': { left: '-40%', width: '40%' },
|
||||
'50%': { left: '20%', width: '60%' },
|
||||
@@ -54,6 +58,7 @@ export default {
|
||||
'shimmer-pulse': 'shimmer 4s ease-in-out infinite',
|
||||
'gradient-loader': 'loader 750ms ease-in-out infinite',
|
||||
indeterminate: 'indeterminate 1.5s infinite linear',
|
||||
'spin-fast': 'spin-fast 0.5s linear infinite',
|
||||
},
|
||||
colors: {
|
||||
bgApp: 'var(--background-app)',
|
||||
|
||||
Reference in New Issue
Block a user