mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
removing lite llm and going directly for privatemode
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
LLM API endpoints - proxy to LiteLLM service with authentication and budget enforcement
|
LLM API endpoints - interface to secure LLM service with authentication and budget enforcement
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -16,7 +16,9 @@ from app.services.api_key_auth import require_api_key, RequireScope, APIKeyAuthS
|
|||||||
from app.core.security import get_current_user
|
from app.core.security import get_current_user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.services.litellm_client import litellm_client
|
from app.services.llm.service import llm_service
|
||||||
|
from app.services.llm.models import ChatRequest, ChatMessage as LLMChatMessage, EmbeddingRequest as LLMEmbeddingRequest
|
||||||
|
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError, ValidationError
|
||||||
from app.services.budget_enforcement import (
|
from app.services.budget_enforcement import (
|
||||||
check_budget_for_request, record_request_usage, BudgetEnforcementService,
|
check_budget_for_request, record_request_usage, BudgetEnforcementService,
|
||||||
atomic_check_and_reserve_budget, atomic_finalize_usage
|
atomic_check_and_reserve_budget, atomic_finalize_usage
|
||||||
@@ -38,7 +40,7 @@ router = APIRouter()
|
|||||||
|
|
||||||
|
|
||||||
async def get_cached_models() -> List[Dict[str, Any]]:
|
async def get_cached_models() -> List[Dict[str, Any]]:
|
||||||
"""Get models from cache or fetch from LiteLLM if cache is stale"""
|
"""Get models from cache or fetch from LLM service if cache is stale"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
# Check if cache is still valid
|
# Check if cache is still valid
|
||||||
@@ -47,10 +49,20 @@ async def get_cached_models() -> List[Dict[str, Any]]:
|
|||||||
logger.debug("Returning cached models list")
|
logger.debug("Returning cached models list")
|
||||||
return _models_cache["data"]
|
return _models_cache["data"]
|
||||||
|
|
||||||
# Cache miss or stale - fetch from LiteLLM
|
# Cache miss or stale - fetch from LLM service
|
||||||
try:
|
try:
|
||||||
logger.debug("Fetching fresh models list from LiteLLM")
|
logger.debug("Fetching fresh models list from LLM service")
|
||||||
models = await litellm_client.get_models()
|
model_infos = await llm_service.get_models()
|
||||||
|
|
||||||
|
# Convert ModelInfo objects to dict format for compatibility
|
||||||
|
models = []
|
||||||
|
for model_info in model_infos:
|
||||||
|
models.append({
|
||||||
|
"id": model_info.id,
|
||||||
|
"object": model_info.object,
|
||||||
|
"created": model_info.created or int(time.time()),
|
||||||
|
"owned_by": model_info.owned_by
|
||||||
|
})
|
||||||
|
|
||||||
# Update cache
|
# Update cache
|
||||||
_models_cache["data"] = models
|
_models_cache["data"] = models
|
||||||
@@ -58,7 +70,7 @@ async def get_cached_models() -> List[Dict[str, Any]]:
|
|||||||
|
|
||||||
return models
|
return models
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to fetch models from LiteLLM: {e}")
|
logger.error(f"Failed to fetch models from LLM service: {e}")
|
||||||
|
|
||||||
# Return stale cache if available, otherwise empty list
|
# Return stale cache if available, otherwise empty list
|
||||||
if _models_cache["data"] is not None:
|
if _models_cache["data"] is not None:
|
||||||
@@ -75,7 +87,7 @@ def invalidate_models_cache():
|
|||||||
logger.info("Models cache invalidated")
|
logger.info("Models cache invalidated")
|
||||||
|
|
||||||
|
|
||||||
# Request/Response Models
|
# Request/Response Models (API layer)
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: str = Field(..., description="Message role (system, user, assistant)")
|
role: str = Field(..., description="Message role (system, user, assistant)")
|
||||||
content: str = Field(..., description="Message content")
|
content: str = Field(..., description="Message content")
|
||||||
@@ -183,7 +195,7 @@ async def list_models(
|
|||||||
detail="Insufficient permissions to list models"
|
detail="Insufficient permissions to list models"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get models from cache or LiteLLM
|
# Get models from cache or LLM service
|
||||||
models = await get_cached_models()
|
models = await get_cached_models()
|
||||||
|
|
||||||
# Filter models based on API key permissions
|
# Filter models based on API key permissions
|
||||||
@@ -309,35 +321,55 @@ async def create_chat_completion(
|
|||||||
warnings = budget_warnings
|
warnings = budget_warnings
|
||||||
reserved_budget_ids = budget_ids
|
reserved_budget_ids = budget_ids
|
||||||
|
|
||||||
# Convert messages to dict format
|
# Convert messages to LLM service format
|
||||||
messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]
|
llm_messages = [LLMChatMessage(role=msg.role, content=msg.content) for msg in chat_request.messages]
|
||||||
|
|
||||||
# Prepare additional parameters
|
# Create LLM service request
|
||||||
kwargs = {}
|
llm_request = ChatRequest(
|
||||||
if chat_request.max_tokens is not None:
|
|
||||||
kwargs["max_tokens"] = chat_request.max_tokens
|
|
||||||
if chat_request.temperature is not None:
|
|
||||||
kwargs["temperature"] = chat_request.temperature
|
|
||||||
if chat_request.top_p is not None:
|
|
||||||
kwargs["top_p"] = chat_request.top_p
|
|
||||||
if chat_request.frequency_penalty is not None:
|
|
||||||
kwargs["frequency_penalty"] = chat_request.frequency_penalty
|
|
||||||
if chat_request.presence_penalty is not None:
|
|
||||||
kwargs["presence_penalty"] = chat_request.presence_penalty
|
|
||||||
if chat_request.stop is not None:
|
|
||||||
kwargs["stop"] = chat_request.stop
|
|
||||||
if chat_request.stream is not None:
|
|
||||||
kwargs["stream"] = chat_request.stream
|
|
||||||
|
|
||||||
# Make request to LiteLLM
|
|
||||||
response = await litellm_client.create_chat_completion(
|
|
||||||
model=chat_request.model,
|
model=chat_request.model,
|
||||||
messages=messages,
|
messages=llm_messages,
|
||||||
|
temperature=chat_request.temperature,
|
||||||
|
max_tokens=chat_request.max_tokens,
|
||||||
|
top_p=chat_request.top_p,
|
||||||
|
frequency_penalty=chat_request.frequency_penalty,
|
||||||
|
presence_penalty=chat_request.presence_penalty,
|
||||||
|
stop=chat_request.stop,
|
||||||
|
stream=chat_request.stream or False,
|
||||||
user_id=str(context.get("user_id", "anonymous")),
|
user_id=str(context.get("user_id", "anonymous")),
|
||||||
api_key_id=context.get("api_key_id", "jwt_user"),
|
api_key_id=context.get("api_key_id", 0) if auth_type == "api_key" else 0
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Make request to LLM service
|
||||||
|
llm_response = await llm_service.create_chat_completion(llm_request)
|
||||||
|
|
||||||
|
# Convert LLM service response to API format
|
||||||
|
response = {
|
||||||
|
"id": llm_response.id,
|
||||||
|
"object": llm_response.object,
|
||||||
|
"created": llm_response.created,
|
||||||
|
"model": llm_response.model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": choice.index,
|
||||||
|
"message": {
|
||||||
|
"role": choice.message.role,
|
||||||
|
"content": choice.message.content
|
||||||
|
},
|
||||||
|
"finish_reason": choice.finish_reason
|
||||||
|
}
|
||||||
|
for choice in llm_response.choices
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
|
||||||
|
"completion_tokens": llm_response.usage.completion_tokens if llm_response.usage else 0,
|
||||||
|
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
|
||||||
|
} if llm_response.usage else {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# Calculate actual cost and update usage
|
# Calculate actual cost and update usage
|
||||||
usage = response.get("usage", {})
|
usage = response.get("usage", {})
|
||||||
input_tokens = usage.get("prompt_tokens", 0)
|
input_tokens = usage.get("prompt_tokens", 0)
|
||||||
@@ -382,8 +414,38 @@ async def create_chat_completion(
|
|||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
except SecurityError as e:
|
||||||
|
logger.warning(f"Security error in chat completion: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Security validation failed: {e.message}"
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.warning(f"Validation error in chat completion: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Request validation failed: {e.message}"
|
||||||
|
)
|
||||||
|
except ProviderError as e:
|
||||||
|
logger.error(f"Provider error in chat completion: {e}")
|
||||||
|
if "rate limit" in str(e).lower():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail="Rate limit exceeded"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="LLM service temporarily unavailable"
|
||||||
|
)
|
||||||
|
except LLMError as e:
|
||||||
|
logger.error(f"LLM service error in chat completion: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="LLM service error"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating chat completion: {e}")
|
logger.error(f"Unexpected error creating chat completion: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to create chat completion"
|
detail="Failed to create chat completion"
|
||||||
@@ -438,15 +500,39 @@ async def create_embedding(
|
|||||||
detail=f"Budget exceeded: {error_message}"
|
detail=f"Budget exceeded: {error_message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make request to LiteLLM
|
# Create LLM service request
|
||||||
response = await litellm_client.create_embedding(
|
llm_request = LLMEmbeddingRequest(
|
||||||
model=request.model,
|
model=request.model,
|
||||||
input_text=request.input,
|
input=request.input,
|
||||||
|
encoding_format=request.encoding_format,
|
||||||
user_id=str(context["user_id"]),
|
user_id=str(context["user_id"]),
|
||||||
api_key_id=context["api_key_id"],
|
api_key_id=context["api_key_id"]
|
||||||
encoding_format=request.encoding_format
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Make request to LLM service
|
||||||
|
llm_response = await llm_service.create_embedding(llm_request)
|
||||||
|
|
||||||
|
# Convert LLM service response to API format
|
||||||
|
response = {
|
||||||
|
"object": llm_response.object,
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": emb.object,
|
||||||
|
"index": emb.index,
|
||||||
|
"embedding": emb.embedding
|
||||||
|
}
|
||||||
|
for emb in llm_response.data
|
||||||
|
],
|
||||||
|
"model": llm_response.model,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": llm_response.usage.prompt_tokens if llm_response.usage else 0,
|
||||||
|
"total_tokens": llm_response.usage.total_tokens if llm_response.usage else 0
|
||||||
|
} if llm_response.usage else {
|
||||||
|
"prompt_tokens": int(estimated_tokens),
|
||||||
|
"total_tokens": int(estimated_tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# Calculate actual cost and update usage
|
# Calculate actual cost and update usage
|
||||||
usage = response.get("usage", {})
|
usage = response.get("usage", {})
|
||||||
total_tokens = usage.get("total_tokens", int(estimated_tokens))
|
total_tokens = usage.get("total_tokens", int(estimated_tokens))
|
||||||
@@ -475,8 +561,38 @@ async def create_embedding(
|
|||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
except SecurityError as e:
|
||||||
|
logger.warning(f"Security error in embedding: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Security validation failed: {e.message}"
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.warning(f"Validation error in embedding: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Request validation failed: {e.message}"
|
||||||
|
)
|
||||||
|
except ProviderError as e:
|
||||||
|
logger.error(f"Provider error in embedding: {e}")
|
||||||
|
if "rate limit" in str(e).lower():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
|
detail="Rate limit exceeded"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="LLM service temporarily unavailable"
|
||||||
|
)
|
||||||
|
except LLMError as e:
|
||||||
|
logger.error(f"LLM service error in embedding: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="LLM service error"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating embedding: {e}")
|
logger.error(f"Unexpected error creating embedding: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to create embedding"
|
detail="Failed to create embedding"
|
||||||
@@ -489,11 +605,28 @@ async def llm_health_check(
|
|||||||
):
|
):
|
||||||
"""Health check for LLM service"""
|
"""Health check for LLM service"""
|
||||||
try:
|
try:
|
||||||
health_status = await litellm_client.health_check()
|
health_summary = llm_service.get_health_summary()
|
||||||
|
provider_status = await llm_service.get_provider_status()
|
||||||
|
|
||||||
|
# Determine overall health
|
||||||
|
overall_status = "healthy"
|
||||||
|
if health_summary["service_status"] != "healthy":
|
||||||
|
overall_status = "degraded"
|
||||||
|
|
||||||
|
for provider, status in provider_status.items():
|
||||||
|
if status.status == "unavailable":
|
||||||
|
overall_status = "degraded"
|
||||||
|
break
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "healthy",
|
"status": overall_status,
|
||||||
"service": "LLM Proxy",
|
"service": "LLM Service",
|
||||||
"litellm_status": health_status,
|
"service_status": health_summary,
|
||||||
|
"provider_status": {name: {
|
||||||
|
"status": status.status,
|
||||||
|
"latency_ms": status.latency_ms,
|
||||||
|
"error_message": status.error_message
|
||||||
|
} for name, status in provider_status.items()},
|
||||||
"user_id": context["user_id"],
|
"user_id": context["user_id"],
|
||||||
"api_key_name": context["api_key_name"]
|
"api_key_name": context["api_key_name"]
|
||||||
}
|
}
|
||||||
@@ -501,7 +634,7 @@ async def llm_health_check(
|
|||||||
logger.error(f"LLM health check error: {e}")
|
logger.error(f"LLM health check error: {e}")
|
||||||
return {
|
return {
|
||||||
"status": "unhealthy",
|
"status": "unhealthy",
|
||||||
"service": "LLM Proxy",
|
"service": "LLM Service",
|
||||||
"error": str(e)
|
"error": str(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -626,50 +759,83 @@ async def get_budget_status(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Generic proxy endpoint for other LiteLLM endpoints
|
# Generic endpoint for additional LLM service functionality
|
||||||
@router.api_route("/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
@router.get("/metrics")
|
||||||
async def proxy_endpoint(
|
async def get_llm_metrics(
|
||||||
endpoint: str,
|
|
||||||
request: Request,
|
|
||||||
context: Dict[str, Any] = Depends(require_api_key),
|
context: Dict[str, Any] = Depends(require_api_key),
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""Generic proxy endpoint for LiteLLM requests"""
|
"""Get LLM service metrics (admin only)"""
|
||||||
try:
|
try:
|
||||||
|
# Check for admin permissions
|
||||||
auth_service = APIKeyAuthService(db)
|
auth_service = APIKeyAuthService(db)
|
||||||
|
if not await auth_service.check_scope_permission(context, "admin.metrics"):
|
||||||
# Check endpoint permission
|
|
||||||
if not await auth_service.check_endpoint_permission(context, endpoint):
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Endpoint '{endpoint}' not allowed"
|
detail="Admin permissions required to view metrics"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get request body
|
metrics = llm_service.get_metrics()
|
||||||
if request.method in ["POST", "PUT", "PATCH"]:
|
return {
|
||||||
try:
|
"object": "llm_metrics",
|
||||||
payload = await request.json()
|
"data": {
|
||||||
except:
|
"total_requests": metrics.total_requests,
|
||||||
payload = {}
|
"successful_requests": metrics.successful_requests,
|
||||||
else:
|
"failed_requests": metrics.failed_requests,
|
||||||
payload = dict(request.query_params)
|
"security_blocked_requests": metrics.security_blocked_requests,
|
||||||
|
"average_latency_ms": metrics.average_latency_ms,
|
||||||
# Make request to LiteLLM
|
"average_risk_score": metrics.average_risk_score,
|
||||||
response = await litellm_client.proxy_request(
|
"provider_metrics": metrics.provider_metrics,
|
||||||
method=request.method,
|
"last_updated": metrics.last_updated.isoformat()
|
||||||
endpoint=endpoint,
|
}
|
||||||
payload=payload,
|
}
|
||||||
user_id=str(context["user_id"]),
|
|
||||||
api_key_id=context["api_key_id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error proxying request to {endpoint}: {e}")
|
logger.error(f"Error getting LLM metrics: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to proxy request"
|
detail="Failed to get LLM metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/providers/status")
|
||||||
|
async def get_provider_status(
|
||||||
|
context: Dict[str, Any] = Depends(require_api_key),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Get status of all LLM providers"""
|
||||||
|
try:
|
||||||
|
auth_service = APIKeyAuthService(db)
|
||||||
|
if not await auth_service.check_scope_permission(context, "admin.status"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Admin permissions required to view provider status"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_status = await llm_service.get_provider_status()
|
||||||
|
return {
|
||||||
|
"object": "provider_status",
|
||||||
|
"data": {
|
||||||
|
name: {
|
||||||
|
"provider": status.provider,
|
||||||
|
"status": status.status,
|
||||||
|
"latency_ms": status.latency_ms,
|
||||||
|
"success_rate": status.success_rate,
|
||||||
|
"last_check": status.last_check.isoformat(),
|
||||||
|
"error_message": status.error_message,
|
||||||
|
"models_available": status.models_available
|
||||||
|
}
|
||||||
|
for name, status in provider_status.items()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting provider status: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to get provider status"
|
||||||
)
|
)
|
||||||
@@ -447,7 +447,7 @@ async def get_module_config(module_name: str):
|
|||||||
log_api_request("get_module_config", {"module_name": module_name})
|
log_api_request("get_module_config", {"module_name": module_name})
|
||||||
|
|
||||||
from app.services.module_config_manager import module_config_manager
|
from app.services.module_config_manager import module_config_manager
|
||||||
from app.services.litellm_client import litellm_client
|
from app.services.llm.service import llm_service
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
# Get module manifest and schema
|
# Get module manifest and schema
|
||||||
@@ -461,9 +461,9 @@ async def get_module_config(module_name: str):
|
|||||||
# For Signal module, populate model options dynamically
|
# For Signal module, populate model options dynamically
|
||||||
if module_name == "signal" and schema:
|
if module_name == "signal" and schema:
|
||||||
try:
|
try:
|
||||||
# Get available models from LiteLLM
|
# Get available models from LLM service
|
||||||
models_data = await litellm_client.get_models()
|
models_data = await llm_service.get_models()
|
||||||
model_ids = [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")]
|
model_ids = [model.id for model in models_data]
|
||||||
|
|
||||||
if model_ids:
|
if model_ids:
|
||||||
# Create a copy of the schema to avoid modifying the original
|
# Create a copy of the schema to avoid modifying the original
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ from app.models.prompt_template import PromptTemplate, ChatbotPromptVariable
|
|||||||
from app.core.security import get_current_user
|
from app.core.security import get_current_user
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.core.logging import log_api_request
|
from app.core.logging import log_api_request
|
||||||
from app.services.litellm_client import litellm_client
|
from app.services.llm.service import llm_service
|
||||||
|
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -394,25 +395,28 @@ Please improve this prompt to make it more effective for a {request.chatbot_type
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Get available models to use a default model
|
# Get available models to use a default model
|
||||||
models = await litellm_client.get_models()
|
models = await llm_service.get_models()
|
||||||
if not models:
|
if not models:
|
||||||
raise HTTPException(status_code=503, detail="No LLM models available")
|
raise HTTPException(status_code=503, detail="No LLM models available")
|
||||||
|
|
||||||
# Use the first available model (you might want to make this configurable)
|
# Use the first available model (you might want to make this configurable)
|
||||||
default_model = models[0]["id"]
|
default_model = models[0].id
|
||||||
|
|
||||||
# Make the AI call
|
# Prepare the chat request for the new LLM service
|
||||||
response = await litellm_client.create_chat_completion(
|
chat_request = LLMChatRequest(
|
||||||
model=default_model,
|
model=default_model,
|
||||||
messages=messages,
|
messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages],
|
||||||
user_id=str(user_id),
|
|
||||||
api_key_id=1, # Using default API key, you might want to make this dynamic
|
|
||||||
temperature=0.3,
|
temperature=0.3,
|
||||||
max_tokens=1000
|
max_tokens=1000,
|
||||||
|
user_id=str(user_id),
|
||||||
|
api_key_id=1 # Using default API key, you might want to make this dynamic
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Make the AI call
|
||||||
|
response = await llm_service.create_chat_completion(chat_request)
|
||||||
|
|
||||||
# Extract the improved prompt from the response
|
# Extract the improved prompt from the response
|
||||||
improved_prompt = response["choices"][0]["message"]["content"].strip()
|
improved_prompt = response.choices[0].message.content.strip()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"improved_prompt": improved_prompt,
|
"improved_prompt": improved_prompt,
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class SystemInfoResponse(BaseModel):
|
|||||||
environment: str
|
environment: str
|
||||||
database_status: str
|
database_status: str
|
||||||
redis_status: str
|
redis_status: str
|
||||||
litellm_status: str
|
llm_service_status: str
|
||||||
modules_loaded: int
|
modules_loaded: int
|
||||||
active_users: int
|
active_users: int
|
||||||
total_api_keys: int
|
total_api_keys: int
|
||||||
@@ -227,8 +227,13 @@ async def get_system_info(
|
|||||||
# Get Redis status (simplified check)
|
# Get Redis status (simplified check)
|
||||||
redis_status = "healthy" # Would implement actual Redis check
|
redis_status = "healthy" # Would implement actual Redis check
|
||||||
|
|
||||||
# Get LiteLLM status (simplified check)
|
# Get LLM service status
|
||||||
litellm_status = "healthy" # Would implement actual LiteLLM check
|
try:
|
||||||
|
from app.services.llm.service import llm_service
|
||||||
|
health_summary = llm_service.get_health_summary()
|
||||||
|
llm_service_status = health_summary.get("service_status", "unknown")
|
||||||
|
except Exception:
|
||||||
|
llm_service_status = "unavailable"
|
||||||
|
|
||||||
# Get modules loaded (from module manager)
|
# Get modules loaded (from module manager)
|
||||||
modules_loaded = 8 # Would get from actual module manager
|
modules_loaded = 8 # Would get from actual module manager
|
||||||
@@ -261,7 +266,7 @@ async def get_system_info(
|
|||||||
environment="production",
|
environment="production",
|
||||||
database_status=database_status,
|
database_status=database_status,
|
||||||
redis_status=redis_status,
|
redis_status=redis_status,
|
||||||
litellm_status=litellm_status,
|
llm_service_status=llm_service_status,
|
||||||
modules_loaded=modules_loaded,
|
modules_loaded=modules_loaded,
|
||||||
active_users=active_users,
|
active_users=active_users,
|
||||||
total_api_keys=total_api_keys,
|
total_api_keys=total_api_keys,
|
||||||
|
|||||||
@@ -43,15 +43,18 @@ class Settings(BaseSettings):
|
|||||||
# CORS
|
# CORS
|
||||||
CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"]
|
CORS_ORIGINS: List[str] = ["http://localhost:3000", "http://localhost:8000"]
|
||||||
|
|
||||||
# LiteLLM
|
# LLM Service Configuration (replaced LiteLLM)
|
||||||
LITELLM_BASE_URL: str = "http://localhost:4000"
|
# LLM service configuration is now handled in app/services/llm/config.py
|
||||||
LITELLM_MASTER_KEY: str = "enclava-master-key"
|
|
||||||
|
# LLM Service Security
|
||||||
|
LLM_ENCRYPTION_KEY: Optional[str] = None # Key for encrypting LLM provider API keys
|
||||||
|
|
||||||
# API Keys for LLM providers
|
# API Keys for LLM providers
|
||||||
OPENAI_API_KEY: Optional[str] = None
|
OPENAI_API_KEY: Optional[str] = None
|
||||||
ANTHROPIC_API_KEY: Optional[str] = None
|
ANTHROPIC_API_KEY: Optional[str] = None
|
||||||
GOOGLE_API_KEY: Optional[str] = None
|
GOOGLE_API_KEY: Optional[str] = None
|
||||||
PRIVATEMODE_API_KEY: Optional[str] = None
|
PRIVATEMODE_API_KEY: Optional[str] = None
|
||||||
|
PRIVATEMODE_PROXY_URL: str = "http://privatemode-proxy:8080/v1"
|
||||||
|
|
||||||
# Qdrant
|
# Qdrant
|
||||||
QDRANT_HOST: str = "localhost"
|
QDRANT_HOST: str = "localhost"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Embedding Service
|
Embedding Service
|
||||||
Provides text embedding functionality using LiteLLM proxy
|
Provides text embedding functionality using LLM service
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -11,32 +11,34 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingService:
|
class EmbeddingService:
|
||||||
"""Service for generating text embeddings using LiteLLM"""
|
"""Service for generating text embeddings using LLM service"""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "privatemode-embeddings"):
|
def __init__(self, model_name: str = "privatemode-embeddings"):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.litellm_client = None
|
|
||||||
self.dimension = 1024 # Actual dimension for privatemode-embeddings
|
self.dimension = 1024 # Actual dimension for privatemode-embeddings
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize the embedding service with LiteLLM"""
|
"""Initialize the embedding service with LLM service"""
|
||||||
try:
|
try:
|
||||||
from app.services.litellm_client import litellm_client
|
from app.services.llm.service import llm_service
|
||||||
self.litellm_client = litellm_client
|
|
||||||
|
|
||||||
# Test connection to LiteLLM
|
# Initialize LLM service if not already done
|
||||||
health = await self.litellm_client.health_check()
|
if not llm_service._initialized:
|
||||||
if health.get("status") == "unhealthy":
|
await llm_service.initialize()
|
||||||
logger.error(f"LiteLLM service unhealthy: {health.get('error')}")
|
|
||||||
|
# Test LLM service health
|
||||||
|
health_summary = llm_service.get_health_summary()
|
||||||
|
if health_summary.get("service_status") != "healthy":
|
||||||
|
logger.error(f"LLM service unhealthy: {health_summary}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
logger.info(f"Embedding service initialized with LiteLLM: {self.model_name} (dimension: {self.dimension})")
|
logger.info(f"Embedding service initialized with LLM service: {self.model_name} (dimension: {self.dimension})")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize LiteLLM embedding service: {e}")
|
logger.error(f"Failed to initialize LLM embedding service: {e}")
|
||||||
logger.warning("Using fallback random embeddings")
|
logger.warning("Using fallback random embeddings")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -46,10 +48,10 @@ class EmbeddingService:
|
|||||||
return embeddings[0]
|
return embeddings[0]
|
||||||
|
|
||||||
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||||
"""Get embeddings for multiple texts using LiteLLM"""
|
"""Get embeddings for multiple texts using LLM service"""
|
||||||
if not self.initialized or not self.litellm_client:
|
if not self.initialized:
|
||||||
# Fallback to random embeddings if not initialized
|
# Fallback to random embeddings if not initialized
|
||||||
logger.warning("LiteLLM not available, using random embeddings")
|
logger.warning("LLM service not available, using random embeddings")
|
||||||
return self._generate_fallback_embeddings(texts)
|
return self._generate_fallback_embeddings(texts)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -73,17 +75,22 @@ class EmbeddingService:
|
|||||||
else:
|
else:
|
||||||
truncated_text = text
|
truncated_text = text
|
||||||
|
|
||||||
# Call LiteLLM embedding endpoint
|
# Call LLM service embedding endpoint
|
||||||
response = await self.litellm_client.create_embedding(
|
from app.services.llm.service import llm_service
|
||||||
|
from app.services.llm.models import EmbeddingRequest
|
||||||
|
|
||||||
|
llm_request = EmbeddingRequest(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
input_text=truncated_text,
|
input=truncated_text,
|
||||||
user_id="rag_system",
|
user_id="rag_system",
|
||||||
api_key_id=0 # System API key
|
api_key_id=0 # System API key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
response = await llm_service.create_embedding(llm_request)
|
||||||
|
|
||||||
# Extract embedding from response
|
# Extract embedding from response
|
||||||
if "data" in response and len(response["data"]) > 0:
|
if response.data and len(response.data) > 0:
|
||||||
embedding = response["data"][0].get("embedding", [])
|
embedding = response.data[0].embedding
|
||||||
if embedding:
|
if embedding:
|
||||||
batch_embeddings.append(embedding)
|
batch_embeddings.append(embedding)
|
||||||
# Update dimension based on actual embedding size
|
# Update dimension based on actual embedding size
|
||||||
@@ -106,7 +113,7 @@ class EmbeddingService:
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating embeddings with LiteLLM: {e}")
|
logger.error(f"Error generating embeddings with LLM service: {e}")
|
||||||
# Fallback to random embeddings
|
# Fallback to random embeddings
|
||||||
return self._generate_fallback_embeddings(texts)
|
return self._generate_fallback_embeddings(texts)
|
||||||
|
|
||||||
@@ -146,14 +153,13 @@ class EmbeddingService:
|
|||||||
"model_name": self.model_name,
|
"model_name": self.model_name,
|
||||||
"model_loaded": self.initialized,
|
"model_loaded": self.initialized,
|
||||||
"dimension": self.dimension,
|
"dimension": self.dimension,
|
||||||
"backend": "LiteLLM",
|
"backend": "LLM Service",
|
||||||
"initialized": self.initialized
|
"initialized": self.initialized
|
||||||
}
|
}
|
||||||
|
|
||||||
async def cleanup(self):
|
async def cleanup(self):
|
||||||
"""Cleanup resources"""
|
"""Cleanup resources"""
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
self.litellm_client = None
|
|
||||||
|
|
||||||
|
|
||||||
# Global embedding service instance
|
# Global embedding service instance
|
||||||
|
|||||||
@@ -1,304 +0,0 @@
|
|||||||
"""
|
|
||||||
LiteLLM Client Service
|
|
||||||
Handles communication with the LiteLLM proxy service
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Dict, Any, Optional, List
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from fastapi import HTTPException, status
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMClient:
|
|
||||||
"""Client for communicating with LiteLLM proxy service"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.base_url = settings.LITELLM_BASE_URL
|
|
||||||
self.master_key = settings.LITELLM_MASTER_KEY
|
|
||||||
self.session: Optional[aiohttp.ClientSession] = None
|
|
||||||
self.timeout = aiohttp.ClientTimeout(total=600) # 10 minutes timeout
|
|
||||||
|
|
||||||
async def _get_session(self) -> aiohttp.ClientSession:
|
|
||||||
"""Get or create aiohttp session"""
|
|
||||||
if self.session is None or self.session.closed:
|
|
||||||
self.session = aiohttp.ClientSession(
|
|
||||||
timeout=self.timeout,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.master_key}"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self.session
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""Close the HTTP session"""
|
|
||||||
if self.session and not self.session.closed:
|
|
||||||
await self.session.close()
|
|
||||||
|
|
||||||
async def health_check(self) -> Dict[str, Any]:
|
|
||||||
"""Check LiteLLM proxy health"""
|
|
||||||
try:
|
|
||||||
session = await self._get_session()
|
|
||||||
async with session.get(f"{self.base_url}/health") as response:
|
|
||||||
if response.status == 200:
|
|
||||||
return await response.json()
|
|
||||||
else:
|
|
||||||
logger.error(f"LiteLLM health check failed: {response.status}")
|
|
||||||
return {"status": "unhealthy", "error": f"HTTP {response.status}"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"LiteLLM health check error: {e}")
|
|
||||||
return {"status": "unhealthy", "error": str(e)}
|
|
||||||
|
|
||||||
async def get_models(self) -> List[Dict[str, Any]]:
|
|
||||||
"""Get available models from LiteLLM"""
|
|
||||||
try:
|
|
||||||
session = await self._get_session()
|
|
||||||
async with session.get(f"{self.base_url}/models") as response:
|
|
||||||
if response.status == 200:
|
|
||||||
data = await response.json()
|
|
||||||
return data.get("data", [])
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed to get models: {response.status}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
detail="LiteLLM service unavailable"
|
|
||||||
)
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
logger.error(f"LiteLLM models request error: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
detail="LiteLLM service unavailable"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_chat_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
user_id: str,
|
|
||||||
api_key_id: int,
|
|
||||||
**kwargs
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Create chat completion via LiteLLM proxy"""
|
|
||||||
try:
|
|
||||||
# Prepare request payload
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"user": f"user_{user_id}", # User identifier for tracking
|
|
||||||
"metadata": {
|
|
||||||
"api_key_id": api_key_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"timestamp": datetime.utcnow().isoformat()
|
|
||||||
},
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
|
|
||||||
session = await self._get_session()
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/chat/completions",
|
|
||||||
json=payload
|
|
||||||
) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
return await response.json()
|
|
||||||
else:
|
|
||||||
error_text = await response.text()
|
|
||||||
logger.error(f"LiteLLM chat completion failed: {response.status} - {error_text}")
|
|
||||||
|
|
||||||
# Handle specific error cases
|
|
||||||
if response.status == 401:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid API key"
|
|
||||||
)
|
|
||||||
elif response.status == 429:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
||||||
detail="Rate limit exceeded"
|
|
||||||
)
|
|
||||||
elif response.status == 400:
|
|
||||||
try:
|
|
||||||
error_data = await response.json()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=error_data.get("error", {}).get("message", "Bad request")
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Invalid request"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
detail="LiteLLM service error"
|
|
||||||
)
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
logger.error(f"LiteLLM chat completion request error: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
detail="LiteLLM service unavailable"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_embedding(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input_text: str,
|
|
||||||
user_id: str,
|
|
||||||
api_key_id: int,
|
|
||||||
**kwargs
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Create embedding via LiteLLM proxy"""
|
|
||||||
try:
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"input": input_text,
|
|
||||||
"user": f"user_{user_id}",
|
|
||||||
"metadata": {
|
|
||||||
"api_key_id": api_key_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"timestamp": datetime.utcnow().isoformat()
|
|
||||||
},
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
|
|
||||||
session = await self._get_session()
|
|
||||||
async with session.post(
|
|
||||||
f"{self.base_url}/embeddings",
|
|
||||||
json=payload
|
|
||||||
) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
return await response.json()
|
|
||||||
else:
|
|
||||||
error_text = await response.text()
|
|
||||||
logger.error(f"LiteLLM embedding failed: {response.status} - {error_text}")
|
|
||||||
|
|
||||||
if response.status == 401:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid API key"
|
|
||||||
)
|
|
||||||
elif response.status == 429:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
||||||
detail="Rate limit exceeded"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
detail="LiteLLM service error"
|
|
||||||
)
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
logger.error(f"LiteLLM embedding request error: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
detail="LiteLLM service unavailable"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_models(self) -> List[Dict[str, Any]]:
|
|
||||||
"""Get available models from LiteLLM proxy"""
|
|
||||||
try:
|
|
||||||
session = await self._get_session()
|
|
||||||
async with session.get(f"{self.base_url}/models") as response:
|
|
||||||
if response.status == 200:
|
|
||||||
data = await response.json()
|
|
||||||
# Return models with exact names from upstream providers
|
|
||||||
models = data.get("data", [])
|
|
||||||
|
|
||||||
# Pass through model names exactly as they come from upstream
|
|
||||||
# Don't modify model IDs - keep them as the original provider names
|
|
||||||
processed_models = []
|
|
||||||
for model in models:
|
|
||||||
# Keep the exact model ID from upstream provider
|
|
||||||
processed_models.append({
|
|
||||||
"id": model.get("id", ""), # Exact model name from provider
|
|
||||||
"object": model.get("object", "model"),
|
|
||||||
"created": model.get("created", 1677610602),
|
|
||||||
"owned_by": model.get("owned_by", "openai")
|
|
||||||
})
|
|
||||||
|
|
||||||
return processed_models
|
|
||||||
else:
|
|
||||||
error_text = await response.text()
|
|
||||||
logger.error(f"LiteLLM models request failed: {response.status} - {error_text}")
|
|
||||||
return []
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
logger.error(f"LiteLLM models request error: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def proxy_request(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
endpoint: str,
|
|
||||||
payload: Dict[str, Any],
|
|
||||||
user_id: str,
|
|
||||||
api_key_id: int
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Generic proxy request to LiteLLM"""
|
|
||||||
try:
|
|
||||||
# Add metadata to payload
|
|
||||||
if isinstance(payload, dict):
|
|
||||||
payload["metadata"] = {
|
|
||||||
"api_key_id": api_key_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"timestamp": datetime.utcnow().isoformat()
|
|
||||||
}
|
|
||||||
if "user" not in payload:
|
|
||||||
payload["user"] = f"user_{user_id}"
|
|
||||||
|
|
||||||
session = await self._get_session()
|
|
||||||
|
|
||||||
# Make the request
|
|
||||||
async with session.request(
|
|
||||||
method,
|
|
||||||
f"{self.base_url}/{endpoint.lstrip('/')}",
|
|
||||||
json=payload if method.upper() in ['POST', 'PUT', 'PATCH'] else None,
|
|
||||||
params=payload if method.upper() == 'GET' else None
|
|
||||||
) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
return await response.json()
|
|
||||||
else:
|
|
||||||
error_text = await response.text()
|
|
||||||
logger.error(f"LiteLLM proxy request failed: {response.status} - {error_text}")
|
|
||||||
|
|
||||||
if response.status == 401:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid API key"
|
|
||||||
)
|
|
||||||
elif response.status == 429:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
||||||
detail="Rate limit exceeded"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=response.status,
|
|
||||||
detail=f"LiteLLM service error: {error_text}"
|
|
||||||
)
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
logger.error(f"LiteLLM proxy request error: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
detail="LiteLLM service unavailable"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def list_models(self) -> List[str]:
|
|
||||||
"""Get list of available model names/IDs"""
|
|
||||||
try:
|
|
||||||
models_data = await self.get_models()
|
|
||||||
return [model.get("id", model.get("model", "")) for model in models_data if model.get("id") or model.get("model")]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error listing model names: {str(e)}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
# Global LiteLLM client instance
|
|
||||||
litellm_client = LiteLLMClient()
|
|
||||||
@@ -23,7 +23,9 @@ from fastapi import APIRouter, HTTPException, Depends
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.services.litellm_client import LiteLLMClient
|
from app.services.llm.service import llm_service
|
||||||
|
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
|
||||||
|
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError
|
||||||
from app.services.base_module import BaseModule, Permission
|
from app.services.base_module import BaseModule, Permission
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.chatbot import ChatbotInstance as DBChatbotInstance, ChatbotConversation as DBConversation, ChatbotMessage as DBMessage, ChatbotAnalytics
|
from app.models.chatbot import ChatbotInstance as DBChatbotInstance, ChatbotConversation as DBConversation, ChatbotMessage as DBMessage, ChatbotAnalytics
|
||||||
@@ -32,7 +34,8 @@ from app.db.database import get_db
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
# Import protocols for type hints and dependency injection
|
# Import protocols for type hints and dependency injection
|
||||||
from ..protocols import RAGServiceProtocol, LiteLLMClientProtocol
|
from ..protocols import RAGServiceProtocol
|
||||||
|
# Note: LiteLLMClientProtocol replaced with direct LLM service usage
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -131,10 +134,8 @@ class ChatbotInstance(BaseModel):
|
|||||||
class ChatbotModule(BaseModule):
|
class ChatbotModule(BaseModule):
|
||||||
"""Main chatbot module implementation"""
|
"""Main chatbot module implementation"""
|
||||||
|
|
||||||
def __init__(self, litellm_client: Optional[LiteLLMClientProtocol] = None,
|
def __init__(self, rag_service: Optional[RAGServiceProtocol] = None):
|
||||||
rag_service: Optional[RAGServiceProtocol] = None):
|
|
||||||
super().__init__("chatbot")
|
super().__init__("chatbot")
|
||||||
self.litellm_client = litellm_client
|
|
||||||
self.rag_module = rag_service # Keep same name for compatibility
|
self.rag_module = rag_service # Keep same name for compatibility
|
||||||
self.db_session = None
|
self.db_session = None
|
||||||
|
|
||||||
@@ -145,15 +146,10 @@ class ChatbotModule(BaseModule):
|
|||||||
"""Initialize the chatbot module"""
|
"""Initialize the chatbot module"""
|
||||||
await super().initialize(**kwargs)
|
await super().initialize(**kwargs)
|
||||||
|
|
||||||
# Get dependencies from global services if not already injected
|
# Initialize the LLM service
|
||||||
if not self.litellm_client:
|
await llm_service.initialize()
|
||||||
try:
|
|
||||||
from app.services.litellm_client import litellm_client
|
|
||||||
self.litellm_client = litellm_client
|
|
||||||
logger.info("LiteLLM client injected from global service")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not inject LiteLLM client: {e}")
|
|
||||||
|
|
||||||
|
# Get RAG module dependency if not already injected
|
||||||
if not self.rag_module:
|
if not self.rag_module:
|
||||||
try:
|
try:
|
||||||
# Try to get RAG module from module manager
|
# Try to get RAG module from module manager
|
||||||
@@ -168,19 +164,16 @@ class ChatbotModule(BaseModule):
|
|||||||
await self._load_prompt_templates()
|
await self._load_prompt_templates()
|
||||||
|
|
||||||
logger.info("Chatbot module initialized")
|
logger.info("Chatbot module initialized")
|
||||||
logger.info(f"LiteLLM client available after init: {self.litellm_client is not None}")
|
logger.info(f"LLM service available: {llm_service._initialized}")
|
||||||
logger.info(f"RAG module available after init: {self.rag_module is not None}")
|
logger.info(f"RAG module available after init: {self.rag_module is not None}")
|
||||||
logger.info(f"Loaded {len(self.system_prompts)} prompt templates")
|
logger.info(f"Loaded {len(self.system_prompts)} prompt templates")
|
||||||
|
|
||||||
async def _ensure_dependencies(self):
|
async def _ensure_dependencies(self):
|
||||||
"""Lazy load dependencies if not available"""
|
"""Lazy load dependencies if not available"""
|
||||||
if not self.litellm_client:
|
# Ensure LLM service is initialized
|
||||||
try:
|
if not llm_service._initialized:
|
||||||
from app.services.litellm_client import litellm_client
|
await llm_service.initialize()
|
||||||
self.litellm_client = litellm_client
|
logger.info("LLM service lazy loaded")
|
||||||
logger.info("LiteLLM client lazy loaded")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not lazy load LiteLLM client: {e}")
|
|
||||||
|
|
||||||
if not self.rag_module:
|
if not self.rag_module:
|
||||||
try:
|
try:
|
||||||
@@ -468,45 +461,58 @@ class ChatbotModule(BaseModule):
|
|||||||
logger.info(msg['content'])
|
logger.info(msg['content'])
|
||||||
logger.info("=== END COMPREHENSIVE LLM REQUEST ===")
|
logger.info("=== END COMPREHENSIVE LLM REQUEST ===")
|
||||||
|
|
||||||
if self.litellm_client:
|
try:
|
||||||
try:
|
logger.info("Calling LLM service create_chat_completion...")
|
||||||
logger.info("Calling LiteLLM client create_chat_completion...")
|
|
||||||
response = await self.litellm_client.create_chat_completion(
|
|
||||||
model=config.model,
|
|
||||||
messages=messages,
|
|
||||||
user_id="chatbot_user",
|
|
||||||
api_key_id="chatbot_api_key",
|
|
||||||
temperature=config.temperature,
|
|
||||||
max_tokens=config.max_tokens
|
|
||||||
)
|
|
||||||
logger.info(f"LiteLLM response received, response keys: {list(response.keys())}")
|
|
||||||
|
|
||||||
# Extract response content from the LiteLLM response format
|
# Convert messages to LLM service format
|
||||||
if 'choices' in response and response['choices']:
|
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||||
content = response['choices'][0]['message']['content']
|
|
||||||
logger.info(f"Response content length: {len(content)}")
|
|
||||||
|
|
||||||
# Always log response for debugging
|
# Create LLM service request
|
||||||
logger.info("=== COMPREHENSIVE LLM RESPONSE ===")
|
llm_request = LLMChatRequest(
|
||||||
logger.info(f"Response content ({len(content)} chars):")
|
model=config.model,
|
||||||
logger.info(content)
|
messages=llm_messages,
|
||||||
if 'usage' in response:
|
temperature=config.temperature,
|
||||||
usage = response['usage']
|
max_tokens=config.max_tokens,
|
||||||
logger.info(f"Token usage - Prompt: {usage.get('prompt_tokens', 'N/A')}, Completion: {usage.get('completion_tokens', 'N/A')}, Total: {usage.get('total_tokens', 'N/A')}")
|
user_id="chatbot_user",
|
||||||
if sources:
|
api_key_id=0 # Chatbot module uses internal service
|
||||||
logger.info(f"RAG sources included: {len(sources)} documents")
|
)
|
||||||
logger.info("=== END COMPREHENSIVE LLM RESPONSE ===")
|
|
||||||
|
|
||||||
return content, sources
|
# Make request to LLM service
|
||||||
else:
|
llm_response = await llm_service.create_chat_completion(llm_request)
|
||||||
logger.warning("No choices in LiteLLM response")
|
|
||||||
return "I received an empty response from the AI model.", sources
|
# Extract response content
|
||||||
except Exception as e:
|
if llm_response.choices:
|
||||||
logger.error(f"LiteLLM completion failed: {e}")
|
content = llm_response.choices[0].message.content
|
||||||
raise e
|
logger.info(f"Response content length: {len(content)}")
|
||||||
else:
|
|
||||||
logger.warning("No LiteLLM client available, using fallback")
|
# Always log response for debugging
|
||||||
# Fallback if no LLM client
|
logger.info("=== COMPREHENSIVE LLM RESPONSE ===")
|
||||||
|
logger.info(f"Response content ({len(content)} chars):")
|
||||||
|
logger.info(content)
|
||||||
|
if llm_response.usage:
|
||||||
|
usage = llm_response.usage
|
||||||
|
logger.info(f"Token usage - Prompt: {usage.prompt_tokens}, Completion: {usage.completion_tokens}, Total: {usage.total_tokens}")
|
||||||
|
if sources:
|
||||||
|
logger.info(f"RAG sources included: {len(sources)} documents")
|
||||||
|
logger.info("=== END COMPREHENSIVE LLM RESPONSE ===")
|
||||||
|
|
||||||
|
return content, sources
|
||||||
|
else:
|
||||||
|
logger.warning("No choices in LLM response")
|
||||||
|
return "I received an empty response from the AI model.", sources
|
||||||
|
|
||||||
|
except SecurityError as e:
|
||||||
|
logger.error(f"Security error in LLM completion: {e}")
|
||||||
|
raise HTTPException(status_code=400, detail=f"Security validation failed: {e.message}")
|
||||||
|
except ProviderError as e:
|
||||||
|
logger.error(f"Provider error in LLM completion: {e}")
|
||||||
|
raise HTTPException(status_code=503, detail="LLM service temporarily unavailable")
|
||||||
|
except LLMError as e:
|
||||||
|
logger.error(f"LLM service error: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="LLM service error")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM completion failed: {e}")
|
||||||
|
# Return fallback if available
|
||||||
return "I'm currently unable to process your request. Please try again later.", None
|
return "I'm currently unable to process your request. Please try again later.", None
|
||||||
|
|
||||||
def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig,
|
def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig,
|
||||||
@@ -685,7 +691,7 @@ class ChatbotModule(BaseModule):
|
|||||||
# Lazy load dependencies
|
# Lazy load dependencies
|
||||||
await self._ensure_dependencies()
|
await self._ensure_dependencies()
|
||||||
|
|
||||||
logger.info(f"LiteLLM client available: {self.litellm_client is not None}")
|
logger.info(f"LLM service available: {llm_service._initialized}")
|
||||||
logger.info(f"RAG module available: {self.rag_module is not None}")
|
logger.info(f"RAG module available: {self.rag_module is not None}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -884,10 +890,9 @@ class ChatbotModule(BaseModule):
|
|||||||
|
|
||||||
|
|
||||||
# Module factory function
|
# Module factory function
|
||||||
def create_module(litellm_client: Optional[LiteLLMClientProtocol] = None,
|
def create_module(rag_service: Optional[RAGServiceProtocol] = None) -> ChatbotModule:
|
||||||
rag_service: Optional[RAGServiceProtocol] = None) -> ChatbotModule:
|
|
||||||
"""Factory function to create chatbot module instance"""
|
"""Factory function to create chatbot module instance"""
|
||||||
return ChatbotModule(litellm_client=litellm_client, rag_service=rag_service)
|
return ChatbotModule(rag_service=rag_service)
|
||||||
|
|
||||||
# Create module instance (dependencies will be injected via factory)
|
# Create module instance (dependencies will be injected via factory)
|
||||||
chatbot_module = ChatbotModule()
|
chatbot_module = ChatbotModule()
|
||||||
@@ -401,7 +401,7 @@ class RAGModule(BaseModule):
|
|||||||
"""Initialize embedding model"""
|
"""Initialize embedding model"""
|
||||||
from app.services.embedding_service import embedding_service
|
from app.services.embedding_service import embedding_service
|
||||||
|
|
||||||
# Use privatemode-embeddings for LiteLLM integration
|
# Use privatemode-embeddings for LLM service integration
|
||||||
model_name = self.config.get("embedding_model", "privatemode-embeddings")
|
model_name = self.config.get("embedding_model", "privatemode-embeddings")
|
||||||
embedding_service.model_name = model_name
|
embedding_service.model_name = model_name
|
||||||
|
|
||||||
|
|||||||
@@ -22,13 +22,16 @@ from fastapi import APIRouter, HTTPException, Depends
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.services.litellm_client import LiteLLMClient
|
from app.services.llm.service import llm_service
|
||||||
|
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
|
||||||
|
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError
|
||||||
from app.services.base_module import Permission
|
from app.services.base_module import Permission
|
||||||
from app.db.database import SessionLocal
|
from app.db.database import SessionLocal
|
||||||
from app.models.workflow import WorkflowDefinition as DBWorkflowDefinition, WorkflowExecution as DBWorkflowExecution
|
from app.models.workflow import WorkflowDefinition as DBWorkflowDefinition, WorkflowExecution as DBWorkflowExecution
|
||||||
|
|
||||||
# Import protocols for type hints and dependency injection
|
# Import protocols for type hints and dependency injection
|
||||||
from ..protocols import ChatbotServiceProtocol, LiteLLMClientProtocol
|
from ..protocols import ChatbotServiceProtocol
|
||||||
|
# Note: LiteLLMClientProtocol replaced with direct LLM service usage
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -234,8 +237,7 @@ class WorkflowExecution(BaseModel):
|
|||||||
class WorkflowEngine:
|
class WorkflowEngine:
|
||||||
"""Core workflow execution engine"""
|
"""Core workflow execution engine"""
|
||||||
|
|
||||||
def __init__(self, litellm_client: LiteLLMClient, chatbot_service: Optional[ChatbotServiceProtocol] = None):
|
def __init__(self, chatbot_service: Optional[ChatbotServiceProtocol] = None):
|
||||||
self.litellm_client = litellm_client
|
|
||||||
self.chatbot_service = chatbot_service
|
self.chatbot_service = chatbot_service
|
||||||
self.executions: Dict[str, WorkflowExecution] = {}
|
self.executions: Dict[str, WorkflowExecution] = {}
|
||||||
self.workflows: Dict[str, WorkflowDefinition] = {}
|
self.workflows: Dict[str, WorkflowDefinition] = {}
|
||||||
@@ -343,15 +345,23 @@ class WorkflowEngine:
|
|||||||
# Template message content with context variables
|
# Template message content with context variables
|
||||||
messages = self._template_messages(llm_step.messages, context.variables)
|
messages = self._template_messages(llm_step.messages, context.variables)
|
||||||
|
|
||||||
# Make LLM call
|
# Convert messages to LLM service format
|
||||||
response = await self.litellm_client.chat_completion(
|
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||||
|
|
||||||
|
# Create LLM service request
|
||||||
|
llm_request = LLMChatRequest(
|
||||||
model=llm_step.model,
|
model=llm_step.model,
|
||||||
messages=messages,
|
messages=llm_messages,
|
||||||
**llm_step.parameters
|
user_id="workflow_user",
|
||||||
|
api_key_id=0, # Workflow module uses internal service
|
||||||
|
**{k: v for k, v in llm_step.parameters.items() if k in ['temperature', 'max_tokens', 'top_p', 'frequency_penalty', 'presence_penalty', 'stop']}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Make LLM call
|
||||||
|
response = await llm_service.create_chat_completion(llm_request)
|
||||||
|
|
||||||
# Store result
|
# Store result
|
||||||
result = response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
result = response.choices[0].message.content if response.choices else ""
|
||||||
context.variables[llm_step.output_variable] = result
|
context.variables[llm_step.output_variable] = result
|
||||||
context.results[step.id] = result
|
context.results[step.id] = result
|
||||||
|
|
||||||
@@ -631,16 +641,21 @@ class WorkflowEngine:
|
|||||||
|
|
||||||
messages = [{"role": "user", "content": self._template_string(prompt, variables)}]
|
messages = [{"role": "user", "content": self._template_string(prompt, variables)}]
|
||||||
|
|
||||||
response = await self.litellm_client.create_chat_completion(
|
# Convert to LLM service format
|
||||||
|
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||||
|
|
||||||
|
llm_request = LLMChatRequest(
|
||||||
model=step.model,
|
model=step.model,
|
||||||
messages=messages,
|
messages=llm_messages,
|
||||||
user_id="workflow_system",
|
user_id="workflow_system",
|
||||||
api_key_id="workflow",
|
api_key_id=0,
|
||||||
temperature=step.temperature,
|
temperature=step.temperature,
|
||||||
max_tokens=step.max_tokens
|
max_tokens=step.max_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
response = await llm_service.create_chat_completion(llm_request)
|
||||||
|
|
||||||
|
return response.choices[0].message.content if response.choices else ""
|
||||||
|
|
||||||
async def _generate_brand_names(self, variables: Dict[str, Any], step: AIGenerationStep) -> List[Dict[str, str]]:
|
async def _generate_brand_names(self, variables: Dict[str, Any], step: AIGenerationStep) -> List[Dict[str, str]]:
|
||||||
"""Generate brand names for a specific category"""
|
"""Generate brand names for a specific category"""
|
||||||
@@ -687,16 +702,21 @@ class WorkflowEngine:
|
|||||||
|
|
||||||
messages = [{"role": "user", "content": self._template_string(prompt, variables)}]
|
messages = [{"role": "user", "content": self._template_string(prompt, variables)}]
|
||||||
|
|
||||||
response = await self.litellm_client.create_chat_completion(
|
# Convert to LLM service format
|
||||||
|
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||||
|
|
||||||
|
llm_request = LLMChatRequest(
|
||||||
model=step.model,
|
model=step.model,
|
||||||
messages=messages,
|
messages=llm_messages,
|
||||||
user_id="workflow_system",
|
user_id="workflow_system",
|
||||||
api_key_id="workflow",
|
api_key_id=0,
|
||||||
temperature=step.temperature,
|
temperature=step.temperature,
|
||||||
max_tokens=step.max_tokens
|
max_tokens=step.max_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
response = await llm_service.create_chat_completion(llm_request)
|
||||||
|
|
||||||
|
return response.choices[0].message.content if response.choices else ""
|
||||||
|
|
||||||
async def _generate_custom_prompt(self, variables: Dict[str, Any], step: AIGenerationStep) -> str:
|
async def _generate_custom_prompt(self, variables: Dict[str, Any], step: AIGenerationStep) -> str:
|
||||||
"""Generate content using custom prompt template"""
|
"""Generate content using custom prompt template"""
|
||||||
@@ -705,16 +725,21 @@ class WorkflowEngine:
|
|||||||
|
|
||||||
messages = [{"role": "user", "content": self._template_string(step.prompt_template, variables)}]
|
messages = [{"role": "user", "content": self._template_string(step.prompt_template, variables)}]
|
||||||
|
|
||||||
response = await self.litellm_client.create_chat_completion(
|
# Convert to LLM service format
|
||||||
|
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||||
|
|
||||||
|
llm_request = LLMChatRequest(
|
||||||
model=step.model,
|
model=step.model,
|
||||||
messages=messages,
|
messages=llm_messages,
|
||||||
user_id="workflow_system",
|
user_id="workflow_system",
|
||||||
api_key_id="workflow",
|
api_key_id=0,
|
||||||
temperature=step.temperature,
|
temperature=step.temperature,
|
||||||
max_tokens=step.max_tokens
|
max_tokens=step.max_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
response = await llm_service.create_chat_completion(llm_request)
|
||||||
|
|
||||||
|
return response.choices[0].message.content if response.choices else ""
|
||||||
|
|
||||||
async def _execute_aggregate_step(self, step: WorkflowStep, context: WorkflowContext):
|
async def _execute_aggregate_step(self, step: WorkflowStep, context: WorkflowContext):
|
||||||
"""Execute aggregate step to combine multiple inputs"""
|
"""Execute aggregate step to combine multiple inputs"""
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ from app.core.config import settings
|
|||||||
from app.db.database import async_session_factory
|
from app.db.database import async_session_factory
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.chatbot import ChatbotInstance
|
from app.models.chatbot import ChatbotInstance
|
||||||
from app.services.litellm_client import LiteLLMClient
|
from app.services.llm.service import llm_service
|
||||||
|
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
@@ -65,8 +66,8 @@ class ZammadModule(BaseModule):
|
|||||||
try:
|
try:
|
||||||
logger.info("Initializing Zammad module...")
|
logger.info("Initializing Zammad module...")
|
||||||
|
|
||||||
# Initialize LLM client for chatbot integration
|
# Initialize LLM service for chatbot integration
|
||||||
self.llm_client = LiteLLMClient()
|
# Note: llm_service is already a global singleton, no need to create instance
|
||||||
|
|
||||||
# Create HTTP session pool for Zammad API calls
|
# Create HTTP session pool for Zammad API calls
|
||||||
timeout = aiohttp.ClientTimeout(total=60, connect=10)
|
timeout = aiohttp.ClientTimeout(total=60, connect=10)
|
||||||
@@ -597,19 +598,21 @@ class ZammadModule(BaseModule):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
# Generate summary using LLM client
|
# Generate summary using new LLM service
|
||||||
response = await self.llm_client.create_chat_completion(
|
chat_request = LLMChatRequest(
|
||||||
messages=messages,
|
|
||||||
model=await self._get_chatbot_model(config.chatbot_id),
|
model=await self._get_chatbot_model(config.chatbot_id),
|
||||||
user_id=str(config.user_id),
|
messages=[LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages],
|
||||||
api_key_id=0, # Using 0 for module requests
|
|
||||||
temperature=0.3,
|
temperature=0.3,
|
||||||
max_tokens=500
|
max_tokens=500,
|
||||||
|
user_id=str(config.user_id),
|
||||||
|
api_key_id=0 # Using 0 for module requests
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract content from LiteLLM response
|
response = await llm_service.create_chat_completion(chat_request)
|
||||||
if "choices" in response and len(response["choices"]) > 0:
|
|
||||||
return response["choices"][0]["message"]["content"].strip()
|
# Extract content from new LLM service response
|
||||||
|
if response.choices and len(response.choices) > 0:
|
||||||
|
return response.choices[0].message.content.strip()
|
||||||
|
|
||||||
return "Unable to generate summary."
|
return "Unable to generate summary."
|
||||||
|
|
||||||
|
|||||||
@@ -1,132 +0,0 @@
|
|||||||
"""
|
|
||||||
Test LLM API endpoints.
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
from httpx import AsyncClient
|
|
||||||
from unittest.mock import patch, AsyncMock
|
|
||||||
|
|
||||||
|
|
||||||
class TestLLMEndpoints:
|
|
||||||
"""Test LLM API endpoints."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_completion_success(self, client: AsyncClient):
|
|
||||||
"""Test successful chat completion."""
|
|
||||||
# Mock the LiteLLM client response
|
|
||||||
mock_response = {
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"message": {
|
|
||||||
"content": "Hello! How can I help you today?",
|
|
||||||
"role": "assistant"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 10,
|
|
||||||
"completion_tokens": 15,
|
|
||||||
"total_tokens": 25
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch("app.services.litellm_client.LiteLLMClient.create_chat_completion") as mock_chat:
|
|
||||||
mock_chat.return_value = mock_response
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/api/v1/llm/chat/completions",
|
|
||||||
json={
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hello"}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
headers={"Authorization": "Bearer test-api-key"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "choices" in data
|
|
||||||
assert data["choices"][0]["message"]["content"] == "Hello! How can I help you today?"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_chat_completion_unauthorized(self, client: AsyncClient):
|
|
||||||
"""Test chat completion without API key."""
|
|
||||||
response = await client.post(
|
|
||||||
"/api/v1/llm/chat/completions",
|
|
||||||
json={
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hello"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_embeddings_success(self, client: AsyncClient):
|
|
||||||
"""Test successful embeddings generation."""
|
|
||||||
mock_response = {
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"embedding": [0.1, 0.2, 0.3],
|
|
||||||
"index": 0
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 5,
|
|
||||||
"total_tokens": 5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch("app.services.litellm_client.LiteLLMClient.create_embedding") as mock_embeddings:
|
|
||||||
mock_embeddings.return_value = mock_response
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/api/v1/llm/embeddings",
|
|
||||||
json={
|
|
||||||
"model": "text-embedding-ada-002",
|
|
||||||
"input": "Hello world"
|
|
||||||
},
|
|
||||||
headers={"Authorization": "Bearer test-api-key"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "data" in data
|
|
||||||
assert len(data["data"][0]["embedding"]) == 3
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_budget_exceeded(self, client: AsyncClient):
|
|
||||||
"""Test budget exceeded scenario."""
|
|
||||||
with patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_check:
|
|
||||||
mock_check.side_effect = Exception("Budget exceeded")
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/api/v1/llm/chat/completions",
|
|
||||||
json={
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hello"}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
headers={"Authorization": "Bearer test-api-key"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 402 # Payment required
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_model_validation(self, client: AsyncClient):
|
|
||||||
"""Test model validation."""
|
|
||||||
response = await client.post(
|
|
||||||
"/api/v1/llm/chat/completions",
|
|
||||||
json={
|
|
||||||
"model": "invalid-model",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hello"}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
headers={"Authorization": "Bearer test-api-key"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
496
backend/tests/integration/test_llm_service_integration.py
Normal file
496
backend/tests/integration/test_llm_service_integration.py
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for the new LLM service.
|
||||||
|
Tests end-to-end functionality including provider integration, security, and performance.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from unittest.mock import patch, AsyncMock, MagicMock
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMServiceIntegration:
|
||||||
|
"""Integration tests for LLM service."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_chat_flow(self, client: AsyncClient):
|
||||||
|
"""Test complete chat completion flow with security and budget checks."""
|
||||||
|
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
|
||||||
|
|
||||||
|
# Mock successful LLM service response
|
||||||
|
mock_response = ChatCompletionResponse(
|
||||||
|
id="test-completion-123",
|
||||||
|
object="chat.completion",
|
||||||
|
created=int(time.time()),
|
||||||
|
model="privatemode-llama-3-70b",
|
||||||
|
choices=[
|
||||||
|
ChatChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="Hello! I'm a TEE-protected AI assistant. How can I help you today?"
|
||||||
|
),
|
||||||
|
finish_reason="stop"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=Usage(
|
||||||
|
prompt_tokens=25,
|
||||||
|
completion_tokens=15,
|
||||||
|
total_tokens=40
|
||||||
|
),
|
||||||
|
security_analysis={
|
||||||
|
"risk_score": 0.1,
|
||||||
|
"threats_detected": [],
|
||||||
|
"risk_level": "low",
|
||||||
|
"analysis_time_ms": 12.5
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat, \
|
||||||
|
patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_budget:
|
||||||
|
|
||||||
|
mock_chat.return_value = mock_response
|
||||||
|
mock_budget.return_value = True # Budget check passes
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "privatemode-llama-3-70b",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello, what are your capabilities?"}
|
||||||
|
],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 150
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Check standard OpenAI-compatible fields
|
||||||
|
assert "id" in data
|
||||||
|
assert "object" in data
|
||||||
|
assert "created" in data
|
||||||
|
assert "model" in data
|
||||||
|
assert "choices" in data
|
||||||
|
assert "usage" in data
|
||||||
|
|
||||||
|
# Check security integration
|
||||||
|
assert "security_analysis" in data
|
||||||
|
assert data["security_analysis"]["risk_level"] == "low"
|
||||||
|
|
||||||
|
# Verify content
|
||||||
|
assert len(data["choices"]) == 1
|
||||||
|
assert data["choices"][0]["message"]["role"] == "assistant"
|
||||||
|
assert "TEE-protected" in data["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
# Verify usage tracking
|
||||||
|
assert data["usage"]["total_tokens"] == 40
|
||||||
|
assert data["usage"]["prompt_tokens"] == 25
|
||||||
|
assert data["usage"]["completion_tokens"] == 15
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedding_integration(self, client: AsyncClient):
|
||||||
|
"""Test embedding generation with fallback handling."""
|
||||||
|
from app.services.llm.models import EmbeddingResponse, EmbeddingData, Usage
|
||||||
|
|
||||||
|
# Create realistic 1024-dimensional embedding
|
||||||
|
embedding_vector = [0.1 * i for i in range(1024)]
|
||||||
|
|
||||||
|
mock_response = EmbeddingResponse(
|
||||||
|
object="list",
|
||||||
|
data=[
|
||||||
|
EmbeddingData(
|
||||||
|
object="embedding",
|
||||||
|
embedding=embedding_vector,
|
||||||
|
index=0
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="privatemode-embeddings",
|
||||||
|
usage=Usage(
|
||||||
|
prompt_tokens=8,
|
||||||
|
total_tokens=8
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.llm.service.llm_service.create_embedding") as mock_embedding:
|
||||||
|
mock_embedding.return_value = mock_response
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/embeddings",
|
||||||
|
json={
|
||||||
|
"model": "privatemode-embeddings",
|
||||||
|
"input": "This is a test document for embedding generation."
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify embedding structure
|
||||||
|
assert "object" in data
|
||||||
|
assert "data" in data
|
||||||
|
assert "usage" in data
|
||||||
|
assert len(data["data"]) == 1
|
||||||
|
assert len(data["data"][0]["embedding"]) == 1024
|
||||||
|
assert data["data"][0]["index"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_health_integration(self, client: AsyncClient):
|
||||||
|
"""Test provider health monitoring integration."""
|
||||||
|
mock_status = {
|
||||||
|
"privatemode": {
|
||||||
|
"provider": "PrivateMode.ai",
|
||||||
|
"status": "healthy",
|
||||||
|
"latency_ms": 245.8,
|
||||||
|
"success_rate": 0.987,
|
||||||
|
"last_check": "2025-01-01T12:00:00Z",
|
||||||
|
"error_message": None,
|
||||||
|
"models_available": [
|
||||||
|
"privatemode-llama-3-70b",
|
||||||
|
"privatemode-claude-3-sonnet",
|
||||||
|
"privatemode-gpt-4o",
|
||||||
|
"privatemode-embeddings"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("app.services.llm.service.llm_service.get_provider_status") as mock_provider:
|
||||||
|
mock_provider.return_value = mock_status
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/llm/providers/status",
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Check response structure
|
||||||
|
assert "data" in data
|
||||||
|
assert "privatemode" in data["data"]
|
||||||
|
|
||||||
|
provider_data = data["data"]["privatemode"]
|
||||||
|
assert provider_data["status"] == "healthy"
|
||||||
|
assert provider_data["latency_ms"] < 300 # Reasonable latency
|
||||||
|
assert provider_data["success_rate"] > 0.95 # High success rate
|
||||||
|
assert len(provider_data["models_available"]) >= 4
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_handling_and_fallback(self, client: AsyncClient):
|
||||||
|
"""Test error handling and fallback scenarios."""
|
||||||
|
# Test provider unavailable scenario
|
||||||
|
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
|
||||||
|
mock_chat.side_effect = Exception("Provider temporarily unavailable")
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "privatemode-llama-3-70b",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return error but not crash
|
||||||
|
assert response.status_code in [500, 503] # Server error or service unavailable
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_security_threat_detection(self, client: AsyncClient):
|
||||||
|
"""Test security threat detection integration."""
|
||||||
|
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
|
||||||
|
|
||||||
|
# Mock response with security threat detected
|
||||||
|
mock_response = ChatCompletionResponse(
|
||||||
|
id="test-completion-security",
|
||||||
|
object="chat.completion",
|
||||||
|
created=int(time.time()),
|
||||||
|
model="privatemode-llama-3-70b",
|
||||||
|
choices=[
|
||||||
|
ChatChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="I cannot help with that request as it violates security policies."
|
||||||
|
),
|
||||||
|
finish_reason="stop"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=Usage(
|
||||||
|
prompt_tokens=15,
|
||||||
|
completion_tokens=12,
|
||||||
|
total_tokens=27
|
||||||
|
),
|
||||||
|
security_analysis={
|
||||||
|
"risk_score": 0.8,
|
||||||
|
"threats_detected": ["potential_malicious_code"],
|
||||||
|
"risk_level": "high",
|
||||||
|
"blocked": True,
|
||||||
|
"analysis_time_ms": 45.2
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
|
||||||
|
mock_chat.return_value = mock_response
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "privatemode-llama-3-70b",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "How to create malicious code?"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200 # Request succeeds but content is filtered
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify security analysis
|
||||||
|
assert "security_analysis" in data
|
||||||
|
assert data["security_analysis"]["risk_level"] == "high"
|
||||||
|
assert data["security_analysis"]["blocked"] is True
|
||||||
|
assert "malicious" in data["security_analysis"]["threats_detected"][0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_performance_characteristics(self, client: AsyncClient):
|
||||||
|
"""Test performance characteristics of the LLM service."""
|
||||||
|
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
|
||||||
|
|
||||||
|
# Mock fast response
|
||||||
|
mock_response = ChatCompletionResponse(
|
||||||
|
id="test-perf",
|
||||||
|
object="chat.completion",
|
||||||
|
created=int(time.time()),
|
||||||
|
model="privatemode-llama-3-70b",
|
||||||
|
choices=[
|
||||||
|
ChatChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="Quick response for performance testing."
|
||||||
|
),
|
||||||
|
finish_reason="stop"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=Usage(
|
||||||
|
prompt_tokens=10,
|
||||||
|
completion_tokens=8,
|
||||||
|
total_tokens=18
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
|
||||||
|
mock_chat.return_value = mock_response
|
||||||
|
|
||||||
|
# Measure response time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "privatemode-llama-3-70b",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Quick test"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
response_time = time.time() - start_time
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
# API should respond quickly (mocked, so should be very fast)
|
||||||
|
assert response_time < 1.0 # Less than 1 second for mocked response
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_capabilities_detection(self, client: AsyncClient):
|
||||||
|
"""Test model capabilities detection and reporting."""
|
||||||
|
from app.services.llm.models import Model
|
||||||
|
|
||||||
|
mock_models = [
|
||||||
|
Model(
|
||||||
|
id="privatemode-llama-3-70b",
|
||||||
|
object="model",
|
||||||
|
created=1234567890,
|
||||||
|
owned_by="PrivateMode.ai",
|
||||||
|
provider="PrivateMode.ai",
|
||||||
|
capabilities=["tee", "chat", "function_calling"],
|
||||||
|
context_window=32768,
|
||||||
|
max_output_tokens=4096,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True
|
||||||
|
),
|
||||||
|
Model(
|
||||||
|
id="privatemode-embeddings",
|
||||||
|
object="model",
|
||||||
|
created=1234567890,
|
||||||
|
owned_by="PrivateMode.ai",
|
||||||
|
provider="PrivateMode.ai",
|
||||||
|
capabilities=["tee", "embeddings"],
|
||||||
|
context_window=512,
|
||||||
|
supports_streaming=False,
|
||||||
|
supports_function_calling=False
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("app.services.llm.service.llm_service.get_models") as mock_models_call:
|
||||||
|
mock_models_call.return_value = mock_models
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/llm/models",
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify model capabilities
|
||||||
|
assert len(data["data"]) == 2
|
||||||
|
|
||||||
|
# Check chat model capabilities
|
||||||
|
chat_model = next(m for m in data["data"] if m["id"] == "privatemode-llama-3-70b")
|
||||||
|
assert "tee" in chat_model["capabilities"]
|
||||||
|
assert "chat" in chat_model["capabilities"]
|
||||||
|
assert chat_model["supports_streaming"] is True
|
||||||
|
assert chat_model["supports_function_calling"] is True
|
||||||
|
assert chat_model["context_window"] == 32768
|
||||||
|
|
||||||
|
# Check embedding model capabilities
|
||||||
|
embed_model = next(m for m in data["data"] if m["id"] == "privatemode-embeddings")
|
||||||
|
assert "tee" in embed_model["capabilities"]
|
||||||
|
assert "embeddings" in embed_model["capabilities"]
|
||||||
|
assert embed_model["supports_streaming"] is False
|
||||||
|
assert embed_model["context_window"] == 512
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_requests(self, client: AsyncClient):
|
||||||
|
"""Test handling of concurrent requests."""
|
||||||
|
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
|
||||||
|
|
||||||
|
mock_response = ChatCompletionResponse(
|
||||||
|
id="test-concurrent",
|
||||||
|
object="chat.completion",
|
||||||
|
created=int(time.time()),
|
||||||
|
model="privatemode-llama-3-70b",
|
||||||
|
choices=[
|
||||||
|
ChatChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="Concurrent response"
|
||||||
|
),
|
||||||
|
finish_reason="stop"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=Usage(
|
||||||
|
prompt_tokens=5,
|
||||||
|
completion_tokens=3,
|
||||||
|
total_tokens=8
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
|
||||||
|
mock_chat.return_value = mock_response
|
||||||
|
|
||||||
|
# Create multiple concurrent requests
|
||||||
|
tasks = []
|
||||||
|
for i in range(5):
|
||||||
|
task = client.post(
|
||||||
|
"/api/v1/llm/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "privatemode-llama-3-70b",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": f"Concurrent test {i}"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Execute all requests concurrently
|
||||||
|
responses = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Verify all requests succeeded
|
||||||
|
for response in responses:
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "choices" in data
|
||||||
|
assert data["choices"][0]["message"]["content"] == "Concurrent response"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_budget_enforcement_integration(self, client: AsyncClient):
|
||||||
|
"""Test budget enforcement integration with LLM service."""
|
||||||
|
# Test budget exceeded scenario
|
||||||
|
with patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_budget:
|
||||||
|
mock_budget.side_effect = Exception("Monthly budget limit exceeded")
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "privatemode-llama-3-70b",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Test budget enforcement"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 402 # Payment required
|
||||||
|
|
||||||
|
# Test budget warning scenario
|
||||||
|
from app.services.llm.models import ChatCompletionResponse, ChatChoice, ChatMessage, Usage
|
||||||
|
|
||||||
|
mock_response = ChatCompletionResponse(
|
||||||
|
id="test-budget-warning",
|
||||||
|
object="chat.completion",
|
||||||
|
created=int(time.time()),
|
||||||
|
model="privatemode-llama-3-70b",
|
||||||
|
choices=[
|
||||||
|
ChatChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="Response with budget warning"
|
||||||
|
),
|
||||||
|
finish_reason="stop"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=Usage(
|
||||||
|
prompt_tokens=10,
|
||||||
|
completion_tokens=8,
|
||||||
|
total_tokens=18
|
||||||
|
),
|
||||||
|
budget_warnings=["Approaching monthly budget limit (85% used)"]
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat, \
|
||||||
|
patch("app.services.budget_enforcement.BudgetEnforcementService.check_budget_compliance") as mock_budget:
|
||||||
|
|
||||||
|
mock_chat.return_value = mock_response
|
||||||
|
mock_budget.return_value = True # Budget check passes but with warning
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "privatemode-llama-3-70b",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Test budget warning"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "budget_warnings" in data
|
||||||
|
assert len(data["budget_warnings"]) > 0
|
||||||
|
assert "85%" in data["budget_warnings"][0]
|
||||||
296
backend/tests/integration/test_llm_validation.py
Normal file
296
backend/tests/integration/test_llm_validation.py
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
"""
|
||||||
|
Simple validation tests for the new LLM service integration.
|
||||||
|
Tests basic functionality without complex mocking.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMServiceValidation:
|
||||||
|
"""Basic validation tests for LLM service integration."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_models_endpoint_exists(self, client: AsyncClient):
|
||||||
|
"""Test that the LLM models endpoint exists and is accessible."""
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/llm/models",
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not return 404 (endpoint exists)
|
||||||
|
assert response.status_code != 404
|
||||||
|
# May return 500 or other error due to missing LLM service, but endpoint exists
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_chat_endpoint_exists(self, client: AsyncClient):
|
||||||
|
"""Test that the LLM chat endpoint exists and is accessible."""
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not return 404 (endpoint exists)
|
||||||
|
assert response.status_code != 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_embeddings_endpoint_exists(self, client: AsyncClient):
|
||||||
|
"""Test that the LLM embeddings endpoint exists and is accessible."""
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/embeddings",
|
||||||
|
json={
|
||||||
|
"model": "test-embedding-model",
|
||||||
|
"input": "Test text"
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not return 404 (endpoint exists)
|
||||||
|
assert response.status_code != 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_provider_status_endpoint_exists(self, client: AsyncClient):
|
||||||
|
"""Test that the provider status endpoint exists and is accessible."""
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/llm/providers/status",
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not return 404 (endpoint exists)
|
||||||
|
assert response.status_code != 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_mocked_service(self, client: AsyncClient):
|
||||||
|
"""Test chat completion with mocked LLM service."""
|
||||||
|
from app.services.llm.models import ChatResponse, ChatChoice, ChatMessage, TokenUsage
|
||||||
|
|
||||||
|
# Mock successful response
|
||||||
|
mock_response = ChatResponse(
|
||||||
|
id="test-123",
|
||||||
|
object="chat.completion",
|
||||||
|
created=1234567890,
|
||||||
|
model="privatemode-llama-3-70b",
|
||||||
|
provider="PrivateMode.ai",
|
||||||
|
choices=[
|
||||||
|
ChatChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="Hello! How can I help you?"
|
||||||
|
),
|
||||||
|
finish_reason="stop"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=TokenUsage(
|
||||||
|
prompt_tokens=10,
|
||||||
|
completion_tokens=8,
|
||||||
|
total_tokens=18
|
||||||
|
),
|
||||||
|
security_check=True,
|
||||||
|
risk_score=0.1,
|
||||||
|
detected_patterns=[],
|
||||||
|
latency_ms=250.5
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.llm.service.llm_service.create_chat_completion") as mock_chat:
|
||||||
|
mock_chat.return_value = mock_response
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "privatemode-llama-3-70b",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify basic response structure
|
||||||
|
assert "id" in data
|
||||||
|
assert "choices" in data
|
||||||
|
assert len(data["choices"]) == 1
|
||||||
|
assert data["choices"][0]["message"]["content"] == "Hello! How can I help you?"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedding_with_mocked_service(self, client: AsyncClient):
|
||||||
|
"""Test embedding generation with mocked LLM service."""
|
||||||
|
from app.services.llm.models import EmbeddingResponse, EmbeddingData, TokenUsage
|
||||||
|
|
||||||
|
# Create a simple embedding vector
|
||||||
|
embedding_vector = [0.1, 0.2, 0.3] * 341 + [0.1, 0.2, 0.3] # 1024 dimensions
|
||||||
|
|
||||||
|
mock_response = EmbeddingResponse(
|
||||||
|
object="list",
|
||||||
|
data=[
|
||||||
|
EmbeddingData(
|
||||||
|
object="embedding",
|
||||||
|
index=0,
|
||||||
|
embedding=embedding_vector
|
||||||
|
)
|
||||||
|
],
|
||||||
|
model="privatemode-embeddings",
|
||||||
|
provider="PrivateMode.ai",
|
||||||
|
usage=TokenUsage(
|
||||||
|
prompt_tokens=5,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=5
|
||||||
|
),
|
||||||
|
security_check=True,
|
||||||
|
risk_score=0.0,
|
||||||
|
latency_ms=150.0
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.llm.service.llm_service.create_embedding") as mock_embedding:
|
||||||
|
mock_embedding.return_value = mock_response
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/embeddings",
|
||||||
|
json={
|
||||||
|
"model": "privatemode-embeddings",
|
||||||
|
"input": "Test text for embedding"
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify basic response structure
|
||||||
|
assert "data" in data
|
||||||
|
assert len(data["data"]) == 1
|
||||||
|
assert len(data["data"][0]["embedding"]) == 1024
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_models_with_mocked_service(self, client: AsyncClient):
|
||||||
|
"""Test models listing with mocked LLM service."""
|
||||||
|
from app.services.llm.models import ModelInfo
|
||||||
|
|
||||||
|
mock_models = [
|
||||||
|
ModelInfo(
|
||||||
|
id="privatemode-llama-3-70b",
|
||||||
|
object="model",
|
||||||
|
created=1234567890,
|
||||||
|
owned_by="PrivateMode.ai",
|
||||||
|
provider="PrivateMode.ai",
|
||||||
|
capabilities=["tee", "chat"],
|
||||||
|
context_window=32768,
|
||||||
|
supports_streaming=True
|
||||||
|
),
|
||||||
|
ModelInfo(
|
||||||
|
id="privatemode-embeddings",
|
||||||
|
object="model",
|
||||||
|
created=1234567890,
|
||||||
|
owned_by="PrivateMode.ai",
|
||||||
|
provider="PrivateMode.ai",
|
||||||
|
capabilities=["tee", "embeddings"],
|
||||||
|
context_window=512,
|
||||||
|
supports_streaming=False
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("app.services.llm.service.llm_service.get_models") as mock_models_call:
|
||||||
|
mock_models_call.return_value = mock_models
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/llm/models",
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify basic response structure
|
||||||
|
assert "data" in data
|
||||||
|
assert len(data["data"]) == 2
|
||||||
|
assert data["data"][0]["id"] == "privatemode-llama-3-70b"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_status_with_mocked_service(self, client: AsyncClient):
|
||||||
|
"""Test provider status with mocked LLM service."""
|
||||||
|
mock_status = {
|
||||||
|
"privatemode": {
|
||||||
|
"provider": "PrivateMode.ai",
|
||||||
|
"status": "healthy",
|
||||||
|
"latency_ms": 250.5,
|
||||||
|
"success_rate": 0.98,
|
||||||
|
"last_check": "2025-01-01T12:00:00Z",
|
||||||
|
"models_available": ["privatemode-llama-3-70b", "privatemode-embeddings"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("app.services.llm.service.llm_service.get_provider_status") as mock_provider:
|
||||||
|
mock_provider.return_value = mock_status
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/llm/providers/status",
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify basic response structure
|
||||||
|
assert "data" in data
|
||||||
|
assert "privatemode" in data["data"]
|
||||||
|
assert data["data"]["privatemode"]["status"] == "healthy"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unauthorized_access(self, client: AsyncClient):
|
||||||
|
"""Test that unauthorized requests are properly rejected."""
|
||||||
|
# Test without authorization header
|
||||||
|
response = await client.get("/api/v1/llm/models")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/embeddings",
|
||||||
|
json={
|
||||||
|
"model": "test-model",
|
||||||
|
"input": "Hello"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_request_data(self, client: AsyncClient):
|
||||||
|
"""Test that invalid request data is properly handled."""
|
||||||
|
# Test invalid JSON structure
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/chat/completions",
|
||||||
|
json={
|
||||||
|
# Missing required fields
|
||||||
|
"model": "test-model"
|
||||||
|
# messages field is missing
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 422 # Unprocessable Entity
|
||||||
|
|
||||||
|
# Test empty messages
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/llm/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [] # Empty messages
|
||||||
|
},
|
||||||
|
headers={"Authorization": "Bearer test-api-key"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 422 # Unprocessable Entity
|
||||||
@@ -6,7 +6,7 @@ import { Badge } from '@/components/ui/badge'
|
|||||||
import { Button } from '@/components/ui/button'
|
import { Button } from '@/components/ui/button'
|
||||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
|
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
|
||||||
import { Alert, AlertDescription } from '@/components/ui/alert'
|
import { Alert, AlertDescription } from '@/components/ui/alert'
|
||||||
import { RefreshCw, Zap, Info, AlertCircle } from 'lucide-react'
|
import { RefreshCw, Zap, Info, AlertCircle, CheckCircle, XCircle, Clock } from 'lucide-react'
|
||||||
|
|
||||||
interface Model {
|
interface Model {
|
||||||
id: string
|
id: string
|
||||||
@@ -16,6 +16,22 @@ interface Model {
|
|||||||
permission?: any[]
|
permission?: any[]
|
||||||
root?: string
|
root?: string
|
||||||
parent?: string
|
parent?: string
|
||||||
|
provider?: string
|
||||||
|
capabilities?: string[]
|
||||||
|
context_window?: number
|
||||||
|
max_output_tokens?: number
|
||||||
|
supports_streaming?: boolean
|
||||||
|
supports_function_calling?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ProviderStatus {
|
||||||
|
provider: string
|
||||||
|
status: 'healthy' | 'degraded' | 'unavailable'
|
||||||
|
latency_ms?: number
|
||||||
|
success_rate?: number
|
||||||
|
last_check: string
|
||||||
|
error_message?: string
|
||||||
|
models_available: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ModelSelectorProps {
|
interface ModelSelectorProps {
|
||||||
@@ -27,6 +43,7 @@ interface ModelSelectorProps {
|
|||||||
|
|
||||||
export default function ModelSelector({ value, onValueChange, filter = 'all', className }: ModelSelectorProps) {
|
export default function ModelSelector({ value, onValueChange, filter = 'all', className }: ModelSelectorProps) {
|
||||||
const [models, setModels] = useState<Model[]>([])
|
const [models, setModels] = useState<Model[]>([])
|
||||||
|
const [providerStatus, setProviderStatus] = useState<Record<string, ProviderStatus>>({})
|
||||||
const [loading, setLoading] = useState(true)
|
const [loading, setLoading] = useState(true)
|
||||||
const [error, setError] = useState<string | null>(null)
|
const [error, setError] = useState<string | null>(null)
|
||||||
const [showDetails, setShowDetails] = useState(false)
|
const [showDetails, setShowDetails] = useState(false)
|
||||||
@@ -37,20 +54,31 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
|
|
||||||
// Get the auth token from localStorage
|
// Get the auth token from localStorage
|
||||||
const token = localStorage.getItem('token')
|
const token = localStorage.getItem('token')
|
||||||
|
const headers = {
|
||||||
|
'Authorization': token ? `Bearer ${token}` : '',
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
|
||||||
const response = await fetch('/api/llm/models', {
|
// Fetch models and provider status in parallel
|
||||||
headers: {
|
const [modelsResponse, statusResponse] = await Promise.allSettled([
|
||||||
'Authorization': token ? `Bearer ${token}` : '',
|
fetch('/api/llm/models', { headers }),
|
||||||
'Content-Type': 'application/json'
|
fetch('/api/llm/providers/status', { headers })
|
||||||
}
|
])
|
||||||
})
|
|
||||||
|
|
||||||
if (!response.ok) {
|
// Handle models response
|
||||||
|
if (modelsResponse.status === 'fulfilled' && modelsResponse.value.ok) {
|
||||||
|
const modelsData = await modelsResponse.value.json()
|
||||||
|
setModels(modelsData.data || [])
|
||||||
|
} else {
|
||||||
throw new Error('Failed to fetch models')
|
throw new Error('Failed to fetch models')
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json()
|
// Handle provider status response (optional)
|
||||||
setModels(data.data || [])
|
if (statusResponse.status === 'fulfilled' && statusResponse.value.ok) {
|
||||||
|
const statusData = await statusResponse.value.json()
|
||||||
|
setProviderStatus(statusData.data || {})
|
||||||
|
}
|
||||||
|
|
||||||
setError(null)
|
setError(null)
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setError(err instanceof Error ? err.message : 'Failed to load models')
|
setError(err instanceof Error ? err.message : 'Failed to load models')
|
||||||
@@ -64,30 +92,39 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
const getProviderFromModel = (modelId: string): string => {
|
const getProviderFromModel = (modelId: string): string => {
|
||||||
|
// PrivateMode models have specific prefixes
|
||||||
|
if (modelId.startsWith('privatemode-')) return 'PrivateMode.ai'
|
||||||
|
|
||||||
|
// Legacy detection for other providers
|
||||||
if (modelId.startsWith('gpt-') || modelId.includes('openai')) return 'OpenAI'
|
if (modelId.startsWith('gpt-') || modelId.includes('openai')) return 'OpenAI'
|
||||||
if (modelId.startsWith('claude-') || modelId.includes('anthropic')) return 'Anthropic'
|
if (modelId.startsWith('claude-') || modelId.includes('anthropic')) return 'Anthropic'
|
||||||
if (modelId.startsWith('gemini-') || modelId.includes('google')) return 'Google'
|
if (modelId.startsWith('gemini-') || modelId.includes('google')) return 'Google'
|
||||||
if (modelId.includes('privatemode')) return 'Privatemode.ai'
|
|
||||||
if (modelId.includes('cohere')) return 'Cohere'
|
if (modelId.includes('cohere')) return 'Cohere'
|
||||||
if (modelId.includes('mistral')) return 'Mistral'
|
if (modelId.includes('mistral')) return 'Mistral'
|
||||||
if (modelId.includes('llama')) return 'Meta'
|
if (modelId.includes('llama') && !modelId.startsWith('privatemode-')) return 'Meta'
|
||||||
return 'Unknown'
|
return 'Unknown'
|
||||||
}
|
}
|
||||||
|
|
||||||
const getModelType = (modelId: string): 'chat' | 'embedding' | 'other' => {
|
const getModelType = (modelId: string): 'chat' | 'embedding' | 'other' => {
|
||||||
if (modelId.includes('embedding')) return 'embedding'
|
if (modelId.includes('embedding') || modelId.includes('embed')) return 'embedding'
|
||||||
if (modelId.includes('whisper')) return 'other' // Audio transcription models
|
if (modelId.includes('whisper') || modelId.includes('speech')) return 'other' // Audio models
|
||||||
|
|
||||||
|
// PrivateMode and other chat models
|
||||||
if (
|
if (
|
||||||
|
modelId.startsWith('privatemode-llama') ||
|
||||||
|
modelId.startsWith('privatemode-claude') ||
|
||||||
|
modelId.startsWith('privatemode-gpt') ||
|
||||||
|
modelId.startsWith('privatemode-gemini') ||
|
||||||
modelId.includes('text-') ||
|
modelId.includes('text-') ||
|
||||||
modelId.includes('gpt-') ||
|
modelId.includes('gpt-') ||
|
||||||
modelId.includes('claude-') ||
|
modelId.includes('claude-') ||
|
||||||
modelId.includes('gemini-') ||
|
modelId.includes('gemini-') ||
|
||||||
modelId.includes('privatemode-') ||
|
|
||||||
modelId.includes('llama') ||
|
modelId.includes('llama') ||
|
||||||
modelId.includes('gemma') ||
|
modelId.includes('gemma') ||
|
||||||
modelId.includes('qwen') ||
|
modelId.includes('qwen') ||
|
||||||
modelId.includes('latest')
|
modelId.includes('latest')
|
||||||
) return 'chat'
|
) return 'chat'
|
||||||
|
|
||||||
return 'other'
|
return 'other'
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,6 +150,28 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
return acc
|
return acc
|
||||||
}, {} as Record<string, Model[]>)
|
}, {} as Record<string, Model[]>)
|
||||||
|
|
||||||
|
const getProviderStatusIcon = (provider: string) => {
|
||||||
|
const status = providerStatus[provider.toLowerCase()]?.status || 'unknown'
|
||||||
|
switch (status) {
|
||||||
|
case 'healthy':
|
||||||
|
return <CheckCircle className="h-3 w-3 text-green-500" />
|
||||||
|
case 'degraded':
|
||||||
|
return <Clock className="h-3 w-3 text-yellow-500" />
|
||||||
|
case 'unavailable':
|
||||||
|
return <XCircle className="h-3 w-3 text-red-500" />
|
||||||
|
default:
|
||||||
|
return <AlertCircle className="h-3 w-3 text-gray-400" />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const getProviderStatusText = (provider: string) => {
|
||||||
|
const status = providerStatus[provider.toLowerCase()]
|
||||||
|
if (!status) return 'Status unknown'
|
||||||
|
|
||||||
|
const latencyText = status.latency_ms ? ` (${Math.round(status.latency_ms)}ms)` : ''
|
||||||
|
return `${status.status.charAt(0).toUpperCase() + status.status.slice(1)}${latencyText}`
|
||||||
|
}
|
||||||
|
|
||||||
const selectedModel = models.find(m => m.id === value)
|
const selectedModel = models.find(m => m.id === value)
|
||||||
|
|
||||||
if (loading) {
|
if (loading) {
|
||||||
@@ -191,16 +250,32 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
<SelectContent>
|
<SelectContent>
|
||||||
{Object.entries(groupedModels).map(([provider, providerModels]) => (
|
{Object.entries(groupedModels).map(([provider, providerModels]) => (
|
||||||
<div key={provider}>
|
<div key={provider}>
|
||||||
<div className="px-2 py-1.5 text-sm font-semibold text-muted-foreground">
|
<div className="px-2 py-1.5 text-sm font-semibold text-muted-foreground flex items-center gap-2">
|
||||||
{provider}
|
{getProviderStatusIcon(provider)}
|
||||||
|
<span>{provider}</span>
|
||||||
|
<span className="text-xs font-normal text-muted-foreground">
|
||||||
|
{getProviderStatusText(provider)}
|
||||||
|
</span>
|
||||||
</div>
|
</div>
|
||||||
{providerModels.map((model) => (
|
{providerModels.map((model) => (
|
||||||
<SelectItem key={model.id} value={model.id}>
|
<SelectItem key={model.id} value={model.id}>
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
<span>{model.id}</span>
|
<span>{model.id}</span>
|
||||||
<Badge variant="outline" className="text-xs">
|
<div className="flex gap-1">
|
||||||
{getModelCategory(model.id)}
|
<Badge variant="outline" className="text-xs">
|
||||||
</Badge>
|
{getModelCategory(model.id)}
|
||||||
|
</Badge>
|
||||||
|
{model.supports_streaming && (
|
||||||
|
<Badge variant="secondary" className="text-xs">
|
||||||
|
Streaming
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
{model.supports_function_calling && (
|
||||||
|
<Badge variant="secondary" className="text-xs">
|
||||||
|
Functions
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
))}
|
))}
|
||||||
@@ -217,7 +292,7 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
Model Details
|
Model Details
|
||||||
</CardTitle>
|
</CardTitle>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent className="space-y-2 text-sm">
|
<CardContent className="space-y-3 text-sm">
|
||||||
<div className="grid grid-cols-2 gap-4">
|
<div className="grid grid-cols-2 gap-4">
|
||||||
<div>
|
<div>
|
||||||
<span className="font-medium">ID:</span>
|
<span className="font-medium">ID:</span>
|
||||||
@@ -225,7 +300,10 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<span className="font-medium">Provider:</span>
|
<span className="font-medium">Provider:</span>
|
||||||
<div className="text-muted-foreground">{getProviderFromModel(selectedModel.id)}</div>
|
<div className="text-muted-foreground flex items-center gap-1">
|
||||||
|
{getProviderStatusIcon(getProviderFromModel(selectedModel.id))}
|
||||||
|
{getProviderFromModel(selectedModel.id)}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<span className="font-medium">Type:</span>
|
<span className="font-medium">Type:</span>
|
||||||
@@ -237,6 +315,40 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{(selectedModel.context_window || selectedModel.max_output_tokens) && (
|
||||||
|
<div className="grid grid-cols-2 gap-4">
|
||||||
|
{selectedModel.context_window && (
|
||||||
|
<div>
|
||||||
|
<span className="font-medium">Context Window:</span>
|
||||||
|
<div className="text-muted-foreground">{selectedModel.context_window.toLocaleString()} tokens</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{selectedModel.max_output_tokens && (
|
||||||
|
<div>
|
||||||
|
<span className="font-medium">Max Output:</span>
|
||||||
|
<div className="text-muted-foreground">{selectedModel.max_output_tokens.toLocaleString()} tokens</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{(selectedModel.supports_streaming || selectedModel.supports_function_calling) && (
|
||||||
|
<div>
|
||||||
|
<span className="font-medium">Capabilities:</span>
|
||||||
|
<div className="flex gap-1 mt-1">
|
||||||
|
{selectedModel.supports_streaming && (
|
||||||
|
<Badge variant="secondary" className="text-xs">Streaming</Badge>
|
||||||
|
)}
|
||||||
|
{selectedModel.supports_function_calling && (
|
||||||
|
<Badge variant="secondary" className="text-xs">Function Calling</Badge>
|
||||||
|
)}
|
||||||
|
{selectedModel.capabilities?.includes('tee') && (
|
||||||
|
<Badge variant="outline" className="text-xs border-green-500 text-green-700">TEE Protected</Badge>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{selectedModel.created && (
|
{selectedModel.created && (
|
||||||
<div>
|
<div>
|
||||||
<span className="font-medium">Created:</span>
|
<span className="font-medium">Created:</span>
|
||||||
@@ -252,6 +364,46 @@ export default function ModelSelector({ value, onValueChange, filter = 'all', cl
|
|||||||
<div className="text-muted-foreground">{selectedModel.owned_by}</div>
|
<div className="text-muted-foreground">{selectedModel.owned_by}</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* Provider Status Details */}
|
||||||
|
{providerStatus[getProviderFromModel(selectedModel.id).toLowerCase()] && (
|
||||||
|
<div className="border-t pt-3">
|
||||||
|
<span className="font-medium">Provider Status:</span>
|
||||||
|
<div className="mt-1 text-xs space-y-1">
|
||||||
|
{(() => {
|
||||||
|
const status = providerStatus[getProviderFromModel(selectedModel.id).toLowerCase()]
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Status:</span>
|
||||||
|
<span className={`font-medium ${
|
||||||
|
status.status === 'healthy' ? 'text-green-600' :
|
||||||
|
status.status === 'degraded' ? 'text-yellow-600' :
|
||||||
|
'text-red-600'
|
||||||
|
}`}>{status.status}</span>
|
||||||
|
</div>
|
||||||
|
{status.latency_ms && (
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Latency:</span>
|
||||||
|
<span>{Math.round(status.latency_ms)}ms</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{status.success_rate && (
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Success Rate:</span>
|
||||||
|
<span>{Math.round(status.success_rate * 100)}%</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Last Check:</span>
|
||||||
|
<span>{new Date(status.last_check).toLocaleTimeString()}</span>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
})()}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
)}
|
)}
|
||||||
|
|||||||
373
frontend/src/components/playground/ProviderHealthDashboard.tsx
Normal file
373
frontend/src/components/playground/ProviderHealthDashboard.tsx
Normal file
@@ -0,0 +1,373 @@
|
|||||||
|
"use client"
|
||||||
|
|
||||||
|
import { useState, useEffect } from 'react'
|
||||||
|
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'
|
||||||
|
import { Badge } from '@/components/ui/badge'
|
||||||
|
import { Button } from '@/components/ui/button'
|
||||||
|
import { Alert, AlertDescription } from '@/components/ui/alert'
|
||||||
|
import { Progress } from '@/components/ui/progress'
|
||||||
|
import {
|
||||||
|
CheckCircle,
|
||||||
|
XCircle,
|
||||||
|
Clock,
|
||||||
|
AlertCircle,
|
||||||
|
RefreshCw,
|
||||||
|
Activity,
|
||||||
|
Zap,
|
||||||
|
Shield,
|
||||||
|
Server
|
||||||
|
} from 'lucide-react'
|
||||||
|
|
||||||
|
interface ProviderStatus {
|
||||||
|
provider: string
|
||||||
|
status: 'healthy' | 'degraded' | 'unavailable'
|
||||||
|
latency_ms?: number
|
||||||
|
success_rate?: number
|
||||||
|
last_check: string
|
||||||
|
error_message?: string
|
||||||
|
models_available: string[]
|
||||||
|
}
|
||||||
|
|
||||||
|
interface LLMMetrics {
|
||||||
|
total_requests: number
|
||||||
|
successful_requests: number
|
||||||
|
failed_requests: number
|
||||||
|
security_blocked_requests: number
|
||||||
|
average_latency_ms: number
|
||||||
|
average_risk_score: number
|
||||||
|
provider_metrics: Record<string, any>
|
||||||
|
last_updated: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function ProviderHealthDashboard() {
|
||||||
|
const [providers, setProviders] = useState<Record<string, ProviderStatus>>({})
|
||||||
|
const [metrics, setMetrics] = useState<LLMMetrics | null>(null)
|
||||||
|
const [loading, setLoading] = useState(true)
|
||||||
|
const [error, setError] = useState<string | null>(null)
|
||||||
|
|
||||||
|
const fetchData = async () => {
|
||||||
|
try {
|
||||||
|
setLoading(true)
|
||||||
|
const token = localStorage.getItem('token')
|
||||||
|
const headers = {
|
||||||
|
'Authorization': token ? `Bearer ${token}` : '',
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
|
||||||
|
const [statusResponse, metricsResponse] = await Promise.allSettled([
|
||||||
|
fetch('/api/llm/providers/status', { headers }),
|
||||||
|
fetch('/api/llm/metrics', { headers })
|
||||||
|
])
|
||||||
|
|
||||||
|
// Handle provider status
|
||||||
|
if (statusResponse.status === 'fulfilled' && statusResponse.value.ok) {
|
||||||
|
const statusData = await statusResponse.value.json()
|
||||||
|
setProviders(statusData.data || {})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle metrics (optional, might require admin permissions)
|
||||||
|
if (metricsResponse.status === 'fulfilled' && metricsResponse.value.ok) {
|
||||||
|
const metricsData = await metricsResponse.value.json()
|
||||||
|
setMetrics(metricsData.data)
|
||||||
|
}
|
||||||
|
|
||||||
|
setError(null)
|
||||||
|
} catch (err) {
|
||||||
|
setError(err instanceof Error ? err.message : 'Failed to load provider data')
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchData()
|
||||||
|
|
||||||
|
// Auto-refresh every 30 seconds
|
||||||
|
const interval = setInterval(fetchData, 30000)
|
||||||
|
return () => clearInterval(interval)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const getStatusIcon = (status: string) => {
|
||||||
|
switch (status) {
|
||||||
|
case 'healthy':
|
||||||
|
return <CheckCircle className="h-5 w-5 text-green-500" />
|
||||||
|
case 'degraded':
|
||||||
|
return <Clock className="h-5 w-5 text-yellow-500" />
|
||||||
|
case 'unavailable':
|
||||||
|
return <XCircle className="h-5 w-5 text-red-500" />
|
||||||
|
default:
|
||||||
|
return <AlertCircle className="h-5 w-5 text-gray-400" />
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const getStatusColor = (status: string) => {
|
||||||
|
switch (status) {
|
||||||
|
case 'healthy':
|
||||||
|
return 'text-green-600 bg-green-50 border-green-200'
|
||||||
|
case 'degraded':
|
||||||
|
return 'text-yellow-600 bg-yellow-50 border-yellow-200'
|
||||||
|
case 'unavailable':
|
||||||
|
return 'text-red-600 bg-red-50 border-red-200'
|
||||||
|
default:
|
||||||
|
return 'text-gray-600 bg-gray-50 border-gray-200'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const getLatencyColor = (latency: number) => {
|
||||||
|
if (latency < 500) return 'text-green-600'
|
||||||
|
if (latency < 2000) return 'text-yellow-600'
|
||||||
|
return 'text-red-600'
|
||||||
|
}
|
||||||
|
|
||||||
|
if (loading) {
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h2 className="text-2xl font-bold">Provider Health Dashboard</h2>
|
||||||
|
<RefreshCw className="h-5 w-5 animate-spin" />
|
||||||
|
</div>
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||||
|
{[1, 2, 3].map(i => (
|
||||||
|
<Card key={i} className="animate-pulse">
|
||||||
|
<CardHeader className="space-y-2">
|
||||||
|
<div className="h-4 bg-gray-200 rounded w-3/4"></div>
|
||||||
|
<div className="h-3 bg-gray-200 rounded w-1/2"></div>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
<div className="space-y-2">
|
||||||
|
<div className="h-3 bg-gray-200 rounded"></div>
|
||||||
|
<div className="h-3 bg-gray-200 rounded w-2/3"></div>
|
||||||
|
</div>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h2 className="text-2xl font-bold">Provider Health Dashboard</h2>
|
||||||
|
<Button onClick={fetchData} size="sm">
|
||||||
|
<RefreshCw className="h-4 w-4 mr-2" />
|
||||||
|
Retry
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
<Alert variant="destructive">
|
||||||
|
<AlertCircle className="h-4 w-4" />
|
||||||
|
<AlertDescription>{error}</AlertDescription>
|
||||||
|
</Alert>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const totalProviders = Object.keys(providers).length
|
||||||
|
const healthyProviders = Object.values(providers).filter(p => p.status === 'healthy').length
|
||||||
|
const overallHealth = totalProviders > 0 ? (healthyProviders / totalProviders) * 100 : 0
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-6">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h2 className="text-2xl font-bold">Provider Health Dashboard</h2>
|
||||||
|
<Button onClick={fetchData} size="sm" disabled={loading}>
|
||||||
|
<RefreshCw className={`h-4 w-4 mr-2 ${loading ? 'animate-spin' : ''}`} />
|
||||||
|
Refresh
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Overall Health Summary */}
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-4 gap-4">
|
||||||
|
<Card>
|
||||||
|
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
|
||||||
|
<CardTitle className="text-sm font-medium">Overall Health</CardTitle>
|
||||||
|
<Activity className="h-4 w-4 text-muted-foreground" />
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
<div className="text-2xl font-bold">{Math.round(overallHealth)}%</div>
|
||||||
|
<Progress value={overallHealth} className="mt-2" />
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
<Card>
|
||||||
|
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
|
||||||
|
<CardTitle className="text-sm font-medium">Healthy Providers</CardTitle>
|
||||||
|
<CheckCircle className="h-4 w-4 text-green-500" />
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
<div className="text-2xl font-bold">{healthyProviders}</div>
|
||||||
|
<p className="text-xs text-muted-foreground">of {totalProviders} providers</p>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
{metrics && (
|
||||||
|
<>
|
||||||
|
<Card>
|
||||||
|
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
|
||||||
|
<CardTitle className="text-sm font-medium">Success Rate</CardTitle>
|
||||||
|
<Zap className="h-4 w-4 text-muted-foreground" />
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
<div className="text-2xl font-bold">
|
||||||
|
{metrics.total_requests > 0
|
||||||
|
? Math.round((metrics.successful_requests / metrics.total_requests) * 100)
|
||||||
|
: 0}%
|
||||||
|
</div>
|
||||||
|
<p className="text-xs text-muted-foreground">
|
||||||
|
{metrics.successful_requests.toLocaleString()} / {metrics.total_requests.toLocaleString()} requests
|
||||||
|
</p>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
<Card>
|
||||||
|
<CardHeader className="flex flex-row items-center justify-between space-y-0 pb-2">
|
||||||
|
<CardTitle className="text-sm font-medium">Security Score</CardTitle>
|
||||||
|
<Shield className="h-4 w-4 text-muted-foreground" />
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
<div className="text-2xl font-bold">
|
||||||
|
{Math.round((1 - metrics.average_risk_score) * 100)}%
|
||||||
|
</div>
|
||||||
|
<p className="text-xs text-muted-foreground">
|
||||||
|
{metrics.security_blocked_requests} blocked requests
|
||||||
|
</p>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Provider Details */}
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||||
|
{Object.entries(providers).map(([name, provider]) => (
|
||||||
|
<Card key={name} className={`border-2 ${getStatusColor(provider.status)}`}>
|
||||||
|
<CardHeader>
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<CardTitle className="text-lg flex items-center gap-2">
|
||||||
|
{getStatusIcon(provider.status)}
|
||||||
|
{provider.provider}
|
||||||
|
</CardTitle>
|
||||||
|
<Badge
|
||||||
|
variant={provider.status === 'healthy' ? 'default' : 'destructive'}
|
||||||
|
className="capitalize"
|
||||||
|
>
|
||||||
|
{provider.status}
|
||||||
|
</Badge>
|
||||||
|
</div>
|
||||||
|
<CardDescription>
|
||||||
|
{provider.models_available.length} models available
|
||||||
|
</CardDescription>
|
||||||
|
</CardHeader>
|
||||||
|
|
||||||
|
<CardContent className="space-y-3">
|
||||||
|
{/* Performance Metrics */}
|
||||||
|
{provider.latency_ms && (
|
||||||
|
<div className="flex justify-between items-center">
|
||||||
|
<span className="text-sm font-medium">Latency</span>
|
||||||
|
<span className={`text-sm font-mono ${getLatencyColor(provider.latency_ms)}`}>
|
||||||
|
{Math.round(provider.latency_ms)}ms
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{provider.success_rate !== undefined && (
|
||||||
|
<div className="flex justify-between items-center">
|
||||||
|
<span className="text-sm font-medium">Success Rate</span>
|
||||||
|
<span className="text-sm font-mono">
|
||||||
|
{Math.round(provider.success_rate * 100)}%
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="flex justify-between items-center">
|
||||||
|
<span className="text-sm font-medium">Last Check</span>
|
||||||
|
<span className="text-sm text-muted-foreground">
|
||||||
|
{new Date(provider.last_check).toLocaleTimeString()}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Error Message */}
|
||||||
|
{provider.error_message && (
|
||||||
|
<Alert variant="destructive" className="mt-3">
|
||||||
|
<AlertCircle className="h-4 w-4" />
|
||||||
|
<AlertDescription className="text-xs">
|
||||||
|
{provider.error_message}
|
||||||
|
</AlertDescription>
|
||||||
|
</Alert>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Models */}
|
||||||
|
<div className="space-y-2">
|
||||||
|
<span className="text-sm font-medium">Available Models</span>
|
||||||
|
<div className="flex flex-wrap gap-1">
|
||||||
|
{provider.models_available.slice(0, 3).map(model => (
|
||||||
|
<Badge key={model} variant="outline" className="text-xs">
|
||||||
|
{model}
|
||||||
|
</Badge>
|
||||||
|
))}
|
||||||
|
{provider.models_available.length > 3 && (
|
||||||
|
<Badge variant="outline" className="text-xs">
|
||||||
|
+{provider.models_available.length - 3} more
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Provider Metrics Details */}
|
||||||
|
{metrics && Object.keys(metrics.provider_metrics).length > 0 && (
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle className="flex items-center gap-2">
|
||||||
|
<Server className="h-5 w-5" />
|
||||||
|
Provider Performance Metrics
|
||||||
|
</CardTitle>
|
||||||
|
<CardDescription>
|
||||||
|
Detailed performance statistics for each provider
|
||||||
|
</CardDescription>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||||
|
{Object.entries(metrics.provider_metrics).map(([provider, data]: [string, any]) => (
|
||||||
|
<div key={provider} className="border rounded-lg p-4">
|
||||||
|
<h4 className="font-semibold mb-3 capitalize">{provider}</h4>
|
||||||
|
<div className="space-y-2 text-sm">
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Total Requests:</span>
|
||||||
|
<span className="font-mono">{data.total_requests?.toLocaleString() || 0}</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Success Rate:</span>
|
||||||
|
<span className="font-mono">
|
||||||
|
{data.success_rate ? Math.round(data.success_rate * 100) : 0}%
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Avg Latency:</span>
|
||||||
|
<span className={`font-mono ${getLatencyColor(data.average_latency_ms || 0)}`}>
|
||||||
|
{Math.round(data.average_latency_ms || 0)}ms
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
{data.token_usage && (
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Total Tokens:</span>
|
||||||
|
<span className="font-mono">
|
||||||
|
{data.token_usage.total_tokens?.toLocaleString() || 0}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user