mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
clean commit
This commit is contained in:
278
backend/app/middleware/security.py
Normal file
278
backend/app/middleware/security.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Security middleware for request/response processing
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Callable, Optional, Dict, Any
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.core.threat_detection import threat_detection_service, SecurityAnalysis
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SecurityMiddleware(BaseHTTPMiddleware):
|
||||
"""Security middleware for threat detection and request filtering"""
|
||||
|
||||
def __init__(self, app, enabled: bool = True):
|
||||
super().__init__(app)
|
||||
self.enabled = enabled and settings.API_SECURITY_ENABLED
|
||||
logger.info(f"SecurityMiddleware initialized, enabled: {self.enabled}")
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""Process request through security analysis"""
|
||||
if not self.enabled:
|
||||
# Security disabled, pass through
|
||||
return await call_next(request)
|
||||
|
||||
# Skip security analysis for certain endpoints
|
||||
if self._should_skip_security(request):
|
||||
response = await call_next(request)
|
||||
return self._add_security_headers(response)
|
||||
|
||||
# Simple authentication check - drop requests without valid auth
|
||||
if not self._has_valid_auth(request):
|
||||
return JSONResponse(
|
||||
content={"error": "Authentication required", "message": "Valid API key or authentication token required"},
|
||||
status_code=401,
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
try:
|
||||
# Get user context if available
|
||||
user_context = getattr(request.state, 'user', None)
|
||||
|
||||
# Perform security analysis
|
||||
start_time = time.time()
|
||||
analysis = await threat_detection_service.analyze_request(request, user_context)
|
||||
analysis_time = time.time() - start_time
|
||||
|
||||
# Store analysis in request state for later use
|
||||
request.state.security_analysis = analysis
|
||||
|
||||
# Log security events
|
||||
if analysis.is_threat:
|
||||
await self._log_security_event(request, analysis)
|
||||
|
||||
# Check if request should be blocked
|
||||
if analysis.should_block:
|
||||
threat_detection_service.stats['threats_blocked'] += 1
|
||||
logger.warning(f"Blocked request from {request.client.host if request.client else 'unknown'}: "
|
||||
f"risk_score={analysis.risk_score:.3f}, threats={len(analysis.threats)}")
|
||||
|
||||
# Return security block response
|
||||
return self._create_block_response(analysis)
|
||||
|
||||
# Log warnings for medium-risk requests
|
||||
if analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
|
||||
logger.warning(f"High-risk request detected from {request.client.host if request.client else 'unknown'}: "
|
||||
f"risk_score={analysis.risk_score:.3f}, auth_level={analysis.auth_level.value}")
|
||||
|
||||
# Continue with request processing
|
||||
response = await call_next(request)
|
||||
|
||||
# Add security headers and metrics
|
||||
response = self._add_security_headers(response)
|
||||
response = self._add_security_metrics(response, analysis, analysis_time)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Security middleware error: {e}")
|
||||
# Continue with request on security middleware errors to avoid breaking the app
|
||||
response = await call_next(request)
|
||||
return self._add_security_headers(response)
|
||||
|
||||
def _should_skip_security(self, request: Request) -> bool:
|
||||
"""Determine if security analysis should be skipped for this request"""
|
||||
path = request.url.path
|
||||
|
||||
# Skip for health checks and static assets
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/v1/docs",
|
||||
"/api/v1/openapi.json",
|
||||
"/api/v1/redoc",
|
||||
"/favicon.ico"
|
||||
]
|
||||
|
||||
# Skip for static file extensions
|
||||
static_extensions = [".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".ico", ".svg", ".woff", ".woff2"]
|
||||
|
||||
return (
|
||||
path in skip_paths or
|
||||
any(path.endswith(ext) for ext in static_extensions) or
|
||||
path.startswith("/static/")
|
||||
)
|
||||
|
||||
def _has_valid_auth(self, request: Request) -> bool:
|
||||
"""Check if request has valid authentication"""
|
||||
# Check Authorization header
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
api_key_header = request.headers.get("X-API-Key", "")
|
||||
|
||||
# Has some form of auth token/key
|
||||
return (
|
||||
auth_header.startswith("Bearer ") and len(auth_header) > 7 or
|
||||
len(api_key_header.strip()) > 0
|
||||
)
|
||||
|
||||
def _create_block_response(self, analysis: SecurityAnalysis) -> JSONResponse:
|
||||
"""Create response for blocked requests"""
|
||||
# Determine status code based on threat type
|
||||
status_code = 403 # Forbidden by default
|
||||
|
||||
# Rate limiting gets 429
|
||||
if analysis.rate_limit_exceeded:
|
||||
status_code = 429
|
||||
|
||||
# Critical threats get 403
|
||||
for threat in analysis.threats:
|
||||
if threat.threat_type in ["command_injection", "sql_injection"]:
|
||||
status_code = 403
|
||||
break
|
||||
|
||||
response_data = {
|
||||
"error": "Security Policy Violation",
|
||||
"message": "Request blocked due to security policy violation",
|
||||
"risk_score": round(analysis.risk_score, 3),
|
||||
"auth_level": analysis.auth_level.value,
|
||||
"threat_count": len(analysis.threats),
|
||||
"recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations
|
||||
}
|
||||
|
||||
# Add rate limiting info if applicable
|
||||
if analysis.rate_limit_exceeded:
|
||||
response_data["error"] = "Rate Limit Exceeded"
|
||||
response_data["message"] = f"Rate limit exceeded for {analysis.auth_level.value} user"
|
||||
response_data["retry_after"] = "60" # Suggest retry after 60 seconds
|
||||
|
||||
response = JSONResponse(
|
||||
content=response_data,
|
||||
status_code=status_code
|
||||
)
|
||||
|
||||
# Add rate limiting headers
|
||||
if analysis.rate_limit_exceeded:
|
||||
response.headers["Retry-After"] = "60"
|
||||
response.headers["X-RateLimit-Limit"] = "See API documentation"
|
||||
response.headers["X-RateLimit-Reset"] = str(int(time.time() + 60))
|
||||
|
||||
return response
|
||||
|
||||
def _add_security_headers(self, response: Response) -> Response:
|
||||
"""Add security headers to response"""
|
||||
if not settings.API_SECURITY_HEADERS_ENABLED:
|
||||
return response
|
||||
|
||||
# Standard security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
# Only add HSTS for HTTPS
|
||||
if hasattr(response, 'headers') and response.headers.get("X-Forwarded-Proto") == "https":
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
|
||||
# Content Security Policy
|
||||
if settings.API_CSP_HEADER:
|
||||
response.headers["Content-Security-Policy"] = settings.API_CSP_HEADER
|
||||
|
||||
return response
|
||||
|
||||
def _add_security_metrics(self, response: Response, analysis: SecurityAnalysis, analysis_time: float) -> Response:
|
||||
"""Add security metrics to response headers (for debugging/monitoring)"""
|
||||
# Only add in debug mode or for admin users
|
||||
if settings.APP_DEBUG:
|
||||
response.headers["X-Security-Risk-Score"] = str(round(analysis.risk_score, 3))
|
||||
response.headers["X-Security-Threats"] = str(len(analysis.threats))
|
||||
response.headers["X-Security-Auth-Level"] = analysis.auth_level.value
|
||||
response.headers["X-Security-Analysis-Time"] = f"{analysis_time*1000:.1f}ms"
|
||||
|
||||
return response
|
||||
|
||||
async def _log_security_event(self, request: Request, analysis: SecurityAnalysis):
|
||||
"""Log security events for audit and monitoring"""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
|
||||
# Create security event log
|
||||
event_data = {
|
||||
"timestamp": analysis.timestamp.isoformat(),
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"path": str(request.url.path),
|
||||
"method": request.method,
|
||||
"risk_score": round(analysis.risk_score, 3),
|
||||
"auth_level": analysis.auth_level.value,
|
||||
"threat_count": len(analysis.threats),
|
||||
"rate_limit_exceeded": analysis.rate_limit_exceeded,
|
||||
"should_block": analysis.should_block,
|
||||
"threats": [
|
||||
{
|
||||
"type": threat.threat_type,
|
||||
"level": threat.level.value,
|
||||
"confidence": round(threat.confidence, 3),
|
||||
"description": threat.description
|
||||
}
|
||||
for threat in analysis.threats[:5] # Limit to first 5 threats
|
||||
],
|
||||
"recommendations": analysis.recommendations
|
||||
}
|
||||
|
||||
# Log at appropriate level based on risk
|
||||
if analysis.should_block:
|
||||
logger.warning(f"SECURITY_BLOCK: {json.dumps(event_data)}")
|
||||
elif analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
|
||||
logger.warning(f"SECURITY_WARNING: {json.dumps(event_data)}")
|
||||
else:
|
||||
logger.info(f"SECURITY_THREAT: {json.dumps(event_data)}")
|
||||
|
||||
|
||||
def setup_security_middleware(app, enabled: bool = True) -> None:
|
||||
"""Setup security middleware on FastAPI app"""
|
||||
if enabled and settings.API_SECURITY_ENABLED:
|
||||
app.add_middleware(SecurityMiddleware, enabled=enabled)
|
||||
logger.info("Security middleware enabled")
|
||||
else:
|
||||
logger.info("Security middleware disabled")
|
||||
|
||||
|
||||
# Helper functions for manual security checks
|
||||
async def analyze_request_security(request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
|
||||
"""Manually analyze request security (for use in route handlers)"""
|
||||
return await threat_detection_service.analyze_request(request, user_context)
|
||||
|
||||
|
||||
def get_security_stats() -> Dict[str, Any]:
|
||||
"""Get security statistics"""
|
||||
return threat_detection_service.get_stats()
|
||||
|
||||
|
||||
def is_request_blocked(request: Request) -> bool:
|
||||
"""Check if request was blocked by security analysis"""
|
||||
if hasattr(request.state, 'security_analysis'):
|
||||
return request.state.security_analysis.should_block
|
||||
return False
|
||||
|
||||
|
||||
def get_request_risk_score(request: Request) -> float:
|
||||
"""Get risk score for request"""
|
||||
if hasattr(request.state, 'security_analysis'):
|
||||
return request.state.security_analysis.risk_score
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_request_auth_level(request: Request) -> str:
|
||||
"""Get authentication level for request"""
|
||||
if hasattr(request.state, 'security_analysis'):
|
||||
return request.state.security_analysis.auth_level.value
|
||||
return "unknown"
|
||||
Reference in New Issue
Block a user