mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
248 lines
8.6 KiB
Python
248 lines
8.6 KiB
Python
"""
|
|
API Key Authentication Service
|
|
Handles API key validation and user authentication with Redis caching for performance
|
|
"""
|
|
|
|
import logging
|
|
from typing import Optional, Dict, Any
|
|
from datetime import datetime
|
|
|
|
from fastapi import HTTPException, Request, status, Depends
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
|
|
from app.core.security import verify_api_key
|
|
from app.db.database import get_db
|
|
from app.models.api_key import APIKey
|
|
from app.models.user import User
|
|
from app.utils.exceptions import AuthenticationError, AuthorizationError
|
|
from app.services.cached_api_key import cached_api_key_service
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class APIKeyAuthService:
|
|
"""Service for API key authentication and validation"""
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def validate_api_key(self, api_key: str, request: Request) -> Optional[Dict[str, Any]]:
|
|
"""Validate API key and return user context using Redis cache for performance"""
|
|
try:
|
|
if not api_key:
|
|
return None
|
|
|
|
# Extract key prefix for lookup
|
|
if len(api_key) < 8:
|
|
logger.warning(f"Invalid API key format: too short")
|
|
return None
|
|
|
|
key_prefix = api_key[:8]
|
|
|
|
# Try cached verification first
|
|
cached_verification = await cached_api_key_service.verify_api_key_cached(api_key, key_prefix)
|
|
|
|
# Get API key data from cache or database
|
|
context = await cached_api_key_service.get_cached_api_key(key_prefix, self.db)
|
|
|
|
if not context:
|
|
logger.warning(f"API key not found: {key_prefix}")
|
|
return None
|
|
|
|
api_key_obj = context["api_key"]
|
|
|
|
# If not in verification cache, verify and cache the result
|
|
if not cached_verification:
|
|
# Get the actual key hash for verification (this should be in the cached context)
|
|
db_api_key = None
|
|
if not hasattr(api_key_obj, 'key_hash'):
|
|
# Fallback: fetch full API key from database for hash
|
|
stmt = select(APIKey).where(APIKey.key_prefix == key_prefix)
|
|
result = await self.db.execute(stmt)
|
|
db_api_key = result.scalar_one_or_none()
|
|
if not db_api_key:
|
|
return None
|
|
key_hash = db_api_key.key_hash
|
|
else:
|
|
key_hash = api_key_obj.key_hash
|
|
|
|
# Verify the API key hash
|
|
if not verify_api_key(api_key, key_hash):
|
|
logger.warning(f"Invalid API key hash: {key_prefix}")
|
|
return None
|
|
|
|
# Cache successful verification
|
|
await cached_api_key_service.cache_verification_result(api_key, key_prefix, key_hash, True)
|
|
|
|
# Check if key is valid (expiry, active status)
|
|
if not api_key_obj.is_valid():
|
|
logger.warning(f"API key expired or inactive: {key_prefix}")
|
|
# Invalidate cache for expired keys
|
|
await cached_api_key_service.invalidate_api_key_cache(key_prefix)
|
|
return None
|
|
|
|
# Check IP restrictions
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
if not api_key_obj.can_access_from_ip(client_ip):
|
|
logger.warning(f"IP not allowed for API key {key_prefix}: {client_ip}")
|
|
return None
|
|
|
|
# Update last used timestamp asynchronously (performance optimization)
|
|
await cached_api_key_service.update_last_used(context["api_key_id"], self.db)
|
|
|
|
return context
|
|
|
|
except Exception as e:
|
|
logger.error(f"API key validation error: {e}")
|
|
return None
|
|
|
|
async def check_endpoint_permission(self, context: Dict[str, Any], endpoint: str) -> bool:
|
|
"""Check if API key has permission to access endpoint"""
|
|
api_key: APIKey = context.get("api_key")
|
|
if not api_key:
|
|
return False
|
|
|
|
return api_key.can_access_endpoint(endpoint)
|
|
|
|
async def check_model_permission(self, context: Dict[str, Any], model: str) -> bool:
|
|
"""Check if API key has permission to access model"""
|
|
api_key: APIKey = context.get("api_key")
|
|
if not api_key:
|
|
return False
|
|
|
|
return api_key.can_access_model(model)
|
|
|
|
async def check_scope_permission(self, context: Dict[str, Any], scope: str) -> bool:
|
|
"""Check if API key has required scope"""
|
|
api_key: APIKey = context.get("api_key")
|
|
if not api_key:
|
|
return False
|
|
|
|
return api_key.has_scope(scope)
|
|
|
|
async def update_usage_stats(self, context: Dict[str, Any], tokens_used: int = 0, cost_cents: int = 0):
|
|
"""Update API key usage statistics"""
|
|
try:
|
|
api_key: APIKey = context.get("api_key")
|
|
if api_key:
|
|
api_key.update_usage(tokens_used, cost_cents)
|
|
await self.db.commit()
|
|
logger.info(f"Updated usage for API key {api_key.key_prefix}: +{tokens_used} tokens, +{cost_cents} cents")
|
|
except Exception as e:
|
|
logger.error(f"Failed to update usage stats: {e}")
|
|
|
|
|
|
async def get_api_key_context(
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db)
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Dependency to get API key context from request"""
|
|
auth_service = APIKeyAuthService(db)
|
|
|
|
# Try different auth methods
|
|
api_key = None
|
|
|
|
# 1. Check Authorization header (Bearer token)
|
|
auth_header = request.headers.get("Authorization")
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
api_key = auth_header[7:]
|
|
|
|
# 2. Check X-API-Key header
|
|
if not api_key:
|
|
api_key = request.headers.get("X-API-Key")
|
|
|
|
# 3. Check query parameter
|
|
if not api_key:
|
|
api_key = request.query_params.get("api_key")
|
|
|
|
if not api_key:
|
|
return None
|
|
|
|
return await auth_service.validate_api_key(api_key, request)
|
|
|
|
|
|
async def require_api_key(
|
|
context: Dict[str, Any] = Depends(get_api_key_context)
|
|
) -> Dict[str, Any]:
|
|
"""Dependency that requires valid API key"""
|
|
if not context:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Valid API key required",
|
|
headers={"WWW-Authenticate": "Bearer"}
|
|
)
|
|
return context
|
|
|
|
|
|
async def get_current_api_key_user(
|
|
context: Dict[str, Any] = Depends(require_api_key)
|
|
) -> tuple:
|
|
"""
|
|
Dependency that returns current user and API key as a tuple
|
|
|
|
Returns:
|
|
tuple: (user, api_key)
|
|
"""
|
|
user = context.get("user")
|
|
api_key = context.get("api_key")
|
|
|
|
if not user or not api_key:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User or API key not found in context"
|
|
)
|
|
|
|
return user, api_key
|
|
|
|
|
|
async def get_api_key_auth(
|
|
context: Dict[str, Any] = Depends(require_api_key)
|
|
) -> APIKey:
|
|
"""
|
|
Dependency that returns the authenticated API key object
|
|
|
|
Returns:
|
|
APIKey: The authenticated API key object
|
|
"""
|
|
api_key = context.get("api_key")
|
|
|
|
if not api_key:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="API key not found in context"
|
|
)
|
|
|
|
return api_key
|
|
|
|
|
|
class RequireScope:
|
|
"""Dependency class for scope checking"""
|
|
|
|
def __init__(self, scope: str):
|
|
self.scope = scope
|
|
|
|
async def __call__(self, context: Dict[str, Any] = Depends(require_api_key)):
|
|
auth_service = APIKeyAuthService(context.get("db"))
|
|
if not await auth_service.check_scope_permission(context, self.scope):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Scope '{self.scope}' required"
|
|
)
|
|
return context
|
|
|
|
|
|
class RequireModel:
|
|
"""Dependency class for model access checking"""
|
|
|
|
def __init__(self, model: str):
|
|
self.model = model
|
|
|
|
async def __call__(self, context: Dict[str, Any] = Depends(require_api_key)):
|
|
auth_service = APIKeyAuthService(context.get("db"))
|
|
if not await auth_service.check_model_permission(context, self.model):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Model '{self.model}' not allowed"
|
|
)
|
|
return context |