""" LLM API endpoints - interface to secure LLM service with authentication and budget enforcement """ import logging import time from typing import Dict, Any, List, Optional from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app.db.database import get_db from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthService, get_api_key_context from app.core.security import get_current_user from app.models.user import User from app.core.config import settings from app.services.llm.service import llm_service from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage, EmbeddingRequest as LLMEmbeddingRequest from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError from app.services.budget_enforcement import ( check_budget_for_request, record_request_usage, BudgetEnforcementService, atomic_check_and_reserve_budget, atomic_finalize_usage ) from app.services.cost_calculator import CostCalculator, estimate_request_cost from app.utils.exceptions import AuthenticationError, AuthorizationError from app.middleware.analytics import set_analytics_data logger = logging.getLogger(__name__) # Models response cache - simple in-memory cache for performance _models_cache = { "data": None, "cached_at": 0, "cache_ttl": 900 # 15 minutes cache TTL } router = APIRouter() async def get_cached_models() -> List[Dict[str, Any]]: """Get models from cache or fetch from LLM service if cache is stale""" current_time = time.time() # Check if cache is still valid if (_models_cache["data"] is not None and current_time - _models_cache["cached_at"] < _models_cache["cache_ttl"]): logger.debug("Returning cached models list") return _models_cache["data"] # Cache miss or stale - fetch from LLM service try: logger.debug("Fetching fresh models list from LLM service") model_infos = await llm_service.get_models() # Convert ModelInfo objects to dict format for compatibility models = [] for model_info in model_infos: models.append({ "id": model_info.id, "object": model_info.object, "created": model_info.created or int(time.time()), "owned_by": model_info.owned_by, # Add frontend-expected fields "name": getattr(model_info, 'name', model_info.id), # Use name if available, fallback to id "provider": getattr(model_info, 'provider', model_info.owned_by) # Use provider if available, fallback to owned_by }) # Update cache _models_cache["data"] = models _models_cache["cached_at"] = current_time return models except Exception as e: logger.error(f"Failed to fetch models from LLM service: {e}") # Return stale cache if available, otherwise empty list if _models_cache["data"] is not None: logger.warning("Returning stale cached models due to fetch error") return _models_cache["data"] return [] def invalidate_models_cache(): """Invalidate the models cache (useful for admin operations)""" _models_cache["data"] = None _models_cache["cached_at"] = 0 logger.info("Models cache invalidated") # Request/Response Models (API layer) class ChatMessage(BaseModel): role: str = Field(..., description="Message role (system, user, assistant)") content: str = Field(..., description="Message content") class ChatCompletionRequest(BaseModel): model: str = Field(..., description="Model name") messages: List[ChatMessage] = Field(..., description="List of messages") max_tokens: Optional[int] = Field(None, description="Maximum tokens to generate") temperature: Optional[float] = Field(None, description="Temperature for sampling") top_p: Optional[float] = Field(None, description="Top-p sampling parameter") frequency_penalty: Optional[float] = Field(None, description="Frequency penalty") presence_penalty: Optional[float] = Field(None, description="Presence penalty") stop: Optional[List[str]] = Field(None, description="Stop sequences") stream: Optional[bool] = Field(False, description="Stream response") class EmbeddingRequest(BaseModel): model: str = Field(..., description="Model name") input: str = Field(..., description="Input text to embed") encoding_format: Optional[str] = Field("float", description="Encoding format") class ModelInfo(BaseModel): id: str object: str = "model" created: int owned_by: str class ModelsResponse(BaseModel): object: str = "list" data: List[ModelInfo] # Hybrid authentication function async def get_auth_context( request: Request, db: AsyncSession = Depends(get_db) ) -> Dict[str, Any]: """Get authentication context from either API key or JWT token""" # Try API key authentication first auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): token = auth_header[7:] # Check if it's an API key (starts with ce_ prefix) if token.startswith(settings.API_KEY_PREFIX): try: context = await get_api_key_context(request, db) if context: return context except Exception as e: logger.warning(f"API key authentication failed: {e}") else: # Try JWT token authentication try: from app.core.security import get_current_user # Create a fake credentials object for JWT validation from fastapi.security import HTTPAuthorizationCredentials credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) user = await get_current_user(credentials, db) if user: return { "user": user, "auth_type": "jwt", "api_key": None } except Exception as e: logger.warning(f"JWT authentication failed: {e}") # Try X-API-Key header api_key = request.headers.get("X-API-Key") if api_key: try: context = await get_api_key_context(request, db) if context: return context except Exception as e: logger.warning(f"X-API-Key authentication failed: {e}") # No valid authentication found raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Valid API key or authentication token required" ) # Endpoints @router.get("/models", response_model=ModelsResponse) async def list_models( context: Dict[str, Any] = Depends(get_auth_context), db: AsyncSession = Depends(get_db) ): """List available models""" try: # For JWT users, allow access to list models if context.get("auth_type") == "jwt": pass # JWT users can list models else: # For API key users, check permissions auth_service = APIKeyAuthService(db) if not await auth_service.check_scope_permission(context, "models.list"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions to list models" ) # Get models from cache or LLM service models = await get_cached_models() # Filter models based on API key permissions api_key = context.get("api_key") if api_key and api_key.allowed_models: models = [model for model in models if model.get("id") in api_key.allowed_models] return ModelsResponse(data=models) except HTTPException: raise except Exception as e: logger.error(f"Error listing models: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to list models" ) @router.post("/models/invalidate-cache") async def invalidate_models_cache_endpoint( context: Dict[str, Any] = Depends(get_auth_context), db: AsyncSession = Depends(get_db) ): """Invalidate models cache (admin only)""" # Check for admin permissions if context.get("auth_type") == "jwt": user = context.get("user") if not user or not user.get("is_superuser"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Admin privileges required" ) else: # For API key users, check admin permissions auth_service = APIKeyAuthService(db) if not await auth_service.check_scope_permission(context, "admin.cache"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Admin permissions required to invalidate cache" ) invalidate_models_cache() return {"message": "Models cache invalidated successfully"} @router.post("/chat/completions") async def create_chat_completion( request_body: Request, chat_request: ChatCompletionRequest, context: Dict[str, Any] = Depends(get_auth_context), db: AsyncSession = Depends(get_db) ): """Create chat completion with budget enforcement""" try: auth_type = context.get("auth_type", "api_key") # Handle different authentication types if auth_type == "api_key": auth_service = APIKeyAuthService(db) # Check permissions if not await auth_service.check_scope_permission(context, "chat.completions"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions for chat completions" ) if not await auth_service.check_model_permission(context, chat_request.model): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Model '{chat_request.model}' not allowed" ) api_key = context.get("api_key") if not api_key: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="API key information not available" ) elif auth_type == "jwt": # For JWT authentication, we'll skip the detailed permission checks for now # and create a dummy API key context for budget tracking user = context.get("user") if not user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="User information not available" ) api_key = None # JWT users don't have API keys else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication type" ) # Estimate token usage for budget checking messages_text = " ".join([msg.content for msg in chat_request.messages]) estimated_tokens = len(messages_text.split()) * 1.3 # Rough token estimation if chat_request.max_tokens: estimated_tokens += chat_request.max_tokens else: estimated_tokens += 150 # Default response length estimate # Get a synchronous session for budget enforcement from app.db.database import SessionLocal sync_db = SessionLocal() try: # Atomic budget check and reservation (only for API key users) warnings = [] reserved_budget_ids = [] if auth_type == "api_key" and api_key: is_allowed, error_message, budget_warnings, budget_ids = atomic_check_and_reserve_budget( sync_db, api_key, chat_request.model, int(estimated_tokens), "chat/completions" ) if not is_allowed: raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Budget exceeded: {error_message}" ) warnings = budget_warnings reserved_budget_ids = budget_ids # Convert messages to LLM service format llm_messages = [LLMChatMessage(role=msg.role, content=msg.content) for msg in chat_request.messages] # Create LLM service request llm_request = ChatRequest( model=chat_request.model, messages=llm_messages, temperature=chat_request.temperature, max_tokens=chat_request.max_tokens, top_p=chat_request.top_p, frequency_penalty=chat_request.frequency_penalty, presence_penalty=chat_request.presence_penalty, stop=chat_request.stop, stream=chat_request.stream or False, user_id=str(context.get("user_id", "anonymous")), api_key_id=context.get("api_key_id", 0) if auth_type == "api_key" else 0 ) # Make request to LLM service llm_response = await llm_service.create_chat_completion(llm_request) # Convert LLM service response to API format response = { "id": llm_response.id, "object": llm_response.object, "created": llm_response.created, "model": llm_response.model, "choices": [ { "index": choice.index, "message": { "role": choice.message.role, "content": choice.message.content }, "finish_reason": choice.finish_reason } for choice in llm_response.choices ], "usage": { "prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0, "completion_tokens": llm_response.usage.completion_tokens if llm_response.usage else 0, "total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0 } if llm_response.usage else { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 } } # Calculate actual cost and update usage usage = response.get("usage", {}) input_tokens = usage.get("prompt_tokens", 0) output_tokens = usage.get("completion_tokens", 0) total_tokens = usage.get("total_tokens", input_tokens + output_tokens) # Calculate accurate cost actual_cost_cents = CostCalculator.calculate_cost_cents( chat_request.model, input_tokens, output_tokens ) # Finalize actual usage in budgets (only for API key users) if auth_type == "api_key" and api_key: atomic_finalize_usage( sync_db, reserved_budget_ids, api_key, chat_request.model, input_tokens, output_tokens, "chat/completions" ) # Update API key usage statistics auth_service = APIKeyAuthService(db) await auth_service.update_usage_stats(context, total_tokens, actual_cost_cents) # Set analytics data for middleware set_analytics_data( model=chat_request.model, request_tokens=input_tokens, response_tokens=output_tokens, total_tokens=total_tokens, cost_cents=actual_cost_cents, budget_ids=reserved_budget_ids, budget_warnings=warnings ) # Add budget warnings to response if any if warnings: response["budget_warnings"] = warnings return response finally: sync_db.close() except HTTPException: raise except SecurityError as e: logger.warning(f"Security error in chat completion: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Security validation failed: {e.message}" ) except ValidationError as e: logger.warning(f"Validation error in chat completion: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Request validation failed: {e.message}" ) except ProviderError as e: logger.error(f"Provider error in chat completion: {e}") if "rate limit" in str(e).lower(): raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Rate limit exceeded" ) else: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="LLM service temporarily unavailable" ) except LLMError as e: logger.error(f"LLM service error in chat completion: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="LLM service error" ) except Exception as e: logger.error(f"Unexpected error creating chat completion: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create chat completion" ) @router.post("/embeddings") async def create_embedding( request: EmbeddingRequest, context: Dict[str, Any] = Depends(require_api_key), db: AsyncSession = Depends(get_db) ): """Create embedding with budget enforcement""" try: auth_service = APIKeyAuthService(db) # Check permissions if not await auth_service.check_scope_permission(context, "embeddings.create"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions for embeddings" ) if not await auth_service.check_model_permission(context, request.model): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Model '{request.model}' not allowed" ) api_key = context.get("api_key") if not api_key: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="API key information not available" ) # Estimate token usage for budget checking estimated_tokens = len(request.input.split()) * 1.3 # Rough token estimation # Convert AsyncSession to Session for budget enforcement sync_db = Session(bind=db.bind.sync_engine) try: # Check budget compliance before making request is_allowed, error_message, warnings = check_budget_for_request( sync_db, api_key, request.model, int(estimated_tokens), "embeddings" ) if not is_allowed: raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Budget exceeded: {error_message}" ) # Create LLM service request llm_request = LLMEmbeddingRequest( model=request.model, input=request.input, encoding_format=request.encoding_format, user_id=str(context["user_id"]), api_key_id=context["api_key_id"] ) # Make request to LLM service llm_response = await llm_service.create_embedding(llm_request) # Convert LLM service response to API format response = { "object": llm_response.object, "data": [ { "object": emb.object, "index": emb.index, "embedding": emb.embedding } for emb in llm_response.data ], "model": llm_response.model, "usage": { "prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0, "total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0 } if llm_response.usage else { "prompt_tokens": int(estimated_tokens), "total_tokens": int(estimated_tokens) } } # Calculate actual cost and update usage usage = response.get("usage", {}) total_tokens = usage.get("total_tokens", int(estimated_tokens)) # Calculate accurate cost (embeddings typically use input tokens only) actual_cost_cents = CostCalculator.calculate_cost_cents( request.model, total_tokens, 0 ) # Record actual usage in budgets record_request_usage( sync_db, api_key, request.model, total_tokens, 0, "embeddings" ) # Update API key usage statistics await auth_service.update_usage_stats(context, total_tokens, actual_cost_cents) # Add budget warnings to response if any if warnings: response["budget_warnings"] = warnings return response finally: sync_db.close() except HTTPException: raise except SecurityError as e: logger.warning(f"Security error in embedding: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Security validation failed: {e.message}" ) except ValidationError as e: logger.warning(f"Validation error in embedding: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Request validation failed: {e.message}" ) except ProviderError as e: logger.error(f"Provider error in embedding: {e}") if "rate limit" in str(e).lower(): raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Rate limit exceeded" ) else: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="LLM service temporarily unavailable" ) except LLMError as e: logger.error(f"LLM service error in embedding: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="LLM service error" ) except Exception as e: logger.error(f"Unexpected error creating embedding: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create embedding" ) @router.get("/health") async def llm_health_check( context: Dict[str, Any] = Depends(get_auth_context) ): """Health check for LLM service""" try: health_summary = llm_service.get_health_summary() provider_status = await llm_service.get_provider_status() # Determine overall health overall_status = "healthy" if health_summary["service_status"] != "healthy": overall_status = "degraded" for provider, status in provider_status.items(): if status.status == "unavailable": overall_status = "degraded" break return { "status": overall_status, "service": "LLM Service", "service_status": health_summary, "provider_status": {name: { "status": status.status, "latency_ms": status.latency_ms, "error_message": status.error_message } for name, status in provider_status.items()}, "user_id": context["user_id"], "api_key_name": context["api_key_name"] } except Exception as e: logger.error(f"LLM health check error: {e}") return { "status": "unhealthy", "service": "LLM Service", "error": str(e) } @router.get("/usage") async def get_usage_stats( context: Dict[str, Any] = Depends(require_api_key) ): """Get usage statistics for the API key""" try: api_key = context.get("api_key") if not api_key: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="API key information not available" ) return { "api_key_id": api_key.id, "api_key_name": api_key.name, "total_requests": api_key.total_requests, "total_tokens": api_key.total_tokens, "total_cost_cents": api_key.total_cost, "created_at": api_key.created_at.isoformat(), "last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None, "rate_limits": { "per_minute": api_key.rate_limit_per_minute, "per_hour": api_key.rate_limit_per_hour, "per_day": api_key.rate_limit_per_day }, "permissions": api_key.permissions, "scopes": api_key.scopes, "allowed_models": api_key.allowed_models } except HTTPException: raise except Exception as e: logger.error(f"Error getting usage stats: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get usage statistics" ) @router.get("/budget/status") async def get_budget_status( request: Request, context: Dict[str, Any] = Depends(get_auth_context), db: AsyncSession = Depends(get_db) ): """Get current budget status and usage analytics""" try: auth_type = context.get("auth_type", "api_key") # Check permissions based on auth type if auth_type == "api_key": auth_service = APIKeyAuthService(db) if not await auth_service.check_scope_permission(context, "budget.read"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions to read budget information" ) api_key = context.get("api_key") if not api_key: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="API key information not available" ) # Convert AsyncSession to Session for budget enforcement sync_db = Session(bind=db.bind.sync_engine) try: budget_service = BudgetEnforcementService(sync_db) budget_status = budget_service.get_budget_status(api_key) return { "object": "budget_status", "data": budget_status } finally: sync_db.close() elif auth_type == "jwt": # For JWT authentication, return user-level budget information user = context.get("user") if not user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="User information not available" ) # Return basic budget info for JWT users return { "object": "budget_status", "data": { "budgets": [], "total_usage": 0.0, "warnings": [], "projections": { "daily_burn_rate": 0.0, "projected_monthly": 0.0, "days_remaining": 30 } } } else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication type" ) except HTTPException: raise except Exception as e: logger.error(f"Error getting budget status: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get budget status" ) # Generic endpoint for additional LLM service functionality @router.get("/metrics") async def get_llm_metrics( context: Dict[str, Any] = Depends(require_api_key), db: AsyncSession = Depends(get_db) ): """Get LLM service metrics (admin only)""" try: # Check for admin permissions auth_service = APIKeyAuthService(db) if not await auth_service.check_scope_permission(context, "admin.metrics"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Admin permissions required to view metrics" ) metrics = llm_service.get_metrics() return { "object": "llm_metrics", "data": { "total_requests": metrics.total_requests, "successful_requests": metrics.successful_requests, "failed_requests": metrics.failed_requests, "security_blocked_requests": metrics.security_blocked_requests, "average_latency_ms": metrics.average_latency_ms, "average_risk_score": metrics.average_risk_score, "provider_metrics": metrics.provider_metrics, "last_updated": metrics.last_updated.isoformat() } } except HTTPException: raise except Exception as e: logger.error(f"Error getting LLM metrics: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get LLM metrics" ) @router.get("/providers/status") async def get_provider_status( context: Dict[str, Any] = Depends(require_api_key), db: AsyncSession = Depends(get_db) ): """Get status of all LLM providers""" try: auth_service = APIKeyAuthService(db) if not await auth_service.check_scope_permission(context, "admin.status"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Admin permissions required to view provider status" ) provider_status = await llm_service.get_provider_status() return { "object": "provider_status", "data": { name: { "provider": status.provider, "status": status.status, "latency_ms": status.latency_ms, "success_rate": status.success_rate, "last_check": status.last_check.isoformat(), "error_message": status.error_message, "models_available": status.models_available } for name, status in provider_status.items() } } except HTTPException: raise except Exception as e: logger.error(f"Error getting provider status: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get provider status" )