mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
212 lines
9.0 KiB
Python
212 lines
9.0 KiB
Python
# Enhanced Embedding Service with Rate Limiting Handling
|
|
"""
|
|
Enhanced embedding service with robust rate limiting and retry logic
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from typing import List, Dict, Any, Optional
|
|
import numpy as np
|
|
from datetime import datetime, timedelta
|
|
|
|
from .embedding_service import EmbeddingService
|
|
from app.core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EnhancedEmbeddingService(EmbeddingService):
|
|
"""Enhanced embedding service with rate limiting handling"""
|
|
|
|
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"):
|
|
super().__init__(model_name)
|
|
self.rate_limit_tracker = {
|
|
'requests_count': 0,
|
|
'window_start': time.time(),
|
|
'window_size': 60, # 1 minute window
|
|
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 12)), # Configurable
|
|
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
|
|
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 1.0)),
|
|
'delay_per_request': float(getattr(settings, 'RAG_EMBEDDING_DELAY_PER_REQUEST', 0.5)),
|
|
'last_rate_limit_error': None
|
|
}
|
|
|
|
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
|
|
"""
|
|
Get embeddings with rate limiting and retry logic
|
|
"""
|
|
if max_retries is None:
|
|
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
|
|
|
|
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 3))
|
|
|
|
if not self.initialized:
|
|
logger.warning("Embedding service not initialized, using fallback")
|
|
return self._generate_fallback_embeddings(texts), False
|
|
|
|
embeddings = []
|
|
success = True
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
batch = texts[i:i+batch_size]
|
|
batch_embeddings, batch_success = await self._get_batch_embeddings_with_retry(batch, max_retries)
|
|
embeddings.extend(batch_embeddings)
|
|
success = success and batch_success
|
|
|
|
# Add delay between batches to avoid rate limiting
|
|
if i + batch_size < len(texts):
|
|
delay = self.rate_limit_tracker['delay_between_batches']
|
|
await asyncio.sleep(delay) # Configurable delay between batches
|
|
|
|
return embeddings, success
|
|
|
|
async def _get_batch_embeddings_with_retry(self, texts: List[str], max_retries: int) -> tuple[List[List[float]], bool]:
|
|
"""Get embeddings for a batch with retry logic"""
|
|
last_error = None
|
|
|
|
for attempt in range(max_retries + 1):
|
|
try:
|
|
# Check rate limit before making request
|
|
if self._is_rate_limited():
|
|
delay = self._get_rate_limit_delay()
|
|
logger.warning(f"Rate limit detected, waiting {delay} seconds")
|
|
await asyncio.sleep(delay)
|
|
continue
|
|
|
|
# Make the request
|
|
embeddings = await self._get_embeddings_batch_impl(texts)
|
|
|
|
return embeddings, True
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
error_msg = str(e).lower()
|
|
|
|
# Check if it's a rate limit error
|
|
if any(indicator in error_msg for indicator in ['429', 'rate limit', 'too many requests', 'quota exceeded']):
|
|
logger.warning(f"Rate limit error (attempt {attempt + 1}/{max_retries + 1}): {e}")
|
|
self._update_rate_limit_tracker(success=False)
|
|
|
|
if attempt < max_retries:
|
|
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
|
|
logger.info(f"Retrying in {delay} seconds...")
|
|
await asyncio.sleep(delay)
|
|
continue
|
|
else:
|
|
logger.error(f"Max retries exceeded for rate limit, using fallback embeddings")
|
|
return self._generate_fallback_embeddings(texts), False
|
|
else:
|
|
# Non-rate-limit error
|
|
logger.error(f"Error generating embeddings: {e}")
|
|
if attempt < max_retries:
|
|
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
|
|
await asyncio.sleep(delay)
|
|
else:
|
|
logger.error("Max retries exceeded, using fallback embeddings")
|
|
return self._generate_fallback_embeddings(texts), False
|
|
|
|
# If we get here, all retries failed
|
|
logger.error(f"All retries failed, last error: {last_error}")
|
|
return self._generate_fallback_embeddings(texts), False
|
|
|
|
async def _get_embeddings_batch_impl(self, texts: List[str]) -> List[List[float]]:
|
|
"""Implementation of getting embeddings for a batch"""
|
|
from app.services.llm.service import llm_service
|
|
from app.services.llm.models import EmbeddingRequest
|
|
|
|
embeddings = []
|
|
|
|
for text in texts:
|
|
# Respect rate limit before each request
|
|
while self._is_rate_limited():
|
|
delay = self._get_rate_limit_delay()
|
|
logger.warning(f"Rate limit window exceeded, waiting {delay:.2f}s before next request")
|
|
await asyncio.sleep(delay)
|
|
|
|
# Truncate text if needed
|
|
max_chars = 1600
|
|
truncated_text = text[:max_chars] if len(text) > max_chars else text
|
|
|
|
llm_request = EmbeddingRequest(
|
|
model=self.model_name,
|
|
input=truncated_text,
|
|
user_id="rag_system",
|
|
api_key_id=0
|
|
)
|
|
|
|
response = await llm_service.create_embedding(llm_request)
|
|
|
|
if response.data and len(response.data) > 0:
|
|
embedding = response.data[0].embedding
|
|
if embedding:
|
|
embeddings.append(embedding)
|
|
if not hasattr(self, '_dimension_confirmed'):
|
|
self.dimension = len(embedding)
|
|
self._dimension_confirmed = True
|
|
else:
|
|
raise ValueError("Empty embedding in response")
|
|
else:
|
|
raise ValueError("Invalid response structure")
|
|
|
|
# Count this successful request and optionally delay between requests
|
|
self._update_rate_limit_tracker(success=True)
|
|
per_req_delay = self.rate_limit_tracker.get('delay_per_request', 0)
|
|
if per_req_delay and per_req_delay > 0:
|
|
await asyncio.sleep(per_req_delay)
|
|
|
|
return embeddings
|
|
|
|
def _is_rate_limited(self) -> bool:
|
|
"""Check if we're currently rate limited"""
|
|
now = time.time()
|
|
window_start = self.rate_limit_tracker['window_start']
|
|
|
|
# Reset window if it's expired
|
|
if now - window_start > self.rate_limit_tracker['window_size']:
|
|
self.rate_limit_tracker['requests_count'] = 0
|
|
self.rate_limit_tracker['window_start'] = now
|
|
return False
|
|
|
|
# Check if we've exceeded the limit
|
|
return self.rate_limit_tracker['requests_count'] >= self.rate_limit_tracker['max_requests_per_minute']
|
|
|
|
def _get_rate_limit_delay(self) -> float:
|
|
"""Get delay to wait for rate limit reset"""
|
|
now = time.time()
|
|
window_end = self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size']
|
|
return max(0, window_end - now)
|
|
|
|
def _update_rate_limit_tracker(self, success: bool):
|
|
"""Update the rate limit tracker"""
|
|
now = time.time()
|
|
|
|
# Reset window if it's expired
|
|
if now - self.rate_limit_tracker['window_start'] > self.rate_limit_tracker['window_size']:
|
|
self.rate_limit_tracker['requests_count'] = 0
|
|
self.rate_limit_tracker['window_start'] = now
|
|
|
|
# Increment counter on successful requests
|
|
if success:
|
|
self.rate_limit_tracker['requests_count'] += 1
|
|
|
|
async def get_embedding_stats(self) -> Dict[str, Any]:
|
|
"""Get embedding service statistics including rate limiting info"""
|
|
base_stats = await self.get_stats()
|
|
|
|
return {
|
|
**base_stats,
|
|
"rate_limit_info": {
|
|
"requests_in_current_window": self.rate_limit_tracker['requests_count'],
|
|
"max_requests_per_minute": self.rate_limit_tracker['max_requests_per_minute'],
|
|
"window_reset_in_seconds": max(0,
|
|
self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - time.time()
|
|
),
|
|
"last_rate_limit_error": self.rate_limit_tracker['last_rate_limit_error']
|
|
}
|
|
}
|
|
|
|
|
|
# Global enhanced embedding service instance
|
|
enhanced_embedding_service = EnhancedEmbeddingService()
|