mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
ollama embeddings
This commit is contained in:
152
.env
152
.env
@@ -1,152 +0,0 @@
|
||||
# ===================================
|
||||
# ENCLAVA MINIMAL CONFIGURATION
|
||||
# ===================================
|
||||
# Only essential environment variables that CANNOT have defaults
|
||||
# Other settings should be configurable through the app UI
|
||||
|
||||
# ===================================
|
||||
# INFRASTRUCTURE (Required)
|
||||
# ===================================
|
||||
DATABASE_URL=postgresql://enclava_user:enclava_pass@enclava-postgres:5432/enclava_db
|
||||
REDIS_URL=redis://enclava-redis:6379
|
||||
|
||||
# ===================================
|
||||
# SECURITY CRITICAL (Required)
|
||||
# ===================================
|
||||
JWT_SECRET=your-super-secret-jwt-key-here-change-in-production
|
||||
PRIVATEMODE_API_KEY=dfaea90e-df15-48d4-94ff-5ee243b846bb
|
||||
|
||||
# Admin user (created on first startup only)
|
||||
ADMIN_EMAIL=admin@example.com
|
||||
ADMIN_PASSWORD=admin123
|
||||
API_RATE_LIMITING_ENABLED=false
|
||||
# ===================================
|
||||
# ADDITIONAL SECURITY SETTINGS (Optional but recommended)
|
||||
# ===================================
|
||||
# JWT Algorithm (default: HS256)
|
||||
# JWT_ALGORITHM=HS256
|
||||
|
||||
# Token expiration times (in minutes)
|
||||
# ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
# REFRESH_TOKEN_EXPIRE_MINUTES=10080
|
||||
# SESSION_EXPIRE_MINUTES=1440
|
||||
|
||||
# API Key prefix (default: en_)
|
||||
# API_KEY_PREFIX=en_
|
||||
|
||||
# Security thresholds (0.0-1.0)
|
||||
# API_SECURITY_RISK_THRESHOLD=0.8
|
||||
# API_SECURITY_WARNING_THRESHOLD=0.6
|
||||
# API_SECURITY_ANOMALY_THRESHOLD=0.7
|
||||
|
||||
# IP security (comma-separated for multiple IPs)
|
||||
# API_BLOCKED_IPS=
|
||||
# API_ALLOWED_IPS=
|
||||
|
||||
# ===================================
|
||||
# APPLICATION BASE URL (Required - derives all URLs and CORS)
|
||||
# ===================================
|
||||
BASE_URL=localhost:80
|
||||
# Frontend derives: APP_URL=http://localhost, API_URL=http://localhost, WS_URL=ws://localhost
|
||||
# Backend derives: CORS_ORIGINS=["http://localhost"]
|
||||
|
||||
# ===================================
|
||||
# DOCKER NETWORKING (Required for containers)
|
||||
# ===================================
|
||||
BACKEND_INTERNAL_PORT=8000
|
||||
FRONTEND_INTERNAL_PORT=3000
|
||||
# Hosts are fixed: enclava-backend, enclava-frontend
|
||||
# Upstreams derive: enclava-backend:8000, enclava-frontend:3000
|
||||
|
||||
# ===================================
|
||||
# QDRANT (Required for RAG)
|
||||
# ===================================
|
||||
QDRANT_HOST=enclava-qdrant
|
||||
QDRANT_PORT=6333
|
||||
QDRANT_URL=http://enclava-qdrant:6333
|
||||
|
||||
# ===================================
|
||||
# OPTIONAL PRIVATEMODE SETTINGS (Have defaults)
|
||||
# ===================================
|
||||
# PRIVATEMODE_CACHE_MODE=none # Optional: defaults to 'none'
|
||||
# PRIVATEMODE_CACHE_SALT= # Optional: defaults to empty
|
||||
|
||||
# ===================================
|
||||
# OPTIONAL CONFIGURATION (All have sensible defaults)
|
||||
# ===================================
|
||||
|
||||
# Application Settings
|
||||
# APP_NAME=Enclava
|
||||
# APP_DEBUG=false
|
||||
# APP_LOG_LEVEL=INFO
|
||||
# APP_HOST=0.0.0.0
|
||||
# APP_PORT=8000
|
||||
|
||||
# Security Features
|
||||
API_SECURITY_ENABLED=false
|
||||
# API_THREAT_DETECTION_ENABLED=true
|
||||
# API_IP_REPUTATION_ENABLED=true
|
||||
# API_ANOMALY_DETECTION_ENABLED=true
|
||||
API_RATE_LIMITING_ENABLED=false
|
||||
# API_SECURITY_HEADERS_ENABLED=true
|
||||
|
||||
# Content Security Policy
|
||||
# API_CSP_HEADER=default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'
|
||||
|
||||
# Rate Limiting (requests per minute/hour)
|
||||
# API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE=300
|
||||
# API_RATE_LIMIT_AUTHENTICATED_PER_HOUR=5000
|
||||
# API_RATE_LIMIT_API_KEY_PER_MINUTE=1000
|
||||
# API_RATE_LIMIT_API_KEY_PER_HOUR=20000
|
||||
# API_RATE_LIMIT_PREMIUM_PER_MINUTE=5000
|
||||
# API_RATE_LIMIT_PREMIUM_PER_HOUR=100000
|
||||
|
||||
# Request Size Limits (in bytes)
|
||||
# API_MAX_REQUEST_BODY_SIZE=10485760 # 10MB
|
||||
# API_MAX_REQUEST_BODY_SIZE_PREMIUM=52428800 # 50MB
|
||||
# MAX_UPLOAD_SIZE=10485760 # 10MB
|
||||
|
||||
# Monitoring
|
||||
# PROMETHEUS_ENABLED=true
|
||||
# PROMETHEUS_PORT=9090
|
||||
|
||||
# Logging
|
||||
# LOG_FORMAT=json
|
||||
# LOG_LEVEL=INFO
|
||||
# LOG_LLM_PROMPTS=false
|
||||
|
||||
# Module Configuration
|
||||
# MODULES_CONFIG_PATH=config/modules.yaml
|
||||
|
||||
# Plugin Configuration
|
||||
# PLUGINS_DIR=/plugins
|
||||
# PLUGINS_CONFIG_PATH=config/plugins.yaml
|
||||
# PLUGIN_REPOSITORY_URL=https://plugins.enclava.com
|
||||
# PLUGIN_ENCRYPTION_KEY=
|
||||
|
||||
# ===================================
|
||||
# RAG EMBEDDING ENHANCED SETTINGS
|
||||
# ===================================
|
||||
# Enhanced embedding service configuration
|
||||
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE=60
|
||||
RAG_EMBEDDING_BATCH_SIZE=5
|
||||
RAG_EMBEDDING_RETRY_COUNT=3
|
||||
RAG_EMBEDDING_RETRY_DELAYS=1,2,4,8,16
|
||||
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES=0.5
|
||||
|
||||
# Fallback embedding behavior
|
||||
RAG_ALLOW_FALLBACK_EMBEDDINGS=true
|
||||
RAG_WARN_ON_FALLBACK=true
|
||||
|
||||
# Processing timeouts (in seconds)
|
||||
RAG_DOCUMENT_PROCESSING_TIMEOUT=300
|
||||
RAG_EMBEDDING_GENERATION_TIMEOUT=120
|
||||
RAG_INDEXING_TIMEOUT=120
|
||||
|
||||
# ===================================
|
||||
# SUMMARY
|
||||
# ===================================
|
||||
# Required: DATABASE_URL, REDIS_URL, JWT_SECRET, ADMIN_EMAIL, ADMIN_PASSWORD, BASE_URL
|
||||
# Recommended: PRIVATEMODE_API_KEY, QDRANT_HOST, QDRANT_PORT
|
||||
# Optional: All other settings have secure defaults
|
||||
# ===================================
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -12,7 +12,7 @@ env.bak/
|
||||
venv.bak/
|
||||
*.sqlite3
|
||||
*.db
|
||||
|
||||
.env
|
||||
# FastAPI logs
|
||||
*.log
|
||||
|
||||
|
||||
350
backend/app/api/health.py
Normal file
350
backend/app/api/health.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
Enhanced Health Check Endpoints
|
||||
|
||||
Provides comprehensive health monitoring including:
|
||||
- Basic HTTP health
|
||||
- Resource usage checks
|
||||
- Session leak detection
|
||||
- Database connectivity
|
||||
- Service dependencies
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import psutil
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.db.database import async_session_factory
|
||||
from app.services.embedding_service import embedding_service
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class HealthChecker:
|
||||
"""Comprehensive health checking service"""
|
||||
|
||||
def __init__(self):
|
||||
self.last_checks: Dict[str, Dict] = {}
|
||||
self.check_history: Dict[str, list] = {}
|
||||
|
||||
async def check_database_health(self) -> Dict[str, Any]:
|
||||
"""Check database connectivity and performance"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
# Simple query to check connectivity
|
||||
await session.execute(select(1))
|
||||
|
||||
# Check table availability
|
||||
await session.execute(text("SELECT COUNT(*) FROM information_schema.tables"))
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"response_time_ms": round(duration * 1000, 2),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"details": {
|
||||
"connection": "successful",
|
||||
"query_execution": "successful"
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"details": {
|
||||
"connection": "failed",
|
||||
"error_type": type(e).__name__
|
||||
}
|
||||
}
|
||||
|
||||
async def check_memory_health(self) -> Dict[str, Any]:
|
||||
"""Check memory usage and detect potential leaks"""
|
||||
try:
|
||||
memory = psutil.virtual_memory()
|
||||
process = psutil.Process()
|
||||
|
||||
# Get process-specific memory
|
||||
process_memory = process.memory_info()
|
||||
process_memory_mb = process_memory.rss / (1024 * 1024)
|
||||
|
||||
# Check for memory issues
|
||||
memory_status = "healthy"
|
||||
issues = []
|
||||
|
||||
if process_memory_mb > 4000: # 4GB threshold
|
||||
memory_status = "warning"
|
||||
issues.append(f"High memory usage: {process_memory_mb:.1f}MB")
|
||||
|
||||
if process_memory_mb > 8000: # 8GB critical threshold
|
||||
memory_status = "critical"
|
||||
issues.append(f"Critical memory usage: {process_memory_mb:.1f}MB")
|
||||
|
||||
# Check system memory pressure
|
||||
if memory.percent > 90:
|
||||
memory_status = "critical"
|
||||
issues.append(f"System memory pressure: {memory.percent:.1f}%")
|
||||
elif memory.percent > 80:
|
||||
if memory_status == "healthy":
|
||||
memory_status = "warning"
|
||||
issues.append(f"High system memory usage: {memory.percent:.1f}%")
|
||||
|
||||
return {
|
||||
"status": memory_status,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"process_memory_mb": round(process_memory_mb, 2),
|
||||
"system_memory_percent": memory.percent,
|
||||
"system_available_gb": round(memory.available / (1024**3), 2),
|
||||
"issues": issues
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Memory health check failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
async def check_connection_health(self) -> Dict[str, Any]:
|
||||
"""Check for connection leaks and network health"""
|
||||
try:
|
||||
process = psutil.Process()
|
||||
|
||||
# Get network connections
|
||||
connections = process.connections()
|
||||
|
||||
# Analyze connections
|
||||
total_connections = len(connections)
|
||||
established_connections = len([c for c in connections if c.status == 'ESTABLISHED'])
|
||||
http_connections = len([c for c in connections if any(port in str(c.laddr) for port in [80, 8000, 3000])])
|
||||
|
||||
# Check for connection issues
|
||||
connection_status = "healthy"
|
||||
issues = []
|
||||
|
||||
if total_connections > 500:
|
||||
connection_status = "warning"
|
||||
issues.append(f"High connection count: {total_connections}")
|
||||
|
||||
if total_connections > 1000:
|
||||
connection_status = "critical"
|
||||
issues.append(f"Critical connection count: {total_connections}")
|
||||
|
||||
# Check for potential session leaks (high number of connections to HTTP ports)
|
||||
if http_connections > 100:
|
||||
connection_status = "warning"
|
||||
issues.append(f"High HTTP connection count: {http_connections}")
|
||||
|
||||
return {
|
||||
"status": connection_status,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"total_connections": total_connections,
|
||||
"established_connections": established_connections,
|
||||
"http_connections": http_connections,
|
||||
"issues": issues
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Connection health check failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
async def check_embedding_service_health(self) -> Dict[str, Any]:
|
||||
"""Check embedding service health and session management"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Get embedding service stats
|
||||
stats = await embedding_service.get_stats()
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Check service status
|
||||
service_status = "healthy" if stats.get("initialized", False) else "warning"
|
||||
issues = []
|
||||
|
||||
if not stats.get("initialized", False):
|
||||
issues.append("Embedding service not initialized")
|
||||
|
||||
# Check backend type
|
||||
backend = stats.get("backend", "unknown")
|
||||
if backend == "fallback_random":
|
||||
service_status = "warning"
|
||||
issues.append("Using fallback random embeddings")
|
||||
|
||||
return {
|
||||
"status": service_status,
|
||||
"response_time_ms": round(duration * 1000, 2),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"stats": stats,
|
||||
"issues": issues
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding service health check failed: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
async def check_redis_health(self) -> Dict[str, Any]:
|
||||
"""Check Redis connectivity"""
|
||||
if not settings.REDIS_URL:
|
||||
return {
|
||||
"status": "not_configured",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
client = redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
socket_connect_timeout=2.0,
|
||||
socket_timeout=2.0,
|
||||
)
|
||||
|
||||
# Test Redis connection
|
||||
await asyncio.wait_for(client.ping(), timeout=3.0)
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
await client.close()
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"response_time_ms": round(duration * 1000, 2),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
async def get_comprehensive_health(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive health status"""
|
||||
checks = {
|
||||
"database": await self.check_database_health(),
|
||||
"memory": await self.check_memory_health(),
|
||||
"connections": await self.check_connection_health(),
|
||||
"embedding_service": await self.check_embedding_service_health(),
|
||||
"redis": await self.check_redis_health()
|
||||
}
|
||||
|
||||
# Determine overall status
|
||||
statuses = [check.get("status", "error") for check in checks.values()]
|
||||
|
||||
if "critical" in statuses or "error" in statuses:
|
||||
overall_status = "unhealthy"
|
||||
elif "warning" in statuses or "unhealthy" in statuses:
|
||||
overall_status = "degraded"
|
||||
else:
|
||||
overall_status = "healthy"
|
||||
|
||||
# Count issues
|
||||
total_issues = sum(len(check.get("issues", [])) for check in checks.values())
|
||||
|
||||
return {
|
||||
"status": overall_status,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"checks": checks,
|
||||
"summary": {
|
||||
"total_checks": len(checks),
|
||||
"healthy_checks": len([s for s in statuses if s == "healthy"]),
|
||||
"degraded_checks": len([s for s in statuses if s in ["warning", "degraded", "unhealthy"]]),
|
||||
"failed_checks": len([s for s in statuses if s in ["critical", "error"]]),
|
||||
"total_issues": total_issues
|
||||
},
|
||||
"version": "1.0.0",
|
||||
"uptime_seconds": int(time.time() - psutil.boot_time())
|
||||
}
|
||||
|
||||
|
||||
# Global health checker instance
|
||||
health_checker = HealthChecker()
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def basic_health_check():
|
||||
"""Basic health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"app": settings.APP_NAME,
|
||||
"version": "1.0.0",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health/detailed")
|
||||
async def detailed_health_check():
|
||||
"""Comprehensive health check with all services"""
|
||||
try:
|
||||
return await health_checker.get_comprehensive_health()
|
||||
except Exception as e:
|
||||
logger.error(f"Detailed health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Health check failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health/memory")
|
||||
async def memory_health_check():
|
||||
"""Memory-specific health check"""
|
||||
try:
|
||||
return await health_checker.check_memory_health()
|
||||
except Exception as e:
|
||||
logger.error(f"Memory health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Memory health check failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health/connections")
|
||||
async def connection_health_check():
|
||||
"""Connection-specific health check"""
|
||||
try:
|
||||
return await health_checker.check_connection_health()
|
||||
except Exception as e:
|
||||
logger.error(f"Connection health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Connection health check failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health/embedding")
|
||||
async def embedding_service_health_check():
|
||||
"""Embedding service-specific health check"""
|
||||
try:
|
||||
return await health_checker.check_embedding_service_health()
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding service health check failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Embedding service health check failed: {str(e)}"
|
||||
)
|
||||
@@ -778,53 +778,170 @@ async def external_chat_with_chatbot(
|
||||
raise HTTPException(status_code=500, detail=f"Failed to process chat: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/external/{chatbot_id}/chat/completions", response_model=ChatbotChatCompletionResponse)
|
||||
async def external_chatbot_chat_completions(
|
||||
# OpenAI-compatible models response for chatbot
|
||||
class ChatbotModelsResponse(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# Implementation functions for OpenAI compatibility (called by v1 endpoints)
|
||||
async def external_chatbot_models(
|
||||
chatbot_id: str,
|
||||
request: ChatbotChatCompletionRequest,
|
||||
api_key: APIKey = Depends(get_api_key_auth),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
api_key: APIKey,
|
||||
db: AsyncSession
|
||||
):
|
||||
"""External OpenAI-compatible chat completions endpoint for chatbot with API key authentication"""
|
||||
log_api_request("external_chatbot_chat_completions", {
|
||||
"""
|
||||
OpenAI-compatible models endpoint implementation
|
||||
Returns only the model configured for this specific chatbot
|
||||
"""
|
||||
log_api_request("external_chatbot_models", {
|
||||
"chatbot_id": chatbot_id,
|
||||
"api_key_id": api_key.id,
|
||||
"messages_count": len(request.messages)
|
||||
"api_key_id": api_key.id
|
||||
})
|
||||
|
||||
|
||||
try:
|
||||
# Check if API key can access this chatbot
|
||||
if not api_key.can_access_chatbot(chatbot_id):
|
||||
raise HTTPException(status_code=403, detail="API key not authorized for this chatbot")
|
||||
|
||||
|
||||
# Get the chatbot instance
|
||||
result = await db.execute(
|
||||
select(ChatbotInstance)
|
||||
.where(ChatbotInstance.id == chatbot_id)
|
||||
)
|
||||
chatbot = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not chatbot:
|
||||
raise HTTPException(status_code=404, detail="Chatbot not found")
|
||||
|
||||
|
||||
if not chatbot.is_active:
|
||||
raise HTTPException(status_code=400, detail="Chatbot is not active")
|
||||
|
||||
|
||||
# Get the configured model from chatbot config
|
||||
model_name = chatbot.config.get("model", "gpt-3.5-turbo")
|
||||
|
||||
# Return OpenAI-compatible models response with just this model
|
||||
return ChatbotModelsResponse(
|
||||
object="list",
|
||||
data=[
|
||||
{
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "enclava-chatbot"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
log_api_request("external_chatbot_models_error", {"error": str(e), "chatbot_id": chatbot_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to retrieve models: {str(e)}")
|
||||
|
||||
|
||||
async def external_chatbot_retrieve_model(
|
||||
chatbot_id: str,
|
||||
model_id: str,
|
||||
api_key: APIKey,
|
||||
db: AsyncSession
|
||||
):
|
||||
"""
|
||||
OpenAI-compatible model retrieve endpoint implementation
|
||||
Returns model info if the model matches the chatbot's configured model
|
||||
"""
|
||||
log_api_request("external_chatbot_retrieve_model", {
|
||||
"chatbot_id": chatbot_id,
|
||||
"model_id": model_id,
|
||||
"api_key_id": api_key.id
|
||||
})
|
||||
|
||||
try:
|
||||
# Check if API key can access this chatbot
|
||||
if not api_key.can_access_chatbot(chatbot_id):
|
||||
raise HTTPException(status_code=403, detail="API key not authorized for this chatbot")
|
||||
|
||||
# Get the chatbot instance
|
||||
result = await db.execute(
|
||||
select(ChatbotInstance)
|
||||
.where(ChatbotInstance.id == chatbot_id)
|
||||
)
|
||||
chatbot = result.scalar_one_or_none()
|
||||
|
||||
if not chatbot:
|
||||
raise HTTPException(status_code=404, detail="Chatbot not found")
|
||||
|
||||
if not chatbot.is_active:
|
||||
raise HTTPException(status_code=400, detail="Chatbot is not active")
|
||||
|
||||
# Get the configured model from chatbot config
|
||||
configured_model = chatbot.config.get("model", "gpt-3.5-turbo")
|
||||
|
||||
# Check if requested model matches the configured model
|
||||
if model_id != configured_model:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found")
|
||||
|
||||
# Return OpenAI-compatible model info
|
||||
return {
|
||||
"id": configured_model,
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "enclava-chatbot"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
log_api_request("external_chatbot_retrieve_model_error", {"error": str(e), "chatbot_id": chatbot_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to retrieve model: {str(e)}")
|
||||
|
||||
|
||||
async def external_chatbot_chat_completions(
|
||||
chatbot_id: str,
|
||||
request: ChatbotChatCompletionRequest,
|
||||
api_key: APIKey,
|
||||
db: AsyncSession
|
||||
):
|
||||
"""External OpenAI-compatible chat completions endpoint implementation with API key authentication"""
|
||||
log_api_request("external_chatbot_chat_completions", {
|
||||
"chatbot_id": chatbot_id,
|
||||
"api_key_id": api_key.id,
|
||||
"messages_count": len(request.messages)
|
||||
})
|
||||
|
||||
try:
|
||||
# Check if API key can access this chatbot
|
||||
if not api_key.can_access_chatbot(chatbot_id):
|
||||
raise HTTPException(status_code=403, detail="API key not authorized for this chatbot")
|
||||
|
||||
# Get the chatbot instance
|
||||
result = await db.execute(
|
||||
select(ChatbotInstance)
|
||||
.where(ChatbotInstance.id == chatbot_id)
|
||||
)
|
||||
chatbot = result.scalar_one_or_none()
|
||||
|
||||
if not chatbot:
|
||||
raise HTTPException(status_code=404, detail="Chatbot not found")
|
||||
|
||||
if not chatbot.is_active:
|
||||
raise HTTPException(status_code=400, detail="Chatbot is not active")
|
||||
|
||||
# Find the last user message to extract conversation context
|
||||
user_messages = [msg for msg in request.messages if msg.role == "user"]
|
||||
if not user_messages:
|
||||
raise HTTPException(status_code=400, detail="No user message found in conversation")
|
||||
|
||||
|
||||
last_user_message = user_messages[-1].content
|
||||
|
||||
|
||||
# Initialize conversation service
|
||||
conversation_service = ConversationService(db)
|
||||
|
||||
|
||||
# For OpenAI format, we'll try to find an existing conversation or create a new one
|
||||
# We'll use a simple hash of the conversation messages as the conversation identifier
|
||||
import hashlib
|
||||
conv_hash = hashlib.md5(str([f"{msg.role}:{msg.content}" for msg in request.messages]).encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
# Get or create conversation with API key context
|
||||
conversation = await conversation_service.get_or_create_conversation(
|
||||
chatbot_id=chatbot_id,
|
||||
@@ -832,12 +949,12 @@ async def external_chatbot_chat_completions(
|
||||
conversation_id=conv_hash,
|
||||
title=f"API Chat {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
|
||||
)
|
||||
|
||||
|
||||
# Add API key metadata to conversation context if new
|
||||
if not conversation.context_data.get("api_key_id"):
|
||||
conversation.context_data = {"api_key_id": api_key.id}
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Build conversation history from the request messages
|
||||
conversation_history = []
|
||||
for msg in request.messages:
|
||||
@@ -846,20 +963,20 @@ async def external_chatbot_chat_completions(
|
||||
"role": msg.role,
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
|
||||
# Get chatbot module and generate response
|
||||
try:
|
||||
chatbot_module = module_manager.modules.get("chatbot")
|
||||
if not chatbot_module:
|
||||
raise HTTPException(status_code=500, detail="Chatbot module not available")
|
||||
|
||||
|
||||
# Merge chatbot config with request parameters
|
||||
effective_config = dict(chatbot.config)
|
||||
if request.temperature is not None:
|
||||
effective_config["temperature"] = request.temperature
|
||||
if request.max_tokens is not None:
|
||||
effective_config["max_tokens"] = request.max_tokens
|
||||
|
||||
|
||||
# Use the chatbot module to generate a response
|
||||
response_data = await chatbot_module.chat(
|
||||
chatbot_config=effective_config,
|
||||
@@ -867,10 +984,10 @@ async def external_chatbot_chat_completions(
|
||||
conversation_history=conversation_history,
|
||||
user_id=f"api_key_{api_key.id}"
|
||||
)
|
||||
|
||||
|
||||
response_content = response_data.get("response", "I'm sorry, I couldn't generate a response.")
|
||||
sources = response_data.get("sources")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Use fallback response
|
||||
fallback_responses = chatbot.config.get("fallback_responses", [
|
||||
@@ -878,7 +995,7 @@ async def external_chatbot_chat_completions(
|
||||
])
|
||||
response_content = fallback_responses[0] if fallback_responses else "I'm sorry, I couldn't process your request."
|
||||
sources = None
|
||||
|
||||
|
||||
# Save the conversation messages
|
||||
for msg in request.messages:
|
||||
if msg.role == "user": # Only save the new user message
|
||||
@@ -888,7 +1005,7 @@ async def external_chatbot_chat_completions(
|
||||
content=msg.content,
|
||||
metadata={"api_key_id": api_key.id}
|
||||
)
|
||||
|
||||
|
||||
# Save assistant message using conversation service
|
||||
assistant_message = await conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
@@ -897,18 +1014,18 @@ async def external_chatbot_chat_completions(
|
||||
metadata={"api_key_id": api_key.id},
|
||||
sources=sources
|
||||
)
|
||||
|
||||
|
||||
# Update API key usage stats
|
||||
prompt_tokens = sum(len(msg.content.split()) for msg in request.messages)
|
||||
completion_tokens = len(response_content.split())
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
|
||||
api_key.update_usage(tokens_used=total_tokens, cost_cents=0)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Create OpenAI-compatible response
|
||||
response_id = f"chatbot-{chatbot_id}-{int(time.time())}"
|
||||
|
||||
|
||||
return ChatbotChatCompletionResponse(
|
||||
id=response_id,
|
||||
object="chat.completion",
|
||||
@@ -927,10 +1044,48 @@ async def external_chatbot_chat_completions(
|
||||
total_tokens=total_tokens
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
log_api_request("external_chatbot_chat_completions_error", {"error": str(e), "chatbot_id": chatbot_id})
|
||||
raise HTTPException(status_code=500, detail=f"Failed to process chat completions: {str(e)}")
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/external/{chatbot_id}/v1/models", response_model=ChatbotModelsResponse)
|
||||
async def external_chatbot_models_v1(
|
||||
chatbot_id: str,
|
||||
api_key: APIKey = Depends(get_api_key_auth),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""OpenAI v1 API compatible models endpoint with /v1 prefix"""
|
||||
return await external_chatbot_models(chatbot_id, api_key, db)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/external/{chatbot_id}/v1/models/{model_id}")
|
||||
async def external_chatbot_retrieve_model_v1(
|
||||
chatbot_id: str,
|
||||
model_id: str,
|
||||
api_key: APIKey = Depends(get_api_key_auth),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""OpenAI v1 API compatible model retrieve endpoint with /v1 prefix"""
|
||||
return await external_chatbot_retrieve_model(chatbot_id, model_id, api_key, db)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/external/{chatbot_id}/v1/chat/completions", response_model=ChatbotChatCompletionResponse)
|
||||
async def external_chatbot_chat_completions_v1(
|
||||
chatbot_id: str,
|
||||
request: ChatbotChatCompletionRequest,
|
||||
api_key: APIKey = Depends(get_api_key_auth),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""OpenAI v1 API compatible chat completions endpoint with /v1 prefix"""
|
||||
return await external_chatbot_chat_completions(chatbot_id, request, api_key, db)
|
||||
|
||||
@@ -175,19 +175,27 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Cleanup
|
||||
logger.info("Shutting down platform...")
|
||||
|
||||
|
||||
# Cleanup embedding service HTTP sessions
|
||||
from app.services.embedding_service import embedding_service
|
||||
try:
|
||||
await embedding_service.cleanup()
|
||||
logger.info("Embedding service cleaned up successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up embedding service: {e}")
|
||||
|
||||
# Close core cache service
|
||||
from app.core.cache import core_cache
|
||||
await core_cache.cleanup()
|
||||
|
||||
|
||||
# Close Redis connection for cached API key service
|
||||
from app.services.cached_api_key import cached_api_key_service
|
||||
await cached_api_key_service.close()
|
||||
|
||||
|
||||
# Stop document processor
|
||||
if hasattr(app.state, 'document_processor'):
|
||||
await app.state.document_processor.stop()
|
||||
|
||||
|
||||
await module_manager.cleanup()
|
||||
logger.info("Platform shutdown complete")
|
||||
|
||||
|
||||
@@ -139,26 +139,26 @@ async def get_api_key_context(
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Dependency to get API key context from request"""
|
||||
auth_service = APIKeyAuthService(db)
|
||||
|
||||
|
||||
# Try different auth methods
|
||||
api_key = None
|
||||
|
||||
|
||||
# 1. Check Authorization header (Bearer token)
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
api_key = auth_header[7:]
|
||||
|
||||
|
||||
# 2. Check X-API-Key header
|
||||
if not api_key:
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
|
||||
|
||||
# 3. Check query parameter
|
||||
if not api_key:
|
||||
api_key = request.query_params.get("api_key")
|
||||
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
|
||||
return await auth_service.validate_api_key(api_key, request)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user