diff --git a/backend/app/api/v1/llm.py b/backend/app/api/v1/llm.py index 0a1ddab..cfd8b8a 100644 --- a/backend/app/api/v1/llm.py +++ b/backend/app/api/v1/llm.py @@ -1,5 +1,5 @@ """ -LLM API endpoints - proxy to LiteLLM service with authentication and budget enforcement +LLM API endpoints - interface to secure LLM service with authentication and budget enforcement """ import logging @@ -16,7 +16,9 @@ from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthS from app.core.security import get_current_user from app.models.user import User from app.core.config import settings -from app.services.litellm_client import litellm_client +from app.services.llm.service import llm_service +from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage, EmbeddingRequest as LLMEmbeddingRequest +from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError from app.services.budget_enforcement import ( check_budget_for_request, record_request_usage, BudgetEnforcementService, atomic_check_and_reserve_budget, atomic_finalize_usage @@ -38,7 +40,7 @@ router = APIRouter() async def get_cached_models() -> List[Dict[str, Any]]: - """Get models from cache or fetch from LiteLLM if cache is stale""" + """Get models from cache or fetch from LLM service if cache is stale""" current_time = time.time() # Check if cache is still valid @@ -47,10 +49,20 @@ async def get_cached_models() -> List[Dict[str, Any]]: logger.debug("Returning cached models list") return _models_cache["data"] - # Cache miss or stale - fetch from LiteLLM + # Cache miss or stale - fetch from LLM service try: - logger.debug("Fetching fresh models list from LiteLLM") - models = await litellm_client.get_models() + logger.debug("Fetching fresh models list from LLM service") + model_infos = await llm_service.get_models() + + # Convert ModelInfo objects to dict format for compatibility + models = [] + for model_info in model_infos: + models.append({ + "id": model_info.id, + "object": model_info.object, + "created": model_info.created or int(time.time()), + "owned_by": model_info.owned_by + }) # Update cache _models_cache["data"] = models @@ -58,7 +70,7 @@ async def get_cached_models() -> List[Dict[str, Any]]: return models except Exception as e: - logger.error(f"Failed to fetch models from LiteLLM: {e}") + logger.error(f"Failed to fetch models from LLM service: {e}") # Return stale cache if available, otherwise empty list if _models_cache["data"] is not None: @@ -75,7 +87,7 @@ def invalidate_models_cache(): logger.info("Models cache invalidated") -# Request/Response Models +# Request/Response Models (API layer) class ChatMessage(BaseModel): role: str = Field(..., description="Message role (system, user, assistant)") content: str = Field(..., description="Message content") @@ -183,7 +195,7 @@ async def list_models( detail="Insufficient permissions to list models" ) - # Get models from cache or LiteLLM + # Get models from cache or LLM service models = await get_cached_models() # Filter models based on API key permissions @@ -309,35 +321,55 @@ async def create_chat_completion( warnings = budget_warnings reserved_budget_ids = budget_ids - # Convert messages to dict format - messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages] + # Convert messages to LLM service format + llm_messages = [LLMChatMessage(role=msg.role, content=msg.content) for msg in chat_request.messages] - # Prepare additional parameters - kwargs = {} - if chat_request.max_tokens is not None: - kwargs["max_tokens"] = chat_request.max_tokens - if chat_request.temperature is not None: - kwargs["temperature"] = chat_request.temperature - if chat_request.top_p is not None: - kwargs["top_p"] = chat_request.top_p - if chat_request.frequency_penalty is not None: - kwargs["frequency_penalty"] = chat_request.frequency_penalty - if chat_request.presence_penalty is not None: - kwargs["presence_penalty"] = chat_request.presence_penalty - if chat_request.stop is not None: - kwargs["stop"] = chat_request.stop - if chat_request.stream is not None: - kwargs["stream"] = chat_request.stream - - # Make request to LiteLLM - response = await litellm_client.create_chat_completion( + # Create LLM service request + llm_request = ChatRequest( model=chat_request.model, - messages=messages, + messages=llm_messages, + temperature=chat_request.temperature, + max_tokens=chat_request.max_tokens, + top_p=chat_request.top_p, + frequency_penalty=chat_request.frequency_penalty, + presence_penalty=chat_request.presence_penalty, + stop=chat_request.stop, + stream=chat_request.stream or False, user_id=str(context.get("user_id", "anonymous")), - api_key_id=context.get("api_key_id", "jwt_user"), - **kwargs + api_key_id=context.get("api_key_id", 0) if auth_type == "api_key" else 0 ) + # Make request to LLM service + llm_response = await llm_service.create_chat_completion(llm_request) + + # Convert LLM service response to API format + response = { + "id": llm_response.id, + "object": llm_response.object, + "created": llm_response.created, + "model": llm_response.model, + "choices": [ + { + "index": choice.index, + "message": { + "role": choice.message.role, + "content": choice.message.content + }, + "finish_reason": choice.finish_reason + } + for choice in llm_response.choices + ], + "usage": { + "prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0, + "completion_tokens": llm_response.usage.completion_tokens if llm_response.usage else 0, + "total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0 + } if llm_response.usage else { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + } + # Calculate actual cost and update usage usage = response.get("usage", {}) input_tokens = usage.get("prompt_tokens", 0) @@ -382,8 +414,38 @@ async def create_chat_completion( except HTTPException: raise + except SecurityError as e: + logger.warning(f"Security error in chat completion: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Security validation failed: {e.message}" + ) + except ValidationError as e: + logger.warning(f"Validation error in chat completion: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Request validation failed: {e.message}" + ) + except ProviderError as e: + logger.error(f"Provider error in chat completion: {e}") + if "rate limit" in str(e).lower(): + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Rate limit exceeded" + ) + else: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="LLM service temporarily unavailable" + ) + except LLMError as e: + logger.error(f"LLM service error in chat completion: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="LLM service error" + ) except Exception as e: - logger.error(f"Error creating chat completion: {e}") + logger.error(f"Unexpected error creating chat completion: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create chat completion" @@ -438,15 +500,39 @@ async def create_embedding( detail=f"Budget exceeded: {error_message}" ) - # Make request to LiteLLM - response = await litellm_client.create_embedding( + # Create LLM service request + llm_request = LLMEmbeddingRequest( model=request.model, - input_text=request.input, + input=request.input, + encoding_format=request.encoding_format, user_id=str(context["user_id"]), - api_key_id=context["api_key_id"], - encoding_format=request.encoding_format + api_key_id=context["api_key_id"] ) + # Make request to LLM service + llm_response = await llm_service.create_embedding(llm_request) + + # Convert LLM service response to API format + response = { + "object": llm_response.object, + "data": [ + { + "object": emb.object, + "index": emb.index, + "embedding": emb.embedding + } + for emb in llm_response.data + ], + "model": llm_response.model, + "usage": { + "prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0, + "total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0 + } if llm_response.usage else { + "prompt_tokens": int(estimated_tokens), + "total_tokens": int(estimated_tokens) + } + } + # Calculate actual cost and update usage usage = response.get("usage", {}) total_tokens = usage.get("total_tokens", int(estimated_tokens)) @@ -475,8 +561,38 @@ async def create_embedding( except HTTPException: raise + except SecurityError as e: + logger.warning(f"Security error in embedding: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Security validation failed: {e.message}" + ) + except ValidationError as e: + logger.warning(f"Validation error in embedding: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Request validation failed: {e.message}" + ) + except ProviderError as e: + logger.error(f"Provider error in embedding: {e}") + if "rate limit" in str(e).lower(): + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Rate limit exceeded" + ) + else: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="LLM service temporarily unavailable" + ) + except LLMError as e: + logger.error(f"LLM service error in embedding: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="LLM service error" + ) except Exception as e: - logger.error(f"Error creating embedding: {e}") + logger.error(f"Unexpected error creating embedding: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create embedding" @@ -489,11 +605,28 @@ async def llm_health_check( ): """Health check for LLM service""" try: - health_status = await litellm_client.health_check() + health_summary = llm_service.get_health_summary() + provider_status = await llm_service.get_provider_status() + + # Determine overall health + overall_status = "healthy" + if health_summary["service_status"] != "healthy": + overall_status = "degraded" + + for provider, status in provider_status.items(): + if status.status == "unavailable": + overall_status = "degraded" + break + return { - "status": "healthy", - "service": "LLM Proxy", - "litellm_status": health_status, + "status": overall_status, + "service": "LLM Service", + "service_status": health_summary, + "provider_status": {name: { + "status": status.status, + "latency_ms": status.latency_ms, + "error_message": status.error_message + } for name, status in provider_status.items()}, "user_id": context["user_id"], "api_key_name": context["api_key_name"] } @@ -501,7 +634,7 @@ async def llm_health_check( logger.error(f"LLM health check error: {e}") return { "status": "unhealthy", - "service": "LLM Proxy", + "service": "LLM Service", "error": str(e) } @@ -626,50 +759,83 @@ async def get_budget_status( ) -# Generic proxy endpoint for other LiteLLM endpoints -@router.api_route("/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) -async def proxy_endpoint( - endpoint: str, - request: Request, +# Generic endpoint for additional LLM service functionality +@router.get("/metrics") +async def get_llm_metrics( context: Dict[str, Any] = Depends(require_api_key), db: AsyncSession = Depends(get_db) ): - """Generic proxy endpoint for LiteLLM requests""" + """Get LLM service metrics (admin only)""" try: + # Check for admin permissions auth_service = APIKeyAuthService(db) - - # Check endpoint permission - if not await auth_service.check_endpoint_permission(context, endpoint): + if not await auth_service.check_scope_permission(context, "admin.metrics"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Endpoint '{endpoint}' not allowed" + detail="Admin permissions required to view metrics" ) - # Get request body - if request.method in ["POST", "PUT", "PATCH"]: - try: - payload = await request.json() - except: - payload = {} - else: - payload = dict(request.query_params) - - # Make request to LiteLLM - response = await litellm_client.proxy_request( - method=request.method, - endpoint=endpoint, - payload=payload, - user_id=str(context["user_id"]), - api_key_id=context["api_key_id"] - ) - - return response + metrics = llm_service.get_metrics() + return { + "object": "llm_metrics", + "data": { + "total_requests": metrics.total_requests, + "successful_requests": metrics.successful_requests, + "failed_requests": metrics.failed_requests, + "security_blocked_requests": metrics.security_blocked_requests, + "average_latency_ms": metrics.average_latency_ms, + "average_risk_score": metrics.average_risk_score, + "provider_metrics": metrics.provider_metrics, + "last_updated": metrics.last_updated.isoformat() + } + } except HTTPException: raise except Exception as e: - logger.error(f"Error proxying request to {endpoint}: {e}") + logger.error(f"Error getting LLM metrics: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to proxy request" + detail="Failed to get LLM metrics" + ) + + +@router.get("/providers/status") +async def get_provider_status( + context: Dict[str, Any] = Depends(require_api_key), + db: AsyncSession = Depends(get_db) +): + """Get status of all LLM providers""" + try: + auth_service = APIKeyAuthService(db) + if not await auth_service.check_scope_permission(context, "admin.status"): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Admin permissions required to view provider status" + ) + + provider_status = await llm_service.get_provider_status() + return { + "object": "provider_status", + "data": { + name: { + "provider": status.provider, + "status": status.status, + "latency_ms": status.latency_ms, + "success_rate": status.success_rate, + "last_check": status.last_check.isoformat(), + "error_message": status.error_message, + "models_available": status.models_available + } + for name, status in provider_status.items() + } + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting provider status: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to get provider status" ) \ No newline at end of file diff --git a/backend/app/api/v1/modules.py b/backend/app/api/v1/modules.py index 81bcef8..b2814e5 100644 --- a/backend/app/api/v1/modules.py +++ b/backend/app/api/v1/modules.py @@ -447,7 +447,7 @@ async def get_module_config(module_name: str): log_api_request("get_module_config", {"module_name": module_name}) from app.services.module_config_manager import module_config_manager - from app.services.litellm_client import litellm_client + from app.services.llm.service import llm_service import copy # Get module manifest and schema @@ -461,9 +461,9 @@ async def get_module_config(module_name: str): # For Signal module, populate model options dynamically if module_name == "signal" and schema: try: - # Get available models from LiteLLM - models_data = await litellm_client.get_models() - model_ids = [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")] + # Get available models from LLM service + models_data = await llm_service.get_models() + model_ids = [model.id for model in models_data] if model_ids: # Create a copy of the schema to avoid modifying the original diff --git a/backend/app/api/v1/prompt_templates.py b/backend/app/api/v1/prompt_templates.py index 71983c8..df41d31 100644 --- a/backend/app/api/v1/prompt_templates.py +++ b/backend/app/api/v1/prompt_templates.py @@ -15,7 +15,8 @@ from app.models.prompt_template import PromptTemplate, ChatbotPromptVariable from app.core.security import get_current_user from app.models.user import User from app.core.logging import log_api_request -from app.services.litellm_client import litellm_client +from app.services.llm.service import llm_service +from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage router = APIRouter() @@ -394,25 +395,28 @@ Please improve this prompt to make it more effective for a {request.chatbot_type ] # Get available models to use a default model - models = await litellm_client.get_models() + models = await llm_service.get_models() if not models: raise HTTPException(status_code=503, detail="No LLM models available") # Use the first available model (you might want to make this configurable) - default_model = models[0]["id"] + default_model = models[0].id - # Make the AI call - response = await litellm_client.create_chat_completion( + # Prepare the chat request for the new LLM service + chat_request = LLMChatRequest( model=default_model, - messages=messages, - user_id=str(user_id), - api_key_id=1, # Using default API key, you might want to make this dynamic + messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages], temperature=0.3, - max_tokens=1000 + max_tokens=1000, + user_id=str(user_id), + api_key_id=1 # Using default API key, you might want to make this dynamic ) + # Make the AI call + response = await llm_service.create_chat_completion(chat_request) + # Extract the improved prompt from the response - improved_prompt = response["choices"][0]["message"]["content"].strip() + improved_prompt = response.choices[0].message.content.strip() return { "improved_prompt": improved_prompt, diff --git a/backend/app/api/v1/settings.py b/backend/app/api/v1/settings.py index ba5b982..58a1f8b 100644 --- a/backend/app/api/v1/settings.py +++ b/backend/app/api/v1/settings.py @@ -51,7 +51,7 @@ class SystemInfoResponse(BaseModel): environment: str database_status: str redis_status: str - litellm_status: str + llm_service_status: str modules_loaded: int active_users: int total_api_keys: int @@ -227,8 +227,13 @@ async def get_system_info( # Get Redis status (simplified check) redis_status = "healthy" # Would implement actual Redis check - # Get LiteLLM status (simplified check) - litellm_status = "healthy" # Would implement actual LiteLLM check + # Get LLM service status + try: + from app.services.llm.service import llm_service + health_summary = llm_service.get_health_summary() + llm_service_status = health_summary.get("service_status", "unknown") + except Exception: + llm_service_status = "unavailable" # Get modules loaded (from module manager) modules_loaded = 8 # Would get from actual module manager @@ -261,7 +266,7 @@ async def get_system_info( environment="production", database_status=database_status, redis_status=redis_status, - litellm_status=litellm_status, + llm_service_status=llm_service_status, modules_loaded=modules_loaded, active_users=active_users, total_api_keys=total_api_keys, diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 62a1011..6448b71 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -43,15 +43,18 @@ class Settings(BaseSettings): # CORS CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"] - # LiteLLM - LITELLM_BASE_URL: str = "http://localhost:4000" - LITELLM_MASTER_KEY: str = "enclava-master-key" + # LLM Service Configuration (replaced LiteLLM) + # LLM service configuration is now handled in app/services/llm/config.py + + # LLM Service Security + LLM_ENCRYPTION_KEY: Optional[str] = None # Key for encrypting LLM provider API keys # API Keys for LLM providers OPENAI_API_KEY: Optional[str] = None ANTHROPIC_API_KEY: Optional[str] = None GOOGLE_API_KEY: Optional[str] = None PRIVATEMODE_API_KEY: Optional[str] = None + PRIVATEMODE_PROXY_URL: str = "http://privatemode-proxy:8080/v1" # Qdrant QDRANT_HOST: str = "localhost" diff --git a/backend/app/services/embedding_service.py b/backend/app/services/embedding_service.py index a8da70e..9a9f457 100644 --- a/backend/app/services/embedding_service.py +++ b/backend/app/services/embedding_service.py @@ -1,6 +1,6 @@ """ Embedding Service -Provides text embedding functionality using LiteLLM proxy +Provides text embedding functionality using LLM service """ import logging @@ -11,32 +11,34 @@ logger = logging.getLogger(__name__) class EmbeddingService: - """Service for generating text embeddings using LiteLLM""" + """Service for generating text embeddings using LLM service""" def __init__(self, model_name: str = "privatemode-embeddings"): self.model_name = model_name - self.litellm_client = None self.dimension = 1024 # Actual dimension for privatemode-embeddings self.initialized = False async def initialize(self): - """Initialize the embedding service with LiteLLM""" + """Initialize the embedding service with LLM service""" try: - from app.services.litellm_client import litellm_client - self.litellm_client = litellm_client + from app.services.llm.service import llm_service - # Test connection to LiteLLM - health = await self.litellm_client.health_check() - if health.get("status") == "unhealthy": - logger.error(f"LiteLLM service unhealthy: {health.get('error')}") + # Initialize LLM service if not already done + if not llm_service._initialized: + await llm_service.initialize() + + # Test LLM service health + health_summary = llm_service.get_health_summary() + if health_summary.get("service_status") != "healthy": + logger.error(f"LLM service unhealthy: {health_summary}") return False self.initialized = True - logger.info(f"Embedding service initialized with LiteLLM: {self.model_name} (dimension: {self.dimension})") + logger.info(f"Embedding service initialized with LLM service: {self.model_name} (dimension: {self.dimension})") return True except Exception as e: - logger.error(f"Failed to initialize LiteLLM embedding service: {e}") + logger.error(f"Failed to initialize LLM embedding service: {e}") logger.warning("Using fallback random embeddings") return False @@ -46,10 +48,10 @@ class EmbeddingService: return embeddings[0] async def get_embeddings(self, texts: List[str]) -> List[List[float]]: - """Get embeddings for multiple texts using LiteLLM""" - if not self.initialized or not self.litellm_client: + """Get embeddings for multiple texts using LLM service""" + if not self.initialized: # Fallback to random embeddings if not initialized - logger.warning("LiteLLM not available, using random embeddings") + logger.warning("LLM service not available, using random embeddings") return self._generate_fallback_embeddings(texts) try: @@ -73,17 +75,22 @@ class EmbeddingService: else: truncated_text = text - # Call LiteLLM embedding endpoint - response = await self.litellm_client.create_embedding( + # Call LLM service embedding endpoint + from app.services.llm.service import llm_service + from app.services.llm.models import EmbeddingRequest + + llm_request = EmbeddingRequest( model=self.model_name, - input_text=truncated_text, + input=truncated_text, user_id="rag_system", api_key_id=0 # System API key ) + response = await llm_service.create_embedding(llm_request) + # Extract embedding from response - if "data" in response and len(response["data"]) > 0: - embedding = response["data"][0].get("embedding", []) + if response.data and len(response.data) > 0: + embedding = response.data[0].embedding if embedding: batch_embeddings.append(embedding) # Update dimension based on actual embedding size @@ -106,7 +113,7 @@ class EmbeddingService: return embeddings except Exception as e: - logger.error(f"Error generating embeddings with LiteLLM: {e}") + logger.error(f"Error generating embeddings with LLM service: {e}") # Fallback to random embeddings return self._generate_fallback_embeddings(texts) @@ -146,14 +153,13 @@ class EmbeddingService: "model_name": self.model_name, "model_loaded": self.initialized, "dimension": self.dimension, - "backend": "LiteLLM", + "backend": "LLM Service", "initialized": self.initialized } async def cleanup(self): """Cleanup resources""" self.initialized = False - self.litellm_client = None # Global embedding service instance diff --git a/backend/app/services/litellm_client.py b/backend/app/services/litellm_client.py deleted file mode 100644 index f636043..0000000 --- a/backend/app/services/litellm_client.py +++ /dev/null @@ -1,304 +0,0 @@ -""" -LiteLLM Client Service -Handles communication with the LiteLLM proxy service -""" - -import asyncio -import json -import logging -from typing import Dict, Any, Optional, List -from datetime import datetime - -import aiohttp -from fastapi import HTTPException, status - -from app.core.config import settings - -logger = logging.getLogger(__name__) - - -class LiteLLMClient: - """Client for communicating with LiteLLM proxy service""" - - def __init__(self): - self.base_url = settings.LITELLM_BASE_URL - self.master_key = settings.LITELLM_MASTER_KEY - self.session: Optional[aiohttp.ClientSession] = None - self.timeout = aiohttp.ClientTimeout(total=600) # 10 minutes timeout - - async def _get_session(self) -> aiohttp.ClientSession: - """Get or create aiohttp session""" - if self.session is None or self.session.closed: - self.session = aiohttp.ClientSession( - timeout=self.timeout, - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {self.master_key}" - } - ) - return self.session - - async def close(self): - """Close the HTTP session""" - if self.session and not self.session.closed: - await self.session.close() - - async def health_check(self) -> Dict[str, Any]: - """Check LiteLLM proxy health""" - try: - session = await self._get_session() - async with session.get(f"{self.base_url}/health") as response: - if response.status == 200: - return await response.json() - else: - logger.error(f"LiteLLM health check failed: {response.status}") - return {"status": "unhealthy", "error": f"HTTP {response.status}"} - except Exception as e: - logger.error(f"LiteLLM health check error: {e}") - return {"status": "unhealthy", "error": str(e)} - - async def get_models(self) -> List[Dict[str, Any]]: - """Get available models from LiteLLM""" - try: - session = await self._get_session() - async with session.get(f"{self.base_url}/models") as response: - if response.status == 200: - data = await response.json() - return data.get("data", []) - else: - logger.error(f"Failed to get models: {response.status}") - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="LiteLLM service unavailable" - ) - except aiohttp.ClientError as e: - logger.error(f"LiteLLM models request error: {e}") - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="LiteLLM service unavailable" - ) - - async def create_chat_completion( - self, - model: str, - messages: List[Dict[str, str]], - user_id: str, - api_key_id: int, - **kwargs - ) -> Dict[str, Any]: - """Create chat completion via LiteLLM proxy""" - try: - # Prepare request payload - payload = { - "model": model, - "messages": messages, - "user": f"user_{user_id}", # User identifier for tracking - "metadata": { - "api_key_id": api_key_id, - "user_id": user_id, - "timestamp": datetime.utcnow().isoformat() - }, - **kwargs - } - - session = await self._get_session() - async with session.post( - f"{self.base_url}/chat/completions", - json=payload - ) as response: - if response.status == 200: - return await response.json() - else: - error_text = await response.text() - logger.error(f"LiteLLM chat completion failed: {response.status} - {error_text}") - - # Handle specific error cases - if response.status == 401: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API key" - ) - elif response.status == 429: - raise HTTPException( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - detail="Rate limit exceeded" - ) - elif response.status == 400: - try: - error_data = await response.json() - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=error_data.get("error", {}).get("message", "Bad request") - ) - except: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid request" - ) - else: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="LiteLLM service error" - ) - except aiohttp.ClientError as e: - logger.error(f"LiteLLM chat completion request error: {e}") - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="LiteLLM service unavailable" - ) - - async def create_embedding( - self, - model: str, - input_text: str, - user_id: str, - api_key_id: int, - **kwargs - ) -> Dict[str, Any]: - """Create embedding via LiteLLM proxy""" - try: - payload = { - "model": model, - "input": input_text, - "user": f"user_{user_id}", - "metadata": { - "api_key_id": api_key_id, - "user_id": user_id, - "timestamp": datetime.utcnow().isoformat() - }, - **kwargs - } - - session = await self._get_session() - async with session.post( - f"{self.base_url}/embeddings", - json=payload - ) as response: - if response.status == 200: - return await response.json() - else: - error_text = await response.text() - logger.error(f"LiteLLM embedding failed: {response.status} - {error_text}") - - if response.status == 401: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API key" - ) - elif response.status == 429: - raise HTTPException( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - detail="Rate limit exceeded" - ) - else: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="LiteLLM service error" - ) - except aiohttp.ClientError as e: - logger.error(f"LiteLLM embedding request error: {e}") - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="LiteLLM service unavailable" - ) - - async def get_models(self) -> List[Dict[str, Any]]: - """Get available models from LiteLLM proxy""" - try: - session = await self._get_session() - async with session.get(f"{self.base_url}/models") as response: - if response.status == 200: - data = await response.json() - # Return models with exact names from upstream providers - models = data.get("data", []) - - # Pass through model names exactly as they come from upstream - # Don't modify model IDs - keep them as the original provider names - processed_models = [] - for model in models: - # Keep the exact model ID from upstream provider - processed_models.append({ - "id": model.get("id", ""), # Exact model name from provider - "object": model.get("object", "model"), - "created": model.get("created", 1677610602), - "owned_by": model.get("owned_by", "openai") - }) - - return processed_models - else: - error_text = await response.text() - logger.error(f"LiteLLM models request failed: {response.status} - {error_text}") - return [] - except aiohttp.ClientError as e: - logger.error(f"LiteLLM models request error: {e}") - return [] - - async def proxy_request( - self, - method: str, - endpoint: str, - payload: Dict[str, Any], - user_id: str, - api_key_id: int - ) -> Dict[str, Any]: - """Generic proxy request to LiteLLM""" - try: - # Add metadata to payload - if isinstance(payload, dict): - payload["metadata"] = { - "api_key_id": api_key_id, - "user_id": user_id, - "timestamp": datetime.utcnow().isoformat() - } - if "user" not in payload: - payload["user"] = f"user_{user_id}" - - session = await self._get_session() - - # Make the request - async with session.request( - method, - f"{self.base_url}/{endpoint.lstrip('/')}", - json=payload if method.upper() in ['POST', 'PUT', 'PATCH'] else None, - params=payload if method.upper() == 'GET' else None - ) as response: - if response.status == 200: - return await response.json() - else: - error_text = await response.text() - logger.error(f"LiteLLM proxy request failed: {response.status} - {error_text}") - - if response.status == 401: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API key" - ) - elif response.status == 429: - raise HTTPException( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - detail="Rate limit exceeded" - ) - else: - raise HTTPException( - status_code=response.status, - detail=f"LiteLLM service error: {error_text}" - ) - except aiohttp.ClientError as e: - logger.error(f"LiteLLM proxy request error: {e}") - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="LiteLLM service unavailable" - ) - - async def list_models(self) -> List[str]: - """Get list of available model names/IDs""" - try: - models_data = await self.get_models() - return [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")] - except Exception as e: - logger.error(f"Error listing model names: {str(e)}") - return [] - - -# Global LiteLLM client instance -litellm_client = LiteLLMClient() \ No newline at end of file diff --git a/backend/modules/chatbot/main.py b/backend/modules/chatbot/main.py index 68a7fb5..75d3418 100644 --- a/backend/modules/chatbot/main.py +++ b/backend/modules/chatbot/main.py @@ -23,7 +23,9 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session from app.core.logging import get_logger -from app.services.litellm_client import LiteLLMClient +from app.services.llm.service import llm_service +from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage +from app.services.llm.exceptions import LLMError, ProviderError, SecurityError from app.services.base_module import BaseModule, Permission from app.models.user import User from app.models.chatbot import ChatbotInstance as DBChatbotInstance, ChatbotConversation as DBConversation, ChatbotMessage as DBMessage, ChatbotAnalytics @@ -32,7 +34,8 @@ from app.db.database import get_db from app.core.config import settings # Import protocols for type hints and dependency injection -from ..protocols import RAGServiceProtocol, LiteLLMClientProtocol +from ..protocols import RAGServiceProtocol +# Note: LiteLLMClientProtocol replaced with direct LLM service usage logger = get_logger(__name__) @@ -131,10 +134,8 @@ class ChatbotInstance(BaseModel): class ChatbotModule(BaseModule): """Main chatbot module implementation""" - def __init__(self, litellm_client: Optional[LiteLLMClientProtocol] = None, - rag_service: Optional[RAGServiceProtocol] = None): + def __init__(self, rag_service: Optional[RAGServiceProtocol] = None): super().__init__("chatbot") - self.litellm_client = litellm_client self.rag_module = rag_service # Keep same name for compatibility self.db_session = None @@ -145,15 +146,10 @@ class ChatbotModule(BaseModule): """Initialize the chatbot module""" await super().initialize(**kwargs) - # Get dependencies from global services if not already injected - if not self.litellm_client: - try: - from app.services.litellm_client import litellm_client - self.litellm_client = litellm_client - logger.info("LiteLLM client injected from global service") - except Exception as e: - logger.warning(f"Could not inject LiteLLM client: {e}") + # Initialize the LLM service + await llm_service.initialize() + # Get RAG module dependency if not already injected if not self.rag_module: try: # Try to get RAG module from module manager @@ -168,19 +164,16 @@ class ChatbotModule(BaseModule): await self._load_prompt_templates() logger.info("Chatbot module initialized") - logger.info(f"LiteLLM client available after init: {self.litellm_client is not None}") + logger.info(f"LLM service available: {llm_service._initialized}") logger.info(f"RAG module available after init: {self.rag_module is not None}") logger.info(f"Loaded {len(self.system_prompts)} prompt templates") async def _ensure_dependencies(self): """Lazy load dependencies if not available""" - if not self.litellm_client: - try: - from app.services.litellm_client import litellm_client - self.litellm_client = litellm_client - logger.info("LiteLLM client lazy loaded") - except Exception as e: - logger.warning(f"Could not lazy load LiteLLM client: {e}") + # Ensure LLM service is initialized + if not llm_service._initialized: + await llm_service.initialize() + logger.info("LLM service lazy loaded") if not self.rag_module: try: @@ -468,45 +461,58 @@ class ChatbotModule(BaseModule): logger.info(msg['content']) logger.info("=== END COMPREHENSIVE LLM REQUEST ===") - if self.litellm_client: - try: - logger.info("Calling LiteLLM client create_chat_completion...") - response = await self.litellm_client.create_chat_completion( - model=config.model, - messages=messages, - user_id="chatbot_user", - api_key_id="chatbot_api_key", - temperature=config.temperature, - max_tokens=config.max_tokens - ) - logger.info(f"LiteLLM response received, response keys: {list(response.keys())}") + try: + logger.info("Calling LLM service create_chat_completion...") + + # Convert messages to LLM service format + llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages] + + # Create LLM service request + llm_request = LLMChatRequest( + model=config.model, + messages=llm_messages, + temperature=config.temperature, + max_tokens=config.max_tokens, + user_id="chatbot_user", + api_key_id=0 # Chatbot module uses internal service + ) + + # Make request to LLM service + llm_response = await llm_service.create_chat_completion(llm_request) + + # Extract response content + if llm_response.choices: + content = llm_response.choices[0].message.content + logger.info(f"Response content length: {len(content)}") - # Extract response content from the LiteLLM response format - if 'choices' in response and response['choices']: - content = response['choices'][0]['message']['content'] - logger.info(f"Response content length: {len(content)}") - - # Always log response for debugging - logger.info("=== COMPREHENSIVE LLM RESPONSE ===") - logger.info(f"Response content ({len(content)} chars):") - logger.info(content) - if 'usage' in response: - usage = response['usage'] - logger.info(f"Token usage - Prompt: {usage.get('prompt_tokens', 'N/A')}, Completion: {usage.get('completion_tokens', 'N/A')}, Total: {usage.get('total_tokens', 'N/A')}") - if sources: - logger.info(f"RAG sources included: {len(sources)} documents") - logger.info("=== END COMPREHENSIVE LLM RESPONSE ===") - - return content, sources - else: - logger.warning("No choices in LiteLLM response") - return "I received an empty response from the AI model.", sources - except Exception as e: - logger.error(f"LiteLLM completion failed: {e}") - raise e - else: - logger.warning("No LiteLLM client available, using fallback") - # Fallback if no LLM client + # Always log response for debugging + logger.info("=== COMPREHENSIVE LLM RESPONSE ===") + logger.info(f"Response content ({len(content)} chars):") + logger.info(content) + if llm_response.usage: + usage = llm_response.usage + logger.info(f"Token usage - Prompt: {usage.prompt_tokens}, Completion: {usage.completion_tokens}, Total: {usage.total_tokens}") + if sources: + logger.info(f"RAG sources included: {len(sources)} documents") + logger.info("=== END COMPREHENSIVE LLM RESPONSE ===") + + return content, sources + else: + logger.warning("No choices in LLM response") + return "I received an empty response from the AI model.", sources + + except SecurityError as e: + logger.error(f"Security error in LLM completion: {e}") + raise HTTPException(status_code=400, detail=f"Security validation failed: {e.message}") + except ProviderError as e: + logger.error(f"Provider error in LLM completion: {e}") + raise HTTPException(status_code=503, detail="LLM service temporarily unavailable") + except LLMError as e: + logger.error(f"LLM service error: {e}") + raise HTTPException(status_code=500, detail="LLM service error") + except Exception as e: + logger.error(f"LLM completion failed: {e}") + # Return fallback if available return "I'm currently unable to process your request. Please try again later.", None def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig, @@ -685,7 +691,7 @@ class ChatbotModule(BaseModule): # Lazy load dependencies await self._ensure_dependencies() - logger.info(f"LiteLLM client available: {self.litellm_client is not None}") + logger.info(f"LLM service available: {llm_service._initialized}") logger.info(f"RAG module available: {self.rag_module is not None}") try: @@ -884,10 +890,9 @@ class ChatbotModule(BaseModule): # Module factory function -def create_module(litellm_client: Optional[LiteLLMClientProtocol] = None, - rag_service: Optional[RAGServiceProtocol] = None) -> ChatbotModule: +def create_module(rag_service: Optional[RAGServiceProtocol] = None) -> ChatbotModule: """Factory function to create chatbot module instance""" - return ChatbotModule(litellm_client=litellm_client, rag_service=rag_service) + return ChatbotModule(rag_service=rag_service) # Create module instance (dependencies will be injected via factory) chatbot_module = ChatbotModule() \ No newline at end of file diff --git a/backend/modules/rag/main.py b/backend/modules/rag/main.py index acd2de1..d52bbd5 100644 --- a/backend/modules/rag/main.py +++ b/backend/modules/rag/main.py @@ -401,7 +401,7 @@ class RAGModule(BaseModule): """Initialize embedding model""" from app.services.embedding_service import embedding_service - # Use privatemode-embeddings for LiteLLM integration + # Use privatemode-embeddings for LLM service integration model_name = self.config.get("embedding_model", "privatemode-embeddings") embedding_service.model_name = model_name diff --git a/backend/modules/workflow/main.py b/backend/modules/workflow/main.py index 1a5d413..adfcd03 100644 --- a/backend/modules/workflow/main.py +++ b/backend/modules/workflow/main.py @@ -22,13 +22,16 @@ from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session from sqlalchemy import select from app.core.logging import get_logger -from app.services.litellm_client import LiteLLMClient +from app.services.llm.service import llm_service +from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage +from app.services.llm.exceptions import LLMError, ProviderError, SecurityError from app.services.base_module import Permission from app.db.database import SessionLocal from app.models.workflow import WorkflowDefinition as DBWorkflowDefinition, WorkflowExecution as DBWorkflowExecution # Import protocols for type hints and dependency injection -from ..protocols import ChatbotServiceProtocol, LiteLLMClientProtocol +from ..protocols import ChatbotServiceProtocol +# Note: LiteLLMClientProtocol replaced with direct LLM service usage logger = get_logger(__name__) @@ -234,8 +237,7 @@ class WorkflowExecution(BaseModel): class WorkflowEngine: """Core workflow execution engine""" - def __init__(self, litellm_client: LiteLLMClient, chatbot_service: Optional[ChatbotServiceProtocol] = None): - self.litellm_client = litellm_client + def __init__(self, chatbot_service: Optional[ChatbotServiceProtocol] = None): self.chatbot_service = chatbot_service self.executions: Dict[str, WorkflowExecution] = {} self.workflows: Dict[str, WorkflowDefinition] = {} @@ -343,15 +345,23 @@ class WorkflowEngine: # Template message content with context variables messages = self._template_messages(llm_step.messages, context.variables) - # Make LLM call - response = await self.litellm_client.chat_completion( + # Convert messages to LLM service format + llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages] + + # Create LLM service request + llm_request = LLMChatRequest( model=llm_step.model, - messages=messages, - **llm_step.parameters + messages=llm_messages, + user_id="workflow_user", + api_key_id=0, # Workflow module uses internal service + **{k: v for k, v in llm_step.parameters.items() if k in ['temperature', 'max_tokens', 'top_p', 'frequency_penalty', 'presence_penalty', 'stop']} ) + # Make LLM call + response = await llm_service.create_chat_completion(llm_request) + # Store result - result = response.get("choices", [{}])[0].get("message", {}).get("content", "") + result = response.choices[0].message.content if response.choices else "" context.variables[llm_step.output_variable] = result context.results[step.id] = result @@ -631,16 +641,21 @@ class WorkflowEngine: messages = [{"role": "user", "content": self._template_string(prompt, variables)}] - response = await self.litellm_client.create_chat_completion( + # Convert to LLM service format + llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages] + + llm_request = LLMChatRequest( model=step.model, - messages=messages, + messages=llm_messages, user_id="workflow_system", - api_key_id="workflow", + api_key_id=0, temperature=step.temperature, max_tokens=step.max_tokens ) - return response.get("choices", [{}])[0].get("message", {}).get("content", "") + response = await llm_service.create_chat_completion(llm_request) + + return response.choices[0].message.content if response.choices else "" async def _generate_brand_names(self, variables: Dict[str, Any], step: AIGenerationStep) -> List[Dict[str, str]]: """Generate brand names for a specific category""" @@ -687,16 +702,21 @@ class WorkflowEngine: messages = [{"role": "user", "content": self._template_string(prompt, variables)}] - response = await self.litellm_client.create_chat_completion( + # Convert to LLM service format + llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages] + + llm_request = LLMChatRequest( model=step.model, - messages=messages, + messages=llm_messages, user_id="workflow_system", - api_key_id="workflow", + api_key_id=0, temperature=step.temperature, max_tokens=step.max_tokens ) - return response.get("choices", [{}])[0].get("message", {}).get("content", "") + response = await llm_service.create_chat_completion(llm_request) + + return response.choices[0].message.content if response.choices else "" async def _generate_custom_prompt(self, variables: Dict[str, Any], step: AIGenerationStep) -> str: """Generate content using custom prompt template""" @@ -705,16 +725,21 @@ class WorkflowEngine: messages = [{"role": "user", "content": self._template_string(step.prompt_template, variables)}] - response = await self.litellm_client.create_chat_completion( + # Convert to LLM service format + llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages] + + llm_request = LLMChatRequest( model=step.model, - messages=messages, + messages=llm_messages, user_id="workflow_system", - api_key_id="workflow", + api_key_id=0, temperature=step.temperature, max_tokens=step.max_tokens ) - return response.get("choices", [{}])[0].get("message", {}).get("content", "") + response = await llm_service.create_chat_completion(llm_request) + + return response.choices[0].message.content if response.choices else "" async def _execute_aggregate_step(self, step: WorkflowStep, context: WorkflowContext): """Execute aggregate step to combine multiple inputs""" diff --git a/backend/modules/zammad/main.py b/backend/modules/zammad/main.py index 2ffd743..b5964b4 100644 --- a/backend/modules/zammad/main.py +++ b/backend/modules/zammad/main.py @@ -23,7 +23,8 @@ from app.core.config import settings from app.db.database import async_session_factory from app.models.user import User from app.models.chatbot import ChatbotInstance -from app.services.litellm_client import LiteLLMClient +from app.services.llm.service import llm_service +from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage from cryptography.fernet import Fernet import base64 import os @@ -65,8 +66,8 @@ class ZammadModule(BaseModule): try: logger.info("Initializing Zammad module...") - # Initialize LLM client for chatbot integration - self.llm_client = LiteLLMClient() + # Initialize LLM service for chatbot integration + # Note: llm_service is already a global singleton, no need to create instance # Create HTTP session pool for Zammad API calls timeout = aiohttp.ClientTimeout(total=60, connect=10) @@ -597,19 +598,21 @@ class ZammadModule(BaseModule): } ] - # Generate summary using LLM client - response = await self.llm_client.create_chat_completion( - messages=messages, + # Generate summary using new LLM service + chat_request = LLMChatRequest( model=await self._get_chatbot_model(config.chatbot_id), - user_id=str(config.user_id), - api_key_id=0, # Using 0 for module requests + messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages], temperature=0.3, - max_tokens=500 + max_tokens=500, + user_id=str(config.user_id), + api_key_id=0 # Using 0 for module requests ) - # Extract content from LiteLLM response - if "choices" in response and len(response["choices"]) > 0: - return response["choices"][0]["message"]["content"].strip() + response = await llm_service.create_chat_completion(chat_request) + + # Extract content from new LLM service response + if response.choices and len(response.choices) > 0: + return response.choices[0].message.content.strip() return "Unable to generate summary." diff --git a/backend/tests/api/test_llm_endpoints.py b/backend/tests/api/test_llm_endpoints.py deleted file mode 100644 index 6e83c64..0000000 --- a/backend/tests/api/test_llm_endpoints.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -Test LLM API endpoints. -""" -import pytest -from httpx import AsyncClient -from unittest.mock import patch, AsyncMock - - -class TestLLMEndpoints: - """Test LLM API endpoints.""" - - @pytest.mark.asyncio - async def test_chat_completion_success(self, client: AsyncClient): - """Test successful chat completion.""" - # Mock the LiteLLM client response - mock_response = { - "choices": [ - { - "message": { - "content": "Hello! How can I help you today?", - "role": "assistant" - } - } - ], - "usage": { - "prompt_tokens": 10, - "completion_tokens": 15, - "total_tokens": 25 - } - } - - with patch("app.services.litellm_client.LiteLLMClient.create_chat_completion") as mock_chat: - mock_chat.return_value = mock_response - - response = await client.post( - "/api/v1/llm/chat/completions", - json={ - "model": "gpt-3.5-turbo", - "messages": [ - {"role": "user", "content": "Hello"} - ] - }, - headers={"Authorization": "Bearer test-api-key"} - ) - - assert response.status_code == 200 - data = response.json() - assert "choices" in data - assert data["choices"][0]["message"]["content"] == "Hello! How can I help you today?" - - @pytest.mark.asyncio - async def test_chat_completion_unauthorized(self, client: AsyncClient): - """Test chat completion without API key.""" - response = await client.post( - "/api/v1/llm/chat/completions", - json={ - "model": "gpt-3.5-turbo", - "messages": [ - {"role": "user", "content": "Hello"} - ] - } - ) - - assert response.status_code == 401 - - @pytest.mark.asyncio - async def test_embeddings_success(self, client: AsyncClient): - """Test successful embeddings generation.""" - mock_response = { - "data": [ - { - "embedding": [0.1, 0.2, 0.3], - "index": 0 - } - ], - "usage": { - "prompt_tokens": 5, - "total_tokens": 5 - } - } - - with patch("app.services.litellm_client.LiteLLMClient.create_embedding") as mock_embeddings: - mock_embeddings.return_value = mock_response - - response = await client.post( - "/api/v1/llm/embeddings", - json={ - "model": "text-embedding-ada-002", - "input": "Hello world" - }, - headers={"Authorization": "Bearer test-api-key"} - ) - - assert response.status_code == 200 - data = response.json() - assert "data" in data - assert len(data["data"][0]["embedding"]) == 3 - - @pytest.mark.asyncio - async def test_budget_exceeded(self, client: AsyncClient): - """Test budget exceeded scenario.""" - with patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_check: - mock_check.side_effect = Exception("Budget exceeded") - - response = await client.post( - "/api/v1/llm/chat/completions", - json={ - "model": "gpt-3.5-turbo", - "messages": [ - {"role": "user", "content": "Hello"} - ] - }, - headers={"Authorization": "Bearer test-api-key"} - ) - - assert response.status_code == 402 # Payment required - - @pytest.mark.asyncio - async def test_model_validation(self, client: AsyncClient): - """Test model validation.""" - response = await client.post( - "/api/v1/llm/chat/completions", - json={ - "model": "invalid-model", - "messages": [ - {"role": "user", "content": "Hello"} - ] - }, - headers={"Authorization": "Bearer test-api-key"} - ) - - assert response.status_code == 400 \ No newline at end of file diff --git a/backend/tests/integration/test_llm_service_integration.py b/backend/tests/integration/test_llm_service_integration.py new file mode 100644 index 0000000..b9a3dd8 --- /dev/null +++ b/backend/tests/integration/test_llm_service_integration.py @@ -0,0 +1,496 @@ +""" +Integration tests for the new LLM service. +Tests end-to-end functionality including provider integration, security, and performance. +""" +import pytest +import asyncio +import time +from httpx import AsyncClient +from unittest.mock import patch, AsyncMock, MagicMock +import json + + +class TestLLMServiceIntegration: + """Integration tests for LLM service.""" + + @pytest.mark.asyncio + async def test_full_chat_flow(self, client: AsyncClient): + """Test complete chat completion flow with security and budget checks.""" + from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage + + # Mock successful LLM service response + mock_response = ChatCompletionResponse( + id="test-completion-123", + object="chat.completion", + created=int(time.time()), + model="privatemode-llama-3-70b", + choices=[ + ChatChoice( + index=0, + message=ChatMessage( + role="assistant", + content="Hello! I'm a TEE-protected AI assistant. How can I help you today?" + ), + finish_reason="stop" + ) + ], + usage=Usage( + prompt_tokens=25, + completion_tokens=15, + total_tokens=40 + ), + security_analysis={ + "risk_score": 0.1, + "threats_detected": [], + "risk_level": "low", + "analysis_time_ms": 12.5 + } + ) + + with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat, \ + patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_budget: + + mock_chat.return_value = mock_response + mock_budget.return_value = True # Budget check passes + + response = await client.post( + "/api/v1/llm/chat/completions", + json={ + "model": "privatemode-llama-3-70b", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, what are your capabilities?"} + ], + "temperature": 0.7, + "max_tokens": 150 + }, + headers={"Authorization": "Bearer test-api-key"} + ) + + # Verify response structure + assert response.status_code == 200 + data = response.json() + + # Check standard OpenAI-compatible fields + assert "id" in data + assert "object" in data + assert "created" in data + assert "model" in data + assert "choices" in data + assert "usage" in data + + # Check security integration + assert "security_analysis" in data + assert data["security_analysis"]["risk_level"] == "low" + + # Verify content + assert len(data["choices"]) == 1 + assert data["choices"][0]["message"]["role"] == "assistant" + assert "TEE-protected" in data["choices"][0]["message"]["content"] + + # Verify usage tracking + assert data["usage"]["total_tokens"] == 40 + assert data["usage"]["prompt_tokens"] == 25 + assert data["usage"]["completion_tokens"] == 15 + + @pytest.mark.asyncio + async def test_embedding_integration(self, client: AsyncClient): + """Test embedding generation with fallback handling.""" + from app.services.llm.models import EmbeddingResponse, EmbeddingData, Usage + + # Create realistic 1024-dimensional embedding + embedding_vector = [0.1 * i for i in range(1024)] + + mock_response = EmbeddingResponse( + object="list", + data=[ + EmbeddingData( + object="embedding", + embedding=embedding_vector, + index=0 + ) + ], + model="privatemode-embeddings", + usage=Usage( + prompt_tokens=8, + total_tokens=8 + ) + ) + + with patch("app.services.llm.service.llm_service.create_embedding") as mock_embedding: + mock_embedding.return_value = mock_response + + response = await client.post( + "/api/v1/llm/embeddings", + json={ + "model": "privatemode-embeddings", + "input": "This is a test document for embedding generation." + }, + headers={"Authorization": "Bearer test-api-key"} + ) + + assert response.status_code == 200 + data = response.json() + + # Verify embedding structure + assert "object" in data + assert "data" in data + assert "usage" in data + assert len(data["data"]) == 1 + assert len(data["data"][0]["embedding"]) == 1024 + assert data["data"][0]["index"] == 0 + + @pytest.mark.asyncio + async def test_provider_health_integration(self, client: AsyncClient): + """Test provider health monitoring integration.""" + mock_status = { + "privatemode": { + "provider": "PrivateMode.ai", + "status": "healthy", + "latency_ms": 245.8, + "success_rate": 0.987, + "last_check": "2025-01-01T12:00:00Z", + "error_message": None, + "models_available": [ + "privatemode-llama-3-70b", + "privatemode-claude-3-sonnet", + "privatemode-gpt-4o", + "privatemode-embeddings" + ] + } + } + + with patch("app.services.llm.service.llm_service.get_provider_status") as mock_provider: + mock_provider.return_value = mock_status + + response = await client.get( + "/api/v1/llm/providers/status", + headers={"Authorization": "Bearer test-api-key"} + ) + + assert response.status_code == 200 + data = response.json() + + # Check response structure + assert "data" in data + assert "privatemode" in data["data"] + + provider_data = data["data"]["privatemode"] + assert provider_data["status"] == "healthy" + assert provider_data["latency_ms"] < 300 # Reasonable latency + assert provider_data["success_rate"] > 0.95 # High success rate + assert len(provider_data["models_available"]) >= 4 + + @pytest.mark.asyncio + async def test_error_handling_and_fallback(self, client: AsyncClient): + """Test error handling and fallback scenarios.""" + # Test provider unavailable scenario + with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat: + mock_chat.side_effect = Exception("Provider temporarily unavailable") + + response = await client.post( + "/api/v1/llm/chat/completions", + json={ + "model": "privatemode-llama-3-70b", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }, + headers={"Authorization": "Bearer test-api-key"} + ) + + # Should return error but not crash + assert response.status_code in [500, 503] # Server error or service unavailable + + @pytest.mark.asyncio + async def test_security_threat_detection(self, client: AsyncClient): + """Test security threat detection integration.""" + from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage + + # Mock response with security threat detected + mock_response = ChatCompletionResponse( + id="test-completion-security", + object="chat.completion", + created=int(time.time()), + model="privatemode-llama-3-70b", + choices=[ + ChatChoice( + index=0, + message=ChatMessage( + role="assistant", + content="I cannot help with that request as it violates security policies." + ), + finish_reason="stop" + ) + ], + usage=Usage( + prompt_tokens=15, + completion_tokens=12, + total_tokens=27 + ), + security_analysis={ + "risk_score": 0.8, + "threats_detected": ["potential_malicious_code"], + "risk_level": "high", + "blocked": True, + "analysis_time_ms": 45.2 + } + ) + + with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat: + mock_chat.return_value = mock_response + + response = await client.post( + "/api/v1/llm/chat/completions", + json={ + "model": "privatemode-llama-3-70b", + "messages": [ + {"role": "user", "content": "How to create malicious code?"} + ] + }, + headers={"Authorization": "Bearer test-api-key"} + ) + + assert response.status_code == 200 # Request succeeds but content is filtered + data = response.json() + + # Verify security analysis + assert "security_analysis" in data + assert data["security_analysis"]["risk_level"] == "high" + assert data["security_analysis"]["blocked"] is True + assert "malicious" in data["security_analysis"]["threats_detected"][0] + + @pytest.mark.asyncio + async def test_performance_characteristics(self, client: AsyncClient): + """Test performance characteristics of the LLM service.""" + from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage + + # Mock fast response + mock_response = ChatCompletionResponse( + id="test-perf", + object="chat.completion", + created=int(time.time()), + model="privatemode-llama-3-70b", + choices=[ + ChatChoice( + index=0, + message=ChatMessage( + role="assistant", + content="Quick response for performance testing." + ), + finish_reason="stop" + ) + ], + usage=Usage( + prompt_tokens=10, + completion_tokens=8, + total_tokens=18 + ) + ) + + with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat: + mock_chat.return_value = mock_response + + # Measure response time + start_time = time.time() + + response = await client.post( + "/api/v1/llm/chat/completions", + json={ + "model": "privatemode-llama-3-70b", + "messages": [ + {"role": "user", "content": "Quick test"} + ] + }, + headers={"Authorization": "Bearer test-api-key"} + ) + + response_time = time.time() - start_time + + assert response.status_code == 200 + # API should respond quickly (mocked, so should be very fast) + assert response_time < 1.0 # Less than 1 second for mocked response + + @pytest.mark.asyncio + async def test_model_capabilities_detection(self, client: AsyncClient): + """Test model capabilities detection and reporting.""" + from app.services.llm.models import Model + + mock_models = [ + Model( + id="privatemode-llama-3-70b", + object="model", + created=1234567890, + owned_by="PrivateMode.ai", + provider="PrivateMode.ai", + capabilities=["tee", "chat", "function_calling"], + context_window=32768, + max_output_tokens=4096, + supports_streaming=True, + supports_function_calling=True + ), + Model( + id="privatemode-embeddings", + object="model", + created=1234567890, + owned_by="PrivateMode.ai", + provider="PrivateMode.ai", + capabilities=["tee", "embeddings"], + context_window=512, + supports_streaming=False, + supports_function_calling=False + ) + ] + + with patch("app.services.llm.service.llm_service.get_models") as mock_models_call: + mock_models_call.return_value = mock_models + + response = await client.get( + "/api/v1/llm/models", + headers={"Authorization": "Bearer test-api-key"} + ) + + assert response.status_code == 200 + data = response.json() + + # Verify model capabilities + assert len(data["data"]) == 2 + + # Check chat model capabilities + chat_model = next(m for m in data["data"] if m["id"] == "privatemode-llama-3-70b") + assert "tee" in chat_model["capabilities"] + assert "chat" in chat_model["capabilities"] + assert chat_model["supports_streaming"] is True + assert chat_model["supports_function_calling"] is True + assert chat_model["context_window"] == 32768 + + # Check embedding model capabilities + embed_model = next(m for m in data["data"] if m["id"] == "privatemode-embeddings") + assert "tee" in embed_model["capabilities"] + assert "embeddings" in embed_model["capabilities"] + assert embed_model["supports_streaming"] is False + assert embed_model["context_window"] == 512 + + @pytest.mark.asyncio + async def test_concurrent_requests(self, client: AsyncClient): + """Test handling of concurrent requests.""" + from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage + + mock_response = ChatCompletionResponse( + id="test-concurrent", + object="chat.completion", + created=int(time.time()), + model="privatemode-llama-3-70b", + choices=[ + ChatChoice( + index=0, + message=ChatMessage( + role="assistant", + content="Concurrent response" + ), + finish_reason="stop" + ) + ], + usage=Usage( + prompt_tokens=5, + completion_tokens=3, + total_tokens=8 + ) + ) + + with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat: + mock_chat.return_value = mock_response + + # Create multiple concurrent requests + tasks = [] + for i in range(5): + task = client.post( + "/api/v1/llm/chat/completions", + json={ + "model": "privatemode-llama-3-70b", + "messages": [ + {"role": "user", "content": f"Concurrent test {i}"} + ] + }, + headers={"Authorization": "Bearer test-api-key"} + ) + tasks.append(task) + + # Execute all requests concurrently + responses = await asyncio.gather(*tasks) + + # Verify all requests succeeded + for response in responses: + assert response.status_code == 200 + data = response.json() + assert "choices" in data + assert data["choices"][0]["message"]["content"] == "Concurrent response" + + @pytest.mark.asyncio + async def test_budget_enforcement_integration(self, client: AsyncClient): + """Test budget enforcement integration with LLM service.""" + # Test budget exceeded scenario + with patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_budget: + mock_budget.side_effect = Exception("Monthly budget limit exceeded") + + response = await client.post( + "/api/v1/llm/chat/completions", + json={ + "model": "privatemode-llama-3-70b", + "messages": [ + {"role": "user", "content": "Test budget enforcement"} + ] + }, + headers={"Authorization": "Bearer test-api-key"} + ) + + assert response.status_code == 402 # Payment required + + # Test budget warning scenario + from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage + + mock_response = ChatCompletionResponse( + id="test-budget-warning", + object="chat.completion", + created=int(time.time()), + model="privatemode-llama-3-70b", + choices=[ + ChatChoice( + index=0, + message=ChatMessage( + role="assistant", + content="Response with budget warning" + ), + finish_reason="stop" + ) + ], + usage=Usage( + prompt_tokens=10, + completion_tokens=8, + total_tokens=18 + ), + budget_warnings=["Approaching monthly budget limit (85% used)"] + ) + + with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat, \ + patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_budget: + + mock_chat.return_value = mock_response + mock_budget.return_value = True # Budget check passes but with warning + + response = await client.post( + "/api/v1/llm/chat/completions", + json={ + "model": "privatemode-llama-3-70b", + "messages": [ + {"role": "user", "content": "Test budget warning"} + ] + }, + headers={"Authorization": "Bearer test-api-key"} + ) + + assert response.status_code == 200 + data = response.json() + assert "budget_warnings" in data + assert len(data["budget_warnings"]) > 0 + assert "85%" in data["budget_warnings"][0] \ No newline at end of file diff --git a/backend/tests/integration/test_llm_validation.py b/backend/tests/integration/test_llm_validation.py new file mode 100644 index 0000000..00cbc51 --- /dev/null +++ b/backend/tests/integration/test_llm_validation.py @@ -0,0 +1,296 @@ +""" +Simple validation tests for the new LLM service integration. +Tests basic functionality without complex mocking. +""" +import pytest +from httpx import AsyncClient +from unittest.mock import patch, MagicMock + + +class TestLLMServiceValidation: + """Basic validation tests for LLM service integration.""" + + @pytest.mark.asyncio + async def test_llm_models_endpoint_exists(self, client: AsyncClient): + """Test that the LLM models endpoint exists and is accessible.""" + response = await client.get( + "/api/v1/llm/models", + headers={"Authorization": "Bearer test-api-key"} + ) + + # Should not return 404 (endpoint exists) + assert response.status_code != 404 + # May return 500 or other error due to missing LLM service, but endpoint exists + + @pytest.mark.asyncio + async def test_llm_chat_endpoint_exists(self, client: AsyncClient): + """Test that the LLM chat endpoint exists and is accessible.""" + response = await client.post( + "/api/v1/llm/chat/completions", + json={ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }, + headers={"Authorization": "Bearer test-api-key"} + ) + + # Should not return 404 (endpoint exists) + assert response.status_code != 404 + + @pytest.mark.asyncio + async def test_llm_embeddings_endpoint_exists(self, client: AsyncClient): + """Test that the LLM embeddings endpoint exists and is accessible.""" + response = await client.post( + "/api/v1/llm/embeddings", + json={ + "model": "test-embedding-model", + "input": "Test text" + }, + headers={"Authorization": "Bearer test-api-key"} + ) + + # Should not return 404 (endpoint exists) + assert response.status_code != 404 + + @pytest.mark.asyncio + async def test_llm_provider_status_endpoint_exists(self, client: AsyncClient): + """Test that the provider status endpoint exists and is accessible.""" + response = await client.get( + "/api/v1/llm/providers/status", + headers={"Authorization": "Bearer test-api-key"} + ) + + # Should not return 404 (endpoint exists) + assert response.status_code != 404 + + @pytest.mark.asyncio + async def test_chat_with_mocked_service(self, client: AsyncClient): + """Test chat completion with mocked LLM service.""" + from app.services.llm.models import ChatResponse, ChatChoice, ChatMessage, TokenUsage + + # Mock successful response + mock_response = ChatResponse( + id="test-123", + object="chat.completion", + created=1234567890, + model="privatemode-llama-3-70b", + provider="PrivateMode.ai", + choices=[ + ChatChoice( + index=0, + message=ChatMessage( + role="assistant", + content="Hello! How can I help you?" + ), + finish_reason="stop" + ) + ], + usage=TokenUsage( + prompt_tokens=10, + completion_tokens=8, + total_tokens=18 + ), + security_check=True, + risk_score=0.1, + detected_patterns=[], + latency_ms=250.5 + ) + + with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat: + mock_chat.return_value = mock_response + + response = await client.post( + "/api/v1/llm/chat/completions", + json={ + "model": "privatemode-llama-3-70b", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }, + headers={"Authorization": "Bearer test-api-key"} + ) + + assert response.status_code == 200 + data = response.json() + + # Verify basic response structure + assert "id" in data + assert "choices" in data + assert len(data["choices"]) == 1 + assert data["choices"][0]["message"]["content"] == "Hello! How can I help you?" + + @pytest.mark.asyncio + async def test_embedding_with_mocked_service(self, client: AsyncClient): + """Test embedding generation with mocked LLM service.""" + from app.services.llm.models import EmbeddingResponse, EmbeddingData, TokenUsage + + # Create a simple embedding vector + embedding_vector = [0.1, 0.2, 0.3] * 341 + [0.1, 0.2, 0.3] # 1024 dimensions + + mock_response = EmbeddingResponse( + object="list", + data=[ + EmbeddingData( + object="embedding", + index=0, + embedding=embedding_vector + ) + ], + model="privatemode-embeddings", + provider="PrivateMode.ai", + usage=TokenUsage( + prompt_tokens=5, + completion_tokens=0, + total_tokens=5 + ), + security_check=True, + risk_score=0.0, + latency_ms=150.0 + ) + + with patch("app.services.llm.service.llm_service.create_embedding") as mock_embedding: + mock_embedding.return_value = mock_response + + response = await client.post( + "/api/v1/llm/embeddings", + json={ + "model": "privatemode-embeddings", + "input": "Test text for embedding" + }, + headers={"Authorization": "Bearer test-api-key"} + ) + + assert response.status_code == 200 + data = response.json() + + # Verify basic response structure + assert "data" in data + assert len(data["data"]) == 1 + assert len(data["data"][0]["embedding"]) == 1024 + + @pytest.mark.asyncio + async def test_models_with_mocked_service(self, client: AsyncClient): + """Test models listing with mocked LLM service.""" + from app.services.llm.models import ModelInfo + + mock_models = [ + ModelInfo( + id="privatemode-llama-3-70b", + object="model", + created=1234567890, + owned_by="PrivateMode.ai", + provider="PrivateMode.ai", + capabilities=["tee", "chat"], + context_window=32768, + supports_streaming=True + ), + ModelInfo( + id="privatemode-embeddings", + object="model", + created=1234567890, + owned_by="PrivateMode.ai", + provider="PrivateMode.ai", + capabilities=["tee", "embeddings"], + context_window=512, + supports_streaming=False + ) + ] + + with patch("app.services.llm.service.llm_service.get_models") as mock_models_call: + mock_models_call.return_value = mock_models + + response = await client.get( + "/api/v1/llm/models", + headers={"Authorization": "Bearer test-api-key"} + ) + + assert response.status_code == 200 + data = response.json() + + # Verify basic response structure + assert "data" in data + assert len(data["data"]) == 2 + assert data["data"][0]["id"] == "privatemode-llama-3-70b" + + @pytest.mark.asyncio + async def test_provider_status_with_mocked_service(self, client: AsyncClient): + """Test provider status with mocked LLM service.""" + mock_status = { + "privatemode": { + "provider": "PrivateMode.ai", + "status": "healthy", + "latency_ms": 250.5, + "success_rate": 0.98, + "last_check": "2025-01-01T12:00:00Z", + "models_available": ["privatemode-llama-3-70b", "privatemode-embeddings"] + } + } + + with patch("app.services.llm.service.llm_service.get_provider_status") as mock_provider: + mock_provider.return_value = mock_status + + response = await client.get( + "/api/v1/llm/providers/status", + headers={"Authorization": "Bearer test-api-key"} + ) + + assert response.status_code == 200 + data = response.json() + + # Verify basic response structure + assert "data" in data + assert "privatemode" in data["data"] + assert data["data"]["privatemode"]["status"] == "healthy" + + @pytest.mark.asyncio + async def test_unauthorized_access(self, client: AsyncClient): + """Test that unauthorized requests are properly rejected.""" + # Test without authorization header + response = await client.get("/api/v1/llm/models") + assert response.status_code == 401 + + response = await client.post( + "/api/v1/llm/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}] + } + ) + assert response.status_code == 401 + + response = await client.post( + "/api/v1/llm/embeddings", + json={ + "model": "test-model", + "input": "Hello" + } + ) + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_invalid_request_data(self, client: AsyncClient): + """Test that invalid request data is properly handled.""" + # Test invalid JSON structure + response = await client.post( + "/api/v1/llm/chat/completions", + json={ + # Missing required fields + "model": "test-model" + # messages field is missing + }, + headers={"Authorization": "Bearer test-api-key"} + ) + assert response.status_code == 422 # Unprocessable Entity + + # Test empty messages + response = await client.post( + "/api/v1/llm/chat/completions", + json={ + "model": "test-model", + "messages": [] # Empty messages + }, + headers={"Authorization": "Bearer test-api-key"} + ) + assert response.status_code == 422 # Unprocessable Entity \ No newline at end of file diff --git a/frontend/src/components/playground/ModelSelector.tsx b/frontend/src/components/playground/ModelSelector.tsx index 81309dc..674bd8e 100644 --- a/frontend/src/components/playground/ModelSelector.tsx +++ b/frontend/src/components/playground/ModelSelector.tsx @@ -6,7 +6,7 @@ import { Badge } from '@/components/ui/badge' import { Button } from '@/components/ui/button' import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' import { Alert, AlertDescription } from '@/components/ui/alert' -import { RefreshCw, Zap, Info, AlertCircle } from 'lucide-react' +import { RefreshCw, Zap, Info, AlertCircle, CheckCircle, XCircle, Clock } from 'lucide-react' interface Model { id: string @@ -16,6 +16,22 @@ interface Model { permission?: any[] root?: string parent?: string + provider?: string + capabilities?: string[] + context_window?: number + max_output_tokens?: number + supports_streaming?: boolean + supports_function_calling?: boolean +} + +interface ProviderStatus { + provider: string + status: 'healthy' | 'degraded' | 'unavailable' + latency_ms?: number + success_rate?: number + last_check: string + error_message?: string + models_available: string[] } interface ModelSelectorProps { @@ -27,6 +43,7 @@ interface ModelSelectorProps { export default function ModelSelector({ value, onValueChange, filter = 'all', className }: ModelSelectorProps) { const [models, setModels] = useState([]) + const [providerStatus, setProviderStatus] = useState>({}) const [loading, setLoading] = useState(true) const [error, setError] = useState(null) const [showDetails, setShowDetails] = useState(false) @@ -37,20 +54,31 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl // Get the auth token from localStorage const token = localStorage.getItem('token') + const headers = { + 'Authorization': token ? `Bearer ${token}` : '', + 'Content-Type': 'application/json' + } - const response = await fetch('/api/llm/models', { - headers: { - 'Authorization': token ? `Bearer ${token}` : '', - 'Content-Type': 'application/json' - } - }) - - if (!response.ok) { + // Fetch models and provider status in parallel + const [modelsResponse, statusResponse] = await Promise.allSettled([ + fetch('/api/llm/models', { headers }), + fetch('/api/llm/providers/status', { headers }) + ]) + + // Handle models response + if (modelsResponse.status === 'fulfilled' && modelsResponse.value.ok) { + const modelsData = await modelsResponse.value.json() + setModels(modelsData.data || []) + } else { throw new Error('Failed to fetch models') } - - const data = await response.json() - setModels(data.data || []) + + // Handle provider status response (optional) + if (statusResponse.status === 'fulfilled' && statusResponse.value.ok) { + const statusData = await statusResponse.value.json() + setProviderStatus(statusData.data || {}) + } + setError(null) } catch (err) { setError(err instanceof Error ? err.message : 'Failed to load models') @@ -64,30 +92,39 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl }, []) const getProviderFromModel = (modelId: string): string => { + // PrivateMode models have specific prefixes + if (modelId.startsWith('privatemode-')) return 'PrivateMode.ai' + + // Legacy detection for other providers if (modelId.startsWith('gpt-') || modelId.includes('openai')) return 'OpenAI' if (modelId.startsWith('claude-') || modelId.includes('anthropic')) return 'Anthropic' if (modelId.startsWith('gemini-') || modelId.includes('google')) return 'Google' - if (modelId.includes('privatemode')) return 'Privatemode.ai' if (modelId.includes('cohere')) return 'Cohere' if (modelId.includes('mistral')) return 'Mistral' - if (modelId.includes('llama')) return 'Meta' + if (modelId.includes('llama') && !modelId.startsWith('privatemode-')) return 'Meta' return 'Unknown' } const getModelType = (modelId: string): 'chat' | 'embedding' | 'other' => { - if (modelId.includes('embedding')) return 'embedding' - if (modelId.includes('whisper')) return 'other' // Audio transcription models + if (modelId.includes('embedding') || modelId.includes('embed')) return 'embedding' + if (modelId.includes('whisper') || modelId.includes('speech')) return 'other' // Audio models + + // PrivateMode and other chat models if ( + modelId.startsWith('privatemode-llama') || + modelId.startsWith('privatemode-claude') || + modelId.startsWith('privatemode-gpt') || + modelId.startsWith('privatemode-gemini') || modelId.includes('text-') || modelId.includes('gpt-') || modelId.includes('claude-') || modelId.includes('gemini-') || - modelId.includes('privatemode-') || modelId.includes('llama') || modelId.includes('gemma') || modelId.includes('qwen') || modelId.includes('latest') ) return 'chat' + return 'other' } @@ -112,6 +149,28 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl acc[provider].push(model) return acc }, {} as Record) + + const getProviderStatusIcon = (provider: string) => { + const status = providerStatus[provider.toLowerCase()]?.status || 'unknown' + switch (status) { + case 'healthy': + return + case 'degraded': + return + case 'unavailable': + return + default: + return + } + } + + const getProviderStatusText = (provider: string) => { + const status = providerStatus[provider.toLowerCase()] + if (!status) return 'Status unknown' + + const latencyText = status.latency_ms ? ` (${Math.round(status.latency_ms)}ms)` : '' + return `${status.status.charAt(0).toUpperCase() + status.status.slice(1)}${latencyText}` + } const selectedModel = models.find(m => m.id === value) @@ -191,16 +250,32 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl {Object.entries(groupedModels).map(([provider, providerModels]) => (
-
- {provider} +
+ {getProviderStatusIcon(provider)} + {provider} + + {getProviderStatusText(provider)} +
{providerModels.map((model) => (
{model.id} - - {getModelCategory(model.id)} - +
+ + {getModelCategory(model.id)} + + {model.supports_streaming && ( + + Streaming + + )} + {model.supports_function_calling && ( + + Functions + + )} +
))} @@ -217,7 +292,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl Model Details - +
ID: @@ -225,7 +300,10 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
Provider: -
{getProviderFromModel(selectedModel.id)}
+
+ {getProviderStatusIcon(getProviderFromModel(selectedModel.id))} + {getProviderFromModel(selectedModel.id)} +
Type: @@ -237,6 +315,40 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
+ {(selectedModel.context_window || selectedModel.max_output_tokens) && ( +
+ {selectedModel.context_window && ( +
+ Context Window: +
{selectedModel.context_window.toLocaleString()} tokens
+
+ )} + {selectedModel.max_output_tokens && ( +
+ Max Output: +
{selectedModel.max_output_tokens.toLocaleString()} tokens
+
+ )} +
+ )} + + {(selectedModel.supports_streaming || selectedModel.supports_function_calling) && ( +
+ Capabilities: +
+ {selectedModel.supports_streaming && ( + Streaming + )} + {selectedModel.supports_function_calling && ( + Function Calling + )} + {selectedModel.capabilities?.includes('tee') && ( + TEE Protected + )} +
+
+ )} + {selectedModel.created && (
Created: @@ -252,6 +364,46 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
{selectedModel.owned_by}
)} + + {/* Provider Status Details */} + {providerStatus[getProviderFromModel(selectedModel.id).toLowerCase()] && ( +
+ Provider Status: +
+ {(() => { + const status = providerStatus[getProviderFromModel(selectedModel.id).toLowerCase()] + return ( + <> +
+ Status: + {status.status} +
+ {status.latency_ms && ( +
+ Latency: + {Math.round(status.latency_ms)}ms +
+ )} + {status.success_rate && ( +
+ Success Rate: + {Math.round(status.success_rate * 100)}% +
+ )} +
+ Last Check: + {new Date(status.last_check).toLocaleTimeString()} +
+ + ) + })()} +
+
+ )}
)} diff --git a/frontend/src/components/playground/ProviderHealthDashboard.tsx b/frontend/src/components/playground/ProviderHealthDashboard.tsx new file mode 100644 index 0000000..85db898 --- /dev/null +++ b/frontend/src/components/playground/ProviderHealthDashboard.tsx @@ -0,0 +1,373 @@ +"use client" + +import { useState, useEffect } from 'react' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { Badge } from '@/components/ui/badge' +import { Button } from '@/components/ui/button' +import { Alert, AlertDescription } from '@/components/ui/alert' +import { Progress } from '@/components/ui/progress' +import { + CheckCircle, + XCircle, + Clock, + AlertCircle, + RefreshCw, + Activity, + Zap, + Shield, + Server +} from 'lucide-react' + +interface ProviderStatus { + provider: string + status: 'healthy' | 'degraded' | 'unavailable' + latency_ms?: number + success_rate?: number + last_check: string + error_message?: string + models_available: string[] +} + +interface LLMMetrics { + total_requests: number + successful_requests: number + failed_requests: number + security_blocked_requests: number + average_latency_ms: number + average_risk_score: number + provider_metrics: Record + last_updated: string +} + +export default function ProviderHealthDashboard() { + const [providers, setProviders] = useState>({}) + const [metrics, setMetrics] = useState(null) + const [loading, setLoading] = useState(true) + const [error, setError] = useState(null) + + const fetchData = async () => { + try { + setLoading(true) + const token = localStorage.getItem('token') + const headers = { + 'Authorization': token ? `Bearer ${token}` : '', + 'Content-Type': 'application/json' + } + + const [statusResponse, metricsResponse] = await Promise.allSettled([ + fetch('/api/llm/providers/status', { headers }), + fetch('/api/llm/metrics', { headers }) + ]) + + // Handle provider status + if (statusResponse.status === 'fulfilled' && statusResponse.value.ok) { + const statusData = await statusResponse.value.json() + setProviders(statusData.data || {}) + } + + // Handle metrics (optional, might require admin permissions) + if (metricsResponse.status === 'fulfilled' && metricsResponse.value.ok) { + const metricsData = await metricsResponse.value.json() + setMetrics(metricsData.data) + } + + setError(null) + } catch (err) { + setError(err instanceof Error ? err.message : 'Failed to load provider data') + } finally { + setLoading(false) + } + } + + useEffect(() => { + fetchData() + + // Auto-refresh every 30 seconds + const interval = setInterval(fetchData, 30000) + return () => clearInterval(interval) + }, []) + + const getStatusIcon = (status: string) => { + switch (status) { + case 'healthy': + return + case 'degraded': + return + case 'unavailable': + return + default: + return + } + } + + const getStatusColor = (status: string) => { + switch (status) { + case 'healthy': + return 'text-green-600 bg-green-50 border-green-200' + case 'degraded': + return 'text-yellow-600 bg-yellow-50 border-yellow-200' + case 'unavailable': + return 'text-red-600 bg-red-50 border-red-200' + default: + return 'text-gray-600 bg-gray-50 border-gray-200' + } + } + + const getLatencyColor = (latency: number) => { + if (latency < 500) return 'text-green-600' + if (latency < 2000) return 'text-yellow-600' + return 'text-red-600' + } + + if (loading) { + return ( +
+
+

Provider Health Dashboard

+ +
+
+ {[1, 2, 3].map(i => ( + + +
+
+
+ +
+
+
+
+
+
+ ))} +
+
+ ) + } + + if (error) { + return ( +
+
+

Provider Health Dashboard

+ +
+ + + {error} + +
+ ) + } + + const totalProviders = Object.keys(providers).length + const healthyProviders = Object.values(providers).filter(p => p.status === 'healthy').length + const overallHealth = totalProviders > 0 ? (healthyProviders / totalProviders) * 100 : 0 + + return ( +
+
+

Provider Health Dashboard

+ +
+ + {/* Overall Health Summary */} +
+ + + Overall Health + + + +
{Math.round(overallHealth)}%
+ +
+
+ + + + Healthy Providers + + + +
{healthyProviders}
+

of {totalProviders} providers

+
+
+ + {metrics && ( + <> + + + Success Rate + + + +
+ {metrics.total_requests > 0 + ? Math.round((metrics.successful_requests / metrics.total_requests) * 100) + : 0}% +
+

+ {metrics.successful_requests.toLocaleString()} / {metrics.total_requests.toLocaleString()} requests +

+
+
+ + + + Security Score + + + +
+ {Math.round((1 - metrics.average_risk_score) * 100)}% +
+

+ {metrics.security_blocked_requests} blocked requests +

+
+
+ + )} +
+ + {/* Provider Details */} +
+ {Object.entries(providers).map(([name, provider]) => ( + + +
+ + {getStatusIcon(provider.status)} + {provider.provider} + + + {provider.status} + +
+ + {provider.models_available.length} models available + +
+ + + {/* Performance Metrics */} + {provider.latency_ms && ( +
+ Latency + + {Math.round(provider.latency_ms)}ms + +
+ )} + + {provider.success_rate !== undefined && ( +
+ Success Rate + + {Math.round(provider.success_rate * 100)}% + +
+ )} + +
+ Last Check + + {new Date(provider.last_check).toLocaleTimeString()} + +
+ + {/* Error Message */} + {provider.error_message && ( + + + + {provider.error_message} + + + )} + + {/* Models */} +
+ Available Models +
+ {provider.models_available.slice(0, 3).map(model => ( + + {model} + + ))} + {provider.models_available.length > 3 && ( + + +{provider.models_available.length - 3} more + + )} +
+
+
+
+ ))} +
+ + {/* Provider Metrics Details */} + {metrics && Object.keys(metrics.provider_metrics).length > 0 && ( + + + + + Provider Performance Metrics + + + Detailed performance statistics for each provider + + + +
+ {Object.entries(metrics.provider_metrics).map(([provider, data]: [string, any]) => ( +
+

{provider}

+
+
+ Total Requests: + {data.total_requests?.toLocaleString() || 0} +
+
+ Success Rate: + + {data.success_rate ? Math.round(data.success_rate * 100) : 0}% + +
+
+ Avg Latency: + + {Math.round(data.average_latency_ms || 0)}ms + +
+ {data.token_usage && ( +
+ Total Tokens: + + {data.token_usage.total_tokens?.toLocaleString() || 0} + +
+ )} +
+
+ ))} +
+
+
+ )} +
+ ) +} \ No newline at end of file