mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
304 lines
12 KiB
Python
304 lines
12 KiB
Python
"""
|
|
LiteLLM Client Service
|
|
Handles communication with the LiteLLM proxy service
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Dict, Any, Optional, List
|
|
from datetime import datetime
|
|
|
|
import aiohttp
|
|
from fastapi import HTTPException, status
|
|
|
|
from app.core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LiteLLMClient:
|
|
"""Client for communicating with LiteLLM proxy service"""
|
|
|
|
def __init__(self):
|
|
self.base_url = settings.LITELLM_BASE_URL
|
|
self.master_key = settings.LITELLM_MASTER_KEY
|
|
self.session: Optional[aiohttp.ClientSession] = None
|
|
self.timeout = aiohttp.ClientTimeout(total=600) # 10 minutes timeout
|
|
|
|
async def _get_session(self) -> aiohttp.ClientSession:
|
|
"""Get or create aiohttp session"""
|
|
if self.session is None or self.session.closed:
|
|
self.session = aiohttp.ClientSession(
|
|
timeout=self.timeout,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.master_key}"
|
|
}
|
|
)
|
|
return self.session
|
|
|
|
async def close(self):
|
|
"""Close the HTTP session"""
|
|
if self.session and not self.session.closed:
|
|
await self.session.close()
|
|
|
|
async def health_check(self) -> Dict[str, Any]:
|
|
"""Check LiteLLM proxy health"""
|
|
try:
|
|
session = await self._get_session()
|
|
async with session.get(f"{self.base_url}/health") as response:
|
|
if response.status == 200:
|
|
return await response.json()
|
|
else:
|
|
logger.error(f"LiteLLM health check failed: {response.status}")
|
|
return {"status": "unhealthy", "error": f"HTTP {response.status}"}
|
|
except Exception as e:
|
|
logger.error(f"LiteLLM health check error: {e}")
|
|
return {"status": "unhealthy", "error": str(e)}
|
|
|
|
async def get_models(self) -> List[Dict[str, Any]]:
|
|
"""Get available models from LiteLLM"""
|
|
try:
|
|
session = await self._get_session()
|
|
async with session.get(f"{self.base_url}/models") as response:
|
|
if response.status == 200:
|
|
data = await response.json()
|
|
return data.get("data", [])
|
|
else:
|
|
logger.error(f"Failed to get models: {response.status}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="LiteLLM service unavailable"
|
|
)
|
|
except aiohttp.ClientError as e:
|
|
logger.error(f"LiteLLM models request error: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="LiteLLM service unavailable"
|
|
)
|
|
|
|
async def create_chat_completion(
|
|
self,
|
|
model: str,
|
|
messages: List[Dict[str, str]],
|
|
user_id: str,
|
|
api_key_id: int,
|
|
**kwargs
|
|
) -> Dict[str, Any]:
|
|
"""Create chat completion via LiteLLM proxy"""
|
|
try:
|
|
# Prepare request payload
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"user": f"user_{user_id}", # User identifier for tracking
|
|
"metadata": {
|
|
"api_key_id": api_key_id,
|
|
"user_id": user_id,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
},
|
|
**kwargs
|
|
}
|
|
|
|
session = await self._get_session()
|
|
async with session.post(
|
|
f"{self.base_url}/chat/completions",
|
|
json=payload
|
|
) as response:
|
|
if response.status == 200:
|
|
return await response.json()
|
|
else:
|
|
error_text = await response.text()
|
|
logger.error(f"LiteLLM chat completion failed: {response.status} - {error_text}")
|
|
|
|
# Handle specific error cases
|
|
if response.status == 401:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid API key"
|
|
)
|
|
elif response.status == 429:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail="Rate limit exceeded"
|
|
)
|
|
elif response.status == 400:
|
|
try:
|
|
error_data = await response.json()
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=error_data.get("error", {}).get("message", "Bad request")
|
|
)
|
|
except:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid request"
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="LiteLLM service error"
|
|
)
|
|
except aiohttp.ClientError as e:
|
|
logger.error(f"LiteLLM chat completion request error: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="LiteLLM service unavailable"
|
|
)
|
|
|
|
async def create_embedding(
|
|
self,
|
|
model: str,
|
|
input_text: str,
|
|
user_id: str,
|
|
api_key_id: int,
|
|
**kwargs
|
|
) -> Dict[str, Any]:
|
|
"""Create embedding via LiteLLM proxy"""
|
|
try:
|
|
payload = {
|
|
"model": model,
|
|
"input": input_text,
|
|
"user": f"user_{user_id}",
|
|
"metadata": {
|
|
"api_key_id": api_key_id,
|
|
"user_id": user_id,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
},
|
|
**kwargs
|
|
}
|
|
|
|
session = await self._get_session()
|
|
async with session.post(
|
|
f"{self.base_url}/embeddings",
|
|
json=payload
|
|
) as response:
|
|
if response.status == 200:
|
|
return await response.json()
|
|
else:
|
|
error_text = await response.text()
|
|
logger.error(f"LiteLLM embedding failed: {response.status} - {error_text}")
|
|
|
|
if response.status == 401:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid API key"
|
|
)
|
|
elif response.status == 429:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail="Rate limit exceeded"
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="LiteLLM service error"
|
|
)
|
|
except aiohttp.ClientError as e:
|
|
logger.error(f"LiteLLM embedding request error: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="LiteLLM service unavailable"
|
|
)
|
|
|
|
async def get_models(self) -> List[Dict[str, Any]]:
|
|
"""Get available models from LiteLLM proxy"""
|
|
try:
|
|
session = await self._get_session()
|
|
async with session.get(f"{self.base_url}/models") as response:
|
|
if response.status == 200:
|
|
data = await response.json()
|
|
# Return models with exact names from upstream providers
|
|
models = data.get("data", [])
|
|
|
|
# Pass through model names exactly as they come from upstream
|
|
# Don't modify model IDs - keep them as the original provider names
|
|
processed_models = []
|
|
for model in models:
|
|
# Keep the exact model ID from upstream provider
|
|
processed_models.append({
|
|
"id": model.get("id", ""), # Exact model name from provider
|
|
"object": model.get("object", "model"),
|
|
"created": model.get("created", 1677610602),
|
|
"owned_by": model.get("owned_by", "openai")
|
|
})
|
|
|
|
return processed_models
|
|
else:
|
|
error_text = await response.text()
|
|
logger.error(f"LiteLLM models request failed: {response.status} - {error_text}")
|
|
return []
|
|
except aiohttp.ClientError as e:
|
|
logger.error(f"LiteLLM models request error: {e}")
|
|
return []
|
|
|
|
async def proxy_request(
|
|
self,
|
|
method: str,
|
|
endpoint: str,
|
|
payload: Dict[str, Any],
|
|
user_id: str,
|
|
api_key_id: int
|
|
) -> Dict[str, Any]:
|
|
"""Generic proxy request to LiteLLM"""
|
|
try:
|
|
# Add metadata to payload
|
|
if isinstance(payload, dict):
|
|
payload["metadata"] = {
|
|
"api_key_id": api_key_id,
|
|
"user_id": user_id,
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
if "user" not in payload:
|
|
payload["user"] = f"user_{user_id}"
|
|
|
|
session = await self._get_session()
|
|
|
|
# Make the request
|
|
async with session.request(
|
|
method,
|
|
f"{self.base_url}/{endpoint.lstrip('/')}",
|
|
json=payload if method.upper() in ['POST', 'PUT', 'PATCH'] else None,
|
|
params=payload if method.upper() == 'GET' else None
|
|
) as response:
|
|
if response.status == 200:
|
|
return await response.json()
|
|
else:
|
|
error_text = await response.text()
|
|
logger.error(f"LiteLLM proxy request failed: {response.status} - {error_text}")
|
|
|
|
if response.status == 401:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid API key"
|
|
)
|
|
elif response.status == 429:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail="Rate limit exceeded"
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=response.status,
|
|
detail=f"LiteLLM service error: {error_text}"
|
|
)
|
|
except aiohttp.ClientError as e:
|
|
logger.error(f"LiteLLM proxy request error: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="LiteLLM service unavailable"
|
|
)
|
|
|
|
async def list_models(self) -> List[str]:
|
|
"""Get list of available model names/IDs"""
|
|
try:
|
|
models_data = await self.get_models()
|
|
return [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")]
|
|
except Exception as e:
|
|
logger.error(f"Error listing model names: {str(e)}")
|
|
return []
|
|
|
|
|
|
# Global LiteLLM client instance
|
|
litellm_client = LiteLLMClient() |