Files
enclava/backend/app/middleware/rate_limiting.py
2025-09-21 06:49:55 +02:00

371 lines
13 KiB
Python

"""
Rate limiting middleware
"""
import time
import redis
from typing import Dict, Optional
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import asyncio
from datetime import datetime, timedelta
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
class RateLimiter:
"""Rate limiting implementation using Redis"""
def __init__(self):
try:
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
self.redis_client.ping() # Test connection
logger.info("Rate limiter initialized with Redis backend")
except Exception as e:
logger.warning(f"Redis not available for rate limiting: {e}")
self.redis_client = None
# Fall back to in-memory rate limiting
self.memory_store: Dict[str, Dict[str, float]] = {}
async def check_rate_limit(
self,
key: str,
limit: int,
window_seconds: int,
identifier: str = "default"
) -> tuple[bool, Dict[str, int]]:
"""
Check if request is within rate limit
Args:
key: Rate limiting key (e.g., IP address, API key)
limit: Maximum number of requests allowed
window_seconds: Time window in seconds
identifier: Additional identifier for the rate limit
Returns:
Tuple of (is_allowed, headers_dict)
"""
full_key = f"rate_limit:{identifier}:{key}"
current_time = int(time.time())
window_start = current_time - window_seconds
if self.redis_client:
return await self._check_redis_rate_limit(
full_key, limit, window_seconds, current_time, window_start
)
else:
return self._check_memory_rate_limit(
full_key, limit, window_seconds, current_time, window_start
)
async def _check_redis_rate_limit(
self,
key: str,
limit: int,
window_seconds: int,
current_time: int,
window_start: int
) -> tuple[bool, Dict[str, int]]:
"""Check rate limit using Redis"""
pipe = self.redis_client.pipeline()
# Remove old entries
pipe.zremrangebyscore(key, 0, window_start)
# Count current requests in window
pipe.zcard(key)
# Add current request
pipe.zadd(key, {str(current_time): current_time})
# Set expiration
pipe.expire(key, window_seconds + 1)
results = pipe.execute()
current_requests = results[1]
# Calculate remaining requests and reset time
remaining = max(0, limit - current_requests - 1)
reset_time = current_time + window_seconds
headers = {
"X-RateLimit-Limit": limit,
"X-RateLimit-Remaining": remaining,
"X-RateLimit-Reset": reset_time,
"X-RateLimit-Window": window_seconds
}
is_allowed = current_requests < limit
if not is_allowed:
logger.warning(f"Rate limit exceeded for key: {key}")
return is_allowed, headers
def _check_memory_rate_limit(
self,
key: str,
limit: int,
window_seconds: int,
current_time: int,
window_start: int
) -> tuple[bool, Dict[str, int]]:
"""Check rate limit using in-memory storage"""
if key not in self.memory_store:
self.memory_store[key] = {}
# Clean old entries
store = self.memory_store[key]
keys_to_remove = [k for k, v in store.items() if v < window_start]
for k in keys_to_remove:
del store[k]
current_requests = len(store)
# Calculate remaining requests and reset time
remaining = max(0, limit - current_requests - 1)
reset_time = current_time + window_seconds
headers = {
"X-RateLimit-Limit": limit,
"X-RateLimit-Remaining": remaining,
"X-RateLimit-Reset": reset_time,
"X-RateLimit-Window": window_seconds
}
is_allowed = current_requests < limit
if is_allowed:
# Add current request
store[str(current_time)] = current_time
else:
logger.warning(f"Rate limit exceeded for key: {key}")
return is_allowed, headers
# Global rate limiter instance
rate_limiter = RateLimiter()
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware for FastAPI"""
def __init__(self, app):
super().__init__(app)
self.rate_limiter = RateLimiter()
logger.info("RateLimitMiddleware initialized")
async def dispatch(self, request: Request, call_next):
"""Process request through rate limiting"""
# Skip rate limiting if disabled in settings
if not settings.API_RATE_LIMITING_ENABLED:
response = await call_next(request)
return response
# Skip rate limiting for all internal API endpoints (platform operations)
if request.url.path.startswith("/api-internal/v1/"):
response = await call_next(request)
return response
# Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service)
# Skip for all other endpoints
if not (request.url.path.startswith("/api/v1/chat/completions") or
request.url.path.startswith("/api/v1/embeddings") or
request.url.path.startswith("/api/v1/models") or
request.url.path.startswith("/api/v1/llm/")):
response = await call_next(request)
return response
# Skip rate limiting for health checks and static files
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
response = await call_next(request)
return response
# Get client IP
client_ip = request.client.host
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
# Check for API key in headers
api_key = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:]
elif request.headers.get("X-API-Key"):
api_key = request.headers.get("X-API-Key")
# Determine rate limiting strategy
headers = {}
is_allowed = True
if api_key:
# API key-based rate limiting
api_key_key = f"api_key:{api_key}"
# First check organization-wide limits (PrivateMode limits are org-wide)
org_key = "organization:privatemode"
# Check organization per-minute limit
org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit(
org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute"
)
# Check organization per-hour limit
org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit(
org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour"
)
# If organization limits are exceeded, return 429
if not (org_allowed_minute and org_allowed_hour):
logger.warning(f"Organization rate limit exceeded for {org_key}")
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Organization rate limit exceeded"},
headers=org_headers_minute
)
# Then check per-API key limits
limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE
limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR
# Check per-minute limit
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
api_key_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
api_key_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
else:
# IP-based rate limiting for unauthenticated requests
rate_limit_key = f"ip:{client_ip}"
# More restrictive limits for unauthenticated requests
limit_per_minute = 20 # Hardcoded for unauthenticated users
limit_per_hour = 100
# Check per-minute limit
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
# If rate limit exceeded, return 429
if not is_allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"error": "RATE_LIMIT_EXCEEDED",
"message": "Rate limit exceeded. Please try again later.",
"details": {
"limit": headers["X-RateLimit-Limit"],
"reset_time": headers["X-RateLimit-Reset"]
}
},
headers={k: str(v) for k, v in headers.items()}
)
# Continue with request
response = await call_next(request)
# Add rate limit headers to response
for key, value in headers.items():
response.headers[key] = str(value)
return response
# Keep the old function for backward compatibility
async def rate_limit_middleware(request: Request, call_next):
"""Legacy function - use RateLimitMiddleware class instead"""
middleware = RateLimitMiddleware(None)
return await middleware.dispatch(request, call_next)
class RateLimitExceeded(HTTPException):
"""Exception raised when rate limit is exceeded"""
def __init__(self, limit: int, reset_time: int):
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Rate limit exceeded. Limit: {limit}, Reset: {reset_time}"
)
# Decorator for applying rate limits to specific endpoints
def rate_limit(requests_per_minute: int = 60, requests_per_hour: int = 1000):
"""
Decorator to apply rate limiting to specific endpoints
Args:
requests_per_minute: Maximum requests per minute
requests_per_hour: Maximum requests per hour
"""
def decorator(func):
async def wrapper(*args, **kwargs):
# This would be implemented to work with FastAPI dependencies
# For now, this is a placeholder for endpoint-specific rate limiting
return await func(*args, **kwargs)
return wrapper
return decorator
# Helper functions for different rate limiting strategies
async def check_api_key_rate_limit(api_key: str, endpoint: str) -> bool:
"""Check rate limit for specific API key and endpoint"""
# This would lookup API key specific limits from database
# For now, using default limits
key = f"api_key:{api_key}:endpoint:{endpoint}"
is_allowed, _ = await rate_limiter.check_rate_limit(
key, limit=100, window_seconds=60, identifier="endpoint"
)
return is_allowed
async def check_user_rate_limit(user_id: str, action: str) -> bool:
"""Check rate limit for specific user and action"""
key = f"user:{user_id}:action:{action}"
is_allowed, _ = await rate_limiter.check_rate_limit(
key, limit=50, window_seconds=60, identifier="user_action"
)
return is_allowed
async def apply_burst_protection(key: str) -> bool:
"""Apply burst protection for high-frequency actions"""
# Allow burst of 10 requests in 10 seconds
is_allowed, _ = await rate_limiter.check_rate_limit(
key, limit=10, window_seconds=10, identifier="burst"
)
return is_allowed