mega changes

This commit is contained in:
2025-11-20 11:11:18 +01:00
parent e070c95190
commit 841d79f26b
138 changed files with 21499 additions and 8844 deletions

View File

@@ -0,0 +1,393 @@
"""
Audit Logging Middleware
Automatically logs user actions and system events
"""
import time
import json
import logging
from typing import Callable, Optional, Dict, Any
from datetime import datetime
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.audit_log import AuditLog, AuditAction, AuditSeverity
from app.db.database import get_db_session
from app.core.security import verify_token
logger = logging.getLogger(__name__)
class AuditLoggingMiddleware(BaseHTTPMiddleware):
"""Middleware to automatically log user actions and API calls"""
def __init__(self, app, exclude_paths: Optional[list] = None):
super().__init__(app)
# Paths to exclude from audit logging
self.exclude_paths = exclude_paths or [
"/docs",
"/redoc",
"/openapi.json",
"/health",
"/metrics",
"/static",
"/favicon.ico",
]
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Skip audit logging for excluded paths
if any(request.url.path.startswith(path) for path in self.exclude_paths):
return await call_next(request)
# Skip audit logging for health checks and static assets
if request.url.path in ["/", "/health"] or "/static/" in request.url.path:
return await call_next(request)
start_time = time.time()
# Extract user information from request
user_info = await self._extract_user_info(request)
# Prepare audit data
audit_data = {
"method": request.method,
"path": request.url.path,
"query_params": dict(request.query_params),
"ip_address": self._get_client_ip(request),
"user_agent": request.headers.get("user-agent"),
"timestamp": datetime.utcnow().isoformat(),
}
# Process request
response = await call_next(request)
# Calculate response time
process_time = time.time() - start_time
audit_data["response_time"] = round(process_time * 1000, 2) # milliseconds
audit_data["status_code"] = response.status_code
audit_data["success"] = 200 <= response.status_code < 400
# Log the audit event asynchronously
try:
await self._log_audit_event(user_info, audit_data, request)
except Exception as e:
logger.error(f"Failed to log audit event: {e}")
# Don't fail the request if audit logging fails
return response
async def _extract_user_info(self, request: Request) -> Optional[Dict[str, Any]]:
"""Extract user information from request headers"""
try:
# Try to get user info from Authorization header
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.split(" ")[1]
payload = verify_token(token)
return {
"user_id": int(payload.get("sub")) if payload.get("sub") else None,
"email": payload.get("email"),
"is_superuser": payload.get("is_superuser", False),
"role": payload.get("role"),
}
except Exception:
# If token verification fails, continue without user info
pass
# Try to get user info from API key header
api_key = request.headers.get("x-api-key")
if api_key:
# Would need to implement API key lookup here
# For now, just indicate it's an API key request
return {
"user_id": None,
"email": "api_key_user",
"is_superuser": False,
"role": "api_user",
"auth_type": "api_key",
}
return None
def _get_client_ip(self, request: Request) -> str:
"""Get client IP address with proxy support"""
# Check for forwarded headers first (for reverse proxy setups)
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
# Take the first IP in the chain
return forwarded_for.split(",")[0].strip()
forwarded = request.headers.get("x-forwarded")
if forwarded:
return forwarded
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
# Fall back to direct client IP
return request.client.host if request.client else "unknown"
async def _log_audit_event(
self,
user_info: Optional[Dict[str, Any]],
audit_data: Dict[str, Any],
request: Request
):
"""Log the audit event to database"""
# Determine action based on HTTP method and path
action = self._determine_action(request.method, request.url.path)
# Determine resource type and ID from path
resource_type, resource_id = self._parse_resource_from_path(request.url.path)
# Create description
description = self._create_description(request.method, request.url.path, audit_data["success"])
# Determine severity
severity = self._determine_severity(request.method, audit_data["status_code"], request.url.path)
# Create audit log entry
try:
async with get_db_session() as db:
audit_log = AuditLog(
user_id=user_info.get("user_id") if user_info else None,
action=action,
resource_type=resource_type,
resource_id=resource_id,
description=description,
details={
"request": {
"method": audit_data["method"],
"path": audit_data["path"],
"query_params": audit_data["query_params"],
"response_time_ms": audit_data["response_time"],
},
"user_info": user_info,
},
ip_address=audit_data["ip_address"],
user_agent=audit_data["user_agent"],
severity=severity,
category=self._determine_category(request.url.path),
success=audit_data["success"],
tags=self._generate_tags(request.method, request.url.path),
)
db.add(audit_log)
await db.commit()
except Exception as e:
logger.error(f"Failed to save audit log to database: {e}")
# Could implement fallback logging to file here
def _determine_action(self, method: str, path: str) -> str:
"""Determine action type from HTTP method and path"""
method = method.upper()
if method == "GET":
return AuditAction.READ
elif method == "POST":
if "login" in path.lower():
return AuditAction.LOGIN
elif "logout" in path.lower():
return AuditAction.LOGOUT
else:
return AuditAction.CREATE
elif method == "PUT" or method == "PATCH":
return AuditAction.UPDATE
elif method == "DELETE":
return AuditAction.DELETE
else:
return method.lower()
def _parse_resource_from_path(self, path: str) -> tuple[str, Optional[str]]:
"""Parse resource type and ID from URL path"""
path_parts = path.strip("/").split("/")
# Skip API version prefix
if path_parts and path_parts[0] in ["api", "api-internal"]:
path_parts = path_parts[2:] # Skip 'api' and 'v1'
if not path_parts:
return "system", None
resource_type = path_parts[0]
resource_id = None
# Try to find numeric ID in path
for part in path_parts[1:]:
if part.isdigit():
resource_id = part
break
return resource_type, resource_id
def _create_description(self, method: str, path: str, success: bool) -> str:
"""Create human-readable description of the action"""
action_verbs = {
"GET": "accessed" if success else "attempted to access",
"POST": "created" if success else "attempted to create",
"PUT": "updated" if success else "attempted to update",
"PATCH": "modified" if success else "attempted to modify",
"DELETE": "deleted" if success else "attempted to delete",
}
verb = action_verbs.get(method, method.lower())
resource = path.strip("/").split("/")[-1] if "/" in path else path
return f"User {verb} {resource}"
def _determine_severity(self, method: str, status_code: int, path: str) -> str:
"""Determine severity level based on action and outcome"""
# Critical operations
if any(keyword in path.lower() for keyword in ["delete", "password", "admin", "key"]):
return AuditSeverity.HIGH
# Failed operations
if status_code >= 400:
if status_code >= 500:
return AuditSeverity.CRITICAL
elif status_code in [401, 403]:
return AuditSeverity.HIGH
else:
return AuditSeverity.MEDIUM
# Write operations
if method in ["POST", "PUT", "PATCH", "DELETE"]:
return AuditSeverity.MEDIUM
# Read operations
return AuditSeverity.LOW
def _determine_category(self, path: str) -> str:
"""Determine category based on path"""
path = path.lower()
if any(keyword in path for keyword in ["auth", "login", "logout", "token"]):
return "authentication"
elif any(keyword in path for keyword in ["user", "admin", "role", "permission"]):
return "user_management"
elif any(keyword in path for keyword in ["api-key", "key"]):
return "security"
elif any(keyword in path for keyword in ["budget", "billing", "usage"]):
return "financial"
elif any(keyword in path for keyword in ["audit", "log"]):
return "audit"
elif any(keyword in path for keyword in ["setting", "config"]):
return "configuration"
else:
return "general"
def _generate_tags(self, method: str, path: str) -> list[str]:
"""Generate tags for the audit log"""
tags = [method.lower()]
path_parts = path.strip("/").split("/")
if path_parts:
tags.append(path_parts[0])
# Add special tags
if "admin" in path.lower():
tags.append("admin_action")
if any(keyword in path.lower() for keyword in ["password", "auth", "login"]):
tags.append("security_action")
return tags
class LoginAuditMiddleware(BaseHTTPMiddleware):
"""Specialized middleware for login/logout events"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Only process auth-related endpoints
if not any(path in request.url.path for path in ["/auth/login", "/auth/logout", "/auth/refresh"]):
return await call_next(request)
start_time = time.time()
# Store request body for login attempts
request_body = None
if request.method == "POST" and "/login" in request.url.path:
try:
body = await request.body()
if body:
request_body = json.loads(body.decode())
# Re-create request with body for downstream processing
from starlette.requests import Request as StarletteRequest
from io import BytesIO
request._body = body
except Exception as e:
logger.warning(f"Failed to parse login request body: {e}")
response = await call_next(request)
# Log login/logout events
try:
await self._log_auth_event(request, response, request_body, time.time() - start_time)
except Exception as e:
logger.error(f"Failed to log auth event: {e}")
return response
async def _log_auth_event(self, request: Request, response: Response, request_body: dict, process_time: float):
"""Log authentication events"""
success = 200 <= response.status_code < 300
if "/login" in request.url.path:
# Extract email/username from request
identifier = None
if request_body:
identifier = request_body.get("email") or request_body.get("username")
# For successful logins, we could extract user_id from response
# For now, we'll use the identifier
async with get_db_session() as db:
audit_log = AuditLog.create_login_event(
user_id=None, # Would need to extract from response for successful logins
success=success,
ip_address=self._get_client_ip(request),
user_agent=request.headers.get("user-agent"),
error_message=f"HTTP {response.status_code}" if not success else None,
)
# Add additional details
audit_log.details.update({
"identifier": identifier,
"response_time_ms": round(process_time * 1000, 2),
})
db.add(audit_log)
await db.commit()
elif "/logout" in request.url.path:
# Extract user info from token if available
user_id = None
try:
auth_header = request.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.split(" ")[1]
payload = verify_token(token)
user_id = int(payload.get("sub")) if payload.get("sub") else None
except Exception:
pass
async with get_db_session() as db:
audit_log = AuditLog.create_logout_event(
user_id=user_id,
session_id=None, # Could extract from token if stored
)
db.add(audit_log)
await db.commit()
def _get_client_ip(self, request: Request) -> str:
"""Get client IP address with proxy support"""
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
return request.client.host if request.client else "unknown"