removing lite llm and going directly for privatemode

This commit is contained in:
2025-08-21 08:44:05 +02:00
parent be581b28f8
commit 27ee8b4cdb
16 changed files with 1775 additions and 677 deletions

View File

@@ -1,5 +1,5 @@
"""
LLM API endpoints - proxy to LiteLLM service with authentication and budget enforcement
LLM API endpoints - interface to secure LLM service with authentication and budget enforcement
"""
import logging
@@ -16,7 +16,9 @@ from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthS
from app.core.security import get_current_user
from app.models.user import User
from app.core.config import settings
from app.services.litellm_client import litellm_client
from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage, EmbeddingRequest as LLMEmbeddingRequest
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
from app.services.budget_enforcement import (
check_budget_for_request, record_request_usage, BudgetEnforcementService,
atomic_check_and_reserve_budget, atomic_finalize_usage
@@ -38,7 +40,7 @@ router = APIRouter()
async def get_cached_models() -> List[Dict[str, Any]]:
"""Get models from cache or fetch from LiteLLM if cache is stale"""
"""Get models from cache or fetch from LLM service if cache is stale"""
current_time = time.time()
# Check if cache is still valid
@@ -47,10 +49,20 @@ async def get_cached_models() -> List[Dict[str, Any]]:
logger.debug("Returning cached models list")
return _models_cache["data"]
# Cache miss or stale - fetch from LiteLLM
# Cache miss or stale - fetch from LLM service
try:
logger.debug("Fetching fresh models list from LiteLLM")
models = await litellm_client.get_models()
logger.debug("Fetching fresh models list from LLM service")
model_infos = await llm_service.get_models()
# Convert ModelInfo objects to dict format for compatibility
models = []
for model_info in model_infos:
models.append({
"id": model_info.id,
"object": model_info.object,
"created": model_info.created or int(time.time()),
"owned_by": model_info.owned_by
})
# Update cache
_models_cache["data"] = models
@@ -58,7 +70,7 @@ async def get_cached_models() -> List[Dict[str, Any]]:
return models
except Exception as e:
logger.error(f"Failed to fetch models from LiteLLM: {e}")
logger.error(f"Failed to fetch models from LLM service: {e}")
# Return stale cache if available, otherwise empty list
if _models_cache["data"] is not None:
@@ -75,7 +87,7 @@ def invalidate_models_cache():
logger.info("Models cache invalidated")
# Request/Response Models
# Request/Response Models (API layer)
class ChatMessage(BaseModel):
role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content")
@@ -183,7 +195,7 @@ async def list_models(
detail="Insufficient permissions to list models"
)
# Get models from cache or LiteLLM
# Get models from cache or LLM service
models = await get_cached_models()
# Filter models based on API key permissions
@@ -309,35 +321,55 @@ async def create_chat_completion(
warnings = budget_warnings
reserved_budget_ids = budget_ids
# Convert messages to dict format
messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]
# Convert messages to LLM service format
llm_messages = [LLMChatMessage(role=msg.role, content=msg.content) for msg in chat_request.messages]
# Prepare additional parameters
kwargs = {}
if chat_request.max_tokens is not None:
kwargs["max_tokens"] = chat_request.max_tokens
if chat_request.temperature is not None:
kwargs["temperature"] = chat_request.temperature
if chat_request.top_p is not None:
kwargs["top_p"] = chat_request.top_p
if chat_request.frequency_penalty is not None:
kwargs["frequency_penalty"] = chat_request.frequency_penalty
if chat_request.presence_penalty is not None:
kwargs["presence_penalty"] = chat_request.presence_penalty
if chat_request.stop is not None:
kwargs["stop"] = chat_request.stop
if chat_request.stream is not None:
kwargs["stream"] = chat_request.stream
# Make request to LiteLLM
response = await litellm_client.create_chat_completion(
# Create LLM service request
llm_request = ChatRequest(
model=chat_request.model,
messages=messages,
messages=llm_messages,
temperature=chat_request.temperature,
max_tokens=chat_request.max_tokens,
top_p=chat_request.top_p,
frequency_penalty=chat_request.frequency_penalty,
presence_penalty=chat_request.presence_penalty,
stop=chat_request.stop,
stream=chat_request.stream or False,
user_id=str(context.get("user_id", "anonymous")),
api_key_id=context.get("api_key_id", "jwt_user"),
**kwargs
api_key_id=context.get("api_key_id", 0) if auth_type == "api_key" else 0
)
# Make request to LLM service
llm_response = await llm_service.create_chat_completion(llm_request)
# Convert LLM service response to API format
response = {
"id": llm_response.id,
"object": llm_response.object,
"created": llm_response.created,
"model": llm_response.model,
"choices": [
{
"index": choice.index,
"message": {
"role": choice.message.role,
"content": choice.message.content
},
"finish_reason": choice.finish_reason
}
for choice in llm_response.choices
],
"usage": {
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
"completion_tokens": llm_response.usage.completion_tokens if llm_response.usage else 0,
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
} if llm_response.usage else {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
# Calculate actual cost and update usage
usage = response.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
@@ -382,8 +414,38 @@ async def create_chat_completion(
except HTTPException:
raise
except SecurityError as e:
logger.warning(f"Security error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Security validation failed: {e.message}"
)
except ValidationError as e:
logger.warning(f"Validation error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Request validation failed: {e.message}"
)
except ProviderError as e:
logger.error(f"Provider error in chat completion: {e}")
if "rate limit" in str(e).lower():
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LLM service temporarily unavailable"
)
except LLMError as e:
logger.error(f"LLM service error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="LLM service error"
)
except Exception as e:
logger.error(f"Error creating chat completion: {e}")
logger.error(f"Unexpected error creating chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create chat completion"
@@ -438,15 +500,39 @@ async def create_embedding(
detail=f"Budget exceeded: {error_message}"
)
# Make request to LiteLLM
response = await litellm_client.create_embedding(
# Create LLM service request
llm_request = LLMEmbeddingRequest(
model=request.model,
input_text=request.input,
input=request.input,
encoding_format=request.encoding_format,
user_id=str(context["user_id"]),
api_key_id=context["api_key_id"],
encoding_format=request.encoding_format
api_key_id=context["api_key_id"]
)
# Make request to LLM service
llm_response = await llm_service.create_embedding(llm_request)
# Convert LLM service response to API format
response = {
"object": llm_response.object,
"data": [
{
"object": emb.object,
"index": emb.index,
"embedding": emb.embedding
}
for emb in llm_response.data
],
"model": llm_response.model,
"usage": {
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
} if llm_response.usage else {
"prompt_tokens": int(estimated_tokens),
"total_tokens": int(estimated_tokens)
}
}
# Calculate actual cost and update usage
usage = response.get("usage", {})
total_tokens = usage.get("total_tokens", int(estimated_tokens))
@@ -475,8 +561,38 @@ async def create_embedding(
except HTTPException:
raise
except SecurityError as e:
logger.warning(f"Security error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Security validation failed: {e.message}"
)
except ValidationError as e:
logger.warning(f"Validation error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Request validation failed: {e.message}"
)
except ProviderError as e:
logger.error(f"Provider error in embedding: {e}")
if "rate limit" in str(e).lower():
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LLM service temporarily unavailable"
)
except LLMError as e:
logger.error(f"LLM service error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="LLM service error"
)
except Exception as e:
logger.error(f"Error creating embedding: {e}")
logger.error(f"Unexpected error creating embedding: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create embedding"
@@ -489,11 +605,28 @@ async def llm_health_check(
):
"""Health check for LLM service"""
try:
health_status = await litellm_client.health_check()
health_summary = llm_service.get_health_summary()
provider_status = await llm_service.get_provider_status()
# Determine overall health
overall_status = "healthy"
if health_summary["service_status"] != "healthy":
overall_status = "degraded"
for provider, status in provider_status.items():
if status.status == "unavailable":
overall_status = "degraded"
break
return {
"status": "healthy",
"service": "LLM Proxy",
"litellm_status": health_status,
"status": overall_status,
"service": "LLM Service",
"service_status": health_summary,
"provider_status": {name: {
"status": status.status,
"latency_ms": status.latency_ms,
"error_message": status.error_message
} for name, status in provider_status.items()},
"user_id": context["user_id"],
"api_key_name": context["api_key_name"]
}
@@ -501,7 +634,7 @@ async def llm_health_check(
logger.error(f"LLM health check error: {e}")
return {
"status": "unhealthy",
"service": "LLM Proxy",
"service": "LLM Service",
"error": str(e)
}
@@ -626,50 +759,83 @@ async def get_budget_status(
)
# Generic proxy endpoint for other LiteLLM endpoints
@router.api_route("/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def proxy_endpoint(
endpoint: str,
request: Request,
# Generic endpoint for additional LLM service functionality
@router.get("/metrics")
async def get_llm_metrics(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
):
"""Generic proxy endpoint for LiteLLM requests"""
"""Get LLM service metrics (admin only)"""
try:
# Check for admin permissions
auth_service = APIKeyAuthService(db)
# Check endpoint permission
if not await auth_service.check_endpoint_permission(context, endpoint):
if not await auth_service.check_scope_permission(context, "admin.metrics"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Endpoint '{endpoint}' not allowed"
detail="Admin permissions required to view metrics"
)
# Get request body
if request.method in ["POST", "PUT", "PATCH"]:
try:
payload = await request.json()
except:
payload = {}
else:
payload = dict(request.query_params)
# Make request to LiteLLM
response = await litellm_client.proxy_request(
method=request.method,
endpoint=endpoint,
payload=payload,
user_id=str(context["user_id"]),
api_key_id=context["api_key_id"]
)
return response
metrics = llm_service.get_metrics()
return {
"object": "llm_metrics",
"data": {
"total_requests": metrics.total_requests,
"successful_requests": metrics.successful_requests,
"failed_requests": metrics.failed_requests,
"security_blocked_requests": metrics.security_blocked_requests,
"average_latency_ms": metrics.average_latency_ms,
"average_risk_score": metrics.average_risk_score,
"provider_metrics": metrics.provider_metrics,
"last_updated": metrics.last_updated.isoformat()
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error proxying request to {endpoint}: {e}")
logger.error(f"Error getting LLM metrics: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to proxy request"
detail="Failed to get LLM metrics"
)
@router.get("/providers/status")
async def get_provider_status(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
):
"""Get status of all LLM providers"""
try:
auth_service = APIKeyAuthService(db)
if not await auth_service.check_scope_permission(context, "admin.status"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin permissions required to view provider status"
)
provider_status = await llm_service.get_provider_status()
return {
"object": "provider_status",
"data": {
name: {
"provider": status.provider,
"status": status.status,
"latency_ms": status.latency_ms,
"success_rate": status.success_rate,
"last_check": status.last_check.isoformat(),
"error_message": status.error_message,
"models_available": status.models_available
}
for name, status in provider_status.items()
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting provider status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get provider status"
)