mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
ratelimiting and rag
This commit is contained in:
@@ -83,17 +83,24 @@ class Settings(BaseSettings):
|
|||||||
# Rate Limiting Configuration
|
# Rate Limiting Configuration
|
||||||
API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true"
|
API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true"
|
||||||
|
|
||||||
# Authenticated users (JWT token)
|
# PrivateMode Standard tier limits (organization-level, not per user)
|
||||||
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "300"))
|
# These are shared across all API keys and users in the organization
|
||||||
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "5000"))
|
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20"))
|
||||||
|
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200"))
|
||||||
|
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000"))
|
||||||
|
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000"))
|
||||||
|
|
||||||
|
# Per-user limits (additional protection on top of organization limits)
|
||||||
|
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")) # Match PrivateMode
|
||||||
|
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200"))
|
||||||
|
|
||||||
# API key users (programmatic access)
|
# API key users (programmatic access)
|
||||||
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "1000"))
|
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")) # Match PrivateMode
|
||||||
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "20000"))
|
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200"))
|
||||||
|
|
||||||
# Premium/Enterprise API keys
|
# Premium/Enterprise API keys
|
||||||
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "5000"))
|
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
|
||||||
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "100000"))
|
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
|
||||||
|
|
||||||
# Security Thresholds
|
# Security Thresholds
|
||||||
API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score
|
API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score
|
||||||
|
|||||||
@@ -139,6 +139,10 @@ setup_analytics_middleware(app)
|
|||||||
from app.middleware.security import setup_security_middleware
|
from app.middleware.security import setup_security_middleware
|
||||||
setup_security_middleware(app, enabled=settings.API_SECURITY_ENABLED)
|
setup_security_middleware(app, enabled=settings.API_SECURITY_ENABLED)
|
||||||
|
|
||||||
|
# Add rate limiting middleware only for specific endpoints
|
||||||
|
from app.middleware.rate_limiting import RateLimitMiddleware
|
||||||
|
app.add_middleware(RateLimitMiddleware)
|
||||||
|
|
||||||
|
|
||||||
# Exception handlers
|
# Exception handlers
|
||||||
@app.exception_handler(CustomHTTPException)
|
@app.exception_handler(CustomHTTPException)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import redis
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
from fastapi import Request, HTTPException, status
|
from fastapi import Request, HTTPException, status
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
@@ -155,10 +156,35 @@ class RateLimiter:
|
|||||||
rate_limiter = RateLimiter()
|
rate_limiter = RateLimiter()
|
||||||
|
|
||||||
|
|
||||||
async def rate_limit_middleware(request: Request, call_next):
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
"""
|
"""Rate limiting middleware for FastAPI"""
|
||||||
Rate limiting middleware for FastAPI
|
|
||||||
"""
|
def __init__(self, app):
|
||||||
|
super().__init__(app)
|
||||||
|
self.rate_limiter = RateLimiter()
|
||||||
|
logger.info("RateLimitMiddleware initialized")
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
"""Process request through rate limiting"""
|
||||||
|
|
||||||
|
# Skip rate limiting if disabled in settings
|
||||||
|
if not settings.API_RATE_LIMITING_ENABLED:
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Skip rate limiting for all internal API endpoints (platform operations)
|
||||||
|
if request.url.path.startswith("/api-internal/v1/"):
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service)
|
||||||
|
# Skip for all other endpoints
|
||||||
|
if not (request.url.path.startswith("/api/v1/chat/completions") or
|
||||||
|
request.url.path.startswith("/api/v1/embeddings") or
|
||||||
|
request.url.path.startswith("/api/v1/models") or
|
||||||
|
request.url.path.startswith("/api/v1/llm/")):
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
# Skip rate limiting for health checks and static files
|
# Skip rate limiting for health checks and static files
|
||||||
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
|
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
|
||||||
@@ -180,22 +206,47 @@ async def rate_limit_middleware(request: Request, call_next):
|
|||||||
api_key = request.headers.get("X-API-Key")
|
api_key = request.headers.get("X-API-Key")
|
||||||
|
|
||||||
# Determine rate limiting strategy
|
# Determine rate limiting strategy
|
||||||
|
headers = {}
|
||||||
|
is_allowed = True
|
||||||
|
|
||||||
if api_key:
|
if api_key:
|
||||||
# API key-based rate limiting
|
# API key-based rate limiting
|
||||||
rate_limit_key = f"api_key:{api_key}"
|
api_key_key = f"api_key:{api_key}"
|
||||||
|
|
||||||
# Get API key limits from database (simplified - would implement proper lookup)
|
# First check organization-wide limits (PrivateMode limits are org-wide)
|
||||||
limit_per_minute = 100 # Default limit
|
org_key = "organization:privatemode"
|
||||||
limit_per_hour = 1000 # Default limit
|
|
||||||
|
# Check organization per-minute limit
|
||||||
|
org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit(
|
||||||
|
org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check organization per-hour limit
|
||||||
|
org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit(
|
||||||
|
org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If organization limits are exceeded, return 429
|
||||||
|
if not (org_allowed_minute and org_allowed_hour):
|
||||||
|
logger.warning(f"Organization rate limit exceeded for {org_key}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
content={"detail": "Organization rate limit exceeded"},
|
||||||
|
headers=org_headers_minute
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then check per-API key limits
|
||||||
|
limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE
|
||||||
|
limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR
|
||||||
|
|
||||||
# Check per-minute limit
|
# Check per-minute limit
|
||||||
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
|
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
|
||||||
rate_limit_key, limit_per_minute, 60, "minute"
|
api_key_key, limit_per_minute, 60, "minute"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check per-hour limit
|
# Check per-hour limit
|
||||||
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
|
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
|
||||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
api_key_key, limit_per_hour, 3600, "hour"
|
||||||
)
|
)
|
||||||
|
|
||||||
is_allowed = is_allowed_minute and is_allowed_hour
|
is_allowed = is_allowed_minute and is_allowed_hour
|
||||||
@@ -206,16 +257,16 @@ async def rate_limit_middleware(request: Request, call_next):
|
|||||||
rate_limit_key = f"ip:{client_ip}"
|
rate_limit_key = f"ip:{client_ip}"
|
||||||
|
|
||||||
# More restrictive limits for unauthenticated requests
|
# More restrictive limits for unauthenticated requests
|
||||||
limit_per_minute = 20
|
limit_per_minute = 20 # Hardcoded for unauthenticated users
|
||||||
limit_per_hour = 100
|
limit_per_hour = 100
|
||||||
|
|
||||||
# Check per-minute limit
|
# Check per-minute limit
|
||||||
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
|
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
|
||||||
rate_limit_key, limit_per_minute, 60, "minute"
|
rate_limit_key, limit_per_minute, 60, "minute"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check per-hour limit
|
# Check per-hour limit
|
||||||
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
|
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
|
||||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
rate_limit_key, limit_per_hour, 3600, "hour"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -247,6 +298,13 @@ async def rate_limit_middleware(request: Request, call_next):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# Keep the old function for backward compatibility
|
||||||
|
async def rate_limit_middleware(request: Request, call_next):
|
||||||
|
"""Legacy function - use RateLimitMiddleware class instead"""
|
||||||
|
middleware = RateLimitMiddleware(None)
|
||||||
|
return await middleware.dispatch(request, call_next)
|
||||||
|
|
||||||
|
|
||||||
class RateLimitExceeded(HTTPException):
|
class RateLimitExceeded(HTTPException):
|
||||||
"""Exception raised when rate limit is exceeded"""
|
"""Exception raised when rate limit is exceeded"""
|
||||||
|
|
||||||
|
|||||||
@@ -61,8 +61,8 @@ class SecurityMiddleware(BaseHTTPMiddleware):
|
|||||||
if analysis.is_threat and (analysis.should_block or analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD):
|
if analysis.is_threat and (analysis.should_block or analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD):
|
||||||
await self._log_security_event(request, analysis)
|
await self._log_security_event(request, analysis)
|
||||||
|
|
||||||
# Check if request should be blocked
|
# Check if request should be blocked (excluding rate limiting)
|
||||||
if analysis.should_block:
|
if analysis.should_block and not analysis.rate_limit_exceeded:
|
||||||
threat_detection_service.stats['threats_blocked'] += 1
|
threat_detection_service.stats['threats_blocked'] += 1
|
||||||
logger.warning(f"Blocked request from {request.client.host if request.client else 'unknown'}: "
|
logger.warning(f"Blocked request from {request.client.host if request.client else 'unknown'}: "
|
||||||
f"risk_score={analysis.risk_score:.3f}, threats={len(analysis.threats)}")
|
f"risk_score={analysis.risk_score:.3f}, threats={len(analysis.threats)}")
|
||||||
@@ -137,10 +137,6 @@ class SecurityMiddleware(BaseHTTPMiddleware):
|
|||||||
# Determine status code based on threat type
|
# Determine status code based on threat type
|
||||||
status_code = 403 # Forbidden by default
|
status_code = 403 # Forbidden by default
|
||||||
|
|
||||||
# Rate limiting gets 429
|
|
||||||
if analysis.rate_limit_exceeded:
|
|
||||||
status_code = 429
|
|
||||||
|
|
||||||
# Critical threats get 403
|
# Critical threats get 403
|
||||||
for threat in analysis.threats:
|
for threat in analysis.threats:
|
||||||
if threat.threat_type in ["command_injection", "sql_injection"]:
|
if threat.threat_type in ["command_injection", "sql_injection"]:
|
||||||
@@ -156,23 +152,11 @@ class SecurityMiddleware(BaseHTTPMiddleware):
|
|||||||
"recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations
|
"recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add rate limiting info if applicable
|
|
||||||
if analysis.rate_limit_exceeded:
|
|
||||||
response_data["error"] = "Rate Limit Exceeded"
|
|
||||||
response_data["message"] = f"Rate limit exceeded for {analysis.auth_level.value} user"
|
|
||||||
response_data["retry_after"] = "60" # Suggest retry after 60 seconds
|
|
||||||
|
|
||||||
response = JSONResponse(
|
response = JSONResponse(
|
||||||
content=response_data,
|
content=response_data,
|
||||||
status_code=status_code
|
status_code=status_code
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add rate limiting headers
|
|
||||||
if analysis.rate_limit_exceeded:
|
|
||||||
response.headers["Retry-After"] = "60"
|
|
||||||
response.headers["X-RateLimit-Limit"] = "See API documentation"
|
|
||||||
response.headers["X-RateLimit-Reset"] = str(int(time.time() + 60))
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _add_security_headers(self, response: Response) -> Response:
|
def _add_security_headers(self, response: Response) -> Response:
|
||||||
|
|||||||
201
backend/app/services/enhanced_embedding_service.py
Normal file
201
backend/app/services/enhanced_embedding_service.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
# Enhanced Embedding Service with Rate Limiting Handling
|
||||||
|
"""
|
||||||
|
Enhanced embedding service with robust rate limiting and retry logic
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from .embedding_service import EmbeddingService
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedEmbeddingService(EmbeddingService):
|
||||||
|
"""Enhanced embedding service with rate limiting handling"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"):
|
||||||
|
super().__init__(model_name)
|
||||||
|
self.rate_limit_tracker = {
|
||||||
|
'requests_count': 0,
|
||||||
|
'window_start': time.time(),
|
||||||
|
'window_size': 60, # 1 minute window
|
||||||
|
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 60)), # Configurable
|
||||||
|
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
|
||||||
|
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 0.5)),
|
||||||
|
'last_rate_limit_error': None
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
|
||||||
|
"""
|
||||||
|
Get embeddings with rate limiting and retry logic
|
||||||
|
"""
|
||||||
|
if max_retries is None:
|
||||||
|
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
|
||||||
|
|
||||||
|
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 5))
|
||||||
|
|
||||||
|
if not self.initialized:
|
||||||
|
logger.warning("Embedding service not initialized, using fallback")
|
||||||
|
return self._generate_fallback_embeddings(texts), False
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
success = True
|
||||||
|
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i:i+batch_size]
|
||||||
|
batch_embeddings, batch_success = await self._get_batch_embeddings_with_retry(batch, max_retries)
|
||||||
|
embeddings.extend(batch_embeddings)
|
||||||
|
success = success and batch_success
|
||||||
|
|
||||||
|
# Add delay between batches to avoid rate limiting
|
||||||
|
if i + batch_size < len(texts):
|
||||||
|
delay = self.rate_limit_tracker['delay_between_batches']
|
||||||
|
await asyncio.sleep(delay) # Configurable delay between batches
|
||||||
|
|
||||||
|
return embeddings, success
|
||||||
|
|
||||||
|
async def _get_batch_embeddings_with_retry(self, texts: List[str], max_retries: int) -> tuple[List[List[float]], bool]:
|
||||||
|
"""Get embeddings for a batch with retry logic"""
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
# Check rate limit before making request
|
||||||
|
if self._is_rate_limited():
|
||||||
|
delay = self._get_rate_limit_delay()
|
||||||
|
logger.warning(f"Rate limit detected, waiting {delay} seconds")
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Make the request
|
||||||
|
embeddings = await self._get_embeddings_batch_impl(texts)
|
||||||
|
|
||||||
|
# Update rate limit tracker on success
|
||||||
|
self._update_rate_limit_tracker(success=True)
|
||||||
|
|
||||||
|
return embeddings, True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
error_msg = str(e).lower()
|
||||||
|
|
||||||
|
# Check if it's a rate limit error
|
||||||
|
if any(indicator in error_msg for indicator in ['429', 'rate limit', 'too many requests', 'quota exceeded']):
|
||||||
|
logger.warning(f"Rate limit error (attempt {attempt + 1}/{max_retries + 1}): {e}")
|
||||||
|
self._update_rate_limit_tracker(success=False)
|
||||||
|
|
||||||
|
if attempt < max_retries:
|
||||||
|
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
|
||||||
|
logger.info(f"Retrying in {delay} seconds...")
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.error(f"Max retries exceeded for rate limit, using fallback embeddings")
|
||||||
|
return self._generate_fallback_embeddings(texts), False
|
||||||
|
else:
|
||||||
|
# Non-rate-limit error
|
||||||
|
logger.error(f"Error generating embeddings: {e}")
|
||||||
|
if attempt < max_retries:
|
||||||
|
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
else:
|
||||||
|
logger.error("Max retries exceeded, using fallback embeddings")
|
||||||
|
return self._generate_fallback_embeddings(texts), False
|
||||||
|
|
||||||
|
# If we get here, all retries failed
|
||||||
|
logger.error(f"All retries failed, last error: {last_error}")
|
||||||
|
return self._generate_fallback_embeddings(texts), False
|
||||||
|
|
||||||
|
async def _get_embeddings_batch_impl(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Implementation of getting embeddings for a batch"""
|
||||||
|
from app.services.llm.service import llm_service
|
||||||
|
from app.services.llm.models import EmbeddingRequest
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
# Truncate text if needed
|
||||||
|
max_chars = 1600
|
||||||
|
truncated_text = text[:max_chars] if len(text) > max_chars else text
|
||||||
|
|
||||||
|
llm_request = EmbeddingRequest(
|
||||||
|
model=self.model_name,
|
||||||
|
input=truncated_text,
|
||||||
|
user_id="rag_system",
|
||||||
|
api_key_id=0
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await llm_service.create_embedding(llm_request)
|
||||||
|
|
||||||
|
if response.data and len(response.data) > 0:
|
||||||
|
embedding = response.data[0].embedding
|
||||||
|
if embedding:
|
||||||
|
embeddings.append(embedding)
|
||||||
|
if not hasattr(self, '_dimension_confirmed'):
|
||||||
|
self.dimension = len(embedding)
|
||||||
|
self._dimension_confirmed = True
|
||||||
|
else:
|
||||||
|
raise ValueError("Empty embedding in response")
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid response structure")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def _is_rate_limited(self) -> bool:
|
||||||
|
"""Check if we're currently rate limited"""
|
||||||
|
now = time.time()
|
||||||
|
window_start = self.rate_limit_tracker['window_start']
|
||||||
|
|
||||||
|
# Reset window if it's expired
|
||||||
|
if now - window_start > self.rate_limit_tracker['window_size']:
|
||||||
|
self.rate_limit_tracker['requests_count'] = 0
|
||||||
|
self.rate_limit_tracker['window_start'] = now
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if we've exceeded the limit
|
||||||
|
return self.rate_limit_tracker['requests_count'] >= self.rate_limit_tracker['max_requests_per_minute']
|
||||||
|
|
||||||
|
def _get_rate_limit_delay(self) -> float:
|
||||||
|
"""Get delay to wait for rate limit reset"""
|
||||||
|
now = time.time()
|
||||||
|
window_end = self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size']
|
||||||
|
return max(0, window_end - now)
|
||||||
|
|
||||||
|
def _update_rate_limit_tracker(self, success: bool):
|
||||||
|
"""Update the rate limit tracker"""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Reset window if it's expired
|
||||||
|
if now - self.rate_limit_tracker['window_start'] > self.rate_limit_tracker['window_size']:
|
||||||
|
self.rate_limit_tracker['requests_count'] = 0
|
||||||
|
self.rate_limit_tracker['window_start'] = now
|
||||||
|
|
||||||
|
# Increment counter on successful requests
|
||||||
|
if success:
|
||||||
|
self.rate_limit_tracker['requests_count'] += 1
|
||||||
|
|
||||||
|
async def get_embedding_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get embedding service statistics including rate limiting info"""
|
||||||
|
base_stats = await self.get_stats()
|
||||||
|
|
||||||
|
return {
|
||||||
|
**base_stats,
|
||||||
|
"rate_limit_info": {
|
||||||
|
"requests_in_current_window": self.rate_limit_tracker['requests_count'],
|
||||||
|
"max_requests_per_minute": self.rate_limit_tracker['max_requests_per_minute'],
|
||||||
|
"window_reset_in_seconds": max(0,
|
||||||
|
self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - time.time()
|
||||||
|
),
|
||||||
|
"last_rate_limit_error": self.rate_limit_tracker['last_rate_limit_error']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global enhanced embedding service instance
|
||||||
|
enhanced_embedding_service = EnhancedEmbeddingService()
|
||||||
@@ -66,6 +66,15 @@ class LLMServiceConfig(BaseModel):
|
|||||||
# Provider configurations
|
# Provider configurations
|
||||||
providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations")
|
providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations")
|
||||||
|
|
||||||
|
# Token rate limiting (organization-wide)
|
||||||
|
token_limits_per_minute: Dict[str, int] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"prompt_tokens": 20000, # PrivateMode Standard tier
|
||||||
|
"completion_tokens": 10000 # PrivateMode Standard tier
|
||||||
|
},
|
||||||
|
description="Token rate limits per minute (organization-wide)"
|
||||||
|
)
|
||||||
|
|
||||||
# Model routing (model_name -> provider_name)
|
# Model routing (model_name -> provider_name)
|
||||||
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
|
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
|
||||||
|
|
||||||
@@ -91,8 +100,8 @@ def create_default_config() -> LLMServiceConfig:
|
|||||||
supported_models=[], # Will be populated dynamically from proxy
|
supported_models=[], # Will be populated dynamically from proxy
|
||||||
capabilities=["chat", "embeddings", "tee"],
|
capabilities=["chat", "embeddings", "tee"],
|
||||||
priority=1,
|
priority=1,
|
||||||
max_requests_per_minute=100,
|
max_requests_per_minute=20, # PrivateMode Standard tier limit: 20 req/min
|
||||||
max_requests_per_hour=2000,
|
max_requests_per_hour=1200, # 20 req/min * 60 min
|
||||||
supports_streaming=True,
|
supports_streaming=True,
|
||||||
supports_function_calling=True,
|
supports_function_calling=True,
|
||||||
max_context_window=128000,
|
max_context_window=128000,
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ import tiktoken
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging import log_module_event
|
from app.core.logging import log_module_event
|
||||||
from app.services.base_module import BaseModule, Permission
|
from app.services.base_module import BaseModule, Permission
|
||||||
|
from app.services.enhanced_embedding_service import enhanced_embedding_service
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -1125,8 +1126,16 @@ class RAGModule(BaseModule):
|
|||||||
# Chunk the document
|
# Chunk the document
|
||||||
chunks = self._chunk_text(content)
|
chunks = self._chunk_text(content)
|
||||||
|
|
||||||
# Generate embeddings for all chunks in batch (more efficient)
|
# Generate embeddings with enhanced rate limiting handling
|
||||||
embeddings = await self._generate_embeddings(chunks)
|
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
|
||||||
|
|
||||||
|
# Log if fallback embeddings were used
|
||||||
|
if not success:
|
||||||
|
logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded")
|
||||||
|
log_module_event("rag", "fallback_embeddings_used", {
|
||||||
|
"document_id": doc_id,
|
||||||
|
"content_preview": content[:100] + "..." if len(content) > 100 else content
|
||||||
|
})
|
||||||
|
|
||||||
# Create document points
|
# Create document points
|
||||||
points = []
|
points = []
|
||||||
@@ -1188,8 +1197,16 @@ class RAGModule(BaseModule):
|
|||||||
# Chunk the document
|
# Chunk the document
|
||||||
chunks = self._chunk_text(processed_doc.content)
|
chunks = self._chunk_text(processed_doc.content)
|
||||||
|
|
||||||
# Generate embeddings for all chunks in batch (more efficient)
|
# Generate embeddings with enhanced rate limiting handling
|
||||||
embeddings = await self._generate_embeddings(chunks)
|
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
|
||||||
|
|
||||||
|
# Log if fallback embeddings were used
|
||||||
|
if not success:
|
||||||
|
logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded")
|
||||||
|
log_module_event("rag", "fallback_embeddings_used", {
|
||||||
|
"document_id": processed_doc.id,
|
||||||
|
"filename": processed_doc.original_filename
|
||||||
|
})
|
||||||
|
|
||||||
# Create document points with enhanced metadata
|
# Create document points with enhanced metadata
|
||||||
points = []
|
points = []
|
||||||
|
|||||||
Reference in New Issue
Block a user