mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-21 16:14:21 +01:00
feat: Add comprehensive cost tracking display for LLM usage (#2992)
Co-authored-by: jack <jack@deck.local> Co-authored-by: Bradley Axen <baxen@squareup.com>
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -3427,6 +3427,7 @@ dependencies = [
|
|||||||
"chrono",
|
"chrono",
|
||||||
"criterion",
|
"criterion",
|
||||||
"ctor",
|
"ctor",
|
||||||
|
"dirs 5.0.1",
|
||||||
"dotenv",
|
"dotenv",
|
||||||
"etcetera",
|
"etcetera",
|
||||||
"fs2",
|
"fs2",
|
||||||
|
|||||||
@@ -10,12 +10,23 @@ use goose::scheduler_factory::SchedulerFactory;
|
|||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
|
use goose::providers::pricing::initialize_pricing_cache;
|
||||||
|
|
||||||
pub async fn run() -> Result<()> {
|
pub async fn run() -> Result<()> {
|
||||||
// Initialize logging
|
// Initialize logging
|
||||||
crate::logging::setup_logging(Some("goosed"))?;
|
crate::logging::setup_logging(Some("goosed"))?;
|
||||||
|
|
||||||
let settings = configuration::Settings::new()?;
|
let settings = configuration::Settings::new()?;
|
||||||
|
|
||||||
|
// Initialize pricing cache on startup
|
||||||
|
tracing::info!("Initializing pricing cache...");
|
||||||
|
if let Err(e) = initialize_pricing_cache().await {
|
||||||
|
tracing::warn!(
|
||||||
|
"Failed to initialize pricing cache: {}. Pricing data may not be available.",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let secret_key =
|
let secret_key =
|
||||||
std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string());
|
std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string());
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ use goose::config::{extensions::name_to_key, PermissionManager};
|
|||||||
use goose::config::{ExtensionConfigManager, ExtensionEntry};
|
use goose::config::{ExtensionConfigManager, ExtensionEntry};
|
||||||
use goose::model::ModelConfig;
|
use goose::model::ModelConfig;
|
||||||
use goose::providers::base::ProviderMetadata;
|
use goose::providers::base::ProviderMetadata;
|
||||||
|
use goose::providers::pricing::{get_all_pricing, get_model_pricing, refresh_pricing};
|
||||||
use goose::providers::providers as get_providers;
|
use goose::providers::providers as get_providers;
|
||||||
use goose::{agents::ExtensionConfig, config::permission::PermissionLevel};
|
use goose::{agents::ExtensionConfig, config::permission::PermissionLevel};
|
||||||
use http::{HeaderMap, StatusCode};
|
use http::{HeaderMap, StatusCode};
|
||||||
@@ -314,6 +315,128 @@ pub async fn providers(
|
|||||||
Ok(Json(providers_response))
|
Ok(Json(providers_response))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub struct PricingData {
|
||||||
|
pub provider: String,
|
||||||
|
pub model: String,
|
||||||
|
pub input_token_cost: f64,
|
||||||
|
pub output_token_cost: f64,
|
||||||
|
pub currency: String,
|
||||||
|
pub context_length: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub struct PricingResponse {
|
||||||
|
pub pricing: Vec<PricingData>,
|
||||||
|
pub source: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, ToSchema)]
|
||||||
|
pub struct PricingQuery {
|
||||||
|
/// If true, only return pricing for configured providers. If false, return all.
|
||||||
|
pub configured_only: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
path = "/config/pricing",
|
||||||
|
request_body = PricingQuery,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Model pricing data retrieved successfully", body = PricingResponse)
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn get_pricing(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
headers: HeaderMap,
|
||||||
|
Json(query): Json<PricingQuery>,
|
||||||
|
) -> Result<Json<PricingResponse>, StatusCode> {
|
||||||
|
verify_secret_key(&headers, &state)?;
|
||||||
|
|
||||||
|
let configured_only = query.configured_only.unwrap_or(true);
|
||||||
|
|
||||||
|
// If refresh requested (configured_only = false), refresh the cache
|
||||||
|
if !configured_only {
|
||||||
|
if let Err(e) = refresh_pricing().await {
|
||||||
|
tracing::error!("Failed to refresh pricing data: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut pricing_data = Vec::new();
|
||||||
|
|
||||||
|
if !configured_only {
|
||||||
|
// Get ALL pricing data from the cache
|
||||||
|
let all_pricing = get_all_pricing().await;
|
||||||
|
|
||||||
|
for (provider, models) in all_pricing {
|
||||||
|
for (model, pricing) in models {
|
||||||
|
pricing_data.push(PricingData {
|
||||||
|
provider: provider.clone(),
|
||||||
|
model: model.clone(),
|
||||||
|
input_token_cost: pricing.input_cost,
|
||||||
|
output_token_cost: pricing.output_cost,
|
||||||
|
currency: "$".to_string(),
|
||||||
|
context_length: pricing.context_length,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Get only configured providers' pricing
|
||||||
|
let providers_metadata = get_providers();
|
||||||
|
|
||||||
|
for metadata in providers_metadata {
|
||||||
|
// Skip unconfigured providers if filtering
|
||||||
|
if !check_provider_configured(&metadata) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for model_info in &metadata.known_models {
|
||||||
|
// Try to get pricing from cache
|
||||||
|
if let Some(pricing) = get_model_pricing(&metadata.name, &model_info.name).await {
|
||||||
|
pricing_data.push(PricingData {
|
||||||
|
provider: metadata.name.clone(),
|
||||||
|
model: model_info.name.clone(),
|
||||||
|
input_token_cost: pricing.input_cost,
|
||||||
|
output_token_cost: pricing.output_cost,
|
||||||
|
currency: "$".to_string(),
|
||||||
|
context_length: pricing.context_length,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// Check if the model has embedded pricing data
|
||||||
|
else if let (Some(input_cost), Some(output_cost)) =
|
||||||
|
(model_info.input_token_cost, model_info.output_token_cost)
|
||||||
|
{
|
||||||
|
pricing_data.push(PricingData {
|
||||||
|
provider: metadata.name.clone(),
|
||||||
|
model: model_info.name.clone(),
|
||||||
|
input_token_cost: input_cost,
|
||||||
|
output_token_cost: output_cost,
|
||||||
|
currency: model_info
|
||||||
|
.currency
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "$".to_string()),
|
||||||
|
context_length: Some(model_info.context_limit as u32),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
"Returning pricing for {} models{}",
|
||||||
|
pricing_data.len(),
|
||||||
|
if configured_only {
|
||||||
|
" (configured providers only)"
|
||||||
|
} else {
|
||||||
|
" (all cached models)"
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Json(PricingResponse {
|
||||||
|
pricing: pricing_data,
|
||||||
|
source: "openrouter".to_string(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
path = "/config/init",
|
path = "/config/init",
|
||||||
@@ -471,6 +594,7 @@ pub fn routes(state: Arc<AppState>) -> Router {
|
|||||||
.route("/config/extensions", post(add_extension))
|
.route("/config/extensions", post(add_extension))
|
||||||
.route("/config/extensions/{name}", delete(remove_extension))
|
.route("/config/extensions/{name}", delete(remove_extension))
|
||||||
.route("/config/providers", get(providers))
|
.route("/config/providers", get(providers))
|
||||||
|
.route("/config/pricing", post(get_pricing))
|
||||||
.route("/config/init", post(init_config))
|
.route("/config/init", post(init_config))
|
||||||
.route("/config/backup", post(backup_config))
|
.route("/config/backup", post(backup_config))
|
||||||
.route("/config/permissions", post(upsert_permissions))
|
.route("/config/permissions", post(upsert_permissions))
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ mcp-core = { path = "../mcp-core" }
|
|||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
thiserror = "1.0"
|
thiserror = "1.0"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
|
dirs = "5.0"
|
||||||
reqwest = { version = "0.12.9", features = [
|
reqwest = { version = "0.12.9", features = [
|
||||||
"rustls-tls-native-roots",
|
"rustls-tls-native-roots",
|
||||||
"json",
|
"json",
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ use reqwest::{Client, StatusCode};
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
|
use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
|
||||||
use super::errors::ProviderError;
|
use super::errors::ProviderError;
|
||||||
use super::formats::anthropic::{create_request, get_usage, response_to_message};
|
use super::formats::anthropic::{create_request, get_usage, response_to_message};
|
||||||
use super::utils::{emit_debug_trace, get_model};
|
use super::utils::{emit_debug_trace, get_model};
|
||||||
@@ -122,12 +122,18 @@ impl AnthropicProvider {
|
|||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Provider for AnthropicProvider {
|
impl Provider for AnthropicProvider {
|
||||||
fn metadata() -> ProviderMetadata {
|
fn metadata() -> ProviderMetadata {
|
||||||
ProviderMetadata::new(
|
ProviderMetadata::with_models(
|
||||||
"anthropic",
|
"anthropic",
|
||||||
"Anthropic",
|
"Anthropic",
|
||||||
"Claude and other models from Anthropic",
|
"Claude and other models from Anthropic",
|
||||||
ANTHROPIC_DEFAULT_MODEL,
|
ANTHROPIC_DEFAULT_MODEL,
|
||||||
ANTHROPIC_KNOWN_MODELS.to_vec(),
|
vec![
|
||||||
|
ModelInfo::with_cost("claude-3-5-sonnet-20241022", 200000, 0.000003, 0.000015),
|
||||||
|
ModelInfo::with_cost("claude-3-5-haiku-20241022", 200000, 0.000001, 0.000005),
|
||||||
|
ModelInfo::with_cost("claude-3-opus-20240229", 200000, 0.000015, 0.000075),
|
||||||
|
ModelInfo::with_cost("claude-3-sonnet-20240229", 200000, 0.000003, 0.000015),
|
||||||
|
ModelInfo::with_cost("claude-3-haiku-20240307", 200000, 0.00000025, 0.00000125),
|
||||||
|
],
|
||||||
ANTHROPIC_DOC_URL,
|
ANTHROPIC_DOC_URL,
|
||||||
vec![
|
vec![
|
||||||
ConfigKey::new("ANTHROPIC_API_KEY", true, true, None),
|
ConfigKey::new("ANTHROPIC_API_KEY", true, true, None),
|
||||||
|
|||||||
@@ -32,6 +32,41 @@ pub struct ModelInfo {
|
|||||||
pub name: String,
|
pub name: String,
|
||||||
/// The maximum context length this model supports
|
/// The maximum context length this model supports
|
||||||
pub context_limit: usize,
|
pub context_limit: usize,
|
||||||
|
/// Cost per token for input (optional)
|
||||||
|
pub input_token_cost: Option<f64>,
|
||||||
|
/// Cost per token for output (optional)
|
||||||
|
pub output_token_cost: Option<f64>,
|
||||||
|
/// Currency for the costs (default: "$")
|
||||||
|
pub currency: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelInfo {
|
||||||
|
/// Create a new ModelInfo with just name and context limit
|
||||||
|
pub fn new(name: impl Into<String>, context_limit: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
name: name.into(),
|
||||||
|
context_limit,
|
||||||
|
input_token_cost: None,
|
||||||
|
output_token_cost: None,
|
||||||
|
currency: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new ModelInfo with cost information (per token)
|
||||||
|
pub fn with_cost(
|
||||||
|
name: impl Into<String>,
|
||||||
|
context_limit: usize,
|
||||||
|
input_cost: f64,
|
||||||
|
output_cost: f64,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
name: name.into(),
|
||||||
|
context_limit,
|
||||||
|
input_token_cost: Some(input_cost),
|
||||||
|
output_token_cost: Some(output_cost),
|
||||||
|
currency: Some("$".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Metadata about a provider's configuration requirements and capabilities
|
/// Metadata about a provider's configuration requirements and capabilities
|
||||||
@@ -74,6 +109,9 @@ impl ProviderMetadata {
|
|||||||
.map(|&name| ModelInfo {
|
.map(|&name| ModelInfo {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
context_limit: ModelConfig::new(name.to_string()).context_limit(),
|
context_limit: ModelConfig::new(name.to_string()).context_limit(),
|
||||||
|
input_token_cost: None,
|
||||||
|
output_token_cost: None,
|
||||||
|
currency: None,
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
model_doc_link: model_doc_link.to_string(),
|
model_doc_link: model_doc_link.to_string(),
|
||||||
@@ -81,6 +119,27 @@ impl ProviderMetadata {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a new ProviderMetadata with ModelInfo objects that include cost data
|
||||||
|
pub fn with_models(
|
||||||
|
name: &str,
|
||||||
|
display_name: &str,
|
||||||
|
description: &str,
|
||||||
|
default_model: &str,
|
||||||
|
models: Vec<ModelInfo>,
|
||||||
|
model_doc_link: &str,
|
||||||
|
config_keys: Vec<ConfigKey>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
name: name.to_string(),
|
||||||
|
display_name: display_name.to_string(),
|
||||||
|
description: description.to_string(),
|
||||||
|
default_model: default_model.to_string(),
|
||||||
|
known_models: models,
|
||||||
|
model_doc_link: model_doc_link.to_string(),
|
||||||
|
config_keys,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn empty() -> Self {
|
pub fn empty() -> Self {
|
||||||
Self {
|
Self {
|
||||||
name: "".to_string(),
|
name: "".to_string(),
|
||||||
@@ -313,6 +372,9 @@ mod tests {
|
|||||||
let info = ModelInfo {
|
let info = ModelInfo {
|
||||||
name: "test-model".to_string(),
|
name: "test-model".to_string(),
|
||||||
context_limit: 1000,
|
context_limit: 1000,
|
||||||
|
input_token_cost: None,
|
||||||
|
output_token_cost: None,
|
||||||
|
currency: None,
|
||||||
};
|
};
|
||||||
assert_eq!(info.context_limit, 1000);
|
assert_eq!(info.context_limit, 1000);
|
||||||
|
|
||||||
@@ -320,6 +382,9 @@ mod tests {
|
|||||||
let info2 = ModelInfo {
|
let info2 = ModelInfo {
|
||||||
name: "test-model".to_string(),
|
name: "test-model".to_string(),
|
||||||
context_limit: 1000,
|
context_limit: 1000,
|
||||||
|
input_token_cost: None,
|
||||||
|
output_token_cost: None,
|
||||||
|
currency: None,
|
||||||
};
|
};
|
||||||
assert_eq!(info, info2);
|
assert_eq!(info, info2);
|
||||||
|
|
||||||
@@ -327,7 +392,20 @@ mod tests {
|
|||||||
let info3 = ModelInfo {
|
let info3 = ModelInfo {
|
||||||
name: "test-model".to_string(),
|
name: "test-model".to_string(),
|
||||||
context_limit: 2000,
|
context_limit: 2000,
|
||||||
|
input_token_cost: None,
|
||||||
|
output_token_cost: None,
|
||||||
|
currency: None,
|
||||||
};
|
};
|
||||||
assert_ne!(info, info3);
|
assert_ne!(info, info3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_info_with_cost() {
|
||||||
|
let info = ModelInfo::with_cost("gpt-4o", 128000, 0.0000025, 0.00001);
|
||||||
|
assert_eq!(info.name, "gpt-4o");
|
||||||
|
assert_eq!(info.context_limit, 128000);
|
||||||
|
assert_eq!(info.input_token_cost, Some(0.0000025));
|
||||||
|
assert_eq!(info.output_token_cost, Some(0.00001));
|
||||||
|
assert_eq!(info.currency, Some("$".to_string()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ pub mod oauth;
|
|||||||
pub mod ollama;
|
pub mod ollama;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
pub mod openrouter;
|
pub mod openrouter;
|
||||||
|
pub mod pricing;
|
||||||
pub mod sagemaker_tgi;
|
pub mod sagemaker_tgi;
|
||||||
pub mod snowflake;
|
pub mod snowflake;
|
||||||
pub mod toolshim;
|
pub mod toolshim;
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ use serde_json::Value;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
|
use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||||
use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse};
|
use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse};
|
||||||
use super::errors::ProviderError;
|
use super::errors::ProviderError;
|
||||||
use super::formats::openai::{create_request, get_usage, response_to_message};
|
use super::formats::openai::{create_request, get_usage, response_to_message};
|
||||||
@@ -126,12 +126,20 @@ impl OpenAiProvider {
|
|||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Provider for OpenAiProvider {
|
impl Provider for OpenAiProvider {
|
||||||
fn metadata() -> ProviderMetadata {
|
fn metadata() -> ProviderMetadata {
|
||||||
ProviderMetadata::new(
|
ProviderMetadata::with_models(
|
||||||
"openai",
|
"openai",
|
||||||
"OpenAI",
|
"OpenAI",
|
||||||
"GPT-4 and other OpenAI models, including OpenAI compatible ones",
|
"GPT-4 and other OpenAI models, including OpenAI compatible ones",
|
||||||
OPEN_AI_DEFAULT_MODEL,
|
OPEN_AI_DEFAULT_MODEL,
|
||||||
OPEN_AI_KNOWN_MODELS.to_vec(),
|
vec![
|
||||||
|
ModelInfo::with_cost("gpt-4o", 128000, 0.0000025, 0.00001),
|
||||||
|
ModelInfo::with_cost("gpt-4o-mini", 128000, 0.00000015, 0.0000006),
|
||||||
|
ModelInfo::with_cost("gpt-4-turbo", 128000, 0.00001, 0.00003),
|
||||||
|
ModelInfo::with_cost("gpt-3.5-turbo", 16385, 0.0000005, 0.0000015),
|
||||||
|
ModelInfo::with_cost("o1", 200000, 0.000015, 0.00006),
|
||||||
|
ModelInfo::with_cost("o3", 200000, 0.000015, 0.00006), // Using o1 pricing as placeholder
|
||||||
|
ModelInfo::with_cost("o4-mini", 128000, 0.000003, 0.000012), // Using o1-mini pricing as placeholder
|
||||||
|
],
|
||||||
OPEN_AI_DOC_URL,
|
OPEN_AI_DOC_URL,
|
||||||
vec![
|
vec![
|
||||||
ConfigKey::new("OPENAI_API_KEY", true, true, None),
|
ConfigKey::new("OPENAI_API_KEY", true, true, None),
|
||||||
|
|||||||
387
crates/goose/src/providers/pricing.rs
Normal file
387
crates/goose/src/providers/pricing.rs
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
/// Disk cache configuration
|
||||||
|
const CACHE_FILE_NAME: &str = "pricing_cache.json";
|
||||||
|
const CACHE_TTL_DAYS: u64 = 7; // Cache for 7 days
|
||||||
|
|
||||||
|
/// Get the cache directory path
|
||||||
|
fn get_cache_dir() -> Result<PathBuf> {
|
||||||
|
let cache_dir = if let Ok(goose_dir) = std::env::var("GOOSE_CACHE_DIR") {
|
||||||
|
PathBuf::from(goose_dir)
|
||||||
|
} else {
|
||||||
|
dirs::cache_dir()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Could not determine cache directory"))?
|
||||||
|
.join("goose")
|
||||||
|
};
|
||||||
|
Ok(cache_dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cached pricing data structure for disk storage
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct CachedPricingData {
|
||||||
|
/// Nested HashMap: provider -> model -> pricing info
|
||||||
|
pub pricing: HashMap<String, HashMap<String, PricingInfo>>,
|
||||||
|
/// Unix timestamp when data was fetched
|
||||||
|
pub fetched_at: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Simplified pricing info for efficient storage
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PricingInfo {
|
||||||
|
pub input_cost: f64, // Cost per token
|
||||||
|
pub output_cost: f64, // Cost per token
|
||||||
|
pub context_length: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cache for OpenRouter pricing data with disk persistence
|
||||||
|
pub struct PricingCache {
|
||||||
|
/// In-memory cache
|
||||||
|
memory_cache: Arc<RwLock<Option<CachedPricingData>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PricingCache {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
memory_cache: Arc::new(RwLock::new(None)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load pricing from disk cache
|
||||||
|
async fn load_from_disk(&self) -> Result<Option<CachedPricingData>> {
|
||||||
|
let cache_path = get_cache_dir()?.join(CACHE_FILE_NAME);
|
||||||
|
|
||||||
|
if !cache_path.exists() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
match tokio::fs::read(&cache_path).await {
|
||||||
|
Ok(data) => {
|
||||||
|
match serde_json::from_slice::<CachedPricingData>(&data) {
|
||||||
|
Ok(cached) => {
|
||||||
|
// Check if cache is still valid
|
||||||
|
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
|
||||||
|
let age_days = (now - cached.fetched_at) / (24 * 60 * 60);
|
||||||
|
|
||||||
|
if age_days < CACHE_TTL_DAYS {
|
||||||
|
tracing::info!(
|
||||||
|
"Loaded pricing data from disk cache (age: {} days)",
|
||||||
|
age_days
|
||||||
|
);
|
||||||
|
Ok(Some(cached))
|
||||||
|
} else {
|
||||||
|
tracing::info!("Disk cache expired (age: {} days)", age_days);
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Failed to parse pricing cache: {}", e);
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Failed to read pricing cache: {}", e);
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Save pricing data to disk
|
||||||
|
async fn save_to_disk(&self, data: &CachedPricingData) -> Result<()> {
|
||||||
|
let cache_dir = get_cache_dir()?;
|
||||||
|
tokio::fs::create_dir_all(&cache_dir).await?;
|
||||||
|
|
||||||
|
let cache_path = cache_dir.join(CACHE_FILE_NAME);
|
||||||
|
let json_data = serde_json::to_vec_pretty(data)?;
|
||||||
|
tokio::fs::write(&cache_path, json_data).await?;
|
||||||
|
|
||||||
|
tracing::info!("Saved pricing data to disk cache");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get pricing for a specific model
|
||||||
|
pub async fn get_model_pricing(&self, provider: &str, model: &str) -> Option<PricingInfo> {
|
||||||
|
// Try memory cache first
|
||||||
|
{
|
||||||
|
let cache = self.memory_cache.read().await;
|
||||||
|
if let Some(cached) = &*cache {
|
||||||
|
return cached
|
||||||
|
.pricing
|
||||||
|
.get(&provider.to_lowercase())
|
||||||
|
.and_then(|models| models.get(model))
|
||||||
|
.cloned();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try loading from disk
|
||||||
|
if let Ok(Some(disk_cache)) = self.load_from_disk().await {
|
||||||
|
// Update memory cache
|
||||||
|
{
|
||||||
|
let mut cache = self.memory_cache.write().await;
|
||||||
|
*cache = Some(disk_cache.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
return disk_cache
|
||||||
|
.pricing
|
||||||
|
.get(&provider.to_lowercase())
|
||||||
|
.and_then(|models| models.get(model))
|
||||||
|
.cloned();
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Force refresh pricing data from OpenRouter
|
||||||
|
pub async fn refresh(&self) -> Result<()> {
|
||||||
|
let pricing = fetch_openrouter_pricing_internal().await?;
|
||||||
|
|
||||||
|
// Convert to our efficient structure
|
||||||
|
let mut structured_pricing: HashMap<String, HashMap<String, PricingInfo>> = HashMap::new();
|
||||||
|
|
||||||
|
for (model_id, model) in pricing {
|
||||||
|
if let Some((provider, model_name)) = parse_model_id(&model_id) {
|
||||||
|
if let (Some(input_cost), Some(output_cost)) = (
|
||||||
|
convert_pricing(&model.pricing.prompt),
|
||||||
|
convert_pricing(&model.pricing.completion),
|
||||||
|
) {
|
||||||
|
let provider_lower = provider.to_lowercase();
|
||||||
|
let provider_models = structured_pricing.entry(provider_lower).or_default();
|
||||||
|
|
||||||
|
provider_models.insert(
|
||||||
|
model_name,
|
||||||
|
PricingInfo {
|
||||||
|
input_cost,
|
||||||
|
output_cost,
|
||||||
|
context_length: model.context_length,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let cached_data = CachedPricingData {
|
||||||
|
pricing: structured_pricing,
|
||||||
|
fetched_at: SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Log how many models we fetched
|
||||||
|
let total_models: usize = cached_data
|
||||||
|
.pricing
|
||||||
|
.values()
|
||||||
|
.map(|models| models.len())
|
||||||
|
.sum();
|
||||||
|
tracing::info!(
|
||||||
|
"Fetched pricing for {} providers with {} total models from OpenRouter",
|
||||||
|
cached_data.pricing.len(),
|
||||||
|
total_models
|
||||||
|
);
|
||||||
|
|
||||||
|
// Save to disk
|
||||||
|
self.save_to_disk(&cached_data).await?;
|
||||||
|
|
||||||
|
// Update memory cache
|
||||||
|
{
|
||||||
|
let mut cache = self.memory_cache.write().await;
|
||||||
|
*cache = Some(cached_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialize cache (load from disk or fetch if needed)
|
||||||
|
pub async fn initialize(&self) -> Result<()> {
|
||||||
|
// Try loading from disk first
|
||||||
|
if let Ok(Some(cached)) = self.load_from_disk().await {
|
||||||
|
// Log how many models we have cached
|
||||||
|
let total_models: usize = cached.pricing.values().map(|models| models.len()).sum();
|
||||||
|
tracing::info!(
|
||||||
|
"Loaded {} providers with {} total models from disk cache",
|
||||||
|
cached.pricing.len(),
|
||||||
|
total_models
|
||||||
|
);
|
||||||
|
|
||||||
|
// Update memory cache
|
||||||
|
{
|
||||||
|
let mut cache = self.memory_cache.write().await;
|
||||||
|
*cache = Some(cached);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no disk cache, fetch from OpenRouter
|
||||||
|
tracing::info!("No valid disk cache found, fetching from OpenRouter");
|
||||||
|
self.refresh().await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PricingCache {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Global cache instance
|
||||||
|
lazy_static::lazy_static! {
|
||||||
|
static ref PRICING_CACHE: PricingCache = PricingCache::new();
|
||||||
|
static ref HTTP_CLIENT: Client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(30))
|
||||||
|
.pool_idle_timeout(Duration::from_secs(90))
|
||||||
|
.pool_max_idle_per_host(10)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// OpenRouter model pricing information
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OpenRouterModel {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub pricing: OpenRouterPricing,
|
||||||
|
pub context_length: Option<u32>,
|
||||||
|
pub architecture: Option<Architecture>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OpenRouterPricing {
|
||||||
|
pub prompt: String, // Cost per token for input (in USD)
|
||||||
|
pub completion: String, // Cost per token for output (in USD)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Architecture {
|
||||||
|
pub modality: String,
|
||||||
|
pub tokenizer: String,
|
||||||
|
pub instruct_type: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response from OpenRouter models endpoint
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct OpenRouterModelsResponse {
|
||||||
|
pub data: Vec<OpenRouterModel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal function to fetch pricing data
|
||||||
|
async fn fetch_openrouter_pricing_internal() -> Result<HashMap<String, OpenRouterModel>> {
|
||||||
|
let response = HTTP_CLIENT
|
||||||
|
.get("https://openrouter.ai/api/v1/models")
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Failed to fetch OpenRouter models: HTTP {}",
|
||||||
|
response.status()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let models_response: OpenRouterModelsResponse = response.json().await?;
|
||||||
|
|
||||||
|
// Create a map for easy lookup
|
||||||
|
let mut pricing_map = HashMap::new();
|
||||||
|
for model in models_response.data {
|
||||||
|
pricing_map.insert(model.id.clone(), model);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(pricing_map)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialize pricing cache on startup
|
||||||
|
pub async fn initialize_pricing_cache() -> Result<()> {
|
||||||
|
PRICING_CACHE.initialize().await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get pricing for a specific model
|
||||||
|
pub async fn get_model_pricing(provider: &str, model: &str) -> Option<PricingInfo> {
|
||||||
|
PRICING_CACHE.get_model_pricing(provider, model).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Force refresh pricing data
|
||||||
|
pub async fn refresh_pricing() -> Result<()> {
|
||||||
|
PRICING_CACHE.refresh().await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get all cached pricing data
|
||||||
|
pub async fn get_all_pricing() -> HashMap<String, HashMap<String, PricingInfo>> {
|
||||||
|
let cache = PRICING_CACHE.memory_cache.read().await;
|
||||||
|
if let Some(cached) = &*cache {
|
||||||
|
cached.pricing.clone()
|
||||||
|
} else {
|
||||||
|
// Try loading from disk
|
||||||
|
if let Ok(Some(disk_cache)) = PRICING_CACHE.load_from_disk().await {
|
||||||
|
// Update memory cache
|
||||||
|
drop(cache);
|
||||||
|
let mut write_cache = PRICING_CACHE.memory_cache.write().await;
|
||||||
|
*write_cache = Some(disk_cache.clone());
|
||||||
|
disk_cache.pricing
|
||||||
|
} else {
|
||||||
|
HashMap::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert OpenRouter model ID to provider/model format
|
||||||
|
/// e.g., "anthropic/claude-3.5-sonnet" -> ("anthropic", "claude-3.5-sonnet")
|
||||||
|
pub fn parse_model_id(model_id: &str) -> Option<(String, String)> {
|
||||||
|
let parts: Vec<&str> = model_id.splitn(2, '/').collect();
|
||||||
|
if parts.len() == 2 {
|
||||||
|
// Normalize provider names to match our internal naming
|
||||||
|
let provider = match parts[0] {
|
||||||
|
"openai" => "openai",
|
||||||
|
"anthropic" => "anthropic",
|
||||||
|
"google" => "google",
|
||||||
|
"meta-llama" => "ollama", // Meta models often run via Ollama
|
||||||
|
"mistralai" => "mistral",
|
||||||
|
"cohere" => "cohere",
|
||||||
|
"perplexity" => "perplexity",
|
||||||
|
"deepseek" => "deepseek",
|
||||||
|
"groq" => "groq",
|
||||||
|
"nvidia" => "nvidia",
|
||||||
|
"microsoft" => "azure",
|
||||||
|
"replicate" => "replicate",
|
||||||
|
"huggingface" => "huggingface",
|
||||||
|
_ => parts[0],
|
||||||
|
};
|
||||||
|
Some((provider.to_string(), parts[1].to_string()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert OpenRouter pricing to cost per token (already in that format)
|
||||||
|
pub fn convert_pricing(price_str: &str) -> Option<f64> {
|
||||||
|
// OpenRouter prices are already in USD per token
|
||||||
|
price_str.parse::<f64>().ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_model_id() {
|
||||||
|
assert_eq!(
|
||||||
|
parse_model_id("anthropic/claude-3.5-sonnet"),
|
||||||
|
Some(("anthropic".to_string(), "claude-3.5-sonnet".to_string()))
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
parse_model_id("openai/gpt-4"),
|
||||||
|
Some(("openai".to_string(), "gpt-4".to_string()))
|
||||||
|
);
|
||||||
|
assert_eq!(parse_model_id("invalid-format"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_convert_pricing() {
|
||||||
|
assert_eq!(convert_pricing("0.000003"), Some(0.000003));
|
||||||
|
assert_eq!(convert_pricing("0.015"), Some(0.015));
|
||||||
|
assert_eq!(convert_pricing("invalid"), None);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1700,9 +1700,26 @@
|
|||||||
"description": "The maximum context length this model supports",
|
"description": "The maximum context length this model supports",
|
||||||
"minimum": 0
|
"minimum": 0
|
||||||
},
|
},
|
||||||
|
"currency": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Currency for the costs (default: \"$\")",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"input_token_cost": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "double",
|
||||||
|
"description": "Cost per token for input (optional)",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
"name": {
|
"name": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The name of the model"
|
"description": "The name of the model"
|
||||||
|
},
|
||||||
|
"output_token_cost": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "double",
|
||||||
|
"description": "Cost per token for output (optional)",
|
||||||
|
"nullable": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import { IpcRendererEvent } from 'electron';
|
|||||||
import { openSharedSessionFromDeepLink, type SessionLinksViewOptions } from './sessionLinks';
|
import { openSharedSessionFromDeepLink, type SessionLinksViewOptions } from './sessionLinks';
|
||||||
import { type SharedSessionDetails } from './sharedSessions';
|
import { type SharedSessionDetails } from './sharedSessions';
|
||||||
import { initializeSystem } from './utils/providerUtils';
|
import { initializeSystem } from './utils/providerUtils';
|
||||||
|
import { initializeCostDatabase } from './utils/costDatabase';
|
||||||
import { ErrorUI } from './components/ErrorBoundary';
|
import { ErrorUI } from './components/ErrorBoundary';
|
||||||
import { ConfirmationModal } from './components/ui/ConfirmationModal';
|
import { ConfirmationModal } from './components/ui/ConfirmationModal';
|
||||||
import { ToastContainer } from 'react-toastify';
|
import { ToastContainer } from 'react-toastify';
|
||||||
@@ -158,6 +159,11 @@ export default function App() {
|
|||||||
|
|
||||||
const initializeApp = async () => {
|
const initializeApp = async () => {
|
||||||
try {
|
try {
|
||||||
|
// Initialize cost database early to pre-load pricing data
|
||||||
|
initializeCostDatabase().catch((error) => {
|
||||||
|
console.error('Failed to initialize cost database:', error);
|
||||||
|
});
|
||||||
|
|
||||||
await initConfig();
|
await initConfig();
|
||||||
try {
|
try {
|
||||||
await readAllConfig({ throwOnError: true });
|
await readAllConfig({ throwOnError: true });
|
||||||
|
|||||||
@@ -229,10 +229,22 @@ export type ModelInfo = {
|
|||||||
* The maximum context length this model supports
|
* The maximum context length this model supports
|
||||||
*/
|
*/
|
||||||
context_limit: number;
|
context_limit: number;
|
||||||
|
/**
|
||||||
|
* Currency for the costs (default: "$")
|
||||||
|
*/
|
||||||
|
currency?: string | null;
|
||||||
|
/**
|
||||||
|
* Cost per token for input (optional)
|
||||||
|
*/
|
||||||
|
input_token_cost?: number | null;
|
||||||
/**
|
/**
|
||||||
* The name of the model
|
* The name of the model
|
||||||
*/
|
*/
|
||||||
name: string;
|
name: string;
|
||||||
|
/**
|
||||||
|
* Cost per token for output (optional)
|
||||||
|
*/
|
||||||
|
output_token_cost?: number | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type PermissionConfirmationRequest = {
|
export type PermissionConfirmationRequest = {
|
||||||
|
|||||||
@@ -29,9 +29,18 @@ interface ChatInputProps {
|
|||||||
droppedFiles?: string[];
|
droppedFiles?: string[];
|
||||||
setView: (view: View) => void;
|
setView: (view: View) => void;
|
||||||
numTokens?: number;
|
numTokens?: number;
|
||||||
|
inputTokens?: number;
|
||||||
|
outputTokens?: number;
|
||||||
hasMessages?: boolean;
|
hasMessages?: boolean;
|
||||||
messages?: Message[];
|
messages?: Message[];
|
||||||
setMessages: (messages: Message[]) => void;
|
setMessages: (messages: Message[]) => void;
|
||||||
|
sessionCosts?: {
|
||||||
|
[key: string]: {
|
||||||
|
inputTokens: number;
|
||||||
|
outputTokens: number;
|
||||||
|
totalCost: number;
|
||||||
|
};
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function ChatInput({
|
export default function ChatInput({
|
||||||
@@ -42,9 +51,12 @@ export default function ChatInput({
|
|||||||
initialValue = '',
|
initialValue = '',
|
||||||
setView,
|
setView,
|
||||||
numTokens,
|
numTokens,
|
||||||
|
inputTokens,
|
||||||
|
outputTokens,
|
||||||
droppedFiles = [],
|
droppedFiles = [],
|
||||||
messages = [],
|
messages = [],
|
||||||
setMessages,
|
setMessages,
|
||||||
|
sessionCosts,
|
||||||
}: ChatInputProps) {
|
}: ChatInputProps) {
|
||||||
const [_value, setValue] = useState(initialValue);
|
const [_value, setValue] = useState(initialValue);
|
||||||
const [displayValue, setDisplayValue] = useState(initialValue); // For immediate visual feedback
|
const [displayValue, setDisplayValue] = useState(initialValue); // For immediate visual feedback
|
||||||
@@ -557,9 +569,12 @@ export default function ChatInput({
|
|||||||
<BottomMenu
|
<BottomMenu
|
||||||
setView={setView}
|
setView={setView}
|
||||||
numTokens={numTokens}
|
numTokens={numTokens}
|
||||||
|
inputTokens={inputTokens}
|
||||||
|
outputTokens={outputTokens}
|
||||||
messages={messages}
|
messages={messages}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
setMessages={setMessages}
|
setMessages={setMessages}
|
||||||
|
sessionCosts={sessionCosts}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ import {
|
|||||||
} from './context_management/ChatContextManager';
|
} from './context_management/ChatContextManager';
|
||||||
import { ContextHandler } from './context_management/ContextHandler';
|
import { ContextHandler } from './context_management/ContextHandler';
|
||||||
import { LocalMessageStorage } from '../utils/localMessageStorage';
|
import { LocalMessageStorage } from '../utils/localMessageStorage';
|
||||||
|
import { useModelAndProvider } from './ModelAndProviderContext';
|
||||||
|
import { getCostForModel } from '../utils/costDatabase';
|
||||||
import {
|
import {
|
||||||
Message,
|
Message,
|
||||||
createUserMessage,
|
createUserMessage,
|
||||||
@@ -106,11 +108,25 @@ function ChatContent({
|
|||||||
const [showGame, setShowGame] = useState(false);
|
const [showGame, setShowGame] = useState(false);
|
||||||
const [isGeneratingRecipe, setIsGeneratingRecipe] = useState(false);
|
const [isGeneratingRecipe, setIsGeneratingRecipe] = useState(false);
|
||||||
const [sessionTokenCount, setSessionTokenCount] = useState<number>(0);
|
const [sessionTokenCount, setSessionTokenCount] = useState<number>(0);
|
||||||
|
const [sessionInputTokens, setSessionInputTokens] = useState<number>(0);
|
||||||
|
const [sessionOutputTokens, setSessionOutputTokens] = useState<number>(0);
|
||||||
|
const [localInputTokens, setLocalInputTokens] = useState<number>(0);
|
||||||
|
const [localOutputTokens, setLocalOutputTokens] = useState<number>(0);
|
||||||
const [ancestorMessages, setAncestorMessages] = useState<Message[]>([]);
|
const [ancestorMessages, setAncestorMessages] = useState<Message[]>([]);
|
||||||
const [droppedFiles, setDroppedFiles] = useState<string[]>([]);
|
const [droppedFiles, setDroppedFiles] = useState<string[]>([]);
|
||||||
|
const [sessionCosts, setSessionCosts] = useState<{
|
||||||
|
[key: string]: {
|
||||||
|
inputTokens: number;
|
||||||
|
outputTokens: number;
|
||||||
|
totalCost: number;
|
||||||
|
};
|
||||||
|
}>({});
|
||||||
const [readyForAutoUserPrompt, setReadyForAutoUserPrompt] = useState(false);
|
const [readyForAutoUserPrompt, setReadyForAutoUserPrompt] = useState(false);
|
||||||
|
|
||||||
const scrollRef = useRef<ScrollAreaHandle>(null);
|
const scrollRef = useRef<ScrollAreaHandle>(null);
|
||||||
|
const { currentModel, currentProvider } = useModelAndProvider();
|
||||||
|
const prevModelRef = useRef<string | undefined>();
|
||||||
|
const prevProviderRef = useRef<string | undefined>();
|
||||||
|
|
||||||
const {
|
const {
|
||||||
summaryContent,
|
summaryContent,
|
||||||
@@ -160,6 +176,7 @@ function ChatContent({
|
|||||||
updateMessageStreamBody,
|
updateMessageStreamBody,
|
||||||
notifications,
|
notifications,
|
||||||
currentModelInfo,
|
currentModelInfo,
|
||||||
|
sessionMetadata,
|
||||||
} = useMessageStream({
|
} = useMessageStream({
|
||||||
api: getApiUrl('/reply'),
|
api: getApiUrl('/reply'),
|
||||||
initialMessages: chat.messages,
|
initialMessages: chat.messages,
|
||||||
@@ -518,12 +535,40 @@ function ChatContent({
|
|||||||
.reverse();
|
.reverse();
|
||||||
}, [filteredMessages]);
|
}, [filteredMessages]);
|
||||||
|
|
||||||
|
// Simple token estimation function (roughly 4 characters per token)
|
||||||
|
const estimateTokens = (text: string): number => {
|
||||||
|
return Math.ceil(text.length / 4);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calculate token counts from messages
|
||||||
|
useEffect(() => {
|
||||||
|
let inputTokens = 0;
|
||||||
|
let outputTokens = 0;
|
||||||
|
|
||||||
|
messages.forEach((message) => {
|
||||||
|
const textContent = getTextContent(message);
|
||||||
|
if (textContent) {
|
||||||
|
const tokens = estimateTokens(textContent);
|
||||||
|
if (message.role === 'user') {
|
||||||
|
inputTokens += tokens;
|
||||||
|
} else if (message.role === 'assistant') {
|
||||||
|
outputTokens += tokens;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
setLocalInputTokens(inputTokens);
|
||||||
|
setLocalOutputTokens(outputTokens);
|
||||||
|
}, [messages]);
|
||||||
|
|
||||||
// Fetch session metadata to get token count
|
// Fetch session metadata to get token count
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const fetchSessionTokens = async () => {
|
const fetchSessionTokens = async () => {
|
||||||
try {
|
try {
|
||||||
const sessionDetails = await fetchSessionDetails(chat.id);
|
const sessionDetails = await fetchSessionDetails(chat.id);
|
||||||
setSessionTokenCount(sessionDetails.metadata.total_tokens || 0);
|
setSessionTokenCount(sessionDetails.metadata.total_tokens || 0);
|
||||||
|
setSessionInputTokens(sessionDetails.metadata.accumulated_input_tokens || 0);
|
||||||
|
setSessionOutputTokens(sessionDetails.metadata.accumulated_output_tokens || 0);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Error fetching session token count:', err);
|
console.error('Error fetching session token count:', err);
|
||||||
}
|
}
|
||||||
@@ -533,6 +578,74 @@ function ChatContent({
|
|||||||
}
|
}
|
||||||
}, [chat.id, messages]);
|
}, [chat.id, messages]);
|
||||||
|
|
||||||
|
// Update token counts when sessionMetadata changes from the message stream
|
||||||
|
useEffect(() => {
|
||||||
|
console.log('Session metadata received:', sessionMetadata);
|
||||||
|
if (sessionMetadata) {
|
||||||
|
setSessionTokenCount(sessionMetadata.totalTokens || 0);
|
||||||
|
setSessionInputTokens(sessionMetadata.accumulatedInputTokens || 0);
|
||||||
|
setSessionOutputTokens(sessionMetadata.accumulatedOutputTokens || 0);
|
||||||
|
}
|
||||||
|
}, [sessionMetadata]);
|
||||||
|
|
||||||
|
// Handle model changes and accumulate costs
|
||||||
|
useEffect(() => {
|
||||||
|
if (
|
||||||
|
prevModelRef.current !== undefined &&
|
||||||
|
prevProviderRef.current !== undefined &&
|
||||||
|
(prevModelRef.current !== currentModel || prevProviderRef.current !== currentProvider)
|
||||||
|
) {
|
||||||
|
// Model/provider has changed, save the costs for the previous model
|
||||||
|
const prevKey = `${prevProviderRef.current}/${prevModelRef.current}`;
|
||||||
|
|
||||||
|
// Get pricing info for the previous model
|
||||||
|
const prevCostInfo = getCostForModel(prevProviderRef.current, prevModelRef.current);
|
||||||
|
|
||||||
|
if (prevCostInfo) {
|
||||||
|
const prevInputCost =
|
||||||
|
(sessionInputTokens || localInputTokens) * (prevCostInfo.input_token_cost || 0);
|
||||||
|
const prevOutputCost =
|
||||||
|
(sessionOutputTokens || localOutputTokens) * (prevCostInfo.output_token_cost || 0);
|
||||||
|
const prevTotalCost = prevInputCost + prevOutputCost;
|
||||||
|
|
||||||
|
// Save the accumulated costs for this model
|
||||||
|
setSessionCosts((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[prevKey]: {
|
||||||
|
inputTokens: sessionInputTokens || localInputTokens,
|
||||||
|
outputTokens: sessionOutputTokens || localOutputTokens,
|
||||||
|
totalCost: prevTotalCost,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset token counters for the new model
|
||||||
|
setSessionTokenCount(0);
|
||||||
|
setSessionInputTokens(0);
|
||||||
|
setSessionOutputTokens(0);
|
||||||
|
setLocalInputTokens(0);
|
||||||
|
setLocalOutputTokens(0);
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
'Model changed from',
|
||||||
|
`${prevProviderRef.current}/${prevModelRef.current}`,
|
||||||
|
'to',
|
||||||
|
`${currentProvider}/${currentModel}`,
|
||||||
|
'- saved costs and reset token counters'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
prevModelRef.current = currentModel || undefined;
|
||||||
|
prevProviderRef.current = currentProvider || undefined;
|
||||||
|
}, [
|
||||||
|
currentModel,
|
||||||
|
currentProvider,
|
||||||
|
sessionInputTokens,
|
||||||
|
sessionOutputTokens,
|
||||||
|
localInputTokens,
|
||||||
|
localOutputTokens,
|
||||||
|
]);
|
||||||
|
|
||||||
const handleDrop = (e: React.DragEvent<HTMLDivElement>) => {
|
const handleDrop = (e: React.DragEvent<HTMLDivElement>) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
const files = e.dataTransfer.files;
|
const files = e.dataTransfer.files;
|
||||||
@@ -684,9 +797,12 @@ function ChatContent({
|
|||||||
setView={setView}
|
setView={setView}
|
||||||
hasMessages={hasMessages}
|
hasMessages={hasMessages}
|
||||||
numTokens={sessionTokenCount}
|
numTokens={sessionTokenCount}
|
||||||
|
inputTokens={sessionInputTokens || localInputTokens}
|
||||||
|
outputTokens={sessionOutputTokens || localOutputTokens}
|
||||||
droppedFiles={droppedFiles}
|
droppedFiles={droppedFiles}
|
||||||
messages={messages}
|
messages={messages}
|
||||||
setMessages={setMessages}
|
setMessages={setMessages}
|
||||||
|
sessionCosts={sessionCosts}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import { useConfig } from '../ConfigContext';
|
|||||||
import { useModelAndProvider } from '../ModelAndProviderContext';
|
import { useModelAndProvider } from '../ModelAndProviderContext';
|
||||||
import { Message } from '../../types/message';
|
import { Message } from '../../types/message';
|
||||||
import { ManualSummarizeButton } from '../context_management/ManualSummaryButton';
|
import { ManualSummarizeButton } from '../context_management/ManualSummaryButton';
|
||||||
|
import { CostTracker } from './CostTracker';
|
||||||
|
|
||||||
const TOKEN_LIMIT_DEFAULT = 128000; // fallback for custom models that the backend doesn't know about
|
const TOKEN_LIMIT_DEFAULT = 128000; // fallback for custom models that the backend doesn't know about
|
||||||
const TOKEN_WARNING_THRESHOLD = 0.8; // warning shows at 80% of the token limit
|
const TOKEN_WARNING_THRESHOLD = 0.8; // warning shows at 80% of the token limit
|
||||||
@@ -22,15 +23,27 @@ interface ModelLimit {
|
|||||||
export default function BottomMenu({
|
export default function BottomMenu({
|
||||||
setView,
|
setView,
|
||||||
numTokens = 0,
|
numTokens = 0,
|
||||||
|
inputTokens = 0,
|
||||||
|
outputTokens = 0,
|
||||||
messages = [],
|
messages = [],
|
||||||
isLoading = false,
|
isLoading = false,
|
||||||
setMessages,
|
setMessages,
|
||||||
|
sessionCosts,
|
||||||
}: {
|
}: {
|
||||||
setView: (view: View, viewOptions?: ViewOptions) => void;
|
setView: (view: View, viewOptions?: ViewOptions) => void;
|
||||||
numTokens?: number;
|
numTokens?: number;
|
||||||
|
inputTokens?: number;
|
||||||
|
outputTokens?: number;
|
||||||
messages?: Message[];
|
messages?: Message[];
|
||||||
isLoading?: boolean;
|
isLoading?: boolean;
|
||||||
setMessages: (messages: Message[]) => void;
|
setMessages: (messages: Message[]) => void;
|
||||||
|
sessionCosts?: {
|
||||||
|
[key: string]: {
|
||||||
|
inputTokens: number;
|
||||||
|
outputTokens: number;
|
||||||
|
totalCost: number;
|
||||||
|
};
|
||||||
|
};
|
||||||
}) {
|
}) {
|
||||||
const [isModelMenuOpen, setIsModelMenuOpen] = useState(false);
|
const [isModelMenuOpen, setIsModelMenuOpen] = useState(false);
|
||||||
const { alerts, addAlert, clearAlerts } = useAlerts();
|
const { alerts, addAlert, clearAlerts } = useAlerts();
|
||||||
@@ -202,29 +215,45 @@ export default function BottomMenu({
|
|||||||
}, [isModelMenuOpen]);
|
}, [isModelMenuOpen]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex justify-between items-center transition-colors text-textSubtle relative text-xs align-middle">
|
<div className="flex justify-between items-center transition-colors text-textSubtle relative text-xs h-6">
|
||||||
<div className="flex items-center pl-2">
|
<div className="flex items-center h-full">
|
||||||
{/* Tool and Token count */}
|
{/* Tool and Token count */}
|
||||||
{<BottomMenuAlertPopover alerts={alerts} />}
|
<div className="flex items-center h-full pl-2">
|
||||||
|
{<BottomMenuAlertPopover alerts={alerts} />}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Cost Tracker - no separator before it */}
|
||||||
|
<div className="flex items-center h-full ml-1">
|
||||||
|
<CostTracker inputTokens={inputTokens} outputTokens={outputTokens} sessionCosts={sessionCosts} />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Separator between cost and model */}
|
||||||
|
<div className="w-[1px] h-4 bg-borderSubtle mx-1.5" />
|
||||||
|
|
||||||
{/* Model Selector Dropdown */}
|
{/* Model Selector Dropdown */}
|
||||||
<ModelsBottomBar dropdownRef={dropdownRef} setView={setView} />
|
<div className="flex items-center h-full">
|
||||||
|
<ModelsBottomBar dropdownRef={dropdownRef} setView={setView} />
|
||||||
|
</div>
|
||||||
|
|
||||||
{/* Separator */}
|
{/* Separator */}
|
||||||
<div className="w-[1px] h-4 bg-borderSubtle mx-2" />
|
<div className="w-[1px] h-4 bg-borderSubtle mx-1.5" />
|
||||||
|
|
||||||
{/* Goose Mode Selector Dropdown */}
|
{/* Goose Mode Selector Dropdown */}
|
||||||
<BottomMenuModeSelection setView={setView} />
|
<div className="flex items-center h-full">
|
||||||
|
<BottomMenuModeSelection setView={setView} />
|
||||||
|
</div>
|
||||||
|
|
||||||
{/* Summarize Context Button - ADD THIS */}
|
{/* Summarize Context Button */}
|
||||||
{messages.length > 0 && (
|
{messages.length > 0 && (
|
||||||
<>
|
<>
|
||||||
<div className="w-[1px] h-4 bg-borderSubtle mx-2" />
|
<div className="w-[1px] h-4 bg-borderSubtle mx-1.5" />
|
||||||
<ManualSummarizeButton
|
<div className="flex items-center h-full">
|
||||||
messages={messages}
|
<ManualSummarizeButton
|
||||||
isLoading={isLoading}
|
messages={messages}
|
||||||
setMessages={setMessages}
|
isLoading={isLoading}
|
||||||
/>
|
setMessages={setMessages}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
310
ui/desktop/src/components/bottom_menu/CostTracker.tsx
Normal file
310
ui/desktop/src/components/bottom_menu/CostTracker.tsx
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
import { useState, useEffect } from 'react';
|
||||||
|
import { useModelAndProvider } from '../ModelAndProviderContext';
|
||||||
|
import { useConfig } from '../ConfigContext';
|
||||||
|
import {
|
||||||
|
getCostForModel,
|
||||||
|
initializeCostDatabase,
|
||||||
|
updateAllModelCosts,
|
||||||
|
fetchAndCachePricing,
|
||||||
|
} from '../../utils/costDatabase';
|
||||||
|
|
||||||
|
interface CostTrackerProps {
|
||||||
|
inputTokens?: number;
|
||||||
|
outputTokens?: number;
|
||||||
|
sessionCosts?: {
|
||||||
|
[key: string]: {
|
||||||
|
inputTokens: number;
|
||||||
|
outputTokens: number;
|
||||||
|
totalCost: number;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function CostTracker({ inputTokens = 0, outputTokens = 0, sessionCosts }: CostTrackerProps) {
|
||||||
|
const { currentModel, currentProvider } = useModelAndProvider();
|
||||||
|
const { getProviders } = useConfig();
|
||||||
|
const [costInfo, setCostInfo] = useState<{
|
||||||
|
input_token_cost?: number;
|
||||||
|
output_token_cost?: number;
|
||||||
|
currency?: string;
|
||||||
|
} | null>(null);
|
||||||
|
const [isLoading, setIsLoading] = useState(true);
|
||||||
|
const [showPricing, setShowPricing] = useState(true);
|
||||||
|
const [pricingFailed, setPricingFailed] = useState(false);
|
||||||
|
const [modelNotFound, setModelNotFound] = useState(false);
|
||||||
|
const [hasAttemptedFetch, setHasAttemptedFetch] = useState(false);
|
||||||
|
const [initialLoadComplete, setInitialLoadComplete] = useState(false);
|
||||||
|
|
||||||
|
// Check if pricing is enabled
|
||||||
|
useEffect(() => {
|
||||||
|
const checkPricingSetting = () => {
|
||||||
|
const stored = localStorage.getItem('show_pricing');
|
||||||
|
setShowPricing(stored !== 'false');
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check on mount
|
||||||
|
checkPricingSetting();
|
||||||
|
|
||||||
|
// Listen for storage changes
|
||||||
|
window.addEventListener('storage', checkPricingSetting);
|
||||||
|
return () => window.removeEventListener('storage', checkPricingSetting);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Set initial load complete after a short delay
|
||||||
|
useEffect(() => {
|
||||||
|
const timer = setTimeout(() => {
|
||||||
|
setInitialLoadComplete(true);
|
||||||
|
}, 3000); // Give 3 seconds for initial load
|
||||||
|
|
||||||
|
return () => window.clearTimeout(timer);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Debug log props removed
|
||||||
|
|
||||||
|
// Initialize cost database on mount
|
||||||
|
useEffect(() => {
|
||||||
|
initializeCostDatabase();
|
||||||
|
|
||||||
|
// Update costs for all models in background
|
||||||
|
updateAllModelCosts().catch((error) => {
|
||||||
|
console.error('Failed to update model costs:', error);
|
||||||
|
});
|
||||||
|
}, [getProviders]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const loadCostInfo = async () => {
|
||||||
|
if (!currentModel || !currentProvider) {
|
||||||
|
setIsLoading(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`CostTracker: Loading cost info for ${currentProvider}/${currentModel}`);
|
||||||
|
|
||||||
|
try {
|
||||||
|
// First check sync cache
|
||||||
|
let costData = getCostForModel(currentProvider, currentModel);
|
||||||
|
|
||||||
|
if (costData) {
|
||||||
|
// We have cached data
|
||||||
|
console.log(
|
||||||
|
`CostTracker: Found cached data for ${currentProvider}/${currentModel}:`,
|
||||||
|
costData
|
||||||
|
);
|
||||||
|
setCostInfo(costData);
|
||||||
|
setPricingFailed(false);
|
||||||
|
setModelNotFound(false);
|
||||||
|
setIsLoading(false);
|
||||||
|
setHasAttemptedFetch(true);
|
||||||
|
} else {
|
||||||
|
// Need to fetch from backend
|
||||||
|
console.log(
|
||||||
|
`CostTracker: No cached data, fetching from backend for ${currentProvider}/${currentModel}`
|
||||||
|
);
|
||||||
|
setIsLoading(true);
|
||||||
|
const result = await fetchAndCachePricing(currentProvider, currentModel);
|
||||||
|
setHasAttemptedFetch(true);
|
||||||
|
|
||||||
|
if (result && result.costInfo) {
|
||||||
|
console.log(
|
||||||
|
`CostTracker: Fetched data for ${currentProvider}/${currentModel}:`,
|
||||||
|
result.costInfo
|
||||||
|
);
|
||||||
|
setCostInfo(result.costInfo);
|
||||||
|
setPricingFailed(false);
|
||||||
|
setModelNotFound(false);
|
||||||
|
} else if (result && result.error === 'model_not_found') {
|
||||||
|
console.log(
|
||||||
|
`CostTracker: Model not found in pricing data for ${currentProvider}/${currentModel}`
|
||||||
|
);
|
||||||
|
// Model not found in pricing database, but API call succeeded
|
||||||
|
setModelNotFound(true);
|
||||||
|
setPricingFailed(false);
|
||||||
|
} else {
|
||||||
|
console.log(`CostTracker: API failed for ${currentProvider}/${currentModel}`);
|
||||||
|
// API call failed or other error
|
||||||
|
const freeProviders = ['ollama', 'local', 'localhost'];
|
||||||
|
if (!freeProviders.includes(currentProvider.toLowerCase())) {
|
||||||
|
setPricingFailed(true);
|
||||||
|
setModelNotFound(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
setIsLoading(false);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error loading cost info:', error);
|
||||||
|
setHasAttemptedFetch(true);
|
||||||
|
// Only set pricing failed if we're not dealing with a known free provider
|
||||||
|
const freeProviders = ['ollama', 'local', 'localhost'];
|
||||||
|
if (!freeProviders.includes(currentProvider.toLowerCase())) {
|
||||||
|
setPricingFailed(true);
|
||||||
|
setModelNotFound(false);
|
||||||
|
}
|
||||||
|
setIsLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
loadCostInfo();
|
||||||
|
}, [currentModel, currentProvider]);
|
||||||
|
|
||||||
|
// Return null early if pricing is disabled
|
||||||
|
if (!showPricing) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const calculateCost = (): number => {
|
||||||
|
// If we have session costs, calculate the total across all models
|
||||||
|
if (sessionCosts) {
|
||||||
|
let totalCost = 0;
|
||||||
|
|
||||||
|
// Add up all historical costs from different models
|
||||||
|
Object.values(sessionCosts).forEach((modelCost) => {
|
||||||
|
totalCost += modelCost.totalCost;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add current model cost if we have pricing info
|
||||||
|
if (
|
||||||
|
costInfo &&
|
||||||
|
(costInfo.input_token_cost !== undefined || costInfo.output_token_cost !== undefined)
|
||||||
|
) {
|
||||||
|
const currentInputCost = inputTokens * (costInfo.input_token_cost || 0);
|
||||||
|
const currentOutputCost = outputTokens * (costInfo.output_token_cost || 0);
|
||||||
|
totalCost += currentInputCost + currentOutputCost;
|
||||||
|
}
|
||||||
|
|
||||||
|
return totalCost;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to simple calculation for current model only
|
||||||
|
if (
|
||||||
|
!costInfo ||
|
||||||
|
(costInfo.input_token_cost === undefined && costInfo.output_token_cost === undefined)
|
||||||
|
) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const inputCost = inputTokens * (costInfo.input_token_cost || 0);
|
||||||
|
const outputCost = outputTokens * (costInfo.output_token_cost || 0);
|
||||||
|
const total = inputCost + outputCost;
|
||||||
|
|
||||||
|
return total;
|
||||||
|
};
|
||||||
|
|
||||||
|
const formatCost = (cost: number): string => {
|
||||||
|
// Always show 6 decimal places for consistency
|
||||||
|
return cost.toFixed(6);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Debug logging removed
|
||||||
|
|
||||||
|
// Show loading state or when we don't have model/provider info
|
||||||
|
if (!currentModel || !currentProvider) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If still loading, show a placeholder
|
||||||
|
if (isLoading) {
|
||||||
|
return (
|
||||||
|
<div className="flex items-center justify-center h-full text-textSubtle translate-y-[1px]">
|
||||||
|
<span className="text-xs font-mono">...</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no cost info found, try to return a default
|
||||||
|
if (
|
||||||
|
!costInfo ||
|
||||||
|
(costInfo.input_token_cost === undefined && costInfo.output_token_cost === undefined)
|
||||||
|
) {
|
||||||
|
// If it's a known free/local provider, show $0.000000 without "not available" message
|
||||||
|
const freeProviders = ['ollama', 'local', 'localhost'];
|
||||||
|
if (freeProviders.includes(currentProvider.toLowerCase())) {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className="flex items-center justify-center h-full text-textSubtle hover:text-textStandard transition-colors cursor-default translate-y-[1px]"
|
||||||
|
title={`Local model (${inputTokens.toLocaleString()} input, ${outputTokens.toLocaleString()} output tokens)`}
|
||||||
|
>
|
||||||
|
<span className="text-xs font-mono">$0.000000</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise show as unavailable
|
||||||
|
const getUnavailableTooltip = () => {
|
||||||
|
if (pricingFailed && hasAttemptedFetch && initialLoadComplete) {
|
||||||
|
return `Pricing data unavailable - OpenRouter connection failed. Click refresh in settings to retry.`;
|
||||||
|
}
|
||||||
|
// If we reach here, it must be modelNotFound (since we only get here after attempting fetch)
|
||||||
|
return `Cost data not available for ${currentModel} (${inputTokens.toLocaleString()} input, ${outputTokens.toLocaleString()} output tokens)`;
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={`flex items-center justify-center h-full transition-colors cursor-default translate-y-[1px] ${
|
||||||
|
(pricingFailed || modelNotFound) && hasAttemptedFetch && initialLoadComplete
|
||||||
|
? 'text-red-500 hover:text-red-400'
|
||||||
|
: 'text-textSubtle hover:text-textStandard'
|
||||||
|
}`}
|
||||||
|
title={getUnavailableTooltip()}
|
||||||
|
>
|
||||||
|
<span className="text-xs font-mono">$0.000000</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const totalCost = calculateCost();
|
||||||
|
|
||||||
|
// Build tooltip content
|
||||||
|
const getTooltipContent = (): string => {
|
||||||
|
// Handle error states first
|
||||||
|
if (pricingFailed && hasAttemptedFetch && initialLoadComplete) {
|
||||||
|
return `Pricing data unavailable - OpenRouter connection failed. Click refresh in settings to retry.`;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (modelNotFound && hasAttemptedFetch && initialLoadComplete) {
|
||||||
|
return `Pricing not available for ${currentProvider}/${currentModel}. This model may not be supported by the pricing service.`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle session costs
|
||||||
|
if (sessionCosts && Object.keys(sessionCosts).length > 0) {
|
||||||
|
// Show session breakdown
|
||||||
|
let tooltip = 'Session cost breakdown:\n';
|
||||||
|
|
||||||
|
Object.entries(sessionCosts).forEach(([modelKey, cost]) => {
|
||||||
|
const costStr = `${costInfo?.currency || '$'}${cost.totalCost.toFixed(6)}`;
|
||||||
|
tooltip += `${modelKey}: ${costStr} (${cost.inputTokens.toLocaleString()} in, ${cost.outputTokens.toLocaleString()} out)\n`;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add current model if it has costs
|
||||||
|
if (costInfo && (inputTokens > 0 || outputTokens > 0)) {
|
||||||
|
const currentCost =
|
||||||
|
inputTokens * (costInfo.input_token_cost || 0) +
|
||||||
|
outputTokens * (costInfo.output_token_cost || 0);
|
||||||
|
if (currentCost > 0) {
|
||||||
|
tooltip += `${currentProvider}/${currentModel} (current): ${costInfo.currency || '$'}${currentCost.toFixed(6)} (${inputTokens.toLocaleString()} in, ${outputTokens.toLocaleString()} out)\n`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tooltip += `\nTotal session cost: ${costInfo?.currency || '$'}${totalCost.toFixed(6)}`;
|
||||||
|
return tooltip;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default tooltip for single model
|
||||||
|
return `Input: ${inputTokens.toLocaleString()} tokens (${costInfo?.currency || '$'}${(inputTokens * (costInfo?.input_token_cost || 0)).toFixed(6)}) | Output: ${outputTokens.toLocaleString()} tokens (${costInfo?.currency || '$'}${(outputTokens * (costInfo?.output_token_cost || 0)).toFixed(6)})`;
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={`flex items-center justify-center h-full transition-colors cursor-default translate-y-[1px] ${
|
||||||
|
(pricingFailed || modelNotFound) && hasAttemptedFetch && initialLoadComplete
|
||||||
|
? 'text-red-500 hover:text-red-400'
|
||||||
|
: 'text-textSubtle hover:text-textStandard'
|
||||||
|
}`}
|
||||||
|
title={getTooltipContent()}
|
||||||
|
>
|
||||||
|
<span className="text-xs font-mono">
|
||||||
|
{costInfo.currency || '$'}
|
||||||
|
{formatCost(totalCost)}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
import { useState, useEffect, useRef } from 'react';
|
import { useState, useEffect, useRef } from 'react';
|
||||||
import { Switch } from '../../ui/switch';
|
import { Switch } from '../../ui/switch';
|
||||||
import { Button } from '../../ui/button';
|
import { Button } from '../../ui/button';
|
||||||
import { Settings } from 'lucide-react';
|
import { Settings, RefreshCw, ExternalLink } from 'lucide-react';
|
||||||
import Modal from '../../Modal';
|
import Modal from '../../Modal';
|
||||||
import UpdateSection from './UpdateSection';
|
import UpdateSection from './UpdateSection';
|
||||||
import { UPDATES_ENABLED } from '../../../updates';
|
import { UPDATES_ENABLED } from '../../../updates';
|
||||||
|
import { getApiUrl, getSecretKey } from '../../../config';
|
||||||
|
|
||||||
interface AppSettingsSectionProps {
|
interface AppSettingsSectionProps {
|
||||||
scrollToSection?: string;
|
scrollToSection?: string;
|
||||||
@@ -17,6 +18,10 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
|
|||||||
const [isMacOS, setIsMacOS] = useState(false);
|
const [isMacOS, setIsMacOS] = useState(false);
|
||||||
const [isDockSwitchDisabled, setIsDockSwitchDisabled] = useState(false);
|
const [isDockSwitchDisabled, setIsDockSwitchDisabled] = useState(false);
|
||||||
const [showNotificationModal, setShowNotificationModal] = useState(false);
|
const [showNotificationModal, setShowNotificationModal] = useState(false);
|
||||||
|
const [pricingStatus, setPricingStatus] = useState<'loading' | 'success' | 'error'>('loading');
|
||||||
|
const [lastFetchTime, setLastFetchTime] = useState<Date | null>(null);
|
||||||
|
const [isRefreshing, setIsRefreshing] = useState(false);
|
||||||
|
const [showPricing, setShowPricing] = useState(true);
|
||||||
const updateSectionRef = useRef<HTMLDivElement>(null);
|
const updateSectionRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
// Check if running on macOS
|
// Check if running on macOS
|
||||||
@@ -24,6 +29,77 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
|
|||||||
setIsMacOS(window.electron.platform === 'darwin');
|
setIsMacOS(window.electron.platform === 'darwin');
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
// Load show pricing setting
|
||||||
|
useEffect(() => {
|
||||||
|
const stored = localStorage.getItem('show_pricing');
|
||||||
|
setShowPricing(stored !== 'false');
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Check pricing status on mount
|
||||||
|
useEffect(() => {
|
||||||
|
checkPricingStatus();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const checkPricingStatus = async () => {
|
||||||
|
try {
|
||||||
|
const apiUrl = getApiUrl('/config/pricing');
|
||||||
|
const secretKey = getSecretKey();
|
||||||
|
|
||||||
|
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||||
|
if (secretKey) {
|
||||||
|
headers['X-Secret-Key'] = secretKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(apiUrl, {
|
||||||
|
method: 'POST',
|
||||||
|
headers,
|
||||||
|
body: JSON.stringify({ configured_only: true }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.ok) {
|
||||||
|
await response.json(); // Consume the response
|
||||||
|
setPricingStatus('success');
|
||||||
|
setLastFetchTime(new Date());
|
||||||
|
} else {
|
||||||
|
setPricingStatus('error');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
setPricingStatus('error');
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleRefreshPricing = async () => {
|
||||||
|
setIsRefreshing(true);
|
||||||
|
try {
|
||||||
|
const apiUrl = getApiUrl('/config/pricing');
|
||||||
|
const secretKey = getSecretKey();
|
||||||
|
|
||||||
|
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||||
|
if (secretKey) {
|
||||||
|
headers['X-Secret-Key'] = secretKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(apiUrl, {
|
||||||
|
method: 'POST',
|
||||||
|
headers,
|
||||||
|
body: JSON.stringify({ configured_only: false }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.ok) {
|
||||||
|
setPricingStatus('success');
|
||||||
|
setLastFetchTime(new Date());
|
||||||
|
// Trigger a reload of the cost database
|
||||||
|
window.dispatchEvent(new CustomEvent('pricing-updated'));
|
||||||
|
} else {
|
||||||
|
setPricingStatus('error');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
setPricingStatus('error');
|
||||||
|
} finally {
|
||||||
|
setIsRefreshing(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Handle scrolling to update section
|
// Handle scrolling to update section
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (scrollToSection === 'update' && updateSectionRef.current) {
|
if (scrollToSection === 'update' && updateSectionRef.current) {
|
||||||
@@ -99,6 +175,13 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleShowPricingToggle = (checked: boolean) => {
|
||||||
|
setShowPricing(checked);
|
||||||
|
localStorage.setItem('show_pricing', String(checked));
|
||||||
|
// Trigger storage event for other components
|
||||||
|
window.dispatchEvent(new CustomEvent('storage'));
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<section id="appSettings" className="px-8">
|
<section id="appSettings" className="px-8">
|
||||||
<div className="flex justify-between items-center mb-2">
|
<div className="flex justify-between items-center mb-2">
|
||||||
@@ -173,6 +256,7 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* Quit Confirmation */}
|
||||||
<div className="flex items-center justify-between mb-4">
|
<div className="flex items-center justify-between mb-4">
|
||||||
<div>
|
<div>
|
||||||
<h3 className="text-textStandard">Quit Confirmation</h3>
|
<h3 className="text-textStandard">Quit Confirmation</h3>
|
||||||
@@ -188,6 +272,87 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* Cost Tracking */}
|
||||||
|
<div className="flex items-center justify-between mb-4">
|
||||||
|
<div>
|
||||||
|
<h3 className="text-textStandard">Cost Tracking</h3>
|
||||||
|
<p className="text-xs text-textSubtle max-w-md mt-[2px]">
|
||||||
|
Show model pricing and usage costs
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center">
|
||||||
|
<Switch
|
||||||
|
checked={showPricing}
|
||||||
|
onCheckedChange={handleShowPricingToggle}
|
||||||
|
variant="mono"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Pricing Status - only show if cost tracking is enabled */}
|
||||||
|
{showPricing && (
|
||||||
|
<>
|
||||||
|
<div className="flex items-center justify-between text-xs mb-2 px-4">
|
||||||
|
<span className="text-textSubtle">Pricing Source:</span>
|
||||||
|
<a
|
||||||
|
href="https://openrouter.ai/docs#models"
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
className="text-blue-600 dark:text-blue-400 hover:underline flex items-center gap-1"
|
||||||
|
>
|
||||||
|
OpenRouter Docs
|
||||||
|
<ExternalLink size={10} />
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex items-center justify-between text-xs mb-2 px-4">
|
||||||
|
<span className="text-textSubtle">Status:</span>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span
|
||||||
|
className={`font-medium ${
|
||||||
|
pricingStatus === 'success'
|
||||||
|
? 'text-green-600 dark:text-green-400'
|
||||||
|
: pricingStatus === 'error'
|
||||||
|
? 'text-red-600 dark:text-red-400'
|
||||||
|
: 'text-textSubtle'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{pricingStatus === 'success'
|
||||||
|
? '✓ Connected'
|
||||||
|
: pricingStatus === 'error'
|
||||||
|
? '✗ Failed'
|
||||||
|
: '... Checking'}
|
||||||
|
</span>
|
||||||
|
<button
|
||||||
|
className="p-0.5 hover:bg-gray-200 dark:hover:bg-gray-700 rounded transition-colors disabled:opacity-50"
|
||||||
|
onClick={handleRefreshPricing}
|
||||||
|
disabled={isRefreshing}
|
||||||
|
title="Refresh pricing data"
|
||||||
|
type="button"
|
||||||
|
>
|
||||||
|
<RefreshCw
|
||||||
|
size={8}
|
||||||
|
className={`text-textSubtle hover:text-textStandard ${isRefreshing ? 'animate-spin-fast' : ''}`}
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{lastFetchTime && (
|
||||||
|
<div className="flex items-center justify-between text-xs mb-2 px-4">
|
||||||
|
<span className="text-textSubtle">Last updated:</span>
|
||||||
|
<span className="text-textSubtle">{lastFetchTime.toLocaleTimeString()}</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{pricingStatus === 'error' && (
|
||||||
|
<p className="text-xs text-red-600 dark:text-red-400 px-4">
|
||||||
|
Unable to fetch pricing data. Costs will not be displayed.
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Help & Feedback Section */}
|
{/* Help & Feedback Section */}
|
||||||
|
|||||||
@@ -43,7 +43,10 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa
|
|||||||
}, [read]);
|
}, [read]);
|
||||||
|
|
||||||
// Determine which model to display - activeModel takes priority when lead/worker is active
|
// Determine which model to display - activeModel takes priority when lead/worker is active
|
||||||
const displayModel = (isLeadWorkerActive && currentModelInfo?.model) ? currentModelInfo.model : (currentModel || 'Select Model');
|
const displayModel =
|
||||||
|
isLeadWorkerActive && currentModelInfo?.model
|
||||||
|
? currentModelInfo.model
|
||||||
|
: currentModel || 'Select Model';
|
||||||
const modelMode = currentModelInfo?.mode;
|
const modelMode = currentModelInfo?.mode;
|
||||||
|
|
||||||
// Update display provider when current provider changes
|
// Update display provider when current provider changes
|
||||||
@@ -106,9 +109,7 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa
|
|||||||
>
|
>
|
||||||
{displayModel}
|
{displayModel}
|
||||||
{isLeadWorkerActive && modelMode && (
|
{isLeadWorkerActive && modelMode && (
|
||||||
<span className="ml-1 text-[10px] opacity-60">
|
<span className="ml-1 text-[10px] opacity-60">({modelMode})</span>
|
||||||
({modelMode})
|
|
||||||
</span>
|
|
||||||
)}
|
)}
|
||||||
</span>
|
</span>
|
||||||
</TooltipTrigger>
|
</TooltipTrigger>
|
||||||
@@ -116,9 +117,7 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa
|
|||||||
<TooltipContent className="max-w-96 overflow-auto scrollbar-thin" side="top">
|
<TooltipContent className="max-w-96 overflow-auto scrollbar-thin" side="top">
|
||||||
{displayModel}
|
{displayModel}
|
||||||
{isLeadWorkerActive && modelMode && (
|
{isLeadWorkerActive && modelMode && (
|
||||||
<span className="ml-1 text-[10px] opacity-60">
|
<span className="ml-1 text-[10px] opacity-60">({modelMode})</span>
|
||||||
({modelMode})
|
|
||||||
</span>
|
|
||||||
)}
|
)}
|
||||||
</TooltipContent>
|
</TooltipContent>
|
||||||
)}
|
)}
|
||||||
@@ -164,7 +163,7 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa
|
|||||||
{isAddModelModalOpen ? (
|
{isAddModelModalOpen ? (
|
||||||
<AddModelModal setView={setView} onClose={() => setIsAddModelModalOpen(false)} />
|
<AddModelModal setView={setView} onClose={() => setIsAddModelModalOpen(false)} />
|
||||||
) : null}
|
) : null}
|
||||||
|
|
||||||
{isLeadWorkerModalOpen ? (
|
{isLeadWorkerModalOpen ? (
|
||||||
<Modal onClose={() => setIsLeadWorkerModalOpen(false)}>
|
<Modal onClose={() => setIsLeadWorkerModalOpen(false)}>
|
||||||
<LeadWorkerSettings onClose={() => setIsLeadWorkerModalOpen(false)} />
|
<LeadWorkerSettings onClose={() => setIsLeadWorkerModalOpen(false)} />
|
||||||
|
|||||||
@@ -21,7 +21,9 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
|||||||
const [failureThreshold, setFailureThreshold] = useState<number>(2);
|
const [failureThreshold, setFailureThreshold] = useState<number>(2);
|
||||||
const [fallbackTurns, setFallbackTurns] = useState<number>(2);
|
const [fallbackTurns, setFallbackTurns] = useState<number>(2);
|
||||||
const [isEnabled, setIsEnabled] = useState(false);
|
const [isEnabled, setIsEnabled] = useState(false);
|
||||||
const [modelOptions, setModelOptions] = useState<{ value: string; label: string; provider: string }[]>([]);
|
const [modelOptions, setModelOptions] = useState<
|
||||||
|
{ value: string; label: string; provider: string }[]
|
||||||
|
>([]);
|
||||||
const [isLoading, setIsLoading] = useState(true);
|
const [isLoading, setIsLoading] = useState(true);
|
||||||
|
|
||||||
// Load current configuration
|
// Load current configuration
|
||||||
@@ -51,7 +53,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
|||||||
if (leadTurnsConfig) setLeadTurns(Number(leadTurnsConfig));
|
if (leadTurnsConfig) setLeadTurns(Number(leadTurnsConfig));
|
||||||
if (failureThresholdConfig) setFailureThreshold(Number(failureThresholdConfig));
|
if (failureThresholdConfig) setFailureThreshold(Number(failureThresholdConfig));
|
||||||
if (fallbackTurnsConfig) setFallbackTurns(Number(fallbackTurnsConfig));
|
if (fallbackTurnsConfig) setFallbackTurns(Number(fallbackTurnsConfig));
|
||||||
|
|
||||||
// Set worker model to current model or from config
|
// Set worker model to current model or from config
|
||||||
const workerModelConfig = await read('GOOSE_MODEL', false);
|
const workerModelConfig = await read('GOOSE_MODEL', false);
|
||||||
if (workerModelConfig) {
|
if (workerModelConfig) {
|
||||||
@@ -59,7 +61,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
|||||||
} else if (currentModel) {
|
} else if (currentModel) {
|
||||||
setWorkerModel(currentModel as string);
|
setWorkerModel(currentModel as string);
|
||||||
}
|
}
|
||||||
|
|
||||||
const workerProviderConfig = await read('GOOSE_PROVIDER', false);
|
const workerProviderConfig = await read('GOOSE_PROVIDER', false);
|
||||||
if (workerProviderConfig) {
|
if (workerProviderConfig) {
|
||||||
setWorkerProvider(workerProviderConfig as string);
|
setWorkerProvider(workerProviderConfig as string);
|
||||||
@@ -69,7 +71,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
|||||||
const providers = await getProviders(false);
|
const providers = await getProviders(false);
|
||||||
const activeProviders = providers.filter((p) => p.is_configured);
|
const activeProviders = providers.filter((p) => p.is_configured);
|
||||||
const options: { value: string; label: string; provider: string }[] = [];
|
const options: { value: string; label: string; provider: string }[] = [];
|
||||||
|
|
||||||
activeProviders.forEach(({ metadata, name }) => {
|
activeProviders.forEach(({ metadata, name }) => {
|
||||||
if (metadata.known_models) {
|
if (metadata.known_models) {
|
||||||
metadata.known_models.forEach((model) => {
|
metadata.known_models.forEach((model) => {
|
||||||
@@ -81,7 +83,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
setModelOptions(options);
|
setModelOptions(options);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error loading configuration:', error);
|
console.error('Error loading configuration:', error);
|
||||||
@@ -184,9 +186,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
|||||||
placeholder="Select worker model..."
|
placeholder="Select worker model..."
|
||||||
isDisabled={!isEnabled}
|
isDisabled={!isEnabled}
|
||||||
/>
|
/>
|
||||||
<p className="text-xs text-textSubtle">
|
<p className="text-xs text-textSubtle">Fast model for routine execution tasks</p>
|
||||||
Fast model for routine execution tasks
|
|
||||||
</p>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="space-y-4 pt-4 border-t border-borderSubtle">
|
<div className="space-y-4 pt-4 border-t border-borderSubtle">
|
||||||
@@ -242,9 +242,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
|||||||
className="w-20"
|
className="w-20"
|
||||||
disabled={!isEnabled}
|
disabled={!isEnabled}
|
||||||
/>
|
/>
|
||||||
<p className="text-xs text-textSubtle">
|
<p className="text-xs text-textSubtle">Turns to use lead model during fallback</p>
|
||||||
Turns to use lead model during fallback
|
|
||||||
</p>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -259,4 +257,4 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ import { useCurrentModelInfo } from '../components/ChatView';
|
|||||||
|
|
||||||
export function useCurrentModel() {
|
export function useCurrentModel() {
|
||||||
const modelInfo = useCurrentModelInfo();
|
const modelInfo = useCurrentModelInfo();
|
||||||
|
|
||||||
return {
|
return {
|
||||||
currentModel: modelInfo?.model || null,
|
currentModel: modelInfo?.model || null,
|
||||||
isLoading: false
|
isLoading: false,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,12 +2,26 @@ import { useState, useCallback, useEffect, useRef, useId } from 'react';
|
|||||||
import useSWR from 'swr';
|
import useSWR from 'swr';
|
||||||
import { getSecretKey } from '../config';
|
import { getSecretKey } from '../config';
|
||||||
import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message';
|
import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message';
|
||||||
|
import { getSessionHistory } from '../api';
|
||||||
|
|
||||||
// Ensure TextDecoder is available in the global scope
|
// Ensure TextDecoder is available in the global scope
|
||||||
const TextDecoder = globalThis.TextDecoder;
|
const TextDecoder = globalThis.TextDecoder;
|
||||||
|
|
||||||
type JsonValue = string | number | boolean | null | JsonValue[] | { [key: string]: JsonValue };
|
type JsonValue = string | number | boolean | null | JsonValue[] | { [key: string]: JsonValue };
|
||||||
|
|
||||||
|
export interface SessionMetadata {
|
||||||
|
workingDir: string;
|
||||||
|
description: string;
|
||||||
|
scheduleId: string | null;
|
||||||
|
messageCount: number;
|
||||||
|
totalTokens: number | null;
|
||||||
|
inputTokens: number | null;
|
||||||
|
outputTokens: number | null;
|
||||||
|
accumulatedTotalTokens: number | null;
|
||||||
|
accumulatedInputTokens: number | null;
|
||||||
|
accumulatedOutputTokens: number | null;
|
||||||
|
}
|
||||||
|
|
||||||
export interface NotificationEvent {
|
export interface NotificationEvent {
|
||||||
type: 'Notification';
|
type: 'Notification';
|
||||||
request_id: string;
|
request_id: string;
|
||||||
@@ -141,9 +155,12 @@ export interface UseMessageStreamHelpers {
|
|||||||
updateMessageStreamBody?: (newBody: object) => void;
|
updateMessageStreamBody?: (newBody: object) => void;
|
||||||
|
|
||||||
notifications: NotificationEvent[];
|
notifications: NotificationEvent[];
|
||||||
|
|
||||||
/** Current model info from the backend */
|
/** Current model info from the backend */
|
||||||
currentModelInfo: { model: string; mode: string } | null;
|
currentModelInfo: { model: string; mode: string } | null;
|
||||||
|
|
||||||
|
/** Session metadata including token counts */
|
||||||
|
sessionMetadata: SessionMetadata | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -172,7 +189,10 @@ export function useMessageStream({
|
|||||||
});
|
});
|
||||||
|
|
||||||
const [notifications, setNotifications] = useState<NotificationEvent[]>([]);
|
const [notifications, setNotifications] = useState<NotificationEvent[]>([]);
|
||||||
const [currentModelInfo, setCurrentModelInfo] = useState<{ model: string; mode: string } | null>(null);
|
const [currentModelInfo, setCurrentModelInfo] = useState<{ model: string; mode: string } | null>(
|
||||||
|
null
|
||||||
|
);
|
||||||
|
const [sessionMetadata, setSessionMetadata] = useState<SessionMetadata | null>(null);
|
||||||
|
|
||||||
// expose a way to update the body so we can update the session id when CLE occurs
|
// expose a way to update the body so we can update the session id when CLE occurs
|
||||||
const updateMessageStreamBody = useCallback((newBody: object) => {
|
const updateMessageStreamBody = useCallback((newBody: object) => {
|
||||||
@@ -291,13 +311,41 @@ export function useMessageStream({
|
|||||||
case 'Error':
|
case 'Error':
|
||||||
throw new Error(parsedEvent.error);
|
throw new Error(parsedEvent.error);
|
||||||
|
|
||||||
case 'Finish':
|
case 'Finish': {
|
||||||
// Call onFinish with the last message if available
|
// Call onFinish with the last message if available
|
||||||
if (onFinish && currentMessages.length > 0) {
|
if (onFinish && currentMessages.length > 0) {
|
||||||
const lastMessage = currentMessages[currentMessages.length - 1];
|
const lastMessage = currentMessages[currentMessages.length - 1];
|
||||||
onFinish(lastMessage, parsedEvent.reason);
|
onFinish(lastMessage, parsedEvent.reason);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fetch updated session metadata with token counts
|
||||||
|
const sessionId = (extraMetadataRef.current.body as Record<string, unknown>)?.session_id as string;
|
||||||
|
if (sessionId) {
|
||||||
|
try {
|
||||||
|
const sessionResponse = await getSessionHistory({
|
||||||
|
path: { session_id: sessionId },
|
||||||
|
});
|
||||||
|
|
||||||
|
if (sessionResponse.data?.metadata) {
|
||||||
|
setSessionMetadata({
|
||||||
|
workingDir: sessionResponse.data.metadata.working_dir,
|
||||||
|
description: sessionResponse.data.metadata.description,
|
||||||
|
scheduleId: sessionResponse.data.metadata.schedule_id || null,
|
||||||
|
messageCount: sessionResponse.data.metadata.message_count,
|
||||||
|
totalTokens: sessionResponse.data.metadata.total_tokens || null,
|
||||||
|
inputTokens: sessionResponse.data.metadata.input_tokens || null,
|
||||||
|
outputTokens: sessionResponse.data.metadata.output_tokens || null,
|
||||||
|
accumulatedTotalTokens: sessionResponse.data.metadata.accumulated_total_tokens || null,
|
||||||
|
accumulatedInputTokens: sessionResponse.data.metadata.accumulated_input_tokens || null,
|
||||||
|
accumulatedOutputTokens: sessionResponse.data.metadata.accumulated_output_tokens || null,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to fetch session metadata:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error('Error parsing SSE event:', e);
|
console.error('Error parsing SSE event:', e);
|
||||||
@@ -559,5 +607,6 @@ export function useMessageStream({
|
|||||||
updateMessageStreamBody,
|
updateMessageStreamBody,
|
||||||
notifications,
|
notifications,
|
||||||
currentModelInfo,
|
currentModelInfo,
|
||||||
|
sessionMetadata,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,10 @@ export interface SessionMetadata {
|
|||||||
message_count: number;
|
message_count: number;
|
||||||
total_tokens: number | null;
|
total_tokens: number | null;
|
||||||
working_dir: string; // Required in type, but may be missing in old sessions
|
working_dir: string; // Required in type, but may be missing in old sessions
|
||||||
|
// Add the accumulated token fields from the API
|
||||||
|
accumulated_input_tokens?: number | null;
|
||||||
|
accumulated_output_tokens?: number | null;
|
||||||
|
accumulated_total_tokens?: number | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to ensure working directory is set
|
// Helper function to ensure working directory is set
|
||||||
@@ -16,6 +20,9 @@ export function ensureWorkingDir(metadata: Partial<SessionMetadata>): SessionMet
|
|||||||
message_count: metadata.message_count || 0,
|
message_count: metadata.message_count || 0,
|
||||||
total_tokens: metadata.total_tokens || null,
|
total_tokens: metadata.total_tokens || null,
|
||||||
working_dir: metadata.working_dir || process.env.HOME || '',
|
working_dir: metadata.working_dir || process.env.HOME || '',
|
||||||
|
accumulated_input_tokens: metadata.accumulated_input_tokens || null,
|
||||||
|
accumulated_output_tokens: metadata.accumulated_output_tokens || null,
|
||||||
|
accumulated_total_tokens: metadata.accumulated_total_tokens || null,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
593
ui/desktop/src/utils/costDatabase.ts
Normal file
593
ui/desktop/src/utils/costDatabase.ts
Normal file
@@ -0,0 +1,593 @@
|
|||||||
|
// Import the proper type from ConfigContext
|
||||||
|
import { getApiUrl, getSecretKey } from '../config';
|
||||||
|
|
||||||
|
export interface ModelCostInfo {
|
||||||
|
input_token_cost: number; // Cost per token for input (in USD)
|
||||||
|
output_token_cost: number; // Cost per token for output (in USD)
|
||||||
|
currency: string; // Currency symbol
|
||||||
|
}
|
||||||
|
|
||||||
|
// In-memory cache for current model pricing only
|
||||||
|
let currentModelPricing: {
|
||||||
|
provider: string;
|
||||||
|
model: string;
|
||||||
|
costInfo: ModelCostInfo | null;
|
||||||
|
} | null = null;
|
||||||
|
|
||||||
|
// LocalStorage keys
|
||||||
|
const PRICING_CACHE_KEY = 'goose_pricing_cache';
|
||||||
|
const PRICING_CACHE_TIMESTAMP_KEY = 'goose_pricing_cache_timestamp';
|
||||||
|
const RECENTLY_USED_MODELS_KEY = 'goose_recently_used_models';
|
||||||
|
const CACHE_TTL_MS = 7 * 24 * 60 * 60 * 1000; // 7 days in milliseconds
|
||||||
|
const MAX_RECENTLY_USED_MODELS = 20; // Keep only the last 20 used models in cache
|
||||||
|
|
||||||
|
interface PricingItem {
|
||||||
|
provider: string;
|
||||||
|
model: string;
|
||||||
|
input_token_cost: number;
|
||||||
|
output_token_cost: number;
|
||||||
|
currency: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface PricingCacheData {
|
||||||
|
pricing: PricingItem[];
|
||||||
|
timestamp: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface RecentlyUsedModel {
|
||||||
|
provider: string;
|
||||||
|
model: string;
|
||||||
|
lastUsed: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get recently used models from localStorage
|
||||||
|
*/
|
||||||
|
function getRecentlyUsedModels(): RecentlyUsedModel[] {
|
||||||
|
try {
|
||||||
|
const stored = localStorage.getItem(RECENTLY_USED_MODELS_KEY);
|
||||||
|
return stored ? JSON.parse(stored) : [];
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error loading recently used models:', error);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a model to the recently used list
|
||||||
|
*/
|
||||||
|
function addToRecentlyUsed(provider: string, model: string): void {
|
||||||
|
try {
|
||||||
|
let recentModels = getRecentlyUsedModels();
|
||||||
|
|
||||||
|
// Remove existing entry if present
|
||||||
|
recentModels = recentModels.filter((m) => !(m.provider === provider && m.model === model));
|
||||||
|
|
||||||
|
// Add to front
|
||||||
|
recentModels.unshift({ provider, model, lastUsed: Date.now() });
|
||||||
|
|
||||||
|
// Keep only the most recent models
|
||||||
|
recentModels = recentModels.slice(0, MAX_RECENTLY_USED_MODELS);
|
||||||
|
|
||||||
|
localStorage.setItem(RECENTLY_USED_MODELS_KEY, JSON.stringify(recentModels));
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error saving recently used models:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load pricing data from localStorage cache - only for recently used models
|
||||||
|
*/
|
||||||
|
function loadPricingFromLocalStorage(): PricingCacheData | null {
|
||||||
|
try {
|
||||||
|
const cached = localStorage.getItem(PRICING_CACHE_KEY);
|
||||||
|
const timestamp = localStorage.getItem(PRICING_CACHE_TIMESTAMP_KEY);
|
||||||
|
|
||||||
|
if (cached && timestamp) {
|
||||||
|
const cacheAge = Date.now() - parseInt(timestamp, 10);
|
||||||
|
if (cacheAge < CACHE_TTL_MS) {
|
||||||
|
const fullCache = JSON.parse(cached) as PricingCacheData;
|
||||||
|
const recentModels = getRecentlyUsedModels();
|
||||||
|
|
||||||
|
// Filter to only include recently used models
|
||||||
|
const filteredPricing = fullCache.pricing.filter((p) =>
|
||||||
|
recentModels.some((r) => r.provider === p.provider && r.model === p.model)
|
||||||
|
);
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`Loading ${filteredPricing.length} recently used models from cache (out of ${fullCache.pricing.length} total)`
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
pricing: filteredPricing,
|
||||||
|
timestamp: fullCache.timestamp,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
console.log('LocalStorage pricing cache expired');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error loading pricing from localStorage:', error);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Save pricing data to localStorage - merge with existing data
|
||||||
|
*/
|
||||||
|
function savePricingToLocalStorage(data: PricingCacheData, mergeWithExisting = true): void {
|
||||||
|
try {
|
||||||
|
if (mergeWithExisting) {
|
||||||
|
// Load existing full cache
|
||||||
|
const existingCached = localStorage.getItem(PRICING_CACHE_KEY);
|
||||||
|
if (existingCached) {
|
||||||
|
const existingData = JSON.parse(existingCached) as PricingCacheData;
|
||||||
|
|
||||||
|
// Create a map of existing pricing for quick lookup
|
||||||
|
const pricingMap = new Map<string, (typeof data.pricing)[0]>();
|
||||||
|
existingData.pricing.forEach((p) => {
|
||||||
|
pricingMap.set(`${p.provider}/${p.model}`, p);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update with new data
|
||||||
|
data.pricing.forEach((p) => {
|
||||||
|
pricingMap.set(`${p.provider}/${p.model}`, p);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert back to array
|
||||||
|
data = {
|
||||||
|
pricing: Array.from(pricingMap.values()),
|
||||||
|
timestamp: data.timestamp,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
localStorage.setItem(PRICING_CACHE_KEY, JSON.stringify(data));
|
||||||
|
localStorage.setItem(PRICING_CACHE_TIMESTAMP_KEY, data.timestamp.toString());
|
||||||
|
console.log(`Saved ${data.pricing.length} models to localStorage cache`);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error saving pricing to localStorage:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetch pricing data from backend for specific provider/model
|
||||||
|
*/
|
||||||
|
async function fetchPricingForModel(
|
||||||
|
provider: string,
|
||||||
|
model: string
|
||||||
|
): Promise<ModelCostInfo | null> {
|
||||||
|
try {
|
||||||
|
const apiUrl = getApiUrl('/config/pricing');
|
||||||
|
const secretKey = getSecretKey();
|
||||||
|
|
||||||
|
console.log(`Fetching pricing for ${provider}/${model} from ${apiUrl}`);
|
||||||
|
|
||||||
|
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||||
|
if (secretKey) {
|
||||||
|
headers['X-Secret-Key'] = secretKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(apiUrl, {
|
||||||
|
method: 'POST',
|
||||||
|
headers,
|
||||||
|
body: JSON.stringify({ configured_only: false }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
console.error('Failed to fetch pricing data:', response.status);
|
||||||
|
throw new Error(`API request failed with status ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
console.log('Pricing response:', data);
|
||||||
|
|
||||||
|
// Find the specific model pricing
|
||||||
|
const pricing = data.pricing?.find(
|
||||||
|
(p: {
|
||||||
|
provider: string;
|
||||||
|
model: string;
|
||||||
|
input_token_cost: number;
|
||||||
|
output_token_cost: number;
|
||||||
|
currency: string;
|
||||||
|
}) => {
|
||||||
|
const providerMatch = p.provider.toLowerCase() === provider.toLowerCase();
|
||||||
|
|
||||||
|
// More flexible model matching - handle versioned models
|
||||||
|
let modelMatch = p.model === model;
|
||||||
|
|
||||||
|
// If exact match fails, try matching without version suffix
|
||||||
|
if (!modelMatch && model.includes('-20')) {
|
||||||
|
// Remove date suffix like -20241022
|
||||||
|
const modelWithoutDate = model.replace(/-20\d{6}$/, '');
|
||||||
|
modelMatch = p.model === modelWithoutDate;
|
||||||
|
|
||||||
|
// Also try with dots instead of dashes (claude-3-5-sonnet vs claude-3.5-sonnet)
|
||||||
|
if (!modelMatch) {
|
||||||
|
const modelWithDots = modelWithoutDate.replace(/-(\d)-/g, '.$1.');
|
||||||
|
modelMatch = p.model === modelWithDots;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`Comparing: ${p.provider}/${p.model} with ${provider}/${model} - Provider match: ${providerMatch}, Model match: ${modelMatch}`
|
||||||
|
);
|
||||||
|
return providerMatch && modelMatch;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
console.log(`Found pricing for ${provider}/${model}:`, pricing);
|
||||||
|
|
||||||
|
if (pricing) {
|
||||||
|
return {
|
||||||
|
input_token_cost: pricing.input_token_cost,
|
||||||
|
output_token_cost: pricing.output_token_cost,
|
||||||
|
currency: pricing.currency || '$',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`No pricing found for ${provider}/${model} in:`,
|
||||||
|
data.pricing?.map((p: { provider: string; model: string }) => `${p.provider}/${p.model}`)
|
||||||
|
);
|
||||||
|
|
||||||
|
// API call succeeded but model not found in pricing data
|
||||||
|
return null;
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error fetching pricing data:', error);
|
||||||
|
// Re-throw the error so the caller can distinguish between API failure and model not found
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize the cost database - only load commonly used models on startup
|
||||||
|
*/
|
||||||
|
export async function initializeCostDatabase(): Promise<void> {
|
||||||
|
try {
|
||||||
|
// Clean up any existing large caches first
|
||||||
|
cleanupPricingCache();
|
||||||
|
|
||||||
|
// First check if we have valid cached data
|
||||||
|
const cachedData = loadPricingFromLocalStorage();
|
||||||
|
if (cachedData && cachedData.pricing.length > 0) {
|
||||||
|
console.log('Using cached pricing data from localStorage');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// List of commonly used models to pre-fetch
|
||||||
|
const commonModels = [
|
||||||
|
{ provider: 'openai', model: 'gpt-4o' },
|
||||||
|
{ provider: 'openai', model: 'gpt-4o-mini' },
|
||||||
|
{ provider: 'openai', model: 'gpt-4-turbo' },
|
||||||
|
{ provider: 'openai', model: 'gpt-4' },
|
||||||
|
{ provider: 'openai', model: 'gpt-3.5-turbo' },
|
||||||
|
{ provider: 'anthropic', model: 'claude-3-5-sonnet' },
|
||||||
|
{ provider: 'anthropic', model: 'claude-3-5-sonnet-20241022' },
|
||||||
|
{ provider: 'anthropic', model: 'claude-3-opus' },
|
||||||
|
{ provider: 'anthropic', model: 'claude-3-sonnet' },
|
||||||
|
{ provider: 'anthropic', model: 'claude-3-haiku' },
|
||||||
|
{ provider: 'google', model: 'gemini-1.5-pro' },
|
||||||
|
{ provider: 'google', model: 'gemini-1.5-flash' },
|
||||||
|
{ provider: 'deepseek', model: 'deepseek-chat' },
|
||||||
|
{ provider: 'deepseek', model: 'deepseek-reasoner' },
|
||||||
|
{ provider: 'meta-llama', model: 'llama-3.2-90b-text-preview' },
|
||||||
|
{ provider: 'meta-llama', model: 'llama-3.1-405b-instruct' },
|
||||||
|
];
|
||||||
|
|
||||||
|
// Get recently used models
|
||||||
|
const recentModels = getRecentlyUsedModels();
|
||||||
|
|
||||||
|
// Combine common and recent models (deduplicated)
|
||||||
|
const modelsToFetch = new Map<string, { provider: string; model: string }>();
|
||||||
|
|
||||||
|
// Add common models
|
||||||
|
commonModels.forEach((m) => {
|
||||||
|
modelsToFetch.set(`${m.provider}/${m.model}`, m);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add recent models
|
||||||
|
recentModels.forEach((m) => {
|
||||||
|
modelsToFetch.set(`${m.provider}/${m.model}`, { provider: m.provider, model: m.model });
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log(`Initializing cost database with ${modelsToFetch.size} models...`);
|
||||||
|
|
||||||
|
// Fetch only the pricing we need
|
||||||
|
const apiUrl = getApiUrl('/config/pricing');
|
||||||
|
const secretKey = getSecretKey();
|
||||||
|
|
||||||
|
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||||
|
if (secretKey) {
|
||||||
|
headers['X-Secret-Key'] = secretKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(apiUrl, {
|
||||||
|
method: 'POST',
|
||||||
|
headers,
|
||||||
|
body: JSON.stringify({
|
||||||
|
configured_only: false,
|
||||||
|
models: Array.from(modelsToFetch.values()), // Send specific models if API supports it
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
console.error('Failed to fetch initial pricing data:', response.status);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
console.log(`Fetched pricing for ${data.pricing?.length || 0} models`);
|
||||||
|
|
||||||
|
if (data.pricing && data.pricing.length > 0) {
|
||||||
|
// Filter to only the models we requested (in case API returns all)
|
||||||
|
const filteredPricing = data.pricing.filter((p: PricingItem) =>
|
||||||
|
modelsToFetch.has(`${p.provider}/${p.model}`)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Save to localStorage
|
||||||
|
const cacheData: PricingCacheData = {
|
||||||
|
pricing: filteredPricing.length > 0 ? filteredPricing : data.pricing.slice(0, 50), // Fallback to first 50 if filtering didn't work
|
||||||
|
timestamp: Date.now(),
|
||||||
|
};
|
||||||
|
savePricingToLocalStorage(cacheData, false); // Don't merge on initial load
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error initializing cost database:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update model costs from providers - no longer needed
|
||||||
|
*/
|
||||||
|
export async function updateAllModelCosts(): Promise<void> {
|
||||||
|
// No-op - we fetch on demand now
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get cost information for a specific model with caching
|
||||||
|
*/
|
||||||
|
export function getCostForModel(provider: string, model: string): ModelCostInfo | null {
|
||||||
|
// Track this model as recently used
|
||||||
|
addToRecentlyUsed(provider, model);
|
||||||
|
|
||||||
|
// Check if it's the same model we already have cached in memory
|
||||||
|
if (
|
||||||
|
currentModelPricing &&
|
||||||
|
currentModelPricing.provider === provider &&
|
||||||
|
currentModelPricing.model === model
|
||||||
|
) {
|
||||||
|
return currentModelPricing.costInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For local/free providers, return zero cost immediately
|
||||||
|
const freeProviders = ['ollama', 'local', 'localhost'];
|
||||||
|
if (freeProviders.includes(provider.toLowerCase())) {
|
||||||
|
const zeroCost = {
|
||||||
|
input_token_cost: 0,
|
||||||
|
output_token_cost: 0,
|
||||||
|
currency: '$',
|
||||||
|
};
|
||||||
|
currentModelPricing = { provider, model, costInfo: zeroCost };
|
||||||
|
return zeroCost;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check localStorage cache (which now only contains recently used models)
|
||||||
|
const cachedData = loadPricingFromLocalStorage();
|
||||||
|
if (cachedData) {
|
||||||
|
const pricing = cachedData.pricing.find((p) => {
|
||||||
|
const providerMatch = p.provider.toLowerCase() === provider.toLowerCase();
|
||||||
|
|
||||||
|
// More flexible model matching - handle versioned models
|
||||||
|
let modelMatch = p.model === model;
|
||||||
|
|
||||||
|
// If exact match fails, try matching without version suffix
|
||||||
|
if (!modelMatch && model.includes('-20')) {
|
||||||
|
// Remove date suffix like -20241022
|
||||||
|
const modelWithoutDate = model.replace(/-20\d{6}$/, '');
|
||||||
|
modelMatch = p.model === modelWithoutDate;
|
||||||
|
|
||||||
|
// Also try with dots instead of dashes (claude-3-5-sonnet vs claude-3.5-sonnet)
|
||||||
|
if (!modelMatch) {
|
||||||
|
const modelWithDots = modelWithoutDate.replace(/-(\d)-/g, '.$1.');
|
||||||
|
modelMatch = p.model === modelWithDots;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return providerMatch && modelMatch;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (pricing) {
|
||||||
|
const costInfo = {
|
||||||
|
input_token_cost: pricing.input_token_cost,
|
||||||
|
output_token_cost: pricing.output_token_cost,
|
||||||
|
currency: pricing.currency || '$',
|
||||||
|
};
|
||||||
|
currentModelPricing = { provider, model, costInfo };
|
||||||
|
return costInfo;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Need to fetch new pricing - return null for now
|
||||||
|
// The component will handle the async fetch
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetch and cache pricing for a model
|
||||||
|
*/
|
||||||
|
export async function fetchAndCachePricing(
|
||||||
|
provider: string,
|
||||||
|
model: string
|
||||||
|
): Promise<{ costInfo: ModelCostInfo | null; error?: string } | null> {
|
||||||
|
try {
|
||||||
|
const costInfo = await fetchPricingForModel(provider, model);
|
||||||
|
|
||||||
|
if (costInfo) {
|
||||||
|
// Cache the result in memory
|
||||||
|
currentModelPricing = { provider, model, costInfo };
|
||||||
|
|
||||||
|
// Update localStorage cache with this new data
|
||||||
|
const cachedData = loadPricingFromLocalStorage();
|
||||||
|
if (cachedData) {
|
||||||
|
// Check if this model already exists in cache
|
||||||
|
const existingIndex = cachedData.pricing.findIndex(
|
||||||
|
(p) => p.provider.toLowerCase() === provider.toLowerCase() && p.model === model
|
||||||
|
);
|
||||||
|
|
||||||
|
const newPricing = {
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
input_token_cost: costInfo.input_token_cost,
|
||||||
|
output_token_cost: costInfo.output_token_cost,
|
||||||
|
currency: costInfo.currency,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (existingIndex >= 0) {
|
||||||
|
// Update existing
|
||||||
|
cachedData.pricing[existingIndex] = newPricing;
|
||||||
|
} else {
|
||||||
|
// Add new
|
||||||
|
cachedData.pricing.push(newPricing);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save updated cache
|
||||||
|
savePricingToLocalStorage(cachedData);
|
||||||
|
}
|
||||||
|
|
||||||
|
return { costInfo };
|
||||||
|
} else {
|
||||||
|
// Cache the null result in memory
|
||||||
|
currentModelPricing = { provider, model, costInfo: null };
|
||||||
|
|
||||||
|
// Check if the API call succeeded but model wasn't found
|
||||||
|
// We can determine this by checking if we got a response but no matching model
|
||||||
|
return { costInfo: null, error: 'model_not_found' };
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error in fetchAndCachePricing:', error);
|
||||||
|
// This is a real API/network error
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Refresh pricing data from backend - only refresh recently used models
|
||||||
|
*/
|
||||||
|
export async function refreshPricing(): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
const apiUrl = getApiUrl('/config/pricing');
|
||||||
|
const secretKey = getSecretKey();
|
||||||
|
|
||||||
|
const headers: HeadersInit = { 'Content-Type': 'application/json' };
|
||||||
|
if (secretKey) {
|
||||||
|
headers['X-Secret-Key'] = secretKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get recently used models to refresh
|
||||||
|
const recentModels = getRecentlyUsedModels();
|
||||||
|
|
||||||
|
// Add some common models as well
|
||||||
|
const commonModels = [
|
||||||
|
{ provider: 'openai', model: 'gpt-4o' },
|
||||||
|
{ provider: 'openai', model: 'gpt-4o-mini' },
|
||||||
|
{ provider: 'anthropic', model: 'claude-3-5-sonnet-20241022' },
|
||||||
|
{ provider: 'google', model: 'gemini-1.5-pro' },
|
||||||
|
];
|
||||||
|
|
||||||
|
// Combine and deduplicate
|
||||||
|
const modelsToRefresh = new Map<string, { provider: string; model: string }>();
|
||||||
|
|
||||||
|
commonModels.forEach((m) => {
|
||||||
|
modelsToRefresh.set(`${m.provider}/${m.model}`, m);
|
||||||
|
});
|
||||||
|
|
||||||
|
recentModels.forEach((m) => {
|
||||||
|
modelsToRefresh.set(`${m.provider}/${m.model}`, { provider: m.provider, model: m.model });
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log(`Refreshing pricing for ${modelsToRefresh.size} models...`);
|
||||||
|
|
||||||
|
const response = await fetch(apiUrl, {
|
||||||
|
method: 'POST',
|
||||||
|
headers,
|
||||||
|
body: JSON.stringify({
|
||||||
|
configured_only: false,
|
||||||
|
models: Array.from(modelsToRefresh.values()), // Send specific models if API supports it
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.ok) {
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.pricing && data.pricing.length > 0) {
|
||||||
|
// Filter to only the models we requested (in case API returns all)
|
||||||
|
const filteredPricing = data.pricing.filter((p: PricingItem) =>
|
||||||
|
modelsToRefresh.has(`${p.provider}/${p.model}`)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Save fresh data to localStorage (merge with existing)
|
||||||
|
const cacheData: PricingCacheData = {
|
||||||
|
pricing: filteredPricing.length > 0 ? filteredPricing : data.pricing.slice(0, 50),
|
||||||
|
timestamp: Date.now(),
|
||||||
|
};
|
||||||
|
savePricingToLocalStorage(cacheData, true); // Merge with existing
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear current memory cache to force re-fetch
|
||||||
|
currentModelPricing = null;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error refreshing pricing data:', error);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clean up old/unused models from the cache
|
||||||
|
*/
|
||||||
|
export function cleanupPricingCache(): void {
|
||||||
|
try {
|
||||||
|
const recentModels = getRecentlyUsedModels();
|
||||||
|
const cachedData = localStorage.getItem(PRICING_CACHE_KEY);
|
||||||
|
|
||||||
|
if (!cachedData) return;
|
||||||
|
|
||||||
|
const fullCache = JSON.parse(cachedData) as PricingCacheData;
|
||||||
|
const recentModelKeys = new Set(recentModels.map((m) => `${m.provider}/${m.model}`));
|
||||||
|
|
||||||
|
// Keep only recently used models and common models
|
||||||
|
const commonModelKeys = new Set([
|
||||||
|
'openai/gpt-4o',
|
||||||
|
'openai/gpt-4o-mini',
|
||||||
|
'openai/gpt-4-turbo',
|
||||||
|
'anthropic/claude-3-5-sonnet',
|
||||||
|
'anthropic/claude-3-5-sonnet-20241022',
|
||||||
|
'google/gemini-1.5-pro',
|
||||||
|
'google/gemini-1.5-flash',
|
||||||
|
]);
|
||||||
|
|
||||||
|
const filteredPricing = fullCache.pricing.filter((p) => {
|
||||||
|
const key = `${p.provider}/${p.model}`;
|
||||||
|
return recentModelKeys.has(key) || commonModelKeys.has(key);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (filteredPricing.length < fullCache.pricing.length) {
|
||||||
|
console.log(
|
||||||
|
`Cleaned up pricing cache: reduced from ${fullCache.pricing.length} to ${filteredPricing.length} models`
|
||||||
|
);
|
||||||
|
|
||||||
|
const cleanedCache: PricingCacheData = {
|
||||||
|
pricing: filteredPricing,
|
||||||
|
timestamp: fullCache.timestamp,
|
||||||
|
};
|
||||||
|
|
||||||
|
localStorage.setItem(PRICING_CACHE_KEY, JSON.stringify(cleanedCache));
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error cleaning up pricing cache:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -44,6 +44,10 @@ export default {
|
|||||||
'0%': { transform: 'rotate(0deg)' },
|
'0%': { transform: 'rotate(0deg)' },
|
||||||
'100%': { transform: 'rotate(360deg)' },
|
'100%': { transform: 'rotate(360deg)' },
|
||||||
},
|
},
|
||||||
|
'spin-fast': {
|
||||||
|
'0%': { transform: 'rotate(0deg)' },
|
||||||
|
'100%': { transform: 'rotate(360deg)' },
|
||||||
|
},
|
||||||
indeterminate: {
|
indeterminate: {
|
||||||
'0%': { left: '-40%', width: '40%' },
|
'0%': { left: '-40%', width: '40%' },
|
||||||
'50%': { left: '20%', width: '60%' },
|
'50%': { left: '20%', width: '60%' },
|
||||||
@@ -54,6 +58,7 @@ export default {
|
|||||||
'shimmer-pulse': 'shimmer 4s ease-in-out infinite',
|
'shimmer-pulse': 'shimmer 4s ease-in-out infinite',
|
||||||
'gradient-loader': 'loader 750ms ease-in-out infinite',
|
'gradient-loader': 'loader 750ms ease-in-out infinite',
|
||||||
indeterminate: 'indeterminate 1.5s infinite linear',
|
indeterminate: 'indeterminate 1.5s infinite linear',
|
||||||
|
'spin-fast': 'spin-fast 0.5s linear infinite',
|
||||||
},
|
},
|
||||||
colors: {
|
colors: {
|
||||||
bgApp: 'var(--background-app)',
|
bgApp: 'var(--background-app)',
|
||||||
|
|||||||
Reference in New Issue
Block a user