mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
removing lite llm and going directly for privatemode
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
LLM API endpoints - proxy to LiteLLM service with authentication and budget enforcement
|
||||
LLM API endpoints - interface to secure LLM service with authentication and budget enforcement
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -16,7 +16,9 @@ from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthS
|
||||
from app.core.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.core.config import settings
|
||||
from app.services.litellm_client import litellm_client
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage, EmbeddingRequest as LLMEmbeddingRequest
|
||||
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
|
||||
from app.services.budget_enforcement import (
|
||||
check_budget_for_request, record_request_usage, BudgetEnforcementService,
|
||||
atomic_check_and_reserve_budget, atomic_finalize_usage
|
||||
@@ -38,7 +40,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
async def get_cached_models() -> List[Dict[str, Any]]:
|
||||
"""Get models from cache or fetch from LiteLLM if cache is stale"""
|
||||
"""Get models from cache or fetch from LLM service if cache is stale"""
|
||||
current_time = time.time()
|
||||
|
||||
# Check if cache is still valid
|
||||
@@ -47,10 +49,20 @@ async def get_cached_models() -> List[Dict[str, Any]]:
|
||||
logger.debug("Returning cached models list")
|
||||
return _models_cache["data"]
|
||||
|
||||
# Cache miss or stale - fetch from LiteLLM
|
||||
# Cache miss or stale - fetch from LLM service
|
||||
try:
|
||||
logger.debug("Fetching fresh models list from LiteLLM")
|
||||
models = await litellm_client.get_models()
|
||||
logger.debug("Fetching fresh models list from LLM service")
|
||||
model_infos = await llm_service.get_models()
|
||||
|
||||
# Convert ModelInfo objects to dict format for compatibility
|
||||
models = []
|
||||
for model_info in model_infos:
|
||||
models.append({
|
||||
"id": model_info.id,
|
||||
"object": model_info.object,
|
||||
"created": model_info.created or int(time.time()),
|
||||
"owned_by": model_info.owned_by
|
||||
})
|
||||
|
||||
# Update cache
|
||||
_models_cache["data"] = models
|
||||
@@ -58,7 +70,7 @@ async def get_cached_models() -> List[Dict[str, Any]]:
|
||||
|
||||
return models
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch models from LiteLLM: {e}")
|
||||
logger.error(f"Failed to fetch models from LLM service: {e}")
|
||||
|
||||
# Return stale cache if available, otherwise empty list
|
||||
if _models_cache["data"] is not None:
|
||||
@@ -75,7 +87,7 @@ def invalidate_models_cache():
|
||||
logger.info("Models cache invalidated")
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
# Request/Response Models (API layer)
|
||||
class ChatMessage(BaseModel):
|
||||
role: str = Field(..., description="Message role (system, user, assistant)")
|
||||
content: str = Field(..., description="Message content")
|
||||
@@ -183,7 +195,7 @@ async def list_models(
|
||||
detail="Insufficient permissions to list models"
|
||||
)
|
||||
|
||||
# Get models from cache or LiteLLM
|
||||
# Get models from cache or LLM service
|
||||
models = await get_cached_models()
|
||||
|
||||
# Filter models based on API key permissions
|
||||
@@ -309,35 +321,55 @@ async def create_chat_completion(
|
||||
warnings = budget_warnings
|
||||
reserved_budget_ids = budget_ids
|
||||
|
||||
# Convert messages to dict format
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]
|
||||
# Convert messages to LLM service format
|
||||
llm_messages = [LLMChatMessage(role=msg.role, content=msg.content) for msg in chat_request.messages]
|
||||
|
||||
# Prepare additional parameters
|
||||
kwargs = {}
|
||||
if chat_request.max_tokens is not None:
|
||||
kwargs["max_tokens"] = chat_request.max_tokens
|
||||
if chat_request.temperature is not None:
|
||||
kwargs["temperature"] = chat_request.temperature
|
||||
if chat_request.top_p is not None:
|
||||
kwargs["top_p"] = chat_request.top_p
|
||||
if chat_request.frequency_penalty is not None:
|
||||
kwargs["frequency_penalty"] = chat_request.frequency_penalty
|
||||
if chat_request.presence_penalty is not None:
|
||||
kwargs["presence_penalty"] = chat_request.presence_penalty
|
||||
if chat_request.stop is not None:
|
||||
kwargs["stop"] = chat_request.stop
|
||||
if chat_request.stream is not None:
|
||||
kwargs["stream"] = chat_request.stream
|
||||
|
||||
# Make request to LiteLLM
|
||||
response = await litellm_client.create_chat_completion(
|
||||
# Create LLM service request
|
||||
llm_request = ChatRequest(
|
||||
model=chat_request.model,
|
||||
messages=messages,
|
||||
messages=llm_messages,
|
||||
temperature=chat_request.temperature,
|
||||
max_tokens=chat_request.max_tokens,
|
||||
top_p=chat_request.top_p,
|
||||
frequency_penalty=chat_request.frequency_penalty,
|
||||
presence_penalty=chat_request.presence_penalty,
|
||||
stop=chat_request.stop,
|
||||
stream=chat_request.stream or False,
|
||||
user_id=str(context.get("user_id", "anonymous")),
|
||||
api_key_id=context.get("api_key_id", "jwt_user"),
|
||||
**kwargs
|
||||
api_key_id=context.get("api_key_id", 0) if auth_type == "api_key" else 0
|
||||
)
|
||||
|
||||
# Make request to LLM service
|
||||
llm_response = await llm_service.create_chat_completion(llm_request)
|
||||
|
||||
# Convert LLM service response to API format
|
||||
response = {
|
||||
"id": llm_response.id,
|
||||
"object": llm_response.object,
|
||||
"created": llm_response.created,
|
||||
"model": llm_response.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": choice.index,
|
||||
"message": {
|
||||
"role": choice.message.role,
|
||||
"content": choice.message.content
|
||||
},
|
||||
"finish_reason": choice.finish_reason
|
||||
}
|
||||
for choice in llm_response.choices
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
|
||||
"completion_tokens": llm_response.usage.completion_tokens if llm_response.usage else 0,
|
||||
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
|
||||
} if llm_response.usage else {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
# Calculate actual cost and update usage
|
||||
usage = response.get("usage", {})
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
@@ -382,8 +414,38 @@ async def create_chat_completion(
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SecurityError as e:
|
||||
logger.warning(f"Security error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Security validation failed: {e.message}"
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Validation error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Request validation failed: {e.message}"
|
||||
)
|
||||
except ProviderError as e:
|
||||
logger.error(f"Provider error in chat completion: {e}")
|
||||
if "rate limit" in str(e).lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LLM service temporarily unavailable"
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.error(f"LLM service error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="LLM service error"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating chat completion: {e}")
|
||||
logger.error(f"Unexpected error creating chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create chat completion"
|
||||
@@ -438,15 +500,39 @@ async def create_embedding(
|
||||
detail=f"Budget exceeded: {error_message}"
|
||||
)
|
||||
|
||||
# Make request to LiteLLM
|
||||
response = await litellm_client.create_embedding(
|
||||
# Create LLM service request
|
||||
llm_request = LLMEmbeddingRequest(
|
||||
model=request.model,
|
||||
input_text=request.input,
|
||||
input=request.input,
|
||||
encoding_format=request.encoding_format,
|
||||
user_id=str(context["user_id"]),
|
||||
api_key_id=context["api_key_id"],
|
||||
encoding_format=request.encoding_format
|
||||
api_key_id=context["api_key_id"]
|
||||
)
|
||||
|
||||
# Make request to LLM service
|
||||
llm_response = await llm_service.create_embedding(llm_request)
|
||||
|
||||
# Convert LLM service response to API format
|
||||
response = {
|
||||
"object": llm_response.object,
|
||||
"data": [
|
||||
{
|
||||
"object": emb.object,
|
||||
"index": emb.index,
|
||||
"embedding": emb.embedding
|
||||
}
|
||||
for emb in llm_response.data
|
||||
],
|
||||
"model": llm_response.model,
|
||||
"usage": {
|
||||
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
|
||||
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
|
||||
} if llm_response.usage else {
|
||||
"prompt_tokens": int(estimated_tokens),
|
||||
"total_tokens": int(estimated_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
# Calculate actual cost and update usage
|
||||
usage = response.get("usage", {})
|
||||
total_tokens = usage.get("total_tokens", int(estimated_tokens))
|
||||
@@ -475,8 +561,38 @@ async def create_embedding(
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SecurityError as e:
|
||||
logger.warning(f"Security error in embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Security validation failed: {e.message}"
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Validation error in embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Request validation failed: {e.message}"
|
||||
)
|
||||
except ProviderError as e:
|
||||
logger.error(f"Provider error in embedding: {e}")
|
||||
if "rate limit" in str(e).lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LLM service temporarily unavailable"
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.error(f"LLM service error in embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="LLM service error"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating embedding: {e}")
|
||||
logger.error(f"Unexpected error creating embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create embedding"
|
||||
@@ -489,11 +605,28 @@ async def llm_health_check(
|
||||
):
|
||||
"""Health check for LLM service"""
|
||||
try:
|
||||
health_status = await litellm_client.health_check()
|
||||
health_summary = llm_service.get_health_summary()
|
||||
provider_status = await llm_service.get_provider_status()
|
||||
|
||||
# Determine overall health
|
||||
overall_status = "healthy"
|
||||
if health_summary["service_status"] != "healthy":
|
||||
overall_status = "degraded"
|
||||
|
||||
for provider, status in provider_status.items():
|
||||
if status.status == "unavailable":
|
||||
overall_status = "degraded"
|
||||
break
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "LLM Proxy",
|
||||
"litellm_status": health_status,
|
||||
"status": overall_status,
|
||||
"service": "LLM Service",
|
||||
"service_status": health_summary,
|
||||
"provider_status": {name: {
|
||||
"status": status.status,
|
||||
"latency_ms": status.latency_ms,
|
||||
"error_message": status.error_message
|
||||
} for name, status in provider_status.items()},
|
||||
"user_id": context["user_id"],
|
||||
"api_key_name": context["api_key_name"]
|
||||
}
|
||||
@@ -501,7 +634,7 @@ async def llm_health_check(
|
||||
logger.error(f"LLM health check error: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"service": "LLM Proxy",
|
||||
"service": "LLM Service",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
@@ -626,50 +759,83 @@ async def get_budget_status(
|
||||
)
|
||||
|
||||
|
||||
# Generic proxy endpoint for other LiteLLM endpoints
|
||||
@router.api_route("/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
||||
async def proxy_endpoint(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
# Generic endpoint for additional LLM service functionality
|
||||
@router.get("/metrics")
|
||||
async def get_llm_metrics(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Generic proxy endpoint for LiteLLM requests"""
|
||||
"""Get LLM service metrics (admin only)"""
|
||||
try:
|
||||
# Check for admin permissions
|
||||
auth_service = APIKeyAuthService(db)
|
||||
|
||||
# Check endpoint permission
|
||||
if not await auth_service.check_endpoint_permission(context, endpoint):
|
||||
if not await auth_service.check_scope_permission(context, "admin.metrics"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Endpoint '{endpoint}' not allowed"
|
||||
detail="Admin permissions required to view metrics"
|
||||
)
|
||||
|
||||
# Get request body
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except:
|
||||
payload = {}
|
||||
else:
|
||||
payload = dict(request.query_params)
|
||||
|
||||
# Make request to LiteLLM
|
||||
response = await litellm_client.proxy_request(
|
||||
method=request.method,
|
||||
endpoint=endpoint,
|
||||
payload=payload,
|
||||
user_id=str(context["user_id"]),
|
||||
api_key_id=context["api_key_id"]
|
||||
)
|
||||
|
||||
return response
|
||||
metrics = llm_service.get_metrics()
|
||||
return {
|
||||
"object": "llm_metrics",
|
||||
"data": {
|
||||
"total_requests": metrics.total_requests,
|
||||
"successful_requests": metrics.successful_requests,
|
||||
"failed_requests": metrics.failed_requests,
|
||||
"security_blocked_requests": metrics.security_blocked_requests,
|
||||
"average_latency_ms": metrics.average_latency_ms,
|
||||
"average_risk_score": metrics.average_risk_score,
|
||||
"provider_metrics": metrics.provider_metrics,
|
||||
"last_updated": metrics.last_updated.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error proxying request to {endpoint}: {e}")
|
||||
logger.error(f"Error getting LLM metrics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to proxy request"
|
||||
detail="Failed to get LLM metrics"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/providers/status")
|
||||
async def get_provider_status(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get status of all LLM providers"""
|
||||
try:
|
||||
auth_service = APIKeyAuthService(db)
|
||||
if not await auth_service.check_scope_permission(context, "admin.status"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin permissions required to view provider status"
|
||||
)
|
||||
|
||||
provider_status = await llm_service.get_provider_status()
|
||||
return {
|
||||
"object": "provider_status",
|
||||
"data": {
|
||||
name: {
|
||||
"provider": status.provider,
|
||||
"status": status.status,
|
||||
"latency_ms": status.latency_ms,
|
||||
"success_rate": status.success_rate,
|
||||
"last_check": status.last_check.isoformat(),
|
||||
"error_message": status.error_message,
|
||||
"models_available": status.models_available
|
||||
}
|
||||
for name, status in provider_status.items()
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting provider status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get provider status"
|
||||
)
|
||||
@@ -447,7 +447,7 @@ async def get_module_config(module_name: str):
|
||||
log_api_request("get_module_config", {"module_name": module_name})
|
||||
|
||||
from app.services.module_config_manager import module_config_manager
|
||||
from app.services.litellm_client import litellm_client
|
||||
from app.services.llm.service import llm_service
|
||||
import copy
|
||||
|
||||
# Get module manifest and schema
|
||||
@@ -461,9 +461,9 @@ async def get_module_config(module_name: str):
|
||||
# For Signal module, populate model options dynamically
|
||||
if module_name == "signal" and schema:
|
||||
try:
|
||||
# Get available models from LiteLLM
|
||||
models_data = await litellm_client.get_models()
|
||||
model_ids = [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")]
|
||||
# Get available models from LLM service
|
||||
models_data = await llm_service.get_models()
|
||||
model_ids = [model.id for model in models_data]
|
||||
|
||||
if model_ids:
|
||||
# Create a copy of the schema to avoid modifying the original
|
||||
|
||||
@@ -15,7 +15,8 @@ from app.models.prompt_template import PromptTemplate, ChatbotPromptVariable
|
||||
from app.core.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.core.logging import log_api_request
|
||||
from app.services.litellm_client import litellm_client
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -394,25 +395,28 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
|
||||
]
|
||||
|
||||
# Get available models to use a default model
|
||||
models = await litellm_client.get_models()
|
||||
models = await llm_service.get_models()
|
||||
if not models:
|
||||
raise HTTPException(status_code=503, detail="No LLM models available")
|
||||
|
||||
# Use the first available model (you might want to make this configurable)
|
||||
default_model = models[0]["id"]
|
||||
default_model = models[0].id
|
||||
|
||||
# Make the AI call
|
||||
response = await litellm_client.create_chat_completion(
|
||||
# Prepare the chat request for the new LLM service
|
||||
chat_request = LLMChatRequest(
|
||||
model=default_model,
|
||||
messages=messages,
|
||||
user_id=str(user_id),
|
||||
api_key_id=1, # Using default API key, you might want to make this dynamic
|
||||
messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages],
|
||||
temperature=0.3,
|
||||
max_tokens=1000
|
||||
max_tokens=1000,
|
||||
user_id=str(user_id),
|
||||
api_key_id=1 # Using default API key, you might want to make this dynamic
|
||||
)
|
||||
|
||||
# Make the AI call
|
||||
response = await llm_service.create_chat_completion(chat_request)
|
||||
|
||||
# Extract the improved prompt from the response
|
||||
improved_prompt = response["choices"][0]["message"]["content"].strip()
|
||||
improved_prompt = response.choices[0].message.content.strip()
|
||||
|
||||
return {
|
||||
"improved_prompt": improved_prompt,
|
||||
|
||||
@@ -51,7 +51,7 @@ class SystemInfoResponse(BaseModel):
|
||||
environment: str
|
||||
database_status: str
|
||||
redis_status: str
|
||||
litellm_status: str
|
||||
llm_service_status: str
|
||||
modules_loaded: int
|
||||
active_users: int
|
||||
total_api_keys: int
|
||||
@@ -227,8 +227,13 @@ async def get_system_info(
|
||||
# Get Redis status (simplified check)
|
||||
redis_status = "healthy" # Would implement actual Redis check
|
||||
|
||||
# Get LiteLLM status (simplified check)
|
||||
litellm_status = "healthy" # Would implement actual LiteLLM check
|
||||
# Get LLM service status
|
||||
try:
|
||||
from app.services.llm.service import llm_service
|
||||
health_summary = llm_service.get_health_summary()
|
||||
llm_service_status = health_summary.get("service_status", "unknown")
|
||||
except Exception:
|
||||
llm_service_status = "unavailable"
|
||||
|
||||
# Get modules loaded (from module manager)
|
||||
modules_loaded = 8 # Would get from actual module manager
|
||||
@@ -261,7 +266,7 @@ async def get_system_info(
|
||||
environment="production",
|
||||
database_status=database_status,
|
||||
redis_status=redis_status,
|
||||
litellm_status=litellm_status,
|
||||
llm_service_status=llm_service_status,
|
||||
modules_loaded=modules_loaded,
|
||||
active_users=active_users,
|
||||
total_api_keys=total_api_keys,
|
||||
|
||||
Reference in New Issue
Block a user