mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
ratelimiting and rag
This commit is contained in:
@@ -83,17 +83,24 @@ class Settings(BaseSettings):
|
||||
# Rate Limiting Configuration
|
||||
API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true"
|
||||
|
||||
# Authenticated users (JWT token)
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "300"))
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "5000"))
|
||||
# PrivateMode Standard tier limits (organization-level, not per user)
|
||||
# These are shared across all API keys and users in the organization
|
||||
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20"))
|
||||
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200"))
|
||||
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000"))
|
||||
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000"))
|
||||
|
||||
# Per-user limits (additional protection on top of organization limits)
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200"))
|
||||
|
||||
# API key users (programmatic access)
|
||||
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "1000"))
|
||||
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "20000"))
|
||||
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200"))
|
||||
|
||||
# Premium/Enterprise API keys
|
||||
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "5000"))
|
||||
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "100000"))
|
||||
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
|
||||
|
||||
# Security Thresholds
|
||||
API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score
|
||||
|
||||
@@ -139,6 +139,10 @@ setup_analytics_middleware(app)
|
||||
from app.middleware.security import setup_security_middleware
|
||||
setup_security_middleware(app, enabled=settings.API_SECURITY_ENABLED)
|
||||
|
||||
# Add rate limiting middleware only for specific endpoints
|
||||
from app.middleware.rate_limiting import RateLimitMiddleware
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
|
||||
|
||||
# Exception handlers
|
||||
@app.exception_handler(CustomHTTPException)
|
||||
|
||||
@@ -7,6 +7,7 @@ import redis
|
||||
from typing import Dict, Optional
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
@@ -155,96 +156,153 @@ class RateLimiter:
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
"""
|
||||
Rate limiting middleware for FastAPI
|
||||
"""
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware for FastAPI"""
|
||||
|
||||
# Skip rate limiting for health checks and static files
|
||||
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.rate_limiter = RateLimiter()
|
||||
logger.info("RateLimitMiddleware initialized")
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process request through rate limiting"""
|
||||
|
||||
# Skip rate limiting if disabled in settings
|
||||
if not settings.API_RATE_LIMITING_ENABLED:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Skip rate limiting for all internal API endpoints (platform operations)
|
||||
if request.url.path.startswith("/api-internal/v1/"):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service)
|
||||
# Skip for all other endpoints
|
||||
if not (request.url.path.startswith("/api/v1/chat/completions") or
|
||||
request.url.path.startswith("/api/v1/embeddings") or
|
||||
request.url.path.startswith("/api/v1/models") or
|
||||
request.url.path.startswith("/api/v1/llm/")):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Skip rate limiting for health checks and static files
|
||||
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Check for API key in headers
|
||||
api_key = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
api_key = auth_header[7:]
|
||||
elif request.headers.get("X-API-Key"):
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
|
||||
# Determine rate limiting strategy
|
||||
headers = {}
|
||||
is_allowed = True
|
||||
|
||||
if api_key:
|
||||
# API key-based rate limiting
|
||||
api_key_key = f"api_key:{api_key}"
|
||||
|
||||
# First check organization-wide limits (PrivateMode limits are org-wide)
|
||||
org_key = "organization:privatemode"
|
||||
|
||||
# Check organization per-minute limit
|
||||
org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit(
|
||||
org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute"
|
||||
)
|
||||
|
||||
# Check organization per-hour limit
|
||||
org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit(
|
||||
org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour"
|
||||
)
|
||||
|
||||
# If organization limits are exceeded, return 429
|
||||
if not (org_allowed_minute and org_allowed_hour):
|
||||
logger.warning(f"Organization rate limit exceeded for {org_key}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={"detail": "Organization rate limit exceeded"},
|
||||
headers=org_headers_minute
|
||||
)
|
||||
|
||||
# Then check per-API key limits
|
||||
limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE
|
||||
limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
|
||||
api_key_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
|
||||
api_key_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
is_allowed = is_allowed_minute and is_allowed_hour
|
||||
headers = headers_minute # Use minute headers for response
|
||||
|
||||
else:
|
||||
# IP-based rate limiting for unauthenticated requests
|
||||
rate_limit_key = f"ip:{client_ip}"
|
||||
|
||||
# More restrictive limits for unauthenticated requests
|
||||
limit_per_minute = 20 # Hardcoded for unauthenticated users
|
||||
limit_per_hour = 100
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
is_allowed = is_allowed_minute and is_allowed_hour
|
||||
headers = headers_minute # Use minute headers for response
|
||||
|
||||
# If rate limit exceeded, return 429
|
||||
if not is_allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"error": "RATE_LIMIT_EXCEEDED",
|
||||
"message": "Rate limit exceeded. Please try again later.",
|
||||
"details": {
|
||||
"limit": headers["X-RateLimit-Limit"],
|
||||
"reset_time": headers["X-RateLimit-Reset"]
|
||||
}
|
||||
},
|
||||
headers={k: str(v) for k, v in headers.items()}
|
||||
)
|
||||
|
||||
# Continue with request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers to response
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = str(value)
|
||||
|
||||
return response
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Check for API key in headers
|
||||
api_key = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
api_key = auth_header[7:]
|
||||
elif request.headers.get("X-API-Key"):
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
|
||||
# Determine rate limiting strategy
|
||||
if api_key:
|
||||
# API key-based rate limiting
|
||||
rate_limit_key = f"api_key:{api_key}"
|
||||
|
||||
# Get API key limits from database (simplified - would implement proper lookup)
|
||||
limit_per_minute = 100 # Default limit
|
||||
limit_per_hour = 1000 # Default limit
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
is_allowed = is_allowed_minute and is_allowed_hour
|
||||
headers = headers_minute # Use minute headers for response
|
||||
|
||||
else:
|
||||
# IP-based rate limiting for unauthenticated requests
|
||||
rate_limit_key = f"ip:{client_ip}"
|
||||
|
||||
# More restrictive limits for unauthenticated requests
|
||||
limit_per_minute = 20
|
||||
limit_per_hour = 100
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
is_allowed = is_allowed_minute and is_allowed_hour
|
||||
headers = headers_minute # Use minute headers for response
|
||||
|
||||
# If rate limit exceeded, return 429
|
||||
if not is_allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"error": "RATE_LIMIT_EXCEEDED",
|
||||
"message": "Rate limit exceeded. Please try again later.",
|
||||
"details": {
|
||||
"limit": headers["X-RateLimit-Limit"],
|
||||
"reset_time": headers["X-RateLimit-Reset"]
|
||||
}
|
||||
},
|
||||
headers={k: str(v) for k, v in headers.items()}
|
||||
)
|
||||
|
||||
# Continue with request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers to response
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = str(value)
|
||||
|
||||
return response
|
||||
# Keep the old function for backward compatibility
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
"""Legacy function - use RateLimitMiddleware class instead"""
|
||||
middleware = RateLimitMiddleware(None)
|
||||
return await middleware.dispatch(request, call_next)
|
||||
|
||||
|
||||
class RateLimitExceeded(HTTPException):
|
||||
|
||||
@@ -61,8 +61,8 @@ class SecurityMiddleware(BaseHTTPMiddleware):
|
||||
if analysis.is_threat and (analysis.should_block or analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD):
|
||||
await self._log_security_event(request, analysis)
|
||||
|
||||
# Check if request should be blocked
|
||||
if analysis.should_block:
|
||||
# Check if request should be blocked (excluding rate limiting)
|
||||
if analysis.should_block and not analysis.rate_limit_exceeded:
|
||||
threat_detection_service.stats['threats_blocked'] += 1
|
||||
logger.warning(f"Blocked request from {request.client.host if request.client else 'unknown'}: "
|
||||
f"risk_score={analysis.risk_score:.3f}, threats={len(analysis.threats)}")
|
||||
@@ -137,10 +137,6 @@ class SecurityMiddleware(BaseHTTPMiddleware):
|
||||
# Determine status code based on threat type
|
||||
status_code = 403 # Forbidden by default
|
||||
|
||||
# Rate limiting gets 429
|
||||
if analysis.rate_limit_exceeded:
|
||||
status_code = 429
|
||||
|
||||
# Critical threats get 403
|
||||
for threat in analysis.threats:
|
||||
if threat.threat_type in ["command_injection", "sql_injection"]:
|
||||
@@ -156,23 +152,11 @@ class SecurityMiddleware(BaseHTTPMiddleware):
|
||||
"recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations
|
||||
}
|
||||
|
||||
# Add rate limiting info if applicable
|
||||
if analysis.rate_limit_exceeded:
|
||||
response_data["error"] = "Rate Limit Exceeded"
|
||||
response_data["message"] = f"Rate limit exceeded for {analysis.auth_level.value} user"
|
||||
response_data["retry_after"] = "60" # Suggest retry after 60 seconds
|
||||
|
||||
response = JSONResponse(
|
||||
content=response_data,
|
||||
status_code=status_code
|
||||
)
|
||||
|
||||
# Add rate limiting headers
|
||||
if analysis.rate_limit_exceeded:
|
||||
response.headers["Retry-After"] = "60"
|
||||
response.headers["X-RateLimit-Limit"] = "See API documentation"
|
||||
response.headers["X-RateLimit-Reset"] = str(int(time.time() + 60))
|
||||
|
||||
return response
|
||||
|
||||
def _add_security_headers(self, response: Response) -> Response:
|
||||
|
||||
201
backend/app/services/enhanced_embedding_service.py
Normal file
201
backend/app/services/enhanced_embedding_service.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Enhanced Embedding Service with Rate Limiting Handling
|
||||
"""
|
||||
Enhanced embedding service with robust rate limiting and retry logic
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from .embedding_service import EmbeddingService
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnhancedEmbeddingService(EmbeddingService):
|
||||
"""Enhanced embedding service with rate limiting handling"""
|
||||
|
||||
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"):
|
||||
super().__init__(model_name)
|
||||
self.rate_limit_tracker = {
|
||||
'requests_count': 0,
|
||||
'window_start': time.time(),
|
||||
'window_size': 60, # 1 minute window
|
||||
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 60)), # Configurable
|
||||
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
|
||||
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 0.5)),
|
||||
'last_rate_limit_error': None
|
||||
}
|
||||
|
||||
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
|
||||
"""
|
||||
Get embeddings with rate limiting and retry logic
|
||||
"""
|
||||
if max_retries is None:
|
||||
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
|
||||
|
||||
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 5))
|
||||
|
||||
if not self.initialized:
|
||||
logger.warning("Embedding service not initialized, using fallback")
|
||||
return self._generate_fallback_embeddings(texts), False
|
||||
|
||||
embeddings = []
|
||||
success = True
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
batch_embeddings, batch_success = await self._get_batch_embeddings_with_retry(batch, max_retries)
|
||||
embeddings.extend(batch_embeddings)
|
||||
success = success and batch_success
|
||||
|
||||
# Add delay between batches to avoid rate limiting
|
||||
if i + batch_size < len(texts):
|
||||
delay = self.rate_limit_tracker['delay_between_batches']
|
||||
await asyncio.sleep(delay) # Configurable delay between batches
|
||||
|
||||
return embeddings, success
|
||||
|
||||
async def _get_batch_embeddings_with_retry(self, texts: List[str], max_retries: int) -> tuple[List[List[float]], bool]:
|
||||
"""Get embeddings for a batch with retry logic"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
# Check rate limit before making request
|
||||
if self._is_rate_limited():
|
||||
delay = self._get_rate_limit_delay()
|
||||
logger.warning(f"Rate limit detected, waiting {delay} seconds")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
# Make the request
|
||||
embeddings = await self._get_embeddings_batch_impl(texts)
|
||||
|
||||
# Update rate limit tracker on success
|
||||
self._update_rate_limit_tracker(success=True)
|
||||
|
||||
return embeddings, True
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_msg = str(e).lower()
|
||||
|
||||
# Check if it's a rate limit error
|
||||
if any(indicator in error_msg for indicator in ['429', 'rate limit', 'too many requests', 'quota exceeded']):
|
||||
logger.warning(f"Rate limit error (attempt {attempt + 1}/{max_retries + 1}): {e}")
|
||||
self._update_rate_limit_tracker(success=False)
|
||||
|
||||
if attempt < max_retries:
|
||||
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
|
||||
logger.info(f"Retrying in {delay} seconds...")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Max retries exceeded for rate limit, using fallback embeddings")
|
||||
return self._generate_fallback_embeddings(texts), False
|
||||
else:
|
||||
# Non-rate-limit error
|
||||
logger.error(f"Error generating embeddings: {e}")
|
||||
if attempt < max_retries:
|
||||
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error("Max retries exceeded, using fallback embeddings")
|
||||
return self._generate_fallback_embeddings(texts), False
|
||||
|
||||
# If we get here, all retries failed
|
||||
logger.error(f"All retries failed, last error: {last_error}")
|
||||
return self._generate_fallback_embeddings(texts), False
|
||||
|
||||
async def _get_embeddings_batch_impl(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Implementation of getting embeddings for a batch"""
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import EmbeddingRequest
|
||||
|
||||
embeddings = []
|
||||
|
||||
for text in texts:
|
||||
# Truncate text if needed
|
||||
max_chars = 1600
|
||||
truncated_text = text[:max_chars] if len(text) > max_chars else text
|
||||
|
||||
llm_request = EmbeddingRequest(
|
||||
model=self.model_name,
|
||||
input=truncated_text,
|
||||
user_id="rag_system",
|
||||
api_key_id=0
|
||||
)
|
||||
|
||||
response = await llm_service.create_embedding(llm_request)
|
||||
|
||||
if response.data and len(response.data) > 0:
|
||||
embedding = response.data[0].embedding
|
||||
if embedding:
|
||||
embeddings.append(embedding)
|
||||
if not hasattr(self, '_dimension_confirmed'):
|
||||
self.dimension = len(embedding)
|
||||
self._dimension_confirmed = True
|
||||
else:
|
||||
raise ValueError("Empty embedding in response")
|
||||
else:
|
||||
raise ValueError("Invalid response structure")
|
||||
|
||||
return embeddings
|
||||
|
||||
def _is_rate_limited(self) -> bool:
|
||||
"""Check if we're currently rate limited"""
|
||||
now = time.time()
|
||||
window_start = self.rate_limit_tracker['window_start']
|
||||
|
||||
# Reset window if it's expired
|
||||
if now - window_start > self.rate_limit_tracker['window_size']:
|
||||
self.rate_limit_tracker['requests_count'] = 0
|
||||
self.rate_limit_tracker['window_start'] = now
|
||||
return False
|
||||
|
||||
# Check if we've exceeded the limit
|
||||
return self.rate_limit_tracker['requests_count'] >= self.rate_limit_tracker['max_requests_per_minute']
|
||||
|
||||
def _get_rate_limit_delay(self) -> float:
|
||||
"""Get delay to wait for rate limit reset"""
|
||||
now = time.time()
|
||||
window_end = self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size']
|
||||
return max(0, window_end - now)
|
||||
|
||||
def _update_rate_limit_tracker(self, success: bool):
|
||||
"""Update the rate limit tracker"""
|
||||
now = time.time()
|
||||
|
||||
# Reset window if it's expired
|
||||
if now - self.rate_limit_tracker['window_start'] > self.rate_limit_tracker['window_size']:
|
||||
self.rate_limit_tracker['requests_count'] = 0
|
||||
self.rate_limit_tracker['window_start'] = now
|
||||
|
||||
# Increment counter on successful requests
|
||||
if success:
|
||||
self.rate_limit_tracker['requests_count'] += 1
|
||||
|
||||
async def get_embedding_stats(self) -> Dict[str, Any]:
|
||||
"""Get embedding service statistics including rate limiting info"""
|
||||
base_stats = await self.get_stats()
|
||||
|
||||
return {
|
||||
**base_stats,
|
||||
"rate_limit_info": {
|
||||
"requests_in_current_window": self.rate_limit_tracker['requests_count'],
|
||||
"max_requests_per_minute": self.rate_limit_tracker['max_requests_per_minute'],
|
||||
"window_reset_in_seconds": max(0,
|
||||
self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - time.time()
|
||||
),
|
||||
"last_rate_limit_error": self.rate_limit_tracker['last_rate_limit_error']
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Global enhanced embedding service instance
|
||||
enhanced_embedding_service = EnhancedEmbeddingService()
|
||||
@@ -66,6 +66,15 @@ class LLMServiceConfig(BaseModel):
|
||||
# Provider configurations
|
||||
providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations")
|
||||
|
||||
# Token rate limiting (organization-wide)
|
||||
token_limits_per_minute: Dict[str, int] = Field(
|
||||
default_factory=lambda: {
|
||||
"prompt_tokens": 20000, # PrivateMode Standard tier
|
||||
"completion_tokens": 10000 # PrivateMode Standard tier
|
||||
},
|
||||
description="Token rate limits per minute (organization-wide)"
|
||||
)
|
||||
|
||||
# Model routing (model_name -> provider_name)
|
||||
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
|
||||
|
||||
@@ -91,8 +100,8 @@ def create_default_config() -> LLMServiceConfig:
|
||||
supported_models=[], # Will be populated dynamically from proxy
|
||||
capabilities=["chat", "embeddings", "tee"],
|
||||
priority=1,
|
||||
max_requests_per_minute=100,
|
||||
max_requests_per_hour=2000,
|
||||
max_requests_per_minute=20, # PrivateMode Standard tier limit: 20 req/min
|
||||
max_requests_per_hour=1200, # 20 req/min * 60 min
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
max_context_window=128000,
|
||||
|
||||
@@ -60,6 +60,7 @@ import tiktoken
|
||||
from app.core.config import settings
|
||||
from app.core.logging import log_module_event
|
||||
from app.services.base_module import BaseModule, Permission
|
||||
from app.services.enhanced_embedding_service import enhanced_embedding_service
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1125,8 +1126,16 @@ class RAGModule(BaseModule):
|
||||
# Chunk the document
|
||||
chunks = self._chunk_text(content)
|
||||
|
||||
# Generate embeddings for all chunks in batch (more efficient)
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
# Generate embeddings with enhanced rate limiting handling
|
||||
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
|
||||
|
||||
# Log if fallback embeddings were used
|
||||
if not success:
|
||||
logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded")
|
||||
log_module_event("rag", "fallback_embeddings_used", {
|
||||
"document_id": doc_id,
|
||||
"content_preview": content[:100] + "..." if len(content) > 100 else content
|
||||
})
|
||||
|
||||
# Create document points
|
||||
points = []
|
||||
@@ -1188,8 +1197,16 @@ class RAGModule(BaseModule):
|
||||
# Chunk the document
|
||||
chunks = self._chunk_text(processed_doc.content)
|
||||
|
||||
# Generate embeddings for all chunks in batch (more efficient)
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
# Generate embeddings with enhanced rate limiting handling
|
||||
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
|
||||
|
||||
# Log if fallback embeddings were used
|
||||
if not success:
|
||||
logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded")
|
||||
log_module_event("rag", "fallback_embeddings_used", {
|
||||
"document_id": processed_doc.id,
|
||||
"filename": processed_doc.original_filename
|
||||
})
|
||||
|
||||
# Create document points with enhanced metadata
|
||||
points = []
|
||||
|
||||
Reference in New Issue
Block a user