mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
143 lines
5.3 KiB
Python
143 lines
5.3 KiB
Python
"""
|
|
Analytics middleware for automatic request tracking
|
|
"""
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
from fastapi import Request, Response
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from sqlalchemy.orm import Session
|
|
from contextvars import ContextVar
|
|
|
|
from app.core.logging import get_logger
|
|
from app.services.analytics import RequestEvent, get_analytics_service
|
|
from app.db.database import get_db
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# Context variable to pass analytics data from endpoints to middleware
|
|
analytics_context: ContextVar[dict] = ContextVar('analytics_context', default={})
|
|
|
|
|
|
class AnalyticsMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware to automatically track all requests for analytics"""
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# Start timing
|
|
start_time = time.time()
|
|
|
|
# Skip analytics for health checks and static files
|
|
if request.url.path in ["/health", "/docs", "/redoc", "/openapi.json"] or request.url.path.startswith("/static"):
|
|
return await call_next(request)
|
|
|
|
# Get user info if available from token
|
|
user_id = None
|
|
api_key_id = None
|
|
|
|
try:
|
|
authorization = request.headers.get("Authorization")
|
|
if authorization and authorization.startswith("Bearer "):
|
|
token = authorization.split(" ")[1]
|
|
# Try to extract user info from token without full validation
|
|
# This is a lightweight check for analytics purposes
|
|
from app.core.security import verify_token
|
|
try:
|
|
payload = verify_token(token)
|
|
user_id = int(payload.get("sub"))
|
|
except:
|
|
# Token might be invalid, but we still want to track the request
|
|
pass
|
|
except Exception:
|
|
# Don't let analytics break the request
|
|
pass
|
|
|
|
# Get client IP
|
|
client_ip = request.client.host if request.client else None
|
|
if not client_ip:
|
|
# Check for forwarded headers
|
|
client_ip = request.headers.get("X-Forwarded-For", "").split(",")[0].strip()
|
|
if not client_ip:
|
|
client_ip = request.headers.get("X-Real-IP", "unknown")
|
|
|
|
# Get user agent
|
|
user_agent = request.headers.get("User-Agent", "")
|
|
|
|
# Get request size
|
|
request_size = int(request.headers.get("Content-Length", 0))
|
|
|
|
# Process the request
|
|
response = None
|
|
error_message = None
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
except Exception as e:
|
|
logger.error(f"Request failed: {e}")
|
|
error_message = str(e)
|
|
response = JSONResponse(
|
|
status_code=500,
|
|
content={"error": "INTERNAL_ERROR", "message": "Internal server error"}
|
|
)
|
|
|
|
# Calculate timing
|
|
end_time = time.time()
|
|
response_time = (end_time - start_time) * 1000 # Convert to milliseconds
|
|
|
|
# Get response size
|
|
response_size = 0
|
|
if hasattr(response, 'body'):
|
|
response_size = len(response.body) if response.body else 0
|
|
|
|
# Get analytics data from context (set by endpoints)
|
|
context_data = analytics_context.get({})
|
|
|
|
# Create analytics event
|
|
event = RequestEvent(
|
|
timestamp=datetime.utcnow(),
|
|
method=request.method,
|
|
path=request.url.path,
|
|
status_code=response.status_code if response else 500,
|
|
response_time=response_time,
|
|
user_id=user_id,
|
|
api_key_id=api_key_id,
|
|
ip_address=client_ip,
|
|
user_agent=user_agent,
|
|
request_size=request_size,
|
|
response_size=response_size,
|
|
error_message=error_message,
|
|
# Token/cost info populated by LLM endpoints via context
|
|
model=context_data.get('model'),
|
|
request_tokens=context_data.get('request_tokens', 0),
|
|
response_tokens=context_data.get('response_tokens', 0),
|
|
total_tokens=context_data.get('total_tokens', 0),
|
|
cost_cents=context_data.get('cost_cents', 0),
|
|
budget_ids=context_data.get('budget_ids', []),
|
|
budget_warnings=context_data.get('budget_warnings', [])
|
|
)
|
|
|
|
# Track the event
|
|
try:
|
|
from app.services.analytics import analytics_service
|
|
if analytics_service is not None:
|
|
await analytics_service.track_request(event)
|
|
else:
|
|
logger.warning("Analytics service not initialized, skipping event tracking")
|
|
except Exception as e:
|
|
logger.error(f"Failed to track analytics event: {e}")
|
|
# Don't let analytics failures break the request
|
|
|
|
return response
|
|
|
|
|
|
def set_analytics_data(**kwargs):
|
|
"""Helper function for endpoints to set analytics data"""
|
|
current_context = analytics_context.get({})
|
|
current_context.update(kwargs)
|
|
analytics_context.set(current_context)
|
|
|
|
|
|
def setup_analytics_middleware(app):
|
|
"""Add analytics middleware to the FastAPI app"""
|
|
app.add_middleware(AnalyticsMiddleware)
|
|
logger.info("Analytics middleware configured") |