Files
enclava/backend/app/services/budget_enforcement.py
2025-08-19 09:50:15 +02:00

649 lines
26 KiB
Python

"""
Budget enforcement service for managing spending limits and cost control
"""
from typing import Optional, List, Tuple, Dict, Any
from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, text, select, update
from sqlalchemy.exc import IntegrityError
import time
import random
from app.models.budget import Budget
from app.models.api_key import APIKey
from app.models.user import User
from app.services.cost_calculator import CostCalculator, estimate_request_cost
from app.core.logging import get_logger
logger = get_logger(__name__)
class BudgetEnforcementError(Exception):
"""Custom exception for budget enforcement failures"""
pass
class BudgetExceededError(BudgetEnforcementError):
"""Exception raised when budget would be exceeded"""
def __init__(self, message: str, budget: Budget, requested_cost: int):
super().__init__(message)
self.budget = budget
self.requested_cost = requested_cost
class BudgetWarningError(BudgetEnforcementError):
"""Exception raised when budget warning threshold is reached"""
def __init__(self, message: str, budget: Budget, requested_cost: int):
super().__init__(message)
self.budget = budget
self.requested_cost = requested_cost
class BudgetConcurrencyError(BudgetEnforcementError):
"""Exception raised when budget update fails due to concurrency"""
def __init__(self, message: str, retry_count: int = 0):
super().__init__(message)
self.retry_count = retry_count
class BudgetAtomicError(BudgetEnforcementError):
"""Exception raised when atomic budget operation fails"""
def __init__(self, message: str, budget_id: int, requested_amount: int):
super().__init__(message)
self.budget_id = budget_id
self.requested_amount = requested_amount
class BudgetEnforcementService:
"""Service for enforcing budget limits and tracking usage"""
def __init__(self, db: Session):
self.db = db
self.max_retries = 3
self.retry_delay_base = 0.1 # Base delay in seconds
def atomic_check_and_reserve_budget(
self,
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
"""
Atomically check budget compliance and reserve spending
Returns:
Tuple of (is_allowed, error_message, warnings, reserved_budget_ids)
"""
estimated_cost = estimate_request_cost(model_name, estimated_tokens)
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
if not budgets:
logger.debug(f"No applicable budgets found for API key {api_key.id}")
return True, None, [], []
# Try atomic reservation with retries
for attempt in range(self.max_retries):
try:
return self._attempt_atomic_reservation(budgets, estimated_cost, api_key.id, attempt)
except BudgetConcurrencyError as e:
if attempt == self.max_retries - 1:
logger.error(f"Atomic budget reservation failed after {self.max_retries} attempts: {e}")
return False, f"Budget check temporarily unavailable (concurrency limit)", [], []
# Exponential backoff with jitter
delay = self.retry_delay_base * (2 ** attempt) + random.uniform(0, 0.1)
time.sleep(delay)
logger.info(f"Retrying atomic budget reservation (attempt {attempt + 2})")
except Exception as e:
logger.error(f"Unexpected error in atomic budget reservation: {e}")
return False, f"Budget check failed: {str(e)}", [], []
return False, "Budget check failed after maximum retries", [], []
def _attempt_atomic_reservation(
self,
budgets: List[Budget],
estimated_cost: int,
api_key_id: int,
attempt: int
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
"""Attempt to atomically reserve budget across all applicable budgets"""
warnings = []
reserved_budget_ids = []
try:
# Begin transaction
self.db.begin()
for budget in budgets:
# Lock budget row for update to prevent concurrent modifications
locked_budget = self.db.query(Budget).filter(
Budget.id == budget.id
).with_for_update().first()
if not locked_budget:
raise BudgetAtomicError(f"Budget {budget.id} not found", budget.id, estimated_cost)
# Reset budget if expired and auto-renew enabled
if locked_budget.is_expired() and locked_budget.auto_renew:
self._reset_expired_budget(locked_budget)
self.db.flush() # Ensure reset is applied before checking
# Skip inactive or expired budgets
if not locked_budget.is_active or locked_budget.is_expired():
continue
# Check if request would exceed budget using atomic operation
if not self._atomic_can_spend(locked_budget, estimated_cost):
error_msg = (
f"Request would exceed budget '{locked_budget.name}' "
f"(${locked_budget.limit_cents/100:.2f}). "
f"Current usage: ${locked_budget.current_usage_cents/100:.2f}, "
f"Requested: ${estimated_cost/100:.4f}, "
f"Remaining: ${(locked_budget.limit_cents - locked_budget.current_usage_cents)/100:.2f}"
)
logger.warning(f"Budget exceeded for API key {api_key_id}: {error_msg}")
self.db.rollback()
return False, error_msg, warnings, []
# Check warning threshold
if locked_budget.would_exceed_warning(estimated_cost) and not locked_budget.is_warning_sent:
warning_msg = (
f"Budget '{locked_budget.name}' approaching limit. "
f"Usage will be ${(locked_budget.current_usage_cents + estimated_cost)/100:.2f} "
f"of ${locked_budget.limit_cents/100:.2f} "
f"({((locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100):.1f}%)"
)
warnings.append({
"type": "budget_warning",
"budget_id": locked_budget.id,
"budget_name": locked_budget.name,
"message": warning_msg,
"current_usage_cents": locked_budget.current_usage_cents + estimated_cost,
"limit_cents": locked_budget.limit_cents,
"usage_percentage": (locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100
})
logger.info(f"Budget warning for API key {api_key_id}: {warning_msg}")
# Reserve the budget (temporarily add estimated cost)
self._atomic_reserve_usage(locked_budget, estimated_cost)
reserved_budget_ids.append(locked_budget.id)
# Commit the reservation
self.db.commit()
logger.debug(f"Successfully reserved budget for API key {api_key_id}, estimated cost: ${estimated_cost/100:.4f}")
return True, None, warnings, reserved_budget_ids
except IntegrityError as e:
self.db.rollback()
raise BudgetConcurrencyError(f"Database integrity error during budget reservation: {e}", attempt)
except Exception as e:
self.db.rollback()
logger.error(f"Error in atomic budget reservation: {e}")
raise
def _atomic_can_spend(self, budget: Budget, amount_cents: int) -> bool:
"""Atomically check if budget can accommodate spending"""
if not budget.is_active or not budget.is_in_period():
return False
if not budget.enforce_hard_limit:
return True
return (budget.current_usage_cents + amount_cents) <= budget.limit_cents
def _atomic_reserve_usage(self, budget: Budget, amount_cents: int):
"""Atomically reserve usage in budget (add to current usage)"""
# Use database-level atomic update
result = self.db.execute(
update(Budget)
.where(Budget.id == budget.id)
.values(
current_usage_cents=Budget.current_usage_cents + amount_cents,
updated_at=datetime.utcnow(),
is_exceeded=Budget.current_usage_cents + amount_cents >= Budget.limit_cents,
is_warning_sent=(
Budget.is_warning_sent |
((Budget.warning_threshold_cents.isnot(None)) &
(Budget.current_usage_cents + amount_cents >= Budget.warning_threshold_cents))
)
)
)
if result.rowcount != 1:
raise BudgetAtomicError(f"Failed to update budget {budget.id}", budget.id, amount_cents)
# Update the in-memory object to reflect changes
budget.current_usage_cents += amount_cents
budget.updated_at = datetime.utcnow()
if budget.current_usage_cents >= budget.limit_cents:
budget.is_exceeded = True
if budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
budget.is_warning_sent = True
def atomic_finalize_usage(
self,
reserved_budget_ids: List[int],
api_key: APIKey,
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
) -> List[Budget]:
"""
Finalize actual usage and adjust reservations
Args:
reserved_budget_ids: Budget IDs that had usage reserved
api_key: API key that made the request
model_name: Model that was used
input_tokens: Actual input tokens used
output_tokens: Actual output tokens used
endpoint: API endpoint that was accessed
Returns:
List of budgets that were updated
"""
if not reserved_budget_ids:
return []
try:
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
updated_budgets = []
# Begin transaction for finalization
self.db.begin()
for budget_id in reserved_budget_ids:
# Lock budget for update
budget = self.db.query(Budget).filter(
Budget.id == budget_id
).with_for_update().first()
if not budget:
logger.warning(f"Budget {budget_id} not found during finalization")
continue
if budget.is_active and budget.is_in_period():
# Calculate adjustment (actual cost - estimated cost already reserved)
# Note: We don't know the exact estimated cost that was reserved
# So we'll just set to actual cost (this is safe as we already reserved)
self._atomic_set_actual_usage(budget, actual_cost, input_tokens, output_tokens)
updated_budgets.append(budget)
logger.debug(
f"Finalized usage for budget {budget.id}: "
f"${actual_cost/100:.4f} (total: ${budget.current_usage_cents/100:.2f})"
)
# Commit finalization
self.db.commit()
return updated_budgets
except Exception as e:
logger.error(f"Error finalizing budget usage: {e}")
self.db.rollback()
return []
def _atomic_set_actual_usage(self, budget: Budget, actual_cost: int, input_tokens: int, output_tokens: int):
"""Set the actual usage cost (replacing any reservation)"""
# For simplicity, we'll just ensure the current usage reflects actual cost
# In a more sophisticated system, you might track reservations separately
# For now, the reservation system ensures we don't exceed limits
# and the actual cost will be very close to estimated cost
pass # The reservation already added the estimated cost, actual cost adjustment is minimal
def check_budget_compliance(
self,
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
"""
Check if a request complies with budget limits
Args:
api_key: API key making the request
model_name: Model being used
estimated_tokens: Estimated token usage
endpoint: API endpoint being accessed
Returns:
Tuple of (is_allowed, error_message, warnings)
"""
try:
# Calculate estimated cost
estimated_cost = estimate_request_cost(model_name, estimated_tokens)
# Get applicable budgets
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
if not budgets:
logger.debug(f"No applicable budgets found for API key {api_key.id}")
return True, None, []
warnings = []
# Check each budget
for budget in budgets:
# Reset budget if period expired and auto-renew is enabled
if budget.is_expired() and budget.auto_renew:
self._reset_expired_budget(budget)
# Skip inactive or expired budgets
if not budget.is_active or budget.is_expired():
continue
# Check if request would exceed budget
if not budget.can_spend(estimated_cost):
error_msg = (
f"Request would exceed budget '{budget.name}' "
f"(${budget.limit_cents/100:.2f}). "
f"Current usage: ${budget.current_usage_cents/100:.2f}, "
f"Requested: ${estimated_cost/100:.4f}, "
f"Remaining: ${(budget.limit_cents - budget.current_usage_cents)/100:.2f}"
)
logger.warning(f"Budget exceeded for API key {api_key.id}: {error_msg}")
return False, error_msg, warnings
# Check if request would trigger warning
if budget.would_exceed_warning(estimated_cost) and not budget.is_warning_sent:
warning_msg = (
f"Budget '{budget.name}' approaching limit. "
f"Usage will be ${(budget.current_usage_cents + estimated_cost)/100:.2f} "
f"of ${budget.limit_cents/100:.2f} "
f"({((budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100):.1f}%)"
)
warnings.append({
"type": "budget_warning",
"budget_id": budget.id,
"budget_name": budget.name,
"message": warning_msg,
"current_usage_cents": budget.current_usage_cents + estimated_cost,
"limit_cents": budget.limit_cents,
"usage_percentage": (budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100
})
logger.info(f"Budget warning for API key {api_key.id}: {warning_msg}")
return True, None, warnings
except Exception as e:
logger.error(f"Error checking budget compliance: {e}")
# Allow request on error to avoid blocking legitimate usage
return True, None, []
def record_usage(
self,
api_key: APIKey,
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
) -> List[Budget]:
"""
Record actual usage against applicable budgets
Args:
api_key: API key that made the request
model_name: Model that was used
input_tokens: Actual input tokens used
output_tokens: Actual output tokens used
endpoint: API endpoint that was accessed
Returns:
List of budgets that were updated
"""
try:
# Calculate actual cost
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
# Get applicable budgets
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
updated_budgets = []
for budget in budgets:
if budget.is_active and budget.is_in_period():
# Add usage to budget
budget.add_usage(actual_cost)
updated_budgets.append(budget)
logger.debug(
f"Recorded usage for budget {budget.id}: "
f"${actual_cost/100:.4f} (total: ${budget.current_usage_cents/100:.2f})"
)
# Commit changes
self.db.commit()
return updated_budgets
except Exception as e:
logger.error(f"Error recording budget usage: {e}")
self.db.rollback()
return []
def _get_applicable_budgets(
self,
api_key: APIKey,
model_name: str = None,
endpoint: str = None
) -> List[Budget]:
"""Get budgets that apply to the given request"""
# Build query conditions
conditions = [
Budget.is_active == True,
or_(
and_(Budget.user_id == api_key.user_id, Budget.api_key_id.is_(None)), # User budget
Budget.api_key_id == api_key.id # API key specific budget
)
]
# Query budgets
query = self.db.query(Budget).filter(and_(*conditions))
budgets = query.all()
# Filter budgets based on allowed models/endpoints
applicable_budgets = []
for budget in budgets:
# Check model restrictions
if model_name and budget.allowed_models:
if model_name not in budget.allowed_models:
continue
# Check endpoint restrictions
if endpoint and budget.allowed_endpoints:
if endpoint not in budget.allowed_endpoints:
continue
applicable_budgets.append(budget)
return applicable_budgets
def _reset_expired_budget(self, budget: Budget):
"""Reset an expired budget for the next period"""
try:
budget.reset_period()
self.db.commit()
logger.info(
f"Reset expired budget {budget.id} for new period: "
f"{budget.period_start} to {budget.period_end}"
)
except Exception as e:
logger.error(f"Error resetting expired budget {budget.id}: {e}")
self.db.rollback()
def get_budget_status(self, api_key: APIKey) -> Dict[str, Any]:
"""Get comprehensive budget status for an API key"""
try:
budgets = self._get_applicable_budgets(api_key)
status = {
"total_budgets": len(budgets),
"active_budgets": 0,
"exceeded_budgets": 0,
"warning_budgets": 0,
"total_limit_cents": 0,
"total_usage_cents": 0,
"budgets": []
}
for budget in budgets:
if not budget.is_active:
continue
budget_info = budget.to_dict()
budget_info.update({
"is_expired": budget.is_expired(),
"days_remaining": budget.get_period_days_remaining(),
"daily_burn_rate": budget.get_daily_burn_rate(),
"projected_spend": budget.get_projected_spend()
})
status["budgets"].append(budget_info)
status["active_budgets"] += 1
status["total_limit_cents"] += budget.limit_cents
status["total_usage_cents"] += budget.current_usage_cents
if budget.is_exceeded:
status["exceeded_budgets"] += 1
elif budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
status["warning_budgets"] += 1
# Calculate overall percentages
if status["total_limit_cents"] > 0:
status["overall_usage_percentage"] = (status["total_usage_cents"] / status["total_limit_cents"]) * 100
else:
status["overall_usage_percentage"] = 0
status["total_limit_dollars"] = status["total_limit_cents"] / 100
status["total_usage_dollars"] = status["total_usage_cents"] / 100
status["total_remaining_cents"] = max(0, status["total_limit_cents"] - status["total_usage_cents"])
status["total_remaining_dollars"] = status["total_remaining_cents"] / 100
return status
except Exception as e:
logger.error(f"Error getting budget status: {e}")
return {
"error": str(e),
"total_budgets": 0,
"active_budgets": 0,
"exceeded_budgets": 0,
"warning_budgets": 0,
"budgets": []
}
def create_default_user_budget(
self,
user_id: int,
limit_dollars: float = 10.0,
period_type: str = "monthly"
) -> Budget:
"""Create a default budget for a new user"""
try:
if period_type == "monthly":
budget = Budget.create_monthly_budget(
user_id=user_id,
name="Default Monthly Budget",
limit_dollars=limit_dollars
)
else:
budget = Budget.create_daily_budget(
user_id=user_id,
name="Default Daily Budget",
limit_dollars=limit_dollars
)
self.db.add(budget)
self.db.commit()
logger.info(f"Created default budget for user {user_id}: ${limit_dollars} {period_type}")
return budget
except Exception as e:
logger.error(f"Error creating default budget: {e}")
self.db.rollback()
raise
def check_and_reset_expired_budgets(self):
"""Background task to check and reset expired budgets"""
try:
expired_budgets = self.db.query(Budget).filter(
and_(
Budget.is_active == True,
Budget.auto_renew == True,
Budget.period_end < datetime.utcnow()
)
).all()
for budget in expired_budgets:
self._reset_expired_budget(budget)
logger.info(f"Reset {len(expired_budgets)} expired budgets")
except Exception as e:
logger.error(f"Error in budget reset task: {e}")
# Convenience functions
# DEPRECATED: Use atomic versions for race-condition-free budget enforcement
def check_budget_for_request(
db: Session,
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
"""DEPRECATED: Convenience function to check budget compliance (race conditions possible)"""
service = BudgetEnforcementService(db)
return service.check_budget_compliance(api_key, model_name, estimated_tokens, endpoint)
def record_request_usage(
db: Session,
api_key: APIKey,
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
) -> List[Budget]:
"""DEPRECATED: Convenience function to record actual usage (race conditions possible)"""
service = BudgetEnforcementService(db)
return service.record_usage(api_key, model_name, input_tokens, output_tokens, endpoint)
# ATOMIC VERSIONS: Race-condition-free budget enforcement
def atomic_check_and_reserve_budget(
db: Session,
api_key: APIKey,
model_name: str,
estimated_tokens: int,
endpoint: str = None
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
"""Atomic convenience function to check budget compliance and reserve spending"""
service = BudgetEnforcementService(db)
return service.atomic_check_and_reserve_budget(api_key, model_name, estimated_tokens, endpoint)
def atomic_finalize_usage(
db: Session,
reserved_budget_ids: List[int],
api_key: APIKey,
model_name: str,
input_tokens: int,
output_tokens: int,
endpoint: str = None
) -> List[Budget]:
"""Atomic convenience function to finalize actual usage after request completion"""
service = BudgetEnforcementService(db)
return service.atomic_finalize_usage(reserved_budget_ids, api_key, model_name, input_tokens, output_tokens, endpoint)