ratelimiting and rag

This commit is contained in:
2025-09-21 06:49:55 +02:00
parent 0c20de4ca1
commit f58a76ac59
7 changed files with 410 additions and 130 deletions

View File

@@ -82,18 +82,25 @@ class Settings(BaseSettings):
# Rate Limiting Configuration # Rate Limiting Configuration
API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true" API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true"
# Authenticated users (JWT token) # PrivateMode Standard tier limits (organization-level, not per user)
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "300")) # These are shared across all API keys and users in the organization
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "5000")) PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20"))
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200"))
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000"))
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000"))
# Per-user limits (additional protection on top of organization limits)
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200"))
# API key users (programmatic access) # API key users (programmatic access)
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "1000")) API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "20000")) API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200"))
# Premium/Enterprise API keys # Premium/Enterprise API keys
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "5000")) API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "100000")) API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
# Security Thresholds # Security Thresholds
API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score

View File

@@ -139,6 +139,10 @@ setup_analytics_middleware(app)
from app.middleware.security import setup_security_middleware from app.middleware.security import setup_security_middleware
setup_security_middleware(app, enabled=settings.API_SECURITY_ENABLED) setup_security_middleware(app, enabled=settings.API_SECURITY_ENABLED)
# Add rate limiting middleware only for specific endpoints
from app.middleware.rate_limiting import RateLimitMiddleware
app.add_middleware(RateLimitMiddleware)
# Exception handlers # Exception handlers
@app.exception_handler(CustomHTTPException) @app.exception_handler(CustomHTTPException)

View File

@@ -7,6 +7,7 @@ import redis
from typing import Dict, Optional from typing import Dict, Optional
from fastapi import Request, HTTPException, status from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
@@ -155,96 +156,153 @@ class RateLimiter:
rate_limiter = RateLimiter() rate_limiter = RateLimiter()
async def rate_limit_middleware(request: Request, call_next): class RateLimitMiddleware(BaseHTTPMiddleware):
""" """Rate limiting middleware for FastAPI"""
Rate limiting middleware for FastAPI
""" def __init__(self, app):
super().__init__(app)
# Skip rate limiting for health checks and static files self.rate_limiter = RateLimiter()
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]: logger.info("RateLimitMiddleware initialized")
async def dispatch(self, request: Request, call_next):
"""Process request through rate limiting"""
# Skip rate limiting if disabled in settings
if not settings.API_RATE_LIMITING_ENABLED:
response = await call_next(request)
return response
# Skip rate limiting for all internal API endpoints (platform operations)
if request.url.path.startswith("/api-internal/v1/"):
response = await call_next(request)
return response
# Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service)
# Skip for all other endpoints
if not (request.url.path.startswith("/api/v1/chat/completions") or
request.url.path.startswith("/api/v1/embeddings") or
request.url.path.startswith("/api/v1/models") or
request.url.path.startswith("/api/v1/llm/")):
response = await call_next(request)
return response
# Skip rate limiting for health checks and static files
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
response = await call_next(request)
return response
# Get client IP
client_ip = request.client.host
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
# Check for API key in headers
api_key = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:]
elif request.headers.get("X-API-Key"):
api_key = request.headers.get("X-API-Key")
# Determine rate limiting strategy
headers = {}
is_allowed = True
if api_key:
# API key-based rate limiting
api_key_key = f"api_key:{api_key}"
# First check organization-wide limits (PrivateMode limits are org-wide)
org_key = "organization:privatemode"
# Check organization per-minute limit
org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit(
org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute"
)
# Check organization per-hour limit
org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit(
org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour"
)
# If organization limits are exceeded, return 429
if not (org_allowed_minute and org_allowed_hour):
logger.warning(f"Organization rate limit exceeded for {org_key}")
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Organization rate limit exceeded"},
headers=org_headers_minute
)
# Then check per-API key limits
limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE
limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR
# Check per-minute limit
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
api_key_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
api_key_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
else:
# IP-based rate limiting for unauthenticated requests
rate_limit_key = f"ip:{client_ip}"
# More restrictive limits for unauthenticated requests
limit_per_minute = 20 # Hardcoded for unauthenticated users
limit_per_hour = 100
# Check per-minute limit
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
# If rate limit exceeded, return 429
if not is_allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"error": "RATE_LIMIT_EXCEEDED",
"message": "Rate limit exceeded. Please try again later.",
"details": {
"limit": headers["X-RateLimit-Limit"],
"reset_time": headers["X-RateLimit-Reset"]
}
},
headers={k: str(v) for k, v in headers.items()}
)
# Continue with request
response = await call_next(request) response = await call_next(request)
# Add rate limit headers to response
for key, value in headers.items():
response.headers[key] = str(value)
return response return response
# Get client IP
client_ip = request.client.host # Keep the old function for backward compatibility
forwarded_for = request.headers.get("X-Forwarded-For") async def rate_limit_middleware(request: Request, call_next):
if forwarded_for: """Legacy function - use RateLimitMiddleware class instead"""
client_ip = forwarded_for.split(",")[0].strip() middleware = RateLimitMiddleware(None)
return await middleware.dispatch(request, call_next)
# Check for API key in headers
api_key = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:]
elif request.headers.get("X-API-Key"):
api_key = request.headers.get("X-API-Key")
# Determine rate limiting strategy
if api_key:
# API key-based rate limiting
rate_limit_key = f"api_key:{api_key}"
# Get API key limits from database (simplified - would implement proper lookup)
limit_per_minute = 100 # Default limit
limit_per_hour = 1000 # Default limit
# Check per-minute limit
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
else:
# IP-based rate limiting for unauthenticated requests
rate_limit_key = f"ip:{client_ip}"
# More restrictive limits for unauthenticated requests
limit_per_minute = 20
limit_per_hour = 100
# Check per-minute limit
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
# If rate limit exceeded, return 429
if not is_allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"error": "RATE_LIMIT_EXCEEDED",
"message": "Rate limit exceeded. Please try again later.",
"details": {
"limit": headers["X-RateLimit-Limit"],
"reset_time": headers["X-RateLimit-Reset"]
}
},
headers={k: str(v) for k, v in headers.items()}
)
# Continue with request
response = await call_next(request)
# Add rate limit headers to response
for key, value in headers.items():
response.headers[key] = str(value)
return response
class RateLimitExceeded(HTTPException): class RateLimitExceeded(HTTPException):

View File

@@ -61,12 +61,12 @@ class SecurityMiddleware(BaseHTTPMiddleware):
if analysis.is_threat and (analysis.should_block or analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD): if analysis.is_threat and (analysis.should_block or analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD):
await self._log_security_event(request, analysis) await self._log_security_event(request, analysis)
# Check if request should be blocked # Check if request should be blocked (excluding rate limiting)
if analysis.should_block: if analysis.should_block and not analysis.rate_limit_exceeded:
threat_detection_service.stats['threats_blocked'] += 1 threat_detection_service.stats['threats_blocked'] += 1
logger.warning(f"Blocked request from {request.client.host if request.client else 'unknown'}: " logger.warning(f"Blocked request from {request.client.host if request.client else 'unknown'}: "
f"risk_score={analysis.risk_score:.3f}, threats={len(analysis.threats)}") f"risk_score={analysis.risk_score:.3f}, threats={len(analysis.threats)}")
# Return security block response # Return security block response
return self._create_block_response(analysis) return self._create_block_response(analysis)
@@ -136,17 +136,13 @@ class SecurityMiddleware(BaseHTTPMiddleware):
"""Create response for blocked requests""" """Create response for blocked requests"""
# Determine status code based on threat type # Determine status code based on threat type
status_code = 403 # Forbidden by default status_code = 403 # Forbidden by default
# Rate limiting gets 429
if analysis.rate_limit_exceeded:
status_code = 429
# Critical threats get 403 # Critical threats get 403
for threat in analysis.threats: for threat in analysis.threats:
if threat.threat_type in ["command_injection", "sql_injection"]: if threat.threat_type in ["command_injection", "sql_injection"]:
status_code = 403 status_code = 403
break break
response_data = { response_data = {
"error": "Security Policy Violation", "error": "Security Policy Violation",
"message": "Request blocked due to security policy violation", "message": "Request blocked due to security policy violation",
@@ -155,24 +151,12 @@ class SecurityMiddleware(BaseHTTPMiddleware):
"threat_count": len(analysis.threats), "threat_count": len(analysis.threats),
"recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations "recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations
} }
# Add rate limiting info if applicable
if analysis.rate_limit_exceeded:
response_data["error"] = "Rate Limit Exceeded"
response_data["message"] = f"Rate limit exceeded for {analysis.auth_level.value} user"
response_data["retry_after"] = "60" # Suggest retry after 60 seconds
response = JSONResponse( response = JSONResponse(
content=response_data, content=response_data,
status_code=status_code status_code=status_code
) )
# Add rate limiting headers
if analysis.rate_limit_exceeded:
response.headers["Retry-After"] = "60"
response.headers["X-RateLimit-Limit"] = "See API documentation"
response.headers["X-RateLimit-Reset"] = str(int(time.time() + 60))
return response return response
def _add_security_headers(self, response: Response) -> Response: def _add_security_headers(self, response: Response) -> Response:

View File

@@ -0,0 +1,201 @@
# Enhanced Embedding Service with Rate Limiting Handling
"""
Enhanced embedding service with robust rate limiting and retry logic
"""
import asyncio
import logging
import time
from typing import List, Dict, Any, Optional
import numpy as np
from datetime import datetime, timedelta
from .embedding_service import EmbeddingService
from app.core.config import settings
logger = logging.getLogger(__name__)
class EnhancedEmbeddingService(EmbeddingService):
"""Enhanced embedding service with rate limiting handling"""
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"):
super().__init__(model_name)
self.rate_limit_tracker = {
'requests_count': 0,
'window_start': time.time(),
'window_size': 60, # 1 minute window
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 60)), # Configurable
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 0.5)),
'last_rate_limit_error': None
}
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
"""
Get embeddings with rate limiting and retry logic
"""
if max_retries is None:
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 5))
if not self.initialized:
logger.warning("Embedding service not initialized, using fallback")
return self._generate_fallback_embeddings(texts), False
embeddings = []
success = True
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
batch_embeddings, batch_success = await self._get_batch_embeddings_with_retry(batch, max_retries)
embeddings.extend(batch_embeddings)
success = success and batch_success
# Add delay between batches to avoid rate limiting
if i + batch_size < len(texts):
delay = self.rate_limit_tracker['delay_between_batches']
await asyncio.sleep(delay) # Configurable delay between batches
return embeddings, success
async def _get_batch_embeddings_with_retry(self, texts: List[str], max_retries: int) -> tuple[List[List[float]], bool]:
"""Get embeddings for a batch with retry logic"""
last_error = None
for attempt in range(max_retries + 1):
try:
# Check rate limit before making request
if self._is_rate_limited():
delay = self._get_rate_limit_delay()
logger.warning(f"Rate limit detected, waiting {delay} seconds")
await asyncio.sleep(delay)
continue
# Make the request
embeddings = await self._get_embeddings_batch_impl(texts)
# Update rate limit tracker on success
self._update_rate_limit_tracker(success=True)
return embeddings, True
except Exception as e:
last_error = e
error_msg = str(e).lower()
# Check if it's a rate limit error
if any(indicator in error_msg for indicator in ['429', 'rate limit', 'too many requests', 'quota exceeded']):
logger.warning(f"Rate limit error (attempt {attempt + 1}/{max_retries + 1}): {e}")
self._update_rate_limit_tracker(success=False)
if attempt < max_retries:
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
logger.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)
continue
else:
logger.error(f"Max retries exceeded for rate limit, using fallback embeddings")
return self._generate_fallback_embeddings(texts), False
else:
# Non-rate-limit error
logger.error(f"Error generating embeddings: {e}")
if attempt < max_retries:
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
await asyncio.sleep(delay)
else:
logger.error("Max retries exceeded, using fallback embeddings")
return self._generate_fallback_embeddings(texts), False
# If we get here, all retries failed
logger.error(f"All retries failed, last error: {last_error}")
return self._generate_fallback_embeddings(texts), False
async def _get_embeddings_batch_impl(self, texts: List[str]) -> List[List[float]]:
"""Implementation of getting embeddings for a batch"""
from app.services.llm.service import llm_service
from app.services.llm.models import EmbeddingRequest
embeddings = []
for text in texts:
# Truncate text if needed
max_chars = 1600
truncated_text = text[:max_chars] if len(text) > max_chars else text
llm_request = EmbeddingRequest(
model=self.model_name,
input=truncated_text,
user_id="rag_system",
api_key_id=0
)
response = await llm_service.create_embedding(llm_request)
if response.data and len(response.data) > 0:
embedding = response.data[0].embedding
if embedding:
embeddings.append(embedding)
if not hasattr(self, '_dimension_confirmed'):
self.dimension = len(embedding)
self._dimension_confirmed = True
else:
raise ValueError("Empty embedding in response")
else:
raise ValueError("Invalid response structure")
return embeddings
def _is_rate_limited(self) -> bool:
"""Check if we're currently rate limited"""
now = time.time()
window_start = self.rate_limit_tracker['window_start']
# Reset window if it's expired
if now - window_start > self.rate_limit_tracker['window_size']:
self.rate_limit_tracker['requests_count'] = 0
self.rate_limit_tracker['window_start'] = now
return False
# Check if we've exceeded the limit
return self.rate_limit_tracker['requests_count'] >= self.rate_limit_tracker['max_requests_per_minute']
def _get_rate_limit_delay(self) -> float:
"""Get delay to wait for rate limit reset"""
now = time.time()
window_end = self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size']
return max(0, window_end - now)
def _update_rate_limit_tracker(self, success: bool):
"""Update the rate limit tracker"""
now = time.time()
# Reset window if it's expired
if now - self.rate_limit_tracker['window_start'] > self.rate_limit_tracker['window_size']:
self.rate_limit_tracker['requests_count'] = 0
self.rate_limit_tracker['window_start'] = now
# Increment counter on successful requests
if success:
self.rate_limit_tracker['requests_count'] += 1
async def get_embedding_stats(self) -> Dict[str, Any]:
"""Get embedding service statistics including rate limiting info"""
base_stats = await self.get_stats()
return {
**base_stats,
"rate_limit_info": {
"requests_in_current_window": self.rate_limit_tracker['requests_count'],
"max_requests_per_minute": self.rate_limit_tracker['max_requests_per_minute'],
"window_reset_in_seconds": max(0,
self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - time.time()
),
"last_rate_limit_error": self.rate_limit_tracker['last_rate_limit_error']
}
}
# Global enhanced embedding service instance
enhanced_embedding_service = EnhancedEmbeddingService()

View File

@@ -65,7 +65,16 @@ class LLMServiceConfig(BaseModel):
# Provider configurations # Provider configurations
providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations") providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations")
# Token rate limiting (organization-wide)
token_limits_per_minute: Dict[str, int] = Field(
default_factory=lambda: {
"prompt_tokens": 20000, # PrivateMode Standard tier
"completion_tokens": 10000 # PrivateMode Standard tier
},
description="Token rate limits per minute (organization-wide)"
)
# Model routing (model_name -> provider_name) # Model routing (model_name -> provider_name)
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing") model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
@@ -91,8 +100,8 @@ def create_default_config() -> LLMServiceConfig:
supported_models=[], # Will be populated dynamically from proxy supported_models=[], # Will be populated dynamically from proxy
capabilities=["chat", "embeddings", "tee"], capabilities=["chat", "embeddings", "tee"],
priority=1, priority=1,
max_requests_per_minute=100, max_requests_per_minute=20, # PrivateMode Standard tier limit: 20 req/min
max_requests_per_hour=2000, max_requests_per_hour=1200, # 20 req/min * 60 min
supports_streaming=True, supports_streaming=True,
supports_function_calling=True, supports_function_calling=True,
max_context_window=128000, max_context_window=128000,

View File

@@ -60,6 +60,7 @@ import tiktoken
from app.core.config import settings from app.core.config import settings
from app.core.logging import log_module_event from app.core.logging import log_module_event
from app.services.base_module import BaseModule, Permission from app.services.base_module import BaseModule, Permission
from app.services.enhanced_embedding_service import enhanced_embedding_service
@dataclass @dataclass
@@ -1125,9 +1126,17 @@ class RAGModule(BaseModule):
# Chunk the document # Chunk the document
chunks = self._chunk_text(content) chunks = self._chunk_text(content)
# Generate embeddings for all chunks in batch (more efficient) # Generate embeddings with enhanced rate limiting handling
embeddings = await self._generate_embeddings(chunks) embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
# Log if fallback embeddings were used
if not success:
logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded")
log_module_event("rag", "fallback_embeddings_used", {
"document_id": doc_id,
"content_preview": content[:100] + "..." if len(content) > 100 else content
})
# Create document points # Create document points
points = [] points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
@@ -1188,9 +1197,17 @@ class RAGModule(BaseModule):
# Chunk the document # Chunk the document
chunks = self._chunk_text(processed_doc.content) chunks = self._chunk_text(processed_doc.content)
# Generate embeddings for all chunks in batch (more efficient) # Generate embeddings with enhanced rate limiting handling
embeddings = await self._generate_embeddings(chunks) embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
# Log if fallback embeddings were used
if not success:
logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded")
log_module_event("rag", "fallback_embeddings_used", {
"document_id": processed_doc.id,
"filename": processed_doc.original_filename
})
# Create document points with enhanced metadata # Create document points with enhanced metadata
points = [] points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):