removing lite llm and going directly for privatemode

This commit is contained in:
2025-08-21 08:44:05 +02:00
parent be581b28f8
commit 27ee8b4cdb
16 changed files with 1775 additions and 677 deletions

View File

@@ -1,5 +1,5 @@
""" """
LLM API endpoints - proxy to LiteLLM service with authentication and budget enforcement LLM API endpoints - interface to secure LLM service with authentication and budget enforcement
""" """
import logging import logging
@@ -16,7 +16,9 @@ from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthS
from app.core.security import get_current_user from app.core.security import get_current_user
from app.models.user import User from app.models.user import User
from app.core.config import settings from app.core.config import settings
from app.services.litellm_client import litellm_client from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage, EmbeddingRequest as LLMEmbeddingRequest
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
from app.services.budget_enforcement import ( from app.services.budget_enforcement import (
check_budget_for_request, record_request_usage, BudgetEnforcementService, check_budget_for_request, record_request_usage, BudgetEnforcementService,
atomic_check_and_reserve_budget, atomic_finalize_usage atomic_check_and_reserve_budget, atomic_finalize_usage
@@ -38,7 +40,7 @@ router = APIRouter()
async def get_cached_models() -> List[Dict[str, Any]]: async def get_cached_models() -> List[Dict[str, Any]]:
"""Get models from cache or fetch from LiteLLM if cache is stale""" """Get models from cache or fetch from LLM service if cache is stale"""
current_time = time.time() current_time = time.time()
# Check if cache is still valid # Check if cache is still valid
@@ -47,10 +49,20 @@ async def get_cached_models() -> List[Dict[str, Any]]:
logger.debug("Returning cached models list") logger.debug("Returning cached models list")
return _models_cache["data"] return _models_cache["data"]
# Cache miss or stale - fetch from LiteLLM # Cache miss or stale - fetch from LLM service
try: try:
logger.debug("Fetching fresh models list from LiteLLM") logger.debug("Fetching fresh models list from LLM service")
models = await litellm_client.get_models() model_infos = await llm_service.get_models()
# Convert ModelInfo objects to dict format for compatibility
models = []
for model_info in model_infos:
models.append({
"id": model_info.id,
"object": model_info.object,
"created": model_info.created or int(time.time()),
"owned_by": model_info.owned_by
})
# Update cache # Update cache
_models_cache["data"] = models _models_cache["data"] = models
@@ -58,7 +70,7 @@ async def get_cached_models() -> List[Dict[str, Any]]:
return models return models
except Exception as e: except Exception as e:
logger.error(f"Failed to fetch models from LiteLLM: {e}") logger.error(f"Failed to fetch models from LLM service: {e}")
# Return stale cache if available, otherwise empty list # Return stale cache if available, otherwise empty list
if _models_cache["data"] is not None: if _models_cache["data"] is not None:
@@ -75,7 +87,7 @@ def invalidate_models_cache():
logger.info("Models cache invalidated") logger.info("Models cache invalidated")
# Request/Response Models # Request/Response Models (API layer)
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: str = Field(..., description="Message role (system, user, assistant)") role: str = Field(..., description="Message role (system, user, assistant)")
content: str = Field(..., description="Message content") content: str = Field(..., description="Message content")
@@ -183,7 +195,7 @@ async def list_models(
detail="Insufficient permissions to list models" detail="Insufficient permissions to list models"
) )
# Get models from cache or LiteLLM # Get models from cache or LLM service
models = await get_cached_models() models = await get_cached_models()
# Filter models based on API key permissions # Filter models based on API key permissions
@@ -309,35 +321,55 @@ async def create_chat_completion(
warnings = budget_warnings warnings = budget_warnings
reserved_budget_ids = budget_ids reserved_budget_ids = budget_ids
# Convert messages to dict format # Convert messages to LLM service format
messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages] llm_messages = [LLMChatMessage(role=msg.role, content=msg.content) for msg in chat_request.messages]
# Prepare additional parameters # Create LLM service request
kwargs = {} llm_request = ChatRequest(
if chat_request.max_tokens is not None:
kwargs["max_tokens"] = chat_request.max_tokens
if chat_request.temperature is not None:
kwargs["temperature"] = chat_request.temperature
if chat_request.top_p is not None:
kwargs["top_p"] = chat_request.top_p
if chat_request.frequency_penalty is not None:
kwargs["frequency_penalty"] = chat_request.frequency_penalty
if chat_request.presence_penalty is not None:
kwargs["presence_penalty"] = chat_request.presence_penalty
if chat_request.stop is not None:
kwargs["stop"] = chat_request.stop
if chat_request.stream is not None:
kwargs["stream"] = chat_request.stream
# Make request to LiteLLM
response = await litellm_client.create_chat_completion(
model=chat_request.model, model=chat_request.model,
messages=messages, messages=llm_messages,
temperature=chat_request.temperature,
max_tokens=chat_request.max_tokens,
top_p=chat_request.top_p,
frequency_penalty=chat_request.frequency_penalty,
presence_penalty=chat_request.presence_penalty,
stop=chat_request.stop,
stream=chat_request.stream or False,
user_id=str(context.get("user_id", "anonymous")), user_id=str(context.get("user_id", "anonymous")),
api_key_id=context.get("api_key_id", "jwt_user"), api_key_id=context.get("api_key_id", 0) if auth_type == "api_key" else 0
**kwargs
) )
# Make request to LLM service
llm_response = await llm_service.create_chat_completion(llm_request)
# Convert LLM service response to API format
response = {
"id": llm_response.id,
"object": llm_response.object,
"created": llm_response.created,
"model": llm_response.model,
"choices": [
{
"index": choice.index,
"message": {
"role": choice.message.role,
"content": choice.message.content
},
"finish_reason": choice.finish_reason
}
for choice in llm_response.choices
],
"usage": {
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
"completion_tokens": llm_response.usage.completion_tokens if llm_response.usage else 0,
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
} if llm_response.usage else {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
# Calculate actual cost and update usage # Calculate actual cost and update usage
usage = response.get("usage", {}) usage = response.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0) input_tokens = usage.get("prompt_tokens", 0)
@@ -382,8 +414,38 @@ async def create_chat_completion(
except HTTPException: except HTTPException:
raise raise
except SecurityError as e:
logger.warning(f"Security error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Security validation failed: {e.message}"
)
except ValidationError as e:
logger.warning(f"Validation error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Request validation failed: {e.message}"
)
except ProviderError as e:
logger.error(f"Provider error in chat completion: {e}")
if "rate limit" in str(e).lower():
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LLM service temporarily unavailable"
)
except LLMError as e:
logger.error(f"LLM service error in chat completion: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="LLM service error"
)
except Exception as e: except Exception as e:
logger.error(f"Error creating chat completion: {e}") logger.error(f"Unexpected error creating chat completion: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create chat completion" detail="Failed to create chat completion"
@@ -438,15 +500,39 @@ async def create_embedding(
detail=f"Budget exceeded: {error_message}" detail=f"Budget exceeded: {error_message}"
) )
# Make request to LiteLLM # Create LLM service request
response = await litellm_client.create_embedding( llm_request = LLMEmbeddingRequest(
model=request.model, model=request.model,
input_text=request.input, input=request.input,
encoding_format=request.encoding_format,
user_id=str(context["user_id"]), user_id=str(context["user_id"]),
api_key_id=context["api_key_id"], api_key_id=context["api_key_id"]
encoding_format=request.encoding_format
) )
# Make request to LLM service
llm_response = await llm_service.create_embedding(llm_request)
# Convert LLM service response to API format
response = {
"object": llm_response.object,
"data": [
{
"object": emb.object,
"index": emb.index,
"embedding": emb.embedding
}
for emb in llm_response.data
],
"model": llm_response.model,
"usage": {
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
} if llm_response.usage else {
"prompt_tokens": int(estimated_tokens),
"total_tokens": int(estimated_tokens)
}
}
# Calculate actual cost and update usage # Calculate actual cost and update usage
usage = response.get("usage", {}) usage = response.get("usage", {})
total_tokens = usage.get("total_tokens", int(estimated_tokens)) total_tokens = usage.get("total_tokens", int(estimated_tokens))
@@ -475,8 +561,38 @@ async def create_embedding(
except HTTPException: except HTTPException:
raise raise
except SecurityError as e:
logger.warning(f"Security error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Security validation failed: {e.message}"
)
except ValidationError as e:
logger.warning(f"Validation error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Request validation failed: {e.message}"
)
except ProviderError as e:
logger.error(f"Provider error in embedding: {e}")
if "rate limit" in str(e).lower():
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LLM service temporarily unavailable"
)
except LLMError as e:
logger.error(f"LLM service error in embedding: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="LLM service error"
)
except Exception as e: except Exception as e:
logger.error(f"Error creating embedding: {e}") logger.error(f"Unexpected error creating embedding: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create embedding" detail="Failed to create embedding"
@@ -489,11 +605,28 @@ async def llm_health_check(
): ):
"""Health check for LLM service""" """Health check for LLM service"""
try: try:
health_status = await litellm_client.health_check() health_summary = llm_service.get_health_summary()
provider_status = await llm_service.get_provider_status()
# Determine overall health
overall_status = "healthy"
if health_summary["service_status"] != "healthy":
overall_status = "degraded"
for provider, status in provider_status.items():
if status.status == "unavailable":
overall_status = "degraded"
break
return { return {
"status": "healthy", "status": overall_status,
"service": "LLM Proxy", "service": "LLM Service",
"litellm_status": health_status, "service_status": health_summary,
"provider_status": {name: {
"status": status.status,
"latency_ms": status.latency_ms,
"error_message": status.error_message
} for name, status in provider_status.items()},
"user_id": context["user_id"], "user_id": context["user_id"],
"api_key_name": context["api_key_name"] "api_key_name": context["api_key_name"]
} }
@@ -501,7 +634,7 @@ async def llm_health_check(
logger.error(f"LLM health check error: {e}") logger.error(f"LLM health check error: {e}")
return { return {
"status": "unhealthy", "status": "unhealthy",
"service": "LLM Proxy", "service": "LLM Service",
"error": str(e) "error": str(e)
} }
@@ -626,50 +759,83 @@ async def get_budget_status(
) )
# Generic proxy endpoint for other LiteLLM endpoints # Generic endpoint for additional LLM service functionality
@router.api_route("/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) @router.get("/metrics")
async def proxy_endpoint( async def get_llm_metrics(
endpoint: str,
request: Request,
context: Dict[str, Any] = Depends(require_api_key), context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
"""Generic proxy endpoint for LiteLLM requests""" """Get LLM service metrics (admin only)"""
try: try:
# Check for admin permissions
auth_service = APIKeyAuthService(db) auth_service = APIKeyAuthService(db)
if not await auth_service.check_scope_permission(context, "admin.metrics"):
# Check endpoint permission
if not await auth_service.check_endpoint_permission(context, endpoint):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"Endpoint '{endpoint}' not allowed" detail="Admin permissions required to view metrics"
) )
# Get request body metrics = llm_service.get_metrics()
if request.method in ["POST", "PUT", "PATCH"]: return {
try: "object": "llm_metrics",
payload = await request.json() "data": {
except: "total_requests": metrics.total_requests,
payload = {} "successful_requests": metrics.successful_requests,
else: "failed_requests": metrics.failed_requests,
payload = dict(request.query_params) "security_blocked_requests": metrics.security_blocked_requests,
"average_latency_ms": metrics.average_latency_ms,
# Make request to LiteLLM "average_risk_score": metrics.average_risk_score,
response = await litellm_client.proxy_request( "provider_metrics": metrics.provider_metrics,
method=request.method, "last_updated": metrics.last_updated.isoformat()
endpoint=endpoint, }
payload=payload, }
user_id=str(context["user_id"]),
api_key_id=context["api_key_id"]
)
return response
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error proxying request to {endpoint}: {e}") logger.error(f"Error getting LLM metrics: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to proxy request" detail="Failed to get LLM metrics"
)
@router.get("/providers/status")
async def get_provider_status(
context: Dict[str, Any] = Depends(require_api_key),
db: AsyncSession = Depends(get_db)
):
"""Get status of all LLM providers"""
try:
auth_service = APIKeyAuthService(db)
if not await auth_service.check_scope_permission(context, "admin.status"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin permissions required to view provider status"
)
provider_status = await llm_service.get_provider_status()
return {
"object": "provider_status",
"data": {
name: {
"provider": status.provider,
"status": status.status,
"latency_ms": status.latency_ms,
"success_rate": status.success_rate,
"last_check": status.last_check.isoformat(),
"error_message": status.error_message,
"models_available": status.models_available
}
for name, status in provider_status.items()
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting provider status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get provider status"
) )

View File

@@ -447,7 +447,7 @@ async def get_module_config(module_name: str):
log_api_request("get_module_config", {"module_name": module_name}) log_api_request("get_module_config", {"module_name": module_name})
from app.services.module_config_manager import module_config_manager from app.services.module_config_manager import module_config_manager
from app.services.litellm_client import litellm_client from app.services.llm.service import llm_service
import copy import copy
# Get module manifest and schema # Get module manifest and schema
@@ -461,9 +461,9 @@ async def get_module_config(module_name: str):
# For Signal module, populate model options dynamically # For Signal module, populate model options dynamically
if module_name == "signal" and schema: if module_name == "signal" and schema:
try: try:
# Get available models from LiteLLM # Get available models from LLM service
models_data = await litellm_client.get_models() models_data = await llm_service.get_models()
model_ids = [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")] model_ids = [model.id for model in models_data]
if model_ids: if model_ids:
# Create a copy of the schema to avoid modifying the original # Create a copy of the schema to avoid modifying the original

View File

@@ -15,7 +15,8 @@ from app.models.prompt_template import PromptTemplate, ChatbotPromptVariable
from app.core.security import get_current_user from app.core.security import get_current_user
from app.models.user import User from app.models.user import User
from app.core.logging import log_api_request from app.core.logging import log_api_request
from app.services.litellm_client import litellm_client from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
router = APIRouter() router = APIRouter()
@@ -394,25 +395,28 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
] ]
# Get available models to use a default model # Get available models to use a default model
models = await litellm_client.get_models() models = await llm_service.get_models()
if not models: if not models:
raise HTTPException(status_code=503, detail="No LLM models available") raise HTTPException(status_code=503, detail="No LLM models available")
# Use the first available model (you might want to make this configurable) # Use the first available model (you might want to make this configurable)
default_model = models[0]["id"] default_model = models[0].id
# Make the AI call # Prepare the chat request for the new LLM service
response = await litellm_client.create_chat_completion( chat_request = LLMChatRequest(
model=default_model, model=default_model,
messages=messages, messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages],
user_id=str(user_id),
api_key_id=1, # Using default API key, you might want to make this dynamic
temperature=0.3, temperature=0.3,
max_tokens=1000 max_tokens=1000,
user_id=str(user_id),
api_key_id=1 # Using default API key, you might want to make this dynamic
) )
# Make the AI call
response = await llm_service.create_chat_completion(chat_request)
# Extract the improved prompt from the response # Extract the improved prompt from the response
improved_prompt = response["choices"][0]["message"]["content"].strip() improved_prompt = response.choices[0].message.content.strip()
return { return {
"improved_prompt": improved_prompt, "improved_prompt": improved_prompt,

View File

@@ -51,7 +51,7 @@ class SystemInfoResponse(BaseModel):
environment: str environment: str
database_status: str database_status: str
redis_status: str redis_status: str
litellm_status: str llm_service_status: str
modules_loaded: int modules_loaded: int
active_users: int active_users: int
total_api_keys: int total_api_keys: int
@@ -227,8 +227,13 @@ async def get_system_info(
# Get Redis status (simplified check) # Get Redis status (simplified check)
redis_status = "healthy" # Would implement actual Redis check redis_status = "healthy" # Would implement actual Redis check
# Get LiteLLM status (simplified check) # Get LLM service status
litellm_status = "healthy" # Would implement actual LiteLLM check try:
from app.services.llm.service import llm_service
health_summary = llm_service.get_health_summary()
llm_service_status = health_summary.get("service_status", "unknown")
except Exception:
llm_service_status = "unavailable"
# Get modules loaded (from module manager) # Get modules loaded (from module manager)
modules_loaded = 8 # Would get from actual module manager modules_loaded = 8 # Would get from actual module manager
@@ -261,7 +266,7 @@ async def get_system_info(
environment="production", environment="production",
database_status=database_status, database_status=database_status,
redis_status=redis_status, redis_status=redis_status,
litellm_status=litellm_status, llm_service_status=llm_service_status,
modules_loaded=modules_loaded, modules_loaded=modules_loaded,
active_users=active_users, active_users=active_users,
total_api_keys=total_api_keys, total_api_keys=total_api_keys,

View File

@@ -43,15 +43,18 @@ class Settings(BaseSettings):
# CORS # CORS
CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"] CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"]
# LiteLLM # LLM Service Configuration (replaced LiteLLM)
LITELLM_BASE_URL: str = "http://localhost:4000" # LLM service configuration is now handled in app/services/llm/config.py
LITELLM_MASTER_KEY: str = "enclava-master-key"
# LLM Service Security
LLM_ENCRYPTION_KEY: Optional[str] = None # Key for encrypting LLM provider API keys
# API Keys for LLM providers # API Keys for LLM providers
OPENAI_API_KEY: Optional[str] = None OPENAI_API_KEY: Optional[str] = None
ANTHROPIC_API_KEY: Optional[str] = None ANTHROPIC_API_KEY: Optional[str] = None
GOOGLE_API_KEY: Optional[str] = None GOOGLE_API_KEY: Optional[str] = None
PRIVATEMODE_API_KEY: Optional[str] = None PRIVATEMODE_API_KEY: Optional[str] = None
PRIVATEMODE_PROXY_URL: str = "http://privatemode-proxy:8080/v1"
# Qdrant # Qdrant
QDRANT_HOST: str = "localhost" QDRANT_HOST: str = "localhost"

View File

@@ -1,6 +1,6 @@
""" """
Embedding Service Embedding Service
Provides text embedding functionality using LiteLLM proxy Provides text embedding functionality using LLM service
""" """
import logging import logging
@@ -11,32 +11,34 @@ logger = logging.getLogger(__name__)
class EmbeddingService: class EmbeddingService:
"""Service for generating text embeddings using LiteLLM""" """Service for generating text embeddings using LLM service"""
def __init__(self, model_name: str = "privatemode-embeddings"): def __init__(self, model_name: str = "privatemode-embeddings"):
self.model_name = model_name self.model_name = model_name
self.litellm_client = None
self.dimension = 1024 # Actual dimension for privatemode-embeddings self.dimension = 1024 # Actual dimension for privatemode-embeddings
self.initialized = False self.initialized = False
async def initialize(self): async def initialize(self):
"""Initialize the embedding service with LiteLLM""" """Initialize the embedding service with LLM service"""
try: try:
from app.services.litellm_client import litellm_client from app.services.llm.service import llm_service
self.litellm_client = litellm_client
# Test connection to LiteLLM # Initialize LLM service if not already done
health = await self.litellm_client.health_check() if not llm_service._initialized:
if health.get("status") == "unhealthy": await llm_service.initialize()
logger.error(f"LiteLLM service unhealthy: {health.get('error')}")
# Test LLM service health
health_summary = llm_service.get_health_summary()
if health_summary.get("service_status") != "healthy":
logger.error(f"LLM service unhealthy: {health_summary}")
return False return False
self.initialized = True self.initialized = True
logger.info(f"Embedding service initialized with LiteLLM: {self.model_name} (dimension: {self.dimension})") logger.info(f"Embedding service initialized with LLM service: {self.model_name} (dimension: {self.dimension})")
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize LiteLLM embedding service: {e}") logger.error(f"Failed to initialize LLM embedding service: {e}")
logger.warning("Using fallback random embeddings") logger.warning("Using fallback random embeddings")
return False return False
@@ -46,10 +48,10 @@ class EmbeddingService:
return embeddings[0] return embeddings[0]
async def get_embeddings(self, texts: List[str]) -> List[List[float]]: async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for multiple texts using LiteLLM""" """Get embeddings for multiple texts using LLM service"""
if not self.initialized or not self.litellm_client: if not self.initialized:
# Fallback to random embeddings if not initialized # Fallback to random embeddings if not initialized
logger.warning("LiteLLM not available, using random embeddings") logger.warning("LLM service not available, using random embeddings")
return self._generate_fallback_embeddings(texts) return self._generate_fallback_embeddings(texts)
try: try:
@@ -73,17 +75,22 @@ class EmbeddingService:
else: else:
truncated_text = text truncated_text = text
# Call LiteLLM embedding endpoint # Call LLM service embedding endpoint
response = await self.litellm_client.create_embedding( from app.services.llm.service import llm_service
from app.services.llm.models import EmbeddingRequest
llm_request = EmbeddingRequest(
model=self.model_name, model=self.model_name,
input_text=truncated_text, input=truncated_text,
user_id="rag_system", user_id="rag_system",
api_key_id=0 # System API key api_key_id=0 # System API key
) )
response = await llm_service.create_embedding(llm_request)
# Extract embedding from response # Extract embedding from response
if "data" in response and len(response["data"]) > 0: if response.data and len(response.data) > 0:
embedding = response["data"][0].get("embedding", []) embedding = response.data[0].embedding
if embedding: if embedding:
batch_embeddings.append(embedding) batch_embeddings.append(embedding)
# Update dimension based on actual embedding size # Update dimension based on actual embedding size
@@ -106,7 +113,7 @@ class EmbeddingService:
return embeddings return embeddings
except Exception as e: except Exception as e:
logger.error(f"Error generating embeddings with LiteLLM: {e}") logger.error(f"Error generating embeddings with LLM service: {e}")
# Fallback to random embeddings # Fallback to random embeddings
return self._generate_fallback_embeddings(texts) return self._generate_fallback_embeddings(texts)
@@ -146,14 +153,13 @@ class EmbeddingService:
"model_name": self.model_name, "model_name": self.model_name,
"model_loaded": self.initialized, "model_loaded": self.initialized,
"dimension": self.dimension, "dimension": self.dimension,
"backend": "LiteLLM", "backend": "LLM Service",
"initialized": self.initialized "initialized": self.initialized
} }
async def cleanup(self): async def cleanup(self):
"""Cleanup resources""" """Cleanup resources"""
self.initialized = False self.initialized = False
self.litellm_client = None
# Global embedding service instance # Global embedding service instance

View File

@@ -1,304 +0,0 @@
"""
LiteLLM Client Service
Handles communication with the LiteLLM proxy service
"""
import asyncio
import json
import logging
from typing import Dict, Any, Optional, List
from datetime import datetime
import aiohttp
from fastapi import HTTPException, status
from app.core.config import settings
logger = logging.getLogger(__name__)
class LiteLLMClient:
"""Client for communicating with LiteLLM proxy service"""
def __init__(self):
self.base_url = settings.LITELLM_BASE_URL
self.master_key = settings.LITELLM_MASTER_KEY
self.session: Optional[aiohttp.ClientSession] = None
self.timeout = aiohttp.ClientTimeout(total=600) # 10 minutes timeout
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create aiohttp session"""
if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession(
timeout=self.timeout,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.master_key}"
}
)
return self.session
async def close(self):
"""Close the HTTP session"""
if self.session and not self.session.closed:
await self.session.close()
async def health_check(self) -> Dict[str, Any]:
"""Check LiteLLM proxy health"""
try:
session = await self._get_session()
async with session.get(f"{self.base_url}/health") as response:
if response.status == 200:
return await response.json()
else:
logger.error(f"LiteLLM health check failed: {response.status}")
return {"status": "unhealthy", "error": f"HTTP {response.status}"}
except Exception as e:
logger.error(f"LiteLLM health check error: {e}")
return {"status": "unhealthy", "error": str(e)}
async def get_models(self) -> List[Dict[str, Any]]:
"""Get available models from LiteLLM"""
try:
session = await self._get_session()
async with session.get(f"{self.base_url}/models") as response:
if response.status == 200:
data = await response.json()
return data.get("data", [])
else:
logger.error(f"Failed to get models: {response.status}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
except aiohttp.ClientError as e:
logger.error(f"LiteLLM models request error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
async def create_chat_completion(
self,
model: str,
messages: List[Dict[str, str]],
user_id: str,
api_key_id: int,
**kwargs
) -> Dict[str, Any]:
"""Create chat completion via LiteLLM proxy"""
try:
# Prepare request payload
payload = {
"model": model,
"messages": messages,
"user": f"user_{user_id}", # User identifier for tracking
"metadata": {
"api_key_id": api_key_id,
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat()
},
**kwargs
}
session = await self._get_session()
async with session.post(
f"{self.base_url}/chat/completions",
json=payload
) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
logger.error(f"LiteLLM chat completion failed: {response.status} - {error_text}")
# Handle specific error cases
if response.status == 401:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
elif response.status == 429:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
elif response.status == 400:
try:
error_data = await response.json()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_data.get("error", {}).get("message", "Bad request")
)
except:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request"
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service error"
)
except aiohttp.ClientError as e:
logger.error(f"LiteLLM chat completion request error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
async def create_embedding(
self,
model: str,
input_text: str,
user_id: str,
api_key_id: int,
**kwargs
) -> Dict[str, Any]:
"""Create embedding via LiteLLM proxy"""
try:
payload = {
"model": model,
"input": input_text,
"user": f"user_{user_id}",
"metadata": {
"api_key_id": api_key_id,
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat()
},
**kwargs
}
session = await self._get_session()
async with session.post(
f"{self.base_url}/embeddings",
json=payload
) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
logger.error(f"LiteLLM embedding failed: {response.status} - {error_text}")
if response.status == 401:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
elif response.status == 429:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
else:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service error"
)
except aiohttp.ClientError as e:
logger.error(f"LiteLLM embedding request error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
async def get_models(self) -> List[Dict[str, Any]]:
"""Get available models from LiteLLM proxy"""
try:
session = await self._get_session()
async with session.get(f"{self.base_url}/models") as response:
if response.status == 200:
data = await response.json()
# Return models with exact names from upstream providers
models = data.get("data", [])
# Pass through model names exactly as they come from upstream
# Don't modify model IDs - keep them as the original provider names
processed_models = []
for model in models:
# Keep the exact model ID from upstream provider
processed_models.append({
"id": model.get("id", ""), # Exact model name from provider
"object": model.get("object", "model"),
"created": model.get("created", 1677610602),
"owned_by": model.get("owned_by", "openai")
})
return processed_models
else:
error_text = await response.text()
logger.error(f"LiteLLM models request failed: {response.status} - {error_text}")
return []
except aiohttp.ClientError as e:
logger.error(f"LiteLLM models request error: {e}")
return []
async def proxy_request(
self,
method: str,
endpoint: str,
payload: Dict[str, Any],
user_id: str,
api_key_id: int
) -> Dict[str, Any]:
"""Generic proxy request to LiteLLM"""
try:
# Add metadata to payload
if isinstance(payload, dict):
payload["metadata"] = {
"api_key_id": api_key_id,
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat()
}
if "user" not in payload:
payload["user"] = f"user_{user_id}"
session = await self._get_session()
# Make the request
async with session.request(
method,
f"{self.base_url}/{endpoint.lstrip('/')}",
json=payload if method.upper() in ['POST', 'PUT', 'PATCH'] else None,
params=payload if method.upper() == 'GET' else None
) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
logger.error(f"LiteLLM proxy request failed: {response.status} - {error_text}")
if response.status == 401:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
elif response.status == 429:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded"
)
else:
raise HTTPException(
status_code=response.status,
detail=f"LiteLLM service error: {error_text}"
)
except aiohttp.ClientError as e:
logger.error(f"LiteLLM proxy request error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="LiteLLM service unavailable"
)
async def list_models(self) -> List[str]:
"""Get list of available model names/IDs"""
try:
models_data = await self.get_models()
return [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")]
except Exception as e:
logger.error(f"Error listing model names: {str(e)}")
return []
# Global LiteLLM client instance
litellm_client = LiteLLMClient()

View File

@@ -23,7 +23,9 @@ from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.logging import get_logger from app.core.logging import get_logger
from app.services.litellm_client import LiteLLMClient from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError
from app.services.base_module import BaseModule, Permission from app.services.base_module import BaseModule, Permission
from app.models.user import User from app.models.user import User
from app.models.chatbot import ChatbotInstance as DBChatbotInstance, ChatbotConversation as DBConversation, ChatbotMessage as DBMessage, ChatbotAnalytics from app.models.chatbot import ChatbotInstance as DBChatbotInstance, ChatbotConversation as DBConversation, ChatbotMessage as DBMessage, ChatbotAnalytics
@@ -32,7 +34,8 @@ from app.db.database import get_db
from app.core.config import settings from app.core.config import settings
# Import protocols for type hints and dependency injection # Import protocols for type hints and dependency injection
from ..protocols import RAGServiceProtocol, LiteLLMClientProtocol from ..protocols import RAGServiceProtocol
# Note: LiteLLMClientProtocol replaced with direct LLM service usage
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -131,10 +134,8 @@ class ChatbotInstance(BaseModel):
class ChatbotModule(BaseModule): class ChatbotModule(BaseModule):
"""Main chatbot module implementation""" """Main chatbot module implementation"""
def __init__(self, litellm_client: Optional[LiteLLMClientProtocol] = None, def __init__(self, rag_service: Optional[RAGServiceProtocol] = None):
rag_service: Optional[RAGServiceProtocol] = None):
super().__init__("chatbot") super().__init__("chatbot")
self.litellm_client = litellm_client
self.rag_module = rag_service # Keep same name for compatibility self.rag_module = rag_service # Keep same name for compatibility
self.db_session = None self.db_session = None
@@ -145,15 +146,10 @@ class ChatbotModule(BaseModule):
"""Initialize the chatbot module""" """Initialize the chatbot module"""
await super().initialize(**kwargs) await super().initialize(**kwargs)
# Get dependencies from global services if not already injected # Initialize the LLM service
if not self.litellm_client: await llm_service.initialize()
try:
from app.services.litellm_client import litellm_client
self.litellm_client = litellm_client
logger.info("LiteLLM client injected from global service")
except Exception as e:
logger.warning(f"Could not inject LiteLLM client: {e}")
# Get RAG module dependency if not already injected
if not self.rag_module: if not self.rag_module:
try: try:
# Try to get RAG module from module manager # Try to get RAG module from module manager
@@ -168,19 +164,16 @@ class ChatbotModule(BaseModule):
await self._load_prompt_templates() await self._load_prompt_templates()
logger.info("Chatbot module initialized") logger.info("Chatbot module initialized")
logger.info(f"LiteLLM client available after init: {self.litellm_client is not None}") logger.info(f"LLM service available: {llm_service._initialized}")
logger.info(f"RAG module available after init: {self.rag_module is not None}") logger.info(f"RAG module available after init: {self.rag_module is not None}")
logger.info(f"Loaded {len(self.system_prompts)} prompt templates") logger.info(f"Loaded {len(self.system_prompts)} prompt templates")
async def _ensure_dependencies(self): async def _ensure_dependencies(self):
"""Lazy load dependencies if not available""" """Lazy load dependencies if not available"""
if not self.litellm_client: # Ensure LLM service is initialized
try: if not llm_service._initialized:
from app.services.litellm_client import litellm_client await llm_service.initialize()
self.litellm_client = litellm_client logger.info("LLM service lazy loaded")
logger.info("LiteLLM client lazy loaded")
except Exception as e:
logger.warning(f"Could not lazy load LiteLLM client: {e}")
if not self.rag_module: if not self.rag_module:
try: try:
@@ -468,45 +461,58 @@ class ChatbotModule(BaseModule):
logger.info(msg['content']) logger.info(msg['content'])
logger.info("=== END COMPREHENSIVE LLM REQUEST ===") logger.info("=== END COMPREHENSIVE LLM REQUEST ===")
if self.litellm_client:
try: try:
logger.info("Calling LiteLLM client create_chat_completion...") logger.info("Calling LLM service create_chat_completion...")
response = await self.litellm_client.create_chat_completion(
model=config.model,
messages=messages,
user_id="chatbot_user",
api_key_id="chatbot_api_key",
temperature=config.temperature,
max_tokens=config.max_tokens
)
logger.info(f"LiteLLM response received, response keys: {list(response.keys())}")
# Extract response content from the LiteLLM response format # Convert messages to LLM service format
if 'choices' in response and response['choices']: llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
content = response['choices'][0]['message']['content']
# Create LLM service request
llm_request = LLMChatRequest(
model=config.model,
messages=llm_messages,
temperature=config.temperature,
max_tokens=config.max_tokens,
user_id="chatbot_user",
api_key_id=0 # Chatbot module uses internal service
)
# Make request to LLM service
llm_response = await llm_service.create_chat_completion(llm_request)
# Extract response content
if llm_response.choices:
content = llm_response.choices[0].message.content
logger.info(f"Response content length: {len(content)}") logger.info(f"Response content length: {len(content)}")
# Always log response for debugging # Always log response for debugging
logger.info("=== COMPREHENSIVE LLM RESPONSE ===") logger.info("=== COMPREHENSIVE LLM RESPONSE ===")
logger.info(f"Response content ({len(content)} chars):") logger.info(f"Response content ({len(content)} chars):")
logger.info(content) logger.info(content)
if 'usage' in response: if llm_response.usage:
usage = response['usage'] usage = llm_response.usage
logger.info(f"Token usage - Prompt: {usage.get('prompt_tokens', 'N/A')}, Completion: {usage.get('completion_tokens', 'N/A')}, Total: {usage.get('total_tokens', 'N/A')}") logger.info(f"Token usage - Prompt: {usage.prompt_tokens}, Completion: {usage.completion_tokens}, Total: {usage.total_tokens}")
if sources: if sources:
logger.info(f"RAG sources included: {len(sources)} documents") logger.info(f"RAG sources included: {len(sources)} documents")
logger.info("=== END COMPREHENSIVE LLM RESPONSE ===") logger.info("=== END COMPREHENSIVE LLM RESPONSE ===")
return content, sources return content, sources
else: else:
logger.warning("No choices in LiteLLM response") logger.warning("No choices in LLM response")
return "I received an empty response from the AI model.", sources return "I received an empty response from the AI model.", sources
except SecurityError as e:
logger.error(f"Security error in LLM completion: {e}")
raise HTTPException(status_code=400, detail=f"Security validation failed: {e.message}")
except ProviderError as e:
logger.error(f"Provider error in LLM completion: {e}")
raise HTTPException(status_code=503, detail="LLM service temporarily unavailable")
except LLMError as e:
logger.error(f"LLM service error: {e}")
raise HTTPException(status_code=500, detail="LLM service error")
except Exception as e: except Exception as e:
logger.error(f"LiteLLM completion failed: {e}") logger.error(f"LLM completion failed: {e}")
raise e # Return fallback if available
else:
logger.warning("No LiteLLM client available, using fallback")
# Fallback if no LLM client
return "I'm currently unable to process your request. Please try again later.", None return "I'm currently unable to process your request. Please try again later.", None
def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig, def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig,
@@ -685,7 +691,7 @@ class ChatbotModule(BaseModule):
# Lazy load dependencies # Lazy load dependencies
await self._ensure_dependencies() await self._ensure_dependencies()
logger.info(f"LiteLLM client available: {self.litellm_client is not None}") logger.info(f"LLM service available: {llm_service._initialized}")
logger.info(f"RAG module available: {self.rag_module is not None}") logger.info(f"RAG module available: {self.rag_module is not None}")
try: try:
@@ -884,10 +890,9 @@ class ChatbotModule(BaseModule):
# Module factory function # Module factory function
def create_module(litellm_client: Optional[LiteLLMClientProtocol] = None, def create_module(rag_service: Optional[RAGServiceProtocol] = None) -> ChatbotModule:
rag_service: Optional[RAGServiceProtocol] = None) -> ChatbotModule:
"""Factory function to create chatbot module instance""" """Factory function to create chatbot module instance"""
return ChatbotModule(litellm_client=litellm_client, rag_service=rag_service) return ChatbotModule(rag_service=rag_service)
# Create module instance (dependencies will be injected via factory) # Create module instance (dependencies will be injected via factory)
chatbot_module = ChatbotModule() chatbot_module = ChatbotModule()

View File

@@ -401,7 +401,7 @@ class RAGModule(BaseModule):
"""Initialize embedding model""" """Initialize embedding model"""
from app.services.embedding_service import embedding_service from app.services.embedding_service import embedding_service
# Use privatemode-embeddings for LiteLLM integration # Use privatemode-embeddings for LLM service integration
model_name = self.config.get("embedding_model", "privatemode-embeddings") model_name = self.config.get("embedding_model", "privatemode-embeddings")
embedding_service.model_name = model_name embedding_service.model_name = model_name

View File

@@ -22,13 +22,16 @@ from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import select from sqlalchemy import select
from app.core.logging import get_logger from app.core.logging import get_logger
from app.services.litellm_client import LiteLLMClient from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError
from app.services.base_module import Permission from app.services.base_module import Permission
from app.db.database import SessionLocal from app.db.database import SessionLocal
from app.models.workflow import WorkflowDefinition as DBWorkflowDefinition, WorkflowExecution as DBWorkflowExecution from app.models.workflow import WorkflowDefinition as DBWorkflowDefinition, WorkflowExecution as DBWorkflowExecution
# Import protocols for type hints and dependency injection # Import protocols for type hints and dependency injection
from ..protocols import ChatbotServiceProtocol, LiteLLMClientProtocol from ..protocols import ChatbotServiceProtocol
# Note: LiteLLMClientProtocol replaced with direct LLM service usage
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -234,8 +237,7 @@ class WorkflowExecution(BaseModel):
class WorkflowEngine: class WorkflowEngine:
"""Core workflow execution engine""" """Core workflow execution engine"""
def __init__(self, litellm_client: LiteLLMClient, chatbot_service: Optional[ChatbotServiceProtocol] = None): def __init__(self, chatbot_service: Optional[ChatbotServiceProtocol] = None):
self.litellm_client = litellm_client
self.chatbot_service = chatbot_service self.chatbot_service = chatbot_service
self.executions: Dict[str, WorkflowExecution] = {} self.executions: Dict[str, WorkflowExecution] = {}
self.workflows: Dict[str, WorkflowDefinition] = {} self.workflows: Dict[str, WorkflowDefinition] = {}
@@ -343,15 +345,23 @@ class WorkflowEngine:
# Template message content with context variables # Template message content with context variables
messages = self._template_messages(llm_step.messages, context.variables) messages = self._template_messages(llm_step.messages, context.variables)
# Make LLM call # Convert messages to LLM service format
response = await self.litellm_client.chat_completion( llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
# Create LLM service request
llm_request = LLMChatRequest(
model=llm_step.model, model=llm_step.model,
messages=messages, messages=llm_messages,
**llm_step.parameters user_id="workflow_user",
api_key_id=0, # Workflow module uses internal service
**{k: v for k, v in llm_step.parameters.items() if k in ['temperature', 'max_tokens', 'top_p', 'frequency_penalty', 'presence_penalty', 'stop']}
) )
# Make LLM call
response = await llm_service.create_chat_completion(llm_request)
# Store result # Store result
result = response.get("choices", [{}])[0].get("message", {}).get("content", "") result = response.choices[0].message.content if response.choices else ""
context.variables[llm_step.output_variable] = result context.variables[llm_step.output_variable] = result
context.results[step.id] = result context.results[step.id] = result
@@ -631,16 +641,21 @@ class WorkflowEngine:
messages = [{"role": "user", "content": self._template_string(prompt, variables)}] messages = [{"role": "user", "content": self._template_string(prompt, variables)}]
response = await self.litellm_client.create_chat_completion( # Convert to LLM service format
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
llm_request = LLMChatRequest(
model=step.model, model=step.model,
messages=messages, messages=llm_messages,
user_id="workflow_system", user_id="workflow_system",
api_key_id="workflow", api_key_id=0,
temperature=step.temperature, temperature=step.temperature,
max_tokens=step.max_tokens max_tokens=step.max_tokens
) )
return response.get("choices", [{}])[0].get("message", {}).get("content", "") response = await llm_service.create_chat_completion(llm_request)
return response.choices[0].message.content if response.choices else ""
async def _generate_brand_names(self, variables: Dict[str, Any], step: AIGenerationStep) -> List[Dict[str, str]]: async def _generate_brand_names(self, variables: Dict[str, Any], step: AIGenerationStep) -> List[Dict[str, str]]:
"""Generate brand names for a specific category""" """Generate brand names for a specific category"""
@@ -687,16 +702,21 @@ class WorkflowEngine:
messages = [{"role": "user", "content": self._template_string(prompt, variables)}] messages = [{"role": "user", "content": self._template_string(prompt, variables)}]
response = await self.litellm_client.create_chat_completion( # Convert to LLM service format
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
llm_request = LLMChatRequest(
model=step.model, model=step.model,
messages=messages, messages=llm_messages,
user_id="workflow_system", user_id="workflow_system",
api_key_id="workflow", api_key_id=0,
temperature=step.temperature, temperature=step.temperature,
max_tokens=step.max_tokens max_tokens=step.max_tokens
) )
return response.get("choices", [{}])[0].get("message", {}).get("content", "") response = await llm_service.create_chat_completion(llm_request)
return response.choices[0].message.content if response.choices else ""
async def _generate_custom_prompt(self, variables: Dict[str, Any], step: AIGenerationStep) -> str: async def _generate_custom_prompt(self, variables: Dict[str, Any], step: AIGenerationStep) -> str:
"""Generate content using custom prompt template""" """Generate content using custom prompt template"""
@@ -705,16 +725,21 @@ class WorkflowEngine:
messages = [{"role": "user", "content": self._template_string(step.prompt_template, variables)}] messages = [{"role": "user", "content": self._template_string(step.prompt_template, variables)}]
response = await self.litellm_client.create_chat_completion( # Convert to LLM service format
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
llm_request = LLMChatRequest(
model=step.model, model=step.model,
messages=messages, messages=llm_messages,
user_id="workflow_system", user_id="workflow_system",
api_key_id="workflow", api_key_id=0,
temperature=step.temperature, temperature=step.temperature,
max_tokens=step.max_tokens max_tokens=step.max_tokens
) )
return response.get("choices", [{}])[0].get("message", {}).get("content", "") response = await llm_service.create_chat_completion(llm_request)
return response.choices[0].message.content if response.choices else ""
async def _execute_aggregate_step(self, step: WorkflowStep, context: WorkflowContext): async def _execute_aggregate_step(self, step: WorkflowStep, context: WorkflowContext):
"""Execute aggregate step to combine multiple inputs""" """Execute aggregate step to combine multiple inputs"""

View File

@@ -23,7 +23,8 @@ from app.core.config import settings
from app.db.database import async_session_factory from app.db.database import async_session_factory
from app.models.user import User from app.models.user import User
from app.models.chatbot import ChatbotInstance from app.models.chatbot import ChatbotInstance
from app.services.litellm_client import LiteLLMClient from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
import base64 import base64
import os import os
@@ -65,8 +66,8 @@ class ZammadModule(BaseModule):
try: try:
logger.info("Initializing Zammad module...") logger.info("Initializing Zammad module...")
# Initialize LLM client for chatbot integration # Initialize LLM service for chatbot integration
self.llm_client = LiteLLMClient() # Note: llm_service is already a global singleton, no need to create instance
# Create HTTP session pool for Zammad API calls # Create HTTP session pool for Zammad API calls
timeout = aiohttp.ClientTimeout(total=60, connect=10) timeout = aiohttp.ClientTimeout(total=60, connect=10)
@@ -597,19 +598,21 @@ class ZammadModule(BaseModule):
} }
] ]
# Generate summary using LLM client # Generate summary using new LLM service
response = await self.llm_client.create_chat_completion( chat_request = LLMChatRequest(
messages=messages,
model=await self._get_chatbot_model(config.chatbot_id), model=await self._get_chatbot_model(config.chatbot_id),
user_id=str(config.user_id), messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages],
api_key_id=0, # Using 0 for module requests
temperature=0.3, temperature=0.3,
max_tokens=500 max_tokens=500,
user_id=str(config.user_id),
api_key_id=0 # Using 0 for module requests
) )
# Extract content from LiteLLM response response = await llm_service.create_chat_completion(chat_request)
if "choices" in response and len(response["choices"]) > 0:
return response["choices"][0]["message"]["content"].strip() # Extract content from new LLM service response
if response.choices and len(response.choices) > 0:
return response.choices[0].message.content.strip()
return "Unable to generate summary." return "Unable to generate summary."

View File

@@ -1,132 +0,0 @@
"""
Test LLM API endpoints.
"""
import pytest
from httpx import AsyncClient
from unittest.mock import patch, AsyncMock
class TestLLMEndpoints:
"""Test LLM API endpoints."""
@pytest.mark.asyncio
async def test_chat_completion_success(self, client: AsyncClient):
"""Test successful chat completion."""
# Mock the LiteLLM client response
mock_response = {
"choices": [
{
"message": {
"content": "Hello! How can I help you today?",
"role": "assistant"
}
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 15,
"total_tokens": 25
}
}
with patch("app.services.litellm_client.LiteLLMClient.create_chat_completion") as mock_chat:
mock_chat.return_value = mock_response
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
assert "choices" in data
assert data["choices"][0]["message"]["content"] == "Hello! How can I help you today?"
@pytest.mark.asyncio
async def test_chat_completion_unauthorized(self, client: AsyncClient):
"""Test chat completion without API key."""
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
]
}
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_embeddings_success(self, client: AsyncClient):
"""Test successful embeddings generation."""
mock_response = {
"data": [
{
"embedding": [0.1, 0.2, 0.3],
"index": 0
}
],
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}
with patch("app.services.litellm_client.LiteLLMClient.create_embedding") as mock_embeddings:
mock_embeddings.return_value = mock_response
response = await client.post(
"/api/v1/llm/embeddings",
json={
"model": "text-embedding-ada-002",
"input": "Hello world"
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
assert "data" in data
assert len(data["data"][0]["embedding"]) == 3
@pytest.mark.asyncio
async def test_budget_exceeded(self, client: AsyncClient):
"""Test budget exceeded scenario."""
with patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_check:
mock_check.side_effect = Exception("Budget exceeded")
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 402 # Payment required
@pytest.mark.asyncio
async def test_model_validation(self, client: AsyncClient):
"""Test model validation."""
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "invalid-model",
"messages": [
{"role": "user", "content": "Hello"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 400

View File

@@ -0,0 +1,496 @@
"""
Integration tests for the new LLM service.
Tests end-to-end functionality including provider integration, security, and performance.
"""
import pytest
import asyncio
import time
from httpx import AsyncClient
from unittest.mock import patch, AsyncMock, MagicMock
import json
class TestLLMServiceIntegration:
"""Integration tests for LLM service."""
@pytest.mark.asyncio
async def test_full_chat_flow(self, client: AsyncClient):
"""Test complete chat completion flow with security and budget checks."""
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
# Mock successful LLM service response
mock_response = ChatCompletionResponse(
id="test-completion-123",
object="chat.completion",
created=int(time.time()),
model="privatemode-llama-3-70b",
choices=[
ChatChoice(
index=0,
message=ChatMessage(
role="assistant",
content="Hello! I'm a TEE-protected AI assistant. How can I help you today?"
),
finish_reason="stop"
)
],
usage=Usage(
prompt_tokens=25,
completion_tokens=15,
total_tokens=40
),
security_analysis={
"risk_score": 0.1,
"threats_detected": [],
"risk_level": "low",
"analysis_time_ms": 12.5
}
)
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat, \
patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_budget:
mock_chat.return_value = mock_response
mock_budget.return_value = True # Budget check passes
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "privatemode-llama-3-70b",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, what are your capabilities?"}
],
"temperature": 0.7,
"max_tokens": 150
},
headers={"Authorization": "Bearer test-api-key"}
)
# Verify response structure
assert response.status_code == 200
data = response.json()
# Check standard OpenAI-compatible fields
assert "id" in data
assert "object" in data
assert "created" in data
assert "model" in data
assert "choices" in data
assert "usage" in data
# Check security integration
assert "security_analysis" in data
assert data["security_analysis"]["risk_level"] == "low"
# Verify content
assert len(data["choices"]) == 1
assert data["choices"][0]["message"]["role"] == "assistant"
assert "TEE-protected" in data["choices"][0]["message"]["content"]
# Verify usage tracking
assert data["usage"]["total_tokens"] == 40
assert data["usage"]["prompt_tokens"] == 25
assert data["usage"]["completion_tokens"] == 15
@pytest.mark.asyncio
async def test_embedding_integration(self, client: AsyncClient):
"""Test embedding generation with fallback handling."""
from app.services.llm.models import EmbeddingResponse, EmbeddingData, Usage
# Create realistic 1024-dimensional embedding
embedding_vector = [0.1 * i for i in range(1024)]
mock_response = EmbeddingResponse(
object="list",
data=[
EmbeddingData(
object="embedding",
embedding=embedding_vector,
index=0
)
],
model="privatemode-embeddings",
usage=Usage(
prompt_tokens=8,
total_tokens=8
)
)
with patch("app.services.llm.service.llm_service.create_embedding") as mock_embedding:
mock_embedding.return_value = mock_response
response = await client.post(
"/api/v1/llm/embeddings",
json={
"model": "privatemode-embeddings",
"input": "This is a test document for embedding generation."
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
# Verify embedding structure
assert "object" in data
assert "data" in data
assert "usage" in data
assert len(data["data"]) == 1
assert len(data["data"][0]["embedding"]) == 1024
assert data["data"][0]["index"] == 0
@pytest.mark.asyncio
async def test_provider_health_integration(self, client: AsyncClient):
"""Test provider health monitoring integration."""
mock_status = {
"privatemode": {
"provider": "PrivateMode.ai",
"status": "healthy",
"latency_ms": 245.8,
"success_rate": 0.987,
"last_check": "2025-01-01T12:00:00Z",
"error_message": None,
"models_available": [
"privatemode-llama-3-70b",
"privatemode-claude-3-sonnet",
"privatemode-gpt-4o",
"privatemode-embeddings"
]
}
}
with patch("app.services.llm.service.llm_service.get_provider_status") as mock_provider:
mock_provider.return_value = mock_status
response = await client.get(
"/api/v1/llm/providers/status",
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
# Check response structure
assert "data" in data
assert "privatemode" in data["data"]
provider_data = data["data"]["privatemode"]
assert provider_data["status"] == "healthy"
assert provider_data["latency_ms"] < 300 # Reasonable latency
assert provider_data["success_rate"] > 0.95 # High success rate
assert len(provider_data["models_available"]) >= 4
@pytest.mark.asyncio
async def test_error_handling_and_fallback(self, client: AsyncClient):
"""Test error handling and fallback scenarios."""
# Test provider unavailable scenario
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
mock_chat.side_effect = Exception("Provider temporarily unavailable")
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "privatemode-llama-3-70b",
"messages": [
{"role": "user", "content": "Hello"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
# Should return error but not crash
assert response.status_code in [500, 503] # Server error or service unavailable
@pytest.mark.asyncio
async def test_security_threat_detection(self, client: AsyncClient):
"""Test security threat detection integration."""
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
# Mock response with security threat detected
mock_response = ChatCompletionResponse(
id="test-completion-security",
object="chat.completion",
created=int(time.time()),
model="privatemode-llama-3-70b",
choices=[
ChatChoice(
index=0,
message=ChatMessage(
role="assistant",
content="I cannot help with that request as it violates security policies."
),
finish_reason="stop"
)
],
usage=Usage(
prompt_tokens=15,
completion_tokens=12,
total_tokens=27
),
security_analysis={
"risk_score": 0.8,
"threats_detected": ["potential_malicious_code"],
"risk_level": "high",
"blocked": True,
"analysis_time_ms": 45.2
}
)
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
mock_chat.return_value = mock_response
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "privatemode-llama-3-70b",
"messages": [
{"role": "user", "content": "How to create malicious code?"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200 # Request succeeds but content is filtered
data = response.json()
# Verify security analysis
assert "security_analysis" in data
assert data["security_analysis"]["risk_level"] == "high"
assert data["security_analysis"]["blocked"] is True
assert "malicious" in data["security_analysis"]["threats_detected"][0]
@pytest.mark.asyncio
async def test_performance_characteristics(self, client: AsyncClient):
"""Test performance characteristics of the LLM service."""
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
# Mock fast response
mock_response = ChatCompletionResponse(
id="test-perf",
object="chat.completion",
created=int(time.time()),
model="privatemode-llama-3-70b",
choices=[
ChatChoice(
index=0,
message=ChatMessage(
role="assistant",
content="Quick response for performance testing."
),
finish_reason="stop"
)
],
usage=Usage(
prompt_tokens=10,
completion_tokens=8,
total_tokens=18
)
)
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
mock_chat.return_value = mock_response
# Measure response time
start_time = time.time()
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "privatemode-llama-3-70b",
"messages": [
{"role": "user", "content": "Quick test"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
response_time = time.time() - start_time
assert response.status_code == 200
# API should respond quickly (mocked, so should be very fast)
assert response_time < 1.0 # Less than 1 second for mocked response
@pytest.mark.asyncio
async def test_model_capabilities_detection(self, client: AsyncClient):
"""Test model capabilities detection and reporting."""
from app.services.llm.models import Model
mock_models = [
Model(
id="privatemode-llama-3-70b",
object="model",
created=1234567890,
owned_by="PrivateMode.ai",
provider="PrivateMode.ai",
capabilities=["tee", "chat", "function_calling"],
context_window=32768,
max_output_tokens=4096,
supports_streaming=True,
supports_function_calling=True
),
Model(
id="privatemode-embeddings",
object="model",
created=1234567890,
owned_by="PrivateMode.ai",
provider="PrivateMode.ai",
capabilities=["tee", "embeddings"],
context_window=512,
supports_streaming=False,
supports_function_calling=False
)
]
with patch("app.services.llm.service.llm_service.get_models") as mock_models_call:
mock_models_call.return_value = mock_models
response = await client.get(
"/api/v1/llm/models",
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
# Verify model capabilities
assert len(data["data"]) == 2
# Check chat model capabilities
chat_model = next(m for m in data["data"] if m["id"] == "privatemode-llama-3-70b")
assert "tee" in chat_model["capabilities"]
assert "chat" in chat_model["capabilities"]
assert chat_model["supports_streaming"] is True
assert chat_model["supports_function_calling"] is True
assert chat_model["context_window"] == 32768
# Check embedding model capabilities
embed_model = next(m for m in data["data"] if m["id"] == "privatemode-embeddings")
assert "tee" in embed_model["capabilities"]
assert "embeddings" in embed_model["capabilities"]
assert embed_model["supports_streaming"] is False
assert embed_model["context_window"] == 512
@pytest.mark.asyncio
async def test_concurrent_requests(self, client: AsyncClient):
"""Test handling of concurrent requests."""
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
mock_response = ChatCompletionResponse(
id="test-concurrent",
object="chat.completion",
created=int(time.time()),
model="privatemode-llama-3-70b",
choices=[
ChatChoice(
index=0,
message=ChatMessage(
role="assistant",
content="Concurrent response"
),
finish_reason="stop"
)
],
usage=Usage(
prompt_tokens=5,
completion_tokens=3,
total_tokens=8
)
)
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
mock_chat.return_value = mock_response
# Create multiple concurrent requests
tasks = []
for i in range(5):
task = client.post(
"/api/v1/llm/chat/completions",
json={
"model": "privatemode-llama-3-70b",
"messages": [
{"role": "user", "content": f"Concurrent test {i}"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
tasks.append(task)
# Execute all requests concurrently
responses = await asyncio.gather(*tasks)
# Verify all requests succeeded
for response in responses:
assert response.status_code == 200
data = response.json()
assert "choices" in data
assert data["choices"][0]["message"]["content"] == "Concurrent response"
@pytest.mark.asyncio
async def test_budget_enforcement_integration(self, client: AsyncClient):
"""Test budget enforcement integration with LLM service."""
# Test budget exceeded scenario
with patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_budget:
mock_budget.side_effect = Exception("Monthly budget limit exceeded")
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "privatemode-llama-3-70b",
"messages": [
{"role": "user", "content": "Test budget enforcement"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 402 # Payment required
# Test budget warning scenario
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
mock_response = ChatCompletionResponse(
id="test-budget-warning",
object="chat.completion",
created=int(time.time()),
model="privatemode-llama-3-70b",
choices=[
ChatChoice(
index=0,
message=ChatMessage(
role="assistant",
content="Response with budget warning"
),
finish_reason="stop"
)
],
usage=Usage(
prompt_tokens=10,
completion_tokens=8,
total_tokens=18
),
budget_warnings=["Approaching monthly budget limit (85% used)"]
)
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat, \
patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_budget:
mock_chat.return_value = mock_response
mock_budget.return_value = True # Budget check passes but with warning
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "privatemode-llama-3-70b",
"messages": [
{"role": "user", "content": "Test budget warning"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
assert "budget_warnings" in data
assert len(data["budget_warnings"]) > 0
assert "85%" in data["budget_warnings"][0]

View File

@@ -0,0 +1,296 @@
"""
Simple validation tests for the new LLM service integration.
Tests basic functionality without complex mocking.
"""
import pytest
from httpx import AsyncClient
from unittest.mock import patch, MagicMock
class TestLLMServiceValidation:
"""Basic validation tests for LLM service integration."""
@pytest.mark.asyncio
async def test_llm_models_endpoint_exists(self, client: AsyncClient):
"""Test that the LLM models endpoint exists and is accessible."""
response = await client.get(
"/api/v1/llm/models",
headers={"Authorization": "Bearer test-api-key"}
)
# Should not return 404 (endpoint exists)
assert response.status_code != 404
# May return 500 or other error due to missing LLM service, but endpoint exists
@pytest.mark.asyncio
async def test_llm_chat_endpoint_exists(self, client: AsyncClient):
"""Test that the LLM chat endpoint exists and is accessible."""
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
# Should not return 404 (endpoint exists)
assert response.status_code != 404
@pytest.mark.asyncio
async def test_llm_embeddings_endpoint_exists(self, client: AsyncClient):
"""Test that the LLM embeddings endpoint exists and is accessible."""
response = await client.post(
"/api/v1/llm/embeddings",
json={
"model": "test-embedding-model",
"input": "Test text"
},
headers={"Authorization": "Bearer test-api-key"}
)
# Should not return 404 (endpoint exists)
assert response.status_code != 404
@pytest.mark.asyncio
async def test_llm_provider_status_endpoint_exists(self, client: AsyncClient):
"""Test that the provider status endpoint exists and is accessible."""
response = await client.get(
"/api/v1/llm/providers/status",
headers={"Authorization": "Bearer test-api-key"}
)
# Should not return 404 (endpoint exists)
assert response.status_code != 404
@pytest.mark.asyncio
async def test_chat_with_mocked_service(self, client: AsyncClient):
"""Test chat completion with mocked LLM service."""
from app.services.llm.models import ChatResponse, ChatChoice, ChatMessage, TokenUsage
# Mock successful response
mock_response = ChatResponse(
id="test-123",
object="chat.completion",
created=1234567890,
model="privatemode-llama-3-70b",
provider="PrivateMode.ai",
choices=[
ChatChoice(
index=0,
message=ChatMessage(
role="assistant",
content="Hello! How can I help you?"
),
finish_reason="stop"
)
],
usage=TokenUsage(
prompt_tokens=10,
completion_tokens=8,
total_tokens=18
),
security_check=True,
risk_score=0.1,
detected_patterns=[],
latency_ms=250.5
)
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
mock_chat.return_value = mock_response
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "privatemode-llama-3-70b",
"messages": [
{"role": "user", "content": "Hello"}
]
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
# Verify basic response structure
assert "id" in data
assert "choices" in data
assert len(data["choices"]) == 1
assert data["choices"][0]["message"]["content"] == "Hello! How can I help you?"
@pytest.mark.asyncio
async def test_embedding_with_mocked_service(self, client: AsyncClient):
"""Test embedding generation with mocked LLM service."""
from app.services.llm.models import EmbeddingResponse, EmbeddingData, TokenUsage
# Create a simple embedding vector
embedding_vector = [0.1, 0.2, 0.3] * 341 + [0.1, 0.2, 0.3] # 1024 dimensions
mock_response = EmbeddingResponse(
object="list",
data=[
EmbeddingData(
object="embedding",
index=0,
embedding=embedding_vector
)
],
model="privatemode-embeddings",
provider="PrivateMode.ai",
usage=TokenUsage(
prompt_tokens=5,
completion_tokens=0,
total_tokens=5
),
security_check=True,
risk_score=0.0,
latency_ms=150.0
)
with patch("app.services.llm.service.llm_service.create_embedding") as mock_embedding:
mock_embedding.return_value = mock_response
response = await client.post(
"/api/v1/llm/embeddings",
json={
"model": "privatemode-embeddings",
"input": "Test text for embedding"
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
# Verify basic response structure
assert "data" in data
assert len(data["data"]) == 1
assert len(data["data"][0]["embedding"]) == 1024
@pytest.mark.asyncio
async def test_models_with_mocked_service(self, client: AsyncClient):
"""Test models listing with mocked LLM service."""
from app.services.llm.models import ModelInfo
mock_models = [
ModelInfo(
id="privatemode-llama-3-70b",
object="model",
created=1234567890,
owned_by="PrivateMode.ai",
provider="PrivateMode.ai",
capabilities=["tee", "chat"],
context_window=32768,
supports_streaming=True
),
ModelInfo(
id="privatemode-embeddings",
object="model",
created=1234567890,
owned_by="PrivateMode.ai",
provider="PrivateMode.ai",
capabilities=["tee", "embeddings"],
context_window=512,
supports_streaming=False
)
]
with patch("app.services.llm.service.llm_service.get_models") as mock_models_call:
mock_models_call.return_value = mock_models
response = await client.get(
"/api/v1/llm/models",
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
# Verify basic response structure
assert "data" in data
assert len(data["data"]) == 2
assert data["data"][0]["id"] == "privatemode-llama-3-70b"
@pytest.mark.asyncio
async def test_provider_status_with_mocked_service(self, client: AsyncClient):
"""Test provider status with mocked LLM service."""
mock_status = {
"privatemode": {
"provider": "PrivateMode.ai",
"status": "healthy",
"latency_ms": 250.5,
"success_rate": 0.98,
"last_check": "2025-01-01T12:00:00Z",
"models_available": ["privatemode-llama-3-70b", "privatemode-embeddings"]
}
}
with patch("app.services.llm.service.llm_service.get_provider_status") as mock_provider:
mock_provider.return_value = mock_status
response = await client.get(
"/api/v1/llm/providers/status",
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 200
data = response.json()
# Verify basic response structure
assert "data" in data
assert "privatemode" in data["data"]
assert data["data"]["privatemode"]["status"] == "healthy"
@pytest.mark.asyncio
async def test_unauthorized_access(self, client: AsyncClient):
"""Test that unauthorized requests are properly rejected."""
# Test without authorization header
response = await client.get("/api/v1/llm/models")
assert response.status_code == 401
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}]
}
)
assert response.status_code == 401
response = await client.post(
"/api/v1/llm/embeddings",
json={
"model": "test-model",
"input": "Hello"
}
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_invalid_request_data(self, client: AsyncClient):
"""Test that invalid request data is properly handled."""
# Test invalid JSON structure
response = await client.post(
"/api/v1/llm/chat/completions",
json={
# Missing required fields
"model": "test-model"
# messages field is missing
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 422 # Unprocessable Entity
# Test empty messages
response = await client.post(
"/api/v1/llm/chat/completions",
json={
"model": "test-model",
"messages": [] # Empty messages
},
headers={"Authorization": "Bearer test-api-key"}
)
assert response.status_code == 422 # Unprocessable Entity

View File

@@ -6,7 +6,7 @@ import { Badge } from '@/components/ui/badge'
import { Button } from '@/components/ui/button' import { Button } from '@/components/ui/button'
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
import { Alert, AlertDescription } from '@/components/ui/alert' import { Alert, AlertDescription } from '@/components/ui/alert'
import { RefreshCw, Zap, Info, AlertCircle } from 'lucide-react' import { RefreshCw, Zap, Info, AlertCircle, CheckCircle, XCircle, Clock } from 'lucide-react'
interface Model { interface Model {
id: string id: string
@@ -16,6 +16,22 @@ interface Model {
permission?: any[] permission?: any[]
root?: string root?: string
parent?: string parent?: string
provider?: string
capabilities?: string[]
context_window?: number
max_output_tokens?: number
supports_streaming?: boolean
supports_function_calling?: boolean
}
interface ProviderStatus {
provider: string
status: 'healthy' | 'degraded' | 'unavailable'
latency_ms?: number
success_rate?: number
last_check: string
error_message?: string
models_available: string[]
} }
interface ModelSelectorProps { interface ModelSelectorProps {
@@ -27,6 +43,7 @@ interface ModelSelectorProps {
export default function ModelSelector({ value, onValueChange, filter = 'all', className }: ModelSelectorProps) { export default function ModelSelector({ value, onValueChange, filter = 'all', className }: ModelSelectorProps) {
const [models, setModels] = useState<Model[]>([]) const [models, setModels] = useState<Model[]>([])
const [providerStatus, setProviderStatus] = useState<Record<string, ProviderStatus>>({})
const [loading, setLoading] = useState(true) const [loading, setLoading] = useState(true)
const [error, setError] = useState<string | null>(null) const [error, setError] = useState<string | null>(null)
const [showDetails, setShowDetails] = useState(false) const [showDetails, setShowDetails] = useState(false)
@@ -37,20 +54,31 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
// Get the auth token from localStorage // Get the auth token from localStorage
const token = localStorage.getItem('token') const token = localStorage.getItem('token')
const headers = {
const response = await fetch('/api/llm/models', {
headers: {
'Authorization': token ? `Bearer ${token}` : '', 'Authorization': token ? `Bearer ${token}` : '',
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
})
if (!response.ok) { // Fetch models and provider status in parallel
const [modelsResponse, statusResponse] = await Promise.allSettled([
fetch('/api/llm/models', { headers }),
fetch('/api/llm/providers/status', { headers })
])
// Handle models response
if (modelsResponse.status === 'fulfilled' && modelsResponse.value.ok) {
const modelsData = await modelsResponse.value.json()
setModels(modelsData.data || [])
} else {
throw new Error('Failed to fetch models') throw new Error('Failed to fetch models')
} }
const data = await response.json() // Handle provider status response (optional)
setModels(data.data || []) if (statusResponse.status === 'fulfilled' && statusResponse.value.ok) {
const statusData = await statusResponse.value.json()
setProviderStatus(statusData.data || {})
}
setError(null) setError(null)
} catch (err) { } catch (err) {
setError(err instanceof Error ? err.message : 'Failed to load models') setError(err instanceof Error ? err.message : 'Failed to load models')
@@ -64,30 +92,39 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
}, []) }, [])
const getProviderFromModel = (modelId: string): string => { const getProviderFromModel = (modelId: string): string => {
// PrivateMode models have specific prefixes
if (modelId.startsWith('privatemode-')) return 'PrivateMode.ai'
// Legacy detection for other providers
if (modelId.startsWith('gpt-') || modelId.includes('openai')) return 'OpenAI' if (modelId.startsWith('gpt-') || modelId.includes('openai')) return 'OpenAI'
if (modelId.startsWith('claude-') || modelId.includes('anthropic')) return 'Anthropic' if (modelId.startsWith('claude-') || modelId.includes('anthropic')) return 'Anthropic'
if (modelId.startsWith('gemini-') || modelId.includes('google')) return 'Google' if (modelId.startsWith('gemini-') || modelId.includes('google')) return 'Google'
if (modelId.includes('privatemode')) return 'Privatemode.ai'
if (modelId.includes('cohere')) return 'Cohere' if (modelId.includes('cohere')) return 'Cohere'
if (modelId.includes('mistral')) return 'Mistral' if (modelId.includes('mistral')) return 'Mistral'
if (modelId.includes('llama')) return 'Meta' if (modelId.includes('llama') && !modelId.startsWith('privatemode-')) return 'Meta'
return 'Unknown' return 'Unknown'
} }
const getModelType = (modelId: string): 'chat' | 'embedding' | 'other' => { const getModelType = (modelId: string): 'chat' | 'embedding' | 'other' => {
if (modelId.includes('embedding')) return 'embedding' if (modelId.includes('embedding') || modelId.includes('embed')) return 'embedding'
if (modelId.includes('whisper')) return 'other' // Audio transcription models if (modelId.includes('whisper') || modelId.includes('speech')) return 'other' // Audio models
// PrivateMode and other chat models
if ( if (
modelId.startsWith('privatemode-llama') ||
modelId.startsWith('privatemode-claude') ||
modelId.startsWith('privatemode-gpt') ||
modelId.startsWith('privatemode-gemini') ||
modelId.includes('text-') || modelId.includes('text-') ||
modelId.includes('gpt-') || modelId.includes('gpt-') ||
modelId.includes('claude-') || modelId.includes('claude-') ||
modelId.includes('gemini-') || modelId.includes('gemini-') ||
modelId.includes('privatemode-') ||
modelId.includes('llama') || modelId.includes('llama') ||
modelId.includes('gemma') || modelId.includes('gemma') ||
modelId.includes('qwen') || modelId.includes('qwen') ||
modelId.includes('latest') modelId.includes('latest')
) return 'chat' ) return 'chat'
return 'other' return 'other'
} }
@@ -113,6 +150,28 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
return acc return acc
}, {} as Record<string, Model[]>) }, {} as Record<string, Model[]>)
const getProviderStatusIcon = (provider: string) => {
const status = providerStatus[provider.toLowerCase()]?.status || 'unknown'
switch (status) {
case 'healthy':
return <CheckCircle className="h-3 w-3 text-green-500" />
case 'degraded':
return <Clock className="h-3 w-3 text-yellow-500" />
case 'unavailable':
return <XCircle className="h-3 w-3 text-red-500" />
default:
return <AlertCircle className="h-3 w-3 text-gray-400" />
}
}
const getProviderStatusText = (provider: string) => {
const status = providerStatus[provider.toLowerCase()]
if (!status) return 'Status unknown'
const latencyText = status.latency_ms ? ` (${Math.round(status.latency_ms)}ms)` : ''
return `${status.status.charAt(0).toUpperCase() + status.status.slice(1)}${latencyText}`
}
const selectedModel = models.find(m => m.id === value) const selectedModel = models.find(m => m.id === value)
if (loading) { if (loading) {
@@ -191,16 +250,32 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
<SelectContent> <SelectContent>
{Object.entries(groupedModels).map(([provider, providerModels]) => ( {Object.entries(groupedModels).map(([provider, providerModels]) => (
<div key={provider}> <div key={provider}>
<div className="px-2 py-1.5 text-sm font-semibold text-muted-foreground"> <div className="px-2 py-1.5 text-sm font-semibold text-muted-foreground flex items-center gap-2">
{provider} {getProviderStatusIcon(provider)}
<span>{provider}</span>
<span className="text-xs font-normal text-muted-foreground">
{getProviderStatusText(provider)}
</span>
</div> </div>
{providerModels.map((model) => ( {providerModels.map((model) => (
<SelectItem key={model.id} value={model.id}> <SelectItem key={model.id} value={model.id}>
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<span>{model.id}</span> <span>{model.id}</span>
<div className="flex gap-1">
<Badge variant="outline" className="text-xs"> <Badge variant="outline" className="text-xs">
{getModelCategory(model.id)} {getModelCategory(model.id)}
</Badge> </Badge>
{model.supports_streaming && (
<Badge variant="secondary" className="text-xs">
Streaming
</Badge>
)}
{model.supports_function_calling && (
<Badge variant="secondary" className="text-xs">
Functions
</Badge>
)}
</div>
</div> </div>
</SelectItem> </SelectItem>
))} ))}
@@ -217,7 +292,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
Model Details Model Details
</CardTitle> </CardTitle>
</CardHeader> </CardHeader>
<CardContent className="space-y-2 text-sm"> <CardContent className="space-y-3 text-sm">
<div className="grid grid-cols-2 gap-4"> <div className="grid grid-cols-2 gap-4">
<div> <div>
<span className="font-medium">ID:</span> <span className="font-medium">ID:</span>
@@ -225,7 +300,10 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
</div> </div>
<div> <div>
<span className="font-medium">Provider:</span> <span className="font-medium">Provider:</span>
<div className="text-muted-foreground">{getProviderFromModel(selectedModel.id)}</div> <div className="text-muted-foreground flex items-center gap-1">
{getProviderStatusIcon(getProviderFromModel(selectedModel.id))}
{getProviderFromModel(selectedModel.id)}
</div>
</div> </div>
<div> <div>
<span className="font-medium">Type:</span> <span className="font-medium">Type:</span>
@@ -237,6 +315,40 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
</div> </div>
</div> </div>
{(selectedModel.context_window || selectedModel.max_output_tokens) && (
<div className="grid grid-cols-2 gap-4">
{selectedModel.context_window && (
<div>
<span className="font-medium">Context Window:</span>
<div className="text-muted-foreground">{selectedModel.context_window.toLocaleString()} tokens</div>
</div>
)}
{selectedModel.max_output_tokens && (
<div>
<span className="font-medium">Max Output:</span>
<div className="text-muted-foreground">{selectedModel.max_output_tokens.toLocaleString()} tokens</div>
</div>
)}
</div>
)}
{(selectedModel.supports_streaming || selectedModel.supports_function_calling) && (
<div>
<span className="font-medium">Capabilities:</span>
<div className="flex gap-1 mt-1">
{selectedModel.supports_streaming && (
<Badge variant="secondary" className="text-xs">Streaming</Badge>
)}
{selectedModel.supports_function_calling && (
<Badge variant="secondary" className="text-xs">Function Calling</Badge>
)}
{selectedModel.capabilities?.includes('tee') && (
<Badge variant="outline" className="text-xs border-green-500 text-green-700">TEE Protected</Badge>
)}
</div>
</div>
)}
{selectedModel.created && ( {selectedModel.created && (
<div> <div>
<span className="font-medium">Created:</span> <span className="font-medium">Created:</span>
@@ -252,6 +364,46 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
<div className="text-muted-foreground">{selectedModel.owned_by}</div> <div className="text-muted-foreground">{selectedModel.owned_by}</div>
</div> </div>
)} )}
{/* Provider Status Details */}
{providerStatus[getProviderFromModel(selectedModel.id).toLowerCase()] && (
<div className="border-t pt-3">
<span className="font-medium">Provider Status:</span>
<div className="mt-1 text-xs space-y-1">
{(() => {
const status = providerStatus[getProviderFromModel(selectedModel.id).toLowerCase()]
return (
<>
<div className="flex justify-between">
<span>Status:</span>
<span className={`font-medium ${
status.status === 'healthy' ? 'text-green-600' :
status.status === 'degraded' ? 'text-yellow-600' :
'text-red-600'
}`}>{status.status}</span>
</div>
{status.latency_ms && (
<div className="flex justify-between">
<span>Latency:</span>
<span>{Math.round(status.latency_ms)}ms</span>
</div>
)}
{status.success_rate && (
<div className="flex justify-between">
<span>Success Rate:</span>
<span>{Math.round(status.success_rate * 100)}%</span>
</div>
)}
<div className="flex justify-between">
<span>Last Check:</span>
<span>{new Date(status.last_check).toLocaleTimeString()}</span>
</div>
</>
)
})()}
</div>
</div>
)}
</CardContent> </CardContent>
</Card> </Card>
)} )}

View File

@@ -0,0 +1,373 @@
"use client"
import { useState, useEffect } from 'react'
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
import { Badge } from '@/components/ui/badge'
import { Button } from '@/components/ui/button'
import { Alert, AlertDescription } from '@/components/ui/alert'
import { Progress } from '@/components/ui/progress'
import {
CheckCircle,
XCircle,
Clock,
AlertCircle,
RefreshCw,
Activity,
Zap,
Shield,
Server
} from 'lucide-react'
interface ProviderStatus {
provider: string
status: 'healthy' | 'degraded' | 'unavailable'
latency_ms?: number
success_rate?: number
last_check: string
error_message?: string
models_available: string[]
}
interface LLMMetrics {
total_requests: number
successful_requests: number
failed_requests: number
security_blocked_requests: number
average_latency_ms: number
average_risk_score: number
provider_metrics: Record<string, any>
last_updated: string
}
export default function ProviderHealthDashboard() {
const [providers, setProviders] = useState<Record<string, ProviderStatus>>({})
const [metrics, setMetrics] = useState<LLMMetrics | null>(null)
const [loading, setLoading] = useState(true)
const [error, setError] = useState<string | null>(null)
const fetchData = async () => {
try {
setLoading(true)
const token = localStorage.getItem('token')
const headers = {
'Authorization': token ? `Bearer ${token}` : '',
'Content-Type': 'application/json'
}
const [statusResponse, metricsResponse] = await Promise.allSettled([
fetch('/api/llm/providers/status', { headers }),
fetch('/api/llm/metrics', { headers })
])
// Handle provider status
if (statusResponse.status === 'fulfilled' && statusResponse.value.ok) {
const statusData = await statusResponse.value.json()
setProviders(statusData.data || {})
}
// Handle metrics (optional, might require admin permissions)
if (metricsResponse.status === 'fulfilled' && metricsResponse.value.ok) {
const metricsData = await metricsResponse.value.json()
setMetrics(metricsData.data)
}
setError(null)
} catch (err) {
setError(err instanceof Error ? err.message : 'Failed to load provider data')
} finally {
setLoading(false)
}
}
useEffect(() => {
fetchData()
// Auto-refresh every 30 seconds
const interval = setInterval(fetchData, 30000)
return () => clearInterval(interval)
}, [])
const getStatusIcon = (status: string) => {
switch (status) {
case 'healthy':
return <CheckCircle className="h-5 w-5 text-green-500" />
case 'degraded':
return <Clock className="h-5 w-5 text-yellow-500" />
case 'unavailable':
return <XCircle className="h-5 w-5 text-red-500" />
default:
return <AlertCircle className="h-5 w-5 text-gray-400" />
}
}
const getStatusColor = (status: string) => {
switch (status) {
case 'healthy':
return 'text-green-600 bg-green-50 border-green-200'
case 'degraded':
return 'text-yellow-600 bg-yellow-50 border-yellow-200'
case 'unavailable':
return 'text-red-600 bg-red-50 border-red-200'
default:
return 'text-gray-600 bg-gray-50 border-gray-200'
}
}
const getLatencyColor = (latency: number) => {
if (latency < 500) return 'text-green-600'
if (latency < 2000) return 'text-yellow-600'
return 'text-red-600'
}
if (loading) {
return (
<div className="space-y-4">
<div className="flex items-center justify-between">
<h2 className="text-2xl font-bold">Provider Health Dashboard</h2>
<RefreshCw className="h-5 w-5 animate-spin" />
</div>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{[1, 2, 3].map(i => (
<Card key={i} className="animate-pulse">
<CardHeader className="space-y-2">
<div className="h-4 bg-gray-200 rounded w-3/4"></div>
<div className="h-3 bg-gray-200 rounded w-1/2"></div>
</CardHeader>
<CardContent>
<div className="space-y-2">
<div className="h-3 bg-gray-200 rounded"></div>
<div className="h-3 bg-gray-200 rounded w-2/3"></div>
</div>
</CardContent>
</Card>
))}
</div>
</div>
)
}
if (error) {
return (
<div className="space-y-4">
<div className="flex items-center justify-between">
<h2 className="text-2xl font-bold">Provider Health Dashboard</h2>
<Button onClick={fetchData} size="sm">
<RefreshCw className="h-4 w-4 mr-2" />
Retry
</Button>
</div>
<Alert variant="destructive">
<AlertCircle className="h-4 w-4" />
<AlertDescription>{error}</AlertDescription>
</Alert>
</div>
)
}
const totalProviders = Object.keys(providers).length
const healthyProviders = Object.values(providers).filter(p => p.status === 'healthy').length
const overallHealth = totalProviders > 0 ? (healthyProviders / totalProviders) * 100 : 0
return (
<div className="space-y-6">
<div className="flex items-center justify-between">
<h2 className="text-2xl font-bold">Provider Health Dashboard</h2>
<Button onClick={fetchData} size="sm" disabled={loading}>
<RefreshCw className={`h-4 w-4 mr-2 ${loading ? 'animate-spin' : ''}`} />
Refresh
</Button>
</div>
{/* Overall Health Summary */}
<div className="grid grid-cols-1 md:grid-cols-4 gap-4">
<Card>
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
<CardTitle className="text-sm font-medium">Overall Health</CardTitle>
<Activity className="h-4 w-4 text-muted-foreground" />
</CardHeader>
<CardContent>
<div className="text-2xl font-bold">{Math.round(overallHealth)}%</div>
<Progress value={overallHealth} className="mt-2" />
</CardContent>
</Card>
<Card>
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
<CardTitle className="text-sm font-medium">Healthy Providers</CardTitle>
<CheckCircle className="h-4 w-4 text-green-500" />
</CardHeader>
<CardContent>
<div className="text-2xl font-bold">{healthyProviders}</div>
<p className="text-xs text-muted-foreground">of {totalProviders} providers</p>
</CardContent>
</Card>
{metrics && (
<>
<Card>
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
<CardTitle className="text-sm font-medium">Success Rate</CardTitle>
<Zap className="h-4 w-4 text-muted-foreground" />
</CardHeader>
<CardContent>
<div className="text-2xl font-bold">
{metrics.total_requests > 0
? Math.round((metrics.successful_requests / metrics.total_requests) * 100)
: 0}%
</div>
<p className="text-xs text-muted-foreground">
{metrics.successful_requests.toLocaleString()} / {metrics.total_requests.toLocaleString()} requests
</p>
</CardContent>
</Card>
<Card>
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
<CardTitle className="text-sm font-medium">Security Score</CardTitle>
<Shield className="h-4 w-4 text-muted-foreground" />
</CardHeader>
<CardContent>
<div className="text-2xl font-bold">
{Math.round((1 - metrics.average_risk_score) * 100)}%
</div>
<p className="text-xs text-muted-foreground">
{metrics.security_blocked_requests} blocked requests
</p>
</CardContent>
</Card>
</>
)}
</div>
{/* Provider Details */}
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{Object.entries(providers).map(([name, provider]) => (
<Card key={name} className={`border-2 ${getStatusColor(provider.status)}`}>
<CardHeader>
<div className="flex items-center justify-between">
<CardTitle className="text-lg flex items-center gap-2">
{getStatusIcon(provider.status)}
{provider.provider}
</CardTitle>
<Badge
variant={provider.status === 'healthy' ? 'default' : 'destructive'}
className="capitalize"
>
{provider.status}
</Badge>
</div>
<CardDescription>
{provider.models_available.length} models available
</CardDescription>
</CardHeader>
<CardContent className="space-y-3">
{/* Performance Metrics */}
{provider.latency_ms && (
<div className="flex justify-between items-center">
<span className="text-sm font-medium">Latency</span>
<span className={`text-sm font-mono ${getLatencyColor(provider.latency_ms)}`}>
{Math.round(provider.latency_ms)}ms
</span>
</div>
)}
{provider.success_rate !== undefined && (
<div className="flex justify-between items-center">
<span className="text-sm font-medium">Success Rate</span>
<span className="text-sm font-mono">
{Math.round(provider.success_rate * 100)}%
</span>
</div>
)}
<div className="flex justify-between items-center">
<span className="text-sm font-medium">Last Check</span>
<span className="text-sm text-muted-foreground">
{new Date(provider.last_check).toLocaleTimeString()}
</span>
</div>
{/* Error Message */}
{provider.error_message && (
<Alert variant="destructive" className="mt-3">
<AlertCircle className="h-4 w-4" />
<AlertDescription className="text-xs">
{provider.error_message}
</AlertDescription>
</Alert>
)}
{/* Models */}
<div className="space-y-2">
<span className="text-sm font-medium">Available Models</span>
<div className="flex flex-wrap gap-1">
{provider.models_available.slice(0, 3).map(model => (
<Badge key={model} variant="outline" className="text-xs">
{model}
</Badge>
))}
{provider.models_available.length > 3 && (
<Badge variant="outline" className="text-xs">
+{provider.models_available.length - 3} more
</Badge>
)}
</div>
</div>
</CardContent>
</Card>
))}
</div>
{/* Provider Metrics Details */}
{metrics && Object.keys(metrics.provider_metrics).length > 0 && (
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<Server className="h-5 w-5" />
Provider Performance Metrics
</CardTitle>
<CardDescription>
Detailed performance statistics for each provider
</CardDescription>
</CardHeader>
<CardContent>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{Object.entries(metrics.provider_metrics).map(([provider, data]: [string, any]) => (
<div key={provider} className="border rounded-lg p-4">
<h4 className="font-semibold mb-3 capitalize">{provider}</h4>
<div className="space-y-2 text-sm">
<div className="flex justify-between">
<span>Total Requests:</span>
<span className="font-mono">{data.total_requests?.toLocaleString() || 0}</span>
</div>
<div className="flex justify-between">
<span>Success Rate:</span>
<span className="font-mono">
{data.success_rate ? Math.round(data.success_rate * 100) : 0}%
</span>
</div>
<div className="flex justify-between">
<span>Avg Latency:</span>
<span className={`font-mono ${getLatencyColor(data.average_latency_ms || 0)}`}>
{Math.round(data.average_latency_ms || 0)}ms
</span>
</div>
{data.token_usage && (
<div className="flex justify-between">
<span>Total Tokens:</span>
<span className="font-mono">
{data.token_usage.total_tokens?.toLocaleString() || 0}
</span>
</div>
)}
</div>
</div>
))}
</div>
</CardContent>
</Card>
)}
</div>
)
}