mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
removing lite llm and going directly for privatemode
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
LLM API endpoints - proxy to LiteLLM service with authentication and budget enforcement
|
||||
LLM API endpoints - interface to secure LLM service with authentication and budget enforcement
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -16,7 +16,9 @@ from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthS
|
||||
from app.core.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.core.config import settings
|
||||
from app.services.litellm_client import litellm_client
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage, EmbeddingRequest as LLMEmbeddingRequest
|
||||
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
|
||||
from app.services.budget_enforcement import (
|
||||
check_budget_for_request, record_request_usage, BudgetEnforcementService,
|
||||
atomic_check_and_reserve_budget, atomic_finalize_usage
|
||||
@@ -38,7 +40,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
async def get_cached_models() -> List[Dict[str, Any]]:
|
||||
"""Get models from cache or fetch from LiteLLM if cache is stale"""
|
||||
"""Get models from cache or fetch from LLM service if cache is stale"""
|
||||
current_time = time.time()
|
||||
|
||||
# Check if cache is still valid
|
||||
@@ -47,10 +49,20 @@ async def get_cached_models() -> List[Dict[str, Any]]:
|
||||
logger.debug("Returning cached models list")
|
||||
return _models_cache["data"]
|
||||
|
||||
# Cache miss or stale - fetch from LiteLLM
|
||||
# Cache miss or stale - fetch from LLM service
|
||||
try:
|
||||
logger.debug("Fetching fresh models list from LiteLLM")
|
||||
models = await litellm_client.get_models()
|
||||
logger.debug("Fetching fresh models list from LLM service")
|
||||
model_infos = await llm_service.get_models()
|
||||
|
||||
# Convert ModelInfo objects to dict format for compatibility
|
||||
models = []
|
||||
for model_info in model_infos:
|
||||
models.append({
|
||||
"id": model_info.id,
|
||||
"object": model_info.object,
|
||||
"created": model_info.created or int(time.time()),
|
||||
"owned_by": model_info.owned_by
|
||||
})
|
||||
|
||||
# Update cache
|
||||
_models_cache["data"] = models
|
||||
@@ -58,7 +70,7 @@ async def get_cached_models() -> List[Dict[str, Any]]:
|
||||
|
||||
return models
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch models from LiteLLM: {e}")
|
||||
logger.error(f"Failed to fetch models from LLM service: {e}")
|
||||
|
||||
# Return stale cache if available, otherwise empty list
|
||||
if _models_cache["data"] is not None:
|
||||
@@ -75,7 +87,7 @@ def invalidate_models_cache():
|
||||
logger.info("Models cache invalidated")
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
# Request/Response Models (API layer)
|
||||
class ChatMessage(BaseModel):
|
||||
role: str = Field(..., description="Message role (system, user, assistant)")
|
||||
content: str = Field(..., description="Message content")
|
||||
@@ -183,7 +195,7 @@ async def list_models(
|
||||
detail="Insufficient permissions to list models"
|
||||
)
|
||||
|
||||
# Get models from cache or LiteLLM
|
||||
# Get models from cache or LLM service
|
||||
models = await get_cached_models()
|
||||
|
||||
# Filter models based on API key permissions
|
||||
@@ -309,35 +321,55 @@ async def create_chat_completion(
|
||||
warnings = budget_warnings
|
||||
reserved_budget_ids = budget_ids
|
||||
|
||||
# Convert messages to dict format
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]
|
||||
# Convert messages to LLM service format
|
||||
llm_messages = [LLMChatMessage(role=msg.role, content=msg.content) for msg in chat_request.messages]
|
||||
|
||||
# Prepare additional parameters
|
||||
kwargs = {}
|
||||
if chat_request.max_tokens is not None:
|
||||
kwargs["max_tokens"] = chat_request.max_tokens
|
||||
if chat_request.temperature is not None:
|
||||
kwargs["temperature"] = chat_request.temperature
|
||||
if chat_request.top_p is not None:
|
||||
kwargs["top_p"] = chat_request.top_p
|
||||
if chat_request.frequency_penalty is not None:
|
||||
kwargs["frequency_penalty"] = chat_request.frequency_penalty
|
||||
if chat_request.presence_penalty is not None:
|
||||
kwargs["presence_penalty"] = chat_request.presence_penalty
|
||||
if chat_request.stop is not None:
|
||||
kwargs["stop"] = chat_request.stop
|
||||
if chat_request.stream is not None:
|
||||
kwargs["stream"] = chat_request.stream
|
||||
|
||||
# Make request to LiteLLM
|
||||
response = await litellm_client.create_chat_completion(
|
||||
# Create LLM service request
|
||||
llm_request = ChatRequest(
|
||||
model=chat_request.model,
|
||||
messages=messages,
|
||||
messages=llm_messages,
|
||||
temperature=chat_request.temperature,
|
||||
max_tokens=chat_request.max_tokens,
|
||||
top_p=chat_request.top_p,
|
||||
frequency_penalty=chat_request.frequency_penalty,
|
||||
presence_penalty=chat_request.presence_penalty,
|
||||
stop=chat_request.stop,
|
||||
stream=chat_request.stream or False,
|
||||
user_id=str(context.get("user_id", "anonymous")),
|
||||
api_key_id=context.get("api_key_id", "jwt_user"),
|
||||
**kwargs
|
||||
api_key_id=context.get("api_key_id", 0) if auth_type == "api_key" else 0
|
||||
)
|
||||
|
||||
# Make request to LLM service
|
||||
llm_response = await llm_service.create_chat_completion(llm_request)
|
||||
|
||||
# Convert LLM service response to API format
|
||||
response = {
|
||||
"id": llm_response.id,
|
||||
"object": llm_response.object,
|
||||
"created": llm_response.created,
|
||||
"model": llm_response.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": choice.index,
|
||||
"message": {
|
||||
"role": choice.message.role,
|
||||
"content": choice.message.content
|
||||
},
|
||||
"finish_reason": choice.finish_reason
|
||||
}
|
||||
for choice in llm_response.choices
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
|
||||
"completion_tokens": llm_response.usage.completion_tokens if llm_response.usage else 0,
|
||||
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
|
||||
} if llm_response.usage else {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
# Calculate actual cost and update usage
|
||||
usage = response.get("usage", {})
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
@@ -382,8 +414,38 @@ async def create_chat_completion(
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SecurityError as e:
|
||||
logger.warning(f"Security error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Security validation failed: {e.message}"
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Validation error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Request validation failed: {e.message}"
|
||||
)
|
||||
except ProviderError as e:
|
||||
logger.error(f"Provider error in chat completion: {e}")
|
||||
if "rate limit" in str(e).lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LLM service temporarily unavailable"
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.error(f"LLM service error in chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="LLM service error"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating chat completion: {e}")
|
||||
logger.error(f"Unexpected error creating chat completion: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create chat completion"
|
||||
@@ -438,15 +500,39 @@ async def create_embedding(
|
||||
detail=f"Budget exceeded: {error_message}"
|
||||
)
|
||||
|
||||
# Make request to LiteLLM
|
||||
response = await litellm_client.create_embedding(
|
||||
# Create LLM service request
|
||||
llm_request = LLMEmbeddingRequest(
|
||||
model=request.model,
|
||||
input_text=request.input,
|
||||
input=request.input,
|
||||
encoding_format=request.encoding_format,
|
||||
user_id=str(context["user_id"]),
|
||||
api_key_id=context["api_key_id"],
|
||||
encoding_format=request.encoding_format
|
||||
api_key_id=context["api_key_id"]
|
||||
)
|
||||
|
||||
# Make request to LLM service
|
||||
llm_response = await llm_service.create_embedding(llm_request)
|
||||
|
||||
# Convert LLM service response to API format
|
||||
response = {
|
||||
"object": llm_response.object,
|
||||
"data": [
|
||||
{
|
||||
"object": emb.object,
|
||||
"index": emb.index,
|
||||
"embedding": emb.embedding
|
||||
}
|
||||
for emb in llm_response.data
|
||||
],
|
||||
"model": llm_response.model,
|
||||
"usage": {
|
||||
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
|
||||
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
|
||||
} if llm_response.usage else {
|
||||
"prompt_tokens": int(estimated_tokens),
|
||||
"total_tokens": int(estimated_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
# Calculate actual cost and update usage
|
||||
usage = response.get("usage", {})
|
||||
total_tokens = usage.get("total_tokens", int(estimated_tokens))
|
||||
@@ -475,8 +561,38 @@ async def create_embedding(
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except SecurityError as e:
|
||||
logger.warning(f"Security error in embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Security validation failed: {e.message}"
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.warning(f"Validation error in embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Request validation failed: {e.message}"
|
||||
)
|
||||
except ProviderError as e:
|
||||
logger.error(f"Provider error in embedding: {e}")
|
||||
if "rate limit" in str(e).lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LLM service temporarily unavailable"
|
||||
)
|
||||
except LLMError as e:
|
||||
logger.error(f"LLM service error in embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="LLM service error"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating embedding: {e}")
|
||||
logger.error(f"Unexpected error creating embedding: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create embedding"
|
||||
@@ -489,11 +605,28 @@ async def llm_health_check(
|
||||
):
|
||||
"""Health check for LLM service"""
|
||||
try:
|
||||
health_status = await litellm_client.health_check()
|
||||
health_summary = llm_service.get_health_summary()
|
||||
provider_status = await llm_service.get_provider_status()
|
||||
|
||||
# Determine overall health
|
||||
overall_status = "healthy"
|
||||
if health_summary["service_status"] != "healthy":
|
||||
overall_status = "degraded"
|
||||
|
||||
for provider, status in provider_status.items():
|
||||
if status.status == "unavailable":
|
||||
overall_status = "degraded"
|
||||
break
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "LLM Proxy",
|
||||
"litellm_status": health_status,
|
||||
"status": overall_status,
|
||||
"service": "LLM Service",
|
||||
"service_status": health_summary,
|
||||
"provider_status": {name: {
|
||||
"status": status.status,
|
||||
"latency_ms": status.latency_ms,
|
||||
"error_message": status.error_message
|
||||
} for name, status in provider_status.items()},
|
||||
"user_id": context["user_id"],
|
||||
"api_key_name": context["api_key_name"]
|
||||
}
|
||||
@@ -501,7 +634,7 @@ async def llm_health_check(
|
||||
logger.error(f"LLM health check error: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"service": "LLM Proxy",
|
||||
"service": "LLM Service",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
@@ -626,50 +759,83 @@ async def get_budget_status(
|
||||
)
|
||||
|
||||
|
||||
# Generic proxy endpoint for other LiteLLM endpoints
|
||||
@router.api_route("/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
||||
async def proxy_endpoint(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
# Generic endpoint for additional LLM service functionality
|
||||
@router.get("/metrics")
|
||||
async def get_llm_metrics(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Generic proxy endpoint for LiteLLM requests"""
|
||||
"""Get LLM service metrics (admin only)"""
|
||||
try:
|
||||
# Check for admin permissions
|
||||
auth_service = APIKeyAuthService(db)
|
||||
|
||||
# Check endpoint permission
|
||||
if not await auth_service.check_endpoint_permission(context, endpoint):
|
||||
if not await auth_service.check_scope_permission(context, "admin.metrics"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Endpoint '{endpoint}' not allowed"
|
||||
detail="Admin permissions required to view metrics"
|
||||
)
|
||||
|
||||
# Get request body
|
||||
if request.method in ["POST", "PUT", "PATCH"]:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except:
|
||||
payload = {}
|
||||
else:
|
||||
payload = dict(request.query_params)
|
||||
|
||||
# Make request to LiteLLM
|
||||
response = await litellm_client.proxy_request(
|
||||
method=request.method,
|
||||
endpoint=endpoint,
|
||||
payload=payload,
|
||||
user_id=str(context["user_id"]),
|
||||
api_key_id=context["api_key_id"]
|
||||
)
|
||||
|
||||
return response
|
||||
metrics = llm_service.get_metrics()
|
||||
return {
|
||||
"object": "llm_metrics",
|
||||
"data": {
|
||||
"total_requests": metrics.total_requests,
|
||||
"successful_requests": metrics.successful_requests,
|
||||
"failed_requests": metrics.failed_requests,
|
||||
"security_blocked_requests": metrics.security_blocked_requests,
|
||||
"average_latency_ms": metrics.average_latency_ms,
|
||||
"average_risk_score": metrics.average_risk_score,
|
||||
"provider_metrics": metrics.provider_metrics,
|
||||
"last_updated": metrics.last_updated.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error proxying request to {endpoint}: {e}")
|
||||
logger.error(f"Error getting LLM metrics: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to proxy request"
|
||||
detail="Failed to get LLM metrics"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/providers/status")
|
||||
async def get_provider_status(
|
||||
context: Dict[str, Any] = Depends(require_api_key),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get status of all LLM providers"""
|
||||
try:
|
||||
auth_service = APIKeyAuthService(db)
|
||||
if not await auth_service.check_scope_permission(context, "admin.status"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin permissions required to view provider status"
|
||||
)
|
||||
|
||||
provider_status = await llm_service.get_provider_status()
|
||||
return {
|
||||
"object": "provider_status",
|
||||
"data": {
|
||||
name: {
|
||||
"provider": status.provider,
|
||||
"status": status.status,
|
||||
"latency_ms": status.latency_ms,
|
||||
"success_rate": status.success_rate,
|
||||
"last_check": status.last_check.isoformat(),
|
||||
"error_message": status.error_message,
|
||||
"models_available": status.models_available
|
||||
}
|
||||
for name, status in provider_status.items()
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting provider status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get provider status"
|
||||
)
|
||||
@@ -447,7 +447,7 @@ async def get_module_config(module_name: str):
|
||||
log_api_request("get_module_config", {"module_name": module_name})
|
||||
|
||||
from app.services.module_config_manager import module_config_manager
|
||||
from app.services.litellm_client import litellm_client
|
||||
from app.services.llm.service import llm_service
|
||||
import copy
|
||||
|
||||
# Get module manifest and schema
|
||||
@@ -461,9 +461,9 @@ async def get_module_config(module_name: str):
|
||||
# For Signal module, populate model options dynamically
|
||||
if module_name == "signal" and schema:
|
||||
try:
|
||||
# Get available models from LiteLLM
|
||||
models_data = await litellm_client.get_models()
|
||||
model_ids = [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")]
|
||||
# Get available models from LLM service
|
||||
models_data = await llm_service.get_models()
|
||||
model_ids = [model.id for model in models_data]
|
||||
|
||||
if model_ids:
|
||||
# Create a copy of the schema to avoid modifying the original
|
||||
|
||||
@@ -15,7 +15,8 @@ from app.models.prompt_template import PromptTemplate, ChatbotPromptVariable
|
||||
from app.core.security import get_current_user
|
||||
from app.models.user import User
|
||||
from app.core.logging import log_api_request
|
||||
from app.services.litellm_client import litellm_client
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -394,25 +395,28 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
|
||||
]
|
||||
|
||||
# Get available models to use a default model
|
||||
models = await litellm_client.get_models()
|
||||
models = await llm_service.get_models()
|
||||
if not models:
|
||||
raise HTTPException(status_code=503, detail="No LLM models available")
|
||||
|
||||
# Use the first available model (you might want to make this configurable)
|
||||
default_model = models[0]["id"]
|
||||
default_model = models[0].id
|
||||
|
||||
# Make the AI call
|
||||
response = await litellm_client.create_chat_completion(
|
||||
# Prepare the chat request for the new LLM service
|
||||
chat_request = LLMChatRequest(
|
||||
model=default_model,
|
||||
messages=messages,
|
||||
user_id=str(user_id),
|
||||
api_key_id=1, # Using default API key, you might want to make this dynamic
|
||||
messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages],
|
||||
temperature=0.3,
|
||||
max_tokens=1000
|
||||
max_tokens=1000,
|
||||
user_id=str(user_id),
|
||||
api_key_id=1 # Using default API key, you might want to make this dynamic
|
||||
)
|
||||
|
||||
# Make the AI call
|
||||
response = await llm_service.create_chat_completion(chat_request)
|
||||
|
||||
# Extract the improved prompt from the response
|
||||
improved_prompt = response["choices"][0]["message"]["content"].strip()
|
||||
improved_prompt = response.choices[0].message.content.strip()
|
||||
|
||||
return {
|
||||
"improved_prompt": improved_prompt,
|
||||
|
||||
@@ -51,7 +51,7 @@ class SystemInfoResponse(BaseModel):
|
||||
environment: str
|
||||
database_status: str
|
||||
redis_status: str
|
||||
litellm_status: str
|
||||
llm_service_status: str
|
||||
modules_loaded: int
|
||||
active_users: int
|
||||
total_api_keys: int
|
||||
@@ -227,8 +227,13 @@ async def get_system_info(
|
||||
# Get Redis status (simplified check)
|
||||
redis_status = "healthy" # Would implement actual Redis check
|
||||
|
||||
# Get LiteLLM status (simplified check)
|
||||
litellm_status = "healthy" # Would implement actual LiteLLM check
|
||||
# Get LLM service status
|
||||
try:
|
||||
from app.services.llm.service import llm_service
|
||||
health_summary = llm_service.get_health_summary()
|
||||
llm_service_status = health_summary.get("service_status", "unknown")
|
||||
except Exception:
|
||||
llm_service_status = "unavailable"
|
||||
|
||||
# Get modules loaded (from module manager)
|
||||
modules_loaded = 8 # Would get from actual module manager
|
||||
@@ -261,7 +266,7 @@ async def get_system_info(
|
||||
environment="production",
|
||||
database_status=database_status,
|
||||
redis_status=redis_status,
|
||||
litellm_status=litellm_status,
|
||||
llm_service_status=llm_service_status,
|
||||
modules_loaded=modules_loaded,
|
||||
active_users=active_users,
|
||||
total_api_keys=total_api_keys,
|
||||
|
||||
@@ -43,15 +43,18 @@ class Settings(BaseSettings):
|
||||
# CORS
|
||||
CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"]
|
||||
|
||||
# LiteLLM
|
||||
LITELLM_BASE_URL: str = "http://localhost:4000"
|
||||
LITELLM_MASTER_KEY: str = "enclava-master-key"
|
||||
# LLM Service Configuration (replaced LiteLLM)
|
||||
# LLM service configuration is now handled in app/services/llm/config.py
|
||||
|
||||
# LLM Service Security
|
||||
LLM_ENCRYPTION_KEY: Optional[str] = None # Key for encrypting LLM provider API keys
|
||||
|
||||
# API Keys for LLM providers
|
||||
OPENAI_API_KEY: Optional[str] = None
|
||||
ANTHROPIC_API_KEY: Optional[str] = None
|
||||
GOOGLE_API_KEY: Optional[str] = None
|
||||
PRIVATEMODE_API_KEY: Optional[str] = None
|
||||
PRIVATEMODE_PROXY_URL: str = "http://privatemode-proxy:8080/v1"
|
||||
|
||||
# Qdrant
|
||||
QDRANT_HOST: str = "localhost"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Embedding Service
|
||||
Provides text embedding functionality using LiteLLM proxy
|
||||
Provides text embedding functionality using LLM service
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -11,32 +11,34 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""Service for generating text embeddings using LiteLLM"""
|
||||
"""Service for generating text embeddings using LLM service"""
|
||||
|
||||
def __init__(self, model_name: str = "privatemode-embeddings"):
|
||||
self.model_name = model_name
|
||||
self.litellm_client = None
|
||||
self.dimension = 1024 # Actual dimension for privatemode-embeddings
|
||||
self.initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the embedding service with LiteLLM"""
|
||||
"""Initialize the embedding service with LLM service"""
|
||||
try:
|
||||
from app.services.litellm_client import litellm_client
|
||||
self.litellm_client = litellm_client
|
||||
from app.services.llm.service import llm_service
|
||||
|
||||
# Test connection to LiteLLM
|
||||
health = await self.litellm_client.health_check()
|
||||
if health.get("status") == "unhealthy":
|
||||
logger.error(f"LiteLLM service unhealthy: {health.get('error')}")
|
||||
# Initialize LLM service if not already done
|
||||
if not llm_service._initialized:
|
||||
await llm_service.initialize()
|
||||
|
||||
# Test LLM service health
|
||||
health_summary = llm_service.get_health_summary()
|
||||
if health_summary.get("service_status") != "healthy":
|
||||
logger.error(f"LLM service unhealthy: {health_summary}")
|
||||
return False
|
||||
|
||||
self.initialized = True
|
||||
logger.info(f"Embedding service initialized with LiteLLM: {self.model_name} (dimension: {self.dimension})")
|
||||
logger.info(f"Embedding service initialized with LLM service: {self.model_name} (dimension: {self.dimension})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LiteLLM embedding service: {e}")
|
||||
logger.error(f"Failed to initialize LLM embedding service: {e}")
|
||||
logger.warning("Using fallback random embeddings")
|
||||
return False
|
||||
|
||||
@@ -46,10 +48,10 @@ class EmbeddingService:
|
||||
return embeddings[0]
|
||||
|
||||
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get embeddings for multiple texts using LiteLLM"""
|
||||
if not self.initialized or not self.litellm_client:
|
||||
"""Get embeddings for multiple texts using LLM service"""
|
||||
if not self.initialized:
|
||||
# Fallback to random embeddings if not initialized
|
||||
logger.warning("LiteLLM not available, using random embeddings")
|
||||
logger.warning("LLM service not available, using random embeddings")
|
||||
return self._generate_fallback_embeddings(texts)
|
||||
|
||||
try:
|
||||
@@ -73,17 +75,22 @@ class EmbeddingService:
|
||||
else:
|
||||
truncated_text = text
|
||||
|
||||
# Call LiteLLM embedding endpoint
|
||||
response = await self.litellm_client.create_embedding(
|
||||
# Call LLM service embedding endpoint
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import EmbeddingRequest
|
||||
|
||||
llm_request = EmbeddingRequest(
|
||||
model=self.model_name,
|
||||
input_text=truncated_text,
|
||||
input=truncated_text,
|
||||
user_id="rag_system",
|
||||
api_key_id=0 # System API key
|
||||
)
|
||||
|
||||
response = await llm_service.create_embedding(llm_request)
|
||||
|
||||
# Extract embedding from response
|
||||
if "data" in response and len(response["data"]) > 0:
|
||||
embedding = response["data"][0].get("embedding", [])
|
||||
if response.data and len(response.data) > 0:
|
||||
embedding = response.data[0].embedding
|
||||
if embedding:
|
||||
batch_embeddings.append(embedding)
|
||||
# Update dimension based on actual embedding size
|
||||
@@ -106,7 +113,7 @@ class EmbeddingService:
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings with LiteLLM: {e}")
|
||||
logger.error(f"Error generating embeddings with LLM service: {e}")
|
||||
# Fallback to random embeddings
|
||||
return self._generate_fallback_embeddings(texts)
|
||||
|
||||
@@ -146,14 +153,13 @@ class EmbeddingService:
|
||||
"model_name": self.model_name,
|
||||
"model_loaded": self.initialized,
|
||||
"dimension": self.dimension,
|
||||
"backend": "LiteLLM",
|
||||
"backend": "LLM Service",
|
||||
"initialized": self.initialized
|
||||
}
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
self.initialized = False
|
||||
self.litellm_client = None
|
||||
|
||||
|
||||
# Global embedding service instance
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
"""
|
||||
LiteLLM Client Service
|
||||
Handles communication with the LiteLLM proxy service
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
"""Client for communicating with LiteLLM proxy service"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = settings.LITELLM_BASE_URL
|
||||
self.master_key = settings.LITELLM_MASTER_KEY
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
self.timeout = aiohttp.ClientTimeout(total=600) # 10 minutes timeout
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session"""
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession(
|
||||
timeout=self.timeout,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.master_key}"
|
||||
}
|
||||
)
|
||||
return self.session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session"""
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Check LiteLLM proxy health"""
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(f"{self.base_url}/health") as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
logger.error(f"LiteLLM health check failed: {response.status}")
|
||||
return {"status": "unhealthy", "error": f"HTTP {response.status}"}
|
||||
except Exception as e:
|
||||
logger.error(f"LiteLLM health check error: {e}")
|
||||
return {"status": "unhealthy", "error": str(e)}
|
||||
|
||||
async def get_models(self) -> List[Dict[str, Any]]:
|
||||
"""Get available models from LiteLLM"""
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(f"{self.base_url}/models") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return data.get("data", [])
|
||||
else:
|
||||
logger.error(f"Failed to get models: {response.status}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service unavailable"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"LiteLLM models request error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service unavailable"
|
||||
)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
user_id: str,
|
||||
api_key_id: int,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Create chat completion via LiteLLM proxy"""
|
||||
try:
|
||||
# Prepare request payload
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"user": f"user_{user_id}", # User identifier for tracking
|
||||
"metadata": {
|
||||
"api_key_id": api_key_id,
|
||||
"user_id": user_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
},
|
||||
**kwargs
|
||||
}
|
||||
|
||||
session = await self._get_session()
|
||||
async with session.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
json=payload
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"LiteLLM chat completion failed: {response.status} - {error_text}")
|
||||
|
||||
# Handle specific error cases
|
||||
if response.status == 401:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key"
|
||||
)
|
||||
elif response.status == 429:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
)
|
||||
elif response.status == 400:
|
||||
try:
|
||||
error_data = await response.json()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=error_data.get("error", {}).get("message", "Bad request")
|
||||
)
|
||||
except:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid request"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service error"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"LiteLLM chat completion request error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service unavailable"
|
||||
)
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
model: str,
|
||||
input_text: str,
|
||||
user_id: str,
|
||||
api_key_id: int,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Create embedding via LiteLLM proxy"""
|
||||
try:
|
||||
payload = {
|
||||
"model": model,
|
||||
"input": input_text,
|
||||
"user": f"user_{user_id}",
|
||||
"metadata": {
|
||||
"api_key_id": api_key_id,
|
||||
"user_id": user_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
},
|
||||
**kwargs
|
||||
}
|
||||
|
||||
session = await self._get_session()
|
||||
async with session.post(
|
||||
f"{self.base_url}/embeddings",
|
||||
json=payload
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"LiteLLM embedding failed: {response.status} - {error_text}")
|
||||
|
||||
if response.status == 401:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key"
|
||||
)
|
||||
elif response.status == 429:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service error"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"LiteLLM embedding request error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service unavailable"
|
||||
)
|
||||
|
||||
async def get_models(self) -> List[Dict[str, Any]]:
|
||||
"""Get available models from LiteLLM proxy"""
|
||||
try:
|
||||
session = await self._get_session()
|
||||
async with session.get(f"{self.base_url}/models") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
# Return models with exact names from upstream providers
|
||||
models = data.get("data", [])
|
||||
|
||||
# Pass through model names exactly as they come from upstream
|
||||
# Don't modify model IDs - keep them as the original provider names
|
||||
processed_models = []
|
||||
for model in models:
|
||||
# Keep the exact model ID from upstream provider
|
||||
processed_models.append({
|
||||
"id": model.get("id", ""), # Exact model name from provider
|
||||
"object": model.get("object", "model"),
|
||||
"created": model.get("created", 1677610602),
|
||||
"owned_by": model.get("owned_by", "openai")
|
||||
})
|
||||
|
||||
return processed_models
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"LiteLLM models request failed: {response.status} - {error_text}")
|
||||
return []
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"LiteLLM models request error: {e}")
|
||||
return []
|
||||
|
||||
async def proxy_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
payload: Dict[str, Any],
|
||||
user_id: str,
|
||||
api_key_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Generic proxy request to LiteLLM"""
|
||||
try:
|
||||
# Add metadata to payload
|
||||
if isinstance(payload, dict):
|
||||
payload["metadata"] = {
|
||||
"api_key_id": api_key_id,
|
||||
"user_id": user_id,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
if "user" not in payload:
|
||||
payload["user"] = f"user_{user_id}"
|
||||
|
||||
session = await self._get_session()
|
||||
|
||||
# Make the request
|
||||
async with session.request(
|
||||
method,
|
||||
f"{self.base_url}/{endpoint.lstrip('/')}",
|
||||
json=payload if method.upper() in ['POST', 'PUT', 'PATCH'] else None,
|
||||
params=payload if method.upper() == 'GET' else None
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(f"LiteLLM proxy request failed: {response.status} - {error_text}")
|
||||
|
||||
if response.status == 401:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key"
|
||||
)
|
||||
elif response.status == 429:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Rate limit exceeded"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"LiteLLM service error: {error_text}"
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"LiteLLM proxy request error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="LiteLLM service unavailable"
|
||||
)
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
"""Get list of available model names/IDs"""
|
||||
try:
|
||||
models_data = await self.get_models()
|
||||
return [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")]
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing model names: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
# Global LiteLLM client instance
|
||||
litellm_client = LiteLLMClient()
|
||||
@@ -23,7 +23,9 @@ from fastapi import APIRouter, HTTPException, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.litellm_client import LiteLLMClient
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
|
||||
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError
|
||||
from app.services.base_module import BaseModule, Permission
|
||||
from app.models.user import User
|
||||
from app.models.chatbot import ChatbotInstance as DBChatbotInstance, ChatbotConversation as DBConversation, ChatbotMessage as DBMessage, ChatbotAnalytics
|
||||
@@ -32,7 +34,8 @@ from app.db.database import get_db
|
||||
from app.core.config import settings
|
||||
|
||||
# Import protocols for type hints and dependency injection
|
||||
from ..protocols import RAGServiceProtocol, LiteLLMClientProtocol
|
||||
from ..protocols import RAGServiceProtocol
|
||||
# Note: LiteLLMClientProtocol replaced with direct LLM service usage
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -131,10 +134,8 @@ class ChatbotInstance(BaseModel):
|
||||
class ChatbotModule(BaseModule):
|
||||
"""Main chatbot module implementation"""
|
||||
|
||||
def __init__(self, litellm_client: Optional[LiteLLMClientProtocol] = None,
|
||||
rag_service: Optional[RAGServiceProtocol] = None):
|
||||
def __init__(self, rag_service: Optional[RAGServiceProtocol] = None):
|
||||
super().__init__("chatbot")
|
||||
self.litellm_client = litellm_client
|
||||
self.rag_module = rag_service # Keep same name for compatibility
|
||||
self.db_session = None
|
||||
|
||||
@@ -145,15 +146,10 @@ class ChatbotModule(BaseModule):
|
||||
"""Initialize the chatbot module"""
|
||||
await super().initialize(**kwargs)
|
||||
|
||||
# Get dependencies from global services if not already injected
|
||||
if not self.litellm_client:
|
||||
try:
|
||||
from app.services.litellm_client import litellm_client
|
||||
self.litellm_client = litellm_client
|
||||
logger.info("LiteLLM client injected from global service")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not inject LiteLLM client: {e}")
|
||||
# Initialize the LLM service
|
||||
await llm_service.initialize()
|
||||
|
||||
# Get RAG module dependency if not already injected
|
||||
if not self.rag_module:
|
||||
try:
|
||||
# Try to get RAG module from module manager
|
||||
@@ -168,19 +164,16 @@ class ChatbotModule(BaseModule):
|
||||
await self._load_prompt_templates()
|
||||
|
||||
logger.info("Chatbot module initialized")
|
||||
logger.info(f"LiteLLM client available after init: {self.litellm_client is not None}")
|
||||
logger.info(f"LLM service available: {llm_service._initialized}")
|
||||
logger.info(f"RAG module available after init: {self.rag_module is not None}")
|
||||
logger.info(f"Loaded {len(self.system_prompts)} prompt templates")
|
||||
|
||||
async def _ensure_dependencies(self):
|
||||
"""Lazy load dependencies if not available"""
|
||||
if not self.litellm_client:
|
||||
try:
|
||||
from app.services.litellm_client import litellm_client
|
||||
self.litellm_client = litellm_client
|
||||
logger.info("LiteLLM client lazy loaded")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not lazy load LiteLLM client: {e}")
|
||||
# Ensure LLM service is initialized
|
||||
if not llm_service._initialized:
|
||||
await llm_service.initialize()
|
||||
logger.info("LLM service lazy loaded")
|
||||
|
||||
if not self.rag_module:
|
||||
try:
|
||||
@@ -468,45 +461,58 @@ class ChatbotModule(BaseModule):
|
||||
logger.info(msg['content'])
|
||||
logger.info("=== END COMPREHENSIVE LLM REQUEST ===")
|
||||
|
||||
if self.litellm_client:
|
||||
try:
|
||||
logger.info("Calling LiteLLM client create_chat_completion...")
|
||||
response = await self.litellm_client.create_chat_completion(
|
||||
model=config.model,
|
||||
messages=messages,
|
||||
user_id="chatbot_user",
|
||||
api_key_id="chatbot_api_key",
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens
|
||||
)
|
||||
logger.info(f"LiteLLM response received, response keys: {list(response.keys())}")
|
||||
try:
|
||||
logger.info("Calling LLM service create_chat_completion...")
|
||||
|
||||
# Convert messages to LLM service format
|
||||
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||
|
||||
# Create LLM service request
|
||||
llm_request = LLMChatRequest(
|
||||
model=config.model,
|
||||
messages=llm_messages,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
user_id="chatbot_user",
|
||||
api_key_id=0 # Chatbot module uses internal service
|
||||
)
|
||||
|
||||
# Make request to LLM service
|
||||
llm_response = await llm_service.create_chat_completion(llm_request)
|
||||
|
||||
# Extract response content
|
||||
if llm_response.choices:
|
||||
content = llm_response.choices[0].message.content
|
||||
logger.info(f"Response content length: {len(content)}")
|
||||
|
||||
# Extract response content from the LiteLLM response format
|
||||
if 'choices' in response and response['choices']:
|
||||
content = response['choices'][0]['message']['content']
|
||||
logger.info(f"Response content length: {len(content)}")
|
||||
|
||||
# Always log response for debugging
|
||||
logger.info("=== COMPREHENSIVE LLM RESPONSE ===")
|
||||
logger.info(f"Response content ({len(content)} chars):")
|
||||
logger.info(content)
|
||||
if 'usage' in response:
|
||||
usage = response['usage']
|
||||
logger.info(f"Token usage - Prompt: {usage.get('prompt_tokens', 'N/A')}, Completion: {usage.get('completion_tokens', 'N/A')}, Total: {usage.get('total_tokens', 'N/A')}")
|
||||
if sources:
|
||||
logger.info(f"RAG sources included: {len(sources)} documents")
|
||||
logger.info("=== END COMPREHENSIVE LLM RESPONSE ===")
|
||||
|
||||
return content, sources
|
||||
else:
|
||||
logger.warning("No choices in LiteLLM response")
|
||||
return "I received an empty response from the AI model.", sources
|
||||
except Exception as e:
|
||||
logger.error(f"LiteLLM completion failed: {e}")
|
||||
raise e
|
||||
else:
|
||||
logger.warning("No LiteLLM client available, using fallback")
|
||||
# Fallback if no LLM client
|
||||
# Always log response for debugging
|
||||
logger.info("=== COMPREHENSIVE LLM RESPONSE ===")
|
||||
logger.info(f"Response content ({len(content)} chars):")
|
||||
logger.info(content)
|
||||
if llm_response.usage:
|
||||
usage = llm_response.usage
|
||||
logger.info(f"Token usage - Prompt: {usage.prompt_tokens}, Completion: {usage.completion_tokens}, Total: {usage.total_tokens}")
|
||||
if sources:
|
||||
logger.info(f"RAG sources included: {len(sources)} documents")
|
||||
logger.info("=== END COMPREHENSIVE LLM RESPONSE ===")
|
||||
|
||||
return content, sources
|
||||
else:
|
||||
logger.warning("No choices in LLM response")
|
||||
return "I received an empty response from the AI model.", sources
|
||||
|
||||
except SecurityError as e:
|
||||
logger.error(f"Security error in LLM completion: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Security validation failed: {e.message}")
|
||||
except ProviderError as e:
|
||||
logger.error(f"Provider error in LLM completion: {e}")
|
||||
raise HTTPException(status_code=503, detail="LLM service temporarily unavailable")
|
||||
except LLMError as e:
|
||||
logger.error(f"LLM service error: {e}")
|
||||
raise HTTPException(status_code=500, detail="LLM service error")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM completion failed: {e}")
|
||||
# Return fallback if available
|
||||
return "I'm currently unable to process your request. Please try again later.", None
|
||||
|
||||
def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig,
|
||||
@@ -685,7 +691,7 @@ class ChatbotModule(BaseModule):
|
||||
# Lazy load dependencies
|
||||
await self._ensure_dependencies()
|
||||
|
||||
logger.info(f"LiteLLM client available: {self.litellm_client is not None}")
|
||||
logger.info(f"LLM service available: {llm_service._initialized}")
|
||||
logger.info(f"RAG module available: {self.rag_module is not None}")
|
||||
|
||||
try:
|
||||
@@ -884,10 +890,9 @@ class ChatbotModule(BaseModule):
|
||||
|
||||
|
||||
# Module factory function
|
||||
def create_module(litellm_client: Optional[LiteLLMClientProtocol] = None,
|
||||
rag_service: Optional[RAGServiceProtocol] = None) -> ChatbotModule:
|
||||
def create_module(rag_service: Optional[RAGServiceProtocol] = None) -> ChatbotModule:
|
||||
"""Factory function to create chatbot module instance"""
|
||||
return ChatbotModule(litellm_client=litellm_client, rag_service=rag_service)
|
||||
return ChatbotModule(rag_service=rag_service)
|
||||
|
||||
# Create module instance (dependencies will be injected via factory)
|
||||
chatbot_module = ChatbotModule()
|
||||
@@ -401,7 +401,7 @@ class RAGModule(BaseModule):
|
||||
"""Initialize embedding model"""
|
||||
from app.services.embedding_service import embedding_service
|
||||
|
||||
# Use privatemode-embeddings for LiteLLM integration
|
||||
# Use privatemode-embeddings for LLM service integration
|
||||
model_name = self.config.get("embedding_model", "privatemode-embeddings")
|
||||
embedding_service.model_name = model_name
|
||||
|
||||
|
||||
@@ -22,13 +22,16 @@ from fastapi import APIRouter, HTTPException, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
from app.core.logging import get_logger
|
||||
from app.services.litellm_client import LiteLLMClient
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
|
||||
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError
|
||||
from app.services.base_module import Permission
|
||||
from app.db.database import SessionLocal
|
||||
from app.models.workflow import WorkflowDefinition as DBWorkflowDefinition, WorkflowExecution as DBWorkflowExecution
|
||||
|
||||
# Import protocols for type hints and dependency injection
|
||||
from ..protocols import ChatbotServiceProtocol, LiteLLMClientProtocol
|
||||
from ..protocols import ChatbotServiceProtocol
|
||||
# Note: LiteLLMClientProtocol replaced with direct LLM service usage
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -234,8 +237,7 @@ class WorkflowExecution(BaseModel):
|
||||
class WorkflowEngine:
|
||||
"""Core workflow execution engine"""
|
||||
|
||||
def __init__(self, litellm_client: LiteLLMClient, chatbot_service: Optional[ChatbotServiceProtocol] = None):
|
||||
self.litellm_client = litellm_client
|
||||
def __init__(self, chatbot_service: Optional[ChatbotServiceProtocol] = None):
|
||||
self.chatbot_service = chatbot_service
|
||||
self.executions: Dict[str, WorkflowExecution] = {}
|
||||
self.workflows: Dict[str, WorkflowDefinition] = {}
|
||||
@@ -343,15 +345,23 @@ class WorkflowEngine:
|
||||
# Template message content with context variables
|
||||
messages = self._template_messages(llm_step.messages, context.variables)
|
||||
|
||||
# Make LLM call
|
||||
response = await self.litellm_client.chat_completion(
|
||||
# Convert messages to LLM service format
|
||||
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||
|
||||
# Create LLM service request
|
||||
llm_request = LLMChatRequest(
|
||||
model=llm_step.model,
|
||||
messages=messages,
|
||||
**llm_step.parameters
|
||||
messages=llm_messages,
|
||||
user_id="workflow_user",
|
||||
api_key_id=0, # Workflow module uses internal service
|
||||
**{k: v for k, v in llm_step.parameters.items() if k in ['temperature', 'max_tokens', 'top_p', 'frequency_penalty', 'presence_penalty', 'stop']}
|
||||
)
|
||||
|
||||
# Make LLM call
|
||||
response = await llm_service.create_chat_completion(llm_request)
|
||||
|
||||
# Store result
|
||||
result = response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
result = response.choices[0].message.content if response.choices else ""
|
||||
context.variables[llm_step.output_variable] = result
|
||||
context.results[step.id] = result
|
||||
|
||||
@@ -631,16 +641,21 @@ class WorkflowEngine:
|
||||
|
||||
messages = [{"role": "user", "content": self._template_string(prompt, variables)}]
|
||||
|
||||
response = await self.litellm_client.create_chat_completion(
|
||||
# Convert to LLM service format
|
||||
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||
|
||||
llm_request = LLMChatRequest(
|
||||
model=step.model,
|
||||
messages=messages,
|
||||
messages=llm_messages,
|
||||
user_id="workflow_system",
|
||||
api_key_id="workflow",
|
||||
api_key_id=0,
|
||||
temperature=step.temperature,
|
||||
max_tokens=step.max_tokens
|
||||
)
|
||||
|
||||
return response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
response = await llm_service.create_chat_completion(llm_request)
|
||||
|
||||
return response.choices[0].message.content if response.choices else ""
|
||||
|
||||
async def _generate_brand_names(self, variables: Dict[str, Any], step: AIGenerationStep) -> List[Dict[str, str]]:
|
||||
"""Generate brand names for a specific category"""
|
||||
@@ -687,16 +702,21 @@ class WorkflowEngine:
|
||||
|
||||
messages = [{"role": "user", "content": self._template_string(prompt, variables)}]
|
||||
|
||||
response = await self.litellm_client.create_chat_completion(
|
||||
# Convert to LLM service format
|
||||
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||
|
||||
llm_request = LLMChatRequest(
|
||||
model=step.model,
|
||||
messages=messages,
|
||||
messages=llm_messages,
|
||||
user_id="workflow_system",
|
||||
api_key_id="workflow",
|
||||
api_key_id=0,
|
||||
temperature=step.temperature,
|
||||
max_tokens=step.max_tokens
|
||||
)
|
||||
|
||||
return response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
response = await llm_service.create_chat_completion(llm_request)
|
||||
|
||||
return response.choices[0].message.content if response.choices else ""
|
||||
|
||||
async def _generate_custom_prompt(self, variables: Dict[str, Any], step: AIGenerationStep) -> str:
|
||||
"""Generate content using custom prompt template"""
|
||||
@@ -705,16 +725,21 @@ class WorkflowEngine:
|
||||
|
||||
messages = [{"role": "user", "content": self._template_string(step.prompt_template, variables)}]
|
||||
|
||||
response = await self.litellm_client.create_chat_completion(
|
||||
# Convert to LLM service format
|
||||
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||
|
||||
llm_request = LLMChatRequest(
|
||||
model=step.model,
|
||||
messages=messages,
|
||||
messages=llm_messages,
|
||||
user_id="workflow_system",
|
||||
api_key_id="workflow",
|
||||
api_key_id=0,
|
||||
temperature=step.temperature,
|
||||
max_tokens=step.max_tokens
|
||||
)
|
||||
|
||||
return response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
response = await llm_service.create_chat_completion(llm_request)
|
||||
|
||||
return response.choices[0].message.content if response.choices else ""
|
||||
|
||||
async def _execute_aggregate_step(self, step: WorkflowStep, context: WorkflowContext):
|
||||
"""Execute aggregate step to combine multiple inputs"""
|
||||
|
||||
@@ -23,7 +23,8 @@ from app.core.config import settings
|
||||
from app.db.database import async_session_factory
|
||||
from app.models.user import User
|
||||
from app.models.chatbot import ChatbotInstance
|
||||
from app.services.litellm_client import LiteLLMClient
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
|
||||
from cryptography.fernet import Fernet
|
||||
import base64
|
||||
import os
|
||||
@@ -65,8 +66,8 @@ class ZammadModule(BaseModule):
|
||||
try:
|
||||
logger.info("Initializing Zammad module...")
|
||||
|
||||
# Initialize LLM client for chatbot integration
|
||||
self.llm_client = LiteLLMClient()
|
||||
# Initialize LLM service for chatbot integration
|
||||
# Note: llm_service is already a global singleton, no need to create instance
|
||||
|
||||
# Create HTTP session pool for Zammad API calls
|
||||
timeout = aiohttp.ClientTimeout(total=60, connect=10)
|
||||
@@ -597,19 +598,21 @@ class ZammadModule(BaseModule):
|
||||
}
|
||||
]
|
||||
|
||||
# Generate summary using LLM client
|
||||
response = await self.llm_client.create_chat_completion(
|
||||
messages=messages,
|
||||
# Generate summary using new LLM service
|
||||
chat_request = LLMChatRequest(
|
||||
model=await self._get_chatbot_model(config.chatbot_id),
|
||||
user_id=str(config.user_id),
|
||||
api_key_id=0, # Using 0 for module requests
|
||||
messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages],
|
||||
temperature=0.3,
|
||||
max_tokens=500
|
||||
max_tokens=500,
|
||||
user_id=str(config.user_id),
|
||||
api_key_id=0 # Using 0 for module requests
|
||||
)
|
||||
|
||||
# Extract content from LiteLLM response
|
||||
if "choices" in response and len(response["choices"]) > 0:
|
||||
return response["choices"][0]["message"]["content"].strip()
|
||||
response = await llm_service.create_chat_completion(chat_request)
|
||||
|
||||
# Extract content from new LLM service response
|
||||
if response.choices and len(response.choices) > 0:
|
||||
return response.choices[0].message.content.strip()
|
||||
|
||||
return "Unable to generate summary."
|
||||
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
"""
|
||||
Test LLM API endpoints.
|
||||
"""
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
|
||||
class TestLLMEndpoints:
|
||||
"""Test LLM API endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_success(self, client: AsyncClient):
|
||||
"""Test successful chat completion."""
|
||||
# Mock the LiteLLM client response
|
||||
mock_response = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "Hello! How can I help you today?",
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 25
|
||||
}
|
||||
}
|
||||
|
||||
with patch("app.services.litellm_client.LiteLLMClient.create_chat_completion") as mock_chat:
|
||||
mock_chat.return_value = mock_response
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "choices" in data
|
||||
assert data["choices"][0]["message"]["content"] == "Hello! How can I help you today?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_unauthorized(self, client: AsyncClient):
|
||||
"""Test chat completion without API key."""
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embeddings_success(self, client: AsyncClient):
|
||||
"""Test successful embeddings generation."""
|
||||
mock_response = {
|
||||
"data": [
|
||||
{
|
||||
"embedding": [0.1, 0.2, 0.3],
|
||||
"index": 0
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"total_tokens": 5
|
||||
}
|
||||
}
|
||||
|
||||
with patch("app.services.litellm_client.LiteLLMClient.create_embedding") as mock_embeddings:
|
||||
mock_embeddings.return_value = mock_response
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/embeddings",
|
||||
json={
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": "Hello world"
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert len(data["data"][0]["embedding"]) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_exceeded(self, client: AsyncClient):
|
||||
"""Test budget exceeded scenario."""
|
||||
with patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_check:
|
||||
mock_check.side_effect = Exception("Budget exceeded")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 402 # Payment required
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_validation(self, client: AsyncClient):
|
||||
"""Test model validation."""
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "invalid-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
496
backend/tests/integration/test_llm_service_integration.py
Normal file
496
backend/tests/integration/test_llm_service_integration.py
Normal file
@@ -0,0 +1,496 @@
|
||||
"""
|
||||
Integration tests for the new LLM service.
|
||||
Tests end-to-end functionality including provider integration, security, and performance.
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
from httpx import AsyncClient
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
import json
|
||||
|
||||
|
||||
class TestLLMServiceIntegration:
|
||||
"""Integration tests for LLM service."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_chat_flow(self, client: AsyncClient):
|
||||
"""Test complete chat completion flow with security and budget checks."""
|
||||
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
|
||||
|
||||
# Mock successful LLM service response
|
||||
mock_response = ChatCompletionResponse(
|
||||
id="test-completion-123",
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model="privatemode-llama-3-70b",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content="Hello! I'm a TEE-protected AI assistant. How can I help you today?"
|
||||
),
|
||||
finish_reason="stop"
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=25,
|
||||
completion_tokens=15,
|
||||
total_tokens=40
|
||||
),
|
||||
security_analysis={
|
||||
"risk_score": 0.1,
|
||||
"threats_detected": [],
|
||||
"risk_level": "low",
|
||||
"analysis_time_ms": 12.5
|
||||
}
|
||||
)
|
||||
|
||||
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat, \
|
||||
patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_budget:
|
||||
|
||||
mock_chat.return_value = mock_response
|
||||
mock_budget.return_value = True # Budget check passes
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "privatemode-llama-3-70b",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, what are your capabilities?"}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 150
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
# Verify response structure
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Check standard OpenAI-compatible fields
|
||||
assert "id" in data
|
||||
assert "object" in data
|
||||
assert "created" in data
|
||||
assert "model" in data
|
||||
assert "choices" in data
|
||||
assert "usage" in data
|
||||
|
||||
# Check security integration
|
||||
assert "security_analysis" in data
|
||||
assert data["security_analysis"]["risk_level"] == "low"
|
||||
|
||||
# Verify content
|
||||
assert len(data["choices"]) == 1
|
||||
assert data["choices"][0]["message"]["role"] == "assistant"
|
||||
assert "TEE-protected" in data["choices"][0]["message"]["content"]
|
||||
|
||||
# Verify usage tracking
|
||||
assert data["usage"]["total_tokens"] == 40
|
||||
assert data["usage"]["prompt_tokens"] == 25
|
||||
assert data["usage"]["completion_tokens"] == 15
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_integration(self, client: AsyncClient):
|
||||
"""Test embedding generation with fallback handling."""
|
||||
from app.services.llm.models import EmbeddingResponse, EmbeddingData, Usage
|
||||
|
||||
# Create realistic 1024-dimensional embedding
|
||||
embedding_vector = [0.1 * i for i in range(1024)]
|
||||
|
||||
mock_response = EmbeddingResponse(
|
||||
object="list",
|
||||
data=[
|
||||
EmbeddingData(
|
||||
object="embedding",
|
||||
embedding=embedding_vector,
|
||||
index=0
|
||||
)
|
||||
],
|
||||
model="privatemode-embeddings",
|
||||
usage=Usage(
|
||||
prompt_tokens=8,
|
||||
total_tokens=8
|
||||
)
|
||||
)
|
||||
|
||||
with patch("app.services.llm.service.llm_service.create_embedding") as mock_embedding:
|
||||
mock_embedding.return_value = mock_response
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/embeddings",
|
||||
json={
|
||||
"model": "privatemode-embeddings",
|
||||
"input": "This is a test document for embedding generation."
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify embedding structure
|
||||
assert "object" in data
|
||||
assert "data" in data
|
||||
assert "usage" in data
|
||||
assert len(data["data"]) == 1
|
||||
assert len(data["data"][0]["embedding"]) == 1024
|
||||
assert data["data"][0]["index"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_health_integration(self, client: AsyncClient):
|
||||
"""Test provider health monitoring integration."""
|
||||
mock_status = {
|
||||
"privatemode": {
|
||||
"provider": "PrivateMode.ai",
|
||||
"status": "healthy",
|
||||
"latency_ms": 245.8,
|
||||
"success_rate": 0.987,
|
||||
"last_check": "2025-01-01T12:00:00Z",
|
||||
"error_message": None,
|
||||
"models_available": [
|
||||
"privatemode-llama-3-70b",
|
||||
"privatemode-claude-3-sonnet",
|
||||
"privatemode-gpt-4o",
|
||||
"privatemode-embeddings"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
with patch("app.services.llm.service.llm_service.get_provider_status") as mock_provider:
|
||||
mock_provider.return_value = mock_status
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/llm/providers/status",
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Check response structure
|
||||
assert "data" in data
|
||||
assert "privatemode" in data["data"]
|
||||
|
||||
provider_data = data["data"]["privatemode"]
|
||||
assert provider_data["status"] == "healthy"
|
||||
assert provider_data["latency_ms"] < 300 # Reasonable latency
|
||||
assert provider_data["success_rate"] > 0.95 # High success rate
|
||||
assert len(provider_data["models_available"]) >= 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_and_fallback(self, client: AsyncClient):
|
||||
"""Test error handling and fallback scenarios."""
|
||||
# Test provider unavailable scenario
|
||||
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
|
||||
mock_chat.side_effect = Exception("Provider temporarily unavailable")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "privatemode-llama-3-70b",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
# Should return error but not crash
|
||||
assert response.status_code in [500, 503] # Server error or service unavailable
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_security_threat_detection(self, client: AsyncClient):
|
||||
"""Test security threat detection integration."""
|
||||
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
|
||||
|
||||
# Mock response with security threat detected
|
||||
mock_response = ChatCompletionResponse(
|
||||
id="test-completion-security",
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model="privatemode-llama-3-70b",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content="I cannot help with that request as it violates security policies."
|
||||
),
|
||||
finish_reason="stop"
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=15,
|
||||
completion_tokens=12,
|
||||
total_tokens=27
|
||||
),
|
||||
security_analysis={
|
||||
"risk_score": 0.8,
|
||||
"threats_detected": ["potential_malicious_code"],
|
||||
"risk_level": "high",
|
||||
"blocked": True,
|
||||
"analysis_time_ms": 45.2
|
||||
}
|
||||
)
|
||||
|
||||
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
|
||||
mock_chat.return_value = mock_response
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "privatemode-llama-3-70b",
|
||||
"messages": [
|
||||
{"role": "user", "content": "How to create malicious code?"}
|
||||
]
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200 # Request succeeds but content is filtered
|
||||
data = response.json()
|
||||
|
||||
# Verify security analysis
|
||||
assert "security_analysis" in data
|
||||
assert data["security_analysis"]["risk_level"] == "high"
|
||||
assert data["security_analysis"]["blocked"] is True
|
||||
assert "malicious" in data["security_analysis"]["threats_detected"][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_characteristics(self, client: AsyncClient):
|
||||
"""Test performance characteristics of the LLM service."""
|
||||
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
|
||||
|
||||
# Mock fast response
|
||||
mock_response = ChatCompletionResponse(
|
||||
id="test-perf",
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model="privatemode-llama-3-70b",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content="Quick response for performance testing."
|
||||
),
|
||||
finish_reason="stop"
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=8,
|
||||
total_tokens=18
|
||||
)
|
||||
)
|
||||
|
||||
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
|
||||
mock_chat.return_value = mock_response
|
||||
|
||||
# Measure response time
|
||||
start_time = time.time()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "privatemode-llama-3-70b",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Quick test"}
|
||||
]
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
response_time = time.time() - start_time
|
||||
|
||||
assert response.status_code == 200
|
||||
# API should respond quickly (mocked, so should be very fast)
|
||||
assert response_time < 1.0 # Less than 1 second for mocked response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_capabilities_detection(self, client: AsyncClient):
|
||||
"""Test model capabilities detection and reporting."""
|
||||
from app.services.llm.models import Model
|
||||
|
||||
mock_models = [
|
||||
Model(
|
||||
id="privatemode-llama-3-70b",
|
||||
object="model",
|
||||
created=1234567890,
|
||||
owned_by="PrivateMode.ai",
|
||||
provider="PrivateMode.ai",
|
||||
capabilities=["tee", "chat", "function_calling"],
|
||||
context_window=32768,
|
||||
max_output_tokens=4096,
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True
|
||||
),
|
||||
Model(
|
||||
id="privatemode-embeddings",
|
||||
object="model",
|
||||
created=1234567890,
|
||||
owned_by="PrivateMode.ai",
|
||||
provider="PrivateMode.ai",
|
||||
capabilities=["tee", "embeddings"],
|
||||
context_window=512,
|
||||
supports_streaming=False,
|
||||
supports_function_calling=False
|
||||
)
|
||||
]
|
||||
|
||||
with patch("app.services.llm.service.llm_service.get_models") as mock_models_call:
|
||||
mock_models_call.return_value = mock_models
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/llm/models",
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify model capabilities
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
# Check chat model capabilities
|
||||
chat_model = next(m for m in data["data"] if m["id"] == "privatemode-llama-3-70b")
|
||||
assert "tee" in chat_model["capabilities"]
|
||||
assert "chat" in chat_model["capabilities"]
|
||||
assert chat_model["supports_streaming"] is True
|
||||
assert chat_model["supports_function_calling"] is True
|
||||
assert chat_model["context_window"] == 32768
|
||||
|
||||
# Check embedding model capabilities
|
||||
embed_model = next(m for m in data["data"] if m["id"] == "privatemode-embeddings")
|
||||
assert "tee" in embed_model["capabilities"]
|
||||
assert "embeddings" in embed_model["capabilities"]
|
||||
assert embed_model["supports_streaming"] is False
|
||||
assert embed_model["context_window"] == 512
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests(self, client: AsyncClient):
|
||||
"""Test handling of concurrent requests."""
|
||||
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
|
||||
|
||||
mock_response = ChatCompletionResponse(
|
||||
id="test-concurrent",
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model="privatemode-llama-3-70b",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content="Concurrent response"
|
||||
),
|
||||
finish_reason="stop"
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=5,
|
||||
completion_tokens=3,
|
||||
total_tokens=8
|
||||
)
|
||||
)
|
||||
|
||||
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
|
||||
mock_chat.return_value = mock_response
|
||||
|
||||
# Create multiple concurrent requests
|
||||
tasks = []
|
||||
for i in range(5):
|
||||
task = client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "privatemode-llama-3-70b",
|
||||
"messages": [
|
||||
{"role": "user", "content": f"Concurrent test {i}"}
|
||||
]
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Execute all requests concurrently
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify all requests succeeded
|
||||
for response in responses:
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "choices" in data
|
||||
assert data["choices"][0]["message"]["content"] == "Concurrent response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_budget_enforcement_integration(self, client: AsyncClient):
|
||||
"""Test budget enforcement integration with LLM service."""
|
||||
# Test budget exceeded scenario
|
||||
with patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_budget:
|
||||
mock_budget.side_effect = Exception("Monthly budget limit exceeded")
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "privatemode-llama-3-70b",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Test budget enforcement"}
|
||||
]
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 402 # Payment required
|
||||
|
||||
# Test budget warning scenario
|
||||
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
|
||||
|
||||
mock_response = ChatCompletionResponse(
|
||||
id="test-budget-warning",
|
||||
object="chat.completion",
|
||||
created=int(time.time()),
|
||||
model="privatemode-llama-3-70b",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content="Response with budget warning"
|
||||
),
|
||||
finish_reason="stop"
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=8,
|
||||
total_tokens=18
|
||||
),
|
||||
budget_warnings=["Approaching monthly budget limit (85% used)"]
|
||||
)
|
||||
|
||||
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat, \
|
||||
patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_budget:
|
||||
|
||||
mock_chat.return_value = mock_response
|
||||
mock_budget.return_value = True # Budget check passes but with warning
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "privatemode-llama-3-70b",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Test budget warning"}
|
||||
]
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "budget_warnings" in data
|
||||
assert len(data["budget_warnings"]) > 0
|
||||
assert "85%" in data["budget_warnings"][0]
|
||||
296
backend/tests/integration/test_llm_validation.py
Normal file
296
backend/tests/integration/test_llm_validation.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
Simple validation tests for the new LLM service integration.
|
||||
Tests basic functionality without complex mocking.
|
||||
"""
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestLLMServiceValidation:
|
||||
"""Basic validation tests for LLM service integration."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_models_endpoint_exists(self, client: AsyncClient):
|
||||
"""Test that the LLM models endpoint exists and is accessible."""
|
||||
response = await client.get(
|
||||
"/api/v1/llm/models",
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
# Should not return 404 (endpoint exists)
|
||||
assert response.status_code != 404
|
||||
# May return 500 or other error due to missing LLM service, but endpoint exists
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_chat_endpoint_exists(self, client: AsyncClient):
|
||||
"""Test that the LLM chat endpoint exists and is accessible."""
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
# Should not return 404 (endpoint exists)
|
||||
assert response.status_code != 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_embeddings_endpoint_exists(self, client: AsyncClient):
|
||||
"""Test that the LLM embeddings endpoint exists and is accessible."""
|
||||
response = await client.post(
|
||||
"/api/v1/llm/embeddings",
|
||||
json={
|
||||
"model": "test-embedding-model",
|
||||
"input": "Test text"
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
# Should not return 404 (endpoint exists)
|
||||
assert response.status_code != 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_provider_status_endpoint_exists(self, client: AsyncClient):
|
||||
"""Test that the provider status endpoint exists and is accessible."""
|
||||
response = await client.get(
|
||||
"/api/v1/llm/providers/status",
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
# Should not return 404 (endpoint exists)
|
||||
assert response.status_code != 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_mocked_service(self, client: AsyncClient):
|
||||
"""Test chat completion with mocked LLM service."""
|
||||
from app.services.llm.models import ChatResponse, ChatChoice, ChatMessage, TokenUsage
|
||||
|
||||
# Mock successful response
|
||||
mock_response = ChatResponse(
|
||||
id="test-123",
|
||||
object="chat.completion",
|
||||
created=1234567890,
|
||||
model="privatemode-llama-3-70b",
|
||||
provider="PrivateMode.ai",
|
||||
choices=[
|
||||
ChatChoice(
|
||||
index=0,
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content="Hello! How can I help you?"
|
||||
),
|
||||
finish_reason="stop"
|
||||
)
|
||||
],
|
||||
usage=TokenUsage(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=8,
|
||||
total_tokens=18
|
||||
),
|
||||
security_check=True,
|
||||
risk_score=0.1,
|
||||
detected_patterns=[],
|
||||
latency_ms=250.5
|
||||
)
|
||||
|
||||
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
|
||||
mock_chat.return_value = mock_response
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "privatemode-llama-3-70b",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify basic response structure
|
||||
assert "id" in data
|
||||
assert "choices" in data
|
||||
assert len(data["choices"]) == 1
|
||||
assert data["choices"][0]["message"]["content"] == "Hello! How can I help you?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_with_mocked_service(self, client: AsyncClient):
|
||||
"""Test embedding generation with mocked LLM service."""
|
||||
from app.services.llm.models import EmbeddingResponse, EmbeddingData, TokenUsage
|
||||
|
||||
# Create a simple embedding vector
|
||||
embedding_vector = [0.1, 0.2, 0.3] * 341 + [0.1, 0.2, 0.3] # 1024 dimensions
|
||||
|
||||
mock_response = EmbeddingResponse(
|
||||
object="list",
|
||||
data=[
|
||||
EmbeddingData(
|
||||
object="embedding",
|
||||
index=0,
|
||||
embedding=embedding_vector
|
||||
)
|
||||
],
|
||||
model="privatemode-embeddings",
|
||||
provider="PrivateMode.ai",
|
||||
usage=TokenUsage(
|
||||
prompt_tokens=5,
|
||||
completion_tokens=0,
|
||||
total_tokens=5
|
||||
),
|
||||
security_check=True,
|
||||
risk_score=0.0,
|
||||
latency_ms=150.0
|
||||
)
|
||||
|
||||
with patch("app.services.llm.service.llm_service.create_embedding") as mock_embedding:
|
||||
mock_embedding.return_value = mock_response
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/embeddings",
|
||||
json={
|
||||
"model": "privatemode-embeddings",
|
||||
"input": "Test text for embedding"
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify basic response structure
|
||||
assert "data" in data
|
||||
assert len(data["data"]) == 1
|
||||
assert len(data["data"][0]["embedding"]) == 1024
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_with_mocked_service(self, client: AsyncClient):
|
||||
"""Test models listing with mocked LLM service."""
|
||||
from app.services.llm.models import ModelInfo
|
||||
|
||||
mock_models = [
|
||||
ModelInfo(
|
||||
id="privatemode-llama-3-70b",
|
||||
object="model",
|
||||
created=1234567890,
|
||||
owned_by="PrivateMode.ai",
|
||||
provider="PrivateMode.ai",
|
||||
capabilities=["tee", "chat"],
|
||||
context_window=32768,
|
||||
supports_streaming=True
|
||||
),
|
||||
ModelInfo(
|
||||
id="privatemode-embeddings",
|
||||
object="model",
|
||||
created=1234567890,
|
||||
owned_by="PrivateMode.ai",
|
||||
provider="PrivateMode.ai",
|
||||
capabilities=["tee", "embeddings"],
|
||||
context_window=512,
|
||||
supports_streaming=False
|
||||
)
|
||||
]
|
||||
|
||||
with patch("app.services.llm.service.llm_service.get_models") as mock_models_call:
|
||||
mock_models_call.return_value = mock_models
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/llm/models",
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify basic response structure
|
||||
assert "data" in data
|
||||
assert len(data["data"]) == 2
|
||||
assert data["data"][0]["id"] == "privatemode-llama-3-70b"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_status_with_mocked_service(self, client: AsyncClient):
|
||||
"""Test provider status with mocked LLM service."""
|
||||
mock_status = {
|
||||
"privatemode": {
|
||||
"provider": "PrivateMode.ai",
|
||||
"status": "healthy",
|
||||
"latency_ms": 250.5,
|
||||
"success_rate": 0.98,
|
||||
"last_check": "2025-01-01T12:00:00Z",
|
||||
"models_available": ["privatemode-llama-3-70b", "privatemode-embeddings"]
|
||||
}
|
||||
}
|
||||
|
||||
with patch("app.services.llm.service.llm_service.get_provider_status") as mock_provider:
|
||||
mock_provider.return_value = mock_status
|
||||
|
||||
response = await client.get(
|
||||
"/api/v1/llm/providers/status",
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify basic response structure
|
||||
assert "data" in data
|
||||
assert "privatemode" in data["data"]
|
||||
assert data["data"]["privatemode"]["status"] == "healthy"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_access(self, client: AsyncClient):
|
||||
"""Test that unauthorized requests are properly rejected."""
|
||||
# Test without authorization header
|
||||
response = await client.get("/api/v1/llm/models")
|
||||
assert response.status_code == 401
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/llm/embeddings",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"input": "Hello"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_request_data(self, client: AsyncClient):
|
||||
"""Test that invalid request data is properly handled."""
|
||||
# Test invalid JSON structure
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
# Missing required fields
|
||||
"model": "test-model"
|
||||
# messages field is missing
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
assert response.status_code == 422 # Unprocessable Entity
|
||||
|
||||
# Test empty messages
|
||||
response = await client.post(
|
||||
"/api/v1/llm/chat/completions",
|
||||
json={
|
||||
"model": "test-model",
|
||||
"messages": [] # Empty messages
|
||||
},
|
||||
headers={"Authorization": "Bearer test-api-key"}
|
||||
)
|
||||
assert response.status_code == 422 # Unprocessable Entity
|
||||
@@ -6,7 +6,7 @@ import { Badge } from '@/components/ui/badge'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
|
||||
import { Alert, AlertDescription } from '@/components/ui/alert'
|
||||
import { RefreshCw, Zap, Info, AlertCircle } from 'lucide-react'
|
||||
import { RefreshCw, Zap, Info, AlertCircle, CheckCircle, XCircle, Clock } from 'lucide-react'
|
||||
|
||||
interface Model {
|
||||
id: string
|
||||
@@ -16,6 +16,22 @@ interface Model {
|
||||
permission?: any[]
|
||||
root?: string
|
||||
parent?: string
|
||||
provider?: string
|
||||
capabilities?: string[]
|
||||
context_window?: number
|
||||
max_output_tokens?: number
|
||||
supports_streaming?: boolean
|
||||
supports_function_calling?: boolean
|
||||
}
|
||||
|
||||
interface ProviderStatus {
|
||||
provider: string
|
||||
status: 'healthy' | 'degraded' | 'unavailable'
|
||||
latency_ms?: number
|
||||
success_rate?: number
|
||||
last_check: string
|
||||
error_message?: string
|
||||
models_available: string[]
|
||||
}
|
||||
|
||||
interface ModelSelectorProps {
|
||||
@@ -27,6 +43,7 @@ interface ModelSelectorProps {
|
||||
|
||||
export default function ModelSelector({ value, onValueChange, filter = 'all', className }: ModelSelectorProps) {
|
||||
const [models, setModels] = useState<Model[]>([])
|
||||
const [providerStatus, setProviderStatus] = useState<Record<string, ProviderStatus>>({})
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
const [showDetails, setShowDetails] = useState(false)
|
||||
@@ -37,20 +54,31 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
|
||||
// Get the auth token from localStorage
|
||||
const token = localStorage.getItem('token')
|
||||
const headers = {
|
||||
'Authorization': token ? `Bearer ${token}` : '',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
const response = await fetch('/api/llm/models', {
|
||||
headers: {
|
||||
'Authorization': token ? `Bearer ${token}` : '',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
// Fetch models and provider status in parallel
|
||||
const [modelsResponse, statusResponse] = await Promise.allSettled([
|
||||
fetch('/api/llm/models', { headers }),
|
||||
fetch('/api/llm/providers/status', { headers })
|
||||
])
|
||||
|
||||
// Handle models response
|
||||
if (modelsResponse.status === 'fulfilled' && modelsResponse.value.ok) {
|
||||
const modelsData = await modelsResponse.value.json()
|
||||
setModels(modelsData.data || [])
|
||||
} else {
|
||||
throw new Error('Failed to fetch models')
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
setModels(data.data || [])
|
||||
|
||||
// Handle provider status response (optional)
|
||||
if (statusResponse.status === 'fulfilled' && statusResponse.value.ok) {
|
||||
const statusData = await statusResponse.value.json()
|
||||
setProviderStatus(statusData.data || {})
|
||||
}
|
||||
|
||||
setError(null)
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to load models')
|
||||
@@ -64,30 +92,39 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
}, [])
|
||||
|
||||
const getProviderFromModel = (modelId: string): string => {
|
||||
// PrivateMode models have specific prefixes
|
||||
if (modelId.startsWith('privatemode-')) return 'PrivateMode.ai'
|
||||
|
||||
// Legacy detection for other providers
|
||||
if (modelId.startsWith('gpt-') || modelId.includes('openai')) return 'OpenAI'
|
||||
if (modelId.startsWith('claude-') || modelId.includes('anthropic')) return 'Anthropic'
|
||||
if (modelId.startsWith('gemini-') || modelId.includes('google')) return 'Google'
|
||||
if (modelId.includes('privatemode')) return 'Privatemode.ai'
|
||||
if (modelId.includes('cohere')) return 'Cohere'
|
||||
if (modelId.includes('mistral')) return 'Mistral'
|
||||
if (modelId.includes('llama')) return 'Meta'
|
||||
if (modelId.includes('llama') && !modelId.startsWith('privatemode-')) return 'Meta'
|
||||
return 'Unknown'
|
||||
}
|
||||
|
||||
const getModelType = (modelId: string): 'chat' | 'embedding' | 'other' => {
|
||||
if (modelId.includes('embedding')) return 'embedding'
|
||||
if (modelId.includes('whisper')) return 'other' // Audio transcription models
|
||||
if (modelId.includes('embedding') || modelId.includes('embed')) return 'embedding'
|
||||
if (modelId.includes('whisper') || modelId.includes('speech')) return 'other' // Audio models
|
||||
|
||||
// PrivateMode and other chat models
|
||||
if (
|
||||
modelId.startsWith('privatemode-llama') ||
|
||||
modelId.startsWith('privatemode-claude') ||
|
||||
modelId.startsWith('privatemode-gpt') ||
|
||||
modelId.startsWith('privatemode-gemini') ||
|
||||
modelId.includes('text-') ||
|
||||
modelId.includes('gpt-') ||
|
||||
modelId.includes('claude-') ||
|
||||
modelId.includes('gemini-') ||
|
||||
modelId.includes('privatemode-') ||
|
||||
modelId.includes('llama') ||
|
||||
modelId.includes('gemma') ||
|
||||
modelId.includes('qwen') ||
|
||||
modelId.includes('latest')
|
||||
) return 'chat'
|
||||
|
||||
return 'other'
|
||||
}
|
||||
|
||||
@@ -112,6 +149,28 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
acc[provider].push(model)
|
||||
return acc
|
||||
}, {} as Record<string, Model[]>)
|
||||
|
||||
const getProviderStatusIcon = (provider: string) => {
|
||||
const status = providerStatus[provider.toLowerCase()]?.status || 'unknown'
|
||||
switch (status) {
|
||||
case 'healthy':
|
||||
return <CheckCircle className="h-3 w-3 text-green-500" />
|
||||
case 'degraded':
|
||||
return <Clock className="h-3 w-3 text-yellow-500" />
|
||||
case 'unavailable':
|
||||
return <XCircle className="h-3 w-3 text-red-500" />
|
||||
default:
|
||||
return <AlertCircle className="h-3 w-3 text-gray-400" />
|
||||
}
|
||||
}
|
||||
|
||||
const getProviderStatusText = (provider: string) => {
|
||||
const status = providerStatus[provider.toLowerCase()]
|
||||
if (!status) return 'Status unknown'
|
||||
|
||||
const latencyText = status.latency_ms ? ` (${Math.round(status.latency_ms)}ms)` : ''
|
||||
return `${status.status.charAt(0).toUpperCase() + status.status.slice(1)}${latencyText}`
|
||||
}
|
||||
|
||||
const selectedModel = models.find(m => m.id === value)
|
||||
|
||||
@@ -191,16 +250,32 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
<SelectContent>
|
||||
{Object.entries(groupedModels).map(([provider, providerModels]) => (
|
||||
<div key={provider}>
|
||||
<div className="px-2 py-1.5 text-sm font-semibold text-muted-foreground">
|
||||
{provider}
|
||||
<div className="px-2 py-1.5 text-sm font-semibold text-muted-foreground flex items-center gap-2">
|
||||
{getProviderStatusIcon(provider)}
|
||||
<span>{provider}</span>
|
||||
<span className="text-xs font-normal text-muted-foreground">
|
||||
{getProviderStatusText(provider)}
|
||||
</span>
|
||||
</div>
|
||||
{providerModels.map((model) => (
|
||||
<SelectItem key={model.id} value={model.id}>
|
||||
<div className="flex items-center gap-2">
|
||||
<span>{model.id}</span>
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{getModelCategory(model.id)}
|
||||
</Badge>
|
||||
<div className="flex gap-1">
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{getModelCategory(model.id)}
|
||||
</Badge>
|
||||
{model.supports_streaming && (
|
||||
<Badge variant="secondary" className="text-xs">
|
||||
Streaming
|
||||
</Badge>
|
||||
)}
|
||||
{model.supports_function_calling && (
|
||||
<Badge variant="secondary" className="text-xs">
|
||||
Functions
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
@@ -217,7 +292,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
Model Details
|
||||
</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-2 text-sm">
|
||||
<CardContent className="space-y-3 text-sm">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div>
|
||||
<span className="font-medium">ID:</span>
|
||||
@@ -225,7 +300,10 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium">Provider:</span>
|
||||
<div className="text-muted-foreground">{getProviderFromModel(selectedModel.id)}</div>
|
||||
<div className="text-muted-foreground flex items-center gap-1">
|
||||
{getProviderStatusIcon(getProviderFromModel(selectedModel.id))}
|
||||
{getProviderFromModel(selectedModel.id)}
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-medium">Type:</span>
|
||||
@@ -237,6 +315,40 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{(selectedModel.context_window || selectedModel.max_output_tokens) && (
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
{selectedModel.context_window && (
|
||||
<div>
|
||||
<span className="font-medium">Context Window:</span>
|
||||
<div className="text-muted-foreground">{selectedModel.context_window.toLocaleString()} tokens</div>
|
||||
</div>
|
||||
)}
|
||||
{selectedModel.max_output_tokens && (
|
||||
<div>
|
||||
<span className="font-medium">Max Output:</span>
|
||||
<div className="text-muted-foreground">{selectedModel.max_output_tokens.toLocaleString()} tokens</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{(selectedModel.supports_streaming || selectedModel.supports_function_calling) && (
|
||||
<div>
|
||||
<span className="font-medium">Capabilities:</span>
|
||||
<div className="flex gap-1 mt-1">
|
||||
{selectedModel.supports_streaming && (
|
||||
<Badge variant="secondary" className="text-xs">Streaming</Badge>
|
||||
)}
|
||||
{selectedModel.supports_function_calling && (
|
||||
<Badge variant="secondary" className="text-xs">Function Calling</Badge>
|
||||
)}
|
||||
{selectedModel.capabilities?.includes('tee') && (
|
||||
<Badge variant="outline" className="text-xs border-green-500 text-green-700">TEE Protected</Badge>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{selectedModel.created && (
|
||||
<div>
|
||||
<span className="font-medium">Created:</span>
|
||||
@@ -252,6 +364,46 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
||||
<div className="text-muted-foreground">{selectedModel.owned_by}</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Provider Status Details */}
|
||||
{providerStatus[getProviderFromModel(selectedModel.id).toLowerCase()] && (
|
||||
<div className="border-t pt-3">
|
||||
<span className="font-medium">Provider Status:</span>
|
||||
<div className="mt-1 text-xs space-y-1">
|
||||
{(() => {
|
||||
const status = providerStatus[getProviderFromModel(selectedModel.id).toLowerCase()]
|
||||
return (
|
||||
<>
|
||||
<div className="flex justify-between">
|
||||
<span>Status:</span>
|
||||
<span className={`font-medium ${
|
||||
status.status === 'healthy' ? 'text-green-600' :
|
||||
status.status === 'degraded' ? 'text-yellow-600' :
|
||||
'text-red-600'
|
||||
}`}>{status.status}</span>
|
||||
</div>
|
||||
{status.latency_ms && (
|
||||
<div className="flex justify-between">
|
||||
<span>Latency:</span>
|
||||
<span>{Math.round(status.latency_ms)}ms</span>
|
||||
</div>
|
||||
)}
|
||||
{status.success_rate && (
|
||||
<div className="flex justify-between">
|
||||
<span>Success Rate:</span>
|
||||
<span>{Math.round(status.success_rate * 100)}%</span>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex justify-between">
|
||||
<span>Last Check:</span>
|
||||
<span>{new Date(status.last_check).toLocaleTimeString()}</span>
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
})()}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
373
frontend/src/components/playground/ProviderHealthDashboard.tsx
Normal file
373
frontend/src/components/playground/ProviderHealthDashboard.tsx
Normal file
@@ -0,0 +1,373 @@
|
||||
"use client"
|
||||
|
||||
import { useState, useEffect } from 'react'
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
|
||||
import { Badge } from '@/components/ui/badge'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Alert, AlertDescription } from '@/components/ui/alert'
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
import {
|
||||
CheckCircle,
|
||||
XCircle,
|
||||
Clock,
|
||||
AlertCircle,
|
||||
RefreshCw,
|
||||
Activity,
|
||||
Zap,
|
||||
Shield,
|
||||
Server
|
||||
} from 'lucide-react'
|
||||
|
||||
interface ProviderStatus {
|
||||
provider: string
|
||||
status: 'healthy' | 'degraded' | 'unavailable'
|
||||
latency_ms?: number
|
||||
success_rate?: number
|
||||
last_check: string
|
||||
error_message?: string
|
||||
models_available: string[]
|
||||
}
|
||||
|
||||
interface LLMMetrics {
|
||||
total_requests: number
|
||||
successful_requests: number
|
||||
failed_requests: number
|
||||
security_blocked_requests: number
|
||||
average_latency_ms: number
|
||||
average_risk_score: number
|
||||
provider_metrics: Record<string, any>
|
||||
last_updated: string
|
||||
}
|
||||
|
||||
export default function ProviderHealthDashboard() {
|
||||
const [providers, setProviders] = useState<Record<string, ProviderStatus>>({})
|
||||
const [metrics, setMetrics] = useState<LLMMetrics | null>(null)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
|
||||
const fetchData = async () => {
|
||||
try {
|
||||
setLoading(true)
|
||||
const token = localStorage.getItem('token')
|
||||
const headers = {
|
||||
'Authorization': token ? `Bearer ${token}` : '',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
const [statusResponse, metricsResponse] = await Promise.allSettled([
|
||||
fetch('/api/llm/providers/status', { headers }),
|
||||
fetch('/api/llm/metrics', { headers })
|
||||
])
|
||||
|
||||
// Handle provider status
|
||||
if (statusResponse.status === 'fulfilled' && statusResponse.value.ok) {
|
||||
const statusData = await statusResponse.value.json()
|
||||
setProviders(statusData.data || {})
|
||||
}
|
||||
|
||||
// Handle metrics (optional, might require admin permissions)
|
||||
if (metricsResponse.status === 'fulfilled' && metricsResponse.value.ok) {
|
||||
const metricsData = await metricsResponse.value.json()
|
||||
setMetrics(metricsData.data)
|
||||
}
|
||||
|
||||
setError(null)
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to load provider data')
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
fetchData()
|
||||
|
||||
// Auto-refresh every 30 seconds
|
||||
const interval = setInterval(fetchData, 30000)
|
||||
return () => clearInterval(interval)
|
||||
}, [])
|
||||
|
||||
const getStatusIcon = (status: string) => {
|
||||
switch (status) {
|
||||
case 'healthy':
|
||||
return <CheckCircle className="h-5 w-5 text-green-500" />
|
||||
case 'degraded':
|
||||
return <Clock className="h-5 w-5 text-yellow-500" />
|
||||
case 'unavailable':
|
||||
return <XCircle className="h-5 w-5 text-red-500" />
|
||||
default:
|
||||
return <AlertCircle className="h-5 w-5 text-gray-400" />
|
||||
}
|
||||
}
|
||||
|
||||
const getStatusColor = (status: string) => {
|
||||
switch (status) {
|
||||
case 'healthy':
|
||||
return 'text-green-600 bg-green-50 border-green-200'
|
||||
case 'degraded':
|
||||
return 'text-yellow-600 bg-yellow-50 border-yellow-200'
|
||||
case 'unavailable':
|
||||
return 'text-red-600 bg-red-50 border-red-200'
|
||||
default:
|
||||
return 'text-gray-600 bg-gray-50 border-gray-200'
|
||||
}
|
||||
}
|
||||
|
||||
const getLatencyColor = (latency: number) => {
|
||||
if (latency < 500) return 'text-green-600'
|
||||
if (latency < 2000) return 'text-yellow-600'
|
||||
return 'text-red-600'
|
||||
}
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<h2 className="text-2xl font-bold">Provider Health Dashboard</h2>
|
||||
<RefreshCw className="h-5 w-5 animate-spin" />
|
||||
</div>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{[1, 2, 3].map(i => (
|
||||
<Card key={i} className="animate-pulse">
|
||||
<CardHeader className="space-y-2">
|
||||
<div className="h-4 bg-gray-200 rounded w-3/4"></div>
|
||||
<div className="h-3 bg-gray-200 rounded w-1/2"></div>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="space-y-2">
|
||||
<div className="h-3 bg-gray-200 rounded"></div>
|
||||
<div className="h-3 bg-gray-200 rounded w-2/3"></div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<h2 className="text-2xl font-bold">Provider Health Dashboard</h2>
|
||||
<Button onClick={fetchData} size="sm">
|
||||
<RefreshCw className="h-4 w-4 mr-2" />
|
||||
Retry
|
||||
</Button>
|
||||
</div>
|
||||
<Alert variant="destructive">
|
||||
<AlertCircle className="h-4 w-4" />
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const totalProviders = Object.keys(providers).length
|
||||
const healthyProviders = Object.values(providers).filter(p => p.status === 'healthy').length
|
||||
const overallHealth = totalProviders > 0 ? (healthyProviders / totalProviders) * 100 : 0
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className="flex items-center justify-between">
|
||||
<h2 className="text-2xl font-bold">Provider Health Dashboard</h2>
|
||||
<Button onClick={fetchData} size="sm" disabled={loading}>
|
||||
<RefreshCw className={`h-4 w-4 mr-2 ${loading ? 'animate-spin' : ''}`} />
|
||||
Refresh
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Overall Health Summary */}
|
||||
<div className="grid grid-cols-1 md:grid-cols-4 gap-4">
|
||||
<Card>
|
||||
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
|
||||
<CardTitle className="text-sm font-medium">Overall Health</CardTitle>
|
||||
<Activity className="h-4 w-4 text-muted-foreground" />
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="text-2xl font-bold">{Math.round(overallHealth)}%</div>
|
||||
<Progress value={overallHealth} className="mt-2" />
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
|
||||
<CardTitle className="text-sm font-medium">Healthy Providers</CardTitle>
|
||||
<CheckCircle className="h-4 w-4 text-green-500" />
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="text-2xl font-bold">{healthyProviders}</div>
|
||||
<p className="text-xs text-muted-foreground">of {totalProviders} providers</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{metrics && (
|
||||
<>
|
||||
<Card>
|
||||
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
|
||||
<CardTitle className="text-sm font-medium">Success Rate</CardTitle>
|
||||
<Zap className="h-4 w-4 text-muted-foreground" />
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="text-2xl font-bold">
|
||||
{metrics.total_requests > 0
|
||||
? Math.round((metrics.successful_requests / metrics.total_requests) * 100)
|
||||
: 0}%
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{metrics.successful_requests.toLocaleString()} / {metrics.total_requests.toLocaleString()} requests
|
||||
</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card>
|
||||
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
|
||||
<CardTitle className="text-sm font-medium">Security Score</CardTitle>
|
||||
<Shield className="h-4 w-4 text-muted-foreground" />
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="text-2xl font-bold">
|
||||
{Math.round((1 - metrics.average_risk_score) * 100)}%
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{metrics.security_blocked_requests} blocked requests
|
||||
</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Provider Details */}
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{Object.entries(providers).map(([name, provider]) => (
|
||||
<Card key={name} className={`border-2 ${getStatusColor(provider.status)}`}>
|
||||
<CardHeader>
|
||||
<div className="flex items-center justify-between">
|
||||
<CardTitle className="text-lg flex items-center gap-2">
|
||||
{getStatusIcon(provider.status)}
|
||||
{provider.provider}
|
||||
</CardTitle>
|
||||
<Badge
|
||||
variant={provider.status === 'healthy' ? 'default' : 'destructive'}
|
||||
className="capitalize"
|
||||
>
|
||||
{provider.status}
|
||||
</Badge>
|
||||
</div>
|
||||
<CardDescription>
|
||||
{provider.models_available.length} models available
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
|
||||
<CardContent className="space-y-3">
|
||||
{/* Performance Metrics */}
|
||||
{provider.latency_ms && (
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-sm font-medium">Latency</span>
|
||||
<span className={`text-sm font-mono ${getLatencyColor(provider.latency_ms)}`}>
|
||||
{Math.round(provider.latency_ms)}ms
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{provider.success_rate !== undefined && (
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-sm font-medium">Success Rate</span>
|
||||
<span className="text-sm font-mono">
|
||||
{Math.round(provider.success_rate * 100)}%
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex justify-between items-center">
|
||||
<span className="text-sm font-medium">Last Check</span>
|
||||
<span className="text-sm text-muted-foreground">
|
||||
{new Date(provider.last_check).toLocaleTimeString()}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Error Message */}
|
||||
{provider.error_message && (
|
||||
<Alert variant="destructive" className="mt-3">
|
||||
<AlertCircle className="h-4 w-4" />
|
||||
<AlertDescription className="text-xs">
|
||||
{provider.error_message}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{/* Models */}
|
||||
<div className="space-y-2">
|
||||
<span className="text-sm font-medium">Available Models</span>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{provider.models_available.slice(0, 3).map(model => (
|
||||
<Badge key={model} variant="outline" className="text-xs">
|
||||
{model}
|
||||
</Badge>
|
||||
))}
|
||||
{provider.models_available.length > 3 && (
|
||||
<Badge variant="outline" className="text-xs">
|
||||
+{provider.models_available.length - 3} more
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Provider Metrics Details */}
|
||||
{metrics && Object.keys(metrics.provider_metrics).length > 0 && (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2">
|
||||
<Server className="h-5 w-5" />
|
||||
Provider Performance Metrics
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
Detailed performance statistics for each provider
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{Object.entries(metrics.provider_metrics).map(([provider, data]: [string, any]) => (
|
||||
<div key={provider} className="border rounded-lg p-4">
|
||||
<h4 className="font-semibold mb-3 capitalize">{provider}</h4>
|
||||
<div className="space-y-2 text-sm">
|
||||
<div className="flex justify-between">
|
||||
<span>Total Requests:</span>
|
||||
<span className="font-mono">{data.total_requests?.toLocaleString() || 0}</span>
|
||||
</div>
|
||||
<div className="flex justify-between">
|
||||
<span>Success Rate:</span>
|
||||
<span className="font-mono">
|
||||
{data.success_rate ? Math.round(data.success_rate * 100) : 0}%
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex justify-between">
|
||||
<span>Avg Latency:</span>
|
||||
<span className={`font-mono ${getLatencyColor(data.average_latency_ms || 0)}`}>
|
||||
{Math.round(data.average_latency_ms || 0)}ms
|
||||
</span>
|
||||
</div>
|
||||
{data.token_usage && (
|
||||
<div className="flex justify-between">
|
||||
<span>Total Tokens:</span>
|
||||
<span className="font-mono">
|
||||
{data.token_usage.total_tokens?.toLocaleString() || 0}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
Reference in New Issue
Block a user