mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
clean commit
This commit is contained in:
304
backend/app/services/litellm_client.py
Normal file
304
backend/app/services/litellm_client.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
LiteLLM Client Service
|
||||
Handles communication with the LiteLLM proxy service
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
"""Client for communicating with LiteLLM proxy service"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = settings.LITELLM_BASE_URL
|
||||
self.master_key = settings.LITELLM_MASTER_KEY
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
self.timeout = aiohttp.ClientTimeout(total=600) # 10 minutes timeout
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session"""
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession(
|
||||
timeout=self.timeout,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.master_key}"
|
||||
}
|
||||
)
|
||||
return self.session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session"""
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Check LiteLLM proxy health"""
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(f"{self.base_url}/health") as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
logger.error(f"LiteLLM health check failed: {response.status}")
|
||||
return {"status": "unhealthy", "error": f"HTTP {response.status}"}
|
||||
except Exception as e:
|
||||
logger.error(f"LiteLLM health check error: {e}")
|
||||
return {"status": "unhealthy", "error": str(e)}
|
||||
|
||||
async def get_models(self) -> List[Dict[str, Any]]:
|
||||
"""Get available models from LiteLLM"""
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(f"{self.base_url}/models") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return data.get("data", [])
|
||||
else:
|
||||
logger.error(f"Failed to get models: {response.status}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service unavailable"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"LiteLLM models request error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service unavailable"
|
||||
)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
user_id: str,
|
||||
api_key_id: int,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Create chat completion via LiteLLM proxy"""
|
||||
try:
|
||||
# Prepare request payload
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"user": f"user_{user_id}", # User identifier for tracking
|
||||
"metadata": {
|
||||
"api_key_id": api_key_id,
|
||||
"user_id": user_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
},
|
||||
**kwargs
|
||||
}
|
||||
|
||||
session = await self._get_session()
|
||||
async with session.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
json=payload
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"LiteLLM chat completion failed: {response.status} - {error_text}")
|
||||
|
||||
# Handle specific error cases
|
||||
if response.status == 401:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key"
|
||||
)
|
||||
elif response.status == 429:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
)
|
||||
elif response.status == 400:
|
||||
try:
|
||||
error_data = await response.json()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=error_data.get("error", {}).get("message", "Bad request")
|
||||
)
|
||||
except:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid request"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service error"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"LiteLLM chat completion request error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service unavailable"
|
||||
)
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
model: str,
|
||||
input_text: str,
|
||||
user_id: str,
|
||||
api_key_id: int,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Create embedding via LiteLLM proxy"""
|
||||
try:
|
||||
payload = {
|
||||
"model": model,
|
||||
"input": input_text,
|
||||
"user": f"user_{user_id}",
|
||||
"metadata": {
|
||||
"api_key_id": api_key_id,
|
||||
"user_id": user_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
},
|
||||
**kwargs
|
||||
}
|
||||
|
||||
session = await self._get_session()
|
||||
async with session.post(
|
||||
f"{self.base_url}/embeddings",
|
||||
json=payload
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"LiteLLM embedding failed: {response.status} - {error_text}")
|
||||
|
||||
if response.status == 401:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key"
|
||||
)
|
||||
elif response.status == 429:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service error"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"LiteLLM embedding request error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service unavailable"
|
||||
)
|
||||
|
||||
async def get_models(self) -> List[Dict[str, Any]]:
|
||||
"""Get available models from LiteLLM proxy"""
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(f"{self.base_url}/models") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
# Return models with exact names from upstream providers
|
||||
models = data.get("data", [])
|
||||
|
||||
# Pass through model names exactly as they come from upstream
|
||||
# Don't modify model IDs - keep them as the original provider names
|
||||
processed_models = []
|
||||
for model in models:
|
||||
# Keep the exact model ID from upstream provider
|
||||
processed_models.append({
|
||||
"id": model.get("id", ""), # Exact model name from provider
|
||||
"object": model.get("object", "model"),
|
||||
"created": model.get("created", 1677610602),
|
||||
"owned_by": model.get("owned_by", "openai")
|
||||
})
|
||||
|
||||
return processed_models
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"LiteLLM models request failed: {response.status} - {error_text}")
|
||||
return []
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"LiteLLM models request error: {e}")
|
||||
return []
|
||||
|
||||
async def proxy_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
payload: Dict[str, Any],
|
||||
user_id: str,
|
||||
api_key_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Generic proxy request to LiteLLM"""
|
||||
try:
|
||||
# Add metadata to payload
|
||||
if isinstance(payload, dict):
|
||||
payload["metadata"] = {
|
||||
"api_key_id": api_key_id,
|
||||
"user_id": user_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
if "user" not in payload:
|
||||
payload["user"] = f"user_{user_id}"
|
||||
|
||||
session = await self._get_session()
|
||||
|
||||
# Make the request
|
||||
async with session.request(
|
||||
method,
|
||||
f"{self.base_url}/{endpoint.lstrip('/')}",
|
||||
json=payload if method.upper() in ['POST', 'PUT', 'PATCH'] else None,
|
||||
params=payload if method.upper() == 'GET' else None
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"LiteLLM proxy request failed: {response.status} - {error_text}")
|
||||
|
||||
if response.status == 401:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key"
|
||||
)
|
||||
elif response.status == 429:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"LiteLLM service error: {error_text}"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"LiteLLM proxy request error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service unavailable"
|
||||
)
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
"""Get list of available model names/IDs"""
|
||||
try:
|
||||
models_data = await self.get_models()
|
||||
return [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")]
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing model names: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
# Global LiteLLM client instance
|
||||
litellm_client = LiteLLMClient()
|
||||
Reference in New Issue
Block a user