Files
enclava/backend/app/services/api_key_auth.py
2025-08-19 09:50:15 +02:00

248 lines
8.6 KiB
Python

"""
API Key Authentication Service
Handles API key validation and user authentication with Redis caching for performance
"""
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from fastapi import HTTPException, Request, status, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.security import verify_api_key
from app.db.database import get_db
from app.models.api_key import APIKey
from app.models.user import User
from app.utils.exceptions import AuthenticationError, AuthorizationError
from app.services.cached_api_key import cached_api_key_service
logger = logging.getLogger(__name__)
class APIKeyAuthService:
"""Service for API key authentication and validation"""
def __init__(self, db: AsyncSession):
self.db = db
async def validate_api_key(self, api_key: str, request: Request) -> Optional[Dict[str, Any]]:
"""Validate API key and return user context using Redis cache for performance"""
try:
if not api_key:
return None
# Extract key prefix for lookup
if len(api_key) < 8:
logger.warning(f"Invalid API key format: too short")
return None
key_prefix = api_key[:8]
# Try cached verification first
cached_verification = await cached_api_key_service.verify_api_key_cached(api_key, key_prefix)
# Get API key data from cache or database
context = await cached_api_key_service.get_cached_api_key(key_prefix, self.db)
if not context:
logger.warning(f"API key not found: {key_prefix}")
return None
api_key_obj = context["api_key"]
# If not in verification cache, verify and cache the result
if not cached_verification:
# Get the actual key hash for verification (this should be in the cached context)
db_api_key = None
if not hasattr(api_key_obj, 'key_hash'):
# Fallback: fetch full API key from database for hash
stmt = select(APIKey).where(APIKey.key_prefix == key_prefix)
result = await self.db.execute(stmt)
db_api_key = result.scalar_one_or_none()
if not db_api_key:
return None
key_hash = db_api_key.key_hash
else:
key_hash = api_key_obj.key_hash
# Verify the API key hash
if not verify_api_key(api_key, key_hash):
logger.warning(f"Invalid API key hash: {key_prefix}")
return None
# Cache successful verification
await cached_api_key_service.cache_verification_result(api_key, key_prefix, key_hash, True)
# Check if key is valid (expiry, active status)
if not api_key_obj.is_valid():
logger.warning(f"API key expired or inactive: {key_prefix}")
# Invalidate cache for expired keys
await cached_api_key_service.invalidate_api_key_cache(key_prefix)
return None
# Check IP restrictions
client_ip = request.client.host if request.client else "unknown"
if not api_key_obj.can_access_from_ip(client_ip):
logger.warning(f"IP not allowed for API key {key_prefix}: {client_ip}")
return None
# Update last used timestamp asynchronously (performance optimization)
await cached_api_key_service.update_last_used(context["api_key_id"], self.db)
return context
except Exception as e:
logger.error(f"API key validation error: {e}")
return None
async def check_endpoint_permission(self, context: Dict[str, Any], endpoint: str) -> bool:
"""Check if API key has permission to access endpoint"""
api_key: APIKey = context.get("api_key")
if not api_key:
return False
return api_key.can_access_endpoint(endpoint)
async def check_model_permission(self, context: Dict[str, Any], model: str) -> bool:
"""Check if API key has permission to access model"""
api_key: APIKey = context.get("api_key")
if not api_key:
return False
return api_key.can_access_model(model)
async def check_scope_permission(self, context: Dict[str, Any], scope: str) -> bool:
"""Check if API key has required scope"""
api_key: APIKey = context.get("api_key")
if not api_key:
return False
return api_key.has_scope(scope)
async def update_usage_stats(self, context: Dict[str, Any], tokens_used: int = 0, cost_cents: int = 0):
"""Update API key usage statistics"""
try:
api_key: APIKey = context.get("api_key")
if api_key:
api_key.update_usage(tokens_used, cost_cents)
await self.db.commit()
logger.info(f"Updated usage for API key {api_key.key_prefix}: +{tokens_used} tokens, +{cost_cents} cents")
except Exception as e:
logger.error(f"Failed to update usage stats: {e}")
async def get_api_key_context(
request: Request,
db: AsyncSession = Depends(get_db)
) -> Optional[Dict[str, Any]]:
"""Dependency to get API key context from request"""
auth_service = APIKeyAuthService(db)
# Try different auth methods
api_key = None
# 1. Check Authorization header (Bearer token)
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:]
# 2. Check X-API-Key header
if not api_key:
api_key = request.headers.get("X-API-Key")
# 3. Check query parameter
if not api_key:
api_key = request.query_params.get("api_key")
if not api_key:
return None
return await auth_service.validate_api_key(api_key, request)
async def require_api_key(
context: Dict[str, Any] = Depends(get_api_key_context)
) -> Dict[str, Any]:
"""Dependency that requires valid API key"""
if not context:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Valid API key required",
headers={"WWW-Authenticate": "Bearer"}
)
return context
async def get_current_api_key_user(
context: Dict[str, Any] = Depends(require_api_key)
) -> tuple:
"""
Dependency that returns current user and API key as a tuple
Returns:
tuple: (user, api_key)
"""
user = context.get("user")
api_key = context.get("api_key")
if not user or not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User or API key not found in context"
)
return user, api_key
async def get_api_key_auth(
context: Dict[str, Any] = Depends(require_api_key)
) -> APIKey:
"""
Dependency that returns the authenticated API key object
Returns:
APIKey: The authenticated API key object
"""
api_key = context.get("api_key")
if not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key not found in context"
)
return api_key
class RequireScope:
"""Dependency class for scope checking"""
def __init__(self, scope: str):
self.scope = scope
async def __call__(self, context: Dict[str, Any] = Depends(require_api_key)):
auth_service = APIKeyAuthService(context.get("db"))
if not await auth_service.check_scope_permission(context, self.scope):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Scope '{self.scope}' required"
)
return context
class RequireModel:
"""Dependency class for model access checking"""
def __init__(self, model: str):
self.model = model
async def __call__(self, context: Dict[str, Any] = Depends(require_api_key)):
auth_service = APIKeyAuthService(context.get("db"))
if not await auth_service.check_model_permission(context, self.model):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Model '{self.model}' not allowed"
)
return context