removing lite llm and going directly for privatemode

This commit is contained in:
2025-08-21 08:44:05 +02:00
parent be581b28f8
commit 27ee8b4cdb
16 changed files with 1775 additions and 677 deletions

View File

@@ -1,5 +1,5 @@
"""
LLM API endpoints - proxy to LiteLLM service with authentication and budget enforcement
LLM API endpoints - interface to secure LLM service with authentication and budget enforcement
"""
import logging
@@ -16,7 +16,9 @@ from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthS
from app.core.security import get_current_user
from app.models.user import User
from app.core.config import settings
from app.services.litellm_client import litellm_client
from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage, EmbeddingRequest as LLMEmbeddingRequest
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
from app.services.budget_enforcement import (
check_budget_for_request, record_request_usage, BudgetEnforcementService,
atomic_check_and_reserve_budget, atomic_finalize_usage
@@ -38,7 +40,7 @@ router = APIRouter()
async def get_cached_models() -> List[Dict[str, Any]]:
"""Get models from cache or fetch from LiteLLM if cache is stale"""
"""Get models from cache or fetch from LLM service if cache is stale"""
current_time = time.time()
# Check if cache is still valid
@@ -47,10 +49,20 @@ async def get_cached_models() -> List[Dict[str, Any]]:
logger.debug("Returning cached models list")
return _models_cache["data"]
# Cache miss or stale - fetch from LiteLLM
# Cache miss or stale - fetch from LLM service
try:
logger.debug("Fetching fresh models list from LiteLLM")
models = await litellm_client.get_models()
logger.debug("Fetching fresh models list from LLM service")
model_infos = await llm_service.get_models()
# Convert ModelInfo objects to dict format for compatibility
models = []
for model_info in model_infos:
models.append({
"id": model_info.id,
"object": model_info.object,
"created": model_info.created or int(time.time()),
"owned_by": model_info.owned_by
})
# Update cache
_models_cache["data"] = models
@@ -58,7 +70,7 @@ async def get_cached_models() -> List[Dict[str, Any]]:
return models
except Exception as e:
logger.error(f"Failed to fetch models from LiteLLM: {e}")
logger.error(f"Failed to fetch models from LLM service: {e}")
# Return stale cache if available, otherwise empty list
if _models_cache["data"] is not None:
@@ -75,7 +87,7 @@ def invalidate_models_cache():
logger.info("Models cache invalidated")
# Request/Response Models
# Request/Response Models (API layer)
class ChatMessage(BaseModel):
role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content")
@@ -183,7 +195,7 @@ async def list_models(
detail="Insufficient permissions to list models"
)
# Get models from cache or LiteLLM
# Get models from cache or LLM service
models = await get_cached_models()
# Filter models based on API key permissions
@@ -309,35 +321,55 @@ async def create_chat_completion(
warnings = budget_warnings
reserved_budget_ids = budget_ids
# Convert messages to dict format
messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]
# Convert messages to LLM service format
llm_messages = [LLMChatMessage(role=msg.role, content=msg.content) for msg in chat_request.messages]
# Prepare additional parameters
kwargs = {}
if chat_request.max_tokens is not None:
kwargs["max_tokens"] = chat_request.max_tokens
if chat_request.temperature is not None:
kwargs["temperature"] = chat_request.temperature
if chat_request.top_p is not None:
kwargs["top_p"] = chat_request.top_p
if chat_request.frequency_penalty is not None:
kwargs["frequency_penalty"] = chat_request.frequency_penalty
if chat_request.presence_penalty is not None:
kwargs["presence_penalty"] = chat_request.presence_penalty
if chat_request.stop is not None:
kwargs["stop"] = chat_request.stop
if chat_request.stream is not None:
kwargs["stream"] = chat_request.stream
# Make request to LiteLLM
response = await litellm_client.create_chat_completion(
# Create LLM service request
llm_request = ChatRequest(
model=chat_request.model,
messages=messages,
messages=llm_messages,
temperature=chat_request.temperature,
max_tokens=chat_request.max_tokens,
top_p=chat_request.top_p,
frequency_penalty=chat_request.frequency_penalty,
presence_penalty=chat_request.presence_penalty,
stop=chat_request.stop,
stream=chat_request.stream or False,
user_id=str(context.get("user_id", "anonymous")),
api_key_id=context.get("api_key_id", "jwt_user"),
**kwargs
api_key_id=context.get("api_key_id", 0) if auth_type == "api_key" else 0
)
# Make request to LLM service
llm_response = await llm_service.create_chat_completion(llm_request)
# Convert LLM service response to API format
response = {
"id": llm_response.id,
"object": llm_response.object,
"created": llm_response.created,
"model": llm_response.model,
"choices": [
{
"index": choice.index,
"message": {
"role": choice.message.role,
"content": choice.message.content
},
"finish_reason": choice.finish_reason
}
for choice in llm_response.choices
],
"usage": {
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
"completion_tokens": llm_response.usage.completion_tokens if llm_response.usage else 0,
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
} if llm_response.usage else {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
# Calculate actual cost and update usage
usage = response.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
@@ -382,8 +414,38 @@ async def create_chat_completion(
except HTTPException:
raise
except SecurityError as e:
logger.warning(f"Security error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Security validation failed: {e.message}"
)
except ValidationError as e:
logger.warning(f"Validation error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Request validation failed: {e.message}"
)
except ProviderError as e:
logger.error(f"Provider error in chat completion: {e}")
if "rate limit" in str(e).lower():
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LLM service temporarily unavailable"
)
except LLMError as e:
logger.error(f"LLM service error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="LLM service error"
)
except Exception as e:
logger.error(f"Error creating chat completion: {e}")
logger.error(f"Unexpected error creating chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create chat completion"
@@ -438,15 +500,39 @@ async def create_embedding(
detail=f"Budget exceeded: {error_message}"
)
# Make request to LiteLLM
response = await litellm_client.create_embedding(
# Create LLM service request
llm_request = LLMEmbeddingRequest(
model=request.model,
input_text=request.input,
input=request.input,
encoding_format=request.encoding_format,
user_id=str(context["user_id"]),
api_key_id=context["api_key_id"],
encoding_format=request.encoding_format
api_key_id=context["api_key_id"]
)
# Make request to LLM service
llm_response = await llm_service.create_embedding(llm_request)
# Convert LLM service response to API format
response = {
"object": llm_response.object,
"data": [
{
"object": emb.object,
"index": emb.index,
"embedding": emb.embedding
}
for emb in llm_response.data
],
"model": llm_response.model,
"usage": {
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
} if llm_response.usage else {
"prompt_tokens": int(estimated_tokens),
"total_tokens": int(estimated_tokens)
}
}
# Calculate actual cost and update usage
usage = response.get("usage", {})
total_tokens = usage.get("total_tokens", int(estimated_tokens))
@@ -475,8 +561,38 @@ async def create_embedding(
except HTTPException:
raise
except SecurityError as e:
logger.warning(f"Security error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Security validation failed: {e.message}"
)
except ValidationError as e:
logger.warning(f"Validation error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Request validation failed: {e.message}"
)
except ProviderError as e:
logger.error(f"Provider error in embedding: {e}")
if "rate limit" in str(e).lower():
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LLM service temporarily unavailable"
)
except LLMError as e:
logger.error(f"LLM service error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="LLM service error"
)
except Exception as e:
logger.error(f"Error creating embedding: {e}")
logger.error(f"Unexpected error creating embedding: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create embedding"
@@ -489,11 +605,28 @@ async def llm_health_check(
):
"""Health check for LLM service"""
try:
health_status = await litellm_client.health_check()
health_summary = llm_service.get_health_summary()
provider_status = await llm_service.get_provider_status()
# Determine overall health
overall_status = "healthy"
if health_summary["service_status"] != "healthy":
overall_status = "degraded"
for provider, status in provider_status.items():
if status.status == "unavailable":
overall_status = "degraded"
break
return {
"status": "healthy",
"service": "LLM Proxy",
"litellm_status": health_status,
"status": overall_status,
"service": "LLM Service",
"service_status": health_summary,
"provider_status": {name: {
"status": status.status,
"latency_ms": status.latency_ms,
"error_message": status.error_message
} for name, status in provider_status.items()},
"user_id": context["user_id"],
"api_key_name": context["api_key_name"]
}
@@ -501,7 +634,7 @@ async def llm_health_check(
logger.error(f"LLM health check error: {e}")
return {
"status": "unhealthy",
"service": "LLM Proxy",
"service": "LLM Service",
"error": str(e)
}
@@ -626,50 +759,83 @@ async def get_budget_status(
)
# Generic proxy endpoint for other LiteLLM endpoints
@router.api_route("/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
async def proxy_endpoint(
endpoint: str,
request: Request,
# Generic endpoint for additional LLM service functionality
@router.get("/metrics")
async def get_llm_metrics(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
):
"""Generic proxy endpoint for LiteLLM requests"""
"""Get LLM service metrics (admin only)"""
try:
# Check for admin permissions
auth_service = APIKeyAuthService(db)
# Check endpoint permission
if not await auth_service.check_endpoint_permission(context, endpoint):
if not await auth_service.check_scope_permission(context, "admin.metrics"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Endpoint '{endpoint}' not allowed"
detail="Admin permissions required to view metrics"
)
# Get request body
if request.method in ["POST", "PUT", "PATCH"]:
try:
payload = await request.json()
except:
payload = {}
else:
payload = dict(request.query_params)
# Make request to LiteLLM
response = await litellm_client.proxy_request(
method=request.method,
endpoint=endpoint,
payload=payload,
user_id=str(context["user_id"]),
api_key_id=context["api_key_id"]
)
return response
metrics = llm_service.get_metrics()
return {
"object": "llm_metrics",
"data": {
"total_requests": metrics.total_requests,
"successful_requests": metrics.successful_requests,
"failed_requests": metrics.failed_requests,
"security_blocked_requests": metrics.security_blocked_requests,
"average_latency_ms": metrics.average_latency_ms,
"average_risk_score": metrics.average_risk_score,
"provider_metrics": metrics.provider_metrics,
"last_updated": metrics.last_updated.isoformat()
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error proxying request to {endpoint}: {e}")
logger.error(f"Error getting LLM metrics: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to proxy request"
detail="Failed to get LLM metrics"
)
@router.get("/providers/status")
async def get_provider_status(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
):
"""Get status of all LLM providers"""
try:
auth_service = APIKeyAuthService(db)
if not await auth_service.check_scope_permission(context, "admin.status"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin permissions required to view provider status"
)
provider_status = await llm_service.get_provider_status()
return {
"object": "provider_status",
"data": {
name: {
"provider": status.provider,
"status": status.status,
"latency_ms": status.latency_ms,
"success_rate": status.success_rate,
"last_check": status.last_check.isoformat(),
"error_message": status.error_message,
"models_available": status.models_available
}
for name, status in provider_status.items()
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting provider status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get provider status"
)

View File

@@ -447,7 +447,7 @@ async def get_module_config(module_name: str):
log_api_request("get_module_config", {"module_name": module_name})
from app.services.module_config_manager import module_config_manager
from app.services.litellm_client import litellm_client
from app.services.llm.service import llm_service
import copy
# Get module manifest and schema
@@ -461,9 +461,9 @@ async def get_module_config(module_name: str):
# For Signal module, populate model options dynamically
if module_name == "signal" and schema:
try:
# Get available models from LiteLLM
models_data = await litellm_client.get_models()
model_ids = [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")]
# Get available models from LLM service
models_data = await llm_service.get_models()
model_ids = [model.id for model in models_data]
if model_ids:
# Create a copy of the schema to avoid modifying the original

View File

@@ -15,7 +15,8 @@ from app.models.prompt_template import PromptTemplate, ChatbotPromptVariable
from app.core.security import get_current_user
from app.models.user import User
from app.core.logging import log_api_request
from app.services.litellm_client import litellm_client
from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
router = APIRouter()
@@ -394,25 +395,28 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
]
# Get available models to use a default model
models = await litellm_client.get_models()
models = await llm_service.get_models()
if not models:
raise HTTPException(status_code=503, detail="No LLM models available")
# Use the first available model (you might want to make this configurable)
default_model = models[0]["id"]
default_model = models[0].id
# Make the AI call
response = await litellm_client.create_chat_completion(
# Prepare the chat request for the new LLM service
chat_request = LLMChatRequest(
model=default_model,
messages=messages,
user_id=str(user_id),
api_key_id=1, # Using default API key, you might want to make this dynamic
messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages],
temperature=0.3,
max_tokens=1000
max_tokens=1000,
user_id=str(user_id),
api_key_id=1 # Using default API key, you might want to make this dynamic
)
# Make the AI call
response = await llm_service.create_chat_completion(chat_request)
# Extract the improved prompt from the response
improved_prompt = response["choices"][0]["message"]["content"].strip()
improved_prompt = response.choices[0].message.content.strip()
return {
"improved_prompt": improved_prompt,

View File

@@ -51,7 +51,7 @@ class SystemInfoResponse(BaseModel):
environment: str
database_status: str
redis_status: str
litellm_status: str
llm_service_status: str
modules_loaded: int
active_users: int
total_api_keys: int
@@ -227,8 +227,13 @@ async def get_system_info(
# Get Redis status (simplified check)
redis_status = "healthy" # Would implement actual Redis check
# Get LiteLLM status (simplified check)
litellm_status = "healthy" # Would implement actual LiteLLM check
# Get LLM service status
try:
from app.services.llm.service import llm_service
health_summary = llm_service.get_health_summary()
llm_service_status = health_summary.get("service_status", "unknown")
except Exception:
llm_service_status = "unavailable"
# Get modules loaded (from module manager)
modules_loaded = 8 # Would get from actual module manager
@@ -261,7 +266,7 @@ async def get_system_info(
environment="production",
database_status=database_status,
redis_status=redis_status,
litellm_status=litellm_status,
llm_service_status=llm_service_status,
modules_loaded=modules_loaded,
active_users=active_users,
total_api_keys=total_api_keys,