mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
rag improvements
This commit is contained in:
153
backend/app/services/llm/token_rate_limiter.py
Normal file
153
backend/app/services/llm/token_rate_limiter.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Token-based rate limiting for LLM service
|
||||
"""
|
||||
|
||||
import time
|
||||
import redis
|
||||
from typing import Dict, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from ..core.config import settings
|
||||
from ..core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenRateLimiter:
|
||||
"""Token-based rate limiting implementation"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
self.redis_client.ping()
|
||||
logger.info("Token rate limiter initialized with Redis backend")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis not available for token rate limiting: {e}")
|
||||
self.redis_client = None
|
||||
# Fall back to in-memory rate limiting
|
||||
self.in_memory_store = {}
|
||||
logger.info("Token rate limiter using in-memory fallback")
|
||||
|
||||
async def check_token_limits(
|
||||
self,
|
||||
provider: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int = 0
|
||||
) -> Tuple[bool, Dict[str, str]]:
|
||||
"""
|
||||
Check if token usage is within limits
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g., "privatemode")
|
||||
prompt_tokens: Number of prompt tokens to use
|
||||
completion_tokens: Number of completion tokens to use
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, headers)
|
||||
"""
|
||||
# Get token limits from configuration
|
||||
from .config import get_config
|
||||
config = get_config()
|
||||
token_limits = config.token_limits_per_minute
|
||||
|
||||
# Check organization-wide limits
|
||||
org_key = f"tokens:org:{provider}"
|
||||
|
||||
# Get current usage
|
||||
current_usage = await self._get_token_usage(org_key)
|
||||
|
||||
# Calculate new usage
|
||||
new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens
|
||||
new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens
|
||||
|
||||
# Check limits
|
||||
prompt_limit = token_limits.get("prompt_tokens", 20000)
|
||||
completion_limit = token_limits.get("completion_tokens", 10000)
|
||||
|
||||
is_allowed = (
|
||||
new_prompt_tokens <= prompt_limit and
|
||||
new_completion_tokens <= completion_limit
|
||||
)
|
||||
|
||||
if is_allowed:
|
||||
# Update usage
|
||||
await self._update_token_usage(org_key, prompt_tokens, completion_tokens)
|
||||
logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, "
|
||||
f"{new_completion_tokens}/{completion_limit} completion")
|
||||
|
||||
# Calculate remaining tokens
|
||||
remaining_prompt = max(0, prompt_limit - new_prompt_tokens)
|
||||
remaining_completion = max(0, completion_limit - new_completion_tokens)
|
||||
|
||||
# Create headers
|
||||
headers = {
|
||||
"X-TokenLimit-Prompt-Remaining": str(remaining_prompt),
|
||||
"X-TokenLimit-Completion-Remaining": str(remaining_completion),
|
||||
"X-TokenLimit-Prompt-Limit": str(prompt_limit),
|
||||
"X-TokenLimit-Completion-Limit": str(completion_limit),
|
||||
"X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute
|
||||
}
|
||||
|
||||
if not is_allowed:
|
||||
logger.warning(f"Token rate limit exceeded for {provider}. "
|
||||
f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. "
|
||||
f"Current: {current_usage}")
|
||||
|
||||
return is_allowed, headers
|
||||
|
||||
async def _get_token_usage(self, key: str) -> Dict[str, int]:
|
||||
"""Get current token usage"""
|
||||
if self.redis_client:
|
||||
try:
|
||||
data = self.redis_client.hgetall(key)
|
||||
if data:
|
||||
return {
|
||||
"prompt_tokens": int(data.get("prompt_tokens", 0)),
|
||||
"completion_tokens": int(data.get("completion_tokens", 0)),
|
||||
"updated_at": float(data.get("updated_at", time.time()))
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting token usage from Redis: {e}")
|
||||
|
||||
# Fallback to in-memory
|
||||
return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0})
|
||||
|
||||
async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int):
|
||||
"""Update token usage"""
|
||||
if self.redis_client:
|
||||
try:
|
||||
pipe = self.redis_client.pipeline()
|
||||
pipe.hincrby(key, "prompt_tokens", prompt_tokens)
|
||||
pipe.hincrby(key, "completion_tokens", completion_tokens)
|
||||
pipe.hset(key, "updated_at", time.time())
|
||||
pipe.expire(key, 60) # Expire after 1 minute
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating token usage in Redis: {e}")
|
||||
# Fallback to in-memory
|
||||
self._update_in_memory(key, prompt_tokens, completion_tokens)
|
||||
else:
|
||||
self._update_in_memory(key, prompt_tokens, completion_tokens)
|
||||
|
||||
def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int):
|
||||
"""Update in-memory token usage"""
|
||||
if key not in self.in_memory_store:
|
||||
self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
|
||||
self.in_memory_store[key]["prompt_tokens"] += prompt_tokens
|
||||
self.in_memory_store[key]["completion_tokens"] += completion_tokens
|
||||
self.in_memory_store[key]["updated_at"] = time.time()
|
||||
|
||||
def cleanup_expired(self):
|
||||
"""Clean up expired entries (for in-memory store)"""
|
||||
if not self.redis_client:
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, data in self.in_memory_store.items()
|
||||
if current_time - data.get("updated_at", 0) > 60
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self.in_memory_store[key]
|
||||
|
||||
|
||||
# Global token rate limiter instance
|
||||
token_rate_limiter = TokenRateLimiter()
|
||||
Reference in New Issue
Block a user