From f58a76ac596439bf6f69963fa278fd7d47f2d288 Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Sun, 21 Sep 2025 06:49:55 +0200 Subject: [PATCH] ratelimiting and rag --- backend/app/core/config.py | 27 +- backend/app/main.py | 4 + backend/app/middleware/rate_limiting.py | 234 +++++++++++------- backend/app/middleware/security.py | 30 +-- .../services/enhanced_embedding_service.py | 201 +++++++++++++++ backend/app/services/llm/config.py | 15 +- backend/modules/rag/main.py | 29 ++- 7 files changed, 410 insertions(+), 130 deletions(-) create mode 100644 backend/app/services/enhanced_embedding_service.py diff --git a/backend/app/core/config.py b/backend/app/core/config.py index c5cb8c3..f3ac614 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -82,18 +82,25 @@ class Settings(BaseSettings): # Rate Limiting Configuration API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true" - - # Authenticated users (JWT token) - API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "300")) - API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "5000")) - + + # PrivateMode Standard tier limits (organization-level, not per user) + # These are shared across all API keys and users in the organization + PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20")) + PRIVATEMODE_REQUESTS_PER_HOUR: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200")) + PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000")) + PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000")) + + # Per-user limits (additional protection on top of organization limits) + API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")) # Match PrivateMode + API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200")) + # API key users (programmatic access) - API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "1000")) - API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "20000")) - + API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")) # Match PrivateMode + API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200")) + # Premium/Enterprise API keys - API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "5000")) - API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "100000")) + API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode + API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200")) # Security Thresholds API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score diff --git a/backend/app/main.py b/backend/app/main.py index e0466e6..40d51a3 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -139,6 +139,10 @@ setup_analytics_middleware(app) from app.middleware.security import setup_security_middleware setup_security_middleware(app, enabled=settings.API_SECURITY_ENABLED) +# Add rate limiting middleware only for specific endpoints +from app.middleware.rate_limiting import RateLimitMiddleware +app.add_middleware(RateLimitMiddleware) + # Exception handlers @app.exception_handler(CustomHTTPException) diff --git a/backend/app/middleware/rate_limiting.py b/backend/app/middleware/rate_limiting.py index 611a67a..f6e1901 100644 --- a/backend/app/middleware/rate_limiting.py +++ b/backend/app/middleware/rate_limiting.py @@ -7,6 +7,7 @@ import redis from typing import Dict, Optional from fastapi import Request, HTTPException, status from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware import asyncio from datetime import datetime, timedelta @@ -155,96 +156,153 @@ class RateLimiter: rate_limiter = RateLimiter() -async def rate_limit_middleware(request: Request, call_next): - """ - Rate limiting middleware for FastAPI - """ - - # Skip rate limiting for health checks and static files - if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]: +class RateLimitMiddleware(BaseHTTPMiddleware): + """Rate limiting middleware for FastAPI""" + + def __init__(self, app): + super().__init__(app) + self.rate_limiter = RateLimiter() + logger.info("RateLimitMiddleware initialized") + + async def dispatch(self, request: Request, call_next): + """Process request through rate limiting""" + + # Skip rate limiting if disabled in settings + if not settings.API_RATE_LIMITING_ENABLED: + response = await call_next(request) + return response + + # Skip rate limiting for all internal API endpoints (platform operations) + if request.url.path.startswith("/api-internal/v1/"): + response = await call_next(request) + return response + + # Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service) + # Skip for all other endpoints + if not (request.url.path.startswith("/api/v1/chat/completions") or + request.url.path.startswith("/api/v1/embeddings") or + request.url.path.startswith("/api/v1/models") or + request.url.path.startswith("/api/v1/llm/")): + response = await call_next(request) + return response + + # Skip rate limiting for health checks and static files + if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]: + response = await call_next(request) + return response + + # Get client IP + client_ip = request.client.host + forwarded_for = request.headers.get("X-Forwarded-For") + if forwarded_for: + client_ip = forwarded_for.split(",")[0].strip() + + # Check for API key in headers + api_key = None + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + api_key = auth_header[7:] + elif request.headers.get("X-API-Key"): + api_key = request.headers.get("X-API-Key") + + # Determine rate limiting strategy + headers = {} + is_allowed = True + + if api_key: + # API key-based rate limiting + api_key_key = f"api_key:{api_key}" + + # First check organization-wide limits (PrivateMode limits are org-wide) + org_key = "organization:privatemode" + + # Check organization per-minute limit + org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit( + org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute" + ) + + # Check organization per-hour limit + org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit( + org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour" + ) + + # If organization limits are exceeded, return 429 + if not (org_allowed_minute and org_allowed_hour): + logger.warning(f"Organization rate limit exceeded for {org_key}") + return JSONResponse( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + content={"detail": "Organization rate limit exceeded"}, + headers=org_headers_minute + ) + + # Then check per-API key limits + limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE + limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR + + # Check per-minute limit + is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit( + api_key_key, limit_per_minute, 60, "minute" + ) + + # Check per-hour limit + is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit( + api_key_key, limit_per_hour, 3600, "hour" + ) + + is_allowed = is_allowed_minute and is_allowed_hour + headers = headers_minute # Use minute headers for response + + else: + # IP-based rate limiting for unauthenticated requests + rate_limit_key = f"ip:{client_ip}" + + # More restrictive limits for unauthenticated requests + limit_per_minute = 20 # Hardcoded for unauthenticated users + limit_per_hour = 100 + + # Check per-minute limit + is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit( + rate_limit_key, limit_per_minute, 60, "minute" + ) + + # Check per-hour limit + is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit( + rate_limit_key, limit_per_hour, 3600, "hour" + ) + + is_allowed = is_allowed_minute and is_allowed_hour + headers = headers_minute # Use minute headers for response + + # If rate limit exceeded, return 429 + if not is_allowed: + return JSONResponse( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + content={ + "error": "RATE_LIMIT_EXCEEDED", + "message": "Rate limit exceeded. Please try again later.", + "details": { + "limit": headers["X-RateLimit-Limit"], + "reset_time": headers["X-RateLimit-Reset"] + } + }, + headers={k: str(v) for k, v in headers.items()} + ) + + # Continue with request response = await call_next(request) + + # Add rate limit headers to response + for key, value in headers.items(): + response.headers[key] = str(value) + return response - - # Get client IP - client_ip = request.client.host - forwarded_for = request.headers.get("X-Forwarded-For") - if forwarded_for: - client_ip = forwarded_for.split(",")[0].strip() - - # Check for API key in headers - api_key = None - auth_header = request.headers.get("Authorization") - if auth_header and auth_header.startswith("Bearer "): - api_key = auth_header[7:] - elif request.headers.get("X-API-Key"): - api_key = request.headers.get("X-API-Key") - - # Determine rate limiting strategy - if api_key: - # API key-based rate limiting - rate_limit_key = f"api_key:{api_key}" - - # Get API key limits from database (simplified - would implement proper lookup) - limit_per_minute = 100 # Default limit - limit_per_hour = 1000 # Default limit - - # Check per-minute limit - is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit( - rate_limit_key, limit_per_minute, 60, "minute" - ) - - # Check per-hour limit - is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit( - rate_limit_key, limit_per_hour, 3600, "hour" - ) - - is_allowed = is_allowed_minute and is_allowed_hour - headers = headers_minute # Use minute headers for response - - else: - # IP-based rate limiting for unauthenticated requests - rate_limit_key = f"ip:{client_ip}" - - # More restrictive limits for unauthenticated requests - limit_per_minute = 20 - limit_per_hour = 100 - - # Check per-minute limit - is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit( - rate_limit_key, limit_per_minute, 60, "minute" - ) - - # Check per-hour limit - is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit( - rate_limit_key, limit_per_hour, 3600, "hour" - ) - - is_allowed = is_allowed_minute and is_allowed_hour - headers = headers_minute # Use minute headers for response - - # If rate limit exceeded, return 429 - if not is_allowed: - return JSONResponse( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - content={ - "error": "RATE_LIMIT_EXCEEDED", - "message": "Rate limit exceeded. Please try again later.", - "details": { - "limit": headers["X-RateLimit-Limit"], - "reset_time": headers["X-RateLimit-Reset"] - } - }, - headers={k: str(v) for k, v in headers.items()} - ) - - # Continue with request - response = await call_next(request) - - # Add rate limit headers to response - for key, value in headers.items(): - response.headers[key] = str(value) - - return response + + +# Keep the old function for backward compatibility +async def rate_limit_middleware(request: Request, call_next): + """Legacy function - use RateLimitMiddleware class instead""" + middleware = RateLimitMiddleware(None) + return await middleware.dispatch(request, call_next) class RateLimitExceeded(HTTPException): diff --git a/backend/app/middleware/security.py b/backend/app/middleware/security.py index 6efc1f4..57d2ebe 100644 --- a/backend/app/middleware/security.py +++ b/backend/app/middleware/security.py @@ -61,12 +61,12 @@ class SecurityMiddleware(BaseHTTPMiddleware): if analysis.is_threat and (analysis.should_block or analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD): await self._log_security_event(request, analysis) - # Check if request should be blocked - if analysis.should_block: + # Check if request should be blocked (excluding rate limiting) + if analysis.should_block and not analysis.rate_limit_exceeded: threat_detection_service.stats['threats_blocked'] += 1 logger.warning(f"Blocked request from {request.client.host if request.client else 'unknown'}: " f"risk_score={analysis.risk_score:.3f}, threats={len(analysis.threats)}") - + # Return security block response return self._create_block_response(analysis) @@ -136,17 +136,13 @@ class SecurityMiddleware(BaseHTTPMiddleware): """Create response for blocked requests""" # Determine status code based on threat type status_code = 403 # Forbidden by default - - # Rate limiting gets 429 - if analysis.rate_limit_exceeded: - status_code = 429 - + # Critical threats get 403 for threat in analysis.threats: if threat.threat_type in ["command_injection", "sql_injection"]: status_code = 403 break - + response_data = { "error": "Security Policy Violation", "message": "Request blocked due to security policy violation", @@ -155,24 +151,12 @@ class SecurityMiddleware(BaseHTTPMiddleware): "threat_count": len(analysis.threats), "recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations } - - # Add rate limiting info if applicable - if analysis.rate_limit_exceeded: - response_data["error"] = "Rate Limit Exceeded" - response_data["message"] = f"Rate limit exceeded for {analysis.auth_level.value} user" - response_data["retry_after"] = "60" # Suggest retry after 60 seconds - + response = JSONResponse( content=response_data, status_code=status_code ) - - # Add rate limiting headers - if analysis.rate_limit_exceeded: - response.headers["Retry-After"] = "60" - response.headers["X-RateLimit-Limit"] = "See API documentation" - response.headers["X-RateLimit-Reset"] = str(int(time.time() + 60)) - + return response def _add_security_headers(self, response: Response) -> Response: diff --git a/backend/app/services/enhanced_embedding_service.py b/backend/app/services/enhanced_embedding_service.py new file mode 100644 index 0000000..284773f --- /dev/null +++ b/backend/app/services/enhanced_embedding_service.py @@ -0,0 +1,201 @@ +# Enhanced Embedding Service with Rate Limiting Handling +""" +Enhanced embedding service with robust rate limiting and retry logic +""" + +import asyncio +import logging +import time +from typing import List, Dict, Any, Optional +import numpy as np +from datetime import datetime, timedelta + +from .embedding_service import EmbeddingService +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +class EnhancedEmbeddingService(EmbeddingService): + """Enhanced embedding service with rate limiting handling""" + + def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"): + super().__init__(model_name) + self.rate_limit_tracker = { + 'requests_count': 0, + 'window_start': time.time(), + 'window_size': 60, # 1 minute window + 'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 60)), # Configurable + 'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff + 'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 0.5)), + 'last_rate_limit_error': None + } + + async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]: + """ + Get embeddings with rate limiting and retry logic + """ + if max_retries is None: + max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3)) + + batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 5)) + + if not self.initialized: + logger.warning("Embedding service not initialized, using fallback") + return self._generate_fallback_embeddings(texts), False + + embeddings = [] + success = True + + for i in range(0, len(texts), batch_size): + batch = texts[i:i+batch_size] + batch_embeddings, batch_success = await self._get_batch_embeddings_with_retry(batch, max_retries) + embeddings.extend(batch_embeddings) + success = success and batch_success + + # Add delay between batches to avoid rate limiting + if i + batch_size < len(texts): + delay = self.rate_limit_tracker['delay_between_batches'] + await asyncio.sleep(delay) # Configurable delay between batches + + return embeddings, success + + async def _get_batch_embeddings_with_retry(self, texts: List[str], max_retries: int) -> tuple[List[List[float]], bool]: + """Get embeddings for a batch with retry logic""" + last_error = None + + for attempt in range(max_retries + 1): + try: + # Check rate limit before making request + if self._is_rate_limited(): + delay = self._get_rate_limit_delay() + logger.warning(f"Rate limit detected, waiting {delay} seconds") + await asyncio.sleep(delay) + continue + + # Make the request + embeddings = await self._get_embeddings_batch_impl(texts) + + # Update rate limit tracker on success + self._update_rate_limit_tracker(success=True) + + return embeddings, True + + except Exception as e: + last_error = e + error_msg = str(e).lower() + + # Check if it's a rate limit error + if any(indicator in error_msg for indicator in ['429', 'rate limit', 'too many requests', 'quota exceeded']): + logger.warning(f"Rate limit error (attempt {attempt + 1}/{max_retries + 1}): {e}") + self._update_rate_limit_tracker(success=False) + + if attempt < max_retries: + delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)] + logger.info(f"Retrying in {delay} seconds...") + await asyncio.sleep(delay) + continue + else: + logger.error(f"Max retries exceeded for rate limit, using fallback embeddings") + return self._generate_fallback_embeddings(texts), False + else: + # Non-rate-limit error + logger.error(f"Error generating embeddings: {e}") + if attempt < max_retries: + delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)] + await asyncio.sleep(delay) + else: + logger.error("Max retries exceeded, using fallback embeddings") + return self._generate_fallback_embeddings(texts), False + + # If we get here, all retries failed + logger.error(f"All retries failed, last error: {last_error}") + return self._generate_fallback_embeddings(texts), False + + async def _get_embeddings_batch_impl(self, texts: List[str]) -> List[List[float]]: + """Implementation of getting embeddings for a batch""" + from app.services.llm.service import llm_service + from app.services.llm.models import EmbeddingRequest + + embeddings = [] + + for text in texts: + # Truncate text if needed + max_chars = 1600 + truncated_text = text[:max_chars] if len(text) > max_chars else text + + llm_request = EmbeddingRequest( + model=self.model_name, + input=truncated_text, + user_id="rag_system", + api_key_id=0 + ) + + response = await llm_service.create_embedding(llm_request) + + if response.data and len(response.data) > 0: + embedding = response.data[0].embedding + if embedding: + embeddings.append(embedding) + if not hasattr(self, '_dimension_confirmed'): + self.dimension = len(embedding) + self._dimension_confirmed = True + else: + raise ValueError("Empty embedding in response") + else: + raise ValueError("Invalid response structure") + + return embeddings + + def _is_rate_limited(self) -> bool: + """Check if we're currently rate limited""" + now = time.time() + window_start = self.rate_limit_tracker['window_start'] + + # Reset window if it's expired + if now - window_start > self.rate_limit_tracker['window_size']: + self.rate_limit_tracker['requests_count'] = 0 + self.rate_limit_tracker['window_start'] = now + return False + + # Check if we've exceeded the limit + return self.rate_limit_tracker['requests_count'] >= self.rate_limit_tracker['max_requests_per_minute'] + + def _get_rate_limit_delay(self) -> float: + """Get delay to wait for rate limit reset""" + now = time.time() + window_end = self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] + return max(0, window_end - now) + + def _update_rate_limit_tracker(self, success: bool): + """Update the rate limit tracker""" + now = time.time() + + # Reset window if it's expired + if now - self.rate_limit_tracker['window_start'] > self.rate_limit_tracker['window_size']: + self.rate_limit_tracker['requests_count'] = 0 + self.rate_limit_tracker['window_start'] = now + + # Increment counter on successful requests + if success: + self.rate_limit_tracker['requests_count'] += 1 + + async def get_embedding_stats(self) -> Dict[str, Any]: + """Get embedding service statistics including rate limiting info""" + base_stats = await self.get_stats() + + return { + **base_stats, + "rate_limit_info": { + "requests_in_current_window": self.rate_limit_tracker['requests_count'], + "max_requests_per_minute": self.rate_limit_tracker['max_requests_per_minute'], + "window_reset_in_seconds": max(0, + self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - time.time() + ), + "last_rate_limit_error": self.rate_limit_tracker['last_rate_limit_error'] + } + } + + +# Global enhanced embedding service instance +enhanced_embedding_service = EnhancedEmbeddingService() \ No newline at end of file diff --git a/backend/app/services/llm/config.py b/backend/app/services/llm/config.py index 8ac8fb8..61a8576 100644 --- a/backend/app/services/llm/config.py +++ b/backend/app/services/llm/config.py @@ -65,7 +65,16 @@ class LLMServiceConfig(BaseModel): # Provider configurations providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations") - + + # Token rate limiting (organization-wide) + token_limits_per_minute: Dict[str, int] = Field( + default_factory=lambda: { + "prompt_tokens": 20000, # PrivateMode Standard tier + "completion_tokens": 10000 # PrivateMode Standard tier + }, + description="Token rate limits per minute (organization-wide)" + ) + # Model routing (model_name -> provider_name) model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing") @@ -91,8 +100,8 @@ def create_default_config() -> LLMServiceConfig: supported_models=[], # Will be populated dynamically from proxy capabilities=["chat", "embeddings", "tee"], priority=1, - max_requests_per_minute=100, - max_requests_per_hour=2000, + max_requests_per_minute=20, # PrivateMode Standard tier limit: 20 req/min + max_requests_per_hour=1200, # 20 req/min * 60 min supports_streaming=True, supports_function_calling=True, max_context_window=128000, diff --git a/backend/modules/rag/main.py b/backend/modules/rag/main.py index 1871b0d..b6c90b7 100644 --- a/backend/modules/rag/main.py +++ b/backend/modules/rag/main.py @@ -60,6 +60,7 @@ import tiktoken from app.core.config import settings from app.core.logging import log_module_event from app.services.base_module import BaseModule, Permission +from app.services.enhanced_embedding_service import enhanced_embedding_service @dataclass @@ -1125,9 +1126,17 @@ class RAGModule(BaseModule): # Chunk the document chunks = self._chunk_text(content) - # Generate embeddings for all chunks in batch (more efficient) - embeddings = await self._generate_embeddings(chunks) - + # Generate embeddings with enhanced rate limiting handling + embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks) + + # Log if fallback embeddings were used + if not success: + logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded") + log_module_event("rag", "fallback_embeddings_used", { + "document_id": doc_id, + "content_preview": content[:100] + "..." if len(content) > 100 else content + }) + # Create document points points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): @@ -1188,9 +1197,17 @@ class RAGModule(BaseModule): # Chunk the document chunks = self._chunk_text(processed_doc.content) - # Generate embeddings for all chunks in batch (more efficient) - embeddings = await self._generate_embeddings(chunks) - + # Generate embeddings with enhanced rate limiting handling + embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks) + + # Log if fallback embeddings were used + if not success: + logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded") + log_module_event("rag", "fallback_embeddings_used", { + "document_id": processed_doc.id, + "filename": processed_doc.original_filename + }) + # Create document points with enhanced metadata points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):