Files
enclava/backend/app/services/litellm_client.py
2025-08-19 09:50:15 +02:00

304 lines
12 KiB
Python

"""
LiteLLM Client Service
Handles communication with the LiteLLM proxy service
"""
import asyncio
import json
import logging
from typing import Dict, Any, Optional, List
from datetime import datetime
import aiohttp
from fastapi import HTTPException, status
from app.core.config import settings
logger = logging.getLogger(__name__)
class LiteLLMClient:
"""Client for communicating with LiteLLM proxy service"""
def __init__(self):
self.base_url = settings.LITELLM_BASE_URL
self.master_key = settings.LITELLM_MASTER_KEY
self.session: Optional[aiohttp.ClientSession] = None
self.timeout = aiohttp.ClientTimeout(total=600) # 10 minutes timeout
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create aiohttp session"""
if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession(
timeout=self.timeout,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.master_key}"
}
)
return self.session
async def close(self):
"""Close the HTTP session"""
if self.session and not self.session.closed:
await self.session.close()
async def health_check(self) -> Dict[str, Any]:
"""Check LiteLLM proxy health"""
try:
session = await self._get_session()
async with session.get(f"{self.base_url}/health") as response:
if response.status == 200:
return await response.json()
else:
logger.error(f"LiteLLM health check failed: {response.status}")
return {"status": "unhealthy", "error": f"HTTP {response.status}"}
except Exception as e:
logger.error(f"LiteLLM health check error: {e}")
return {"status": "unhealthy", "error": str(e)}
async def get_models(self) -> List[Dict[str, Any]]:
"""Get available models from LiteLLM"""
try:
session = await self._get_session()
async with session.get(f"{self.base_url}/models") as response:
if response.status == 200:
data = await response.json()
return data.get("data", [])
else:
logger.error(f"Failed to get models: {response.status}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
except aiohttp.ClientError as e:
logger.error(f"LiteLLM models request error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
async def create_chat_completion(
self,
model: str,
messages: List[Dict[str, str]],
user_id: str,
api_key_id: int,
**kwargs
) -> Dict[str, Any]:
"""Create chat completion via LiteLLM proxy"""
try:
# Prepare request payload
payload = {
"model": model,
"messages": messages,
"user": f"user_{user_id}", # User identifier for tracking
"metadata": {
"api_key_id": api_key_id,
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat()
},
**kwargs
}
session = await self._get_session()
async with session.post(
f"{self.base_url}/chat/completions",
json=payload
) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
logger.error(f"LiteLLM chat completion failed: {response.status} - {error_text}")
# Handle specific error cases
if response.status == 401:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
elif response.status == 429:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
elif response.status == 400:
try:
error_data = await response.json()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_data.get("error", {}).get("message", "Bad request")
)
except:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request"
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service error"
)
except aiohttp.ClientError as e:
logger.error(f"LiteLLM chat completion request error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
async def create_embedding(
self,
model: str,
input_text: str,
user_id: str,
api_key_id: int,
**kwargs
) -> Dict[str, Any]:
"""Create embedding via LiteLLM proxy"""
try:
payload = {
"model": model,
"input": input_text,
"user": f"user_{user_id}",
"metadata": {
"api_key_id": api_key_id,
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat()
},
**kwargs
}
session = await self._get_session()
async with session.post(
f"{self.base_url}/embeddings",
json=payload
) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
logger.error(f"LiteLLM embedding failed: {response.status} - {error_text}")
if response.status == 401:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
elif response.status == 429:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service error"
)
except aiohttp.ClientError as e:
logger.error(f"LiteLLM embedding request error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
async def get_models(self) -> List[Dict[str, Any]]:
"""Get available models from LiteLLM proxy"""
try:
session = await self._get_session()
async with session.get(f"{self.base_url}/models") as response:
if response.status == 200:
data = await response.json()
# Return models with exact names from upstream providers
models = data.get("data", [])
# Pass through model names exactly as they come from upstream
# Don't modify model IDs - keep them as the original provider names
processed_models = []
for model in models:
# Keep the exact model ID from upstream provider
processed_models.append({
"id": model.get("id", ""), # Exact model name from provider
"object": model.get("object", "model"),
"created": model.get("created", 1677610602),
"owned_by": model.get("owned_by", "openai")
})
return processed_models
else:
error_text = await response.text()
logger.error(f"LiteLLM models request failed: {response.status} - {error_text}")
return []
except aiohttp.ClientError as e:
logger.error(f"LiteLLM models request error: {e}")
return []
async def proxy_request(
self,
method: str,
endpoint: str,
payload: Dict[str, Any],
user_id: str,
api_key_id: int
) -> Dict[str, Any]:
"""Generic proxy request to LiteLLM"""
try:
# Add metadata to payload
if isinstance(payload, dict):
payload["metadata"] = {
"api_key_id": api_key_id,
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat()
}
if "user" not in payload:
payload["user"] = f"user_{user_id}"
session = await self._get_session()
# Make the request
async with session.request(
method,
f"{self.base_url}/{endpoint.lstrip('/')}",
json=payload if method.upper() in ['POST', 'PUT', 'PATCH'] else None,
params=payload if method.upper() == 'GET' else None
) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
logger.error(f"LiteLLM proxy request failed: {response.status} - {error_text}")
if response.status == 401:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
elif response.status == 429:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
else:
raise HTTPException(
status_code=response.status,
detail=f"LiteLLM service error: {error_text}"
)
except aiohttp.ClientError as e:
logger.error(f"LiteLLM proxy request error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
async def list_models(self) -> List[str]:
"""Get list of available model names/IDs"""
try:
models_data = await self.get_models()
return [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")]
except Exception as e:
logger.error(f"Error listing model names: {str(e)}")
return []
# Global LiteLLM client instance
litellm_client = LiteLLMClient()