Files
enclava/backend/app/middleware/analytics.py
2025-08-19 09:50:15 +02:00

143 lines
5.3 KiB
Python

"""
Analytics middleware for automatic request tracking
"""
import time
from datetime import datetime
from typing import Optional
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from sqlalchemy.orm import Session
from contextvars import ContextVar
from app.core.logging import get_logger
from app.services.analytics import RequestEvent, get_analytics_service
from app.db.database import get_db
logger = get_logger(__name__)
# Context variable to pass analytics data from endpoints to middleware
analytics_context: ContextVar[dict] = ContextVar('analytics_context', default={})
class AnalyticsMiddleware(BaseHTTPMiddleware):
"""Middleware to automatically track all requests for analytics"""
async def dispatch(self, request: Request, call_next):
# Start timing
start_time = time.time()
# Skip analytics for health checks and static files
if request.url.path in ["/health", "/docs", "/redoc", "/openapi.json"] or request.url.path.startswith("/static"):
return await call_next(request)
# Get user info if available from token
user_id = None
api_key_id = None
try:
authorization = request.headers.get("Authorization")
if authorization and authorization.startswith("Bearer "):
token = authorization.split(" ")[1]
# Try to extract user info from token without full validation
# This is a lightweight check for analytics purposes
from app.core.security import verify_token
try:
payload = verify_token(token)
user_id = int(payload.get("sub"))
except:
# Token might be invalid, but we still want to track the request
pass
except Exception:
# Don't let analytics break the request
pass
# Get client IP
client_ip = request.client.host if request.client else None
if not client_ip:
# Check for forwarded headers
client_ip = request.headers.get("X-Forwarded-For", "").split(",")[0].strip()
if not client_ip:
client_ip = request.headers.get("X-Real-IP", "unknown")
# Get user agent
user_agent = request.headers.get("User-Agent", "")
# Get request size
request_size = int(request.headers.get("Content-Length", 0))
# Process the request
response = None
error_message = None
try:
response = await call_next(request)
except Exception as e:
logger.error(f"Request failed: {e}")
error_message = str(e)
response = JSONResponse(
status_code=500,
content={"error": "INTERNAL_ERROR", "message": "Internal server error"}
)
# Calculate timing
end_time = time.time()
response_time = (end_time - start_time) * 1000 # Convert to milliseconds
# Get response size
response_size = 0
if hasattr(response, 'body'):
response_size = len(response.body) if response.body else 0
# Get analytics data from context (set by endpoints)
context_data = analytics_context.get({})
# Create analytics event
event = RequestEvent(
timestamp=datetime.utcnow(),
method=request.method,
path=request.url.path,
status_code=response.status_code if response else 500,
response_time=response_time,
user_id=user_id,
api_key_id=api_key_id,
ip_address=client_ip,
user_agent=user_agent,
request_size=request_size,
response_size=response_size,
error_message=error_message,
# Token/cost info populated by LLM endpoints via context
model=context_data.get('model'),
request_tokens=context_data.get('request_tokens', 0),
response_tokens=context_data.get('response_tokens', 0),
total_tokens=context_data.get('total_tokens', 0),
cost_cents=context_data.get('cost_cents', 0),
budget_ids=context_data.get('budget_ids', []),
budget_warnings=context_data.get('budget_warnings', [])
)
# Track the event
try:
from app.services.analytics import analytics_service
if analytics_service is not None:
await analytics_service.track_request(event)
else:
logger.warning("Analytics service not initialized, skipping event tracking")
except Exception as e:
logger.error(f"Failed to track analytics event: {e}")
# Don't let analytics failures break the request
return response
def set_analytics_data(**kwargs):
"""Helper function for endpoints to set analytics data"""
current_context = analytics_context.get({})
current_context.update(kwargs)
analytics_context.set(current_context)
def setup_analytics_middleware(app):
"""Add analytics middleware to the FastAPI app"""
app.add_middleware(AnalyticsMiddleware)
logger.info("Analytics middleware configured")