mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
clean commit
This commit is contained in:
606
backend/app/services/api_proxy.py
Normal file
606
backend/app/services/api_proxy.py
Normal file
@@ -0,0 +1,606 @@
|
||||
"""
|
||||
API Proxy with comprehensive security interceptors
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
from typing import Dict, List, Any, Optional
|
||||
from fastapi import Request, Response, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
import httpx
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.services.api_key_auth import get_api_key_info
|
||||
from app.services.budget_enforcement import check_budget_and_record_usage
|
||||
from app.middleware.rate_limiting import rate_limiter
|
||||
from app.utils.exceptions import ValidationError, AuthenticationError, RateLimitExceeded
|
||||
from app.services.audit_service import create_audit_log
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SecurityConfiguration:
|
||||
"""Security configuration for API proxy"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = self._load_security_config()
|
||||
|
||||
def _load_security_config(self) -> Dict[str, Any]:
|
||||
"""Load security configuration"""
|
||||
return {
|
||||
"rate_limits": {
|
||||
"global": 10000, # per hour
|
||||
"per_key": 1000, # per hour
|
||||
"per_endpoint": {
|
||||
"/api/llm/v1/chat/completions": 100, # per minute
|
||||
"/api/modules/v1/rag/search": 500, # per hour
|
||||
}
|
||||
},
|
||||
"max_request_size": 10 * 1024 * 1024, # 10MB
|
||||
"max_string_length": 50000,
|
||||
"timeout": 30, # seconds
|
||||
"required_headers": ["X-API-Key"],
|
||||
"ip_whitelist_enabled": False,
|
||||
"ip_whitelist": [],
|
||||
"ip_blacklist": [],
|
||||
"forbidden_patterns": [
|
||||
"<script", "javascript:", "data:text/html", "vbscript:",
|
||||
"union select", "drop table", "insert into", "delete from"
|
||||
],
|
||||
"audit": {
|
||||
"enabled": True,
|
||||
"include_request_body": False,
|
||||
"include_response_body": False,
|
||||
"sensitive_paths": ["/api/platform/v1/auth"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class RequestValidator:
|
||||
"""Validates API requests against schemas and security policies"""
|
||||
|
||||
def __init__(self, config: SecurityConfiguration):
|
||||
self.config = config
|
||||
self.schemas = self._load_openapi_schemas()
|
||||
|
||||
def _load_openapi_schemas(self) -> Dict[str, Any]:
|
||||
"""Load OpenAPI schemas for validation"""
|
||||
# Would load actual OpenAPI schemas in production
|
||||
return {
|
||||
"POST /api/llm/v1/chat/completions": {
|
||||
"requestBody": {
|
||||
"type": "object",
|
||||
"required": ["model", "messages"],
|
||||
"properties": {
|
||||
"model": {"type": "string"},
|
||||
"messages": {"type": "array"},
|
||||
"temperature": {"type": "number", "minimum": 0, "maximum": 2},
|
||||
"max_tokens": {"type": "integer", "minimum": 1, "maximum": 32000}
|
||||
}
|
||||
}
|
||||
},
|
||||
"POST /api/modules/v1/rag/search": {
|
||||
"requestBody": {
|
||||
"type": "object",
|
||||
"required": ["query"],
|
||||
"properties": {
|
||||
"query": {"type": "string", "maxLength": 1000},
|
||||
"limit": {"type": "integer", "minimum": 1, "maximum": 100}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async def validate(self, path: str, method: str, body: Dict, headers: Dict) -> Dict:
|
||||
"""Validate request against schema and security policies"""
|
||||
|
||||
# Check request size
|
||||
body_str = json.dumps(body)
|
||||
if len(body_str.encode()) > self.config.config["max_request_size"]:
|
||||
raise ValidationError(f"Request size exceeds maximum allowed")
|
||||
|
||||
# Check required headers
|
||||
for header in self.config.config["required_headers"]:
|
||||
if header not in headers:
|
||||
raise ValidationError(f"Missing required header: {header}")
|
||||
|
||||
# Validate against schema if available
|
||||
schema_key = f"{method.upper()} {path}"
|
||||
if schema_key in self.schemas:
|
||||
await self._validate_against_schema(body, self.schemas[schema_key])
|
||||
|
||||
# Security validation
|
||||
self._validate_security_patterns(body)
|
||||
|
||||
return body
|
||||
|
||||
async def _validate_against_schema(self, body: Dict, schema: Dict):
|
||||
"""Validate request body against OpenAPI schema"""
|
||||
request_schema = schema.get("requestBody", {})
|
||||
|
||||
# Basic validation (would use proper JSON schema validator in production)
|
||||
if "required" in request_schema:
|
||||
for field in request_schema["required"]:
|
||||
if field not in body:
|
||||
raise ValidationError(f"Missing required field: {field}")
|
||||
|
||||
if "properties" in request_schema:
|
||||
for field, constraints in request_schema["properties"].items():
|
||||
if field in body:
|
||||
await self._validate_field(field, body[field], constraints)
|
||||
|
||||
async def _validate_field(self, field_name: str, value: Any, constraints: Dict):
|
||||
"""Validate individual field against constraints"""
|
||||
field_type = constraints.get("type")
|
||||
|
||||
if field_type == "string":
|
||||
if not isinstance(value, str):
|
||||
raise ValidationError(f"Field {field_name} must be a string")
|
||||
if "maxLength" in constraints and len(value) > constraints["maxLength"]:
|
||||
raise ValidationError(f"Field {field_name} exceeds maximum length")
|
||||
|
||||
elif field_type == "integer":
|
||||
if not isinstance(value, int):
|
||||
raise ValidationError(f"Field {field_name} must be an integer")
|
||||
if "minimum" in constraints and value < constraints["minimum"]:
|
||||
raise ValidationError(f"Field {field_name} below minimum value")
|
||||
if "maximum" in constraints and value > constraints["maximum"]:
|
||||
raise ValidationError(f"Field {field_name} exceeds maximum value")
|
||||
|
||||
elif field_type == "number":
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValidationError(f"Field {field_name} must be a number")
|
||||
if "minimum" in constraints and value < constraints["minimum"]:
|
||||
raise ValidationError(f"Field {field_name} below minimum value")
|
||||
if "maximum" in constraints and value > constraints["maximum"]:
|
||||
raise ValidationError(f"Field {field_name} exceeds maximum value")
|
||||
|
||||
def _validate_security_patterns(self, body: Dict):
|
||||
"""Check for forbidden security patterns"""
|
||||
body_str = json.dumps(body).lower()
|
||||
|
||||
for pattern in self.config.config["forbidden_patterns"]:
|
||||
if pattern.lower() in body_str:
|
||||
raise ValidationError(f"Request contains forbidden pattern: {pattern}")
|
||||
|
||||
|
||||
class APISecurityProxy:
|
||||
"""Main API security proxy with interceptor pattern"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = SecurityConfiguration()
|
||||
self.request_validator = RequestValidator(self.config)
|
||||
|
||||
async def proxy_request(self, request: Request, path: str) -> Response:
|
||||
"""
|
||||
Main proxy method that implements the full interceptor pattern
|
||||
"""
|
||||
start_time = time.time()
|
||||
api_key_info = None
|
||||
user_permissions = []
|
||||
|
||||
try:
|
||||
# 1. Extract and validate API key
|
||||
api_key_info = await self._extract_and_validate_api_key(request)
|
||||
if api_key_info:
|
||||
user_permissions = api_key_info.get("permissions", [])
|
||||
|
||||
# 2. IP validation (if enabled)
|
||||
await self._validate_ip_address(request)
|
||||
|
||||
# 3. Rate limiting
|
||||
await self._check_rate_limits(request, path, api_key_info)
|
||||
|
||||
# 4. Request validation and sanitization
|
||||
request_body = await self._get_request_body(request)
|
||||
validated_body = await self.request_validator.validate(
|
||||
path=path,
|
||||
method=request.method,
|
||||
body=request_body,
|
||||
headers=dict(request.headers)
|
||||
)
|
||||
|
||||
# 5. Sanitize request
|
||||
sanitized_body = self._sanitize_request(validated_body)
|
||||
|
||||
# 6. Budget checking (for LLM endpoints)
|
||||
if path.startswith("/api/llm/"):
|
||||
await self._check_budget_constraints(api_key_info, sanitized_body)
|
||||
|
||||
# 7. Build proxy headers
|
||||
proxy_headers = self._build_proxy_headers(request, api_key_info)
|
||||
|
||||
# 8. Log security event
|
||||
await self._log_security_event(
|
||||
request=request,
|
||||
path=path,
|
||||
api_key_info=api_key_info,
|
||||
sanitized_body=sanitized_body
|
||||
)
|
||||
|
||||
# 9. Forward request to appropriate backend
|
||||
response = await self._forward_request(
|
||||
path=path,
|
||||
method=request.method,
|
||||
body=sanitized_body,
|
||||
headers=proxy_headers
|
||||
)
|
||||
|
||||
# 10. Validate and sanitize response
|
||||
validated_response = await self._process_response(path, response)
|
||||
|
||||
# 11. Record usage metrics
|
||||
await self._record_usage_metrics(
|
||||
api_key_info=api_key_info,
|
||||
path=path,
|
||||
duration=time.time() - start_time,
|
||||
success=True
|
||||
)
|
||||
|
||||
return validated_response
|
||||
|
||||
except Exception as e:
|
||||
# Error handling and logging
|
||||
await self._handle_error(
|
||||
request=request,
|
||||
path=path,
|
||||
api_key_info=api_key_info,
|
||||
error=e,
|
||||
duration=time.time() - start_time
|
||||
)
|
||||
|
||||
# Return appropriate error response
|
||||
return await self._create_error_response(e)
|
||||
|
||||
async def _extract_and_validate_api_key(self, request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""Extract and validate API key from request"""
|
||||
|
||||
# Try different auth methods
|
||||
api_key = None
|
||||
|
||||
# Bearer token
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
api_key = auth_header[7:]
|
||||
|
||||
# X-API-Key header
|
||||
elif request.headers.get("X-API-Key"):
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
|
||||
if not api_key:
|
||||
raise AuthenticationError("Missing API key")
|
||||
|
||||
# Validate API key
|
||||
api_key_info = await get_api_key_info(api_key)
|
||||
if not api_key_info:
|
||||
raise AuthenticationError("Invalid API key")
|
||||
|
||||
if not api_key_info.get("is_active", False):
|
||||
raise AuthenticationError("API key is disabled")
|
||||
|
||||
return api_key_info
|
||||
|
||||
async def _validate_ip_address(self, request: Request):
|
||||
"""Validate client IP address against whitelist/blacklist"""
|
||||
client_ip = request.client.host
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
|
||||
config = self.config.config
|
||||
|
||||
# Check blacklist
|
||||
if client_ip in config["ip_blacklist"]:
|
||||
raise AuthenticationError(f"IP address {client_ip} is blacklisted")
|
||||
|
||||
# Check whitelist (if enabled)
|
||||
if config["ip_whitelist_enabled"] and client_ip not in config["ip_whitelist"]:
|
||||
raise AuthenticationError(f"IP address {client_ip} is not whitelisted")
|
||||
|
||||
async def _check_rate_limits(self, request: Request, path: str, api_key_info: Optional[Dict]):
|
||||
"""Check rate limits for the request"""
|
||||
client_ip = request.client.host
|
||||
api_key = api_key_info.get("key_prefix", "") if api_key_info else None
|
||||
|
||||
# Use existing rate limiter
|
||||
if api_key:
|
||||
# API key-based rate limiting
|
||||
rate_limit_key = f"api_key:{api_key}"
|
||||
limit_per_minute = api_key_info.get("rate_limit_per_minute", 100)
|
||||
limit_per_hour = api_key_info.get("rate_limit_per_hour", 1000)
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, _ = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, _ = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
if not (is_allowed_minute and is_allowed_hour):
|
||||
raise RateLimitExceeded("API key rate limit exceeded")
|
||||
|
||||
else:
|
||||
# IP-based rate limiting for unauthenticated requests
|
||||
rate_limit_key = f"ip:{client_ip}"
|
||||
|
||||
is_allowed_minute, _ = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, 20, 60, "minute"
|
||||
)
|
||||
|
||||
if not is_allowed_minute:
|
||||
raise RateLimitExceeded("IP rate limit exceeded")
|
||||
|
||||
async def _get_request_body(self, request: Request) -> Dict[str, Any]:
|
||||
"""Extract request body"""
|
||||
try:
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
return await request.json()
|
||||
else:
|
||||
return {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def _sanitize_request(self, body: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize request data"""
|
||||
def sanitize_value(value):
|
||||
if isinstance(value, str):
|
||||
# Remove forbidden patterns
|
||||
for pattern in self.config.config["forbidden_patterns"]:
|
||||
value = re.sub(re.escape(pattern), "", value, flags=re.IGNORECASE)
|
||||
|
||||
# Limit string length
|
||||
max_length = self.config.config["max_string_length"]
|
||||
if len(value) > max_length:
|
||||
value = value[:max_length]
|
||||
logger.warning(f"Truncated long string in request: {len(value)} chars")
|
||||
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: sanitize_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [sanitize_value(item) for item in value]
|
||||
else:
|
||||
return value
|
||||
|
||||
return sanitize_value(body)
|
||||
|
||||
async def _check_budget_constraints(self, api_key_info: Dict, body: Dict):
|
||||
"""Check budget constraints for LLM requests"""
|
||||
if not api_key_info:
|
||||
return
|
||||
|
||||
# Estimate cost based on request
|
||||
estimated_cost = self._estimate_request_cost(body)
|
||||
|
||||
# Check budget
|
||||
user_id = api_key_info.get("user_id")
|
||||
api_key_id = api_key_info.get("id")
|
||||
|
||||
budget_ok = await check_budget_and_record_usage(
|
||||
user_id=user_id,
|
||||
api_key_id=api_key_id,
|
||||
estimated_cost=estimated_cost,
|
||||
actual_cost=0, # Will be updated after response
|
||||
metadata={"endpoint": "llm_proxy", "model": body.get("model", "unknown")}
|
||||
)
|
||||
|
||||
if not budget_ok:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail="Budget limit exceeded"
|
||||
)
|
||||
|
||||
def _estimate_request_cost(self, body: Dict) -> float:
|
||||
"""Estimate cost of LLM request"""
|
||||
# Rough estimation based on model and tokens
|
||||
model = body.get("model", "gpt-3.5-turbo")
|
||||
messages = body.get("messages", [])
|
||||
max_tokens = body.get("max_tokens", 1000)
|
||||
|
||||
# Estimate input tokens
|
||||
input_text = " ".join([msg.get("content", "") for msg in messages if isinstance(msg, dict)])
|
||||
input_tokens = len(input_text.split()) * 1.3 # Rough approximation
|
||||
|
||||
# Model pricing (simplified)
|
||||
pricing = {
|
||||
"gpt-4": {"input": 0.03, "output": 0.06}, # per 1K tokens
|
||||
"gpt-3.5-turbo": {"input": 0.001, "output": 0.002},
|
||||
"claude-3-sonnet": {"input": 0.003, "output": 0.015},
|
||||
"claude-3-haiku": {"input": 0.00025, "output": 0.00125}
|
||||
}
|
||||
|
||||
model_pricing = pricing.get(model, pricing["gpt-3.5-turbo"])
|
||||
|
||||
estimated_cost = (
|
||||
(input_tokens / 1000) * model_pricing["input"] +
|
||||
(max_tokens / 1000) * model_pricing["output"]
|
||||
)
|
||||
|
||||
return estimated_cost
|
||||
|
||||
def _build_proxy_headers(self, request: Request, api_key_info: Optional[Dict]) -> Dict[str, str]:
|
||||
"""Build headers for proxy request"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"ConfidentialEmpire-Proxy/1.0",
|
||||
"X-Forwarded-For": request.client.host,
|
||||
"X-Request-ID": f"req_{int(time.time() * 1000)}"
|
||||
}
|
||||
|
||||
if api_key_info:
|
||||
headers["X-User-ID"] = str(api_key_info.get("user_id", ""))
|
||||
headers["X-API-Key-ID"] = str(api_key_info.get("id", ""))
|
||||
|
||||
return headers
|
||||
|
||||
async def _log_security_event(self, request: Request, path: str, api_key_info: Optional[Dict], sanitized_body: Dict):
|
||||
"""Log security event for audit trail"""
|
||||
await create_audit_log(
|
||||
action=f"api_proxy_{request.method.lower()}",
|
||||
resource_type="api_endpoint",
|
||||
resource_id=path,
|
||||
user_id=api_key_info.get("user_id") if api_key_info else None,
|
||||
success=True,
|
||||
ip_address=request.client.host,
|
||||
user_agent=request.headers.get("User-Agent", ""),
|
||||
metadata={
|
||||
"endpoint": path,
|
||||
"method": request.method,
|
||||
"api_key_id": api_key_info.get("id") if api_key_info else None,
|
||||
"request_size": len(json.dumps(sanitized_body))
|
||||
}
|
||||
)
|
||||
|
||||
async def _forward_request(self, path: str, method: str, body: Dict, headers: Dict) -> Dict:
|
||||
"""Forward request to appropriate backend service"""
|
||||
|
||||
# Determine target service based on path
|
||||
if path.startswith("/api/llm/"):
|
||||
target_url = f"{settings.LITELLM_BASE_URL}{path}"
|
||||
target_headers = {**headers, "Authorization": f"Bearer {settings.LITELLM_MASTER_KEY}"}
|
||||
elif path.startswith("/api/modules/"):
|
||||
# Route to module system
|
||||
return await self._route_to_module(path, method, body, headers)
|
||||
else:
|
||||
raise ValidationError(f"Unknown endpoint: {path}")
|
||||
|
||||
# Make HTTP request to target service
|
||||
timeout = self.config.config["timeout"]
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(target_url, headers=target_headers)
|
||||
elif method == "POST":
|
||||
response = await client.post(target_url, json=body, headers=target_headers)
|
||||
elif method == "PUT":
|
||||
response = await client.put(target_url, json=body, headers=target_headers)
|
||||
elif method == "DELETE":
|
||||
response = await client.delete(target_url, headers=target_headers)
|
||||
else:
|
||||
raise ValidationError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
if response.status_code >= 400:
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
|
||||
return response.json()
|
||||
|
||||
async def _route_to_module(self, path: str, method: str, body: Dict, headers: Dict) -> Dict:
|
||||
"""Route request to module system"""
|
||||
# Extract module name from path
|
||||
# e.g., /api/modules/v1/rag/search -> module: rag, action: search
|
||||
path_parts = path.strip("/").split("/")
|
||||
if len(path_parts) >= 4:
|
||||
module_name = path_parts[3]
|
||||
action = path_parts[4] if len(path_parts) > 4 else "execute"
|
||||
else:
|
||||
raise ValidationError("Invalid module path")
|
||||
|
||||
# Import module manager
|
||||
from app.services.module_manager import module_manager
|
||||
|
||||
if module_name not in module_manager.modules:
|
||||
raise ValidationError(f"Module not found: {module_name}")
|
||||
|
||||
module = module_manager.modules[module_name]
|
||||
|
||||
# Prepare context
|
||||
context = {
|
||||
"user_id": headers.get("X-User-ID"),
|
||||
"api_key_id": headers.get("X-API-Key-ID"),
|
||||
"ip_address": headers.get("X-Forwarded-For"),
|
||||
"user_permissions": [] # Would be populated from API key info
|
||||
}
|
||||
|
||||
# Prepare request
|
||||
module_request = {
|
||||
"action": action,
|
||||
"method": method,
|
||||
**body
|
||||
}
|
||||
|
||||
# Execute through module's interceptor chain
|
||||
if hasattr(module, 'execute_with_interceptors'):
|
||||
return await module.execute_with_interceptors(module_request, context)
|
||||
else:
|
||||
# Fallback for legacy modules
|
||||
if hasattr(module, action):
|
||||
return await getattr(module, action)(module_request)
|
||||
else:
|
||||
raise ValidationError(f"Action not supported: {action}")
|
||||
|
||||
async def _process_response(self, path: str, response: Dict) -> JSONResponse:
|
||||
"""Process and validate response"""
|
||||
# Add security headers
|
||||
headers = {
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains"
|
||||
}
|
||||
|
||||
return JSONResponse(content=response, headers=headers)
|
||||
|
||||
async def _record_usage_metrics(self, api_key_info: Optional[Dict], path: str, duration: float, success: bool):
|
||||
"""Record usage metrics"""
|
||||
if api_key_info:
|
||||
# Record API key usage
|
||||
# This would update database metrics
|
||||
pass
|
||||
|
||||
async def _handle_error(self, request: Request, path: str, api_key_info: Optional[Dict], error: Exception, duration: float):
|
||||
"""Handle and log errors"""
|
||||
await create_audit_log(
|
||||
action=f"api_proxy_{request.method.lower()}",
|
||||
resource_type="api_endpoint",
|
||||
resource_id=path,
|
||||
user_id=api_key_info.get("user_id") if api_key_info else None,
|
||||
success=False,
|
||||
error_message=str(error),
|
||||
ip_address=request.client.host,
|
||||
user_agent=request.headers.get("User-Agent", ""),
|
||||
metadata={
|
||||
"endpoint": path,
|
||||
"method": request.method,
|
||||
"duration_ms": int(duration * 1000),
|
||||
"error_type": type(error).__name__
|
||||
}
|
||||
)
|
||||
|
||||
async def _create_error_response(self, error: Exception) -> JSONResponse:
|
||||
"""Create appropriate error response"""
|
||||
if isinstance(error, AuthenticationError):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"error": "AUTHENTICATION_ERROR", "message": str(error)}
|
||||
)
|
||||
elif isinstance(error, ValidationError):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"error": "VALIDATION_ERROR", "message": str(error)}
|
||||
)
|
||||
elif isinstance(error, RateLimitExceeded):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={"error": "RATE_LIMIT_EXCEEDED", "message": str(error)}
|
||||
)
|
||||
elif isinstance(error, HTTPException):
|
||||
return JSONResponse(
|
||||
status_code=error.status_code,
|
||||
content={"error": "HTTP_ERROR", "message": error.detail}
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unexpected error in API proxy: {error}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={"error": "INTERNAL_ERROR", "message": "An unexpected error occurred"}
|
||||
)
|
||||
|
||||
|
||||
# Global proxy instance
|
||||
api_security_proxy = APISecurityProxy()
|
||||
Reference in New Issue
Block a user