mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 15:34:36 +01:00
ratelimiting and rag
This commit is contained in:
@@ -7,6 +7,7 @@ import redis
|
||||
from typing import Dict, Optional
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
@@ -155,96 +156,153 @@ class RateLimiter:
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
"""
|
||||
Rate limiting middleware for FastAPI
|
||||
"""
|
||||
|
||||
# Skip rate limiting for health checks and static files
|
||||
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware for FastAPI"""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.rate_limiter = RateLimiter()
|
||||
logger.info("RateLimitMiddleware initialized")
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process request through rate limiting"""
|
||||
|
||||
# Skip rate limiting if disabled in settings
|
||||
if not settings.API_RATE_LIMITING_ENABLED:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Skip rate limiting for all internal API endpoints (platform operations)
|
||||
if request.url.path.startswith("/api-internal/v1/"):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service)
|
||||
# Skip for all other endpoints
|
||||
if not (request.url.path.startswith("/api/v1/chat/completions") or
|
||||
request.url.path.startswith("/api/v1/embeddings") or
|
||||
request.url.path.startswith("/api/v1/models") or
|
||||
request.url.path.startswith("/api/v1/llm/")):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Skip rate limiting for health checks and static files
|
||||
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Check for API key in headers
|
||||
api_key = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
api_key = auth_header[7:]
|
||||
elif request.headers.get("X-API-Key"):
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
|
||||
# Determine rate limiting strategy
|
||||
headers = {}
|
||||
is_allowed = True
|
||||
|
||||
if api_key:
|
||||
# API key-based rate limiting
|
||||
api_key_key = f"api_key:{api_key}"
|
||||
|
||||
# First check organization-wide limits (PrivateMode limits are org-wide)
|
||||
org_key = "organization:privatemode"
|
||||
|
||||
# Check organization per-minute limit
|
||||
org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit(
|
||||
org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute"
|
||||
)
|
||||
|
||||
# Check organization per-hour limit
|
||||
org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit(
|
||||
org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour"
|
||||
)
|
||||
|
||||
# If organization limits are exceeded, return 429
|
||||
if not (org_allowed_minute and org_allowed_hour):
|
||||
logger.warning(f"Organization rate limit exceeded for {org_key}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={"detail": "Organization rate limit exceeded"},
|
||||
headers=org_headers_minute
|
||||
)
|
||||
|
||||
# Then check per-API key limits
|
||||
limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE
|
||||
limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
|
||||
api_key_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
|
||||
api_key_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
is_allowed = is_allowed_minute and is_allowed_hour
|
||||
headers = headers_minute # Use minute headers for response
|
||||
|
||||
else:
|
||||
# IP-based rate limiting for unauthenticated requests
|
||||
rate_limit_key = f"ip:{client_ip}"
|
||||
|
||||
# More restrictive limits for unauthenticated requests
|
||||
limit_per_minute = 20 # Hardcoded for unauthenticated users
|
||||
limit_per_hour = 100
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
is_allowed = is_allowed_minute and is_allowed_hour
|
||||
headers = headers_minute # Use minute headers for response
|
||||
|
||||
# If rate limit exceeded, return 429
|
||||
if not is_allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"error": "RATE_LIMIT_EXCEEDED",
|
||||
"message": "Rate limit exceeded. Please try again later.",
|
||||
"details": {
|
||||
"limit": headers["X-RateLimit-Limit"],
|
||||
"reset_time": headers["X-RateLimit-Reset"]
|
||||
}
|
||||
},
|
||||
headers={k: str(v) for k, v in headers.items()}
|
||||
)
|
||||
|
||||
# Continue with request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers to response
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = str(value)
|
||||
|
||||
return response
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Check for API key in headers
|
||||
api_key = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
api_key = auth_header[7:]
|
||||
elif request.headers.get("X-API-Key"):
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
|
||||
# Determine rate limiting strategy
|
||||
if api_key:
|
||||
# API key-based rate limiting
|
||||
rate_limit_key = f"api_key:{api_key}"
|
||||
|
||||
# Get API key limits from database (simplified - would implement proper lookup)
|
||||
limit_per_minute = 100 # Default limit
|
||||
limit_per_hour = 1000 # Default limit
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
is_allowed = is_allowed_minute and is_allowed_hour
|
||||
headers = headers_minute # Use minute headers for response
|
||||
|
||||
else:
|
||||
# IP-based rate limiting for unauthenticated requests
|
||||
rate_limit_key = f"ip:{client_ip}"
|
||||
|
||||
# More restrictive limits for unauthenticated requests
|
||||
limit_per_minute = 20
|
||||
limit_per_hour = 100
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
is_allowed = is_allowed_minute and is_allowed_hour
|
||||
headers = headers_minute # Use minute headers for response
|
||||
|
||||
# If rate limit exceeded, return 429
|
||||
if not is_allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"error": "RATE_LIMIT_EXCEEDED",
|
||||
"message": "Rate limit exceeded. Please try again later.",
|
||||
"details": {
|
||||
"limit": headers["X-RateLimit-Limit"],
|
||||
"reset_time": headers["X-RateLimit-Reset"]
|
||||
}
|
||||
},
|
||||
headers={k: str(v) for k, v in headers.items()}
|
||||
)
|
||||
|
||||
# Continue with request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers to response
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = str(value)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Keep the old function for backward compatibility
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
"""Legacy function - use RateLimitMiddleware class instead"""
|
||||
middleware = RateLimitMiddleware(None)
|
||||
return await middleware.dispatch(request, call_next)
|
||||
|
||||
|
||||
class RateLimitExceeded(HTTPException):
|
||||
|
||||
Reference in New Issue
Block a user