ratelimiting and rag

This commit is contained in:
2025-09-21 06:49:55 +02:00
parent 0c20de4ca1
commit f58a76ac59
7 changed files with 410 additions and 130 deletions

View File

@@ -7,6 +7,7 @@ import redis
from typing import Dict, Optional
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import asyncio
from datetime import datetime, timedelta
@@ -155,96 +156,153 @@ class RateLimiter:
rate_limiter = RateLimiter()
async def rate_limit_middleware(request: Request, call_next):
"""
Rate limiting middleware for FastAPI
"""
# Skip rate limiting for health checks and static files
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware for FastAPI"""
def __init__(self, app):
super().__init__(app)
self.rate_limiter = RateLimiter()
logger.info("RateLimitMiddleware initialized")
async def dispatch(self, request: Request, call_next):
"""Process request through rate limiting"""
# Skip rate limiting if disabled in settings
if not settings.API_RATE_LIMITING_ENABLED:
response = await call_next(request)
return response
# Skip rate limiting for all internal API endpoints (platform operations)
if request.url.path.startswith("/api-internal/v1/"):
response = await call_next(request)
return response
# Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service)
# Skip for all other endpoints
if not (request.url.path.startswith("/api/v1/chat/completions") or
request.url.path.startswith("/api/v1/embeddings") or
request.url.path.startswith("/api/v1/models") or
request.url.path.startswith("/api/v1/llm/")):
response = await call_next(request)
return response
# Skip rate limiting for health checks and static files
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
response = await call_next(request)
return response
# Get client IP
client_ip = request.client.host
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
# Check for API key in headers
api_key = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:]
elif request.headers.get("X-API-Key"):
api_key = request.headers.get("X-API-Key")
# Determine rate limiting strategy
headers = {}
is_allowed = True
if api_key:
# API key-based rate limiting
api_key_key = f"api_key:{api_key}"
# First check organization-wide limits (PrivateMode limits are org-wide)
org_key = "organization:privatemode"
# Check organization per-minute limit
org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit(
org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute"
)
# Check organization per-hour limit
org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit(
org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour"
)
# If organization limits are exceeded, return 429
if not (org_allowed_minute and org_allowed_hour):
logger.warning(f"Organization rate limit exceeded for {org_key}")
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Organization rate limit exceeded"},
headers=org_headers_minute
)
# Then check per-API key limits
limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE
limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR
# Check per-minute limit
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
api_key_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
api_key_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
else:
# IP-based rate limiting for unauthenticated requests
rate_limit_key = f"ip:{client_ip}"
# More restrictive limits for unauthenticated requests
limit_per_minute = 20 # Hardcoded for unauthenticated users
limit_per_hour = 100
# Check per-minute limit
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
# If rate limit exceeded, return 429
if not is_allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"error": "RATE_LIMIT_EXCEEDED",
"message": "Rate limit exceeded. Please try again later.",
"details": {
"limit": headers["X-RateLimit-Limit"],
"reset_time": headers["X-RateLimit-Reset"]
}
},
headers={k: str(v) for k, v in headers.items()}
)
# Continue with request
response = await call_next(request)
# Add rate limit headers to response
for key, value in headers.items():
response.headers[key] = str(value)
return response
# Get client IP
client_ip = request.client.host
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
# Check for API key in headers
api_key = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:]
elif request.headers.get("X-API-Key"):
api_key = request.headers.get("X-API-Key")
# Determine rate limiting strategy
if api_key:
# API key-based rate limiting
rate_limit_key = f"api_key:{api_key}"
# Get API key limits from database (simplified - would implement proper lookup)
limit_per_minute = 100 # Default limit
limit_per_hour = 1000 # Default limit
# Check per-minute limit
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
else:
# IP-based rate limiting for unauthenticated requests
rate_limit_key = f"ip:{client_ip}"
# More restrictive limits for unauthenticated requests
limit_per_minute = 20
limit_per_hour = 100
# Check per-minute limit
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
# If rate limit exceeded, return 429
if not is_allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"error": "RATE_LIMIT_EXCEEDED",
"message": "Rate limit exceeded. Please try again later.",
"details": {
"limit": headers["X-RateLimit-Limit"],
"reset_time": headers["X-RateLimit-Reset"]
}
},
headers={k: str(v) for k, v in headers.items()}
)
# Continue with request
response = await call_next(request)
# Add rate limit headers to response
for key, value in headers.items():
response.headers[key] = str(value)
return response
# Keep the old function for backward compatibility
async def rate_limit_middleware(request: Request, call_next):
"""Legacy function - use RateLimitMiddleware class instead"""
middleware = RateLimitMiddleware(None)
return await middleware.dispatch(request, call_next)
class RateLimitExceeded(HTTPException):