Files
enclava/backend/app/services/llm/token_rate_limiter.py
2025-09-21 18:44:02 +02:00

153 lines
5.9 KiB
Python

"""
Token-based rate limiting for LLM service
"""
import time
import redis
from typing import Dict, Optional, Tuple
from datetime import datetime, timedelta
from ..core.config import settings
from ..core.logging import get_logger
logger = get_logger(__name__)
class TokenRateLimiter:
"""Token-based rate limiting implementation"""
def __init__(self):
try:
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
self.redis_client.ping()
logger.info("Token rate limiter initialized with Redis backend")
except Exception as e:
logger.warning(f"Redis not available for token rate limiting: {e}")
self.redis_client = None
# Fall back to in-memory rate limiting
self.in_memory_store = {}
logger.info("Token rate limiter using in-memory fallback")
async def check_token_limits(
self,
provider: str,
prompt_tokens: int,
completion_tokens: int = 0
) -> Tuple[bool, Dict[str, str]]:
"""
Check if token usage is within limits
Args:
provider: Provider name (e.g., "privatemode")
prompt_tokens: Number of prompt tokens to use
completion_tokens: Number of completion tokens to use
Returns:
Tuple of (is_allowed, headers)
"""
# Get token limits from configuration
from .config import get_config
config = get_config()
token_limits = config.token_limits_per_minute
# Check organization-wide limits
org_key = f"tokens:org:{provider}"
# Get current usage
current_usage = await self._get_token_usage(org_key)
# Calculate new usage
new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens
new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens
# Check limits
prompt_limit = token_limits.get("prompt_tokens", 20000)
completion_limit = token_limits.get("completion_tokens", 10000)
is_allowed = (
new_prompt_tokens <= prompt_limit and
new_completion_tokens <= completion_limit
)
if is_allowed:
# Update usage
await self._update_token_usage(org_key, prompt_tokens, completion_tokens)
logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, "
f"{new_completion_tokens}/{completion_limit} completion")
# Calculate remaining tokens
remaining_prompt = max(0, prompt_limit - new_prompt_tokens)
remaining_completion = max(0, completion_limit - new_completion_tokens)
# Create headers
headers = {
"X-TokenLimit-Prompt-Remaining": str(remaining_prompt),
"X-TokenLimit-Completion-Remaining": str(remaining_completion),
"X-TokenLimit-Prompt-Limit": str(prompt_limit),
"X-TokenLimit-Completion-Limit": str(completion_limit),
"X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute
}
if not is_allowed:
logger.warning(f"Token rate limit exceeded for {provider}. "
f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. "
f"Current: {current_usage}")
return is_allowed, headers
async def _get_token_usage(self, key: str) -> Dict[str, int]:
"""Get current token usage"""
if self.redis_client:
try:
data = self.redis_client.hgetall(key)
if data:
return {
"prompt_tokens": int(data.get("prompt_tokens", 0)),
"completion_tokens": int(data.get("completion_tokens", 0)),
"updated_at": float(data.get("updated_at", time.time()))
}
except Exception as e:
logger.error(f"Error getting token usage from Redis: {e}")
# Fallback to in-memory
return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0})
async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int):
"""Update token usage"""
if self.redis_client:
try:
pipe = self.redis_client.pipeline()
pipe.hincrby(key, "prompt_tokens", prompt_tokens)
pipe.hincrby(key, "completion_tokens", completion_tokens)
pipe.hset(key, "updated_at", time.time())
pipe.expire(key, 60) # Expire after 1 minute
pipe.execute()
except Exception as e:
logger.error(f"Error updating token usage in Redis: {e}")
# Fallback to in-memory
self._update_in_memory(key, prompt_tokens, completion_tokens)
else:
self._update_in_memory(key, prompt_tokens, completion_tokens)
def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int):
"""Update in-memory token usage"""
if key not in self.in_memory_store:
self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0}
self.in_memory_store[key]["prompt_tokens"] += prompt_tokens
self.in_memory_store[key]["completion_tokens"] += completion_tokens
self.in_memory_store[key]["updated_at"] = time.time()
def cleanup_expired(self):
"""Clean up expired entries (for in-memory store)"""
if not self.redis_client:
current_time = time.time()
expired_keys = [
key for key, data in self.in_memory_store.items()
if current_time - data.get("updated_at", 0) > 60
]
for key in expired_keys:
del self.in_memory_store[key]
# Global token rate limiter instance
token_rate_limiter = TokenRateLimiter()