Files
enclava/backend/app/services/api_proxy.py
2025-08-19 09:50:15 +02:00

606 lines
24 KiB
Python

"""
API Proxy with comprehensive security interceptors
"""
import json
import time
import re
from typing import Dict, List, Any, Optional
from fastapi import Request, Response, HTTPException, status
from fastapi.responses import JSONResponse
import httpx
import yaml
from pathlib import Path
from app.core.config import settings
from app.core.logging import get_logger
from app.services.api_key_auth import get_api_key_info
from app.services.budget_enforcement import check_budget_and_record_usage
from app.middleware.rate_limiting import rate_limiter
from app.utils.exceptions import ValidationError, AuthenticationError, RateLimitExceeded
from app.services.audit_service import create_audit_log
logger = get_logger(__name__)
class SecurityConfiguration:
"""Security configuration for API proxy"""
def __init__(self):
self.config = self._load_security_config()
def _load_security_config(self) -> Dict[str, Any]:
"""Load security configuration"""
return {
"rate_limits": {
"global": 10000, # per hour
"per_key": 1000, # per hour
"per_endpoint": {
"/api/llm/v1/chat/completions": 100, # per minute
"/api/modules/v1/rag/search": 500, # per hour
}
},
"max_request_size": 10 * 1024 * 1024, # 10MB
"max_string_length": 50000,
"timeout": 30, # seconds
"required_headers": ["X-API-Key"],
"ip_whitelist_enabled": False,
"ip_whitelist": [],
"ip_blacklist": [],
"forbidden_patterns": [
"<script", "javascript:", "data:text/html", "vbscript:",
"union select", "drop table", "insert into", "delete from"
],
"audit": {
"enabled": True,
"include_request_body": False,
"include_response_body": False,
"sensitive_paths": ["/api/platform/v1/auth"]
}
}
class RequestValidator:
"""Validates API requests against schemas and security policies"""
def __init__(self, config: SecurityConfiguration):
self.config = config
self.schemas = self._load_openapi_schemas()
def _load_openapi_schemas(self) -> Dict[str, Any]:
"""Load OpenAPI schemas for validation"""
# Would load actual OpenAPI schemas in production
return {
"POST /api/llm/v1/chat/completions": {
"requestBody": {
"type": "object",
"required": ["model", "messages"],
"properties": {
"model": {"type": "string"},
"messages": {"type": "array"},
"temperature": {"type": "number", "minimum": 0, "maximum": 2},
"max_tokens": {"type": "integer", "minimum": 1, "maximum": 32000}
}
}
},
"POST /api/modules/v1/rag/search": {
"requestBody": {
"type": "object",
"required": ["query"],
"properties": {
"query": {"type": "string", "maxLength": 1000},
"limit": {"type": "integer", "minimum": 1, "maximum": 100}
}
}
}
}
async def validate(self, path: str, method: str, body: Dict, headers: Dict) -> Dict:
"""Validate request against schema and security policies"""
# Check request size
body_str = json.dumps(body)
if len(body_str.encode()) > self.config.config["max_request_size"]:
raise ValidationError(f"Request size exceeds maximum allowed")
# Check required headers
for header in self.config.config["required_headers"]:
if header not in headers:
raise ValidationError(f"Missing required header: {header}")
# Validate against schema if available
schema_key = f"{method.upper()} {path}"
if schema_key in self.schemas:
await self._validate_against_schema(body, self.schemas[schema_key])
# Security validation
self._validate_security_patterns(body)
return body
async def _validate_against_schema(self, body: Dict, schema: Dict):
"""Validate request body against OpenAPI schema"""
request_schema = schema.get("requestBody", {})
# Basic validation (would use proper JSON schema validator in production)
if "required" in request_schema:
for field in request_schema["required"]:
if field not in body:
raise ValidationError(f"Missing required field: {field}")
if "properties" in request_schema:
for field, constraints in request_schema["properties"].items():
if field in body:
await self._validate_field(field, body[field], constraints)
async def _validate_field(self, field_name: str, value: Any, constraints: Dict):
"""Validate individual field against constraints"""
field_type = constraints.get("type")
if field_type == "string":
if not isinstance(value, str):
raise ValidationError(f"Field {field_name} must be a string")
if "maxLength" in constraints and len(value) > constraints["maxLength"]:
raise ValidationError(f"Field {field_name} exceeds maximum length")
elif field_type == "integer":
if not isinstance(value, int):
raise ValidationError(f"Field {field_name} must be an integer")
if "minimum" in constraints and value < constraints["minimum"]:
raise ValidationError(f"Field {field_name} below minimum value")
if "maximum" in constraints and value > constraints["maximum"]:
raise ValidationError(f"Field {field_name} exceeds maximum value")
elif field_type == "number":
if not isinstance(value, (int, float)):
raise ValidationError(f"Field {field_name} must be a number")
if "minimum" in constraints and value < constraints["minimum"]:
raise ValidationError(f"Field {field_name} below minimum value")
if "maximum" in constraints and value > constraints["maximum"]:
raise ValidationError(f"Field {field_name} exceeds maximum value")
def _validate_security_patterns(self, body: Dict):
"""Check for forbidden security patterns"""
body_str = json.dumps(body).lower()
for pattern in self.config.config["forbidden_patterns"]:
if pattern.lower() in body_str:
raise ValidationError(f"Request contains forbidden pattern: {pattern}")
class APISecurityProxy:
"""Main API security proxy with interceptor pattern"""
def __init__(self):
self.config = SecurityConfiguration()
self.request_validator = RequestValidator(self.config)
async def proxy_request(self, request: Request, path: str) -> Response:
"""
Main proxy method that implements the full interceptor pattern
"""
start_time = time.time()
api_key_info = None
user_permissions = []
try:
# 1. Extract and validate API key
api_key_info = await self._extract_and_validate_api_key(request)
if api_key_info:
user_permissions = api_key_info.get("permissions", [])
# 2. IP validation (if enabled)
await self._validate_ip_address(request)
# 3. Rate limiting
await self._check_rate_limits(request, path, api_key_info)
# 4. Request validation and sanitization
request_body = await self._get_request_body(request)
validated_body = await self.request_validator.validate(
path=path,
method=request.method,
body=request_body,
headers=dict(request.headers)
)
# 5. Sanitize request
sanitized_body = self._sanitize_request(validated_body)
# 6. Budget checking (for LLM endpoints)
if path.startswith("/api/llm/"):
await self._check_budget_constraints(api_key_info, sanitized_body)
# 7. Build proxy headers
proxy_headers = self._build_proxy_headers(request, api_key_info)
# 8. Log security event
await self._log_security_event(
request=request,
path=path,
api_key_info=api_key_info,
sanitized_body=sanitized_body
)
# 9. Forward request to appropriate backend
response = await self._forward_request(
path=path,
method=request.method,
body=sanitized_body,
headers=proxy_headers
)
# 10. Validate and sanitize response
validated_response = await self._process_response(path, response)
# 11. Record usage metrics
await self._record_usage_metrics(
api_key_info=api_key_info,
path=path,
duration=time.time() - start_time,
success=True
)
return validated_response
except Exception as e:
# Error handling and logging
await self._handle_error(
request=request,
path=path,
api_key_info=api_key_info,
error=e,
duration=time.time() - start_time
)
# Return appropriate error response
return await self._create_error_response(e)
async def _extract_and_validate_api_key(self, request: Request) -> Optional[Dict[str, Any]]:
"""Extract and validate API key from request"""
# Try different auth methods
api_key = None
# Bearer token
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
api_key = auth_header[7:]
# X-API-Key header
elif request.headers.get("X-API-Key"):
api_key = request.headers.get("X-API-Key")
if not api_key:
raise AuthenticationError("Missing API key")
# Validate API key
api_key_info = await get_api_key_info(api_key)
if not api_key_info:
raise AuthenticationError("Invalid API key")
if not api_key_info.get("is_active", False):
raise AuthenticationError("API key is disabled")
return api_key_info
async def _validate_ip_address(self, request: Request):
"""Validate client IP address against whitelist/blacklist"""
client_ip = request.client.host
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
config = self.config.config
# Check blacklist
if client_ip in config["ip_blacklist"]:
raise AuthenticationError(f"IP address {client_ip} is blacklisted")
# Check whitelist (if enabled)
if config["ip_whitelist_enabled"] and client_ip not in config["ip_whitelist"]:
raise AuthenticationError(f"IP address {client_ip} is not whitelisted")
async def _check_rate_limits(self, request: Request, path: str, api_key_info: Optional[Dict]):
"""Check rate limits for the request"""
client_ip = request.client.host
api_key = api_key_info.get("key_prefix", "") if api_key_info else None
# Use existing rate limiter
if api_key:
# API key-based rate limiting
rate_limit_key = f"api_key:{api_key}"
limit_per_minute = api_key_info.get("rate_limit_per_minute", 100)
limit_per_hour = api_key_info.get("rate_limit_per_hour", 1000)
# Check per-minute limit
is_allowed_minute, _ = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, _ = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
if not (is_allowed_minute and is_allowed_hour):
raise RateLimitExceeded("API key rate limit exceeded")
else:
# IP-based rate limiting for unauthenticated requests
rate_limit_key = f"ip:{client_ip}"
is_allowed_minute, _ = await rate_limiter.check_rate_limit(
rate_limit_key, 20, 60, "minute"
)
if not is_allowed_minute:
raise RateLimitExceeded("IP rate limit exceeded")
async def _get_request_body(self, request: Request) -> Dict[str, Any]:
"""Extract request body"""
try:
if request.method in ["POST", "PUT", "PATCH"]:
return await request.json()
else:
return {}
except Exception:
return {}
def _sanitize_request(self, body: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize request data"""
def sanitize_value(value):
if isinstance(value, str):
# Remove forbidden patterns
for pattern in self.config.config["forbidden_patterns"]:
value = re.sub(re.escape(pattern), "", value, flags=re.IGNORECASE)
# Limit string length
max_length = self.config.config["max_string_length"]
if len(value) > max_length:
value = value[:max_length]
logger.warning(f"Truncated long string in request: {len(value)} chars")
return value
elif isinstance(value, dict):
return {k: sanitize_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [sanitize_value(item) for item in value]
else:
return value
return sanitize_value(body)
async def _check_budget_constraints(self, api_key_info: Dict, body: Dict):
"""Check budget constraints for LLM requests"""
if not api_key_info:
return
# Estimate cost based on request
estimated_cost = self._estimate_request_cost(body)
# Check budget
user_id = api_key_info.get("user_id")
api_key_id = api_key_info.get("id")
budget_ok = await check_budget_and_record_usage(
user_id=user_id,
api_key_id=api_key_id,
estimated_cost=estimated_cost,
actual_cost=0, # Will be updated after response
metadata={"endpoint": "llm_proxy", "model": body.get("model", "unknown")}
)
if not budget_ok:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail="Budget limit exceeded"
)
def _estimate_request_cost(self, body: Dict) -> float:
"""Estimate cost of LLM request"""
# Rough estimation based on model and tokens
model = body.get("model", "gpt-3.5-turbo")
messages = body.get("messages", [])
max_tokens = body.get("max_tokens", 1000)
# Estimate input tokens
input_text = " ".join([msg.get("content", "") for msg in messages if isinstance(msg, dict)])
input_tokens = len(input_text.split()) * 1.3 # Rough approximation
# Model pricing (simplified)
pricing = {
"gpt-4": {"input": 0.03, "output": 0.06}, # per 1K tokens
"gpt-3.5-turbo": {"input": 0.001, "output": 0.002},
"claude-3-sonnet": {"input": 0.003, "output": 0.015},
"claude-3-haiku": {"input": 0.00025, "output": 0.00125}
}
model_pricing = pricing.get(model, pricing["gpt-3.5-turbo"])
estimated_cost = (
(input_tokens / 1000) * model_pricing["input"] +
(max_tokens / 1000) * model_pricing["output"]
)
return estimated_cost
def _build_proxy_headers(self, request: Request, api_key_info: Optional[Dict]) -> Dict[str, str]:
"""Build headers for proxy request"""
headers = {
"Content-Type": "application/json",
"User-Agent": f"ConfidentialEmpire-Proxy/1.0",
"X-Forwarded-For": request.client.host,
"X-Request-ID": f"req_{int(time.time() * 1000)}"
}
if api_key_info:
headers["X-User-ID"] = str(api_key_info.get("user_id", ""))
headers["X-API-Key-ID"] = str(api_key_info.get("id", ""))
return headers
async def _log_security_event(self, request: Request, path: str, api_key_info: Optional[Dict], sanitized_body: Dict):
"""Log security event for audit trail"""
await create_audit_log(
action=f"api_proxy_{request.method.lower()}",
resource_type="api_endpoint",
resource_id=path,
user_id=api_key_info.get("user_id") if api_key_info else None,
success=True,
ip_address=request.client.host,
user_agent=request.headers.get("User-Agent", ""),
metadata={
"endpoint": path,
"method": request.method,
"api_key_id": api_key_info.get("id") if api_key_info else None,
"request_size": len(json.dumps(sanitized_body))
}
)
async def _forward_request(self, path: str, method: str, body: Dict, headers: Dict) -> Dict:
"""Forward request to appropriate backend service"""
# Determine target service based on path
if path.startswith("/api/llm/"):
target_url = f"{settings.LITELLM_BASE_URL}{path}"
target_headers = {**headers, "Authorization": f"Bearer {settings.LITELLM_MASTER_KEY}"}
elif path.startswith("/api/modules/"):
# Route to module system
return await self._route_to_module(path, method, body, headers)
else:
raise ValidationError(f"Unknown endpoint: {path}")
# Make HTTP request to target service
timeout = self.config.config["timeout"]
async with httpx.AsyncClient(timeout=timeout) as client:
if method == "GET":
response = await client.get(target_url, headers=target_headers)
elif method == "POST":
response = await client.post(target_url, json=body, headers=target_headers)
elif method == "PUT":
response = await client.put(target_url, json=body, headers=target_headers)
elif method == "DELETE":
response = await client.delete(target_url, headers=target_headers)
else:
raise ValidationError(f"Unsupported HTTP method: {method}")
if response.status_code >= 400:
raise HTTPException(status_code=response.status_code, detail=response.text)
return response.json()
async def _route_to_module(self, path: str, method: str, body: Dict, headers: Dict) -> Dict:
"""Route request to module system"""
# Extract module name from path
# e.g., /api/modules/v1/rag/search -> module: rag, action: search
path_parts = path.strip("/").split("/")
if len(path_parts) >= 4:
module_name = path_parts[3]
action = path_parts[4] if len(path_parts) > 4 else "execute"
else:
raise ValidationError("Invalid module path")
# Import module manager
from app.services.module_manager import module_manager
if module_name not in module_manager.modules:
raise ValidationError(f"Module not found: {module_name}")
module = module_manager.modules[module_name]
# Prepare context
context = {
"user_id": headers.get("X-User-ID"),
"api_key_id": headers.get("X-API-Key-ID"),
"ip_address": headers.get("X-Forwarded-For"),
"user_permissions": [] # Would be populated from API key info
}
# Prepare request
module_request = {
"action": action,
"method": method,
**body
}
# Execute through module's interceptor chain
if hasattr(module, 'execute_with_interceptors'):
return await module.execute_with_interceptors(module_request, context)
else:
# Fallback for legacy modules
if hasattr(module, action):
return await getattr(module, action)(module_request)
else:
raise ValidationError(f"Action not supported: {action}")
async def _process_response(self, path: str, response: Dict) -> JSONResponse:
"""Process and validate response"""
# Add security headers
headers = {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains"
}
return JSONResponse(content=response, headers=headers)
async def _record_usage_metrics(self, api_key_info: Optional[Dict], path: str, duration: float, success: bool):
"""Record usage metrics"""
if api_key_info:
# Record API key usage
# This would update database metrics
pass
async def _handle_error(self, request: Request, path: str, api_key_info: Optional[Dict], error: Exception, duration: float):
"""Handle and log errors"""
await create_audit_log(
action=f"api_proxy_{request.method.lower()}",
resource_type="api_endpoint",
resource_id=path,
user_id=api_key_info.get("user_id") if api_key_info else None,
success=False,
error_message=str(error),
ip_address=request.client.host,
user_agent=request.headers.get("User-Agent", ""),
metadata={
"endpoint": path,
"method": request.method,
"duration_ms": int(duration * 1000),
"error_type": type(error).__name__
}
)
async def _create_error_response(self, error: Exception) -> JSONResponse:
"""Create appropriate error response"""
if isinstance(error, AuthenticationError):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"error": "AUTHENTICATION_ERROR", "message": str(error)}
)
elif isinstance(error, ValidationError):
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"error": "VALIDATION_ERROR", "message": str(error)}
)
elif isinstance(error, RateLimitExceeded):
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"error": "RATE_LIMIT_EXCEEDED", "message": str(error)}
)
elif isinstance(error, HTTPException):
return JSONResponse(
status_code=error.status_code,
content={"error": "HTTP_ERROR", "message": error.detail}
)
else:
logger.error(f"Unexpected error in API proxy: {error}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"error": "INTERNAL_ERROR", "message": "An unexpected error occurred"}
)
# Global proxy instance
api_security_proxy = APISecurityProxy()