feat: Add comprehensive cost tracking display for LLM usage (#2992)

Co-authored-by: jack <jack@deck.local>
Co-authored-by: Bradley Axen <baxen@squareup.com>
This commit is contained in:
jack
2025-06-26 10:46:14 +01:00
committed by GitHub
parent 9d48ed980f
commit d2ff4f3746
24 changed files with 1985 additions and 47 deletions

1
Cargo.lock generated
View File

@@ -3427,6 +3427,7 @@ dependencies = [
"chrono",
"criterion",
"ctor",
"dirs 5.0.1",
"dotenv",
"etcetera",
"fs2",

View File

@@ -10,12 +10,23 @@ use goose::scheduler_factory::SchedulerFactory;
use tower_http::cors::{Any, CorsLayer};
use tracing::info;
use goose::providers::pricing::initialize_pricing_cache;
pub async fn run() -> Result<()> {
// Initialize logging
crate::logging::setup_logging(Some("goosed"))?;
let settings = configuration::Settings::new()?;
// Initialize pricing cache on startup
tracing::info!("Initializing pricing cache...");
if let Err(e) = initialize_pricing_cache().await {
tracing::warn!(
"Failed to initialize pricing cache: {}. Pricing data may not be available.",
e
);
}
let secret_key =
std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string());

View File

@@ -13,6 +13,7 @@ use goose::config::{extensions::name_to_key, PermissionManager};
use goose::config::{ExtensionConfigManager, ExtensionEntry};
use goose::model::ModelConfig;
use goose::providers::base::ProviderMetadata;
use goose::providers::pricing::{get_all_pricing, get_model_pricing, refresh_pricing};
use goose::providers::providers as get_providers;
use goose::{agents::ExtensionConfig, config::permission::PermissionLevel};
use http::{HeaderMap, StatusCode};
@@ -314,6 +315,128 @@ pub async fn providers(
Ok(Json(providers_response))
}
#[derive(Serialize, ToSchema)]
pub struct PricingData {
pub provider: String,
pub model: String,
pub input_token_cost: f64,
pub output_token_cost: f64,
pub currency: String,
pub context_length: Option<u32>,
}
#[derive(Serialize, ToSchema)]
pub struct PricingResponse {
pub pricing: Vec<PricingData>,
pub source: String,
}
#[derive(Deserialize, ToSchema)]
pub struct PricingQuery {
/// If true, only return pricing for configured providers. If false, return all.
pub configured_only: Option<bool>,
}
#[utoipa::path(
post,
path = "/config/pricing",
request_body = PricingQuery,
responses(
(status = 200, description = "Model pricing data retrieved successfully", body = PricingResponse)
)
)]
pub async fn get_pricing(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(query): Json<PricingQuery>,
) -> Result<Json<PricingResponse>, StatusCode> {
verify_secret_key(&headers, &state)?;
let configured_only = query.configured_only.unwrap_or(true);
// If refresh requested (configured_only = false), refresh the cache
if !configured_only {
if let Err(e) = refresh_pricing().await {
tracing::error!("Failed to refresh pricing data: {}", e);
}
}
let mut pricing_data = Vec::new();
if !configured_only {
// Get ALL pricing data from the cache
let all_pricing = get_all_pricing().await;
for (provider, models) in all_pricing {
for (model, pricing) in models {
pricing_data.push(PricingData {
provider: provider.clone(),
model: model.clone(),
input_token_cost: pricing.input_cost,
output_token_cost: pricing.output_cost,
currency: "$".to_string(),
context_length: pricing.context_length,
});
}
}
} else {
// Get only configured providers' pricing
let providers_metadata = get_providers();
for metadata in providers_metadata {
// Skip unconfigured providers if filtering
if !check_provider_configured(&metadata) {
continue;
}
for model_info in &metadata.known_models {
// Try to get pricing from cache
if let Some(pricing) = get_model_pricing(&metadata.name, &model_info.name).await {
pricing_data.push(PricingData {
provider: metadata.name.clone(),
model: model_info.name.clone(),
input_token_cost: pricing.input_cost,
output_token_cost: pricing.output_cost,
currency: "$".to_string(),
context_length: pricing.context_length,
});
}
// Check if the model has embedded pricing data
else if let (Some(input_cost), Some(output_cost)) =
(model_info.input_token_cost, model_info.output_token_cost)
{
pricing_data.push(PricingData {
provider: metadata.name.clone(),
model: model_info.name.clone(),
input_token_cost: input_cost,
output_token_cost: output_cost,
currency: model_info
.currency
.clone()
.unwrap_or_else(|| "$".to_string()),
context_length: Some(model_info.context_limit as u32),
});
}
}
}
}
tracing::info!(
"Returning pricing for {} models{}",
pricing_data.len(),
if configured_only {
" (configured providers only)"
} else {
" (all cached models)"
}
);
Ok(Json(PricingResponse {
pricing: pricing_data,
source: "openrouter".to_string(),
}))
}
#[utoipa::path(
post,
path = "/config/init",
@@ -471,6 +594,7 @@ pub fn routes(state: Arc<AppState>) -> Router {
.route("/config/extensions", post(add_extension))
.route("/config/extensions/{name}", delete(remove_extension))
.route("/config/providers", get(providers))
.route("/config/pricing", post(get_pricing))
.route("/config/init", post(init_config))
.route("/config/backup", post(backup_config))
.route("/config/permissions", post(upsert_permissions))

View File

@@ -17,6 +17,7 @@ mcp-core = { path = "../mcp-core" }
anyhow = "1.0"
thiserror = "1.0"
futures = "0.3"
dirs = "5.0"
reqwest = { version = "0.12.9", features = [
"rustls-tls-native-roots",
"json",

View File

@@ -5,7 +5,7 @@ use reqwest::{Client, StatusCode};
use serde_json::Value;
use std::time::Duration;
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
use super::errors::ProviderError;
use super::formats::anthropic::{create_request, get_usage, response_to_message};
use super::utils::{emit_debug_trace, get_model};
@@ -122,12 +122,18 @@ impl AnthropicProvider {
#[async_trait]
impl Provider for AnthropicProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
ProviderMetadata::with_models(
"anthropic",
"Anthropic",
"Claude and other models from Anthropic",
ANTHROPIC_DEFAULT_MODEL,
ANTHROPIC_KNOWN_MODELS.to_vec(),
vec![
ModelInfo::with_cost("claude-3-5-sonnet-20241022", 200000, 0.000003, 0.000015),
ModelInfo::with_cost("claude-3-5-haiku-20241022", 200000, 0.000001, 0.000005),
ModelInfo::with_cost("claude-3-opus-20240229", 200000, 0.000015, 0.000075),
ModelInfo::with_cost("claude-3-sonnet-20240229", 200000, 0.000003, 0.000015),
ModelInfo::with_cost("claude-3-haiku-20240307", 200000, 0.00000025, 0.00000125),
],
ANTHROPIC_DOC_URL,
vec![
ConfigKey::new("ANTHROPIC_API_KEY", true, true, None),

View File

@@ -32,6 +32,41 @@ pub struct ModelInfo {
pub name: String,
/// The maximum context length this model supports
pub context_limit: usize,
/// Cost per token for input (optional)
pub input_token_cost: Option<f64>,
/// Cost per token for output (optional)
pub output_token_cost: Option<f64>,
/// Currency for the costs (default: "$")
pub currency: Option<String>,
}
impl ModelInfo {
/// Create a new ModelInfo with just name and context limit
pub fn new(name: impl Into<String>, context_limit: usize) -> Self {
Self {
name: name.into(),
context_limit,
input_token_cost: None,
output_token_cost: None,
currency: None,
}
}
/// Create a new ModelInfo with cost information (per token)
pub fn with_cost(
name: impl Into<String>,
context_limit: usize,
input_cost: f64,
output_cost: f64,
) -> Self {
Self {
name: name.into(),
context_limit,
input_token_cost: Some(input_cost),
output_token_cost: Some(output_cost),
currency: Some("$".to_string()),
}
}
}
/// Metadata about a provider's configuration requirements and capabilities
@@ -74,6 +109,9 @@ impl ProviderMetadata {
.map(|&name| ModelInfo {
name: name.to_string(),
context_limit: ModelConfig::new(name.to_string()).context_limit(),
input_token_cost: None,
output_token_cost: None,
currency: None,
})
.collect(),
model_doc_link: model_doc_link.to_string(),
@@ -81,6 +119,27 @@ impl ProviderMetadata {
}
}
/// Create a new ProviderMetadata with ModelInfo objects that include cost data
pub fn with_models(
name: &str,
display_name: &str,
description: &str,
default_model: &str,
models: Vec<ModelInfo>,
model_doc_link: &str,
config_keys: Vec<ConfigKey>,
) -> Self {
Self {
name: name.to_string(),
display_name: display_name.to_string(),
description: description.to_string(),
default_model: default_model.to_string(),
known_models: models,
model_doc_link: model_doc_link.to_string(),
config_keys,
}
}
pub fn empty() -> Self {
Self {
name: "".to_string(),
@@ -313,6 +372,9 @@ mod tests {
let info = ModelInfo {
name: "test-model".to_string(),
context_limit: 1000,
input_token_cost: None,
output_token_cost: None,
currency: None,
};
assert_eq!(info.context_limit, 1000);
@@ -320,6 +382,9 @@ mod tests {
let info2 = ModelInfo {
name: "test-model".to_string(),
context_limit: 1000,
input_token_cost: None,
output_token_cost: None,
currency: None,
};
assert_eq!(info, info2);
@@ -327,7 +392,20 @@ mod tests {
let info3 = ModelInfo {
name: "test-model".to_string(),
context_limit: 2000,
input_token_cost: None,
output_token_cost: None,
currency: None,
};
assert_ne!(info, info3);
}
#[test]
fn test_model_info_with_cost() {
let info = ModelInfo::with_cost("gpt-4o", 128000, 0.0000025, 0.00001);
assert_eq!(info.name, "gpt-4o");
assert_eq!(info.context_limit, 128000);
assert_eq!(info.input_token_cost, Some(0.0000025));
assert_eq!(info.output_token_cost, Some(0.00001));
assert_eq!(info.currency, Some("$".to_string()));
}
}

View File

@@ -18,6 +18,7 @@ pub mod oauth;
pub mod ollama;
pub mod openai;
pub mod openrouter;
pub mod pricing;
pub mod sagemaker_tgi;
pub mod snowflake;
pub mod toolshim;

View File

@@ -5,7 +5,7 @@ use serde_json::Value;
use std::collections::HashMap;
use std::time::Duration;
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse};
use super::errors::ProviderError;
use super::formats::openai::{create_request, get_usage, response_to_message};
@@ -126,12 +126,20 @@ impl OpenAiProvider {
#[async_trait]
impl Provider for OpenAiProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
ProviderMetadata::with_models(
"openai",
"OpenAI",
"GPT-4 and other OpenAI models, including OpenAI compatible ones",
OPEN_AI_DEFAULT_MODEL,
OPEN_AI_KNOWN_MODELS.to_vec(),
vec![
ModelInfo::with_cost("gpt-4o", 128000, 0.0000025, 0.00001),
ModelInfo::with_cost("gpt-4o-mini", 128000, 0.00000015, 0.0000006),
ModelInfo::with_cost("gpt-4-turbo", 128000, 0.00001, 0.00003),
ModelInfo::with_cost("gpt-3.5-turbo", 16385, 0.0000005, 0.0000015),
ModelInfo::with_cost("o1", 200000, 0.000015, 0.00006),
ModelInfo::with_cost("o3", 200000, 0.000015, 0.00006), // Using o1 pricing as placeholder
ModelInfo::with_cost("o4-mini", 128000, 0.000003, 0.000012), // Using o1-mini pricing as placeholder
],
OPEN_AI_DOC_URL,
vec![
ConfigKey::new("OPENAI_API_KEY", true, true, None),

View File

@@ -0,0 +1,387 @@
use anyhow::Result;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
/// Disk cache configuration
const CACHE_FILE_NAME: &str = "pricing_cache.json";
const CACHE_TTL_DAYS: u64 = 7; // Cache for 7 days
/// Get the cache directory path
fn get_cache_dir() -> Result<PathBuf> {
let cache_dir = if let Ok(goose_dir) = std::env::var("GOOSE_CACHE_DIR") {
PathBuf::from(goose_dir)
} else {
dirs::cache_dir()
.ok_or_else(|| anyhow::anyhow!("Could not determine cache directory"))?
.join("goose")
};
Ok(cache_dir)
}
/// Cached pricing data structure for disk storage
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedPricingData {
/// Nested HashMap: provider -> model -> pricing info
pub pricing: HashMap<String, HashMap<String, PricingInfo>>,
/// Unix timestamp when data was fetched
pub fetched_at: u64,
}
/// Simplified pricing info for efficient storage
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PricingInfo {
pub input_cost: f64, // Cost per token
pub output_cost: f64, // Cost per token
pub context_length: Option<u32>,
}
/// Cache for OpenRouter pricing data with disk persistence
pub struct PricingCache {
/// In-memory cache
memory_cache: Arc<RwLock<Option<CachedPricingData>>>,
}
impl PricingCache {
pub fn new() -> Self {
Self {
memory_cache: Arc::new(RwLock::new(None)),
}
}
/// Load pricing from disk cache
async fn load_from_disk(&self) -> Result<Option<CachedPricingData>> {
let cache_path = get_cache_dir()?.join(CACHE_FILE_NAME);
if !cache_path.exists() {
return Ok(None);
}
match tokio::fs::read(&cache_path).await {
Ok(data) => {
match serde_json::from_slice::<CachedPricingData>(&data) {
Ok(cached) => {
// Check if cache is still valid
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
let age_days = (now - cached.fetched_at) / (24 * 60 * 60);
if age_days < CACHE_TTL_DAYS {
tracing::info!(
"Loaded pricing data from disk cache (age: {} days)",
age_days
);
Ok(Some(cached))
} else {
tracing::info!("Disk cache expired (age: {} days)", age_days);
Ok(None)
}
}
Err(e) => {
tracing::warn!("Failed to parse pricing cache: {}", e);
Ok(None)
}
}
}
Err(e) => {
tracing::warn!("Failed to read pricing cache: {}", e);
Ok(None)
}
}
}
/// Save pricing data to disk
async fn save_to_disk(&self, data: &CachedPricingData) -> Result<()> {
let cache_dir = get_cache_dir()?;
tokio::fs::create_dir_all(&cache_dir).await?;
let cache_path = cache_dir.join(CACHE_FILE_NAME);
let json_data = serde_json::to_vec_pretty(data)?;
tokio::fs::write(&cache_path, json_data).await?;
tracing::info!("Saved pricing data to disk cache");
Ok(())
}
/// Get pricing for a specific model
pub async fn get_model_pricing(&self, provider: &str, model: &str) -> Option<PricingInfo> {
// Try memory cache first
{
let cache = self.memory_cache.read().await;
if let Some(cached) = &*cache {
return cached
.pricing
.get(&provider.to_lowercase())
.and_then(|models| models.get(model))
.cloned();
}
}
// Try loading from disk
if let Ok(Some(disk_cache)) = self.load_from_disk().await {
// Update memory cache
{
let mut cache = self.memory_cache.write().await;
*cache = Some(disk_cache.clone());
}
return disk_cache
.pricing
.get(&provider.to_lowercase())
.and_then(|models| models.get(model))
.cloned();
}
None
}
/// Force refresh pricing data from OpenRouter
pub async fn refresh(&self) -> Result<()> {
let pricing = fetch_openrouter_pricing_internal().await?;
// Convert to our efficient structure
let mut structured_pricing: HashMap<String, HashMap<String, PricingInfo>> = HashMap::new();
for (model_id, model) in pricing {
if let Some((provider, model_name)) = parse_model_id(&model_id) {
if let (Some(input_cost), Some(output_cost)) = (
convert_pricing(&model.pricing.prompt),
convert_pricing(&model.pricing.completion),
) {
let provider_lower = provider.to_lowercase();
let provider_models = structured_pricing.entry(provider_lower).or_default();
provider_models.insert(
model_name,
PricingInfo {
input_cost,
output_cost,
context_length: model.context_length,
},
);
}
}
}
let cached_data = CachedPricingData {
pricing: structured_pricing,
fetched_at: SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(),
};
// Log how many models we fetched
let total_models: usize = cached_data
.pricing
.values()
.map(|models| models.len())
.sum();
tracing::info!(
"Fetched pricing for {} providers with {} total models from OpenRouter",
cached_data.pricing.len(),
total_models
);
// Save to disk
self.save_to_disk(&cached_data).await?;
// Update memory cache
{
let mut cache = self.memory_cache.write().await;
*cache = Some(cached_data);
}
Ok(())
}
/// Initialize cache (load from disk or fetch if needed)
pub async fn initialize(&self) -> Result<()> {
// Try loading from disk first
if let Ok(Some(cached)) = self.load_from_disk().await {
// Log how many models we have cached
let total_models: usize = cached.pricing.values().map(|models| models.len()).sum();
tracing::info!(
"Loaded {} providers with {} total models from disk cache",
cached.pricing.len(),
total_models
);
// Update memory cache
{
let mut cache = self.memory_cache.write().await;
*cache = Some(cached);
}
return Ok(());
}
// If no disk cache, fetch from OpenRouter
tracing::info!("No valid disk cache found, fetching from OpenRouter");
self.refresh().await
}
}
impl Default for PricingCache {
fn default() -> Self {
Self::new()
}
}
// Global cache instance
lazy_static::lazy_static! {
static ref PRICING_CACHE: PricingCache = PricingCache::new();
static ref HTTP_CLIENT: Client = Client::builder()
.timeout(Duration::from_secs(30))
.pool_idle_timeout(Duration::from_secs(90))
.pool_max_idle_per_host(10)
.build()
.unwrap();
}
/// OpenRouter model pricing information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenRouterModel {
pub id: String,
pub name: String,
pub pricing: OpenRouterPricing,
pub context_length: Option<u32>,
pub architecture: Option<Architecture>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenRouterPricing {
pub prompt: String, // Cost per token for input (in USD)
pub completion: String, // Cost per token for output (in USD)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Architecture {
pub modality: String,
pub tokenizer: String,
pub instruct_type: Option<String>,
}
/// Response from OpenRouter models endpoint
#[derive(Debug, Deserialize)]
pub struct OpenRouterModelsResponse {
pub data: Vec<OpenRouterModel>,
}
/// Internal function to fetch pricing data
async fn fetch_openrouter_pricing_internal() -> Result<HashMap<String, OpenRouterModel>> {
let response = HTTP_CLIENT
.get("https://openrouter.ai/api/v1/models")
.send()
.await?;
if !response.status().is_success() {
anyhow::bail!(
"Failed to fetch OpenRouter models: HTTP {}",
response.status()
);
}
let models_response: OpenRouterModelsResponse = response.json().await?;
// Create a map for easy lookup
let mut pricing_map = HashMap::new();
for model in models_response.data {
pricing_map.insert(model.id.clone(), model);
}
Ok(pricing_map)
}
/// Initialize pricing cache on startup
pub async fn initialize_pricing_cache() -> Result<()> {
PRICING_CACHE.initialize().await
}
/// Get pricing for a specific model
pub async fn get_model_pricing(provider: &str, model: &str) -> Option<PricingInfo> {
PRICING_CACHE.get_model_pricing(provider, model).await
}
/// Force refresh pricing data
pub async fn refresh_pricing() -> Result<()> {
PRICING_CACHE.refresh().await
}
/// Get all cached pricing data
pub async fn get_all_pricing() -> HashMap<String, HashMap<String, PricingInfo>> {
let cache = PRICING_CACHE.memory_cache.read().await;
if let Some(cached) = &*cache {
cached.pricing.clone()
} else {
// Try loading from disk
if let Ok(Some(disk_cache)) = PRICING_CACHE.load_from_disk().await {
// Update memory cache
drop(cache);
let mut write_cache = PRICING_CACHE.memory_cache.write().await;
*write_cache = Some(disk_cache.clone());
disk_cache.pricing
} else {
HashMap::new()
}
}
}
/// Convert OpenRouter model ID to provider/model format
/// e.g., "anthropic/claude-3.5-sonnet" -> ("anthropic", "claude-3.5-sonnet")
pub fn parse_model_id(model_id: &str) -> Option<(String, String)> {
let parts: Vec<&str> = model_id.splitn(2, '/').collect();
if parts.len() == 2 {
// Normalize provider names to match our internal naming
let provider = match parts[0] {
"openai" => "openai",
"anthropic" => "anthropic",
"google" => "google",
"meta-llama" => "ollama", // Meta models often run via Ollama
"mistralai" => "mistral",
"cohere" => "cohere",
"perplexity" => "perplexity",
"deepseek" => "deepseek",
"groq" => "groq",
"nvidia" => "nvidia",
"microsoft" => "azure",
"replicate" => "replicate",
"huggingface" => "huggingface",
_ => parts[0],
};
Some((provider.to_string(), parts[1].to_string()))
} else {
None
}
}
/// Convert OpenRouter pricing to cost per token (already in that format)
pub fn convert_pricing(price_str: &str) -> Option<f64> {
// OpenRouter prices are already in USD per token
price_str.parse::<f64>().ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_model_id() {
assert_eq!(
parse_model_id("anthropic/claude-3.5-sonnet"),
Some(("anthropic".to_string(), "claude-3.5-sonnet".to_string()))
);
assert_eq!(
parse_model_id("openai/gpt-4"),
Some(("openai".to_string(), "gpt-4".to_string()))
);
assert_eq!(parse_model_id("invalid-format"), None);
}
#[test]
fn test_convert_pricing() {
assert_eq!(convert_pricing("0.000003"), Some(0.000003));
assert_eq!(convert_pricing("0.015"), Some(0.015));
assert_eq!(convert_pricing("invalid"), None);
}
}

View File

@@ -1700,9 +1700,26 @@
"description": "The maximum context length this model supports",
"minimum": 0
},
"currency": {
"type": "string",
"description": "Currency for the costs (default: \"$\")",
"nullable": true
},
"input_token_cost": {
"type": "number",
"format": "double",
"description": "Cost per token for input (optional)",
"nullable": true
},
"name": {
"type": "string",
"description": "The name of the model"
},
"output_token_cost": {
"type": "number",
"format": "double",
"description": "Cost per token for output (optional)",
"nullable": true
}
}
},

View File

@@ -3,6 +3,7 @@ import { IpcRendererEvent } from 'electron';
import { openSharedSessionFromDeepLink, type SessionLinksViewOptions } from './sessionLinks';
import { type SharedSessionDetails } from './sharedSessions';
import { initializeSystem } from './utils/providerUtils';
import { initializeCostDatabase } from './utils/costDatabase';
import { ErrorUI } from './components/ErrorBoundary';
import { ConfirmationModal } from './components/ui/ConfirmationModal';
import { ToastContainer } from 'react-toastify';
@@ -158,6 +159,11 @@ export default function App() {
const initializeApp = async () => {
try {
// Initialize cost database early to pre-load pricing data
initializeCostDatabase().catch((error) => {
console.error('Failed to initialize cost database:', error);
});
await initConfig();
try {
await readAllConfig({ throwOnError: true });

View File

@@ -229,10 +229,22 @@ export type ModelInfo = {
* The maximum context length this model supports
*/
context_limit: number;
/**
* Currency for the costs (default: "$")
*/
currency?: string | null;
/**
* Cost per token for input (optional)
*/
input_token_cost?: number | null;
/**
* The name of the model
*/
name: string;
/**
* Cost per token for output (optional)
*/
output_token_cost?: number | null;
};
export type PermissionConfirmationRequest = {

View File

@@ -29,9 +29,18 @@ interface ChatInputProps {
droppedFiles?: string[];
setView: (view: View) => void;
numTokens?: number;
inputTokens?: number;
outputTokens?: number;
hasMessages?: boolean;
messages?: Message[];
setMessages: (messages: Message[]) => void;
sessionCosts?: {
[key: string]: {
inputTokens: number;
outputTokens: number;
totalCost: number;
};
};
}
export default function ChatInput({
@@ -42,9 +51,12 @@ export default function ChatInput({
initialValue = '',
setView,
numTokens,
inputTokens,
outputTokens,
droppedFiles = [],
messages = [],
setMessages,
sessionCosts,
}: ChatInputProps) {
const [_value, setValue] = useState(initialValue);
const [displayValue, setDisplayValue] = useState(initialValue); // For immediate visual feedback
@@ -557,9 +569,12 @@ export default function ChatInput({
<BottomMenu
setView={setView}
numTokens={numTokens}
inputTokens={inputTokens}
outputTokens={outputTokens}
messages={messages}
isLoading={isLoading}
setMessages={setMessages}
sessionCosts={sessionCosts}
/>
</div>
</div>

View File

@@ -33,6 +33,8 @@ import {
} from './context_management/ChatContextManager';
import { ContextHandler } from './context_management/ContextHandler';
import { LocalMessageStorage } from '../utils/localMessageStorage';
import { useModelAndProvider } from './ModelAndProviderContext';
import { getCostForModel } from '../utils/costDatabase';
import {
Message,
createUserMessage,
@@ -106,11 +108,25 @@ function ChatContent({
const [showGame, setShowGame] = useState(false);
const [isGeneratingRecipe, setIsGeneratingRecipe] = useState(false);
const [sessionTokenCount, setSessionTokenCount] = useState<number>(0);
const [sessionInputTokens, setSessionInputTokens] = useState<number>(0);
const [sessionOutputTokens, setSessionOutputTokens] = useState<number>(0);
const [localInputTokens, setLocalInputTokens] = useState<number>(0);
const [localOutputTokens, setLocalOutputTokens] = useState<number>(0);
const [ancestorMessages, setAncestorMessages] = useState<Message[]>([]);
const [droppedFiles, setDroppedFiles] = useState<string[]>([]);
const [sessionCosts, setSessionCosts] = useState<{
[key: string]: {
inputTokens: number;
outputTokens: number;
totalCost: number;
};
}>({});
const [readyForAutoUserPrompt, setReadyForAutoUserPrompt] = useState(false);
const scrollRef = useRef<ScrollAreaHandle>(null);
const { currentModel, currentProvider } = useModelAndProvider();
const prevModelRef = useRef<string | undefined>();
const prevProviderRef = useRef<string | undefined>();
const {
summaryContent,
@@ -160,6 +176,7 @@ function ChatContent({
updateMessageStreamBody,
notifications,
currentModelInfo,
sessionMetadata,
} = useMessageStream({
api: getApiUrl('/reply'),
initialMessages: chat.messages,
@@ -518,12 +535,40 @@ function ChatContent({
.reverse();
}, [filteredMessages]);
// Simple token estimation function (roughly 4 characters per token)
const estimateTokens = (text: string): number => {
return Math.ceil(text.length / 4);
};
// Calculate token counts from messages
useEffect(() => {
let inputTokens = 0;
let outputTokens = 0;
messages.forEach((message) => {
const textContent = getTextContent(message);
if (textContent) {
const tokens = estimateTokens(textContent);
if (message.role === 'user') {
inputTokens += tokens;
} else if (message.role === 'assistant') {
outputTokens += tokens;
}
}
});
setLocalInputTokens(inputTokens);
setLocalOutputTokens(outputTokens);
}, [messages]);
// Fetch session metadata to get token count
useEffect(() => {
const fetchSessionTokens = async () => {
try {
const sessionDetails = await fetchSessionDetails(chat.id);
setSessionTokenCount(sessionDetails.metadata.total_tokens || 0);
setSessionInputTokens(sessionDetails.metadata.accumulated_input_tokens || 0);
setSessionOutputTokens(sessionDetails.metadata.accumulated_output_tokens || 0);
} catch (err) {
console.error('Error fetching session token count:', err);
}
@@ -533,6 +578,74 @@ function ChatContent({
}
}, [chat.id, messages]);
// Update token counts when sessionMetadata changes from the message stream
useEffect(() => {
console.log('Session metadata received:', sessionMetadata);
if (sessionMetadata) {
setSessionTokenCount(sessionMetadata.totalTokens || 0);
setSessionInputTokens(sessionMetadata.accumulatedInputTokens || 0);
setSessionOutputTokens(sessionMetadata.accumulatedOutputTokens || 0);
}
}, [sessionMetadata]);
// Handle model changes and accumulate costs
useEffect(() => {
if (
prevModelRef.current !== undefined &&
prevProviderRef.current !== undefined &&
(prevModelRef.current !== currentModel || prevProviderRef.current !== currentProvider)
) {
// Model/provider has changed, save the costs for the previous model
const prevKey = `${prevProviderRef.current}/${prevModelRef.current}`;
// Get pricing info for the previous model
const prevCostInfo = getCostForModel(prevProviderRef.current, prevModelRef.current);
if (prevCostInfo) {
const prevInputCost =
(sessionInputTokens || localInputTokens) * (prevCostInfo.input_token_cost || 0);
const prevOutputCost =
(sessionOutputTokens || localOutputTokens) * (prevCostInfo.output_token_cost || 0);
const prevTotalCost = prevInputCost + prevOutputCost;
// Save the accumulated costs for this model
setSessionCosts((prev) => ({
...prev,
[prevKey]: {
inputTokens: sessionInputTokens || localInputTokens,
outputTokens: sessionOutputTokens || localOutputTokens,
totalCost: prevTotalCost,
},
}));
}
// Reset token counters for the new model
setSessionTokenCount(0);
setSessionInputTokens(0);
setSessionOutputTokens(0);
setLocalInputTokens(0);
setLocalOutputTokens(0);
console.log(
'Model changed from',
`${prevProviderRef.current}/${prevModelRef.current}`,
'to',
`${currentProvider}/${currentModel}`,
'- saved costs and reset token counters'
);
}
prevModelRef.current = currentModel || undefined;
prevProviderRef.current = currentProvider || undefined;
}, [
currentModel,
currentProvider,
sessionInputTokens,
sessionOutputTokens,
localInputTokens,
localOutputTokens,
]);
const handleDrop = (e: React.DragEvent<HTMLDivElement>) => {
e.preventDefault();
const files = e.dataTransfer.files;
@@ -684,9 +797,12 @@ function ChatContent({
setView={setView}
hasMessages={hasMessages}
numTokens={sessionTokenCount}
inputTokens={sessionInputTokens || localInputTokens}
outputTokens={sessionOutputTokens || localOutputTokens}
droppedFiles={droppedFiles}
messages={messages}
setMessages={setMessages}
sessionCosts={sessionCosts}
/>
</div>
</Card>

View File

@@ -9,6 +9,7 @@ import { useConfig } from '../ConfigContext';
import { useModelAndProvider } from '../ModelAndProviderContext';
import { Message } from '../../types/message';
import { ManualSummarizeButton } from '../context_management/ManualSummaryButton';
import { CostTracker } from './CostTracker';
const TOKEN_LIMIT_DEFAULT = 128000; // fallback for custom models that the backend doesn't know about
const TOKEN_WARNING_THRESHOLD = 0.8; // warning shows at 80% of the token limit
@@ -22,15 +23,27 @@ interface ModelLimit {
export default function BottomMenu({
setView,
numTokens = 0,
inputTokens = 0,
outputTokens = 0,
messages = [],
isLoading = false,
setMessages,
sessionCosts,
}: {
setView: (view: View, viewOptions?: ViewOptions) => void;
numTokens?: number;
inputTokens?: number;
outputTokens?: number;
messages?: Message[];
isLoading?: boolean;
setMessages: (messages: Message[]) => void;
sessionCosts?: {
[key: string]: {
inputTokens: number;
outputTokens: number;
totalCost: number;
};
};
}) {
const [isModelMenuOpen, setIsModelMenuOpen] = useState(false);
const { alerts, addAlert, clearAlerts } = useAlerts();
@@ -202,29 +215,45 @@ export default function BottomMenu({
}, [isModelMenuOpen]);
return (
<div className="flex justify-between items-center transition-colors text-textSubtle relative text-xs align-middle">
<div className="flex items-center pl-2">
<div className="flex justify-between items-center transition-colors text-textSubtle relative text-xs h-6">
<div className="flex items-center h-full">
{/* Tool and Token count */}
{<BottomMenuAlertPopover alerts={alerts} />}
<div className="flex items-center h-full pl-2">
{<BottomMenuAlertPopover alerts={alerts} />}
</div>
{/* Cost Tracker - no separator before it */}
<div className="flex items-center h-full ml-1">
<CostTracker inputTokens={inputTokens} outputTokens={outputTokens} sessionCosts={sessionCosts} />
</div>
{/* Separator between cost and model */}
<div className="w-[1px] h-4 bg-borderSubtle mx-1.5" />
{/* Model Selector Dropdown */}
<ModelsBottomBar dropdownRef={dropdownRef} setView={setView} />
<div className="flex items-center h-full">
<ModelsBottomBar dropdownRef={dropdownRef} setView={setView} />
</div>
{/* Separator */}
<div className="w-[1px] h-4 bg-borderSubtle mx-2" />
<div className="w-[1px] h-4 bg-borderSubtle mx-1.5" />
{/* Goose Mode Selector Dropdown */}
<BottomMenuModeSelection setView={setView} />
<div className="flex items-center h-full">
<BottomMenuModeSelection setView={setView} />
</div>
{/* Summarize Context Button - ADD THIS */}
{/* Summarize Context Button */}
{messages.length > 0 && (
<>
<div className="w-[1px] h-4 bg-borderSubtle mx-2" />
<ManualSummarizeButton
messages={messages}
isLoading={isLoading}
setMessages={setMessages}
/>
<div className="w-[1px] h-4 bg-borderSubtle mx-1.5" />
<div className="flex items-center h-full">
<ManualSummarizeButton
messages={messages}
isLoading={isLoading}
setMessages={setMessages}
/>
</div>
</>
)}
</div>

View File

@@ -0,0 +1,310 @@
import { useState, useEffect } from 'react';
import { useModelAndProvider } from '../ModelAndProviderContext';
import { useConfig } from '../ConfigContext';
import {
getCostForModel,
initializeCostDatabase,
updateAllModelCosts,
fetchAndCachePricing,
} from '../../utils/costDatabase';
interface CostTrackerProps {
inputTokens?: number;
outputTokens?: number;
sessionCosts?: {
[key: string]: {
inputTokens: number;
outputTokens: number;
totalCost: number;
};
};
}
export function CostTracker({ inputTokens = 0, outputTokens = 0, sessionCosts }: CostTrackerProps) {
const { currentModel, currentProvider } = useModelAndProvider();
const { getProviders } = useConfig();
const [costInfo, setCostInfo] = useState<{
input_token_cost?: number;
output_token_cost?: number;
currency?: string;
} | null>(null);
const [isLoading, setIsLoading] = useState(true);
const [showPricing, setShowPricing] = useState(true);
const [pricingFailed, setPricingFailed] = useState(false);
const [modelNotFound, setModelNotFound] = useState(false);
const [hasAttemptedFetch, setHasAttemptedFetch] = useState(false);
const [initialLoadComplete, setInitialLoadComplete] = useState(false);
// Check if pricing is enabled
useEffect(() => {
const checkPricingSetting = () => {
const stored = localStorage.getItem('show_pricing');
setShowPricing(stored !== 'false');
};
// Check on mount
checkPricingSetting();
// Listen for storage changes
window.addEventListener('storage', checkPricingSetting);
return () => window.removeEventListener('storage', checkPricingSetting);
}, []);
// Set initial load complete after a short delay
useEffect(() => {
const timer = setTimeout(() => {
setInitialLoadComplete(true);
}, 3000); // Give 3 seconds for initial load
return () => window.clearTimeout(timer);
}, []);
// Debug log props removed
// Initialize cost database on mount
useEffect(() => {
initializeCostDatabase();
// Update costs for all models in background
updateAllModelCosts().catch((error) => {
console.error('Failed to update model costs:', error);
});
}, [getProviders]);
useEffect(() => {
const loadCostInfo = async () => {
if (!currentModel || !currentProvider) {
setIsLoading(false);
return;
}
console.log(`CostTracker: Loading cost info for ${currentProvider}/${currentModel}`);
try {
// First check sync cache
let costData = getCostForModel(currentProvider, currentModel);
if (costData) {
// We have cached data
console.log(
`CostTracker: Found cached data for ${currentProvider}/${currentModel}:`,
costData
);
setCostInfo(costData);
setPricingFailed(false);
setModelNotFound(false);
setIsLoading(false);
setHasAttemptedFetch(true);
} else {
// Need to fetch from backend
console.log(
`CostTracker: No cached data, fetching from backend for ${currentProvider}/${currentModel}`
);
setIsLoading(true);
const result = await fetchAndCachePricing(currentProvider, currentModel);
setHasAttemptedFetch(true);
if (result && result.costInfo) {
console.log(
`CostTracker: Fetched data for ${currentProvider}/${currentModel}:`,
result.costInfo
);
setCostInfo(result.costInfo);
setPricingFailed(false);
setModelNotFound(false);
} else if (result && result.error === 'model_not_found') {
console.log(
`CostTracker: Model not found in pricing data for ${currentProvider}/${currentModel}`
);
// Model not found in pricing database, but API call succeeded
setModelNotFound(true);
setPricingFailed(false);
} else {
console.log(`CostTracker: API failed for ${currentProvider}/${currentModel}`);
// API call failed or other error
const freeProviders = ['ollama', 'local', 'localhost'];
if (!freeProviders.includes(currentProvider.toLowerCase())) {
setPricingFailed(true);
setModelNotFound(false);
}
}
setIsLoading(false);
}
} catch (error) {
console.error('Error loading cost info:', error);
setHasAttemptedFetch(true);
// Only set pricing failed if we're not dealing with a known free provider
const freeProviders = ['ollama', 'local', 'localhost'];
if (!freeProviders.includes(currentProvider.toLowerCase())) {
setPricingFailed(true);
setModelNotFound(false);
}
setIsLoading(false);
}
};
loadCostInfo();
}, [currentModel, currentProvider]);
// Return null early if pricing is disabled
if (!showPricing) {
return null;
}
const calculateCost = (): number => {
// If we have session costs, calculate the total across all models
if (sessionCosts) {
let totalCost = 0;
// Add up all historical costs from different models
Object.values(sessionCosts).forEach((modelCost) => {
totalCost += modelCost.totalCost;
});
// Add current model cost if we have pricing info
if (
costInfo &&
(costInfo.input_token_cost !== undefined || costInfo.output_token_cost !== undefined)
) {
const currentInputCost = inputTokens * (costInfo.input_token_cost || 0);
const currentOutputCost = outputTokens * (costInfo.output_token_cost || 0);
totalCost += currentInputCost + currentOutputCost;
}
return totalCost;
}
// Fallback to simple calculation for current model only
if (
!costInfo ||
(costInfo.input_token_cost === undefined && costInfo.output_token_cost === undefined)
) {
return 0;
}
const inputCost = inputTokens * (costInfo.input_token_cost || 0);
const outputCost = outputTokens * (costInfo.output_token_cost || 0);
const total = inputCost + outputCost;
return total;
};
const formatCost = (cost: number): string => {
// Always show 6 decimal places for consistency
return cost.toFixed(6);
};
// Debug logging removed
// Show loading state or when we don't have model/provider info
if (!currentModel || !currentProvider) {
return null;
}
// If still loading, show a placeholder
if (isLoading) {
return (
<div className="flex items-center justify-center h-full text-textSubtle translate-y-[1px]">
<span className="text-xs font-mono">...</span>
</div>
);
}
// If no cost info found, try to return a default
if (
!costInfo ||
(costInfo.input_token_cost === undefined && costInfo.output_token_cost === undefined)
) {
// If it's a known free/local provider, show $0.000000 without "not available" message
const freeProviders = ['ollama', 'local', 'localhost'];
if (freeProviders.includes(currentProvider.toLowerCase())) {
return (
<div
className="flex items-center justify-center h-full text-textSubtle hover:text-textStandard transition-colors cursor-default translate-y-[1px]"
title={`Local model (${inputTokens.toLocaleString()} input, ${outputTokens.toLocaleString()} output tokens)`}
>
<span className="text-xs font-mono">$0.000000</span>
</div>
);
}
// Otherwise show as unavailable
const getUnavailableTooltip = () => {
if (pricingFailed && hasAttemptedFetch && initialLoadComplete) {
return `Pricing data unavailable - OpenRouter connection failed. Click refresh in settings to retry.`;
}
// If we reach here, it must be modelNotFound (since we only get here after attempting fetch)
return `Cost data not available for ${currentModel} (${inputTokens.toLocaleString()} input, ${outputTokens.toLocaleString()} output tokens)`;
};
return (
<div
className={`flex items-center justify-center h-full transition-colors cursor-default translate-y-[1px] ${
(pricingFailed || modelNotFound) && hasAttemptedFetch && initialLoadComplete
? 'text-red-500 hover:text-red-400'
: 'text-textSubtle hover:text-textStandard'
}`}
title={getUnavailableTooltip()}
>
<span className="text-xs font-mono">$0.000000</span>
</div>
);
}
const totalCost = calculateCost();
// Build tooltip content
const getTooltipContent = (): string => {
// Handle error states first
if (pricingFailed && hasAttemptedFetch && initialLoadComplete) {
return `Pricing data unavailable - OpenRouter connection failed. Click refresh in settings to retry.`;
}
if (modelNotFound && hasAttemptedFetch && initialLoadComplete) {
return `Pricing not available for ${currentProvider}/${currentModel}. This model may not be supported by the pricing service.`;
}
// Handle session costs
if (sessionCosts && Object.keys(sessionCosts).length > 0) {
// Show session breakdown
let tooltip = 'Session cost breakdown:\n';
Object.entries(sessionCosts).forEach(([modelKey, cost]) => {
const costStr = `${costInfo?.currency || '$'}${cost.totalCost.toFixed(6)}`;
tooltip += `${modelKey}: ${costStr} (${cost.inputTokens.toLocaleString()} in, ${cost.outputTokens.toLocaleString()} out)\n`;
});
// Add current model if it has costs
if (costInfo && (inputTokens > 0 || outputTokens > 0)) {
const currentCost =
inputTokens * (costInfo.input_token_cost || 0) +
outputTokens * (costInfo.output_token_cost || 0);
if (currentCost > 0) {
tooltip += `${currentProvider}/${currentModel} (current): ${costInfo.currency || '$'}${currentCost.toFixed(6)} (${inputTokens.toLocaleString()} in, ${outputTokens.toLocaleString()} out)\n`;
}
}
tooltip += `\nTotal session cost: ${costInfo?.currency || '$'}${totalCost.toFixed(6)}`;
return tooltip;
}
// Default tooltip for single model
return `Input: ${inputTokens.toLocaleString()} tokens (${costInfo?.currency || '$'}${(inputTokens * (costInfo?.input_token_cost || 0)).toFixed(6)}) | Output: ${outputTokens.toLocaleString()} tokens (${costInfo?.currency || '$'}${(outputTokens * (costInfo?.output_token_cost || 0)).toFixed(6)})`;
};
return (
<div
className={`flex items-center justify-center h-full transition-colors cursor-default translate-y-[1px] ${
(pricingFailed || modelNotFound) && hasAttemptedFetch && initialLoadComplete
? 'text-red-500 hover:text-red-400'
: 'text-textSubtle hover:text-textStandard'
}`}
title={getTooltipContent()}
>
<span className="text-xs font-mono">
{costInfo.currency || '$'}
{formatCost(totalCost)}
</span>
</div>
);
}

View File

@@ -1,10 +1,11 @@
import { useState, useEffect, useRef } from 'react';
import { Switch } from '../../ui/switch';
import { Button } from '../../ui/button';
import { Settings } from 'lucide-react';
import { Settings, RefreshCw, ExternalLink } from 'lucide-react';
import Modal from '../../Modal';
import UpdateSection from './UpdateSection';
import { UPDATES_ENABLED } from '../../../updates';
import { getApiUrl, getSecretKey } from '../../../config';
interface AppSettingsSectionProps {
scrollToSection?: string;
@@ -17,6 +18,10 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
const [isMacOS, setIsMacOS] = useState(false);
const [isDockSwitchDisabled, setIsDockSwitchDisabled] = useState(false);
const [showNotificationModal, setShowNotificationModal] = useState(false);
const [pricingStatus, setPricingStatus] = useState<'loading' | 'success' | 'error'>('loading');
const [lastFetchTime, setLastFetchTime] = useState<Date | null>(null);
const [isRefreshing, setIsRefreshing] = useState(false);
const [showPricing, setShowPricing] = useState(true);
const updateSectionRef = useRef<HTMLDivElement>(null);
// Check if running on macOS
@@ -24,6 +29,77 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
setIsMacOS(window.electron.platform === 'darwin');
}, []);
// Load show pricing setting
useEffect(() => {
const stored = localStorage.getItem('show_pricing');
setShowPricing(stored !== 'false');
}, []);
// Check pricing status on mount
useEffect(() => {
checkPricingStatus();
}, []);
const checkPricingStatus = async () => {
try {
const apiUrl = getApiUrl('/config/pricing');
const secretKey = getSecretKey();
const headers: HeadersInit = { 'Content-Type': 'application/json' };
if (secretKey) {
headers['X-Secret-Key'] = secretKey;
}
const response = await fetch(apiUrl, {
method: 'POST',
headers,
body: JSON.stringify({ configured_only: true }),
});
if (response.ok) {
await response.json(); // Consume the response
setPricingStatus('success');
setLastFetchTime(new Date());
} else {
setPricingStatus('error');
}
} catch (error) {
setPricingStatus('error');
}
};
const handleRefreshPricing = async () => {
setIsRefreshing(true);
try {
const apiUrl = getApiUrl('/config/pricing');
const secretKey = getSecretKey();
const headers: HeadersInit = { 'Content-Type': 'application/json' };
if (secretKey) {
headers['X-Secret-Key'] = secretKey;
}
const response = await fetch(apiUrl, {
method: 'POST',
headers,
body: JSON.stringify({ configured_only: false }),
});
if (response.ok) {
setPricingStatus('success');
setLastFetchTime(new Date());
// Trigger a reload of the cost database
window.dispatchEvent(new CustomEvent('pricing-updated'));
} else {
setPricingStatus('error');
}
} catch (error) {
setPricingStatus('error');
} finally {
setIsRefreshing(false);
}
};
// Handle scrolling to update section
useEffect(() => {
if (scrollToSection === 'update' && updateSectionRef.current) {
@@ -99,6 +175,13 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
}
};
const handleShowPricingToggle = (checked: boolean) => {
setShowPricing(checked);
localStorage.setItem('show_pricing', String(checked));
// Trigger storage event for other components
window.dispatchEvent(new CustomEvent('storage'));
};
return (
<section id="appSettings" className="px-8">
<div className="flex justify-between items-center mb-2">
@@ -173,6 +256,7 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
</div>
)}
{/* Quit Confirmation */}
<div className="flex items-center justify-between mb-4">
<div>
<h3 className="text-textStandard">Quit Confirmation</h3>
@@ -188,6 +272,87 @@ export default function AppSettingsSection({ scrollToSection }: AppSettingsSecti
/>
</div>
</div>
{/* Cost Tracking */}
<div className="flex items-center justify-between mb-4">
<div>
<h3 className="text-textStandard">Cost Tracking</h3>
<p className="text-xs text-textSubtle max-w-md mt-[2px]">
Show model pricing and usage costs
</p>
</div>
<div className="flex items-center">
<Switch
checked={showPricing}
onCheckedChange={handleShowPricingToggle}
variant="mono"
/>
</div>
</div>
{/* Pricing Status - only show if cost tracking is enabled */}
{showPricing && (
<>
<div className="flex items-center justify-between text-xs mb-2 px-4">
<span className="text-textSubtle">Pricing Source:</span>
<a
href="https://openrouter.ai/docs#models"
target="_blank"
rel="noopener noreferrer"
className="text-blue-600 dark:text-blue-400 hover:underline flex items-center gap-1"
>
OpenRouter Docs
<ExternalLink size={10} />
</a>
</div>
<div className="flex items-center justify-between text-xs mb-2 px-4">
<span className="text-textSubtle">Status:</span>
<div className="flex items-center gap-2">
<span
className={`font-medium ${
pricingStatus === 'success'
? 'text-green-600 dark:text-green-400'
: pricingStatus === 'error'
? 'text-red-600 dark:text-red-400'
: 'text-textSubtle'
}`}
>
{pricingStatus === 'success'
? '✓ Connected'
: pricingStatus === 'error'
? '✗ Failed'
: '... Checking'}
</span>
<button
className="p-0.5 hover:bg-gray-200 dark:hover:bg-gray-700 rounded transition-colors disabled:opacity-50"
onClick={handleRefreshPricing}
disabled={isRefreshing}
title="Refresh pricing data"
type="button"
>
<RefreshCw
size={8}
className={`text-textSubtle hover:text-textStandard ${isRefreshing ? 'animate-spin-fast' : ''}`}
/>
</button>
</div>
</div>
{lastFetchTime && (
<div className="flex items-center justify-between text-xs mb-2 px-4">
<span className="text-textSubtle">Last updated:</span>
<span className="text-textSubtle">{lastFetchTime.toLocaleTimeString()}</span>
</div>
)}
{pricingStatus === 'error' && (
<p className="text-xs text-red-600 dark:text-red-400 px-4">
Unable to fetch pricing data. Costs will not be displayed.
</p>
)}
</>
)}
</div>
{/* Help & Feedback Section */}

View File

@@ -43,7 +43,10 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa
}, [read]);
// Determine which model to display - activeModel takes priority when lead/worker is active
const displayModel = (isLeadWorkerActive && currentModelInfo?.model) ? currentModelInfo.model : (currentModel || 'Select Model');
const displayModel =
isLeadWorkerActive && currentModelInfo?.model
? currentModelInfo.model
: currentModel || 'Select Model';
const modelMode = currentModelInfo?.mode;
// Update display provider when current provider changes
@@ -106,9 +109,7 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa
>
{displayModel}
{isLeadWorkerActive && modelMode && (
<span className="ml-1 text-[10px] opacity-60">
({modelMode})
</span>
<span className="ml-1 text-[10px] opacity-60">({modelMode})</span>
)}
</span>
</TooltipTrigger>
@@ -116,9 +117,7 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa
<TooltipContent className="max-w-96 overflow-auto scrollbar-thin" side="top">
{displayModel}
{isLeadWorkerActive && modelMode && (
<span className="ml-1 text-[10px] opacity-60">
({modelMode})
</span>
<span className="ml-1 text-[10px] opacity-60">({modelMode})</span>
)}
</TooltipContent>
)}
@@ -164,7 +163,7 @@ export default function ModelsBottomBar({ dropdownRef, setView }: ModelsBottomBa
{isAddModelModalOpen ? (
<AddModelModal setView={setView} onClose={() => setIsAddModelModalOpen(false)} />
) : null}
{isLeadWorkerModalOpen ? (
<Modal onClose={() => setIsLeadWorkerModalOpen(false)}>
<LeadWorkerSettings onClose={() => setIsLeadWorkerModalOpen(false)} />

View File

@@ -21,7 +21,9 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
const [failureThreshold, setFailureThreshold] = useState<number>(2);
const [fallbackTurns, setFallbackTurns] = useState<number>(2);
const [isEnabled, setIsEnabled] = useState(false);
const [modelOptions, setModelOptions] = useState<{ value: string; label: string; provider: string }[]>([]);
const [modelOptions, setModelOptions] = useState<
{ value: string; label: string; provider: string }[]
>([]);
const [isLoading, setIsLoading] = useState(true);
// Load current configuration
@@ -51,7 +53,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
if (leadTurnsConfig) setLeadTurns(Number(leadTurnsConfig));
if (failureThresholdConfig) setFailureThreshold(Number(failureThresholdConfig));
if (fallbackTurnsConfig) setFallbackTurns(Number(fallbackTurnsConfig));
// Set worker model to current model or from config
const workerModelConfig = await read('GOOSE_MODEL', false);
if (workerModelConfig) {
@@ -59,7 +61,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
} else if (currentModel) {
setWorkerModel(currentModel as string);
}
const workerProviderConfig = await read('GOOSE_PROVIDER', false);
if (workerProviderConfig) {
setWorkerProvider(workerProviderConfig as string);
@@ -69,7 +71,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
const providers = await getProviders(false);
const activeProviders = providers.filter((p) => p.is_configured);
const options: { value: string; label: string; provider: string }[] = [];
activeProviders.forEach(({ metadata, name }) => {
if (metadata.known_models) {
metadata.known_models.forEach((model) => {
@@ -81,7 +83,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
});
}
});
setModelOptions(options);
} catch (error) {
console.error('Error loading configuration:', error);
@@ -184,9 +186,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
placeholder="Select worker model..."
isDisabled={!isEnabled}
/>
<p className="text-xs text-textSubtle">
Fast model for routine execution tasks
</p>
<p className="text-xs text-textSubtle">Fast model for routine execution tasks</p>
</div>
<div className="space-y-4 pt-4 border-t border-borderSubtle">
@@ -242,9 +242,7 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
className="w-20"
disabled={!isEnabled}
/>
<p className="text-xs text-textSubtle">
Turns to use lead model during fallback
</p>
<p className="text-xs text-textSubtle">Turns to use lead model during fallback</p>
</div>
</div>
</div>
@@ -259,4 +257,4 @@ export function LeadWorkerSettings({ onClose }: LeadWorkerSettingsProps) {
</div>
</div>
);
}
}

View File

@@ -2,9 +2,9 @@ import { useCurrentModelInfo } from '../components/ChatView';
export function useCurrentModel() {
const modelInfo = useCurrentModelInfo();
return {
return {
currentModel: modelInfo?.model || null,
isLoading: false
isLoading: false,
};
}
}

View File

@@ -2,12 +2,26 @@ import { useState, useCallback, useEffect, useRef, useId } from 'react';
import useSWR from 'swr';
import { getSecretKey } from '../config';
import { Message, createUserMessage, hasCompletedToolCalls } from '../types/message';
import { getSessionHistory } from '../api';
// Ensure TextDecoder is available in the global scope
const TextDecoder = globalThis.TextDecoder;
type JsonValue = string | number | boolean | null | JsonValue[] | { [key: string]: JsonValue };
export interface SessionMetadata {
workingDir: string;
description: string;
scheduleId: string | null;
messageCount: number;
totalTokens: number | null;
inputTokens: number | null;
outputTokens: number | null;
accumulatedTotalTokens: number | null;
accumulatedInputTokens: number | null;
accumulatedOutputTokens: number | null;
}
export interface NotificationEvent {
type: 'Notification';
request_id: string;
@@ -141,9 +155,12 @@ export interface UseMessageStreamHelpers {
updateMessageStreamBody?: (newBody: object) => void;
notifications: NotificationEvent[];
/** Current model info from the backend */
currentModelInfo: { model: string; mode: string } | null;
/** Session metadata including token counts */
sessionMetadata: SessionMetadata | null;
}
/**
@@ -172,7 +189,10 @@ export function useMessageStream({
});
const [notifications, setNotifications] = useState<NotificationEvent[]>([]);
const [currentModelInfo, setCurrentModelInfo] = useState<{ model: string; mode: string } | null>(null);
const [currentModelInfo, setCurrentModelInfo] = useState<{ model: string; mode: string } | null>(
null
);
const [sessionMetadata, setSessionMetadata] = useState<SessionMetadata | null>(null);
// expose a way to update the body so we can update the session id when CLE occurs
const updateMessageStreamBody = useCallback((newBody: object) => {
@@ -291,13 +311,41 @@ export function useMessageStream({
case 'Error':
throw new Error(parsedEvent.error);
case 'Finish':
case 'Finish': {
// Call onFinish with the last message if available
if (onFinish && currentMessages.length > 0) {
const lastMessage = currentMessages[currentMessages.length - 1];
onFinish(lastMessage, parsedEvent.reason);
}
// Fetch updated session metadata with token counts
const sessionId = (extraMetadataRef.current.body as Record<string, unknown>)?.session_id as string;
if (sessionId) {
try {
const sessionResponse = await getSessionHistory({
path: { session_id: sessionId },
});
if (sessionResponse.data?.metadata) {
setSessionMetadata({
workingDir: sessionResponse.data.metadata.working_dir,
description: sessionResponse.data.metadata.description,
scheduleId: sessionResponse.data.metadata.schedule_id || null,
messageCount: sessionResponse.data.metadata.message_count,
totalTokens: sessionResponse.data.metadata.total_tokens || null,
inputTokens: sessionResponse.data.metadata.input_tokens || null,
outputTokens: sessionResponse.data.metadata.output_tokens || null,
accumulatedTotalTokens: sessionResponse.data.metadata.accumulated_total_tokens || null,
accumulatedInputTokens: sessionResponse.data.metadata.accumulated_input_tokens || null,
accumulatedOutputTokens: sessionResponse.data.metadata.accumulated_output_tokens || null,
});
}
} catch (error) {
console.error('Failed to fetch session metadata:', error);
}
}
break;
}
}
} catch (e) {
console.error('Error parsing SSE event:', e);
@@ -559,5 +607,6 @@ export function useMessageStream({
updateMessageStreamBody,
notifications,
currentModelInfo,
sessionMetadata,
};
}

View File

@@ -7,6 +7,10 @@ export interface SessionMetadata {
message_count: number;
total_tokens: number | null;
working_dir: string; // Required in type, but may be missing in old sessions
// Add the accumulated token fields from the API
accumulated_input_tokens?: number | null;
accumulated_output_tokens?: number | null;
accumulated_total_tokens?: number | null;
}
// Helper function to ensure working directory is set
@@ -16,6 +20,9 @@ export function ensureWorkingDir(metadata: Partial<SessionMetadata>): SessionMet
message_count: metadata.message_count || 0,
total_tokens: metadata.total_tokens || null,
working_dir: metadata.working_dir || process.env.HOME || '',
accumulated_input_tokens: metadata.accumulated_input_tokens || null,
accumulated_output_tokens: metadata.accumulated_output_tokens || null,
accumulated_total_tokens: metadata.accumulated_total_tokens || null,
};
}

View File

@@ -0,0 +1,593 @@
// Import the proper type from ConfigContext
import { getApiUrl, getSecretKey } from '../config';
export interface ModelCostInfo {
input_token_cost: number; // Cost per token for input (in USD)
output_token_cost: number; // Cost per token for output (in USD)
currency: string; // Currency symbol
}
// In-memory cache for current model pricing only
let currentModelPricing: {
provider: string;
model: string;
costInfo: ModelCostInfo | null;
} | null = null;
// LocalStorage keys
const PRICING_CACHE_KEY = 'goose_pricing_cache';
const PRICING_CACHE_TIMESTAMP_KEY = 'goose_pricing_cache_timestamp';
const RECENTLY_USED_MODELS_KEY = 'goose_recently_used_models';
const CACHE_TTL_MS = 7 * 24 * 60 * 60 * 1000; // 7 days in milliseconds
const MAX_RECENTLY_USED_MODELS = 20; // Keep only the last 20 used models in cache
interface PricingItem {
provider: string;
model: string;
input_token_cost: number;
output_token_cost: number;
currency: string;
}
interface PricingCacheData {
pricing: PricingItem[];
timestamp: number;
}
interface RecentlyUsedModel {
provider: string;
model: string;
lastUsed: number;
}
/**
* Get recently used models from localStorage
*/
function getRecentlyUsedModels(): RecentlyUsedModel[] {
try {
const stored = localStorage.getItem(RECENTLY_USED_MODELS_KEY);
return stored ? JSON.parse(stored) : [];
} catch (error) {
console.error('Error loading recently used models:', error);
return [];
}
}
/**
* Add a model to the recently used list
*/
function addToRecentlyUsed(provider: string, model: string): void {
try {
let recentModels = getRecentlyUsedModels();
// Remove existing entry if present
recentModels = recentModels.filter((m) => !(m.provider === provider && m.model === model));
// Add to front
recentModels.unshift({ provider, model, lastUsed: Date.now() });
// Keep only the most recent models
recentModels = recentModels.slice(0, MAX_RECENTLY_USED_MODELS);
localStorage.setItem(RECENTLY_USED_MODELS_KEY, JSON.stringify(recentModels));
} catch (error) {
console.error('Error saving recently used models:', error);
}
}
/**
* Load pricing data from localStorage cache - only for recently used models
*/
function loadPricingFromLocalStorage(): PricingCacheData | null {
try {
const cached = localStorage.getItem(PRICING_CACHE_KEY);
const timestamp = localStorage.getItem(PRICING_CACHE_TIMESTAMP_KEY);
if (cached && timestamp) {
const cacheAge = Date.now() - parseInt(timestamp, 10);
if (cacheAge < CACHE_TTL_MS) {
const fullCache = JSON.parse(cached) as PricingCacheData;
const recentModels = getRecentlyUsedModels();
// Filter to only include recently used models
const filteredPricing = fullCache.pricing.filter((p) =>
recentModels.some((r) => r.provider === p.provider && r.model === p.model)
);
console.log(
`Loading ${filteredPricing.length} recently used models from cache (out of ${fullCache.pricing.length} total)`
);
return {
pricing: filteredPricing,
timestamp: fullCache.timestamp,
};
} else {
console.log('LocalStorage pricing cache expired');
}
}
} catch (error) {
console.error('Error loading pricing from localStorage:', error);
}
return null;
}
/**
* Save pricing data to localStorage - merge with existing data
*/
function savePricingToLocalStorage(data: PricingCacheData, mergeWithExisting = true): void {
try {
if (mergeWithExisting) {
// Load existing full cache
const existingCached = localStorage.getItem(PRICING_CACHE_KEY);
if (existingCached) {
const existingData = JSON.parse(existingCached) as PricingCacheData;
// Create a map of existing pricing for quick lookup
const pricingMap = new Map<string, (typeof data.pricing)[0]>();
existingData.pricing.forEach((p) => {
pricingMap.set(`${p.provider}/${p.model}`, p);
});
// Update with new data
data.pricing.forEach((p) => {
pricingMap.set(`${p.provider}/${p.model}`, p);
});
// Convert back to array
data = {
pricing: Array.from(pricingMap.values()),
timestamp: data.timestamp,
};
}
}
localStorage.setItem(PRICING_CACHE_KEY, JSON.stringify(data));
localStorage.setItem(PRICING_CACHE_TIMESTAMP_KEY, data.timestamp.toString());
console.log(`Saved ${data.pricing.length} models to localStorage cache`);
} catch (error) {
console.error('Error saving pricing to localStorage:', error);
}
}
/**
* Fetch pricing data from backend for specific provider/model
*/
async function fetchPricingForModel(
provider: string,
model: string
): Promise<ModelCostInfo | null> {
try {
const apiUrl = getApiUrl('/config/pricing');
const secretKey = getSecretKey();
console.log(`Fetching pricing for ${provider}/${model} from ${apiUrl}`);
const headers: HeadersInit = { 'Content-Type': 'application/json' };
if (secretKey) {
headers['X-Secret-Key'] = secretKey;
}
const response = await fetch(apiUrl, {
method: 'POST',
headers,
body: JSON.stringify({ configured_only: false }),
});
if (!response.ok) {
console.error('Failed to fetch pricing data:', response.status);
throw new Error(`API request failed with status ${response.status}`);
}
const data = await response.json();
console.log('Pricing response:', data);
// Find the specific model pricing
const pricing = data.pricing?.find(
(p: {
provider: string;
model: string;
input_token_cost: number;
output_token_cost: number;
currency: string;
}) => {
const providerMatch = p.provider.toLowerCase() === provider.toLowerCase();
// More flexible model matching - handle versioned models
let modelMatch = p.model === model;
// If exact match fails, try matching without version suffix
if (!modelMatch && model.includes('-20')) {
// Remove date suffix like -20241022
const modelWithoutDate = model.replace(/-20\d{6}$/, '');
modelMatch = p.model === modelWithoutDate;
// Also try with dots instead of dashes (claude-3-5-sonnet vs claude-3.5-sonnet)
if (!modelMatch) {
const modelWithDots = modelWithoutDate.replace(/-(\d)-/g, '.$1.');
modelMatch = p.model === modelWithDots;
}
}
console.log(
`Comparing: ${p.provider}/${p.model} with ${provider}/${model} - Provider match: ${providerMatch}, Model match: ${modelMatch}`
);
return providerMatch && modelMatch;
}
);
console.log(`Found pricing for ${provider}/${model}:`, pricing);
if (pricing) {
return {
input_token_cost: pricing.input_token_cost,
output_token_cost: pricing.output_token_cost,
currency: pricing.currency || '$',
};
}
console.log(
`No pricing found for ${provider}/${model} in:`,
data.pricing?.map((p: { provider: string; model: string }) => `${p.provider}/${p.model}`)
);
// API call succeeded but model not found in pricing data
return null;
} catch (error) {
console.error('Error fetching pricing data:', error);
// Re-throw the error so the caller can distinguish between API failure and model not found
throw error;
}
}
/**
* Initialize the cost database - only load commonly used models on startup
*/
export async function initializeCostDatabase(): Promise<void> {
try {
// Clean up any existing large caches first
cleanupPricingCache();
// First check if we have valid cached data
const cachedData = loadPricingFromLocalStorage();
if (cachedData && cachedData.pricing.length > 0) {
console.log('Using cached pricing data from localStorage');
return;
}
// List of commonly used models to pre-fetch
const commonModels = [
{ provider: 'openai', model: 'gpt-4o' },
{ provider: 'openai', model: 'gpt-4o-mini' },
{ provider: 'openai', model: 'gpt-4-turbo' },
{ provider: 'openai', model: 'gpt-4' },
{ provider: 'openai', model: 'gpt-3.5-turbo' },
{ provider: 'anthropic', model: 'claude-3-5-sonnet' },
{ provider: 'anthropic', model: 'claude-3-5-sonnet-20241022' },
{ provider: 'anthropic', model: 'claude-3-opus' },
{ provider: 'anthropic', model: 'claude-3-sonnet' },
{ provider: 'anthropic', model: 'claude-3-haiku' },
{ provider: 'google', model: 'gemini-1.5-pro' },
{ provider: 'google', model: 'gemini-1.5-flash' },
{ provider: 'deepseek', model: 'deepseek-chat' },
{ provider: 'deepseek', model: 'deepseek-reasoner' },
{ provider: 'meta-llama', model: 'llama-3.2-90b-text-preview' },
{ provider: 'meta-llama', model: 'llama-3.1-405b-instruct' },
];
// Get recently used models
const recentModels = getRecentlyUsedModels();
// Combine common and recent models (deduplicated)
const modelsToFetch = new Map<string, { provider: string; model: string }>();
// Add common models
commonModels.forEach((m) => {
modelsToFetch.set(`${m.provider}/${m.model}`, m);
});
// Add recent models
recentModels.forEach((m) => {
modelsToFetch.set(`${m.provider}/${m.model}`, { provider: m.provider, model: m.model });
});
console.log(`Initializing cost database with ${modelsToFetch.size} models...`);
// Fetch only the pricing we need
const apiUrl = getApiUrl('/config/pricing');
const secretKey = getSecretKey();
const headers: HeadersInit = { 'Content-Type': 'application/json' };
if (secretKey) {
headers['X-Secret-Key'] = secretKey;
}
const response = await fetch(apiUrl, {
method: 'POST',
headers,
body: JSON.stringify({
configured_only: false,
models: Array.from(modelsToFetch.values()), // Send specific models if API supports it
}),
});
if (!response.ok) {
console.error('Failed to fetch initial pricing data:', response.status);
return;
}
const data = await response.json();
console.log(`Fetched pricing for ${data.pricing?.length || 0} models`);
if (data.pricing && data.pricing.length > 0) {
// Filter to only the models we requested (in case API returns all)
const filteredPricing = data.pricing.filter((p: PricingItem) =>
modelsToFetch.has(`${p.provider}/${p.model}`)
);
// Save to localStorage
const cacheData: PricingCacheData = {
pricing: filteredPricing.length > 0 ? filteredPricing : data.pricing.slice(0, 50), // Fallback to first 50 if filtering didn't work
timestamp: Date.now(),
};
savePricingToLocalStorage(cacheData, false); // Don't merge on initial load
}
} catch (error) {
console.error('Error initializing cost database:', error);
}
}
/**
* Update model costs from providers - no longer needed
*/
export async function updateAllModelCosts(): Promise<void> {
// No-op - we fetch on demand now
}
/**
* Get cost information for a specific model with caching
*/
export function getCostForModel(provider: string, model: string): ModelCostInfo | null {
// Track this model as recently used
addToRecentlyUsed(provider, model);
// Check if it's the same model we already have cached in memory
if (
currentModelPricing &&
currentModelPricing.provider === provider &&
currentModelPricing.model === model
) {
return currentModelPricing.costInfo;
}
// For local/free providers, return zero cost immediately
const freeProviders = ['ollama', 'local', 'localhost'];
if (freeProviders.includes(provider.toLowerCase())) {
const zeroCost = {
input_token_cost: 0,
output_token_cost: 0,
currency: '$',
};
currentModelPricing = { provider, model, costInfo: zeroCost };
return zeroCost;
}
// Check localStorage cache (which now only contains recently used models)
const cachedData = loadPricingFromLocalStorage();
if (cachedData) {
const pricing = cachedData.pricing.find((p) => {
const providerMatch = p.provider.toLowerCase() === provider.toLowerCase();
// More flexible model matching - handle versioned models
let modelMatch = p.model === model;
// If exact match fails, try matching without version suffix
if (!modelMatch && model.includes('-20')) {
// Remove date suffix like -20241022
const modelWithoutDate = model.replace(/-20\d{6}$/, '');
modelMatch = p.model === modelWithoutDate;
// Also try with dots instead of dashes (claude-3-5-sonnet vs claude-3.5-sonnet)
if (!modelMatch) {
const modelWithDots = modelWithoutDate.replace(/-(\d)-/g, '.$1.');
modelMatch = p.model === modelWithDots;
}
}
return providerMatch && modelMatch;
});
if (pricing) {
const costInfo = {
input_token_cost: pricing.input_token_cost,
output_token_cost: pricing.output_token_cost,
currency: pricing.currency || '$',
};
currentModelPricing = { provider, model, costInfo };
return costInfo;
}
}
// Need to fetch new pricing - return null for now
// The component will handle the async fetch
return null;
}
/**
* Fetch and cache pricing for a model
*/
export async function fetchAndCachePricing(
provider: string,
model: string
): Promise<{ costInfo: ModelCostInfo | null; error?: string } | null> {
try {
const costInfo = await fetchPricingForModel(provider, model);
if (costInfo) {
// Cache the result in memory
currentModelPricing = { provider, model, costInfo };
// Update localStorage cache with this new data
const cachedData = loadPricingFromLocalStorage();
if (cachedData) {
// Check if this model already exists in cache
const existingIndex = cachedData.pricing.findIndex(
(p) => p.provider.toLowerCase() === provider.toLowerCase() && p.model === model
);
const newPricing = {
provider,
model,
input_token_cost: costInfo.input_token_cost,
output_token_cost: costInfo.output_token_cost,
currency: costInfo.currency,
};
if (existingIndex >= 0) {
// Update existing
cachedData.pricing[existingIndex] = newPricing;
} else {
// Add new
cachedData.pricing.push(newPricing);
}
// Save updated cache
savePricingToLocalStorage(cachedData);
}
return { costInfo };
} else {
// Cache the null result in memory
currentModelPricing = { provider, model, costInfo: null };
// Check if the API call succeeded but model wasn't found
// We can determine this by checking if we got a response but no matching model
return { costInfo: null, error: 'model_not_found' };
}
} catch (error) {
console.error('Error in fetchAndCachePricing:', error);
// This is a real API/network error
return null;
}
}
/**
* Refresh pricing data from backend - only refresh recently used models
*/
export async function refreshPricing(): Promise<boolean> {
try {
const apiUrl = getApiUrl('/config/pricing');
const secretKey = getSecretKey();
const headers: HeadersInit = { 'Content-Type': 'application/json' };
if (secretKey) {
headers['X-Secret-Key'] = secretKey;
}
// Get recently used models to refresh
const recentModels = getRecentlyUsedModels();
// Add some common models as well
const commonModels = [
{ provider: 'openai', model: 'gpt-4o' },
{ provider: 'openai', model: 'gpt-4o-mini' },
{ provider: 'anthropic', model: 'claude-3-5-sonnet-20241022' },
{ provider: 'google', model: 'gemini-1.5-pro' },
];
// Combine and deduplicate
const modelsToRefresh = new Map<string, { provider: string; model: string }>();
commonModels.forEach((m) => {
modelsToRefresh.set(`${m.provider}/${m.model}`, m);
});
recentModels.forEach((m) => {
modelsToRefresh.set(`${m.provider}/${m.model}`, { provider: m.provider, model: m.model });
});
console.log(`Refreshing pricing for ${modelsToRefresh.size} models...`);
const response = await fetch(apiUrl, {
method: 'POST',
headers,
body: JSON.stringify({
configured_only: false,
models: Array.from(modelsToRefresh.values()), // Send specific models if API supports it
}),
});
if (response.ok) {
const data = await response.json();
if (data.pricing && data.pricing.length > 0) {
// Filter to only the models we requested (in case API returns all)
const filteredPricing = data.pricing.filter((p: PricingItem) =>
modelsToRefresh.has(`${p.provider}/${p.model}`)
);
// Save fresh data to localStorage (merge with existing)
const cacheData: PricingCacheData = {
pricing: filteredPricing.length > 0 ? filteredPricing : data.pricing.slice(0, 50),
timestamp: Date.now(),
};
savePricingToLocalStorage(cacheData, true); // Merge with existing
}
// Clear current memory cache to force re-fetch
currentModelPricing = null;
return true;
}
return false;
} catch (error) {
console.error('Error refreshing pricing data:', error);
return false;
}
}
/**
* Clean up old/unused models from the cache
*/
export function cleanupPricingCache(): void {
try {
const recentModels = getRecentlyUsedModels();
const cachedData = localStorage.getItem(PRICING_CACHE_KEY);
if (!cachedData) return;
const fullCache = JSON.parse(cachedData) as PricingCacheData;
const recentModelKeys = new Set(recentModels.map((m) => `${m.provider}/${m.model}`));
// Keep only recently used models and common models
const commonModelKeys = new Set([
'openai/gpt-4o',
'openai/gpt-4o-mini',
'openai/gpt-4-turbo',
'anthropic/claude-3-5-sonnet',
'anthropic/claude-3-5-sonnet-20241022',
'google/gemini-1.5-pro',
'google/gemini-1.5-flash',
]);
const filteredPricing = fullCache.pricing.filter((p) => {
const key = `${p.provider}/${p.model}`;
return recentModelKeys.has(key) || commonModelKeys.has(key);
});
if (filteredPricing.length < fullCache.pricing.length) {
console.log(
`Cleaned up pricing cache: reduced from ${fullCache.pricing.length} to ${filteredPricing.length} models`
);
const cleanedCache: PricingCacheData = {
pricing: filteredPricing,
timestamp: fullCache.timestamp,
};
localStorage.setItem(PRICING_CACHE_KEY, JSON.stringify(cleanedCache));
}
} catch (error) {
console.error('Error cleaning up pricing cache:', error);
}
}

View File

@@ -44,6 +44,10 @@ export default {
'0%': { transform: 'rotate(0deg)' },
'100%': { transform: 'rotate(360deg)' },
},
'spin-fast': {
'0%': { transform: 'rotate(0deg)' },
'100%': { transform: 'rotate(360deg)' },
},
indeterminate: {
'0%': { left: '-40%', width: '40%' },
'50%': { left: '20%', width: '60%' },
@@ -54,6 +58,7 @@ export default {
'shimmer-pulse': 'shimmer 4s ease-in-out infinite',
'gradient-loader': 'loader 750ms ease-in-out infinite',
indeterminate: 'indeterminate 1.5s infinite linear',
'spin-fast': 'spin-fast 0.5s linear infinite',
},
colors: {
bgApp: 'var(--background-app)',