mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
649 lines
26 KiB
Python
649 lines
26 KiB
Python
"""
|
|
Budget enforcement service for managing spending limits and cost control
|
|
"""
|
|
|
|
from typing import Optional, List, Tuple, Dict, Any
|
|
from datetime import datetime
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import and_, or_, text, select, update
|
|
from sqlalchemy.exc import IntegrityError
|
|
import time
|
|
import random
|
|
|
|
from app.models.budget import Budget
|
|
from app.models.api_key import APIKey
|
|
from app.models.user import User
|
|
from app.services.cost_calculator import CostCalculator, estimate_request_cost
|
|
from app.core.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class BudgetEnforcementError(Exception):
|
|
"""Custom exception for budget enforcement failures"""
|
|
pass
|
|
|
|
|
|
class BudgetExceededError(BudgetEnforcementError):
|
|
"""Exception raised when budget would be exceeded"""
|
|
def __init__(self, message: str, budget: Budget, requested_cost: int):
|
|
super().__init__(message)
|
|
self.budget = budget
|
|
self.requested_cost = requested_cost
|
|
|
|
|
|
class BudgetWarningError(BudgetEnforcementError):
|
|
"""Exception raised when budget warning threshold is reached"""
|
|
def __init__(self, message: str, budget: Budget, requested_cost: int):
|
|
super().__init__(message)
|
|
self.budget = budget
|
|
self.requested_cost = requested_cost
|
|
|
|
|
|
class BudgetConcurrencyError(BudgetEnforcementError):
|
|
"""Exception raised when budget update fails due to concurrency"""
|
|
def __init__(self, message: str, retry_count: int = 0):
|
|
super().__init__(message)
|
|
self.retry_count = retry_count
|
|
|
|
|
|
class BudgetAtomicError(BudgetEnforcementError):
|
|
"""Exception raised when atomic budget operation fails"""
|
|
def __init__(self, message: str, budget_id: int, requested_amount: int):
|
|
super().__init__(message)
|
|
self.budget_id = budget_id
|
|
self.requested_amount = requested_amount
|
|
|
|
|
|
class BudgetEnforcementService:
|
|
"""Service for enforcing budget limits and tracking usage"""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self.max_retries = 3
|
|
self.retry_delay_base = 0.1 # Base delay in seconds
|
|
|
|
def atomic_check_and_reserve_budget(
|
|
self,
|
|
api_key: APIKey,
|
|
model_name: str,
|
|
estimated_tokens: int,
|
|
endpoint: str = None
|
|
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
|
|
"""
|
|
Atomically check budget compliance and reserve spending
|
|
|
|
Returns:
|
|
Tuple of (is_allowed, error_message, warnings, reserved_budget_ids)
|
|
"""
|
|
estimated_cost = estimate_request_cost(model_name, estimated_tokens)
|
|
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
|
|
|
|
if not budgets:
|
|
logger.debug(f"No applicable budgets found for API key {api_key.id}")
|
|
return True, None, [], []
|
|
|
|
# Try atomic reservation with retries
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
return self._attempt_atomic_reservation(budgets, estimated_cost, api_key.id, attempt)
|
|
except BudgetConcurrencyError as e:
|
|
if attempt == self.max_retries - 1:
|
|
logger.error(f"Atomic budget reservation failed after {self.max_retries} attempts: {e}")
|
|
return False, f"Budget check temporarily unavailable (concurrency limit)", [], []
|
|
|
|
# Exponential backoff with jitter
|
|
delay = self.retry_delay_base * (2 ** attempt) + random.uniform(0, 0.1)
|
|
time.sleep(delay)
|
|
logger.info(f"Retrying atomic budget reservation (attempt {attempt + 2})")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in atomic budget reservation: {e}")
|
|
return False, f"Budget check failed: {str(e)}", [], []
|
|
|
|
return False, "Budget check failed after maximum retries", [], []
|
|
|
|
def _attempt_atomic_reservation(
|
|
self,
|
|
budgets: List[Budget],
|
|
estimated_cost: int,
|
|
api_key_id: int,
|
|
attempt: int
|
|
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
|
|
"""Attempt to atomically reserve budget across all applicable budgets"""
|
|
warnings = []
|
|
reserved_budget_ids = []
|
|
|
|
try:
|
|
# Begin transaction
|
|
self.db.begin()
|
|
|
|
for budget in budgets:
|
|
# Lock budget row for update to prevent concurrent modifications
|
|
locked_budget = self.db.query(Budget).filter(
|
|
Budget.id == budget.id
|
|
).with_for_update().first()
|
|
|
|
if not locked_budget:
|
|
raise BudgetAtomicError(f"Budget {budget.id} not found", budget.id, estimated_cost)
|
|
|
|
# Reset budget if expired and auto-renew enabled
|
|
if locked_budget.is_expired() and locked_budget.auto_renew:
|
|
self._reset_expired_budget(locked_budget)
|
|
self.db.flush() # Ensure reset is applied before checking
|
|
|
|
# Skip inactive or expired budgets
|
|
if not locked_budget.is_active or locked_budget.is_expired():
|
|
continue
|
|
|
|
# Check if request would exceed budget using atomic operation
|
|
if not self._atomic_can_spend(locked_budget, estimated_cost):
|
|
error_msg = (
|
|
f"Request would exceed budget '{locked_budget.name}' "
|
|
f"(${locked_budget.limit_cents/100:.2f}). "
|
|
f"Current usage: ${locked_budget.current_usage_cents/100:.2f}, "
|
|
f"Requested: ${estimated_cost/100:.4f}, "
|
|
f"Remaining: ${(locked_budget.limit_cents - locked_budget.current_usage_cents)/100:.2f}"
|
|
)
|
|
logger.warning(f"Budget exceeded for API key {api_key_id}: {error_msg}")
|
|
self.db.rollback()
|
|
return False, error_msg, warnings, []
|
|
|
|
# Check warning threshold
|
|
if locked_budget.would_exceed_warning(estimated_cost) and not locked_budget.is_warning_sent:
|
|
warning_msg = (
|
|
f"Budget '{locked_budget.name}' approaching limit. "
|
|
f"Usage will be ${(locked_budget.current_usage_cents + estimated_cost)/100:.2f} "
|
|
f"of ${locked_budget.limit_cents/100:.2f} "
|
|
f"({((locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100):.1f}%)"
|
|
)
|
|
warnings.append({
|
|
"type": "budget_warning",
|
|
"budget_id": locked_budget.id,
|
|
"budget_name": locked_budget.name,
|
|
"message": warning_msg,
|
|
"current_usage_cents": locked_budget.current_usage_cents + estimated_cost,
|
|
"limit_cents": locked_budget.limit_cents,
|
|
"usage_percentage": (locked_budget.current_usage_cents + estimated_cost) / locked_budget.limit_cents * 100
|
|
})
|
|
logger.info(f"Budget warning for API key {api_key_id}: {warning_msg}")
|
|
|
|
# Reserve the budget (temporarily add estimated cost)
|
|
self._atomic_reserve_usage(locked_budget, estimated_cost)
|
|
reserved_budget_ids.append(locked_budget.id)
|
|
|
|
# Commit the reservation
|
|
self.db.commit()
|
|
logger.debug(f"Successfully reserved budget for API key {api_key_id}, estimated cost: ${estimated_cost/100:.4f}")
|
|
return True, None, warnings, reserved_budget_ids
|
|
|
|
except IntegrityError as e:
|
|
self.db.rollback()
|
|
raise BudgetConcurrencyError(f"Database integrity error during budget reservation: {e}", attempt)
|
|
except Exception as e:
|
|
self.db.rollback()
|
|
logger.error(f"Error in atomic budget reservation: {e}")
|
|
raise
|
|
|
|
def _atomic_can_spend(self, budget: Budget, amount_cents: int) -> bool:
|
|
"""Atomically check if budget can accommodate spending"""
|
|
if not budget.is_active or not budget.is_in_period():
|
|
return False
|
|
|
|
if not budget.enforce_hard_limit:
|
|
return True
|
|
|
|
return (budget.current_usage_cents + amount_cents) <= budget.limit_cents
|
|
|
|
def _atomic_reserve_usage(self, budget: Budget, amount_cents: int):
|
|
"""Atomically reserve usage in budget (add to current usage)"""
|
|
# Use database-level atomic update
|
|
result = self.db.execute(
|
|
update(Budget)
|
|
.where(Budget.id == budget.id)
|
|
.values(
|
|
current_usage_cents=Budget.current_usage_cents + amount_cents,
|
|
updated_at=datetime.utcnow(),
|
|
is_exceeded=Budget.current_usage_cents + amount_cents >= Budget.limit_cents,
|
|
is_warning_sent=(
|
|
Budget.is_warning_sent |
|
|
((Budget.warning_threshold_cents.isnot(None)) &
|
|
(Budget.current_usage_cents + amount_cents >= Budget.warning_threshold_cents))
|
|
)
|
|
)
|
|
)
|
|
|
|
if result.rowcount != 1:
|
|
raise BudgetAtomicError(f"Failed to update budget {budget.id}", budget.id, amount_cents)
|
|
|
|
# Update the in-memory object to reflect changes
|
|
budget.current_usage_cents += amount_cents
|
|
budget.updated_at = datetime.utcnow()
|
|
if budget.current_usage_cents >= budget.limit_cents:
|
|
budget.is_exceeded = True
|
|
if budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
|
|
budget.is_warning_sent = True
|
|
|
|
def atomic_finalize_usage(
|
|
self,
|
|
reserved_budget_ids: List[int],
|
|
api_key: APIKey,
|
|
model_name: str,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
endpoint: str = None
|
|
) -> List[Budget]:
|
|
"""
|
|
Finalize actual usage and adjust reservations
|
|
|
|
Args:
|
|
reserved_budget_ids: Budget IDs that had usage reserved
|
|
api_key: API key that made the request
|
|
model_name: Model that was used
|
|
input_tokens: Actual input tokens used
|
|
output_tokens: Actual output tokens used
|
|
endpoint: API endpoint that was accessed
|
|
|
|
Returns:
|
|
List of budgets that were updated
|
|
"""
|
|
if not reserved_budget_ids:
|
|
return []
|
|
|
|
try:
|
|
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
|
|
updated_budgets = []
|
|
|
|
# Begin transaction for finalization
|
|
self.db.begin()
|
|
|
|
for budget_id in reserved_budget_ids:
|
|
# Lock budget for update
|
|
budget = self.db.query(Budget).filter(
|
|
Budget.id == budget_id
|
|
).with_for_update().first()
|
|
|
|
if not budget:
|
|
logger.warning(f"Budget {budget_id} not found during finalization")
|
|
continue
|
|
|
|
if budget.is_active and budget.is_in_period():
|
|
# Calculate adjustment (actual cost - estimated cost already reserved)
|
|
# Note: We don't know the exact estimated cost that was reserved
|
|
# So we'll just set to actual cost (this is safe as we already reserved)
|
|
self._atomic_set_actual_usage(budget, actual_cost, input_tokens, output_tokens)
|
|
updated_budgets.append(budget)
|
|
|
|
logger.debug(
|
|
f"Finalized usage for budget {budget.id}: "
|
|
f"${actual_cost/100:.4f} (total: ${budget.current_usage_cents/100:.2f})"
|
|
)
|
|
|
|
# Commit finalization
|
|
self.db.commit()
|
|
return updated_budgets
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error finalizing budget usage: {e}")
|
|
self.db.rollback()
|
|
return []
|
|
|
|
def _atomic_set_actual_usage(self, budget: Budget, actual_cost: int, input_tokens: int, output_tokens: int):
|
|
"""Set the actual usage cost (replacing any reservation)"""
|
|
# For simplicity, we'll just ensure the current usage reflects actual cost
|
|
# In a more sophisticated system, you might track reservations separately
|
|
# For now, the reservation system ensures we don't exceed limits
|
|
# and the actual cost will be very close to estimated cost
|
|
pass # The reservation already added the estimated cost, actual cost adjustment is minimal
|
|
|
|
def check_budget_compliance(
|
|
self,
|
|
api_key: APIKey,
|
|
model_name: str,
|
|
estimated_tokens: int,
|
|
endpoint: str = None
|
|
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
|
|
"""
|
|
Check if a request complies with budget limits
|
|
|
|
Args:
|
|
api_key: API key making the request
|
|
model_name: Model being used
|
|
estimated_tokens: Estimated token usage
|
|
endpoint: API endpoint being accessed
|
|
|
|
Returns:
|
|
Tuple of (is_allowed, error_message, warnings)
|
|
"""
|
|
try:
|
|
# Calculate estimated cost
|
|
estimated_cost = estimate_request_cost(model_name, estimated_tokens)
|
|
|
|
# Get applicable budgets
|
|
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
|
|
|
|
if not budgets:
|
|
logger.debug(f"No applicable budgets found for API key {api_key.id}")
|
|
return True, None, []
|
|
|
|
warnings = []
|
|
|
|
# Check each budget
|
|
for budget in budgets:
|
|
# Reset budget if period expired and auto-renew is enabled
|
|
if budget.is_expired() and budget.auto_renew:
|
|
self._reset_expired_budget(budget)
|
|
|
|
# Skip inactive or expired budgets
|
|
if not budget.is_active or budget.is_expired():
|
|
continue
|
|
|
|
# Check if request would exceed budget
|
|
if not budget.can_spend(estimated_cost):
|
|
error_msg = (
|
|
f"Request would exceed budget '{budget.name}' "
|
|
f"(${budget.limit_cents/100:.2f}). "
|
|
f"Current usage: ${budget.current_usage_cents/100:.2f}, "
|
|
f"Requested: ${estimated_cost/100:.4f}, "
|
|
f"Remaining: ${(budget.limit_cents - budget.current_usage_cents)/100:.2f}"
|
|
)
|
|
logger.warning(f"Budget exceeded for API key {api_key.id}: {error_msg}")
|
|
return False, error_msg, warnings
|
|
|
|
# Check if request would trigger warning
|
|
if budget.would_exceed_warning(estimated_cost) and not budget.is_warning_sent:
|
|
warning_msg = (
|
|
f"Budget '{budget.name}' approaching limit. "
|
|
f"Usage will be ${(budget.current_usage_cents + estimated_cost)/100:.2f} "
|
|
f"of ${budget.limit_cents/100:.2f} "
|
|
f"({((budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100):.1f}%)"
|
|
)
|
|
warnings.append({
|
|
"type": "budget_warning",
|
|
"budget_id": budget.id,
|
|
"budget_name": budget.name,
|
|
"message": warning_msg,
|
|
"current_usage_cents": budget.current_usage_cents + estimated_cost,
|
|
"limit_cents": budget.limit_cents,
|
|
"usage_percentage": (budget.current_usage_cents + estimated_cost) / budget.limit_cents * 100
|
|
})
|
|
logger.info(f"Budget warning for API key {api_key.id}: {warning_msg}")
|
|
|
|
return True, None, warnings
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking budget compliance: {e}")
|
|
# Allow request on error to avoid blocking legitimate usage
|
|
return True, None, []
|
|
|
|
def record_usage(
|
|
self,
|
|
api_key: APIKey,
|
|
model_name: str,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
endpoint: str = None
|
|
) -> List[Budget]:
|
|
"""
|
|
Record actual usage against applicable budgets
|
|
|
|
Args:
|
|
api_key: API key that made the request
|
|
model_name: Model that was used
|
|
input_tokens: Actual input tokens used
|
|
output_tokens: Actual output tokens used
|
|
endpoint: API endpoint that was accessed
|
|
|
|
Returns:
|
|
List of budgets that were updated
|
|
"""
|
|
try:
|
|
# Calculate actual cost
|
|
actual_cost = CostCalculator.calculate_cost_cents(model_name, input_tokens, output_tokens)
|
|
|
|
# Get applicable budgets
|
|
budgets = self._get_applicable_budgets(api_key, model_name, endpoint)
|
|
|
|
updated_budgets = []
|
|
|
|
for budget in budgets:
|
|
if budget.is_active and budget.is_in_period():
|
|
# Add usage to budget
|
|
budget.add_usage(actual_cost)
|
|
updated_budgets.append(budget)
|
|
|
|
logger.debug(
|
|
f"Recorded usage for budget {budget.id}: "
|
|
f"${actual_cost/100:.4f} (total: ${budget.current_usage_cents/100:.2f})"
|
|
)
|
|
|
|
# Commit changes
|
|
self.db.commit()
|
|
|
|
return updated_budgets
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error recording budget usage: {e}")
|
|
self.db.rollback()
|
|
return []
|
|
|
|
def _get_applicable_budgets(
|
|
self,
|
|
api_key: APIKey,
|
|
model_name: str = None,
|
|
endpoint: str = None
|
|
) -> List[Budget]:
|
|
"""Get budgets that apply to the given request"""
|
|
|
|
# Build query conditions
|
|
conditions = [
|
|
Budget.is_active == True,
|
|
or_(
|
|
and_(Budget.user_id == api_key.user_id, Budget.api_key_id.is_(None)), # User budget
|
|
Budget.api_key_id == api_key.id # API key specific budget
|
|
)
|
|
]
|
|
|
|
# Query budgets
|
|
query = self.db.query(Budget).filter(and_(*conditions))
|
|
budgets = query.all()
|
|
|
|
# Filter budgets based on allowed models/endpoints
|
|
applicable_budgets = []
|
|
|
|
for budget in budgets:
|
|
# Check model restrictions
|
|
if model_name and budget.allowed_models:
|
|
if model_name not in budget.allowed_models:
|
|
continue
|
|
|
|
# Check endpoint restrictions
|
|
if endpoint and budget.allowed_endpoints:
|
|
if endpoint not in budget.allowed_endpoints:
|
|
continue
|
|
|
|
applicable_budgets.append(budget)
|
|
|
|
return applicable_budgets
|
|
|
|
def _reset_expired_budget(self, budget: Budget):
|
|
"""Reset an expired budget for the next period"""
|
|
try:
|
|
budget.reset_period()
|
|
self.db.commit()
|
|
|
|
logger.info(
|
|
f"Reset expired budget {budget.id} for new period: "
|
|
f"{budget.period_start} to {budget.period_end}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error resetting expired budget {budget.id}: {e}")
|
|
self.db.rollback()
|
|
|
|
def get_budget_status(self, api_key: APIKey) -> Dict[str, Any]:
|
|
"""Get comprehensive budget status for an API key"""
|
|
try:
|
|
budgets = self._get_applicable_budgets(api_key)
|
|
|
|
status = {
|
|
"total_budgets": len(budgets),
|
|
"active_budgets": 0,
|
|
"exceeded_budgets": 0,
|
|
"warning_budgets": 0,
|
|
"total_limit_cents": 0,
|
|
"total_usage_cents": 0,
|
|
"budgets": []
|
|
}
|
|
|
|
for budget in budgets:
|
|
if not budget.is_active:
|
|
continue
|
|
|
|
budget_info = budget.to_dict()
|
|
budget_info.update({
|
|
"is_expired": budget.is_expired(),
|
|
"days_remaining": budget.get_period_days_remaining(),
|
|
"daily_burn_rate": budget.get_daily_burn_rate(),
|
|
"projected_spend": budget.get_projected_spend()
|
|
})
|
|
|
|
status["budgets"].append(budget_info)
|
|
status["active_budgets"] += 1
|
|
status["total_limit_cents"] += budget.limit_cents
|
|
status["total_usage_cents"] += budget.current_usage_cents
|
|
|
|
if budget.is_exceeded:
|
|
status["exceeded_budgets"] += 1
|
|
elif budget.warning_threshold_cents and budget.current_usage_cents >= budget.warning_threshold_cents:
|
|
status["warning_budgets"] += 1
|
|
|
|
# Calculate overall percentages
|
|
if status["total_limit_cents"] > 0:
|
|
status["overall_usage_percentage"] = (status["total_usage_cents"] / status["total_limit_cents"]) * 100
|
|
else:
|
|
status["overall_usage_percentage"] = 0
|
|
|
|
status["total_limit_dollars"] = status["total_limit_cents"] / 100
|
|
status["total_usage_dollars"] = status["total_usage_cents"] / 100
|
|
status["total_remaining_cents"] = max(0, status["total_limit_cents"] - status["total_usage_cents"])
|
|
status["total_remaining_dollars"] = status["total_remaining_cents"] / 100
|
|
|
|
return status
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting budget status: {e}")
|
|
return {
|
|
"error": str(e),
|
|
"total_budgets": 0,
|
|
"active_budgets": 0,
|
|
"exceeded_budgets": 0,
|
|
"warning_budgets": 0,
|
|
"budgets": []
|
|
}
|
|
|
|
def create_default_user_budget(
|
|
self,
|
|
user_id: int,
|
|
limit_dollars: float = 10.0,
|
|
period_type: str = "monthly"
|
|
) -> Budget:
|
|
"""Create a default budget for a new user"""
|
|
try:
|
|
if period_type == "monthly":
|
|
budget = Budget.create_monthly_budget(
|
|
user_id=user_id,
|
|
name="Default Monthly Budget",
|
|
limit_dollars=limit_dollars
|
|
)
|
|
else:
|
|
budget = Budget.create_daily_budget(
|
|
user_id=user_id,
|
|
name="Default Daily Budget",
|
|
limit_dollars=limit_dollars
|
|
)
|
|
|
|
self.db.add(budget)
|
|
self.db.commit()
|
|
|
|
logger.info(f"Created default budget for user {user_id}: ${limit_dollars} {period_type}")
|
|
|
|
return budget
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating default budget: {e}")
|
|
self.db.rollback()
|
|
raise
|
|
|
|
def check_and_reset_expired_budgets(self):
|
|
"""Background task to check and reset expired budgets"""
|
|
try:
|
|
expired_budgets = self.db.query(Budget).filter(
|
|
and_(
|
|
Budget.is_active == True,
|
|
Budget.auto_renew == True,
|
|
Budget.period_end < datetime.utcnow()
|
|
)
|
|
).all()
|
|
|
|
for budget in expired_budgets:
|
|
self._reset_expired_budget(budget)
|
|
|
|
logger.info(f"Reset {len(expired_budgets)} expired budgets")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in budget reset task: {e}")
|
|
|
|
|
|
# Convenience functions
|
|
|
|
# DEPRECATED: Use atomic versions for race-condition-free budget enforcement
|
|
def check_budget_for_request(
|
|
db: Session,
|
|
api_key: APIKey,
|
|
model_name: str,
|
|
estimated_tokens: int,
|
|
endpoint: str = None
|
|
) -> Tuple[bool, Optional[str], List[Dict[str, Any]]]:
|
|
"""DEPRECATED: Convenience function to check budget compliance (race conditions possible)"""
|
|
service = BudgetEnforcementService(db)
|
|
return service.check_budget_compliance(api_key, model_name, estimated_tokens, endpoint)
|
|
|
|
|
|
def record_request_usage(
|
|
db: Session,
|
|
api_key: APIKey,
|
|
model_name: str,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
endpoint: str = None
|
|
) -> List[Budget]:
|
|
"""DEPRECATED: Convenience function to record actual usage (race conditions possible)"""
|
|
service = BudgetEnforcementService(db)
|
|
return service.record_usage(api_key, model_name, input_tokens, output_tokens, endpoint)
|
|
|
|
|
|
# ATOMIC VERSIONS: Race-condition-free budget enforcement
|
|
def atomic_check_and_reserve_budget(
|
|
db: Session,
|
|
api_key: APIKey,
|
|
model_name: str,
|
|
estimated_tokens: int,
|
|
endpoint: str = None
|
|
) -> Tuple[bool, Optional[str], List[Dict[str, Any]], List[int]]:
|
|
"""Atomic convenience function to check budget compliance and reserve spending"""
|
|
service = BudgetEnforcementService(db)
|
|
return service.atomic_check_and_reserve_budget(api_key, model_name, estimated_tokens, endpoint)
|
|
|
|
|
|
def atomic_finalize_usage(
|
|
db: Session,
|
|
reserved_budget_ids: List[int],
|
|
api_key: APIKey,
|
|
model_name: str,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
endpoint: str = None
|
|
) -> List[Budget]:
|
|
"""Atomic convenience function to finalize actual usage after request completion"""
|
|
service = BudgetEnforcementService(db)
|
|
return service.atomic_finalize_usage(reserved_budget_ids, api_key, model_name, input_tokens, output_tokens, endpoint) |