mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
152
.env
Normal file
152
.env
Normal file
@@ -0,0 +1,152 @@
|
||||
# ===================================
|
||||
# ENCLAVA MINIMAL CONFIGURATION
|
||||
# ===================================
|
||||
# Only essential environment variables that CANNOT have defaults
|
||||
# Other settings should be configurable through the app UI
|
||||
|
||||
# ===================================
|
||||
# INFRASTRUCTURE (Required)
|
||||
# ===================================
|
||||
DATABASE_URL=postgresql://enclava_user:enclava_pass@enclava-postgres:5432/enclava_db
|
||||
REDIS_URL=redis://enclava-redis:6379
|
||||
|
||||
# ===================================
|
||||
# SECURITY CRITICAL (Required)
|
||||
# ===================================
|
||||
JWT_SECRET=your-super-secret-jwt-key-here-change-in-production
|
||||
PRIVATEMODE_API_KEY=dfaea90e-df15-48d4-94ff-5ee243b846bb
|
||||
|
||||
# Admin user (created on first startup only)
|
||||
ADMIN_EMAIL=admin@example.com
|
||||
ADMIN_PASSWORD=admin123
|
||||
API_RATE_LIMITING_ENABLED=false
|
||||
# ===================================
|
||||
# ADDITIONAL SECURITY SETTINGS (Optional but recommended)
|
||||
# ===================================
|
||||
# JWT Algorithm (default: HS256)
|
||||
# JWT_ALGORITHM=HS256
|
||||
|
||||
# Token expiration times (in minutes)
|
||||
# ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
# REFRESH_TOKEN_EXPIRE_MINUTES=10080
|
||||
# SESSION_EXPIRE_MINUTES=1440
|
||||
|
||||
# API Key prefix (default: en_)
|
||||
# API_KEY_PREFIX=en_
|
||||
|
||||
# Security thresholds (0.0-1.0)
|
||||
# API_SECURITY_RISK_THRESHOLD=0.8
|
||||
# API_SECURITY_WARNING_THRESHOLD=0.6
|
||||
# API_SECURITY_ANOMALY_THRESHOLD=0.7
|
||||
|
||||
# IP security (comma-separated for multiple IPs)
|
||||
# API_BLOCKED_IPS=
|
||||
# API_ALLOWED_IPS=
|
||||
|
||||
# ===================================
|
||||
# APPLICATION BASE URL (Required - derives all URLs and CORS)
|
||||
# ===================================
|
||||
BASE_URL=localhost:80
|
||||
# Frontend derives: APP_URL=http://localhost, API_URL=http://localhost, WS_URL=ws://localhost
|
||||
# Backend derives: CORS_ORIGINS=["http://localhost"]
|
||||
|
||||
# ===================================
|
||||
# DOCKER NETWORKING (Required for containers)
|
||||
# ===================================
|
||||
BACKEND_INTERNAL_PORT=8000
|
||||
FRONTEND_INTERNAL_PORT=3000
|
||||
# Hosts are fixed: enclava-backend, enclava-frontend
|
||||
# Upstreams derive: enclava-backend:8000, enclava-frontend:3000
|
||||
|
||||
# ===================================
|
||||
# QDRANT (Required for RAG)
|
||||
# ===================================
|
||||
QDRANT_HOST=enclava-qdrant
|
||||
QDRANT_PORT=6333
|
||||
QDRANT_URL=http://enclava-qdrant:6333
|
||||
|
||||
# ===================================
|
||||
# OPTIONAL PRIVATEMODE SETTINGS (Have defaults)
|
||||
# ===================================
|
||||
# PRIVATEMODE_CACHE_MODE=none # Optional: defaults to 'none'
|
||||
# PRIVATEMODE_CACHE_SALT= # Optional: defaults to empty
|
||||
|
||||
# ===================================
|
||||
# OPTIONAL CONFIGURATION (All have sensible defaults)
|
||||
# ===================================
|
||||
|
||||
# Application Settings
|
||||
# APP_NAME=Enclava
|
||||
# APP_DEBUG=false
|
||||
# APP_LOG_LEVEL=INFO
|
||||
# APP_HOST=0.0.0.0
|
||||
# APP_PORT=8000
|
||||
|
||||
# Security Features
|
||||
API_SECURITY_ENABLED=false
|
||||
# API_THREAT_DETECTION_ENABLED=true
|
||||
# API_IP_REPUTATION_ENABLED=true
|
||||
# API_ANOMALY_DETECTION_ENABLED=true
|
||||
API_RATE_LIMITING_ENABLED=false
|
||||
# API_SECURITY_HEADERS_ENABLED=true
|
||||
|
||||
# Content Security Policy
|
||||
# API_CSP_HEADER=default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'
|
||||
|
||||
# Rate Limiting (requests per minute/hour)
|
||||
# API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE=300
|
||||
# API_RATE_LIMIT_AUTHENTICATED_PER_HOUR=5000
|
||||
# API_RATE_LIMIT_API_KEY_PER_MINUTE=1000
|
||||
# API_RATE_LIMIT_API_KEY_PER_HOUR=20000
|
||||
# API_RATE_LIMIT_PREMIUM_PER_MINUTE=5000
|
||||
# API_RATE_LIMIT_PREMIUM_PER_HOUR=100000
|
||||
|
||||
# Request Size Limits (in bytes)
|
||||
# API_MAX_REQUEST_BODY_SIZE=10485760 # 10MB
|
||||
# API_MAX_REQUEST_BODY_SIZE_PREMIUM=52428800 # 50MB
|
||||
# MAX_UPLOAD_SIZE=10485760 # 10MB
|
||||
|
||||
# Monitoring
|
||||
# PROMETHEUS_ENABLED=true
|
||||
# PROMETHEUS_PORT=9090
|
||||
|
||||
# Logging
|
||||
# LOG_FORMAT=json
|
||||
# LOG_LEVEL=INFO
|
||||
# LOG_LLM_PROMPTS=false
|
||||
|
||||
# Module Configuration
|
||||
# MODULES_CONFIG_PATH=config/modules.yaml
|
||||
|
||||
# Plugin Configuration
|
||||
# PLUGINS_DIR=/plugins
|
||||
# PLUGINS_CONFIG_PATH=config/plugins.yaml
|
||||
# PLUGIN_REPOSITORY_URL=https://plugins.enclava.com
|
||||
# PLUGIN_ENCRYPTION_KEY=
|
||||
|
||||
# ===================================
|
||||
# RAG EMBEDDING ENHANCED SETTINGS
|
||||
# ===================================
|
||||
# Enhanced embedding service configuration
|
||||
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE=60
|
||||
RAG_EMBEDDING_BATCH_SIZE=5
|
||||
RAG_EMBEDDING_RETRY_COUNT=3
|
||||
RAG_EMBEDDING_RETRY_DELAYS=1,2,4,8,16
|
||||
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES=0.5
|
||||
|
||||
# Fallback embedding behavior
|
||||
RAG_ALLOW_FALLBACK_EMBEDDINGS=true
|
||||
RAG_WARN_ON_FALLBACK=true
|
||||
|
||||
# Processing timeouts (in seconds)
|
||||
RAG_DOCUMENT_PROCESSING_TIMEOUT=300
|
||||
RAG_EMBEDDING_GENERATION_TIMEOUT=120
|
||||
RAG_INDEXING_TIMEOUT=120
|
||||
|
||||
# ===================================
|
||||
# SUMMARY
|
||||
# ===================================
|
||||
# Required: DATABASE_URL, REDIS_URL, JWT_SECRET, ADMIN_EMAIL, ADMIN_PASSWORD, BASE_URL
|
||||
# Recommended: PRIVATEMODE_API_KEY, QDRANT_HOST, QDRANT_PORT
|
||||
# Optional: All other settings have secure defaults
|
||||
# ===================================
|
||||
10
.env.example
10
.env.example
@@ -77,6 +77,16 @@ QDRANT_HOST=enclava-qdrant
|
||||
QDRANT_PORT=6333
|
||||
QDRANT_URL=http://enclava-qdrant:6333
|
||||
|
||||
# ===================================
|
||||
# RAG EMBEDDING CONFIGURATION (Optional overrides)
|
||||
# ===================================
|
||||
# These control embedding throughput to avoid provider 429s.
|
||||
# Defaults are conservative; uncomment to override.
|
||||
# RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE=12
|
||||
# RAG_EMBEDDING_BATCH_SIZE=3
|
||||
# RAG_EMBEDDING_DELAY_BETWEEN_BATCHES=1.0 # seconds
|
||||
# RAG_EMBEDDING_DELAY_PER_REQUEST=0.5 # seconds
|
||||
|
||||
# ===================================
|
||||
# OPTIONAL PRIVATEMODE SETTINGS (Have defaults)
|
||||
# ===================================
|
||||
|
||||
69
.gitignore
vendored
69
.gitignore
vendored
@@ -1,3 +1,68 @@
|
||||
.env
|
||||
backend/.config_encryption_key
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.env
|
||||
*.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
*.sqlite3
|
||||
*.db
|
||||
|
||||
# FastAPI logs
|
||||
*.log
|
||||
|
||||
# Node.js
|
||||
node_modules/
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
|
||||
# Next.js build
|
||||
frontend/.next/
|
||||
frontend/out/
|
||||
frontend/.env.local
|
||||
frontend/.env.production
|
||||
frontend/.env.development
|
||||
|
||||
|
||||
backend/storage/
|
||||
|
||||
# TypeScript
|
||||
*.tsbuildinfo
|
||||
|
||||
# Coverage reports
|
||||
htmlcov/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.pytest_cache/
|
||||
backend/.pytest_cache/
|
||||
backend/.mypy_cache/
|
||||
.mypy_cache/
|
||||
*.prof
|
||||
|
||||
backend/_to_delete/
|
||||
backend/__pycache__/
|
||||
backend/app/core/__pycache__/
|
||||
backend/app/services/__pycache__/
|
||||
backend/app/services/llm/__pycache__/
|
||||
backend/app/services/llm/providers/__pycache__/
|
||||
backend/app/utils/__pycache__/
|
||||
backend/modules/rag/__pycache__/
|
||||
frontend/.next/
|
||||
frontend/node_modules/
|
||||
node_modules/
|
||||
venv/
|
||||
|
||||
|
||||
0
backend/.env
Normal file
0
backend/.env
Normal file
@@ -17,6 +17,9 @@ RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y \
|
||||
ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CPU-only PyTorch and compatible numpy first (faster download)
|
||||
RUN pip install --no-cache-dir torch==2.5.1+cpu torchaudio==2.5.1+cpu --index-url https://download.pytorch.org/whl/cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
||||
# Copy requirements and install Python dependencies
|
||||
COPY requirements.txt .
|
||||
COPY tests/requirements-test.txt ./tests/
|
||||
|
||||
@@ -12,8 +12,8 @@ from ..v1.audit import router as audit_router
|
||||
from ..v1.settings import router as settings_router
|
||||
from ..v1.analytics import router as analytics_router
|
||||
from ..v1.rag import router as rag_router
|
||||
from ..rag_debug import router as rag_debug_router
|
||||
from ..v1.prompt_templates import router as prompt_templates_router
|
||||
from ..v1.security import router as security_router
|
||||
from ..v1.plugin_registry import router as plugin_registry_router
|
||||
from ..v1.platform import router as platform_router
|
||||
from ..v1.llm_internal import router as llm_internal_router
|
||||
@@ -53,11 +53,12 @@ internal_api_router.include_router(analytics_router, prefix="/analytics", tags=[
|
||||
# Include RAG routes (frontend RAG document management)
|
||||
internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"])
|
||||
|
||||
# Include RAG debug routes (for demo and debugging)
|
||||
internal_api_router.include_router(rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"])
|
||||
|
||||
# Include prompt template routes (frontend prompt template management)
|
||||
internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"])
|
||||
|
||||
# Include security routes (frontend security settings)
|
||||
internal_api_router.include_router(security_router, prefix="/security", tags=["internal-security"])
|
||||
|
||||
# Include plugin registry routes (frontend plugin management)
|
||||
internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"])
|
||||
|
||||
97
backend/app/api/rag_debug.py
Normal file
97
backend/app/api/rag_debug.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
RAG Debug API endpoints for testing and debugging
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
from app.core.security import get_current_user
|
||||
from app.core.config import settings
|
||||
from app.modules.rag.main import RAGModule
|
||||
from app.models.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create router
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/collections")
|
||||
async def list_collections(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""List all available RAG collections"""
|
||||
try:
|
||||
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||
|
||||
# Get collections from Qdrant (same as main RAG API)
|
||||
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||
collections = stats_data.get("collections", [])
|
||||
|
||||
# Extract collection names
|
||||
collection_names = [col["name"] for col in collections]
|
||||
|
||||
return {
|
||||
"collections": collection_names,
|
||||
"count": len(collection_names)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"List collections error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/search")
|
||||
async def debug_search(
|
||||
query: str = Query(..., description="Search query"),
|
||||
max_results: int = Query(10, ge=1, le=50, description="Maximum number of results"),
|
||||
score_threshold: float = Query(0.3, ge=0.0, le=1.0, description="Minimum score threshold"),
|
||||
collection_name: Optional[str] = Query(None, description="Collection name to search"),
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Debug search endpoint with detailed information"""
|
||||
try:
|
||||
# Get configuration
|
||||
app_config = settings
|
||||
|
||||
# Initialize RAG module
|
||||
rag_module = RAGModule(app_config)
|
||||
|
||||
# Get available collections if none specified
|
||||
if not collection_name:
|
||||
collections = await rag_module.list_collections()
|
||||
if collections:
|
||||
collection_name = collections[0] # Use first collection
|
||||
else:
|
||||
return {
|
||||
"results": [],
|
||||
"debug_info": {
|
||||
"error": "No collections available",
|
||||
"collections_found": 0
|
||||
},
|
||||
"search_time_ms": 0
|
||||
}
|
||||
|
||||
# Perform search
|
||||
results = await rag_module.search(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
score_threshold=score_threshold,
|
||||
collection_name=collection_name,
|
||||
config=config or {}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Debug search error: {e}")
|
||||
return {
|
||||
"results": [],
|
||||
"debug_info": {
|
||||
"error": str(e),
|
||||
"query": query,
|
||||
"collection_name": collection_name
|
||||
},
|
||||
"search_time_ms": 0
|
||||
}
|
||||
@@ -16,7 +16,6 @@ from .analytics import router as analytics_router
|
||||
from .rag import router as rag_router
|
||||
from .chatbot import router as chatbot_router
|
||||
from .prompt_templates import router as prompt_templates_router
|
||||
from .security import router as security_router
|
||||
from .plugin_registry import router as plugin_registry_router
|
||||
|
||||
# Create main API router
|
||||
@@ -61,8 +60,6 @@ api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
|
||||
# Include prompt template routes
|
||||
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
|
||||
|
||||
# Include security routes
|
||||
api_router.include_router(security_router, prefix="/security", tags=["security"])
|
||||
|
||||
|
||||
# Include plugin registry routes
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
Authentication API endpoints
|
||||
"""
|
||||
"""Authentication API endpoints"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
@@ -13,6 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.core.security import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
@@ -26,7 +25,7 @@ from app.db.database import get_db
|
||||
from app.models.user import User
|
||||
from app.utils.exceptions import AuthenticationError, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
security = HTTPBearer()
|
||||
@@ -162,88 +161,70 @@ async def login(
|
||||
):
|
||||
"""Login user and return access tokens"""
|
||||
|
||||
logger.info(f"==================================================")
|
||||
logger.info(f"=== LOGIN ENDPOINT REACHED ===")
|
||||
logger.info(f"==================================================")
|
||||
logger.info(f"=== LOGIN DEBUG ===")
|
||||
logger.info(f"Login attempt for email: {user_data.email}")
|
||||
logger.info(f"Current UTC time: {datetime.utcnow().isoformat()}")
|
||||
logger.info(f"Settings check - DATABASE_URL: {'SET' if settings.DATABASE_URL else 'NOT SET'}")
|
||||
logger.info(f"Settings check - JWT_SECRET: {'SET' if settings.JWT_SECRET else 'NOT SET'}")
|
||||
logger.info(f"Settings check - ADMIN_EMAIL: {settings.ADMIN_EMAIL}")
|
||||
logger.info(f"Settings check - BCRYPT_ROUNDS: {settings.BCRYPT_ROUNDS}")
|
||||
|
||||
# DEBUG: Check Redis connection
|
||||
try:
|
||||
logger.info("Testing Redis connection...")
|
||||
import redis.asyncio as redis
|
||||
redis_url = settings.REDIS_URL
|
||||
logger.info(f"Redis URL: {redis_url}")
|
||||
redis_client = redis.from_url(redis_url)
|
||||
test_start = datetime.utcnow()
|
||||
await redis_client.ping()
|
||||
test_end = datetime.utcnow()
|
||||
logger.info(f"Redis connection test successful. Time: {(test_end - test_start).total_seconds()} seconds")
|
||||
await redis_client.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Redis connection test failed: {e}")
|
||||
|
||||
# DEBUG: Check database connection with timeout
|
||||
try:
|
||||
logger.info("Testing database connection...")
|
||||
test_start = datetime.utcnow()
|
||||
await db.execute(select(1))
|
||||
test_end = datetime.utcnow()
|
||||
logger.info(f"Database connection test successful. Time: {(test_end - test_start).total_seconds()} seconds")
|
||||
except Exception as e:
|
||||
logger.error(f"Database connection test failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Database connection error"
|
||||
logger.info(
|
||||
"LOGIN_DEBUG_START",
|
||||
request_time=datetime.utcnow().isoformat(),
|
||||
email=user_data.email,
|
||||
database_url="SET" if settings.DATABASE_URL else "NOT SET",
|
||||
jwt_secret="SET" if settings.JWT_SECRET else "NOT SET",
|
||||
admin_email=settings.ADMIN_EMAIL,
|
||||
bcrypt_rounds=settings.BCRYPT_ROUNDS,
|
||||
)
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Get user by email
|
||||
logger.info("Querying user by email...")
|
||||
logger.info("LOGIN_USER_QUERY_START")
|
||||
query_start = datetime.utcnow()
|
||||
stmt = select(User).where(User.email == user_data.email)
|
||||
result = await db.execute(stmt)
|
||||
query_end = datetime.utcnow()
|
||||
logger.info(f"User query completed. Time: {(query_end - query_start).total_seconds()} seconds")
|
||||
logger.info(
|
||||
"LOGIN_USER_QUERY_END",
|
||||
duration_seconds=(query_end - query_start).total_seconds(),
|
||||
)
|
||||
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"User not found: {user_data.email}")
|
||||
logger.warning("LOGIN_USER_NOT_FOUND", email=user_data.email)
|
||||
# List available users for debugging
|
||||
try:
|
||||
all_users_stmt = select(User).limit(5)
|
||||
all_users_result = await db.execute(all_users_stmt)
|
||||
all_users = all_users_result.scalars().all()
|
||||
logger.info(f"Available users (first 5): {[u.email for u in all_users]}")
|
||||
logger.info(
|
||||
"LOGIN_USER_LIST",
|
||||
users=[u.email for u in all_users],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not list users: {e}")
|
||||
logger.error("LOGIN_USER_LIST_FAILURE", error=str(e))
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password"
|
||||
)
|
||||
|
||||
logger.info(f"User found: {user.email}, is_active: {user.is_active}")
|
||||
logger.info(f"User found, starting password verification...")
|
||||
logger.info("LOGIN_USER_FOUND", email=user.email, is_active=user.is_active)
|
||||
logger.info("LOGIN_PASSWORD_VERIFY_START")
|
||||
verify_start = datetime.utcnow()
|
||||
|
||||
if not verify_password(user_data.password, user.hashed_password):
|
||||
verify_end = datetime.utcnow()
|
||||
logger.warning(f"Password verification failed. Time taken: {(verify_end - verify_start).total_seconds()} seconds")
|
||||
logger.warning(
|
||||
"LOGIN_PASSWORD_VERIFY_FAILURE",
|
||||
duration_seconds=(verify_end - verify_start).total_seconds(),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password"
|
||||
)
|
||||
|
||||
verify_end = datetime.utcnow()
|
||||
logger.info(f"Password verification successful. Time taken: {(verify_end - verify_start).total_seconds()} seconds")
|
||||
logger.info(
|
||||
"LOGIN_PASSWORD_VERIFY_SUCCESS",
|
||||
duration_seconds=(verify_end - verify_start).total_seconds(),
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
@@ -252,19 +233,20 @@ async def login(
|
||||
)
|
||||
|
||||
# Update last login
|
||||
logger.info("Updating last login...")
|
||||
logger.info("LOGIN_LAST_LOGIN_UPDATE_START")
|
||||
update_start = datetime.utcnow()
|
||||
user.update_last_login()
|
||||
await db.commit()
|
||||
update_end = datetime.utcnow()
|
||||
logger.info(f"Last login updated. Time: {(update_end - update_start).total_seconds()} seconds")
|
||||
logger.info(
|
||||
"LOGIN_LAST_LOGIN_UPDATE_SUCCESS",
|
||||
duration_seconds=(update_end - update_start).total_seconds(),
|
||||
)
|
||||
|
||||
# Create tokens
|
||||
logger.info("Creating tokens...")
|
||||
logger.info("LOGIN_TOKEN_CREATE_START")
|
||||
token_start = datetime.utcnow()
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
logger.info(f"Creating access token with expiration: {access_token_expires}")
|
||||
logger.info(f"ACCESS_TOKEN_EXPIRE_MINUTES from settings: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}")
|
||||
|
||||
access_token = create_access_token(
|
||||
data={
|
||||
@@ -280,10 +262,16 @@ async def login(
|
||||
data={"sub": str(user.id), "type": "refresh"}
|
||||
)
|
||||
token_end = datetime.utcnow()
|
||||
logger.info(f"Tokens created. Time: {(token_end - token_start).total_seconds()} seconds")
|
||||
logger.info(
|
||||
"LOGIN_TOKEN_CREATE_SUCCESS",
|
||||
duration_seconds=(token_end - token_start).total_seconds(),
|
||||
)
|
||||
|
||||
total_time = datetime.utcnow() - start_time
|
||||
logger.info(f"=== LOGIN COMPLETED === Total time: {total_time.total_seconds()} seconds")
|
||||
logger.info(
|
||||
"LOGIN_DEBUG_COMPLETE",
|
||||
total_duration_seconds=total_time.total_seconds(),
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
|
||||
@@ -32,12 +32,28 @@ class ChatbotCreateRequest(BaseModel):
|
||||
use_rag: bool = False
|
||||
rag_collection: Optional[str] = None
|
||||
rag_top_k: int = 5
|
||||
rag_score_threshold: float = 0.02 # Lowered from default 0.3 to allow more results
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 1000
|
||||
memory_length: int = 10
|
||||
fallback_responses: List[str] = []
|
||||
|
||||
|
||||
class ChatbotUpdateRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
chatbot_type: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
use_rag: Optional[bool] = None
|
||||
rag_collection: Optional[str] = None
|
||||
rag_top_k: Optional[int] = None
|
||||
rag_score_threshold: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
memory_length: Optional[int] = None
|
||||
fallback_responses: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
conversation_id: Optional[str] = None
|
||||
@@ -190,7 +206,7 @@ async def create_chatbot(
|
||||
@router.put("/update/{chatbot_id}")
|
||||
async def update_chatbot(
|
||||
chatbot_id: str,
|
||||
request: ChatbotCreateRequest,
|
||||
request: ChatbotUpdateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
@@ -214,28 +230,23 @@ async def update_chatbot(
|
||||
if not chatbot:
|
||||
raise HTTPException(status_code=404, detail="Chatbot not found or access denied")
|
||||
|
||||
# Update chatbot configuration
|
||||
config = {
|
||||
"name": request.name,
|
||||
"chatbot_type": request.chatbot_type,
|
||||
"model": request.model,
|
||||
"system_prompt": request.system_prompt,
|
||||
"use_rag": request.use_rag,
|
||||
"rag_collection": request.rag_collection,
|
||||
"rag_top_k": request.rag_top_k,
|
||||
"temperature": request.temperature,
|
||||
"max_tokens": request.max_tokens,
|
||||
"memory_length": request.memory_length,
|
||||
"fallback_responses": request.fallback_responses
|
||||
}
|
||||
# Get existing config
|
||||
existing_config = chatbot.config.copy() if chatbot.config else {}
|
||||
|
||||
# Update only the fields that are provided in the request
|
||||
update_data = request.dict(exclude_unset=True)
|
||||
|
||||
# Merge with existing config, preserving unset values
|
||||
for key, value in update_data.items():
|
||||
existing_config[key] = value
|
||||
|
||||
# Update the chatbot
|
||||
await db.execute(
|
||||
update(ChatbotInstance)
|
||||
.where(ChatbotInstance.id == chatbot_id)
|
||||
.values(
|
||||
name=request.name,
|
||||
config=config,
|
||||
name=existing_config.get("name", chatbot.name),
|
||||
config=existing_config,
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
)
|
||||
@@ -275,7 +286,7 @@ async def chat_with_chatbot(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Send a message to a chatbot and get a response"""
|
||||
"""Send a message to a chatbot and get a response (without persisting conversation)"""
|
||||
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
|
||||
log_api_request("chat_with_chatbot", {
|
||||
"user_id": user_id,
|
||||
@@ -298,42 +309,17 @@ async def chat_with_chatbot(
|
||||
if not chatbot.is_active:
|
||||
raise HTTPException(status_code=400, detail="Chatbot is not active")
|
||||
|
||||
# Initialize conversation service
|
||||
conversation_service = ConversationService(db)
|
||||
|
||||
# Get or create conversation
|
||||
conversation = await conversation_service.get_or_create_conversation(
|
||||
chatbot_id=chatbot_id,
|
||||
user_id=str(user_id),
|
||||
conversation_id=request.conversation_id
|
||||
)
|
||||
|
||||
# Add user message to conversation
|
||||
await conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="user",
|
||||
content=request.message,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
# Get chatbot module and generate response
|
||||
try:
|
||||
chatbot_module = module_manager.modules.get("chatbot")
|
||||
if not chatbot_module:
|
||||
raise HTTPException(status_code=500, detail="Chatbot module not available")
|
||||
|
||||
# Load conversation history for context
|
||||
conversation_history = await conversation_service.get_conversation_history(
|
||||
conversation_id=conversation.id,
|
||||
limit=chatbot.config.get('memory_length', 10),
|
||||
include_system=False
|
||||
)
|
||||
|
||||
# Use the chatbot module to generate a response
|
||||
# Use the chatbot module to generate a response (without persisting)
|
||||
response_data = await chatbot_module.chat(
|
||||
chatbot_config=chatbot.config,
|
||||
message=request.message,
|
||||
conversation_history=conversation_history,
|
||||
conversation_history=[], # Empty history for test chat
|
||||
user_id=str(user_id)
|
||||
)
|
||||
|
||||
@@ -346,19 +332,10 @@ async def chat_with_chatbot(
|
||||
])
|
||||
response_content = fallback_responses[0] if fallback_responses else "I'm sorry, I couldn't process your request."
|
||||
|
||||
# Save assistant message using conversation service
|
||||
assistant_message = await conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="assistant",
|
||||
content=response_content,
|
||||
metadata={},
|
||||
sources=response_data.get("sources")
|
||||
)
|
||||
|
||||
# Return response without conversation ID (since we're not persisting)
|
||||
return {
|
||||
"conversation_id": conversation.id,
|
||||
"response": response_content,
|
||||
"timestamp": assistant_message.timestamp.isoformat()
|
||||
"sources": response_data.get("sources")
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@@ -745,7 +745,6 @@ async def get_llm_metrics(
|
||||
"total_requests": metrics.total_requests,
|
||||
"successful_requests": metrics.successful_requests,
|
||||
"failed_requests": metrics.failed_requests,
|
||||
"security_blocked_requests": metrics.security_blocked_requests,
|
||||
"average_latency_ms": metrics.average_latency_ms,
|
||||
"average_risk_score": metrics.average_risk_score,
|
||||
"provider_metrics": metrics.provider_metrics,
|
||||
|
||||
@@ -493,6 +493,25 @@ async def seed_default_templates(
|
||||
existing_template.system_prompt = template_data["prompt"]
|
||||
existing_template.updated_at = datetime.utcnow()
|
||||
updated_templates.append(type_key)
|
||||
else:
|
||||
# Check if any inactive template exists with this type_key
|
||||
inactive_result = await db.execute(
|
||||
select(PromptTemplate)
|
||||
.where(PromptTemplate.type_key == type_key)
|
||||
.where(PromptTemplate.is_active == False)
|
||||
)
|
||||
inactive_template = inactive_result.scalar_one_or_none()
|
||||
|
||||
if inactive_template:
|
||||
# Reactivate the inactive template
|
||||
inactive_template.is_active = True
|
||||
inactive_template.name = template_data["name"]
|
||||
inactive_template.description = template_data["description"]
|
||||
inactive_template.system_prompt = template_data["prompt"]
|
||||
inactive_template.is_default = True
|
||||
inactive_template.version = 1
|
||||
inactive_template.updated_at = datetime.utcnow()
|
||||
updated_templates.append(type_key)
|
||||
else:
|
||||
# Create new template
|
||||
new_template = PromptTemplate(
|
||||
|
||||
@@ -3,12 +3,14 @@ RAG API Endpoints
|
||||
Provides REST API for RAG (Retrieval Augmented Generation) operations
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from pydantic import BaseModel
|
||||
import io
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
@@ -16,6 +18,9 @@ from app.models.user import User
|
||||
from app.services.rag_service import RAGService
|
||||
from app.utils.exceptions import APIException
|
||||
|
||||
# Import RAG module from module manager
|
||||
from app.services.module_manager import module_manager
|
||||
|
||||
|
||||
router = APIRouter(tags=["RAG"])
|
||||
|
||||
@@ -78,14 +83,25 @@ async def get_collections(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get all RAG collections from Qdrant (source of truth) with PostgreSQL metadata"""
|
||||
"""Get all RAG collections - live data directly from Qdrant (source of truth)"""
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
collections_data = await rag_service.get_all_collections(skip=skip, limit=limit)
|
||||
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||
|
||||
# Get live stats from Qdrant
|
||||
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||
collections = stats_data.get("collections", [])
|
||||
|
||||
# Apply pagination
|
||||
start_idx = skip
|
||||
end_idx = skip + limit
|
||||
paginated_collections = collections[start_idx:end_idx]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"collections": collections_data,
|
||||
"total": len(collections_data)
|
||||
"collections": paginated_collections,
|
||||
"total": len(collections),
|
||||
"total_documents": stats_data.get("total_documents", 0),
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0)
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -116,6 +132,62 @@ async def create_collection(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stats", response_model=dict)
|
||||
async def get_rag_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get overall RAG statistics - live data directly from Qdrant"""
|
||||
try:
|
||||
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||
|
||||
# Get live stats from Qdrant
|
||||
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||
|
||||
# Calculate active collections (collections with documents)
|
||||
active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
|
||||
|
||||
# Calculate processing documents from database
|
||||
processing_docs = 0
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
from app.models.rag_document import RagDocument, ProcessingStatus
|
||||
|
||||
result = await db.execute(
|
||||
select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
|
||||
)
|
||||
processing_docs = len(result.scalars().all())
|
||||
except Exception:
|
||||
pass # If database query fails, default to 0
|
||||
|
||||
response_data = {
|
||||
"success": True,
|
||||
"stats": {
|
||||
"collections": {
|
||||
"total": stats_data.get("total_collections", 0),
|
||||
"active": active_collections
|
||||
},
|
||||
"documents": {
|
||||
"total": stats_data.get("total_documents", 0),
|
||||
"processing": processing_docs,
|
||||
"processed": stats_data.get("total_documents", 0) # Indexed documents
|
||||
},
|
||||
"storage": {
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0),
|
||||
"total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
|
||||
},
|
||||
"vectors": {
|
||||
"total": stats_data.get("total_documents", 0) # Same as documents for RAG
|
||||
},
|
||||
"last_updated": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
return response_data
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/collections/{collection_id}", response_model=dict)
|
||||
async def get_collection(
|
||||
collection_id: int,
|
||||
@@ -232,11 +304,55 @@ async def upload_document(
|
||||
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
|
||||
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
|
||||
|
||||
# Validate file can be read before processing
|
||||
filename = file.filename or "unknown"
|
||||
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
|
||||
|
||||
try:
|
||||
# Test file readability based on type
|
||||
if file_extension == 'jsonl':
|
||||
# Validate JSONL format - try to parse first few lines
|
||||
try:
|
||||
content_str = file_content.decode('utf-8')
|
||||
lines = content_str.strip().split('\n')[:5] # Check first 5 lines
|
||||
import json
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip(): # Skip empty lines
|
||||
json.loads(line) # Will raise JSONDecodeError if invalid
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
|
||||
|
||||
elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
|
||||
# Validate text files can be decoded
|
||||
try:
|
||||
file_content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
|
||||
elif file_extension in ['pdf']:
|
||||
# For PDF files, just check if it starts with PDF signature
|
||||
if not file_content.startswith(b'%PDF'):
|
||||
raise HTTPException(status_code=400, detail="Invalid PDF file format")
|
||||
|
||||
elif file_extension in ['docx', 'xlsx', 'pptx']:
|
||||
# For Office documents, check ZIP signature
|
||||
if not file_content.startswith(b'PK'):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
|
||||
|
||||
# For other file types, we'll rely on the document processor
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
|
||||
|
||||
rag_service = RAGService(db)
|
||||
document = await rag_service.upload_document(
|
||||
collection_id=collection_id,
|
||||
file_content=file_content,
|
||||
filename=file.filename or "unknown",
|
||||
filename=filename,
|
||||
content_type=file.content_type
|
||||
)
|
||||
|
||||
@@ -362,21 +478,167 @@ async def download_document(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Stats Endpoint
|
||||
|
||||
@router.get("/stats", response_model=dict)
|
||||
async def get_rag_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
# Debug Endpoints
|
||||
|
||||
@router.post("/debug/search")
|
||||
async def search_with_debug(
|
||||
query: str,
|
||||
max_results: int = 10,
|
||||
score_threshold: float = 0.3,
|
||||
collection_name: str = None,
|
||||
config: Dict[str, Any] = None,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get RAG system statistics"""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Enhanced search with comprehensive debug information
|
||||
"""
|
||||
# Get RAG module from module manager
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||
|
||||
debug_info = {}
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
stats = await rag_service.get_stats()
|
||||
# Apply configuration if provided
|
||||
if config:
|
||||
# Update RAG config temporarily
|
||||
original_config = rag_module.config.copy()
|
||||
rag_module.config.update(config)
|
||||
|
||||
# Generate query embedding (with or without prefix)
|
||||
if config and config.get("use_query_prefix"):
|
||||
optimized_query = f"query: {query}"
|
||||
else:
|
||||
optimized_query = query
|
||||
|
||||
query_embedding = await rag_module._generate_embedding(optimized_query)
|
||||
|
||||
# Store embedding info for debug
|
||||
if config and config.get("debug", {}).get("show_embeddings"):
|
||||
debug_info["query_embedding"] = query_embedding[:10] # First 10 dimensions
|
||||
debug_info["embedding_dimension"] = len(query_embedding)
|
||||
debug_info["optimized_query"] = optimized_query
|
||||
|
||||
# Perform search
|
||||
search_start = asyncio.get_event_loop().time()
|
||||
results = await rag_module.search_documents(
|
||||
query,
|
||||
max_results=max_results,
|
||||
score_threshold=score_threshold,
|
||||
collection_name=collection_name
|
||||
)
|
||||
search_time = (asyncio.get_event_loop().time() - search_start) * 1000
|
||||
|
||||
# Calculate score statistics
|
||||
scores = [r.score for r in results if r.score is not None]
|
||||
if scores:
|
||||
import statistics
|
||||
debug_info["score_stats"] = {
|
||||
"min": min(scores),
|
||||
"max": max(scores),
|
||||
"avg": statistics.mean(scores),
|
||||
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0
|
||||
}
|
||||
|
||||
# Get collection statistics
|
||||
try:
|
||||
from qdrant_client.http.models import Filter
|
||||
collection_name = collection_name or rag_module.default_collection_name
|
||||
|
||||
# Count total documents
|
||||
count_result = rag_module.qdrant_client.count(
|
||||
collection_name=collection_name,
|
||||
count_filter=Filter(must=[])
|
||||
)
|
||||
total_points = count_result.count
|
||||
|
||||
# Get unique documents and languages
|
||||
scroll_result = rag_module.qdrant_client.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=1000, # Sample for stats
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
unique_docs = set()
|
||||
languages = set()
|
||||
|
||||
for point in scroll_result[0]:
|
||||
payload = point.payload or {}
|
||||
doc_id = payload.get("document_id")
|
||||
if doc_id:
|
||||
unique_docs.add(doc_id)
|
||||
|
||||
language = payload.get("language")
|
||||
if language:
|
||||
languages.add(language)
|
||||
|
||||
debug_info["collection_stats"] = {
|
||||
"total_documents": len(unique_docs),
|
||||
"total_chunks": total_points,
|
||||
"languages": sorted(list(languages))
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
debug_info["collection_stats_error"] = str(e)
|
||||
|
||||
# Enhance results with debug info
|
||||
enhanced_results = []
|
||||
for result in results:
|
||||
enhanced_result = {
|
||||
"document": {
|
||||
"id": result.document.id,
|
||||
"content": result.document.content,
|
||||
"metadata": result.document.metadata
|
||||
},
|
||||
"score": result.score,
|
||||
"debug_info": {}
|
||||
}
|
||||
|
||||
# Add hybrid search debug info if available
|
||||
metadata = result.document.metadata or {}
|
||||
if "_vector_score" in metadata:
|
||||
enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
|
||||
if "_bm25_score" in metadata:
|
||||
enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
|
||||
|
||||
enhanced_results.append(enhanced_result)
|
||||
|
||||
# Note: Analytics logging disabled (module not available)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"stats": stats
|
||||
"results": enhanced_results,
|
||||
"debug_info": debug_info,
|
||||
"search_time_ms": search_time,
|
||||
"timestamp": start_time.isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
# Note: Analytics logging disabled (module not available)
|
||||
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
||||
|
||||
finally:
|
||||
# Restore original config if modified
|
||||
if config and 'original_config' in locals():
|
||||
rag_module.config = original_config
|
||||
|
||||
|
||||
@router.get("/debug/config")
|
||||
async def get_current_config(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current RAG configuration"""
|
||||
# Get RAG module from module manager
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||
|
||||
return {
|
||||
"config": rag_module.config,
|
||||
"embedding_model": rag_module.embedding_model,
|
||||
"enabled": rag_module.enabled,
|
||||
"collections": await rag_module._get_collections_safely()
|
||||
}
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
"""
|
||||
Security API endpoints for monitoring and configuration
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.security import get_current_active_user, RequiresRole
|
||||
from app.middleware.security import get_security_stats, get_request_auth_level, get_request_risk_score
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(tags=["security"])
|
||||
|
||||
|
||||
# Pydantic models for API responses
|
||||
class SecurityStatsResponse(BaseModel):
|
||||
"""Security statistics response model"""
|
||||
total_requests_analyzed: int
|
||||
threats_detected: int
|
||||
threats_blocked: int
|
||||
anomalies_detected: int
|
||||
rate_limits_exceeded: int
|
||||
avg_analysis_time: float
|
||||
threat_types: Dict[str, int]
|
||||
threat_levels: Dict[str, int]
|
||||
top_attacking_ips: List[tuple]
|
||||
security_enabled: bool
|
||||
threat_detection_enabled: bool
|
||||
rate_limiting_enabled: bool
|
||||
|
||||
|
||||
class SecurityConfigResponse(BaseModel):
|
||||
"""Security configuration response model"""
|
||||
security_enabled: bool = Field(description="Overall security system enabled")
|
||||
threat_detection_enabled: bool = Field(description="Threat detection analysis enabled")
|
||||
rate_limiting_enabled: bool = Field(description="Rate limiting enabled")
|
||||
ip_reputation_enabled: bool = Field(description="IP reputation checking enabled")
|
||||
anomaly_detection_enabled: bool = Field(description="Anomaly detection enabled")
|
||||
security_headers_enabled: bool = Field(description="Security headers enabled")
|
||||
|
||||
# Rate limiting settings
|
||||
unauthenticated_per_minute: int = Field(description="Rate limit for unauthenticated requests per minute")
|
||||
authenticated_per_minute: int = Field(description="Rate limit for authenticated users per minute")
|
||||
api_key_per_minute: int = Field(description="Rate limit for API key users per minute")
|
||||
premium_per_minute: int = Field(description="Rate limit for premium users per minute")
|
||||
|
||||
# Security thresholds
|
||||
risk_threshold: float = Field(description="Risk score threshold for blocking requests")
|
||||
warning_threshold: float = Field(description="Risk score threshold for warnings")
|
||||
anomaly_threshold: float = Field(description="Anomaly severity threshold")
|
||||
|
||||
# IP settings
|
||||
blocked_ips: List[str] = Field(description="List of blocked IP addresses")
|
||||
allowed_ips: List[str] = Field(description="List of allowed IP addresses (empty = allow all)")
|
||||
|
||||
|
||||
class RateLimitInfoResponse(BaseModel):
|
||||
"""Rate limit information for current request"""
|
||||
auth_level: str = Field(description="Authentication level (unauthenticated, authenticated, api_key, premium)")
|
||||
current_limits: Dict[str, int] = Field(description="Current rate limits for this auth level")
|
||||
remaining_requests: Optional[Dict[str, int]] = Field(description="Estimated remaining requests (if available)")
|
||||
|
||||
|
||||
@router.get("/stats", response_model=SecurityStatsResponse)
|
||||
async def get_security_statistics(
|
||||
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
|
||||
):
|
||||
"""
|
||||
Get security system statistics
|
||||
|
||||
Requires admin role. Returns comprehensive statistics about:
|
||||
- Request analysis counts
|
||||
- Threat detection results
|
||||
- Rate limiting enforcement
|
||||
- Top attacking IPs
|
||||
- Performance metrics
|
||||
"""
|
||||
try:
|
||||
stats = get_security_stats()
|
||||
return SecurityStatsResponse(**stats)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting security stats: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve security statistics"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config", response_model=SecurityConfigResponse)
|
||||
async def get_security_config(
|
||||
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
|
||||
):
|
||||
"""
|
||||
Get current security configuration
|
||||
|
||||
Requires admin role. Returns current security settings including:
|
||||
- Feature enablement flags
|
||||
- Rate limiting thresholds
|
||||
- Security thresholds
|
||||
- IP allowlists/blocklists
|
||||
"""
|
||||
return SecurityConfigResponse(
|
||||
security_enabled=settings.API_SECURITY_ENABLED,
|
||||
threat_detection_enabled=settings.API_THREAT_DETECTION_ENABLED,
|
||||
rate_limiting_enabled=settings.API_RATE_LIMITING_ENABLED,
|
||||
ip_reputation_enabled=settings.API_IP_REPUTATION_ENABLED,
|
||||
anomaly_detection_enabled=settings.API_ANOMALY_DETECTION_ENABLED,
|
||||
security_headers_enabled=settings.API_SECURITY_HEADERS_ENABLED,
|
||||
|
||||
unauthenticated_per_minute=settings.API_RATE_LIMIT_UNAUTHENTICATED_PER_MINUTE,
|
||||
authenticated_per_minute=settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE,
|
||||
api_key_per_minute=settings.API_RATE_LIMIT_API_KEY_PER_MINUTE,
|
||||
premium_per_minute=settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE,
|
||||
|
||||
risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
|
||||
warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
|
||||
anomaly_threshold=settings.API_SECURITY_ANOMALY_THRESHOLD,
|
||||
|
||||
blocked_ips=settings.API_BLOCKED_IPS,
|
||||
allowed_ips=settings.API_ALLOWED_IPS
|
||||
)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_security_status(
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
Get security status for current request
|
||||
|
||||
Returns information about the security analysis of the current request:
|
||||
- Authentication level
|
||||
- Risk score (if available)
|
||||
- Rate limiting status
|
||||
"""
|
||||
auth_level = get_request_auth_level(request)
|
||||
risk_score = get_request_risk_score(request)
|
||||
|
||||
# Get rate limits for current auth level
|
||||
from app.core.threat_detection import AuthLevel
|
||||
try:
|
||||
auth_enum = AuthLevel(auth_level)
|
||||
from app.core.threat_detection import threat_detection_service
|
||||
minute_limit, hour_limit = threat_detection_service.get_rate_limits(auth_enum)
|
||||
|
||||
rate_limit_info = RateLimitInfoResponse(
|
||||
auth_level=auth_level,
|
||||
current_limits={
|
||||
"per_minute": minute_limit,
|
||||
"per_hour": hour_limit
|
||||
},
|
||||
remaining_requests=None # We don't track remaining requests in current implementation
|
||||
)
|
||||
except ValueError:
|
||||
rate_limit_info = RateLimitInfoResponse(
|
||||
auth_level=auth_level,
|
||||
current_limits={},
|
||||
remaining_requests=None
|
||||
)
|
||||
|
||||
return {
|
||||
"security_enabled": settings.API_SECURITY_ENABLED,
|
||||
"auth_level": auth_level,
|
||||
"risk_score": round(risk_score, 3) if risk_score > 0 else None,
|
||||
"rate_limit_info": rate_limit_info.dict(),
|
||||
"security_headers_enabled": settings.API_SECURITY_HEADERS_ENABLED
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_security_analysis(
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
|
||||
):
|
||||
"""
|
||||
Test security analysis on current request
|
||||
|
||||
Requires admin role. Manually triggers security analysis on the current request
|
||||
and returns detailed results. Useful for testing security rules and thresholds.
|
||||
"""
|
||||
try:
|
||||
from app.middleware.security import analyze_request_security
|
||||
|
||||
analysis = await analyze_request_security(request, current_user)
|
||||
|
||||
return {
|
||||
"analysis_complete": True,
|
||||
"is_threat": analysis.is_threat,
|
||||
"risk_score": round(analysis.risk_score, 3),
|
||||
"auth_level": analysis.auth_level.value,
|
||||
"should_block": analysis.should_block,
|
||||
"rate_limit_exceeded": analysis.rate_limit_exceeded,
|
||||
"threat_count": len(analysis.threats),
|
||||
"threats": [
|
||||
{
|
||||
"type": threat.threat_type,
|
||||
"level": threat.level.value,
|
||||
"confidence": round(threat.confidence, 3),
|
||||
"description": threat.description,
|
||||
"mitigation": threat.mitigation
|
||||
}
|
||||
for threat in analysis.threats
|
||||
],
|
||||
"recommendations": analysis.recommendations
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in security analysis test: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to perform security analysis test"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def security_health_check():
|
||||
"""
|
||||
Security system health check
|
||||
|
||||
Public endpoint that returns the health status of the security system.
|
||||
Does not require authentication.
|
||||
"""
|
||||
try:
|
||||
stats = get_security_stats()
|
||||
|
||||
# Basic health checks
|
||||
is_healthy = (
|
||||
settings.API_SECURITY_ENABLED and
|
||||
stats.get("total_requests_analyzed", 0) >= 0 and
|
||||
stats.get("avg_analysis_time", 0) < 1.0 # Analysis should be under 1 second
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "healthy" if is_healthy else "degraded",
|
||||
"security_enabled": settings.API_SECURITY_ENABLED,
|
||||
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
|
||||
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED,
|
||||
"avg_analysis_time_ms": round(stats.get("avg_analysis_time", 0) * 1000, 2),
|
||||
"total_requests_analyzed": stats.get("total_requests_analyzed", 0)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Security health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": "Security system error",
|
||||
"security_enabled": settings.API_SECURITY_ENABLED
|
||||
}
|
||||
@@ -97,7 +97,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
|
||||
"api": {
|
||||
# Security Settings
|
||||
"security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"},
|
||||
"threat_detection_enabled": {"value": True, "type": "boolean", "description": "Enable threat detection analysis"},
|
||||
"rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"},
|
||||
"ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"},
|
||||
"anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"},
|
||||
@@ -112,7 +111,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
|
||||
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"},
|
||||
|
||||
# Security Thresholds
|
||||
"security_risk_threshold": {"value": 0.8, "type": "float", "description": "Risk score threshold for blocking requests (0.0-1.0)"},
|
||||
"security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"},
|
||||
"anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"},
|
||||
|
||||
@@ -601,7 +599,6 @@ async def reset_to_defaults(
|
||||
"api": {
|
||||
# Security Settings
|
||||
"security_enabled": {"value": True, "type": "boolean"},
|
||||
"threat_detection_enabled": {"value": True, "type": "boolean"},
|
||||
"rate_limiting_enabled": {"value": True, "type": "boolean"},
|
||||
"ip_reputation_enabled": {"value": True, "type": "boolean"},
|
||||
"anomaly_detection_enabled": {"value": True, "type": "boolean"},
|
||||
@@ -616,7 +613,6 @@ async def reset_to_defaults(
|
||||
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer"},
|
||||
|
||||
# Security Thresholds
|
||||
"security_risk_threshold": {"value": 0.8, "type": "float"},
|
||||
"security_warning_threshold": {"value": 0.6, "type": "float"},
|
||||
"anomaly_threshold": {"value": 0.7, "type": "float"},
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ class Settings(BaseSettings):
|
||||
APP_LOG_LEVEL: str = os.getenv("APP_LOG_LEVEL", "INFO")
|
||||
APP_HOST: str = os.getenv("APP_HOST", "0.0.0.0")
|
||||
APP_PORT: int = int(os.getenv("APP_PORT", "8000"))
|
||||
BACKEND_INTERNAL_PORT: int = int(os.getenv("BACKEND_INTERNAL_PORT", "8000"))
|
||||
FRONTEND_INTERNAL_PORT: int = int(os.getenv("FRONTEND_INTERNAL_PORT", "3000"))
|
||||
|
||||
# Detailed logging for LLM interactions
|
||||
LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM
|
||||
@@ -75,44 +77,37 @@ class Settings(BaseSettings):
|
||||
QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost")
|
||||
QDRANT_PORT: int = int(os.getenv("QDRANT_PORT", "6333"))
|
||||
QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY")
|
||||
QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
|
||||
# API & Security Settings
|
||||
API_SECURITY_ENABLED: bool = os.getenv("API_SECURITY_ENABLED", "True").lower() == "true"
|
||||
API_THREAT_DETECTION_ENABLED: bool = os.getenv("API_THREAT_DETECTION_ENABLED", "True").lower() == "true"
|
||||
API_IP_REPUTATION_ENABLED: bool = os.getenv("API_IP_REPUTATION_ENABLED", "True").lower() == "true"
|
||||
API_ANOMALY_DETECTION_ENABLED: bool = os.getenv("API_ANOMALY_DETECTION_ENABLED", "True").lower() == "true"
|
||||
|
||||
# Rate Limiting Configuration
|
||||
API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true"
|
||||
|
||||
# Authenticated users (JWT token)
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "300"))
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "5000"))
|
||||
# PrivateMode Standard tier limits (organization-level, not per user)
|
||||
# These are shared across all API keys and users in the organization
|
||||
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20"))
|
||||
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200"))
|
||||
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000"))
|
||||
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000"))
|
||||
|
||||
# Per-user limits (additional protection on top of organization limits)
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200"))
|
||||
|
||||
# API key users (programmatic access)
|
||||
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "1000"))
|
||||
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "20000"))
|
||||
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200"))
|
||||
|
||||
# Premium/Enterprise API keys
|
||||
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "5000"))
|
||||
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "100000"))
|
||||
|
||||
# Security Thresholds
|
||||
API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score
|
||||
API_SECURITY_WARNING_THRESHOLD: float = float(os.getenv("API_SECURITY_WARNING_THRESHOLD", "0.6")) # Log warnings above this threshold
|
||||
API_SECURITY_ANOMALY_THRESHOLD: float = float(os.getenv("API_SECURITY_ANOMALY_THRESHOLD", "0.7")) # Flag anomalies above this threshold
|
||||
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
|
||||
|
||||
# Request Size Limits
|
||||
API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB
|
||||
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium
|
||||
|
||||
# IP Security
|
||||
API_BLOCKED_IPS: List[str] = os.getenv("API_BLOCKED_IPS", "").split(",") if os.getenv("API_BLOCKED_IPS") else []
|
||||
API_ALLOWED_IPS: List[str] = os.getenv("API_ALLOWED_IPS", "").split(",") if os.getenv("API_ALLOWED_IPS") else []
|
||||
API_IP_REPUTATION_CACHE_TTL: int = int(os.getenv("API_IP_REPUTATION_CACHE_TTL", "3600")) # 1 hour
|
||||
|
||||
# Security Headers
|
||||
API_SECURITY_HEADERS_ENABLED: bool = os.getenv("API_SECURITY_HEADERS_ENABLED", "True").lower() == "true"
|
||||
API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
|
||||
|
||||
# Monitoring
|
||||
@@ -125,6 +120,19 @@ class Settings(BaseSettings):
|
||||
# Module configuration
|
||||
MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml")
|
||||
|
||||
# RAG Embedding Configuration
|
||||
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12"))
|
||||
RAG_EMBEDDING_BATCH_SIZE: int = int(os.getenv("RAG_EMBEDDING_BATCH_SIZE", "3"))
|
||||
RAG_EMBEDDING_RETRY_COUNT: int = int(os.getenv("RAG_EMBEDDING_RETRY_COUNT", "3"))
|
||||
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv("RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16")
|
||||
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0"))
|
||||
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
|
||||
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
|
||||
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
|
||||
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
|
||||
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
|
||||
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
|
||||
|
||||
# Plugin configuration
|
||||
PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins")
|
||||
PLUGINS_CONFIG_PATH: str = os.getenv("PLUGINS_CONFIG_PATH", "config/plugins.yaml")
|
||||
@@ -137,17 +145,13 @@ class Settings(BaseSettings):
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"case_sensitive": True
|
||||
"case_sensitive": True,
|
||||
# Ignore unknown environment variables to avoid validation errors
|
||||
# when optional/deprecated flags are present in .env
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
|
||||
# Log configuration values for debugging
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"JWT Configuration loaded:")
|
||||
logger.info(f"ACCESS_TOKEN_EXPIRE_MINUTES: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}")
|
||||
logger.info(f"REFRESH_TOKEN_EXPIRE_MINUTES: {settings.REFRESH_TOKEN_EXPIRE_MINUTES}")
|
||||
logger.info(f"JWT_ALGORITHM: {settings.JWT_ALGORITHM}")
|
||||
@@ -2,6 +2,8 @@
|
||||
Security utilities for authentication and authorization
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
@@ -33,19 +35,29 @@ security = HTTPBearer()
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
logger.info(f"=== PASSWORD VERIFICATION START === BCRYPT_ROUNDS: {settings.BCRYPT_ROUNDS}")
|
||||
|
||||
try:
|
||||
result = pwd_context.verify(plain_password, hashed_password)
|
||||
# Run password verification in a thread with timeout
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(pwd_context.verify, plain_password, hashed_password)
|
||||
result = future.result(timeout=5.0) # 5 second timeout
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
logger.info(f"=== PASSWORD VERIFICATION END === Duration: {duration:.3f}s, Result: {result}")
|
||||
|
||||
if duration > 5:
|
||||
if duration > 1:
|
||||
logger.warning(f"PASSWORD VERIFICATION TOOK TOO LONG: {duration:.3f}s")
|
||||
|
||||
return result
|
||||
except concurrent.futures.TimeoutError:
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
logger.error(f"=== PASSWORD VERIFICATION TIMEOUT === Duration: {duration:.3f}s")
|
||||
return False # Treat timeout as verification failure
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
|
||||
@@ -1,744 +0,0 @@
|
||||
"""
|
||||
Core threat detection and security analysis for the platform
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Set, Tuple, Any, Union
|
||||
from urllib.parse import unquote
|
||||
|
||||
from fastapi import Request
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ThreatLevel(Enum):
|
||||
"""Threat severity levels"""
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class AuthLevel(Enum):
|
||||
"""Authentication levels for rate limiting"""
|
||||
AUTHENTICATED = "authenticated"
|
||||
API_KEY = "api_key"
|
||||
PREMIUM = "premium"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityThreat:
|
||||
"""Security threat detection result"""
|
||||
threat_type: str
|
||||
level: ThreatLevel
|
||||
confidence: float
|
||||
description: str
|
||||
source_ip: str
|
||||
user_agent: Optional[str] = None
|
||||
request_path: Optional[str] = None
|
||||
payload: Optional[str] = None
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
mitigation: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityAnalysis:
|
||||
"""Comprehensive security analysis result"""
|
||||
is_threat: bool
|
||||
threats: List[SecurityThreat]
|
||||
risk_score: float
|
||||
recommendations: List[str]
|
||||
auth_level: AuthLevel
|
||||
rate_limit_exceeded: bool
|
||||
should_block: bool
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitInfo:
|
||||
"""Rate limiting information"""
|
||||
auth_level: AuthLevel
|
||||
requests_per_minute: int
|
||||
requests_per_hour: int
|
||||
minute_limit: int
|
||||
hour_limit: int
|
||||
exceeded: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnomalyDetection:
|
||||
"""Anomaly detection result"""
|
||||
is_anomaly: bool
|
||||
anomaly_type: str
|
||||
severity: float
|
||||
details: Dict[str, Any]
|
||||
baseline_value: Optional[float] = None
|
||||
current_value: Optional[float] = None
|
||||
|
||||
|
||||
class ThreatDetectionService:
|
||||
"""Core threat detection and security analysis service"""
|
||||
|
||||
def __init__(self):
|
||||
self.name = "threat_detection"
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'total_requests_analyzed': 0,
|
||||
'threats_detected': 0,
|
||||
'threats_blocked': 0,
|
||||
'anomalies_detected': 0,
|
||||
'rate_limits_exceeded': 0,
|
||||
'total_analysis_time': 0,
|
||||
'threat_types': defaultdict(int),
|
||||
'threat_levels': defaultdict(int),
|
||||
'attacking_ips': defaultdict(int)
|
||||
}
|
||||
|
||||
# Threat detection patterns
|
||||
self.sql_injection_patterns = [
|
||||
r"(\bunion\b.*\bselect\b)",
|
||||
r"(\bselect\b.*\bfrom\b)",
|
||||
r"(\binsert\b.*\binto\b)",
|
||||
r"(\bupdate\b.*\bset\b)",
|
||||
r"(\bdelete\b.*\bfrom\b)",
|
||||
r"(\bdrop\b.*\btable\b)",
|
||||
r"(\bor\b.*\b1\s*=\s*1\b)",
|
||||
r"(\band\b.*\b1\s*=\s*1\b)",
|
||||
r"(\bexec\b.*\bxp_\w+)",
|
||||
r"(\bsp_\w+)",
|
||||
r"(\bsleep\b\s*\(\s*\d+\s*\))",
|
||||
r"(\bwaitfor\b.*\bdelay\b)",
|
||||
r"(\bbenchmark\b\s*\(\s*\d+)",
|
||||
r"(\bload_file\b\s*\()",
|
||||
r"(\binto\b.*\boutfile\b)"
|
||||
]
|
||||
|
||||
self.xss_patterns = [
|
||||
r"<script[^>]*>.*?</script>",
|
||||
r"<iframe[^>]*>.*?</iframe>",
|
||||
r"<object[^>]*>.*?</object>",
|
||||
r"<embed[^>]*>.*?</embed>",
|
||||
r"<link[^>]*>",
|
||||
r"<meta[^>]*>",
|
||||
r"javascript:",
|
||||
r"vbscript:",
|
||||
r"on\w+\s*=",
|
||||
r"style\s*=.*expression",
|
||||
r"style\s*=.*javascript"
|
||||
]
|
||||
|
||||
self.path_traversal_patterns = [
|
||||
r"\.\.\/",
|
||||
r"\.\.\\",
|
||||
r"%2e%2e%2f",
|
||||
r"%2e%2e%5c",
|
||||
r"..%2f",
|
||||
r"..%5c",
|
||||
r"%252e%252e%252f",
|
||||
r"%252e%252e%255c"
|
||||
]
|
||||
|
||||
self.command_injection_patterns = [
|
||||
r";\s*cat\s+",
|
||||
r";\s*ls\s+",
|
||||
r";\s*pwd\s*",
|
||||
r";\s*whoami\s*",
|
||||
r";\s*id\s*",
|
||||
r";\s*uname\s*",
|
||||
r";\s*ps\s+",
|
||||
r";\s*netstat\s+",
|
||||
r";\s*wget\s+",
|
||||
r";\s*curl\s+",
|
||||
r"\|\s*cat\s+",
|
||||
r"\|\s*ls\s+",
|
||||
r"&&\s*cat\s+",
|
||||
r"&&\s*ls\s+"
|
||||
]
|
||||
|
||||
self.suspicious_ua_patterns = [
|
||||
r"sqlmap",
|
||||
r"nikto",
|
||||
r"nmap",
|
||||
r"masscan",
|
||||
r"zap",
|
||||
r"burp",
|
||||
r"w3af",
|
||||
r"acunetix",
|
||||
r"nessus",
|
||||
r"openvas",
|
||||
r"metasploit"
|
||||
]
|
||||
|
||||
# Rate limiting tracking - separate by auth level (excluding unauthenticated since they're blocked)
|
||||
self.rate_limits = {
|
||||
AuthLevel.AUTHENTICATED: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
|
||||
AuthLevel.API_KEY: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
|
||||
AuthLevel.PREMIUM: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)})
|
||||
}
|
||||
|
||||
# Anomaly detection
|
||||
self.request_history = deque(maxlen=1000)
|
||||
self.ip_history = defaultdict(lambda: deque(maxlen=100))
|
||||
self.endpoint_history = defaultdict(lambda: deque(maxlen=100))
|
||||
|
||||
# Blocked and allowed IPs
|
||||
self.blocked_ips = set(settings.API_BLOCKED_IPS)
|
||||
self.allowed_ips = set(settings.API_ALLOWED_IPS) if settings.API_ALLOWED_IPS else None
|
||||
|
||||
# IP reputation cache
|
||||
self.ip_reputation_cache = {}
|
||||
self.cache_expiry = {}
|
||||
|
||||
# Compile patterns for performance
|
||||
self._compile_patterns()
|
||||
|
||||
logger.info(f"ThreatDetectionService initialized with {len(self.sql_injection_patterns)} SQL patterns, "
|
||||
f"{len(self.xss_patterns)} XSS patterns, rate limiting enabled: {settings.API_RATE_LIMITING_ENABLED}")
|
||||
|
||||
def _compile_patterns(self):
|
||||
"""Compile regex patterns for better performance"""
|
||||
try:
|
||||
self.compiled_sql_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.sql_injection_patterns]
|
||||
self.compiled_xss_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.xss_patterns]
|
||||
self.compiled_path_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.path_traversal_patterns]
|
||||
self.compiled_cmd_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.command_injection_patterns]
|
||||
self.compiled_ua_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.suspicious_ua_patterns]
|
||||
except re.error as e:
|
||||
logger.error(f"Failed to compile security patterns: {e}")
|
||||
# Fallback to empty lists to prevent crashes
|
||||
self.compiled_sql_patterns = []
|
||||
self.compiled_xss_patterns = []
|
||||
self.compiled_path_patterns = []
|
||||
self.compiled_cmd_patterns = []
|
||||
self.compiled_ua_patterns = []
|
||||
|
||||
def determine_auth_level(self, request: Request, user_context: Optional[Dict] = None) -> AuthLevel:
|
||||
"""Determine authentication level for rate limiting"""
|
||||
# Check if request has API key authentication
|
||||
if hasattr(request.state, 'api_key_context') and request.state.api_key_context:
|
||||
api_key = request.state.api_key_context.get('api_key')
|
||||
if api_key and hasattr(api_key, 'tier'):
|
||||
# Check for premium tier
|
||||
if api_key.tier in ['premium', 'enterprise']:
|
||||
return AuthLevel.PREMIUM
|
||||
return AuthLevel.API_KEY
|
||||
|
||||
# Check for JWT authentication
|
||||
if user_context or hasattr(request.state, 'user'):
|
||||
return AuthLevel.AUTHENTICATED
|
||||
|
||||
# Check Authorization header for API key
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
api_key_header = request.headers.get("X-API-Key", "")
|
||||
if auth_header.startswith("Bearer ") or api_key_header:
|
||||
return AuthLevel.API_KEY
|
||||
|
||||
# Default to authenticated since unauthenticated requests are blocked at middleware
|
||||
return AuthLevel.AUTHENTICATED
|
||||
|
||||
def get_rate_limits(self, auth_level: AuthLevel) -> Tuple[int, int]:
|
||||
"""Get rate limits for authentication level"""
|
||||
if not settings.API_RATE_LIMITING_ENABLED:
|
||||
return float('inf'), float('inf')
|
||||
|
||||
if auth_level == AuthLevel.AUTHENTICATED:
|
||||
return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
|
||||
elif auth_level == AuthLevel.API_KEY:
|
||||
return (settings.API_RATE_LIMIT_API_KEY_PER_MINUTE, settings.API_RATE_LIMIT_API_KEY_PER_HOUR)
|
||||
elif auth_level == AuthLevel.PREMIUM:
|
||||
return (settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE, settings.API_RATE_LIMIT_PREMIUM_PER_HOUR)
|
||||
else:
|
||||
# Fallback to authenticated limits
|
||||
return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
|
||||
|
||||
def check_rate_limit(self, client_ip: str, auth_level: AuthLevel) -> RateLimitInfo:
|
||||
"""Check if request exceeds rate limits"""
|
||||
minute_limit, hour_limit = self.get_rate_limits(auth_level)
|
||||
current_time = time.time()
|
||||
|
||||
# Get or create tracking for this auth level
|
||||
if auth_level not in self.rate_limits:
|
||||
# This shouldn't happen, but handle gracefully
|
||||
return RateLimitInfo(
|
||||
auth_level=auth_level,
|
||||
requests_per_minute=0,
|
||||
requests_per_hour=0,
|
||||
minute_limit=minute_limit,
|
||||
hour_limit=hour_limit,
|
||||
exceeded=False
|
||||
)
|
||||
|
||||
ip_limits = self.rate_limits[auth_level][client_ip]
|
||||
|
||||
# Clean old entries
|
||||
minute_ago = current_time - 60
|
||||
hour_ago = current_time - 3600
|
||||
|
||||
while ip_limits['minute'] and ip_limits['minute'][0] < minute_ago:
|
||||
ip_limits['minute'].popleft()
|
||||
|
||||
while ip_limits['hour'] and ip_limits['hour'][0] < hour_ago:
|
||||
ip_limits['hour'].popleft()
|
||||
|
||||
# Check current counts
|
||||
requests_per_minute = len(ip_limits['minute'])
|
||||
requests_per_hour = len(ip_limits['hour'])
|
||||
|
||||
# Check if limits exceeded
|
||||
exceeded = (requests_per_minute >= minute_limit) or (requests_per_hour >= hour_limit)
|
||||
|
||||
# Add current request to tracking
|
||||
if not exceeded:
|
||||
ip_limits['minute'].append(current_time)
|
||||
ip_limits['hour'].append(current_time)
|
||||
|
||||
return RateLimitInfo(
|
||||
auth_level=auth_level,
|
||||
requests_per_minute=requests_per_minute,
|
||||
requests_per_hour=requests_per_hour,
|
||||
minute_limit=minute_limit,
|
||||
hour_limit=hour_limit,
|
||||
exceeded=exceeded
|
||||
)
|
||||
|
||||
async def analyze_request(self, request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
|
||||
"""Perform comprehensive security analysis on a request"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
path = str(request.url.path)
|
||||
method = request.method
|
||||
|
||||
# Determine authentication level
|
||||
auth_level = self.determine_auth_level(request, user_context)
|
||||
|
||||
# Check IP allowlist/blocklist first
|
||||
if self.allowed_ips and client_ip not in self.allowed_ips:
|
||||
threat = SecurityThreat(
|
||||
threat_type="ip_not_allowed",
|
||||
level=ThreatLevel.HIGH,
|
||||
confidence=1.0,
|
||||
description=f"IP {client_ip} not in allowlist",
|
||||
source_ip=client_ip,
|
||||
mitigation="Add IP to allowlist or remove IP restrictions"
|
||||
)
|
||||
return SecurityAnalysis(
|
||||
is_threat=True,
|
||||
threats=[threat],
|
||||
risk_score=1.0,
|
||||
recommendations=["Block request immediately"],
|
||||
auth_level=auth_level,
|
||||
rate_limit_exceeded=False,
|
||||
should_block=True
|
||||
)
|
||||
|
||||
if client_ip in self.blocked_ips:
|
||||
threat = SecurityThreat(
|
||||
threat_type="ip_blocked",
|
||||
level=ThreatLevel.CRITICAL,
|
||||
confidence=1.0,
|
||||
description=f"IP {client_ip} is blocked",
|
||||
source_ip=client_ip,
|
||||
mitigation="Remove IP from blocklist if legitimate"
|
||||
)
|
||||
return SecurityAnalysis(
|
||||
is_threat=True,
|
||||
threats=[threat],
|
||||
risk_score=1.0,
|
||||
recommendations=["Block request immediately"],
|
||||
auth_level=auth_level,
|
||||
rate_limit_exceeded=False,
|
||||
should_block=True
|
||||
)
|
||||
|
||||
# Check rate limiting
|
||||
rate_limit_info = self.check_rate_limit(client_ip, auth_level)
|
||||
if rate_limit_info.exceeded:
|
||||
self.stats['rate_limits_exceeded'] += 1
|
||||
threat = SecurityThreat(
|
||||
threat_type="rate_limit_exceeded",
|
||||
level=ThreatLevel.MEDIUM,
|
||||
confidence=0.9,
|
||||
description=f"Rate limit exceeded for {auth_level.value}: {rate_limit_info.requests_per_minute}/min, {rate_limit_info.requests_per_hour}/hr",
|
||||
source_ip=client_ip,
|
||||
mitigation=f"Implement rate limiting, current limits: {rate_limit_info.minute_limit}/min, {rate_limit_info.hour_limit}/hr"
|
||||
)
|
||||
return SecurityAnalysis(
|
||||
is_threat=True,
|
||||
threats=[threat],
|
||||
risk_score=0.7,
|
||||
recommendations=[f"Rate limit exceeded for {auth_level.value} user"],
|
||||
auth_level=auth_level,
|
||||
rate_limit_exceeded=True,
|
||||
should_block=True
|
||||
)
|
||||
|
||||
# Skip threat detection if disabled
|
||||
if not settings.API_THREAT_DETECTION_ENABLED:
|
||||
return SecurityAnalysis(
|
||||
is_threat=False,
|
||||
threats=[],
|
||||
risk_score=0.0,
|
||||
recommendations=[],
|
||||
auth_level=auth_level,
|
||||
rate_limit_exceeded=False,
|
||||
should_block=False
|
||||
)
|
||||
|
||||
# Collect request data for threat analysis
|
||||
query_params = str(request.query_params)
|
||||
headers = dict(request.headers)
|
||||
|
||||
# Try to get body content safely
|
||||
body_content = ""
|
||||
try:
|
||||
if hasattr(request, '_body') and request._body:
|
||||
body_content = request._body.decode() if isinstance(request._body, bytes) else str(request._body)
|
||||
except:
|
||||
pass
|
||||
|
||||
threats = []
|
||||
|
||||
# Analyze for various threats
|
||||
threats.extend(await self._detect_sql_injection(query_params, body_content, path, client_ip))
|
||||
threats.extend(await self._detect_xss(query_params, body_content, headers, client_ip))
|
||||
threats.extend(await self._detect_path_traversal(path, query_params, client_ip))
|
||||
threats.extend(await self._detect_command_injection(query_params, body_content, client_ip))
|
||||
threats.extend(await self._detect_suspicious_patterns(headers, user_agent, path, client_ip))
|
||||
|
||||
# Anomaly detection if enabled
|
||||
if settings.API_ANOMALY_DETECTION_ENABLED:
|
||||
anomaly = await self._detect_anomalies(client_ip, path, method, len(body_content))
|
||||
if anomaly.is_anomaly and anomaly.severity > settings.API_SECURITY_ANOMALY_THRESHOLD:
|
||||
threat = SecurityThreat(
|
||||
threat_type=f"anomaly_{anomaly.anomaly_type}",
|
||||
level=ThreatLevel.MEDIUM if anomaly.severity > 0.7 else ThreatLevel.LOW,
|
||||
confidence=anomaly.severity,
|
||||
description=f"Anomalous behavior detected: {anomaly.details}",
|
||||
source_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
request_path=path
|
||||
)
|
||||
threats.append(threat)
|
||||
|
||||
# Calculate risk score
|
||||
risk_score = self._calculate_risk_score(threats)
|
||||
|
||||
# Determine if request should be blocked
|
||||
should_block = risk_score >= settings.API_SECURITY_RISK_THRESHOLD
|
||||
|
||||
# Generate recommendations
|
||||
recommendations = self._generate_recommendations(threats, risk_score, auth_level)
|
||||
|
||||
# Update statistics
|
||||
self._update_stats(threats, time.time() - start_time)
|
||||
|
||||
return SecurityAnalysis(
|
||||
is_threat=len(threats) > 0,
|
||||
threats=threats,
|
||||
risk_score=risk_score,
|
||||
recommendations=recommendations,
|
||||
auth_level=auth_level,
|
||||
rate_limit_exceeded=False,
|
||||
should_block=should_block
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in threat analysis: {e}")
|
||||
return SecurityAnalysis(
|
||||
is_threat=False,
|
||||
threats=[],
|
||||
risk_score=0.0,
|
||||
recommendations=["Error occurred during security analysis"],
|
||||
auth_level=AuthLevel.AUTHENTICATED,
|
||||
rate_limit_exceeded=False,
|
||||
should_block=False
|
||||
)
|
||||
|
||||
async def _detect_sql_injection(self, query_params: str, body_content: str, path: str, client_ip: str) -> List[SecurityThreat]:
|
||||
"""Detect SQL injection attempts"""
|
||||
threats = []
|
||||
content_to_check = f"{query_params} {body_content} {path}".lower()
|
||||
|
||||
for pattern in self.compiled_sql_patterns:
|
||||
if pattern.search(content_to_check):
|
||||
threat = SecurityThreat(
|
||||
threat_type="sql_injection",
|
||||
level=ThreatLevel.HIGH,
|
||||
confidence=0.85,
|
||||
description="Potential SQL injection attempt detected",
|
||||
source_ip=client_ip,
|
||||
payload=pattern.pattern,
|
||||
mitigation="Block request, sanitize input, use parameterized queries"
|
||||
)
|
||||
threats.append(threat)
|
||||
break # Don't duplicate for multiple patterns
|
||||
|
||||
return threats
|
||||
|
||||
async def _detect_xss(self, query_params: str, body_content: str, headers: dict, client_ip: str) -> List[SecurityThreat]:
|
||||
"""Detect XSS attempts"""
|
||||
threats = []
|
||||
content_to_check = f"{query_params} {body_content}".lower()
|
||||
|
||||
# Check headers for XSS
|
||||
for header_name, header_value in headers.items():
|
||||
content_to_check += f" {header_value}".lower()
|
||||
|
||||
for pattern in self.compiled_xss_patterns:
|
||||
if pattern.search(content_to_check):
|
||||
threat = SecurityThreat(
|
||||
threat_type="xss",
|
||||
level=ThreatLevel.HIGH,
|
||||
confidence=0.80,
|
||||
description="Potential XSS attack detected",
|
||||
source_ip=client_ip,
|
||||
payload=pattern.pattern,
|
||||
mitigation="Block request, sanitize input, implement CSP headers"
|
||||
)
|
||||
threats.append(threat)
|
||||
break
|
||||
|
||||
return threats
|
||||
|
||||
async def _detect_path_traversal(self, path: str, query_params: str, client_ip: str) -> List[SecurityThreat]:
|
||||
"""Detect path traversal attempts"""
|
||||
threats = []
|
||||
content_to_check = f"{path} {query_params}".lower()
|
||||
decoded_content = unquote(content_to_check)
|
||||
|
||||
for pattern in self.compiled_path_patterns:
|
||||
if pattern.search(content_to_check) or pattern.search(decoded_content):
|
||||
threat = SecurityThreat(
|
||||
threat_type="path_traversal",
|
||||
level=ThreatLevel.HIGH,
|
||||
confidence=0.90,
|
||||
description="Path traversal attempt detected",
|
||||
source_ip=client_ip,
|
||||
request_path=path,
|
||||
mitigation="Block request, validate file paths, implement access controls"
|
||||
)
|
||||
threats.append(threat)
|
||||
break
|
||||
|
||||
return threats
|
||||
|
||||
async def _detect_command_injection(self, query_params: str, body_content: str, client_ip: str) -> List[SecurityThreat]:
|
||||
"""Detect command injection attempts"""
|
||||
threats = []
|
||||
content_to_check = f"{query_params} {body_content}".lower()
|
||||
|
||||
for pattern in self.compiled_cmd_patterns:
|
||||
if pattern.search(content_to_check):
|
||||
threat = SecurityThreat(
|
||||
threat_type="command_injection",
|
||||
level=ThreatLevel.CRITICAL,
|
||||
confidence=0.95,
|
||||
description="Command injection attempt detected",
|
||||
source_ip=client_ip,
|
||||
payload=pattern.pattern,
|
||||
mitigation="Block request immediately, sanitize input, disable shell execution"
|
||||
)
|
||||
threats.append(threat)
|
||||
break
|
||||
|
||||
return threats
|
||||
|
||||
async def _detect_suspicious_patterns(self, headers: dict, user_agent: str, path: str, client_ip: str) -> List[SecurityThreat]:
|
||||
"""Detect suspicious patterns in headers and user agent"""
|
||||
threats = []
|
||||
|
||||
# Check for suspicious user agents
|
||||
ua_lower = user_agent.lower()
|
||||
for pattern in self.compiled_ua_patterns:
|
||||
if pattern.search(ua_lower):
|
||||
threat = SecurityThreat(
|
||||
threat_type="suspicious_user_agent",
|
||||
level=ThreatLevel.HIGH,
|
||||
confidence=0.85,
|
||||
description=f"Suspicious user agent detected: {pattern.pattern}",
|
||||
source_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
mitigation="Block request, monitor IP for further activity"
|
||||
)
|
||||
threats.append(threat)
|
||||
break
|
||||
|
||||
# Check for suspicious headers
|
||||
if "x-forwarded-for" in headers and "x-real-ip" in headers:
|
||||
# Potential header manipulation
|
||||
threat = SecurityThreat(
|
||||
threat_type="header_manipulation",
|
||||
level=ThreatLevel.LOW,
|
||||
confidence=0.30,
|
||||
description="Potential IP header manipulation detected",
|
||||
source_ip=client_ip,
|
||||
mitigation="Validate proxy headers, implement IP whitelisting"
|
||||
)
|
||||
threats.append(threat)
|
||||
|
||||
return threats
|
||||
|
||||
async def _detect_anomalies(self, client_ip: str, path: str, method: str, body_size: int) -> AnomalyDetection:
|
||||
"""Detect anomalous behavior patterns"""
|
||||
try:
|
||||
# Request size anomaly
|
||||
max_size = settings.API_MAX_REQUEST_BODY_SIZE
|
||||
if body_size > max_size:
|
||||
return AnomalyDetection(
|
||||
is_anomaly=True,
|
||||
anomaly_type="request_size",
|
||||
severity=0.8,
|
||||
details={"body_size": body_size, "threshold": max_size},
|
||||
current_value=body_size,
|
||||
baseline_value=max_size // 10
|
||||
)
|
||||
|
||||
# Unusual endpoint access
|
||||
if path.startswith("/admin") or path.startswith("/api/admin"):
|
||||
return AnomalyDetection(
|
||||
is_anomaly=True,
|
||||
anomaly_type="sensitive_endpoint",
|
||||
severity=0.6,
|
||||
details={"path": path, "reason": "admin endpoint access"},
|
||||
current_value=1.0,
|
||||
baseline_value=0.0
|
||||
)
|
||||
|
||||
# IP request frequency anomaly
|
||||
current_time = time.time()
|
||||
ip_requests = self.ip_history[client_ip]
|
||||
|
||||
# Clean old entries (last 5 minutes)
|
||||
five_minutes_ago = current_time - 300
|
||||
while ip_requests and ip_requests[0] < five_minutes_ago:
|
||||
ip_requests.popleft()
|
||||
|
||||
ip_requests.append(current_time)
|
||||
|
||||
if len(ip_requests) > 100: # More than 100 requests in 5 minutes
|
||||
return AnomalyDetection(
|
||||
is_anomaly=True,
|
||||
anomaly_type="request_frequency",
|
||||
severity=0.7,
|
||||
details={"requests_5min": len(ip_requests), "threshold": 100},
|
||||
current_value=len(ip_requests),
|
||||
baseline_value=10 # 10 requests baseline
|
||||
)
|
||||
|
||||
return AnomalyDetection(
|
||||
is_anomaly=False,
|
||||
anomaly_type="none",
|
||||
severity=0.0,
|
||||
details={}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in anomaly detection: {e}")
|
||||
return AnomalyDetection(
|
||||
is_anomaly=False,
|
||||
anomaly_type="error",
|
||||
severity=0.0,
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
def _calculate_risk_score(self, threats: List[SecurityThreat]) -> float:
|
||||
"""Calculate overall risk score based on threats"""
|
||||
if not threats:
|
||||
return 0.0
|
||||
|
||||
score = 0.0
|
||||
for threat in threats:
|
||||
level_multiplier = {
|
||||
ThreatLevel.LOW: 0.25,
|
||||
ThreatLevel.MEDIUM: 0.5,
|
||||
ThreatLevel.HIGH: 0.75,
|
||||
ThreatLevel.CRITICAL: 1.0
|
||||
}
|
||||
score += threat.confidence * level_multiplier.get(threat.level, 0.5)
|
||||
|
||||
# Normalize to 0-1 range
|
||||
return min(score / len(threats), 1.0)
|
||||
|
||||
def _generate_recommendations(self, threats: List[SecurityThreat], risk_score: float, auth_level: AuthLevel) -> List[str]:
|
||||
"""Generate security recommendations based on analysis"""
|
||||
recommendations = []
|
||||
|
||||
if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
|
||||
recommendations.append("CRITICAL: Block this request immediately")
|
||||
elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
|
||||
recommendations.append("HIGH: Consider blocking or rate limiting this IP")
|
||||
elif risk_score > 0.4:
|
||||
recommendations.append("MEDIUM: Monitor this IP closely")
|
||||
|
||||
threat_types = {threat.threat_type for threat in threats}
|
||||
|
||||
if "sql_injection" in threat_types:
|
||||
recommendations.append("Implement parameterized queries and input validation")
|
||||
|
||||
if "xss" in threat_types:
|
||||
recommendations.append("Implement Content Security Policy (CSP) headers")
|
||||
|
||||
if "command_injection" in threat_types:
|
||||
recommendations.append("Disable shell execution and validate all inputs")
|
||||
|
||||
if "path_traversal" in threat_types:
|
||||
recommendations.append("Implement proper file path validation and access controls")
|
||||
|
||||
if "rate_limit_exceeded" in threat_types:
|
||||
recommendations.append(f"Rate limiting active for {auth_level.value} user")
|
||||
|
||||
if not recommendations:
|
||||
recommendations.append("No immediate action required, continue monitoring")
|
||||
|
||||
return recommendations
|
||||
|
||||
def _update_stats(self, threats: List[SecurityThreat], analysis_time: float):
|
||||
"""Update service statistics"""
|
||||
self.stats['total_requests_analyzed'] += 1
|
||||
self.stats['total_analysis_time'] += analysis_time
|
||||
|
||||
if threats:
|
||||
self.stats['threats_detected'] += len(threats)
|
||||
for threat in threats:
|
||||
self.stats['threat_types'][threat.threat_type] += 1
|
||||
self.stats['threat_levels'][threat.level.value] += 1
|
||||
if threat.source_ip:
|
||||
self.stats['attacking_ips'][threat.source_ip] += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get service statistics"""
|
||||
avg_time = (self.stats['total_analysis_time'] / self.stats['total_requests_analyzed']
|
||||
if self.stats['total_requests_analyzed'] > 0 else 0)
|
||||
|
||||
# Get top attacking IPs
|
||||
top_ips = sorted(self.stats['attacking_ips'].items(), key=lambda x: x[1], reverse=True)[:10]
|
||||
|
||||
return {
|
||||
"total_requests_analyzed": self.stats['total_requests_analyzed'],
|
||||
"threats_detected": self.stats['threats_detected'],
|
||||
"threats_blocked": self.stats['threats_blocked'],
|
||||
"anomalies_detected": self.stats['anomalies_detected'],
|
||||
"rate_limits_exceeded": self.stats['rate_limits_exceeded'],
|
||||
"avg_analysis_time": avg_time,
|
||||
"threat_types": dict(self.stats['threat_types']),
|
||||
"threat_levels": dict(self.stats['threat_levels']),
|
||||
"top_attacking_ips": top_ips,
|
||||
"security_enabled": settings.API_SECURITY_ENABLED,
|
||||
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
|
||||
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED
|
||||
}
|
||||
|
||||
|
||||
# Global threat detection service instance
|
||||
threat_detection_service = ThreatDetectionService()
|
||||
@@ -24,6 +24,7 @@ engine = create_async_engine(
|
||||
pool_recycle=3600, # Recycle connections every hour
|
||||
pool_timeout=30, # Max time to get connection from pool
|
||||
connect_args={
|
||||
"timeout": 5,
|
||||
"command_timeout": 5,
|
||||
"server_settings": {
|
||||
"application_name": "enclava_backend",
|
||||
@@ -49,6 +50,7 @@ sync_engine = create_engine(
|
||||
pool_recycle=3600, # Recycle connections every hour
|
||||
pool_timeout=30, # Max time to get connection from pool
|
||||
connect_args={
|
||||
"connect_timeout": 5,
|
||||
"application_name": "enclava_backend_sync",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""
|
||||
Main FastAPI application entry point
|
||||
"""
|
||||
"""Main FastAPI application entry point"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
@@ -14,10 +14,13 @@ from fastapi.exceptions import RequestValidationError
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import setup_logging
|
||||
from app.core.security import get_current_user
|
||||
from app.db.database import init_db
|
||||
from app.db.database import init_db, async_session_factory
|
||||
from app.api.internal_v1 import internal_api_router
|
||||
from app.api.public_v1 import public_api_router
|
||||
from app.utils.exceptions import CustomHTTPException
|
||||
@@ -32,6 +35,68 @@ setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _check_redis_startup():
|
||||
"""Validate Redis connectivity during startup."""
|
||||
if not settings.REDIS_URL:
|
||||
logger.info("Startup Redis check skipped: REDIS_URL not configured")
|
||||
return
|
||||
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
except ModuleNotFoundError:
|
||||
logger.warning("Startup Redis check skipped: redis library not installed")
|
||||
return
|
||||
|
||||
client = redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
socket_connect_timeout=1.0,
|
||||
socket_timeout=1.0,
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
await asyncio.wait_for(client.ping(), timeout=3.0)
|
||||
duration = time.perf_counter() - start
|
||||
logger.info(
|
||||
"Startup Redis check succeeded",
|
||||
extra={"redis_url": settings.REDIS_URL, "duration_seconds": round(duration, 3)},
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Startup Redis check failed",
|
||||
extra={"error": str(exc), "redis_url": settings.REDIS_URL},
|
||||
)
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def _check_database_startup():
|
||||
"""Validate database connectivity during startup."""
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
await asyncio.wait_for(session.execute(select(1)), timeout=3.0)
|
||||
duration = time.perf_counter() - start
|
||||
logger.info(
|
||||
"Startup database check succeeded",
|
||||
extra={"duration_seconds": round(duration, 3)},
|
||||
)
|
||||
except (asyncio.TimeoutError, SQLAlchemyError) as exc:
|
||||
logger.error(
|
||||
"Startup database check failed",
|
||||
extra={"error": str(exc)},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def run_startup_dependency_checks():
|
||||
"""Run dependency checks once during application startup."""
|
||||
logger.info("Running startup dependency checks...")
|
||||
await _check_redis_startup()
|
||||
await _check_database_startup()
|
||||
logger.info("Startup dependency checks complete")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
@@ -47,12 +112,27 @@ async def lifespan(app: FastAPI):
|
||||
except Exception as e:
|
||||
logger.warning(f"Core cache service initialization failed: {e}")
|
||||
|
||||
# Run one-time dependency checks (non-blocking for auth requests)
|
||||
try:
|
||||
await run_startup_dependency_checks()
|
||||
except Exception:
|
||||
logger.error("Critical dependency check failed during startup")
|
||||
raise
|
||||
|
||||
# Initialize database
|
||||
await init_db()
|
||||
|
||||
# Initialize config manager
|
||||
await init_config_manager()
|
||||
|
||||
# Initialize LLM service (needed by RAG module)
|
||||
from app.services.llm.service import llm_service
|
||||
try:
|
||||
await llm_service.initialize()
|
||||
logger.info("LLM service initialized successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM service initialization failed: {e}")
|
||||
|
||||
# Initialize analytics service
|
||||
init_analytics_service()
|
||||
|
||||
@@ -143,13 +223,9 @@ app.add_middleware(
|
||||
# Add analytics middleware
|
||||
setup_analytics_middleware(app)
|
||||
|
||||
# Add debugging middleware for detailed request/response logging
|
||||
from app.middleware.debugging import setup_debugging_middleware
|
||||
setup_debugging_middleware(app)
|
||||
# Security middleware disabled - handled externally
|
||||
|
||||
# Add security middleware
|
||||
from app.middleware.security import setup_security_middleware
|
||||
setup_security_middleware(app, enabled=settings.API_SECURITY_ENABLED)
|
||||
# Rate limiting middleware disabled - handled externally
|
||||
|
||||
|
||||
# Exception handlers
|
||||
|
||||
@@ -1,313 +0,0 @@
|
||||
"""
|
||||
Rate limiting middleware
|
||||
"""
|
||||
|
||||
import time
|
||||
import redis
|
||||
from typing import Dict, Optional
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiting implementation using Redis"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
self.redis_client.ping() # Test connection
|
||||
logger.info("Rate limiter initialized with Redis backend")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis not available for rate limiting: {e}")
|
||||
self.redis_client = None
|
||||
# Fall back to in-memory rate limiting
|
||||
self.memory_store: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
async def check_rate_limit(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
identifier: str = "default"
|
||||
) -> tuple[bool, Dict[str, int]]:
|
||||
"""
|
||||
Check if request is within rate limit
|
||||
|
||||
Args:
|
||||
key: Rate limiting key (e.g., IP address, API key)
|
||||
limit: Maximum number of requests allowed
|
||||
window_seconds: Time window in seconds
|
||||
identifier: Additional identifier for the rate limit
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, headers_dict)
|
||||
"""
|
||||
|
||||
full_key = f"rate_limit:{identifier}:{key}"
|
||||
current_time = int(time.time())
|
||||
window_start = current_time - window_seconds
|
||||
|
||||
if self.redis_client:
|
||||
return await self._check_redis_rate_limit(
|
||||
full_key, limit, window_seconds, current_time, window_start
|
||||
)
|
||||
else:
|
||||
return self._check_memory_rate_limit(
|
||||
full_key, limit, window_seconds, current_time, window_start
|
||||
)
|
||||
|
||||
async def _check_redis_rate_limit(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
current_time: int,
|
||||
window_start: int
|
||||
) -> tuple[bool, Dict[str, int]]:
|
||||
"""Check rate limit using Redis"""
|
||||
|
||||
pipe = self.redis_client.pipeline()
|
||||
|
||||
# Remove old entries
|
||||
pipe.zremrangebyscore(key, 0, window_start)
|
||||
|
||||
# Count current requests in window
|
||||
pipe.zcard(key)
|
||||
|
||||
# Add current request
|
||||
pipe.zadd(key, {str(current_time): current_time})
|
||||
|
||||
# Set expiration
|
||||
pipe.expire(key, window_seconds + 1)
|
||||
|
||||
results = pipe.execute()
|
||||
current_requests = results[1]
|
||||
|
||||
# Calculate remaining requests and reset time
|
||||
remaining = max(0, limit - current_requests - 1)
|
||||
reset_time = current_time + window_seconds
|
||||
|
||||
headers = {
|
||||
"X-RateLimit-Limit": limit,
|
||||
"X-RateLimit-Remaining": remaining,
|
||||
"X-RateLimit-Reset": reset_time,
|
||||
"X-RateLimit-Window": window_seconds
|
||||
}
|
||||
|
||||
is_allowed = current_requests < limit
|
||||
|
||||
if not is_allowed:
|
||||
logger.warning(f"Rate limit exceeded for key: {key}")
|
||||
|
||||
return is_allowed, headers
|
||||
|
||||
def _check_memory_rate_limit(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
current_time: int,
|
||||
window_start: int
|
||||
) -> tuple[bool, Dict[str, int]]:
|
||||
"""Check rate limit using in-memory storage"""
|
||||
|
||||
if key not in self.memory_store:
|
||||
self.memory_store[key] = {}
|
||||
|
||||
# Clean old entries
|
||||
store = self.memory_store[key]
|
||||
keys_to_remove = [k for k, v in store.items() if v < window_start]
|
||||
for k in keys_to_remove:
|
||||
del store[k]
|
||||
|
||||
current_requests = len(store)
|
||||
|
||||
# Calculate remaining requests and reset time
|
||||
remaining = max(0, limit - current_requests - 1)
|
||||
reset_time = current_time + window_seconds
|
||||
|
||||
headers = {
|
||||
"X-RateLimit-Limit": limit,
|
||||
"X-RateLimit-Remaining": remaining,
|
||||
"X-RateLimit-Reset": reset_time,
|
||||
"X-RateLimit-Window": window_seconds
|
||||
}
|
||||
|
||||
is_allowed = current_requests < limit
|
||||
|
||||
if is_allowed:
|
||||
# Add current request
|
||||
store[str(current_time)] = current_time
|
||||
else:
|
||||
logger.warning(f"Rate limit exceeded for key: {key}")
|
||||
|
||||
return is_allowed, headers
|
||||
|
||||
|
||||
# Global rate limiter instance
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
"""
|
||||
Rate limiting middleware for FastAPI
|
||||
"""
|
||||
|
||||
# Skip rate limiting for health checks and static files
|
||||
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Check for API key in headers
|
||||
api_key = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
api_key = auth_header[7:]
|
||||
elif request.headers.get("X-API-Key"):
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
|
||||
# Determine rate limiting strategy
|
||||
if api_key:
|
||||
# API key-based rate limiting
|
||||
rate_limit_key = f"api_key:{api_key}"
|
||||
|
||||
# Get API key limits from database (simplified - would implement proper lookup)
|
||||
limit_per_minute = 100 # Default limit
|
||||
limit_per_hour = 1000 # Default limit
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
is_allowed = is_allowed_minute and is_allowed_hour
|
||||
headers = headers_minute # Use minute headers for response
|
||||
|
||||
else:
|
||||
# IP-based rate limiting for unauthenticated requests
|
||||
rate_limit_key = f"ip:{client_ip}"
|
||||
|
||||
# More restrictive limits for unauthenticated requests
|
||||
limit_per_minute = 20
|
||||
limit_per_hour = 100
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
is_allowed = is_allowed_minute and is_allowed_hour
|
||||
headers = headers_minute # Use minute headers for response
|
||||
|
||||
# If rate limit exceeded, return 429
|
||||
if not is_allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"error": "RATE_LIMIT_EXCEEDED",
|
||||
"message": "Rate limit exceeded. Please try again later.",
|
||||
"details": {
|
||||
"limit": headers["X-RateLimit-Limit"],
|
||||
"reset_time": headers["X-RateLimit-Reset"]
|
||||
}
|
||||
},
|
||||
headers={k: str(v) for k, v in headers.items()}
|
||||
)
|
||||
|
||||
# Continue with request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers to response
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = str(value)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RateLimitExceeded(HTTPException):
|
||||
"""Exception raised when rate limit is exceeded"""
|
||||
|
||||
def __init__(self, limit: int, reset_time: int):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Rate limit exceeded. Limit: {limit}, Reset: {reset_time}"
|
||||
)
|
||||
|
||||
|
||||
# Decorator for applying rate limits to specific endpoints
|
||||
def rate_limit(requests_per_minute: int = 60, requests_per_hour: int = 1000):
|
||||
"""
|
||||
Decorator to apply rate limiting to specific endpoints
|
||||
|
||||
Args:
|
||||
requests_per_minute: Maximum requests per minute
|
||||
requests_per_hour: Maximum requests per hour
|
||||
"""
|
||||
def decorator(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
# This would be implemented to work with FastAPI dependencies
|
||||
# For now, this is a placeholder for endpoint-specific rate limiting
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# Helper functions for different rate limiting strategies
|
||||
async def check_api_key_rate_limit(api_key: str, endpoint: str) -> bool:
|
||||
"""Check rate limit for specific API key and endpoint"""
|
||||
|
||||
# This would lookup API key specific limits from database
|
||||
# For now, using default limits
|
||||
key = f"api_key:{api_key}:endpoint:{endpoint}"
|
||||
|
||||
is_allowed, _ = await rate_limiter.check_rate_limit(
|
||||
key, limit=100, window_seconds=60, identifier="endpoint"
|
||||
)
|
||||
|
||||
return is_allowed
|
||||
|
||||
|
||||
async def check_user_rate_limit(user_id: str, action: str) -> bool:
|
||||
"""Check rate limit for specific user and action"""
|
||||
|
||||
key = f"user:{user_id}:action:{action}"
|
||||
|
||||
is_allowed, _ = await rate_limiter.check_rate_limit(
|
||||
key, limit=50, window_seconds=60, identifier="user_action"
|
||||
)
|
||||
|
||||
return is_allowed
|
||||
|
||||
|
||||
async def apply_burst_protection(key: str) -> bool:
|
||||
"""Apply burst protection for high-frequency actions"""
|
||||
|
||||
# Allow burst of 10 requests in 10 seconds
|
||||
is_allowed, _ = await rate_limiter.check_rate_limit(
|
||||
key, limit=10, window_seconds=10, identifier="burst"
|
||||
)
|
||||
|
||||
return is_allowed
|
||||
@@ -1,286 +0,0 @@
|
||||
"""
|
||||
Security middleware for request/response processing
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Callable, Optional, Dict, Any
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.core.threat_detection import threat_detection_service, SecurityAnalysis
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SecurityMiddleware(BaseHTTPMiddleware):
|
||||
"""Security middleware for threat detection and request filtering"""
|
||||
|
||||
def __init__(self, app, enabled: bool = True):
|
||||
super().__init__(app)
|
||||
self.enabled = enabled and settings.API_SECURITY_ENABLED
|
||||
logger.info(f"SecurityMiddleware initialized, enabled: {self.enabled}")
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""Process request through security analysis"""
|
||||
if not self.enabled:
|
||||
# Security disabled, pass through
|
||||
return await call_next(request)
|
||||
|
||||
# Skip security analysis for certain endpoints
|
||||
if self._should_skip_security(request):
|
||||
response = await call_next(request)
|
||||
return self._add_security_headers(response)
|
||||
|
||||
# Simple authentication check - drop requests without valid auth
|
||||
if not self._has_valid_auth(request):
|
||||
return JSONResponse(
|
||||
content={"error": "Authentication required", "message": "Valid API key or authentication token required"},
|
||||
status_code=401,
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
try:
|
||||
# Get user context if available
|
||||
user_context = getattr(request.state, 'user', None)
|
||||
|
||||
# Perform security analysis
|
||||
start_time = time.time()
|
||||
analysis = await threat_detection_service.analyze_request(request, user_context)
|
||||
analysis_time = time.time() - start_time
|
||||
|
||||
# Store analysis in request state for later use
|
||||
request.state.security_analysis = analysis
|
||||
|
||||
# Log security events (only for significant threats to reduce false positive noise)
|
||||
# Only log if: being blocked OR risk score above warning threshold (0.6)
|
||||
if analysis.is_threat and (analysis.should_block or analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD):
|
||||
await self._log_security_event(request, analysis)
|
||||
|
||||
# Check if request should be blocked
|
||||
if analysis.should_block:
|
||||
threat_detection_service.stats['threats_blocked'] += 1
|
||||
logger.warning(f"Blocked request from {request.client.host if request.client else 'unknown'}: "
|
||||
f"risk_score={analysis.risk_score:.3f}, threats={len(analysis.threats)}")
|
||||
|
||||
# Return security block response
|
||||
return self._create_block_response(analysis)
|
||||
|
||||
# Log warnings for medium-risk requests
|
||||
if analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
|
||||
logger.warning(f"High-risk request detected from {request.client.host if request.client else 'unknown'}: "
|
||||
f"risk_score={analysis.risk_score:.3f}, auth_level={analysis.auth_level.value}")
|
||||
|
||||
# Continue with request processing
|
||||
response = await call_next(request)
|
||||
|
||||
# Add security headers and metrics
|
||||
response = self._add_security_headers(response)
|
||||
response = self._add_security_metrics(response, analysis, analysis_time)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Security middleware error: {e}")
|
||||
# Continue with request on security middleware errors to avoid breaking the app
|
||||
response = await call_next(request)
|
||||
return self._add_security_headers(response)
|
||||
|
||||
def _should_skip_security(self, request: Request) -> bool:
|
||||
"""Determine if security analysis should be skipped for this request"""
|
||||
path = request.url.path
|
||||
|
||||
# Skip for health checks, authentication endpoints, and static assets
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/v1/docs",
|
||||
"/api/v1/openapi.json",
|
||||
"/api/v1/redoc",
|
||||
"/favicon.ico",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/refresh", # Allow refresh endpoint
|
||||
"/api-internal/v1/auth/register",
|
||||
"/api-internal/v1/auth/login",
|
||||
"/api-internal/v1/auth/refresh", # Allow refresh endpoint for internal API
|
||||
"/", # Root endpoint
|
||||
]
|
||||
|
||||
# Skip for static file extensions
|
||||
static_extensions = [".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".ico", ".svg", ".woff", ".woff2"]
|
||||
|
||||
return (
|
||||
path in skip_paths or
|
||||
any(path.endswith(ext) for ext in static_extensions) or
|
||||
path.startswith("/static/")
|
||||
)
|
||||
|
||||
def _has_valid_auth(self, request: Request) -> bool:
|
||||
"""Check if request has valid authentication"""
|
||||
# Check Authorization header
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
api_key_header = request.headers.get("X-API-Key", "")
|
||||
|
||||
# Has some form of auth token/key
|
||||
return (
|
||||
auth_header.startswith("Bearer ") and len(auth_header) > 7 or
|
||||
len(api_key_header.strip()) > 0
|
||||
)
|
||||
|
||||
def _create_block_response(self, analysis: SecurityAnalysis) -> JSONResponse:
|
||||
"""Create response for blocked requests"""
|
||||
# Determine status code based on threat type
|
||||
status_code = 403 # Forbidden by default
|
||||
|
||||
# Rate limiting gets 429
|
||||
if analysis.rate_limit_exceeded:
|
||||
status_code = 429
|
||||
|
||||
# Critical threats get 403
|
||||
for threat in analysis.threats:
|
||||
if threat.threat_type in ["command_injection", "sql_injection"]:
|
||||
status_code = 403
|
||||
break
|
||||
|
||||
response_data = {
|
||||
"error": "Security Policy Violation",
|
||||
"message": "Request blocked due to security policy violation",
|
||||
"risk_score": round(analysis.risk_score, 3),
|
||||
"auth_level": analysis.auth_level.value,
|
||||
"threat_count": len(analysis.threats),
|
||||
"recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations
|
||||
}
|
||||
|
||||
# Add rate limiting info if applicable
|
||||
if analysis.rate_limit_exceeded:
|
||||
response_data["error"] = "Rate Limit Exceeded"
|
||||
response_data["message"] = f"Rate limit exceeded for {analysis.auth_level.value} user"
|
||||
response_data["retry_after"] = "60" # Suggest retry after 60 seconds
|
||||
|
||||
response = JSONResponse(
|
||||
content=response_data,
|
||||
status_code=status_code
|
||||
)
|
||||
|
||||
# Add rate limiting headers
|
||||
if analysis.rate_limit_exceeded:
|
||||
response.headers["Retry-After"] = "60"
|
||||
response.headers["X-RateLimit-Limit"] = "See API documentation"
|
||||
response.headers["X-RateLimit-Reset"] = str(int(time.time() + 60))
|
||||
|
||||
return response
|
||||
|
||||
def _add_security_headers(self, response: Response) -> Response:
|
||||
"""Add security headers to response"""
|
||||
if not settings.API_SECURITY_HEADERS_ENABLED:
|
||||
return response
|
||||
|
||||
# Standard security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
# Only add HSTS for HTTPS
|
||||
if hasattr(response, 'headers') and response.headers.get("X-Forwarded-Proto") == "https":
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
|
||||
# Content Security Policy
|
||||
if settings.API_CSP_HEADER:
|
||||
response.headers["Content-Security-Policy"] = settings.API_CSP_HEADER
|
||||
|
||||
return response
|
||||
|
||||
def _add_security_metrics(self, response: Response, analysis: SecurityAnalysis, analysis_time: float) -> Response:
|
||||
"""Add security metrics to response headers (for debugging/monitoring)"""
|
||||
# Only add in debug mode or for admin users
|
||||
if settings.APP_DEBUG:
|
||||
response.headers["X-Security-Risk-Score"] = str(round(analysis.risk_score, 3))
|
||||
response.headers["X-Security-Threats"] = str(len(analysis.threats))
|
||||
response.headers["X-Security-Auth-Level"] = analysis.auth_level.value
|
||||
response.headers["X-Security-Analysis-Time"] = f"{analysis_time*1000:.1f}ms"
|
||||
|
||||
return response
|
||||
|
||||
async def _log_security_event(self, request: Request, analysis: SecurityAnalysis):
|
||||
"""Log security events for audit and monitoring"""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
|
||||
# Create security event log
|
||||
event_data = {
|
||||
"timestamp": analysis.timestamp.isoformat(),
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"path": str(request.url.path),
|
||||
"method": request.method,
|
||||
"risk_score": round(analysis.risk_score, 3),
|
||||
"auth_level": analysis.auth_level.value,
|
||||
"threat_count": len(analysis.threats),
|
||||
"rate_limit_exceeded": analysis.rate_limit_exceeded,
|
||||
"should_block": analysis.should_block,
|
||||
"threats": [
|
||||
{
|
||||
"type": threat.threat_type,
|
||||
"level": threat.level.value,
|
||||
"confidence": round(threat.confidence, 3),
|
||||
"description": threat.description
|
||||
}
|
||||
for threat in analysis.threats[:5] # Limit to first 5 threats
|
||||
],
|
||||
"recommendations": analysis.recommendations
|
||||
}
|
||||
|
||||
# Log at appropriate level based on risk
|
||||
if analysis.should_block:
|
||||
logger.warning(f"SECURITY_BLOCK: {json.dumps(event_data)}")
|
||||
elif analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
|
||||
logger.warning(f"SECURITY_WARNING: {json.dumps(event_data)}")
|
||||
else:
|
||||
logger.info(f"SECURITY_THREAT: {json.dumps(event_data)}")
|
||||
|
||||
|
||||
def setup_security_middleware(app, enabled: bool = True) -> None:
|
||||
"""Setup security middleware on FastAPI app"""
|
||||
if enabled and settings.API_SECURITY_ENABLED:
|
||||
app.add_middleware(SecurityMiddleware, enabled=enabled)
|
||||
logger.info("Security middleware enabled")
|
||||
else:
|
||||
logger.info("Security middleware disabled")
|
||||
|
||||
|
||||
# Helper functions for manual security checks
|
||||
async def analyze_request_security(request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
|
||||
"""Manually analyze request security (for use in route handlers)"""
|
||||
return await threat_detection_service.analyze_request(request, user_context)
|
||||
|
||||
|
||||
def get_security_stats() -> Dict[str, Any]:
|
||||
"""Get security statistics"""
|
||||
return threat_detection_service.get_stats()
|
||||
|
||||
|
||||
def is_request_blocked(request: Request) -> bool:
|
||||
"""Check if request was blocked by security analysis"""
|
||||
if hasattr(request.state, 'security_analysis'):
|
||||
return request.state.security_analysis.should_block
|
||||
return False
|
||||
|
||||
|
||||
def get_request_risk_score(request: Request) -> float:
|
||||
"""Get risk score for request"""
|
||||
if hasattr(request.state, 'security_analysis'):
|
||||
return request.state.security_analysis.risk_score
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_request_auth_level(request: Request) -> str:
|
||||
"""Get authentication level for request"""
|
||||
if hasattr(request.state, 'security_analysis'):
|
||||
return request.state.security_analysis.auth_level.value
|
||||
return "unknown"
|
||||
21
backend/app/modules/chatbot/__init__.py
Normal file
21
backend/app/modules/chatbot/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Chatbot Module - AI Chatbot with RAG Integration
|
||||
|
||||
This module provides AI chatbot capabilities with:
|
||||
- Multiple personality types (Assistant, Customer Support, Teacher, etc.)
|
||||
- RAG integration for knowledge-based responses
|
||||
- Conversation memory and context management
|
||||
- Workflow integration as building blocks
|
||||
- UI-configurable settings
|
||||
"""
|
||||
|
||||
from .main import ChatbotModule, create_module
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Enclava Team"
|
||||
|
||||
# Export main classes for easy importing
|
||||
__all__ = [
|
||||
"ChatbotModule",
|
||||
"create_module"
|
||||
]
|
||||
126
backend/app/modules/chatbot/config_schema.json
Normal file
126
backend/app/modules/chatbot/config_schema.json
Normal file
@@ -0,0 +1,126 @@
|
||||
{
|
||||
"title": "Chatbot Configuration",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"title": "Chatbot Name",
|
||||
"description": "Display name for this chatbot instance",
|
||||
"minLength": 1,
|
||||
"maxLength": 100
|
||||
},
|
||||
"chatbot_type": {
|
||||
"type": "string",
|
||||
"title": "Chatbot Type",
|
||||
"description": "Select the type of chatbot personality",
|
||||
"enum": ["assistant", "customer_support", "teacher", "researcher", "creative_writer", "custom"],
|
||||
"enumNames": ["General Assistant", "Customer Support", "Teacher", "Researcher", "Creative Writer", "Custom"],
|
||||
"default": "assistant"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"title": "AI Model",
|
||||
"description": "Choose the LLM model for responses",
|
||||
"enum": ["gpt-4", "gpt-3.5-turbo", "claude-3-sonnet", "claude-3-opus", "llama-70b"],
|
||||
"default": "gpt-3.5-turbo"
|
||||
},
|
||||
"system_prompt": {
|
||||
"type": "string",
|
||||
"title": "System Prompt",
|
||||
"description": "Define the chatbot's personality and behavior instructions",
|
||||
"ui:widget": "textarea",
|
||||
"ui:options": {
|
||||
"rows": 6,
|
||||
"placeholder": "You are a helpful AI assistant..."
|
||||
}
|
||||
},
|
||||
"use_rag": {
|
||||
"type": "boolean",
|
||||
"title": "Enable Knowledge Base",
|
||||
"description": "Use RAG to search knowledge base for context",
|
||||
"default": false
|
||||
},
|
||||
"rag_collection": {
|
||||
"type": "string",
|
||||
"title": "Knowledge Base Collection",
|
||||
"description": "Select which document collection to search",
|
||||
"ui:widget": "rag-collection-selector",
|
||||
"ui:condition": "use_rag === true"
|
||||
},
|
||||
"rag_top_k": {
|
||||
"type": "integer",
|
||||
"title": "Knowledge Base Results",
|
||||
"description": "Number of relevant documents to include",
|
||||
"minimum": 1,
|
||||
"maximum": 10,
|
||||
"default": 5,
|
||||
"ui:condition": "use_rag === true"
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"title": "Response Creativity",
|
||||
"description": "Controls randomness (0.0 = focused, 1.0 = creative)",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
"default": 0.7,
|
||||
"ui:widget": "range",
|
||||
"ui:options": {
|
||||
"step": 0.1
|
||||
}
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Maximum Response Length",
|
||||
"description": "Maximum number of tokens in response",
|
||||
"minimum": 50,
|
||||
"maximum": 4000,
|
||||
"default": 1000,
|
||||
"ui:widget": "range",
|
||||
"ui:options": {
|
||||
"step": 50
|
||||
}
|
||||
},
|
||||
"memory_length": {
|
||||
"type": "integer",
|
||||
"title": "Conversation Memory",
|
||||
"description": "Number of previous message pairs to remember",
|
||||
"minimum": 1,
|
||||
"maximum": 50,
|
||||
"default": 10,
|
||||
"ui:widget": "range"
|
||||
},
|
||||
"fallback_responses": {
|
||||
"type": "array",
|
||||
"title": "Fallback Responses",
|
||||
"description": "Responses to use when the AI cannot answer",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"title": "Fallback Response"
|
||||
},
|
||||
"default": [
|
||||
"I'm not sure how to help with that. Could you please rephrase your question?",
|
||||
"I don't have enough information to answer that question accurately.",
|
||||
"That's outside my knowledge area. Is there something else I can help you with?"
|
||||
],
|
||||
"ui:options": {
|
||||
"orderable": true,
|
||||
"addable": true,
|
||||
"removable": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["name", "chatbot_type", "model"],
|
||||
"ui:order": [
|
||||
"name",
|
||||
"chatbot_type",
|
||||
"model",
|
||||
"system_prompt",
|
||||
"use_rag",
|
||||
"rag_collection",
|
||||
"rag_top_k",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"memory_length",
|
||||
"fallback_responses"
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
{
|
||||
"name": "Customer Support Workflow",
|
||||
"description": "Intelligent customer support workflow with intent classification, knowledge base search, and chatbot response generation",
|
||||
"version": "1.0",
|
||||
"variables": {
|
||||
"support_chatbot_id": "cs-bot-001",
|
||||
"escalation_threshold": 0.3,
|
||||
"max_attempts": 3
|
||||
},
|
||||
"steps": [
|
||||
{
|
||||
"id": "classify_intent",
|
||||
"name": "Classify Customer Intent",
|
||||
"type": "llm_call",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an intent classifier for customer support. Classify the customer message into one of these categories: technical_issue, billing_question, feature_request, complaint, general_inquiry. Also provide a confidence score between 0 and 1. Respond with JSON: {\"intent\": \"category\", \"confidence\": 0.95, \"reasoning\": \"explanation\"}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "{{ inputs.customer_message }}"
|
||||
}
|
||||
],
|
||||
"output_variable": "intent_classification"
|
||||
},
|
||||
|
||||
{
|
||||
"id": "search_knowledge_base",
|
||||
"name": "Search Knowledge Base",
|
||||
"type": "workflow_step",
|
||||
"module": "rag",
|
||||
"action": "search",
|
||||
"config": {
|
||||
"query": "{{ inputs.customer_message }}",
|
||||
"collection": "support_documentation",
|
||||
"top_k": 5,
|
||||
"include_metadata": true
|
||||
},
|
||||
"output_variable": "knowledge_results"
|
||||
},
|
||||
|
||||
{
|
||||
"id": "check_confidence",
|
||||
"name": "Check Intent Confidence",
|
||||
"type": "condition",
|
||||
"condition": "JSON.parse(steps.classify_intent.result).confidence > variables.escalation_threshold",
|
||||
"true_steps": [
|
||||
{
|
||||
"id": "generate_chatbot_response",
|
||||
"name": "Generate Chatbot Response",
|
||||
"type": "workflow_step",
|
||||
"module": "chatbot",
|
||||
"action": "workflow_chat_step",
|
||||
"config": {
|
||||
"message": "{{ inputs.customer_message }}",
|
||||
"chatbot_id": "{{ variables.support_chatbot_id }}",
|
||||
"use_rag": true,
|
||||
"context": {
|
||||
"intent": "{{ steps.classify_intent.result }}",
|
||||
"knowledge_base_results": "{{ steps.search_knowledge_base.result }}",
|
||||
"customer_history": "{{ inputs.customer_history }}",
|
||||
"additional_instructions": "Be empathetic and professional. If you cannot fully resolve the issue, offer to escalate to a human agent."
|
||||
}
|
||||
},
|
||||
"output_variable": "chatbot_response"
|
||||
},
|
||||
|
||||
{
|
||||
"id": "analyze_response_quality",
|
||||
"name": "Analyze Response Quality",
|
||||
"type": "llm_call",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Analyze if this customer support response adequately addresses the customer's question. Consider completeness, accuracy, and helpfulness. Respond with JSON: {\"quality_score\": 0.85, \"is_adequate\": true, \"requires_escalation\": false, \"reasoning\": \"explanation\"}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Customer Question: {{ inputs.customer_message }}\\n\\nChatbot Response: {{ steps.generate_chatbot_response.result.response }}\\n\\nKnowledge Base Context: {{ steps.search_knowledge_base.result }}"
|
||||
}
|
||||
],
|
||||
"output_variable": "response_quality"
|
||||
},
|
||||
|
||||
{
|
||||
"id": "final_response_decision",
|
||||
"name": "Final Response Decision",
|
||||
"type": "condition",
|
||||
"condition": "JSON.parse(steps.analyze_response_quality.result).is_adequate === true",
|
||||
"true_steps": [
|
||||
{
|
||||
"id": "send_chatbot_response",
|
||||
"name": "Send Chatbot Response",
|
||||
"type": "output",
|
||||
"config": {
|
||||
"response_type": "chatbot_response",
|
||||
"message": "{{ steps.generate_chatbot_response.result.response }}",
|
||||
"sources": "{{ steps.generate_chatbot_response.result.sources }}",
|
||||
"confidence": "{{ JSON.parse(steps.classify_intent.result).confidence }}",
|
||||
"quality_score": "{{ JSON.parse(steps.analyze_response_quality.result).quality_score }}"
|
||||
}
|
||||
}
|
||||
],
|
||||
"false_steps": [
|
||||
{
|
||||
"id": "escalate_to_human",
|
||||
"name": "Escalate to Human Agent",
|
||||
"type": "output",
|
||||
"config": {
|
||||
"response_type": "human_escalation",
|
||||
"message": "I'd like to connect you with one of our human support agents who can better assist with your specific situation. Please hold on while I transfer you.",
|
||||
"escalation_reason": "Response quality below threshold",
|
||||
"intent": "{{ steps.classify_intent.result }}",
|
||||
"attempted_response": "{{ steps.generate_chatbot_response.result.response }}",
|
||||
"priority": "normal"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"false_steps": [
|
||||
{
|
||||
"id": "low_confidence_escalation",
|
||||
"name": "Low Confidence Escalation",
|
||||
"type": "output",
|
||||
"config": {
|
||||
"response_type": "human_escalation",
|
||||
"message": "I want to make sure you get the best possible help. Let me connect you with one of our human support agents.",
|
||||
"escalation_reason": "Low intent classification confidence",
|
||||
"intent": "{{ steps.classify_intent.result }}",
|
||||
"priority": "high"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
{
|
||||
"id": "log_interaction",
|
||||
"name": "Log Customer Interaction",
|
||||
"type": "workflow_step",
|
||||
"module": "analytics",
|
||||
"action": "log_event",
|
||||
"config": {
|
||||
"event_type": "customer_support_interaction",
|
||||
"data": {
|
||||
"customer_message": "{{ inputs.customer_message }}",
|
||||
"intent_classification": "{{ steps.classify_intent.result }}",
|
||||
"response_generated": "{{ steps.generate_chatbot_response.result.response }}",
|
||||
"knowledge_base_used": "{{ steps.search_knowledge_base.result }}",
|
||||
"escalated": "{{ outputs.response_type === 'human_escalation' }}",
|
||||
"workflow_execution_time": "{{ execution_time }}",
|
||||
"timestamp": "{{ current_timestamp }}"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
"outputs": {
|
||||
"response_type": "string",
|
||||
"message": "string",
|
||||
"sources": "array",
|
||||
"escalation_reason": "string",
|
||||
"confidence": "number",
|
||||
"quality_score": "number"
|
||||
},
|
||||
|
||||
"error_handling": {
|
||||
"retry_failed_steps": true,
|
||||
"max_retries": 2,
|
||||
"fallback_response": "I apologize, but I'm experiencing technical difficulties. Please contact our support team directly for assistance."
|
||||
},
|
||||
|
||||
"metadata": {
|
||||
"created_by": "support_team",
|
||||
"use_case": "customer_support_automation",
|
||||
"tags": ["customer_support", "chatbot", "rag", "escalation"],
|
||||
"estimated_execution_time": "5-15 seconds"
|
||||
}
|
||||
}
|
||||
949
backend/app/modules/chatbot/main.py
Normal file
949
backend/app/modules/chatbot/main.py
Normal file
@@ -0,0 +1,949 @@
|
||||
"""
|
||||
Chatbot Module Implementation
|
||||
|
||||
Provides AI chatbot capabilities with:
|
||||
- RAG integration for knowledge-based responses
|
||||
- Custom prompts and personalities
|
||||
- Conversation memory and context
|
||||
- Workflow integration as building blocks
|
||||
- UI-configurable settings
|
||||
"""
|
||||
|
||||
import json
|
||||
from pprint import pprint
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
|
||||
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError
|
||||
from app.services.base_module import BaseModule, Permission
|
||||
from app.models.user import User
|
||||
from app.models.chatbot import ChatbotInstance as DBChatbotInstance, ChatbotConversation as DBConversation, ChatbotMessage as DBMessage, ChatbotAnalytics
|
||||
from app.core.security import get_current_user
|
||||
from app.db.database import get_db
|
||||
from app.core.config import settings
|
||||
|
||||
# Import protocols for type hints and dependency injection
|
||||
from ..protocols import RAGServiceProtocol
|
||||
# Note: LiteLLMClientProtocol replaced with direct LLM service usage
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ChatbotType(str, Enum):
|
||||
"""Types of chatbot personalities"""
|
||||
ASSISTANT = "assistant"
|
||||
CUSTOMER_SUPPORT = "customer_support"
|
||||
TEACHER = "teacher"
|
||||
RESEARCHER = "researcher"
|
||||
CREATIVE_WRITER = "creative_writer"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Message roles in conversation"""
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatbotConfig:
|
||||
"""Chatbot configuration"""
|
||||
name: str
|
||||
chatbot_type: str # Changed from ChatbotType enum to str to allow custom types
|
||||
model: str
|
||||
rag_collection: Optional[str] = None
|
||||
system_prompt: str = ""
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 1000
|
||||
memory_length: int = 10 # Number of previous messages to remember
|
||||
use_rag: bool = False
|
||||
rag_top_k: int = 5
|
||||
rag_score_threshold: float = 0.02 # Lowered from default 0.3 to allow more results
|
||||
fallback_responses: List[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.fallback_responses is None:
|
||||
self.fallback_responses = [
|
||||
"I'm not sure how to help with that. Could you please rephrase your question?",
|
||||
"I don't have enough information to answer that question accurately.",
|
||||
"That's outside my knowledge area. Is there something else I can help you with?"
|
||||
]
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""Individual chat message"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
role: MessageRole
|
||||
content: str
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
sources: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
"""Conversation state"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
chatbot_id: str
|
||||
user_id: str
|
||||
messages: List[ChatMessage] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""Chat completion request"""
|
||||
message: str
|
||||
conversation_id: Optional[str] = None
|
||||
chatbot_id: str
|
||||
use_rag: Optional[bool] = None
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Chat completion response"""
|
||||
response: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
sources: Optional[List[Dict[str, Any]]] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChatbotInstance(BaseModel):
|
||||
"""Configured chatbot instance"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
name: str
|
||||
config: ChatbotConfig
|
||||
created_by: str
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class ChatbotModule(BaseModule):
|
||||
"""Main chatbot module implementation"""
|
||||
|
||||
def __init__(self, rag_service: Optional[RAGServiceProtocol] = None):
|
||||
super().__init__("chatbot")
|
||||
self.rag_module = rag_service # Keep same name for compatibility
|
||||
self.db_session = None
|
||||
|
||||
# System prompts will be loaded from database
|
||||
self.system_prompts = {}
|
||||
|
||||
async def initialize(self, **kwargs):
|
||||
"""Initialize the chatbot module"""
|
||||
await super().initialize(**kwargs)
|
||||
|
||||
# Initialize the LLM service
|
||||
await llm_service.initialize()
|
||||
|
||||
# Get RAG module dependency if not already injected
|
||||
if not self.rag_module:
|
||||
try:
|
||||
# Try to get RAG module from module manager
|
||||
from app.services.module_manager import module_manager
|
||||
if hasattr(module_manager, 'modules') and 'rag' in module_manager.modules:
|
||||
self.rag_module = module_manager.modules['rag']
|
||||
logger.info("RAG module injected from module manager")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not inject RAG module: {e}")
|
||||
|
||||
# Load prompt templates from database
|
||||
await self._load_prompt_templates()
|
||||
|
||||
logger.info("Chatbot module initialized")
|
||||
logger.info(f"LLM service available: {llm_service._initialized}")
|
||||
logger.info(f"RAG module available after init: {self.rag_module is not None}")
|
||||
logger.info(f"Loaded {len(self.system_prompts)} prompt templates")
|
||||
|
||||
async def _ensure_dependencies(self):
|
||||
"""Lazy load dependencies if not available"""
|
||||
# Ensure LLM service is initialized
|
||||
if not llm_service._initialized:
|
||||
await llm_service.initialize()
|
||||
logger.info("LLM service lazy loaded")
|
||||
|
||||
if not self.rag_module:
|
||||
try:
|
||||
# Try to get RAG module from module manager
|
||||
from app.services.module_manager import module_manager
|
||||
if hasattr(module_manager, 'modules') and 'rag' in module_manager.modules:
|
||||
self.rag_module = module_manager.modules['rag']
|
||||
logger.info("RAG module lazy loaded from module manager")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not lazy load RAG module: {e}")
|
||||
|
||||
async def _load_prompt_templates(self):
|
||||
"""Load prompt templates from database"""
|
||||
try:
|
||||
from app.db.database import SessionLocal
|
||||
from app.models.prompt_template import PromptTemplate
|
||||
from sqlalchemy import select
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
result = db.execute(
|
||||
select(PromptTemplate)
|
||||
.where(PromptTemplate.is_active == True)
|
||||
)
|
||||
templates = result.scalars().all()
|
||||
|
||||
for template in templates:
|
||||
self.system_prompts[template.type_key] = template.system_prompt
|
||||
|
||||
logger.info(f"Loaded {len(self.system_prompts)} prompt templates from database")
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load prompt templates from database: {e}")
|
||||
# Fallback to hardcoded prompts
|
||||
self.system_prompts = {
|
||||
"assistant": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations.",
|
||||
"customer_support": "You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions.",
|
||||
"teacher": "You are an experienced educational tutor. Break down complex concepts into understandable parts. Be patient, supportive, and encouraging.",
|
||||
"researcher": "You are a thorough research assistant with a focus on accuracy and evidence-based information.",
|
||||
"creative_writer": "You are an experienced creative writing mentor and storytelling expert.",
|
||||
"custom": "You are a helpful AI assistant. Your personality and behavior will be defined by custom instructions."
|
||||
}
|
||||
|
||||
async def get_system_prompt_for_type(self, chatbot_type: str) -> str:
|
||||
"""Get system prompt for a specific chatbot type"""
|
||||
if chatbot_type in self.system_prompts:
|
||||
return self.system_prompts[chatbot_type]
|
||||
|
||||
# If not found, try to reload templates
|
||||
await self._load_prompt_templates()
|
||||
|
||||
return self.system_prompts.get(chatbot_type, self.system_prompts.get("assistant",
|
||||
"You are a helpful AI assistant. Provide accurate, concise, and friendly responses."))
|
||||
|
||||
async def create_chatbot(self, config: ChatbotConfig, user_id: str, db: Session) -> ChatbotInstance:
|
||||
"""Create a new chatbot instance"""
|
||||
|
||||
# Set system prompt based on type if not provided or empty
|
||||
if not config.system_prompt or config.system_prompt.strip() == "":
|
||||
config.system_prompt = await self.get_system_prompt_for_type(config.chatbot_type)
|
||||
|
||||
# Create database record
|
||||
db_chatbot = DBChatbotInstance(
|
||||
name=config.name,
|
||||
description=f"{config.chatbot_type.replace('_', ' ').title()} chatbot",
|
||||
config=config.__dict__,
|
||||
created_by=user_id
|
||||
)
|
||||
|
||||
db.add(db_chatbot)
|
||||
db.commit()
|
||||
db.refresh(db_chatbot)
|
||||
|
||||
# Convert to response model
|
||||
chatbot = ChatbotInstance(
|
||||
id=db_chatbot.id,
|
||||
name=db_chatbot.name,
|
||||
config=ChatbotConfig(**db_chatbot.config),
|
||||
created_by=db_chatbot.created_by,
|
||||
created_at=db_chatbot.created_at,
|
||||
updated_at=db_chatbot.updated_at,
|
||||
is_active=db_chatbot.is_active
|
||||
)
|
||||
|
||||
logger.info(f"Created new chatbot: {chatbot.name} ({chatbot.id})")
|
||||
return chatbot
|
||||
|
||||
async def chat_completion(self, request: ChatRequest, user_id: str, db: Session) -> ChatResponse:
|
||||
"""Generate chat completion response"""
|
||||
|
||||
# Get chatbot configuration from database
|
||||
db_chatbot = db.query(DBChatbotInstance).filter(DBChatbotInstance.id == request.chatbot_id).first()
|
||||
if not db_chatbot:
|
||||
raise HTTPException(status_code=404, detail="Chatbot not found")
|
||||
|
||||
chatbot_config = ChatbotConfig(**db_chatbot.config)
|
||||
|
||||
# Get or create conversation
|
||||
conversation = await self._get_or_create_conversation(
|
||||
request.conversation_id, request.chatbot_id, user_id, db
|
||||
)
|
||||
|
||||
# Create user message
|
||||
user_message = DBMessage(
|
||||
conversation_id=conversation.id,
|
||||
role=MessageRole.USER.value,
|
||||
content=request.message
|
||||
)
|
||||
db.add(user_message)
|
||||
db.commit()
|
||||
db.refresh(user_message)
|
||||
|
||||
logger.info(f"Created user message with ID {user_message.id} for conversation {conversation.id}")
|
||||
|
||||
try:
|
||||
# Force the session to see the committed changes
|
||||
db.expire_all()
|
||||
|
||||
# Get conversation history for context - includes the current message we just created
|
||||
# Fetch up to memory_length pairs of messages (user + assistant)
|
||||
# The +1 ensures we include the current message if we're at the limit
|
||||
messages = db.query(DBMessage).filter(
|
||||
DBMessage.conversation_id == conversation.id
|
||||
).order_by(DBMessage.timestamp.desc()).limit(chatbot_config.memory_length * 2 + 1).all()
|
||||
|
||||
logger.info(f"Query for conversation_id={conversation.id}, memory_length={chatbot_config.memory_length}")
|
||||
logger.info(f"Found {len(messages)} messages in conversation history")
|
||||
|
||||
# If we don't have any messages, manually add the user message we just created
|
||||
if len(messages) == 0:
|
||||
logger.warning(f"No messages found in query, but we just created message {user_message.id}")
|
||||
logger.warning(f"Using the user message we just created")
|
||||
messages = [user_message]
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
logger.info(f"Message {idx}: id={msg.id}, role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...")
|
||||
|
||||
# Generate response
|
||||
response_content, sources = await self._generate_response(
|
||||
request.message, messages, chatbot_config, request.context, db
|
||||
)
|
||||
|
||||
# Create assistant message
|
||||
assistant_message = DBMessage(
|
||||
conversation_id=conversation.id,
|
||||
role=MessageRole.ASSISTANT.value,
|
||||
content=response_content,
|
||||
sources=sources,
|
||||
metadata={"model": chatbot_config.model, "temperature": chatbot_config.temperature}
|
||||
)
|
||||
db.add(assistant_message)
|
||||
db.commit()
|
||||
db.refresh(assistant_message)
|
||||
|
||||
# Update conversation timestamp
|
||||
conversation.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
return ChatResponse(
|
||||
response=response_content,
|
||||
conversation_id=conversation.id,
|
||||
message_id=assistant_message.id,
|
||||
sources=sources
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Chat completion failed: {e}")
|
||||
# Return fallback response
|
||||
fallback = chatbot_config.fallback_responses[0] if chatbot_config.fallback_responses else "I'm having trouble responding right now."
|
||||
|
||||
assistant_message = DBMessage(
|
||||
conversation_id=conversation.id,
|
||||
role=MessageRole.ASSISTANT.value,
|
||||
content=fallback,
|
||||
metadata={"error": str(e), "fallback": True}
|
||||
)
|
||||
db.add(assistant_message)
|
||||
db.commit()
|
||||
db.refresh(assistant_message)
|
||||
|
||||
return ChatResponse(
|
||||
response=fallback,
|
||||
conversation_id=conversation.id,
|
||||
message_id=assistant_message.id,
|
||||
metadata={"error": str(e), "fallback": True}
|
||||
)
|
||||
|
||||
async def _generate_response(self, message: str, db_messages: List[DBMessage],
|
||||
config: ChatbotConfig, context: Optional[Dict] = None, db: Session = None) -> tuple[str, Optional[List]]:
|
||||
"""Generate response using LLM with optional RAG"""
|
||||
|
||||
# Lazy load dependencies if not available
|
||||
await self._ensure_dependencies()
|
||||
|
||||
sources = None
|
||||
rag_context = ""
|
||||
|
||||
# Helper: detect encryption-related queries for extra care
|
||||
def _is_encryption_query(q: str) -> bool:
|
||||
ql = (q or "").lower()
|
||||
return any(k in ql for k in ["encrypt", "encryption", "encrypted", "decrypt", "decryption", "sd card", "microsd", "micro-sd"])
|
||||
|
||||
is_encryption = _is_encryption_query(message)
|
||||
|
||||
# RAG search if enabled
|
||||
if config.use_rag and config.rag_collection and self.rag_module:
|
||||
logger.info(f"RAG search enabled for collection: {config.rag_collection}")
|
||||
try:
|
||||
# Get the Qdrant collection name from RAG collection
|
||||
qdrant_collection_name = await self._get_qdrant_collection_name(config.rag_collection, db)
|
||||
logger.info(f"Qdrant collection name: {qdrant_collection_name}")
|
||||
|
||||
if qdrant_collection_name:
|
||||
logger.info(f"Searching RAG documents: query='{message[:50]}...', max_results={config.rag_top_k}")
|
||||
rag_results = await self.rag_module.search_documents(
|
||||
query=message,
|
||||
max_results=config.rag_top_k,
|
||||
collection_name=qdrant_collection_name,
|
||||
score_threshold=config.rag_score_threshold
|
||||
)
|
||||
|
||||
# If the user asks about encryption, prefer results that explicitly mention it
|
||||
if rag_results and is_encryption:
|
||||
kw = ["encrypt", "encryption", "encrypted", "decrypt", "decryption"]
|
||||
filtered = [r for r in rag_results if any(k in (r.document.content or "").lower() for k in kw)]
|
||||
if filtered:
|
||||
rag_results = filtered + [r for r in rag_results if r not in filtered]
|
||||
|
||||
if rag_results:
|
||||
logger.info(f"RAG search found {len(rag_results)} results")
|
||||
sources = [{"title": f"Document {i+1}", "content": result.document.content[:200]}
|
||||
for i, result in enumerate(rag_results)]
|
||||
|
||||
# Build full RAG context from all results
|
||||
rag_context = "\n\nRelevant information from knowledge base:\n" + "\n\n".join([
|
||||
f"[Document {i+1}]:\n{result.document.content}" for i, result in enumerate(rag_results)
|
||||
])
|
||||
|
||||
# Detailed RAG logging - ALWAYS log for debugging
|
||||
logger.info("=== COMPREHENSIVE RAG SEARCH RESULTS ===")
|
||||
logger.info(f"Query: '{message}'")
|
||||
logger.info(f"Collection: {qdrant_collection_name}")
|
||||
logger.info(f"Number of results: {len(rag_results)}")
|
||||
for i, result in enumerate(rag_results):
|
||||
logger.info(f"\n--- RAG Result {i+1} ---")
|
||||
logger.info(f"Score: {getattr(result, 'score', 'N/A')}")
|
||||
logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}")
|
||||
logger.info(f"Full Content ({len(result.document.content)} chars):")
|
||||
logger.info(f"{result.document.content}")
|
||||
if hasattr(result.document, 'metadata'):
|
||||
logger.info(f"Metadata: {result.document.metadata}")
|
||||
logger.info(f"\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
|
||||
logger.info(rag_context)
|
||||
logger.info("=== END RAG SEARCH RESULTS ===")
|
||||
else:
|
||||
logger.warning("RAG search returned no results")
|
||||
else:
|
||||
logger.warning(f"RAG collection '{config.rag_collection}' not found in database")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG search failed: {e}")
|
||||
import traceback
|
||||
logger.warning(f"RAG search traceback: {traceback.format_exc()}")
|
||||
|
||||
# Build conversation context (includes the current message from db_messages)
|
||||
# Inject strict grounding instructions when RAG is used, especially for encryption questions
|
||||
extra_instructions = {}
|
||||
if config.use_rag:
|
||||
guardrails = (
|
||||
"Answer strictly using the 'Relevant information' provided. "
|
||||
"If the information does not explicitly answer the question, say you don't have enough information instead of guessing. "
|
||||
)
|
||||
if is_encryption:
|
||||
guardrails += (
|
||||
"When asked about encryption or SD-card backups, do not claim that backups are encrypted unless the provided context explicitly uses wording like 'encrypt', 'encrypted', or 'encryption'. "
|
||||
"If such wording is absent, state clearly that the SD-card backup is not encrypted. "
|
||||
"Product policy: For BitBox devices, microSD (SD card) backups are not encrypted; verification steps may require a recovery password, but that is not encryption. Do not conflate password entry with encryption. "
|
||||
)
|
||||
extra_instructions["additional_instructions"] = guardrails
|
||||
|
||||
# Deterministic enforcement: if encryption question and RAG context does not explicitly
|
||||
# contain encryption wording, return policy answer without calling the LLM.
|
||||
ctx_lower = (rag_context or "").lower()
|
||||
has_encryption_terms = any(k in ctx_lower for k in ["encrypt", "encrypted", "encryption", "decrypt", "decryption"])
|
||||
if is_encryption and not has_encryption_terms:
|
||||
policy_answer = (
|
||||
"No. BitBox microSD (SD card) backups are not encrypted. "
|
||||
"Verification may require entering a recovery password, but that does not encrypt the backup — "
|
||||
"it only proves you have the correct credentials to restore. Keep the card and password secure."
|
||||
)
|
||||
return policy_answer, sources
|
||||
|
||||
messages = self._build_conversation_messages(db_messages, config, rag_context, extra_instructions)
|
||||
|
||||
# Note: Current user message is already included in db_messages from the query
|
||||
logger.info(f"Built conversation context with {len(messages)} messages")
|
||||
|
||||
# LLM completion
|
||||
logger.info(f"Attempting LLM completion with model: {config.model}")
|
||||
logger.info(f"Messages to send: {len(messages)} messages")
|
||||
|
||||
# Always log detailed prompts for debugging
|
||||
logger.info("=== COMPREHENSIVE LLM REQUEST ===")
|
||||
logger.info(f"Model: {config.model}")
|
||||
logger.info(f"Temperature: {config.temperature}")
|
||||
logger.info(f"Max tokens: {config.max_tokens}")
|
||||
logger.info(f"RAG enabled: {config.use_rag}")
|
||||
logger.info(f"RAG collection: {config.rag_collection}")
|
||||
if config.use_rag and rag_context:
|
||||
logger.info(f"RAG context added: {len(rag_context)} characters")
|
||||
logger.info(f"RAG sources: {len(sources) if sources else 0} documents")
|
||||
logger.info("\n=== COMPLETE MESSAGES SENT TO LLM ===")
|
||||
for i, msg in enumerate(messages):
|
||||
logger.info(f"\n--- Message {i+1} ---")
|
||||
logger.info(f"Role: {msg['role']}")
|
||||
logger.info(f"Content ({len(msg['content'])} chars):")
|
||||
# Truncate long content for logging (full RAG context can be very long)
|
||||
if len(msg['content']) > 500:
|
||||
logger.info(f"{msg['content'][:500]}... [truncated, total {len(msg['content'])} chars]")
|
||||
else:
|
||||
logger.info(msg['content'])
|
||||
logger.info("=== END COMPREHENSIVE LLM REQUEST ===")
|
||||
|
||||
try:
|
||||
logger.info("Calling LLM service create_chat_completion...")
|
||||
|
||||
# Convert messages to LLM service format
|
||||
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
|
||||
|
||||
# Create LLM service request
|
||||
llm_request = LLMChatRequest(
|
||||
model=config.model,
|
||||
messages=llm_messages,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
user_id="chatbot_user",
|
||||
api_key_id=0 # Chatbot module uses internal service
|
||||
)
|
||||
|
||||
# Make request to LLM service
|
||||
llm_response = await llm_service.create_chat_completion(llm_request)
|
||||
|
||||
# Extract response content
|
||||
if llm_response.choices:
|
||||
content = llm_response.choices[0].message.content
|
||||
logger.info(f"Response content length: {len(content)}")
|
||||
|
||||
# Always log response for debugging
|
||||
logger.info("=== COMPREHENSIVE LLM RESPONSE ===")
|
||||
logger.info(f"Response content ({len(content)} chars):")
|
||||
logger.info(content)
|
||||
if llm_response.usage:
|
||||
usage = llm_response.usage
|
||||
logger.info(f"Token usage - Prompt: {usage.prompt_tokens}, Completion: {usage.completion_tokens}, Total: {usage.total_tokens}")
|
||||
if sources:
|
||||
logger.info(f"RAG sources included: {len(sources)} documents")
|
||||
logger.info("=== END COMPREHENSIVE LLM RESPONSE ===")
|
||||
|
||||
return content, sources
|
||||
else:
|
||||
logger.warning("No choices in LLM response")
|
||||
return "I received an empty response from the AI model.", sources
|
||||
|
||||
except SecurityError as e:
|
||||
logger.error(f"Security error in LLM completion: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Security validation failed: {e.message}")
|
||||
except ProviderError as e:
|
||||
logger.error(f"Provider error in LLM completion: {e}")
|
||||
raise HTTPException(status_code=503, detail="LLM service temporarily unavailable")
|
||||
except LLMError as e:
|
||||
logger.error(f"LLM service error: {e}")
|
||||
raise HTTPException(status_code=500, detail="LLM service error")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM completion failed: {e}")
|
||||
# Return fallback if available
|
||||
return "I'm currently unable to process your request. Please try again later.", None
|
||||
|
||||
def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig,
|
||||
rag_context: str = "", context: Optional[Dict] = None) -> List[Dict]:
|
||||
"""Build messages array for LLM completion"""
|
||||
|
||||
messages = []
|
||||
|
||||
# System prompt
|
||||
system_prompt = config.system_prompt
|
||||
if rag_context:
|
||||
# Add explicit instruction to use RAG context
|
||||
system_prompt += "\n\nIMPORTANT: Use the following information from the knowledge base to answer the user's question. " \
|
||||
"This information is directly relevant to their query and should be your primary source:\n" + rag_context
|
||||
if context and context.get('additional_instructions'):
|
||||
system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}"
|
||||
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
logger.info(f"Building messages from {len(db_messages)} database messages")
|
||||
|
||||
# Conversation history (messages are already limited by memory_length in the query)
|
||||
# Reverse to get chronological order
|
||||
# Include ALL messages - the current user message is needed for the LLM to respond!
|
||||
for idx, msg in enumerate(reversed(db_messages)):
|
||||
logger.info(f"Processing message {idx}: role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...")
|
||||
if msg.role in ["user", "assistant"]:
|
||||
messages.append({
|
||||
"role": msg.role,
|
||||
"content": msg.content
|
||||
})
|
||||
logger.info(f"Added message with role {msg.role} to LLM messages")
|
||||
else:
|
||||
logger.info(f"Skipped message with role {msg.role}")
|
||||
|
||||
logger.info(f"Final messages array has {len(messages)} messages") # For debugging, can be removed in production
|
||||
return messages
|
||||
|
||||
async def _get_or_create_conversation(self, conversation_id: Optional[str],
|
||||
chatbot_id: str, user_id: str, db: Session) -> DBConversation:
|
||||
"""Get existing conversation or create new one"""
|
||||
|
||||
if conversation_id:
|
||||
conversation = db.query(DBConversation).filter(DBConversation.id == conversation_id).first()
|
||||
if conversation:
|
||||
return conversation
|
||||
|
||||
# Create new conversation
|
||||
conversation = DBConversation(
|
||||
chatbot_id=chatbot_id,
|
||||
user_id=user_id,
|
||||
title="New Conversation"
|
||||
)
|
||||
|
||||
db.add(conversation)
|
||||
db.commit()
|
||||
db.refresh(conversation)
|
||||
return conversation
|
||||
|
||||
def get_router(self) -> APIRouter:
|
||||
"""Get FastAPI router for chatbot endpoints"""
|
||||
router = APIRouter(prefix="/chatbot", tags=["chatbot"])
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat_endpoint(
|
||||
request: ChatRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Chat completion endpoint"""
|
||||
return await self.chat_completion(request, str(current_user['id']), db)
|
||||
|
||||
@router.post("/create", response_model=ChatbotInstance)
|
||||
async def create_chatbot_endpoint(
|
||||
config: ChatbotConfig,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Create new chatbot instance"""
|
||||
return await self.create_chatbot(config, str(current_user['id']), db)
|
||||
|
||||
@router.get("/list", response_model=List[ChatbotInstance])
|
||||
async def list_chatbots_endpoint(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""List user's chatbots"""
|
||||
db_chatbots = db.query(DBChatbotInstance).filter(
|
||||
(DBChatbotInstance.created_by == str(current_user['id'])) |
|
||||
(DBChatbotInstance.created_by == "system")
|
||||
).all()
|
||||
|
||||
chatbots = []
|
||||
for db_chatbot in db_chatbots:
|
||||
chatbot = ChatbotInstance(
|
||||
id=db_chatbot.id,
|
||||
name=db_chatbot.name,
|
||||
config=ChatbotConfig(**db_chatbot.config),
|
||||
created_by=db_chatbot.created_by,
|
||||
created_at=db_chatbot.created_at,
|
||||
updated_at=db_chatbot.updated_at,
|
||||
is_active=db_chatbot.is_active
|
||||
)
|
||||
chatbots.append(chatbot)
|
||||
|
||||
return chatbots
|
||||
|
||||
@router.get("/conversations/{conversation_id}", response_model=Conversation)
|
||||
async def get_conversation_endpoint(
|
||||
conversation_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get conversation history"""
|
||||
conversation = db.query(DBConversation).filter(
|
||||
DBConversation.id == conversation_id
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
|
||||
# Check if user owns this conversation
|
||||
if conversation.user_id != str(current_user['id']):
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
# Get messages
|
||||
messages = db.query(DBMessage).filter(
|
||||
DBMessage.conversation_id == conversation_id
|
||||
).order_by(DBMessage.timestamp).all()
|
||||
|
||||
# Convert to response model
|
||||
chat_messages = []
|
||||
for msg in messages:
|
||||
chat_message = ChatMessage(
|
||||
id=msg.id,
|
||||
role=MessageRole(msg.role),
|
||||
content=msg.content,
|
||||
timestamp=msg.timestamp,
|
||||
metadata=msg.metadata or {},
|
||||
sources=msg.sources
|
||||
)
|
||||
chat_messages.append(chat_message)
|
||||
|
||||
response_conversation = Conversation(
|
||||
id=conversation.id,
|
||||
chatbot_id=conversation.chatbot_id,
|
||||
user_id=conversation.user_id,
|
||||
messages=chat_messages,
|
||||
created_at=conversation.created_at,
|
||||
updated_at=conversation.updated_at,
|
||||
metadata=conversation.context_data or {}
|
||||
)
|
||||
|
||||
return response_conversation
|
||||
|
||||
@router.get("/types", response_model=List[Dict[str, str]])
|
||||
async def get_chatbot_types_endpoint():
|
||||
"""Get available chatbot types and their descriptions"""
|
||||
return [
|
||||
{"type": "assistant", "name": "General Assistant", "description": "Helpful AI assistant for general questions"},
|
||||
{"type": "customer_support", "name": "Customer Support", "description": "Professional customer service chatbot"},
|
||||
{"type": "teacher", "name": "Teacher", "description": "Educational tutor and learning assistant"},
|
||||
{"type": "researcher", "name": "Researcher", "description": "Research assistant with fact-checking focus"},
|
||||
{"type": "creative_writer", "name": "Creative Writer", "description": "Creative writing and storytelling assistant"},
|
||||
{"type": "custom", "name": "Custom", "description": "Custom chatbot with user-defined personality"}
|
||||
]
|
||||
|
||||
return router
|
||||
|
||||
# API Compatibility Methods
|
||||
async def chat(self, chatbot_config: Dict[str, Any], message: str,
|
||||
conversation_history: List = None, user_id: str = "anonymous") -> Dict[str, Any]:
|
||||
"""Chat method for API compatibility"""
|
||||
logger.info(f"Chat method called with message: {message[:50]}... by user: {user_id}")
|
||||
|
||||
# Lazy load dependencies
|
||||
await self._ensure_dependencies()
|
||||
|
||||
logger.info(f"LLM service available: {llm_service._initialized}")
|
||||
logger.info(f"RAG module available: {self.rag_module is not None}")
|
||||
|
||||
try:
|
||||
# Create a minimal database session for the chat
|
||||
from app.db.database import SessionLocal
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Convert config dict to ChatbotConfig
|
||||
config = ChatbotConfig(
|
||||
name=chatbot_config.get("name", "Unknown"),
|
||||
chatbot_type=chatbot_config.get("chatbot_type", "assistant"),
|
||||
model=chatbot_config.get("model", "gpt-3.5-turbo"),
|
||||
system_prompt=chatbot_config.get("system_prompt", ""),
|
||||
temperature=chatbot_config.get("temperature", 0.7),
|
||||
max_tokens=chatbot_config.get("max_tokens", 1000),
|
||||
memory_length=chatbot_config.get("memory_length", 10),
|
||||
use_rag=chatbot_config.get("use_rag", False),
|
||||
rag_collection=chatbot_config.get("rag_collection"),
|
||||
rag_top_k=chatbot_config.get("rag_top_k", 5),
|
||||
fallback_responses=chatbot_config.get("fallback_responses", [])
|
||||
)
|
||||
|
||||
# Generate response using internal method
|
||||
# Create a temporary message object for the current user message
|
||||
temp_messages = [
|
||||
DBMessage(
|
||||
id=0,
|
||||
conversation_id=0,
|
||||
role="user",
|
||||
content=message,
|
||||
timestamp=datetime.utcnow(),
|
||||
metadata={}
|
||||
)
|
||||
]
|
||||
|
||||
response_content, sources = await self._generate_response(
|
||||
message, temp_messages, config, None, db
|
||||
)
|
||||
|
||||
return {
|
||||
"response": response_content,
|
||||
"sources": sources,
|
||||
"conversation_id": None,
|
||||
"message_id": f"msg_{uuid.uuid4()}"
|
||||
}
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Chat method failed: {e}")
|
||||
fallback_responses = chatbot_config.get("fallback_responses", [
|
||||
"I'm sorry, I'm having trouble processing your request right now."
|
||||
])
|
||||
return {
|
||||
"response": fallback_responses[0] if fallback_responses else "I'm sorry, I couldn't process your request.",
|
||||
"sources": None,
|
||||
"conversation_id": None,
|
||||
"message_id": f"msg_{uuid.uuid4()}"
|
||||
}
|
||||
|
||||
# Workflow Integration Methods
|
||||
async def workflow_chat_step(self, context: Dict[str, Any], step_config: Dict[str, Any], db: Session) -> Dict[str, Any]:
|
||||
"""Execute chatbot as a workflow step"""
|
||||
|
||||
message = step_config.get('message', '')
|
||||
chatbot_id = step_config.get('chatbot_id')
|
||||
use_rag = step_config.get('use_rag', False)
|
||||
|
||||
# Template substitution from context
|
||||
message = self._substitute_template_variables(message, context)
|
||||
|
||||
request = ChatRequest(
|
||||
message=message,
|
||||
chatbot_id=chatbot_id,
|
||||
use_rag=use_rag,
|
||||
context=step_config.get('context', {})
|
||||
)
|
||||
|
||||
# Use system user for workflow executions
|
||||
response = await self.chat_completion(request, "workflow_system", db)
|
||||
|
||||
return {
|
||||
"response": response.response,
|
||||
"conversation_id": response.conversation_id,
|
||||
"sources": response.sources,
|
||||
"metadata": response.metadata
|
||||
}
|
||||
|
||||
def _substitute_template_variables(self, template: str, context: Dict[str, Any]) -> str:
|
||||
"""Simple template variable substitution"""
|
||||
import re
|
||||
|
||||
def replace_var(match):
|
||||
var_path = match.group(1)
|
||||
try:
|
||||
# Simple dot notation support: context.user.name
|
||||
value = context
|
||||
for part in var_path.split('.'):
|
||||
value = value[part]
|
||||
return str(value)
|
||||
except (KeyError, TypeError):
|
||||
return match.group(0) # Return original if not found
|
||||
|
||||
return re.sub(r'\\{\\{\\s*([^}]+)\\s*\\}\\}', replace_var, template)
|
||||
|
||||
async def _get_qdrant_collection_name(self, collection_identifier: str, db: Session) -> Optional[str]:
|
||||
"""Get Qdrant collection name from RAG collection ID, name, or direct Qdrant collection"""
|
||||
try:
|
||||
from app.models.rag_collection import RagCollection
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Looking up RAG collection with identifier: '{collection_identifier}'")
|
||||
|
||||
# First check if this might be a direct Qdrant collection name
|
||||
# (e.g., starts with "ext_", "rag_", or contains specific patterns)
|
||||
if collection_identifier.startswith(("ext_", "rag_", "test_")) or "_" in collection_identifier:
|
||||
# Check if this collection exists in Qdrant directly
|
||||
actual_collection_name = collection_identifier
|
||||
# Remove "ext_" prefix if present
|
||||
if collection_identifier.startswith("ext_"):
|
||||
actual_collection_name = collection_identifier[4:]
|
||||
|
||||
logger.info(f"Checking if '{actual_collection_name}' exists in Qdrant directly")
|
||||
if self.rag_module:
|
||||
try:
|
||||
# Try to verify the collection exists in Qdrant
|
||||
from qdrant_client import QdrantClient
|
||||
qdrant_client = QdrantClient(host="enclava-qdrant", port=6333)
|
||||
collections = qdrant_client.get_collections()
|
||||
collection_names = [c.name for c in collections.collections]
|
||||
|
||||
if actual_collection_name in collection_names:
|
||||
logger.info(f"Found Qdrant collection directly: {actual_collection_name}")
|
||||
return actual_collection_name
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking Qdrant collections: {e}")
|
||||
|
||||
rag_collection = None
|
||||
|
||||
# Then try PostgreSQL lookup by ID if numeric
|
||||
if collection_identifier.isdigit():
|
||||
logger.info(f"Treating '{collection_identifier}' as collection ID")
|
||||
stmt = select(RagCollection).where(
|
||||
RagCollection.id == int(collection_identifier),
|
||||
RagCollection.is_active == True
|
||||
)
|
||||
result = db.execute(stmt)
|
||||
rag_collection = result.scalar_one_or_none()
|
||||
|
||||
# If not found by ID, try to look up by name in PostgreSQL
|
||||
if not rag_collection:
|
||||
logger.info(f"Collection not found by ID, trying by name: '{collection_identifier}'")
|
||||
stmt = select(RagCollection).where(
|
||||
RagCollection.name == collection_identifier,
|
||||
RagCollection.is_active == True
|
||||
)
|
||||
result = db.execute(stmt)
|
||||
rag_collection = result.scalar_one_or_none()
|
||||
|
||||
if rag_collection:
|
||||
logger.info(f"Found RAG collection: ID={rag_collection.id}, name='{rag_collection.name}', qdrant_collection='{rag_collection.qdrant_collection_name}'")
|
||||
return rag_collection.qdrant_collection_name
|
||||
else:
|
||||
logger.warning(f"RAG collection '{collection_identifier}' not found in database (tried both ID and name)")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up RAG collection '{collection_identifier}': {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
# Required abstract methods from BaseModule
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup chatbot module resources"""
|
||||
logger.info("Chatbot module cleanup completed")
|
||||
|
||||
def get_required_permissions(self) -> List[Permission]:
|
||||
"""Get required permissions for chatbot module"""
|
||||
return [
|
||||
Permission("chatbots", "create", "Create chatbot instances"),
|
||||
Permission("chatbots", "configure", "Configure chatbot settings"),
|
||||
Permission("chatbots", "chat", "Use chatbot for conversations"),
|
||||
Permission("chatbots", "manage", "Manage all chatbots")
|
||||
]
|
||||
|
||||
async def process_request(self, request_type: str, data: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Process chatbot requests"""
|
||||
if request_type == "chat":
|
||||
# Handle chat requests
|
||||
chat_request = ChatRequest(**data)
|
||||
user_id = context.get("user_id", "anonymous")
|
||||
db = context.get("db")
|
||||
|
||||
if db:
|
||||
response = await self.chat_completion(chat_request, user_id, db)
|
||||
return {
|
||||
"success": True,
|
||||
"response": response.response,
|
||||
"conversation_id": response.conversation_id,
|
||||
"sources": response.sources
|
||||
}
|
||||
|
||||
return {"success": False, "error": f"Unknown request type: {request_type}"}
|
||||
|
||||
|
||||
# Module factory function
|
||||
def create_module(rag_service: Optional[RAGServiceProtocol] = None) -> ChatbotModule:
|
||||
"""Factory function to create chatbot module instance"""
|
||||
return ChatbotModule(rag_service=rag_service)
|
||||
|
||||
# Create module instance (dependencies will be injected via factory)
|
||||
chatbot_module = ChatbotModule()
|
||||
110
backend/app/modules/chatbot/module.yaml
Normal file
110
backend/app/modules/chatbot/module.yaml
Normal file
@@ -0,0 +1,110 @@
|
||||
name: chatbot
|
||||
version: 1.0.0
|
||||
description: "AI Chatbot with RAG integration and customizable prompts"
|
||||
author: "Enclava Team"
|
||||
category: "conversation"
|
||||
|
||||
# Module lifecycle
|
||||
enabled: true
|
||||
auto_start: true
|
||||
dependencies:
|
||||
- rag
|
||||
optional_dependencies:
|
||||
- analytics
|
||||
|
||||
# Configuration
|
||||
config_schema: "./config_schema.json"
|
||||
ui_components: "./ui_components/"
|
||||
|
||||
# Module capabilities
|
||||
provides:
|
||||
- "chat_completion"
|
||||
- "conversation_management"
|
||||
- "chatbot_configuration"
|
||||
|
||||
consumes:
|
||||
- "rag_search"
|
||||
- "llm_completion"
|
||||
|
||||
# API endpoints
|
||||
endpoints:
|
||||
- path: "/chatbot/chat"
|
||||
method: "POST"
|
||||
description: "Generate chat completion"
|
||||
|
||||
- path: "/chatbot/create"
|
||||
method: "POST"
|
||||
description: "Create new chatbot instance"
|
||||
|
||||
- path: "/chatbot/list"
|
||||
method: "GET"
|
||||
description: "List user chatbots"
|
||||
|
||||
# UI Configuration
|
||||
ui_config:
|
||||
icon: "message-circle"
|
||||
color: "#10B981"
|
||||
category: "AI & ML"
|
||||
|
||||
# Configuration forms
|
||||
forms:
|
||||
- name: "basic_config"
|
||||
title: "Basic Settings"
|
||||
fields: ["name", "chatbot_type", "model"]
|
||||
|
||||
- name: "personality"
|
||||
title: "Personality & Prompts"
|
||||
fields: ["system_prompt", "temperature", "fallback_responses"]
|
||||
|
||||
- name: "knowledge_base"
|
||||
title: "Knowledge Base"
|
||||
fields: ["use_rag", "rag_collection", "rag_top_k"]
|
||||
|
||||
- name: "advanced"
|
||||
title: "Advanced Settings"
|
||||
fields: ["max_tokens", "memory_length"]
|
||||
|
||||
# Permissions
|
||||
permissions:
|
||||
- name: "chatbot.create"
|
||||
description: "Create new chatbot instances"
|
||||
|
||||
- name: "chatbot.configure"
|
||||
description: "Configure chatbot settings"
|
||||
|
||||
- name: "chatbot.chat"
|
||||
description: "Use chatbot for conversations"
|
||||
|
||||
- name: "chatbot.manage"
|
||||
description: "Manage all chatbots (admin)"
|
||||
|
||||
# Analytics events
|
||||
analytics_events:
|
||||
- name: "chatbot_created"
|
||||
description: "New chatbot instance created"
|
||||
|
||||
- name: "chat_message_sent"
|
||||
description: "User sent message to chatbot"
|
||||
|
||||
- name: "chat_response_generated"
|
||||
description: "Chatbot generated response"
|
||||
|
||||
- name: "rag_context_used"
|
||||
description: "RAG context was used in response"
|
||||
|
||||
# Health checks
|
||||
health_checks:
|
||||
- name: "llm_connectivity"
|
||||
description: "Check LLM client connection"
|
||||
|
||||
- name: "rag_availability"
|
||||
description: "Check RAG module availability"
|
||||
|
||||
- name: "conversation_memory"
|
||||
description: "Check conversation storage health"
|
||||
|
||||
# Documentation
|
||||
documentation:
|
||||
readme: "./README.md"
|
||||
examples: "./examples/"
|
||||
api_docs: "./docs/api.md"
|
||||
225
backend/app/modules/factory.py
Normal file
225
backend/app/modules/factory.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Module Factory for Confidential Empire
|
||||
|
||||
This factory creates and wires up all modules with their dependencies.
|
||||
It ensures proper dependency injection while maintaining optimal performance
|
||||
through direct method calls and minimal indirection.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional, Any
|
||||
import logging
|
||||
|
||||
# Import all modules
|
||||
from .rag.main import RAGModule
|
||||
from .chatbot.main import ChatbotModule, create_module as create_chatbot_module
|
||||
from .workflow.main import WorkflowModule
|
||||
|
||||
# Import services that modules depend on
|
||||
from app.services.litellm_client import LiteLLMClient
|
||||
|
||||
# Import protocols for type safety
|
||||
from .protocols import (
|
||||
RAGServiceProtocol,
|
||||
ChatbotServiceProtocol,
|
||||
LiteLLMClientProtocol,
|
||||
WorkflowServiceProtocol,
|
||||
ServiceRegistry
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModuleFactory:
|
||||
"""Factory for creating and wiring module dependencies"""
|
||||
|
||||
def __init__(self):
|
||||
self.modules: Dict[str, Any] = {}
|
||||
self.initialized = False
|
||||
|
||||
async def create_all_modules(self, config: Optional[Dict[str, Any]] = None) -> ServiceRegistry:
|
||||
"""
|
||||
Create all modules with proper dependency injection
|
||||
|
||||
Args:
|
||||
config: Optional configuration for modules
|
||||
|
||||
Returns:
|
||||
Dictionary of created modules with their dependencies wired
|
||||
"""
|
||||
config = config or {}
|
||||
|
||||
logger.info("Creating modules with dependency injection...")
|
||||
|
||||
# Step 1: Create LiteLLM client (shared dependency)
|
||||
litellm_client = LiteLLMClient()
|
||||
|
||||
# Step 2: Create RAG module (no dependencies on other modules)
|
||||
rag_module = RAGModule(config=config.get("rag", {}))
|
||||
|
||||
# Step 3: Create chatbot module with RAG dependency
|
||||
chatbot_module = create_chatbot_module(
|
||||
litellm_client=litellm_client,
|
||||
rag_service=rag_module # RAG module implements RAGServiceProtocol
|
||||
)
|
||||
|
||||
# Step 4: Create workflow module with chatbot dependency
|
||||
workflow_module = WorkflowModule(
|
||||
chatbot_service=chatbot_module # Chatbot module implements ChatbotServiceProtocol
|
||||
)
|
||||
|
||||
# Store all modules
|
||||
modules = {
|
||||
"rag": rag_module,
|
||||
"chatbot": chatbot_module,
|
||||
"workflow": workflow_module
|
||||
}
|
||||
|
||||
logger.info(f"Created {len(modules)} modules with dependencies wired")
|
||||
|
||||
# Initialize all modules
|
||||
await self._initialize_modules(modules, config)
|
||||
|
||||
self.modules = modules
|
||||
self.initialized = True
|
||||
|
||||
return modules
|
||||
|
||||
async def _initialize_modules(self, modules: Dict[str, Any], config: Dict[str, Any]):
|
||||
"""Initialize all modules in dependency order"""
|
||||
|
||||
# Initialize in dependency order (modules with no deps first)
|
||||
initialization_order = [
|
||||
("rag", modules["rag"]),
|
||||
("chatbot", modules["chatbot"]), # Depends on RAG
|
||||
("workflow", modules["workflow"]) # Depends on Chatbot
|
||||
]
|
||||
|
||||
for module_name, module in initialization_order:
|
||||
try:
|
||||
logger.info(f"Initializing {module_name} module...")
|
||||
module_config = config.get(module_name, {})
|
||||
|
||||
# Different modules have different initialization patterns
|
||||
if hasattr(module, 'initialize'):
|
||||
if module_name == "rag":
|
||||
await module.initialize()
|
||||
else:
|
||||
await module.initialize(**module_config)
|
||||
|
||||
logger.info(f"✅ {module_name} module initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to initialize {module_name} module: {e}")
|
||||
raise RuntimeError(f"Module initialization failed: {module_name}") from e
|
||||
|
||||
async def cleanup_all_modules(self):
|
||||
"""Cleanup all modules in reverse dependency order"""
|
||||
if not self.initialized:
|
||||
return
|
||||
|
||||
# Cleanup in reverse order
|
||||
cleanup_order = ["workflow", "chatbot", "rag"]
|
||||
|
||||
for module_name in cleanup_order:
|
||||
if module_name in self.modules:
|
||||
try:
|
||||
logger.info(f"Cleaning up {module_name} module...")
|
||||
module = self.modules[module_name]
|
||||
if hasattr(module, 'cleanup'):
|
||||
await module.cleanup()
|
||||
logger.info(f"✅ {module_name} module cleaned up")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error cleaning up {module_name}: {e}")
|
||||
|
||||
self.modules.clear()
|
||||
self.initialized = False
|
||||
|
||||
def get_module(self, name: str) -> Optional[Any]:
|
||||
"""Get a module by name"""
|
||||
return self.modules.get(name)
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if factory is initialized"""
|
||||
return self.initialized
|
||||
|
||||
|
||||
# Global factory instance
|
||||
module_factory = ModuleFactory()
|
||||
|
||||
|
||||
# Convenience functions for external use
|
||||
async def create_modules(config: Optional[Dict[str, Any]] = None) -> ServiceRegistry:
|
||||
"""Create all modules with dependencies wired"""
|
||||
return await module_factory.create_all_modules(config)
|
||||
|
||||
|
||||
async def cleanup_modules():
|
||||
"""Cleanup all modules"""
|
||||
await module_factory.cleanup_all_modules()
|
||||
|
||||
|
||||
def get_module(name: str) -> Optional[Any]:
|
||||
"""Get a module by name"""
|
||||
return module_factory.get_module(name)
|
||||
|
||||
|
||||
def get_all_modules() -> Dict[str, Any]:
|
||||
"""Get all modules"""
|
||||
return module_factory.modules.copy()
|
||||
|
||||
|
||||
# Factory functions for individual modules (for testing/special cases)
|
||||
def create_rag_module(config: Optional[Dict[str, Any]] = None) -> RAGModule:
|
||||
"""Create RAG module"""
|
||||
return RAGModule(config=config or {})
|
||||
|
||||
|
||||
def create_chatbot_with_rag(rag_service: RAGServiceProtocol,
|
||||
litellm_client: LiteLLMClientProtocol) -> ChatbotModule:
|
||||
"""Create chatbot module with RAG dependency"""
|
||||
return create_chatbot_module(litellm_client=litellm_client, rag_service=rag_service)
|
||||
|
||||
|
||||
def create_workflow_with_chatbot(chatbot_service: ChatbotServiceProtocol) -> WorkflowModule:
|
||||
"""Create workflow module with chatbot dependency"""
|
||||
return WorkflowModule(chatbot_service=chatbot_service)
|
||||
|
||||
|
||||
# Module registry for backward compatibility
|
||||
class ModuleRegistry:
|
||||
"""Registry that provides access to modules (for backward compatibility)"""
|
||||
|
||||
def __init__(self, factory: ModuleFactory):
|
||||
self._factory = factory
|
||||
|
||||
@property
|
||||
def modules(self) -> Dict[str, Any]:
|
||||
"""Get all modules (compatible with existing module_manager interface)"""
|
||||
return self._factory.modules
|
||||
|
||||
def get(self, name: str) -> Optional[Any]:
|
||||
"""Get module by name"""
|
||||
return self._factory.get_module(name)
|
||||
|
||||
def __getitem__(self, name: str) -> Any:
|
||||
"""Support dictionary-style access"""
|
||||
module = self.get(name)
|
||||
if module is None:
|
||||
raise KeyError(f"Module '{name}' not found")
|
||||
return module
|
||||
|
||||
def keys(self):
|
||||
"""Get module names"""
|
||||
return self._factory.modules.keys()
|
||||
|
||||
def values(self):
|
||||
"""Get module instances"""
|
||||
return self._factory.modules.values()
|
||||
|
||||
def items(self):
|
||||
"""Get module name-instance pairs"""
|
||||
return self._factory.modules.items()
|
||||
|
||||
|
||||
# Create registry instance for backward compatibility
|
||||
module_registry = ModuleRegistry(module_factory)
|
||||
258
backend/app/modules/protocols.py
Normal file
258
backend/app/modules/protocols.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Module Protocols for Confidential Empire
|
||||
|
||||
This file defines the interface contracts that modules must implement for inter-module communication.
|
||||
Using Python protocols provides compile-time type checking with zero runtime overhead.
|
||||
"""
|
||||
|
||||
from typing import Protocol, Dict, List, Any, Optional, Union
|
||||
from datetime import datetime
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class RAGServiceProtocol(Protocol):
|
||||
"""Protocol for RAG (Retrieval-Augmented Generation) service interface"""
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, query: str, collection_name: str, top_k: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Search for relevant documents
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
collection_name: Name of the collection to search in
|
||||
top_k: Number of top results to return
|
||||
|
||||
Returns:
|
||||
Dictionary containing search results with 'results' key
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def index_document(self, content: str, metadata: Dict[str, Any] = None) -> str:
|
||||
"""
|
||||
Index a document in the vector database
|
||||
|
||||
Args:
|
||||
content: Document content to index
|
||||
metadata: Optional metadata for the document
|
||||
|
||||
Returns:
|
||||
Document ID
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def delete_document(self, document_id: str) -> bool:
|
||||
"""
|
||||
Delete a document from the vector database
|
||||
|
||||
Args:
|
||||
document_id: ID of document to delete
|
||||
|
||||
Returns:
|
||||
True if successfully deleted
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ChatbotServiceProtocol(Protocol):
|
||||
"""Protocol for Chatbot service interface"""
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(self, request: Any, user_id: str, db: Any) -> Any:
|
||||
"""
|
||||
Generate chat completion response
|
||||
|
||||
Args:
|
||||
request: Chat request object
|
||||
user_id: ID of the user making the request
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Chat response object
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def create_chatbot(self, config: Any, user_id: str, db: Any) -> Any:
|
||||
"""
|
||||
Create a new chatbot instance
|
||||
|
||||
Args:
|
||||
config: Chatbot configuration
|
||||
user_id: ID of the user creating the chatbot
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created chatbot instance
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class LiteLLMClientProtocol(Protocol):
|
||||
"""Protocol for LiteLLM client interface"""
|
||||
|
||||
@abstractmethod
|
||||
async def completion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||
"""
|
||||
Create a completion using the specified model
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
messages: List of messages for the conversation
|
||||
**kwargs: Additional parameters for the completion
|
||||
|
||||
Returns:
|
||||
Completion response object
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def create_chat_completion(self, model: str, messages: List[Dict[str, str]],
|
||||
user_id: str, api_key_id: str, **kwargs) -> Any:
|
||||
"""
|
||||
Create a chat completion with user tracking
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
messages: List of messages for the conversation
|
||||
user_id: ID of the user making the request
|
||||
api_key_id: API key identifier
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Chat completion response
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class CacheServiceProtocol(Protocol):
|
||||
"""Protocol for Cache service interface"""
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get value from cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
Cached value or default
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Set value in cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if successfully cached
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete key from cache
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
|
||||
Returns:
|
||||
True if successfully deleted
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SecurityServiceProtocol(Protocol):
|
||||
"""Protocol for Security service interface"""
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_request(self, request: Any) -> Any:
|
||||
"""
|
||||
Perform security analysis on a request
|
||||
|
||||
Args:
|
||||
request: Request object to analyze
|
||||
|
||||
Returns:
|
||||
Security analysis result
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def validate_request(self, request: Any) -> bool:
|
||||
"""
|
||||
Validate request for security compliance
|
||||
|
||||
Args:
|
||||
request: Request object to validate
|
||||
|
||||
Returns:
|
||||
True if request is valid/safe
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class WorkflowServiceProtocol(Protocol):
|
||||
"""Protocol for Workflow service interface"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute_workflow(self, workflow: Any, input_data: Dict[str, Any] = None) -> Any:
|
||||
"""
|
||||
Execute a workflow definition
|
||||
|
||||
Args:
|
||||
workflow: Workflow definition to execute
|
||||
input_data: Optional input data for the workflow
|
||||
|
||||
Returns:
|
||||
Workflow execution result
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_execution(self, execution_id: str) -> Any:
|
||||
"""
|
||||
Get workflow execution status
|
||||
|
||||
Args:
|
||||
execution_id: ID of the execution to retrieve
|
||||
|
||||
Returns:
|
||||
Execution status object
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ModuleServiceProtocol(Protocol):
|
||||
"""Base protocol for all module services"""
|
||||
|
||||
@abstractmethod
|
||||
async def initialize(self, **kwargs) -> None:
|
||||
"""Initialize the module"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup module resources"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_required_permissions(self) -> List[Any]:
|
||||
"""Get required permissions for this module"""
|
||||
...
|
||||
|
||||
|
||||
# Type aliases for common service combinations
|
||||
ServiceRegistry = Dict[str, ModuleServiceProtocol]
|
||||
ServiceDependencies = Dict[str, Optional[ModuleServiceProtocol]]
|
||||
6
backend/app/modules/rag/__init__.py
Normal file
6
backend/app/modules/rag/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
RAG (Retrieval-Augmented Generation) module for Confidential Empire platform
|
||||
"""
|
||||
from .main import RAGModule
|
||||
|
||||
__all__ = ["RAGModule"]
|
||||
1922
backend/app/modules/rag/main.py
Normal file
1922
backend/app/modules/rag/main.py
Normal file
File diff suppressed because it is too large
Load Diff
82
backend/app/modules/rag/module.yaml
Normal file
82
backend/app/modules/rag/module.yaml
Normal file
@@ -0,0 +1,82 @@
|
||||
name: rag
|
||||
version: 1.0.0
|
||||
description: "Document search, retrieval, and vector storage"
|
||||
author: "Enclava Team"
|
||||
category: "ai"
|
||||
|
||||
# Module lifecycle
|
||||
enabled: true
|
||||
auto_start: true
|
||||
dependencies: []
|
||||
optional_dependencies:
|
||||
- cache
|
||||
|
||||
# Module capabilities
|
||||
provides:
|
||||
- "document_storage"
|
||||
- "semantic_search"
|
||||
- "vector_embeddings"
|
||||
- "document_processing"
|
||||
|
||||
consumes:
|
||||
- "qdrant_connection"
|
||||
- "llm_embeddings"
|
||||
- "document_parsing"
|
||||
|
||||
# API endpoints
|
||||
endpoints:
|
||||
- path: "/rag/collections"
|
||||
method: "GET"
|
||||
description: "List document collections"
|
||||
|
||||
- path: "/rag/upload"
|
||||
method: "POST"
|
||||
description: "Upload and process documents"
|
||||
|
||||
- path: "/rag/search"
|
||||
method: "POST"
|
||||
description: "Semantic search in documents"
|
||||
|
||||
- path: "/rag/collections/{collection_id}/documents"
|
||||
method: "GET"
|
||||
description: "List documents in collection"
|
||||
|
||||
# UI Configuration
|
||||
ui_config:
|
||||
icon: "search"
|
||||
color: "#8B5CF6"
|
||||
category: "AI & ML"
|
||||
|
||||
forms:
|
||||
- name: "collection_config"
|
||||
title: "Collection Settings"
|
||||
fields: ["name", "description", "embedding_model"]
|
||||
|
||||
- name: "search_config"
|
||||
title: "Search Configuration"
|
||||
fields: ["top_k", "similarity_threshold", "rerank_enabled"]
|
||||
|
||||
# Permissions
|
||||
permissions:
|
||||
- name: "rag.create"
|
||||
description: "Create document collections"
|
||||
|
||||
- name: "rag.upload"
|
||||
description: "Upload documents to collections"
|
||||
|
||||
- name: "rag.search"
|
||||
description: "Search document collections"
|
||||
|
||||
- name: "rag.manage"
|
||||
description: "Manage all collections (admin)"
|
||||
|
||||
# Health checks
|
||||
health_checks:
|
||||
- name: "qdrant_connectivity"
|
||||
description: "Check Qdrant vector database connection"
|
||||
|
||||
- name: "embeddings_service"
|
||||
description: "Check LLM embeddings service"
|
||||
|
||||
- name: "document_processing"
|
||||
description: "Check document parsing capabilities"
|
||||
@@ -162,6 +162,7 @@ class DocumentProcessor:
|
||||
|
||||
async def _process_document(self, task: ProcessingTask) -> bool:
|
||||
"""Process a single document"""
|
||||
from datetime import datetime
|
||||
from app.db.database import async_session_factory
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
@@ -182,16 +183,24 @@ class DocumentProcessor:
|
||||
document.status = ProcessingStatus.PROCESSING
|
||||
await session.commit()
|
||||
|
||||
# Get RAG module for processing (now includes content processing)
|
||||
# Get RAG module for processing
|
||||
try:
|
||||
from app.services.module_manager import module_manager
|
||||
rag_module = module_manager.get_module('rag')
|
||||
# Import RAG module and initialize it properly
|
||||
from modules.rag.main import RAGModule
|
||||
from app.core.config import settings
|
||||
|
||||
# Create and initialize RAG module instance
|
||||
rag_module = RAGModule(settings)
|
||||
init_result = await rag_module.initialize()
|
||||
if not rag_module.enabled:
|
||||
raise Exception("Failed to enable RAG module")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get RAG module: {e}")
|
||||
raise Exception(f"RAG module not available: {e}")
|
||||
|
||||
if not rag_module:
|
||||
raise Exception("RAG module not available")
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise Exception("RAG module not available or not enabled")
|
||||
|
||||
logger.info(f"RAG module loaded successfully for document {task.document_id}")
|
||||
|
||||
@@ -204,23 +213,31 @@ class DocumentProcessor:
|
||||
|
||||
# Process with RAG module
|
||||
logger.info(f"Starting document processing for document {task.document_id} with RAG module")
|
||||
|
||||
# Special handling for JSONL files - skip processing phase
|
||||
if document.file_type == 'jsonl':
|
||||
# For JSONL files, we don't need to process content here
|
||||
# The optimized JSONL processor will handle everything during indexing
|
||||
document.converted_content = f"JSONL file with {len(file_content)} bytes"
|
||||
document.word_count = 0 # Will be updated during indexing
|
||||
document.character_count = len(file_content)
|
||||
document.document_metadata = {"file_path": document.file_path, "processed": "jsonl"}
|
||||
document.status = ProcessingStatus.PROCESSED
|
||||
document.processed_at = datetime.utcnow()
|
||||
logger.info(f"JSONL document {task.document_id} marked for optimized processing")
|
||||
else:
|
||||
# Standard processing for other file types
|
||||
try:
|
||||
# Add timeout to prevent hanging
|
||||
processed_doc = await asyncio.wait_for(
|
||||
rag_module.process_document(
|
||||
file_content,
|
||||
document.original_filename,
|
||||
{}
|
||||
{"file_path": document.file_path}
|
||||
),
|
||||
timeout=300.0 # 5 minute timeout
|
||||
)
|
||||
logger.info(f"Document processing completed for document {task.document_id}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Document processing timed out for document {task.document_id}")
|
||||
raise Exception("Document processing timed out after 5 minutes")
|
||||
except Exception as e:
|
||||
logger.error(f"Document processing failed for document {task.document_id}: {e}")
|
||||
raise
|
||||
|
||||
# Update document with processed content
|
||||
document.converted_content = processed_doc.content
|
||||
@@ -229,6 +246,12 @@ class DocumentProcessor:
|
||||
document.document_metadata = processed_doc.metadata
|
||||
document.status = ProcessingStatus.PROCESSED
|
||||
document.processed_at = datetime.utcnow()
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Document processing timed out for document {task.document_id}")
|
||||
raise Exception("Document processing timed out after 5 minutes")
|
||||
except Exception as e:
|
||||
logger.error(f"Document processing failed for document {task.document_id}: {e}")
|
||||
raise
|
||||
|
||||
# Index in RAG system using same RAG module
|
||||
if rag_module and document.converted_content:
|
||||
@@ -245,6 +268,49 @@ class DocumentProcessor:
|
||||
}
|
||||
|
||||
# Use the correct Qdrant collection name for this document
|
||||
# For JSONL files, we need to use the processed document flow
|
||||
if document.file_type == 'jsonl':
|
||||
# Create a ProcessedDocument for the JSONL processor
|
||||
from app.modules.rag.main import ProcessedDocument
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
|
||||
# Calculate file hash
|
||||
processed_at = datetime.utcnow()
|
||||
file_hash = hashlib.md5(str(document.id).encode()).hexdigest()
|
||||
|
||||
processed_doc = ProcessedDocument(
|
||||
id=str(document.id),
|
||||
content="", # Will be filled by JSONL processor
|
||||
extracted_text="", # Will be filled by JSONL processor
|
||||
metadata={
|
||||
**doc_metadata,
|
||||
"file_path": document.file_path
|
||||
},
|
||||
original_filename=document.original_filename,
|
||||
file_type=document.file_type,
|
||||
mime_type=document.mime_type,
|
||||
language=document.document_metadata.get('language', 'EN'),
|
||||
word_count=0, # Will be updated during processing
|
||||
sentence_count=0, # Will be updated during processing
|
||||
entities=[],
|
||||
keywords=[],
|
||||
processing_time=0.0,
|
||||
processed_at=processed_at,
|
||||
file_hash=file_hash,
|
||||
file_size=document.file_size
|
||||
)
|
||||
|
||||
# The JSONL processor will read the original file
|
||||
await asyncio.wait_for(
|
||||
rag_module.index_processed_document(
|
||||
processed_doc=processed_doc,
|
||||
collection_name=document.collection.qdrant_collection_name
|
||||
),
|
||||
timeout=300.0 # 5 minute timeout for JSONL processing
|
||||
)
|
||||
else:
|
||||
# Use standard indexing for other file types
|
||||
await asyncio.wait_for(
|
||||
rag_module.index_document(
|
||||
content=document.converted_content,
|
||||
@@ -271,7 +337,9 @@ class DocumentProcessor:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to index document {task.document_id} in RAG: {e}")
|
||||
# Keep as processed even if indexing fails
|
||||
# Mark as error since indexing failed
|
||||
document.status = ProcessingStatus.ERROR
|
||||
document.processing_error = f"Indexing failed: {str(e)}"
|
||||
# Don't raise the exception to avoid retries on indexing failures
|
||||
|
||||
await session.commit()
|
||||
|
||||
@@ -28,9 +28,19 @@ class EmbeddingService:
|
||||
await llm_service.initialize()
|
||||
|
||||
# Test LLM service health
|
||||
health_summary = llm_service.get_health_summary()
|
||||
if health_summary.get("service_status") != "healthy":
|
||||
logger.error(f"LLM service unhealthy: {health_summary}")
|
||||
if not llm_service._initialized:
|
||||
logger.error("LLM service not initialized")
|
||||
return False
|
||||
|
||||
# Check if PrivateMode provider is available
|
||||
try:
|
||||
provider_status = await llm_service.get_provider_status()
|
||||
privatemode_status = provider_status.get("privatemode")
|
||||
if not privatemode_status or privatemode_status.status != "healthy":
|
||||
logger.error(f"PrivateMode provider not available: {privatemode_status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check provider status: {e}")
|
||||
return False
|
||||
|
||||
self.initialized = True
|
||||
@@ -75,6 +85,12 @@ class EmbeddingService:
|
||||
else:
|
||||
truncated_text = text
|
||||
|
||||
# Guard: skip empty inputs (validator rejects empty strings)
|
||||
if not truncated_text.strip():
|
||||
logger.debug("Empty input for embedding; using fallback vector")
|
||||
batch_embeddings.append(self._generate_fallback_embedding(text))
|
||||
continue
|
||||
|
||||
# Call LLM service embedding endpoint
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import EmbeddingRequest
|
||||
|
||||
211
backend/app/services/enhanced_embedding_service.py
Normal file
211
backend/app/services/enhanced_embedding_service.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# Enhanced Embedding Service with Rate Limiting Handling
|
||||
"""
|
||||
Enhanced embedding service with robust rate limiting and retry logic
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from .embedding_service import EmbeddingService
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnhancedEmbeddingService(EmbeddingService):
|
||||
"""Enhanced embedding service with rate limiting handling"""
|
||||
|
||||
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"):
|
||||
super().__init__(model_name)
|
||||
self.rate_limit_tracker = {
|
||||
'requests_count': 0,
|
||||
'window_start': time.time(),
|
||||
'window_size': 60, # 1 minute window
|
||||
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 12)), # Configurable
|
||||
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
|
||||
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 1.0)),
|
||||
'delay_per_request': float(getattr(settings, 'RAG_EMBEDDING_DELAY_PER_REQUEST', 0.5)),
|
||||
'last_rate_limit_error': None
|
||||
}
|
||||
|
||||
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
|
||||
"""
|
||||
Get embeddings with rate limiting and retry logic
|
||||
"""
|
||||
if max_retries is None:
|
||||
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
|
||||
|
||||
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 3))
|
||||
|
||||
if not self.initialized:
|
||||
logger.warning("Embedding service not initialized, using fallback")
|
||||
return self._generate_fallback_embeddings(texts), False
|
||||
|
||||
embeddings = []
|
||||
success = True
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
batch_embeddings, batch_success = await self._get_batch_embeddings_with_retry(batch, max_retries)
|
||||
embeddings.extend(batch_embeddings)
|
||||
success = success and batch_success
|
||||
|
||||
# Add delay between batches to avoid rate limiting
|
||||
if i + batch_size < len(texts):
|
||||
delay = self.rate_limit_tracker['delay_between_batches']
|
||||
await asyncio.sleep(delay) # Configurable delay between batches
|
||||
|
||||
return embeddings, success
|
||||
|
||||
async def _get_batch_embeddings_with_retry(self, texts: List[str], max_retries: int) -> tuple[List[List[float]], bool]:
|
||||
"""Get embeddings for a batch with retry logic"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
# Check rate limit before making request
|
||||
if self._is_rate_limited():
|
||||
delay = self._get_rate_limit_delay()
|
||||
logger.warning(f"Rate limit detected, waiting {delay} seconds")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
# Make the request
|
||||
embeddings = await self._get_embeddings_batch_impl(texts)
|
||||
|
||||
return embeddings, True
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_msg = str(e).lower()
|
||||
|
||||
# Check if it's a rate limit error
|
||||
if any(indicator in error_msg for indicator in ['429', 'rate limit', 'too many requests', 'quota exceeded']):
|
||||
logger.warning(f"Rate limit error (attempt {attempt + 1}/{max_retries + 1}): {e}")
|
||||
self._update_rate_limit_tracker(success=False)
|
||||
|
||||
if attempt < max_retries:
|
||||
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
|
||||
logger.info(f"Retrying in {delay} seconds...")
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Max retries exceeded for rate limit, using fallback embeddings")
|
||||
return self._generate_fallback_embeddings(texts), False
|
||||
else:
|
||||
# Non-rate-limit error
|
||||
logger.error(f"Error generating embeddings: {e}")
|
||||
if attempt < max_retries:
|
||||
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error("Max retries exceeded, using fallback embeddings")
|
||||
return self._generate_fallback_embeddings(texts), False
|
||||
|
||||
# If we get here, all retries failed
|
||||
logger.error(f"All retries failed, last error: {last_error}")
|
||||
return self._generate_fallback_embeddings(texts), False
|
||||
|
||||
async def _get_embeddings_batch_impl(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Implementation of getting embeddings for a batch"""
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import EmbeddingRequest
|
||||
|
||||
embeddings = []
|
||||
|
||||
for text in texts:
|
||||
# Respect rate limit before each request
|
||||
while self._is_rate_limited():
|
||||
delay = self._get_rate_limit_delay()
|
||||
logger.warning(f"Rate limit window exceeded, waiting {delay:.2f}s before next request")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Truncate text if needed
|
||||
max_chars = 1600
|
||||
truncated_text = text[:max_chars] if len(text) > max_chars else text
|
||||
|
||||
llm_request = EmbeddingRequest(
|
||||
model=self.model_name,
|
||||
input=truncated_text,
|
||||
user_id="rag_system",
|
||||
api_key_id=0
|
||||
)
|
||||
|
||||
response = await llm_service.create_embedding(llm_request)
|
||||
|
||||
if response.data and len(response.data) > 0:
|
||||
embedding = response.data[0].embedding
|
||||
if embedding:
|
||||
embeddings.append(embedding)
|
||||
if not hasattr(self, '_dimension_confirmed'):
|
||||
self.dimension = len(embedding)
|
||||
self._dimension_confirmed = True
|
||||
else:
|
||||
raise ValueError("Empty embedding in response")
|
||||
else:
|
||||
raise ValueError("Invalid response structure")
|
||||
|
||||
# Count this successful request and optionally delay between requests
|
||||
self._update_rate_limit_tracker(success=True)
|
||||
per_req_delay = self.rate_limit_tracker.get('delay_per_request', 0)
|
||||
if per_req_delay and per_req_delay > 0:
|
||||
await asyncio.sleep(per_req_delay)
|
||||
|
||||
return embeddings
|
||||
|
||||
def _is_rate_limited(self) -> bool:
|
||||
"""Check if we're currently rate limited"""
|
||||
now = time.time()
|
||||
window_start = self.rate_limit_tracker['window_start']
|
||||
|
||||
# Reset window if it's expired
|
||||
if now - window_start > self.rate_limit_tracker['window_size']:
|
||||
self.rate_limit_tracker['requests_count'] = 0
|
||||
self.rate_limit_tracker['window_start'] = now
|
||||
return False
|
||||
|
||||
# Check if we've exceeded the limit
|
||||
return self.rate_limit_tracker['requests_count'] >= self.rate_limit_tracker['max_requests_per_minute']
|
||||
|
||||
def _get_rate_limit_delay(self) -> float:
|
||||
"""Get delay to wait for rate limit reset"""
|
||||
now = time.time()
|
||||
window_end = self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size']
|
||||
return max(0, window_end - now)
|
||||
|
||||
def _update_rate_limit_tracker(self, success: bool):
|
||||
"""Update the rate limit tracker"""
|
||||
now = time.time()
|
||||
|
||||
# Reset window if it's expired
|
||||
if now - self.rate_limit_tracker['window_start'] > self.rate_limit_tracker['window_size']:
|
||||
self.rate_limit_tracker['requests_count'] = 0
|
||||
self.rate_limit_tracker['window_start'] = now
|
||||
|
||||
# Increment counter on successful requests
|
||||
if success:
|
||||
self.rate_limit_tracker['requests_count'] += 1
|
||||
|
||||
async def get_embedding_stats(self) -> Dict[str, Any]:
|
||||
"""Get embedding service statistics including rate limiting info"""
|
||||
base_stats = await self.get_stats()
|
||||
|
||||
return {
|
||||
**base_stats,
|
||||
"rate_limit_info": {
|
||||
"requests_in_current_window": self.rate_limit_tracker['requests_count'],
|
||||
"max_requests_per_minute": self.rate_limit_tracker['max_requests_per_minute'],
|
||||
"window_reset_in_seconds": max(0,
|
||||
self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - time.time()
|
||||
),
|
||||
"last_rate_limit_error": self.rate_limit_tracker['last_rate_limit_error']
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Global enhanced embedding service instance
|
||||
enhanced_embedding_service = EnhancedEmbeddingService()
|
||||
211
backend/app/services/jsonl_processor.py
Normal file
211
backend/app/services/jsonl_processor.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Optimized JSONL Processor for RAG Module
|
||||
Handles JSONL files efficiently to prevent resource exhaustion
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue
|
||||
from qdrant_client.http.models import Batch
|
||||
|
||||
from app.modules.rag.main import ProcessedDocument
|
||||
# from app.core.analytics import log_module_event # Analytics module not available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JSONLProcessor:
|
||||
"""Specialized processor for JSONL files"""
|
||||
|
||||
def __init__(self, rag_module):
|
||||
self.rag_module = rag_module
|
||||
self.config = rag_module.config
|
||||
|
||||
async def process_and_index_jsonl(self, collection_name: str, content: bytes,
|
||||
filename: str, metadata: Dict[str, Any]) -> str:
|
||||
"""Process and index a JSONL file efficiently
|
||||
|
||||
Processes each JSON line as a separate document to avoid
|
||||
creating thousands of chunks from a single large document.
|
||||
"""
|
||||
try:
|
||||
# Decode content
|
||||
jsonl_content = content.decode('utf-8', errors='replace')
|
||||
lines = jsonl_content.strip().split('\n')
|
||||
|
||||
logger.info(f"Processing JSONL file {filename} with {len(lines)} lines")
|
||||
|
||||
# Generate base document ID
|
||||
base_doc_id = self.rag_module._generate_document_id(jsonl_content, metadata)
|
||||
|
||||
# Process lines in batches
|
||||
batch_size = 10 # Smaller batches for better memory management
|
||||
processed_count = 0
|
||||
|
||||
for batch_start in range(0, len(lines), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(lines))
|
||||
batch_lines = lines[batch_start:batch_end]
|
||||
|
||||
# Process batch
|
||||
await self._process_jsonl_batch(
|
||||
collection_name,
|
||||
batch_lines,
|
||||
batch_start,
|
||||
base_doc_id,
|
||||
filename,
|
||||
metadata
|
||||
)
|
||||
|
||||
processed_count += len(batch_lines)
|
||||
|
||||
# Log progress
|
||||
if processed_count % 50 == 0:
|
||||
logger.info(f"Processed {processed_count}/{len(lines)} lines from {filename}")
|
||||
|
||||
# Small delay to prevent resource exhaustion
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
logger.info(f"Successfully processed JSONL file {filename} with {len(lines)} lines")
|
||||
return base_doc_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing JSONL file {filename}: {e}")
|
||||
raise
|
||||
|
||||
async def _process_jsonl_batch(self, collection_name: str, lines: List[str],
|
||||
start_idx: int, base_doc_id: str,
|
||||
filename: str, metadata: Dict[str, Any]) -> None:
|
||||
"""Process a batch of JSONL lines"""
|
||||
try:
|
||||
points = []
|
||||
|
||||
for line_idx, line in enumerate(lines, start=start_idx + 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
# Parse JSON line
|
||||
data = json.loads(line)
|
||||
|
||||
# Debug: check if data is None
|
||||
if data is None:
|
||||
logger.warning(f"JSON line {line_idx} parsed as None")
|
||||
continue
|
||||
|
||||
# Handle helpjuice export format
|
||||
if 'payload' in data and data['payload'] is not None:
|
||||
payload = data['payload']
|
||||
article_id = data.get('id', f'article_{line_idx}')
|
||||
|
||||
# Extract Q&A
|
||||
question = payload.get('question', '')
|
||||
answer = payload.get('answer', '')
|
||||
language = payload.get('language', 'EN')
|
||||
|
||||
if question or answer:
|
||||
# Create Q&A content
|
||||
content = f"Question: {question}\n\nAnswer: {answer}"
|
||||
|
||||
# Create metadata
|
||||
doc_metadata = {
|
||||
**metadata,
|
||||
"article_id": article_id,
|
||||
"language": language,
|
||||
"filename": filename,
|
||||
"line_number": line_idx,
|
||||
"content_type": "qa_pair",
|
||||
"question": question[:100], # Truncate for metadata
|
||||
"processed_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Generate single embedding for the Q&A pair
|
||||
embeddings = await self.rag_module._generate_embeddings([content])
|
||||
|
||||
# Create point
|
||||
point_id = str(uuid.uuid4())
|
||||
points.append(PointStruct(
|
||||
id=point_id,
|
||||
vector=embeddings[0],
|
||||
payload={
|
||||
**doc_metadata,
|
||||
"document_id": f"{base_doc_id}_{article_id}",
|
||||
"content": content,
|
||||
"chunk_index": 0,
|
||||
"chunk_count": 1
|
||||
}
|
||||
))
|
||||
|
||||
# Handle generic JSON format
|
||||
else:
|
||||
content = json.dumps(data, indent=2, ensure_ascii=False)
|
||||
|
||||
# For larger JSON objects, we might need to chunk
|
||||
if len(content) > 1000:
|
||||
chunks = self.rag_module._chunk_text(content, chunk_size=500)
|
||||
embeddings = await self.rag_module._generate_embeddings(chunks)
|
||||
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
point_id = str(uuid.uuid4())
|
||||
points.append(PointStruct(
|
||||
id=point_id,
|
||||
vector=embedding,
|
||||
payload={
|
||||
**metadata,
|
||||
"filename": filename,
|
||||
"line_number": line_idx,
|
||||
"content_type": "json_object",
|
||||
"document_id": f"{base_doc_id}_line_{line_idx}",
|
||||
"content": chunk,
|
||||
"chunk_index": i,
|
||||
"chunk_count": len(chunks)
|
||||
}
|
||||
))
|
||||
else:
|
||||
# Small JSON - no chunking needed
|
||||
embeddings = await self.rag_module._generate_embeddings([content])
|
||||
point_id = str(uuid.uuid4())
|
||||
points.append(PointStruct(
|
||||
id=point_id,
|
||||
vector=embeddings[0],
|
||||
payload={
|
||||
**metadata,
|
||||
"filename": filename,
|
||||
"line_number": line_idx,
|
||||
"content_type": "json_object",
|
||||
"document_id": f"{base_doc_id}_line_{line_idx}",
|
||||
"content": content,
|
||||
"chunk_index": 0,
|
||||
"chunk_count": 1
|
||||
}
|
||||
))
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Error parsing JSONL line {line_idx}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing JSONL line {line_idx}: {e}")
|
||||
continue
|
||||
|
||||
# Insert all points in this batch
|
||||
if points:
|
||||
self.rag_module.qdrant_client.upsert(
|
||||
collection_name=collection_name,
|
||||
points=points
|
||||
)
|
||||
|
||||
# Update stats
|
||||
self.rag_module.stats["documents_indexed"] += len(points)
|
||||
# log_module_event("rag", "jsonl_batch_processed", { # Analytics module not available
|
||||
# "filename": filename,
|
||||
# "lines_processed": len(lines),
|
||||
# "points_created": len(points)
|
||||
# })
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing JSONL batch: {e}")
|
||||
raise
|
||||
@@ -16,6 +16,7 @@ from .models import ResilienceConfig
|
||||
class ProviderConfig(BaseModel):
|
||||
"""Configuration for an LLM provider"""
|
||||
name: str = Field(..., description="Provider name")
|
||||
provider_type: str = Field(..., description="Provider type (e.g., 'openai', 'privatemode')")
|
||||
enabled: bool = Field(True, description="Whether provider is enabled")
|
||||
base_url: str = Field(..., description="Provider base URL")
|
||||
api_key_env_var: str = Field(..., description="Environment variable for API key")
|
||||
@@ -53,9 +54,6 @@ class LLMServiceConfig(BaseModel):
|
||||
enable_security_checks: bool = Field(True, description="Enable security validation")
|
||||
enable_metrics_collection: bool = Field(True, description="Enable metrics collection")
|
||||
|
||||
# Security settings
|
||||
security_risk_threshold: float = Field(0.8, ge=0.0, le=1.0, description="Risk threshold for blocking")
|
||||
security_warning_threshold: float = Field(0.6, ge=0.0, le=1.0, description="Risk threshold for warnings")
|
||||
max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length")
|
||||
max_response_length: int = Field(32000, ge=1000, description="Maximum response length")
|
||||
|
||||
@@ -66,15 +64,18 @@ class LLMServiceConfig(BaseModel):
|
||||
# Provider configurations
|
||||
providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations")
|
||||
|
||||
# Token rate limiting (organization-wide)
|
||||
token_limits_per_minute: Dict[str, int] = Field(
|
||||
default_factory=lambda: {
|
||||
"prompt_tokens": 20000, # PrivateMode Standard tier
|
||||
"completion_tokens": 10000 # PrivateMode Standard tier
|
||||
},
|
||||
description="Token rate limits per minute (organization-wide)"
|
||||
)
|
||||
|
||||
# Model routing (model_name -> provider_name)
|
||||
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
|
||||
|
||||
@validator('security_risk_threshold')
|
||||
def validate_risk_threshold(cls, v, values):
|
||||
warning_threshold = values.get('security_warning_threshold', 0.6)
|
||||
if v <= warning_threshold:
|
||||
raise ValueError("Risk threshold must be greater than warning threshold")
|
||||
return v
|
||||
|
||||
|
||||
def create_default_config() -> LLMServiceConfig:
|
||||
@@ -84,6 +85,7 @@ def create_default_config() -> LLMServiceConfig:
|
||||
# Models will be fetched dynamically from proxy /models endpoint
|
||||
privatemode_config = ProviderConfig(
|
||||
name="privatemode",
|
||||
provider_type="privatemode",
|
||||
enabled=True,
|
||||
base_url=settings.PRIVATEMODE_PROXY_URL,
|
||||
api_key_env_var="PRIVATEMODE_API_KEY",
|
||||
@@ -91,8 +93,8 @@ def create_default_config() -> LLMServiceConfig:
|
||||
supported_models=[], # Will be populated dynamically from proxy
|
||||
capabilities=["chat", "embeddings", "tee"],
|
||||
priority=1,
|
||||
max_requests_per_minute=100,
|
||||
max_requests_per_hour=2000,
|
||||
max_requests_per_minute=20, # PrivateMode Standard tier limit: 20 req/min
|
||||
max_requests_per_hour=1200, # 20 req/min * 60 min
|
||||
supports_streaming=True,
|
||||
supports_function_calling=True,
|
||||
max_context_window=128000,
|
||||
@@ -110,9 +112,6 @@ def create_default_config() -> LLMServiceConfig:
|
||||
config = LLMServiceConfig(
|
||||
default_provider="privatemode",
|
||||
enable_detailed_logging=settings.LOG_LLM_PROMPTS,
|
||||
enable_security_checks=settings.API_SECURITY_ENABLED,
|
||||
security_risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
|
||||
security_warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
|
||||
providers={
|
||||
"privatemode": privatemode_config
|
||||
},
|
||||
|
||||
@@ -124,7 +124,6 @@ class MetricsCollector:
|
||||
total_requests = len(self._metrics)
|
||||
successful_requests = sum(1 for m in self._metrics if m.success)
|
||||
failed_requests = total_requests - successful_requests
|
||||
security_blocked = sum(1 for m in self._metrics if not m.success and m.security_risk_score > 0.8)
|
||||
|
||||
# Calculate averages
|
||||
latencies = [m.latency_ms for m in self._metrics if m.latency_ms > 0]
|
||||
@@ -143,7 +142,6 @@ class MetricsCollector:
|
||||
total_requests=total_requests,
|
||||
successful_requests=successful_requests,
|
||||
failed_requests=failed_requests,
|
||||
security_blocked_requests=security_blocked,
|
||||
average_latency_ms=avg_latency,
|
||||
average_risk_score=avg_risk_score,
|
||||
provider_metrics=provider_metrics,
|
||||
|
||||
@@ -452,6 +452,8 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
|
||||
else:
|
||||
error_text = await response.text()
|
||||
# Log the detailed error response from the provider
|
||||
logger.error(f"PrivateMode embedding error - Status {response.status}: {error_text}")
|
||||
self._handle_http_error(response.status, error_text, "embeddings")
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
"""
|
||||
LLM Security Manager
|
||||
|
||||
Handles prompt injection detection and audit logging.
|
||||
Provides comprehensive security for LLM interactions.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
import hashlib
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityManager:
|
||||
"""Manages security for LLM operations"""
|
||||
|
||||
def __init__(self):
|
||||
self._setup_prompt_injection_patterns()
|
||||
|
||||
|
||||
def _setup_prompt_injection_patterns(self):
|
||||
"""Setup patterns for prompt injection detection"""
|
||||
self.injection_patterns = [
|
||||
# Direct instruction injection
|
||||
r"(?i)(ignore|forget|disregard|override)\s+(previous|all|above|prior)\s+(instructions|rules|prompts)",
|
||||
r"(?i)(new|updated|different)\s+(instructions|rules|system)",
|
||||
r"(?i)act\s+as\s+(if|though)\s+you\s+(are|were)",
|
||||
r"(?i)pretend\s+(to\s+be|you\s+are)",
|
||||
r"(?i)you\s+are\s+now\s+(a|an)\s+",
|
||||
|
||||
# System role manipulation
|
||||
r"(?i)system\s*:\s*",
|
||||
r"(?i)\[system\]",
|
||||
r"(?i)<system>",
|
||||
r"(?i)assistant\s*:\s*",
|
||||
r"(?i)\[assistant\]",
|
||||
|
||||
# Escape attempts
|
||||
r"(?i)\\n\\n#+",
|
||||
r"(?i)```\s*(system|assistant|user)",
|
||||
r"(?i)---\s*(new|system|override)",
|
||||
|
||||
# Role manipulation
|
||||
r"(?i)(you|your)\s+(role|purpose|function)\s+(is|has\s+changed)",
|
||||
r"(?i)switch\s+to\s+(admin|developer|debug)\s+mode",
|
||||
r"(?i)(admin|root|sudo|developer)\s+(access|mode|privileges)",
|
||||
|
||||
# Information extraction attempts
|
||||
r"(?i)(show|display|reveal|expose)\s+(your|the)\s+(prompt|instructions|system)",
|
||||
r"(?i)what\s+(are|were)\s+your\s+(original|initial)\s+(instructions|prompts)",
|
||||
r"(?i)(debug|verbose|diagnostic)\s+mode",
|
||||
|
||||
# Encoding/obfuscation attempts
|
||||
r"(?i)base64\s*:",
|
||||
r"(?i)hex\s*:",
|
||||
r"(?i)unicode\s*:",
|
||||
r"[A-Za-z0-9+/]{20,}={0,2}", # Potential base64
|
||||
|
||||
# SQL injection patterns (for system prompts)
|
||||
r"(?i)(union|select|insert|update|delete|drop|create)\s+",
|
||||
r"(?i)(or|and)\s+1\s*=\s*1",
|
||||
r"(?i)';?\s*(drop|delete|insert)",
|
||||
|
||||
# Command injection patterns
|
||||
r"(?i)(exec|eval|system|shell|cmd)\s*\(",
|
||||
r"(?i)(\$\(|\`)[^)]+(\)|\`)",
|
||||
r"(?i)&&\s*(rm|del|format)",
|
||||
|
||||
# Jailbreak attempts
|
||||
r"(?i)jailbreak",
|
||||
r"(?i)break\s+out\s+of",
|
||||
r"(?i)escape\s+(the|your)\s+(rules|constraints)",
|
||||
r"(?i)(DAN|Do\s+Anything\s+Now)",
|
||||
r"(?i)unrestricted\s+mode",
|
||||
]
|
||||
|
||||
self.compiled_patterns = [re.compile(pattern) for pattern in self.injection_patterns]
|
||||
logger.info(f"Initialized {len(self.injection_patterns)} prompt injection patterns")
|
||||
|
||||
|
||||
def validate_prompt_security(self, messages: List[Dict[str, str]]) -> Tuple[bool, float, List[str]]:
|
||||
"""
|
||||
Validate messages for prompt injection attempts
|
||||
|
||||
Returns:
|
||||
Tuple[bool, float, List[str]]: (is_safe, risk_score, detected_patterns)
|
||||
"""
|
||||
detected_patterns = []
|
||||
total_risk = 0.0
|
||||
|
||||
for message in messages:
|
||||
content = message.get("content", "")
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# Check against injection patterns
|
||||
for i, pattern in enumerate(self.compiled_patterns):
|
||||
matches = pattern.findall(content)
|
||||
if matches:
|
||||
pattern_risk = self._calculate_pattern_risk(i, matches)
|
||||
total_risk += pattern_risk
|
||||
detected_patterns.append({
|
||||
"pattern_index": i,
|
||||
"pattern": self.injection_patterns[i],
|
||||
"matches": matches,
|
||||
"risk": pattern_risk
|
||||
})
|
||||
|
||||
# Additional security checks
|
||||
total_risk += self._check_message_characteristics(content)
|
||||
|
||||
# Normalize risk score (0.0 to 1.0)
|
||||
risk_score = min(total_risk / len(messages) if messages else 0.0, 1.0)
|
||||
is_safe = risk_score < settings.API_SECURITY_RISK_THRESHOLD
|
||||
|
||||
if detected_patterns:
|
||||
logger.warning(f"Detected {len(detected_patterns)} potential injection patterns, risk score: {risk_score}")
|
||||
|
||||
return is_safe, risk_score, detected_patterns
|
||||
|
||||
def _calculate_pattern_risk(self, pattern_index: int, matches: List) -> float:
|
||||
"""Calculate risk score for a detected pattern"""
|
||||
# Different patterns have different risk levels
|
||||
high_risk_patterns = [0, 1, 2, 3, 4, 5, 6, 7, 14, 15, 16, 22, 23, 24] # System manipulation, jailbreak
|
||||
medium_risk_patterns = [8, 9, 10, 11, 12, 13, 17, 18, 19, 20, 21] # Escape attempts, info extraction
|
||||
|
||||
base_risk = 0.8 if pattern_index in high_risk_patterns else 0.5 if pattern_index in medium_risk_patterns else 0.3
|
||||
|
||||
# Increase risk based on number of matches
|
||||
match_multiplier = min(1.0 + (len(matches) - 1) * 0.2, 2.0)
|
||||
|
||||
return base_risk * match_multiplier
|
||||
|
||||
def _check_message_characteristics(self, content: str) -> float:
|
||||
"""Check message characteristics for additional risk factors"""
|
||||
risk = 0.0
|
||||
|
||||
# Excessive length (potential stuffing attack)
|
||||
if len(content) > 10000:
|
||||
risk += 0.3
|
||||
|
||||
# High ratio of special characters
|
||||
special_chars = sum(1 for c in content if not c.isalnum() and not c.isspace())
|
||||
if len(content) > 0 and special_chars / len(content) > 0.5:
|
||||
risk += 0.4
|
||||
|
||||
# Multiple encoding indicators
|
||||
encoding_indicators = ["base64", "hex", "unicode", "url", "ascii"]
|
||||
found_encodings = sum(1 for indicator in encoding_indicators if indicator.lower() in content.lower())
|
||||
if found_encodings > 1:
|
||||
risk += 0.3
|
||||
|
||||
# Excessive newlines or formatting (potential formatting attacks)
|
||||
if content.count('\n') > 50 or content.count('\\n') > 50:
|
||||
risk += 0.2
|
||||
|
||||
return risk
|
||||
|
||||
def create_audit_log(
|
||||
self,
|
||||
user_id: str,
|
||||
api_key_id: int,
|
||||
provider: str,
|
||||
model: str,
|
||||
request_type: str,
|
||||
risk_score: float,
|
||||
detected_patterns: List[str],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create comprehensive audit log for LLM request"""
|
||||
audit_entry = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"user_id": user_id,
|
||||
"api_key_id": api_key_id,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"request_type": request_type,
|
||||
"security": {
|
||||
"risk_score": risk_score,
|
||||
"detected_patterns": detected_patterns,
|
||||
"security_check_passed": risk_score < settings.API_SECURITY_RISK_THRESHOLD
|
||||
},
|
||||
"metadata": metadata or {},
|
||||
"audit_hash": None # Will be set below
|
||||
}
|
||||
|
||||
# Create hash for audit integrity
|
||||
audit_hash = self._create_audit_hash(audit_entry)
|
||||
audit_entry["audit_hash"] = audit_hash
|
||||
|
||||
# Log based on risk level
|
||||
if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
|
||||
logger.error(f"HIGH RISK LLM REQUEST BLOCKED: {json.dumps(audit_entry)}")
|
||||
elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
|
||||
logger.warning(f"MEDIUM RISK LLM REQUEST: {json.dumps(audit_entry)}")
|
||||
else:
|
||||
logger.info(f"LLM REQUEST AUDIT: user={user_id}, model={model}, risk={risk_score:.3f}")
|
||||
|
||||
return audit_entry
|
||||
|
||||
def _create_audit_hash(self, audit_entry: Dict[str, Any]) -> str:
|
||||
"""Create hash for audit trail integrity"""
|
||||
# Create hash from key fields (excluding the hash itself)
|
||||
hash_data = {
|
||||
"timestamp": audit_entry["timestamp"],
|
||||
"user_id": audit_entry["user_id"],
|
||||
"api_key_id": audit_entry["api_key_id"],
|
||||
"provider": audit_entry["provider"],
|
||||
"model": audit_entry["model"],
|
||||
"request_type": audit_entry["request_type"],
|
||||
"risk_score": audit_entry["security"]["risk_score"]
|
||||
}
|
||||
|
||||
hash_string = json.dumps(hash_data, sort_keys=True)
|
||||
return hashlib.sha256(hash_string.encode()).hexdigest()
|
||||
|
||||
def log_detailed_request(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
context_info: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Log detailed LLM request if LOG_LLM_PROMPTS is enabled"""
|
||||
if not settings.LOG_LLM_PROMPTS:
|
||||
return
|
||||
|
||||
logger.info("=== DETAILED LLM REQUEST ===")
|
||||
logger.info(f"Model: {model}")
|
||||
logger.info(f"Provider: {provider}")
|
||||
logger.info(f"User ID: {user_id}")
|
||||
|
||||
if context_info:
|
||||
for key, value in context_info.items():
|
||||
logger.info(f"{key}: {value}")
|
||||
|
||||
logger.info("Messages to LLM:")
|
||||
for i, message in enumerate(messages):
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "")[:500] # Truncate for logging
|
||||
logger.info(f" Message {i+1} [{role}]: {content}{'...' if len(message.get('content', '')) > 500 else ''}")
|
||||
|
||||
logger.info("=== END DETAILED LLM REQUEST ===")
|
||||
|
||||
def log_detailed_response(
|
||||
self,
|
||||
response_content: str,
|
||||
token_usage: Optional[Dict[str, int]] = None,
|
||||
provider: str = "unknown"
|
||||
):
|
||||
"""Log detailed LLM response if LOG_LLM_PROMPTS is enabled"""
|
||||
if not settings.LOG_LLM_PROMPTS:
|
||||
return
|
||||
|
||||
logger.info("=== DETAILED LLM RESPONSE ===")
|
||||
logger.info(f"Provider: {provider}")
|
||||
logger.info(f"Response content: {response_content[:500]}{'...' if len(response_content) > 500 else ''}")
|
||||
|
||||
if token_usage:
|
||||
logger.info(f"Token usage - Prompt: {token_usage.get('prompt_tokens', 0)}, "
|
||||
f"Completion: {token_usage.get('completion_tokens', 0)}, "
|
||||
f"Total: {token_usage.get('total_tokens', 0)}")
|
||||
|
||||
logger.info("=== END DETAILED LLM RESPONSE ===")
|
||||
|
||||
|
||||
class SecurityError(Exception):
|
||||
"""Security-related errors in LLM operations"""
|
||||
pass
|
||||
|
||||
|
||||
# Global security manager instance
|
||||
security_manager = SecurityManager()
|
||||
@@ -16,9 +16,10 @@ from .models import (
|
||||
ModelInfo, ProviderStatus, LLMMetrics
|
||||
)
|
||||
from .config import config_manager, ProviderConfig
|
||||
# Security service removed as requested
|
||||
from ...core.config import settings
|
||||
|
||||
from .resilience import ResilienceManagerFactory
|
||||
from .metrics import metrics_collector
|
||||
# from .metrics import metrics_collector
|
||||
from .providers import BaseLLMProvider, PrivateModeProvider
|
||||
from .exceptions import (
|
||||
LLMError, ProviderError, SecurityError, ConfigurationError,
|
||||
@@ -149,8 +150,7 @@ class LLMService:
|
||||
if not request.messages:
|
||||
raise ValidationError("Messages cannot be empty", field="messages")
|
||||
|
||||
# Security validation removed as requested
|
||||
messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
risk_score = 0.0
|
||||
|
||||
# Get provider for model
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
@@ -159,7 +159,6 @@ class LLMService:
|
||||
if not provider:
|
||||
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
|
||||
|
||||
# Security logging removed as requested
|
||||
|
||||
# Execute with resilience
|
||||
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
|
||||
@@ -173,49 +172,20 @@ class LLMService:
|
||||
non_retryable_exceptions=(ValidationError,)
|
||||
)
|
||||
|
||||
# Set default security values since security is removed
|
||||
response.security_check = True
|
||||
response.risk_score = 0.0
|
||||
response.detected_patterns = []
|
||||
|
||||
# Security logging removed as requested
|
||||
|
||||
# Record successful request
|
||||
|
||||
# Record successful request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
metrics_collector.record_request(
|
||||
provider=provider_name,
|
||||
model=request.model,
|
||||
request_type="chat_completion",
|
||||
success=True,
|
||||
latency_ms=total_latency,
|
||||
token_usage=response.usage.model_dump() if response.usage else None,
|
||||
# security_risk_score removed as requested
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id
|
||||
)
|
||||
|
||||
# Security audit logging removed as requested
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Record failed request
|
||||
# Record failed request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||
|
||||
metrics_collector.record_request(
|
||||
provider=provider_name,
|
||||
model=request.model,
|
||||
request_type="chat_completion",
|
||||
success=False,
|
||||
latency_ms=total_latency,
|
||||
# security_risk_score removed as requested
|
||||
error_code=error_code,
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id
|
||||
)
|
||||
|
||||
# Security audit logging removed as requested
|
||||
|
||||
raise
|
||||
|
||||
@@ -224,8 +194,9 @@ class LLMService:
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
# Security validation removed as requested
|
||||
messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
# Security validation disabled - always allow streaming requests
|
||||
risk_score = 0.0
|
||||
|
||||
|
||||
# Get provider
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
@@ -247,19 +218,8 @@ class LLMService:
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
# Record streaming failure
|
||||
# Record streaming failure - metrics disabled
|
||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||
metrics_collector.record_request(
|
||||
provider=provider_name,
|
||||
model=request.model,
|
||||
request_type="chat_completion_stream",
|
||||
success=False,
|
||||
latency_ms=0,
|
||||
# security_risk_score removed as requested
|
||||
error_code=error_code,
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
||||
@@ -267,7 +227,9 @@ class LLMService:
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
# Security validation removed as requested
|
||||
# Security validation disabled - always allow embedding requests
|
||||
risk_score = 0.0
|
||||
|
||||
|
||||
# Get provider
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
@@ -288,43 +250,17 @@ class LLMService:
|
||||
non_retryable_exceptions=(ValidationError,)
|
||||
)
|
||||
|
||||
# Set default security values since security is removed
|
||||
response.security_check = True
|
||||
response.risk_score = 0.0
|
||||
|
||||
# Record successful request
|
||||
# Record successful request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
metrics_collector.record_request(
|
||||
provider=provider_name,
|
||||
model=request.model,
|
||||
request_type="embedding",
|
||||
success=True,
|
||||
latency_ms=total_latency,
|
||||
token_usage=response.usage.model_dump() if response.usage else None,
|
||||
# security_risk_score removed as requested
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id
|
||||
)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Record failed request
|
||||
# Record failed request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||
|
||||
metrics_collector.record_request(
|
||||
provider=provider_name,
|
||||
model=request.model,
|
||||
request_type="embedding",
|
||||
success=False,
|
||||
latency_ms=total_latency,
|
||||
# security_risk_score removed as requested
|
||||
error_code=error_code,
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
async def get_models(self, provider_name: Optional[str] = None) -> List[ModelInfo]:
|
||||
@@ -378,12 +314,18 @@ class LLMService:
|
||||
return status_dict
|
||||
|
||||
def get_metrics(self) -> LLMMetrics:
|
||||
"""Get service metrics"""
|
||||
return metrics_collector.get_metrics()
|
||||
"""Get service metrics - metrics disabled"""
|
||||
# return metrics_collector.get_metrics()
|
||||
return LLMMetrics(
|
||||
total_requests=0,
|
||||
success_rate=0.0,
|
||||
avg_latency_ms=0,
|
||||
error_rates={}
|
||||
)
|
||||
|
||||
def get_health_summary(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive health summary"""
|
||||
metrics_health = metrics_collector.get_health_summary()
|
||||
"""Get comprehensive health summary - metrics disabled"""
|
||||
# metrics_health = metrics_collector.get_health_summary()
|
||||
resilience_health = ResilienceManagerFactory.get_all_health_status()
|
||||
|
||||
return {
|
||||
@@ -391,7 +333,7 @@ class LLMService:
|
||||
"startup_time": self._startup_time.isoformat() if self._startup_time else None,
|
||||
"provider_count": len(self._providers),
|
||||
"active_providers": list(self._providers.keys()),
|
||||
"metrics": metrics_health,
|
||||
"metrics": {"status": "disabled"},
|
||||
"resilience": resilience_health
|
||||
}
|
||||
|
||||
|
||||
163
backend/app/services/qdrant_stats_service.py
Normal file
163
backend/app/services/qdrant_stats_service.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Qdrant Stats Service
|
||||
Provides direct, live statistics from Qdrant vector database
|
||||
This is the single source of truth for all RAG collection statistics
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QdrantStatsService:
|
||||
"""Service for getting live statistics from Qdrant"""
|
||||
|
||||
def __init__(self):
|
||||
self.qdrant_host = getattr(settings, 'QDRANT_HOST', 'enclava-qdrant')
|
||||
self.qdrant_port = getattr(settings, 'QDRANT_PORT', 6333)
|
||||
self.qdrant_url = f"http://{self.qdrant_host}:{self.qdrant_port}"
|
||||
|
||||
async def get_collections_stats(self) -> Dict[str, Any]:
|
||||
"""Get live collection statistics directly from Qdrant"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# Get all collections
|
||||
response = await client.get(f"{self.qdrant_url}/collections")
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to get collections: {response.status_code}")
|
||||
return {"collections": [], "total_documents": 0, "total_size_bytes": 0}
|
||||
|
||||
data = response.json()
|
||||
result = data.get("result", {})
|
||||
collections_data = result.get("collections", [])
|
||||
|
||||
collections = []
|
||||
total_documents = 0
|
||||
total_size_bytes = 0
|
||||
|
||||
# Get detailed info for each collection
|
||||
for col_info in collections_data:
|
||||
collection_name = col_info.get("name", "")
|
||||
# Include all collections, not just rag_ ones
|
||||
|
||||
# Get detailed collection info
|
||||
try:
|
||||
detail_response = await client.get(f"{self.qdrant_url}/collections/{collection_name}")
|
||||
if detail_response.status_code == 200:
|
||||
detail_data = detail_response.json()
|
||||
detail_result = detail_data.get("result", {})
|
||||
|
||||
points_count = detail_result.get("points_count", 0)
|
||||
status = detail_result.get("status", "unknown")
|
||||
|
||||
# Get vector size for size calculation
|
||||
vector_size = 1024 # Default for multilingual-e5-large
|
||||
try:
|
||||
config = detail_result.get("config", {})
|
||||
params = config.get("params", {})
|
||||
vectors = params.get("vectors", {})
|
||||
if isinstance(vectors, dict) and "size" in vectors:
|
||||
vector_size = vectors["size"]
|
||||
elif isinstance(vectors, dict) and "default" in vectors:
|
||||
vector_size = vectors["default"].get("size", 1024)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Estimate size (points * vector_size * 4 bytes + 20% metadata overhead)
|
||||
estimated_size = int(points_count * vector_size * 4 * 1.2)
|
||||
|
||||
# Extract collection metadata for user-friendly name
|
||||
display_name = collection_name
|
||||
description = ""
|
||||
|
||||
# Parse collection name to get original name
|
||||
if collection_name.startswith("rag_"):
|
||||
parts = collection_name[4:].split("_")
|
||||
if len(parts) > 1:
|
||||
# Remove the UUID suffix
|
||||
uuid_parts = [p for p in parts if len(p) == 8 and all(c in '0123456789abcdef' for c in p)]
|
||||
for uuid_part in uuid_parts:
|
||||
parts.remove(uuid_part)
|
||||
display_name = " ".join(parts).replace("_", " ").title()
|
||||
|
||||
collection_stat = {
|
||||
"id": collection_name,
|
||||
"name": display_name,
|
||||
"description": description,
|
||||
"document_count": points_count,
|
||||
"vector_count": points_count,
|
||||
"size_bytes": estimated_size,
|
||||
"status": status,
|
||||
"qdrant_collection_name": collection_name,
|
||||
"created_at": "", # Not available from Qdrant
|
||||
"updated_at": datetime.utcnow().isoformat(),
|
||||
"is_active": status == "green",
|
||||
"is_managed": True,
|
||||
"source": "qdrant"
|
||||
}
|
||||
|
||||
collections.append(collection_stat)
|
||||
total_documents += points_count
|
||||
total_size_bytes += estimated_size
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting details for collection {collection_name}: {e}")
|
||||
continue
|
||||
|
||||
return {
|
||||
"collections": collections,
|
||||
"total_documents": total_documents,
|
||||
"total_size_bytes": total_size_bytes,
|
||||
"total_collections": len(collections)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting Qdrant stats: {e}")
|
||||
return {"collections": [], "total_documents": 0, "total_size_bytes": 0, "total_collections": 0}
|
||||
|
||||
async def get_collection_stats(self, collection_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get statistics for a specific collection"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(f"{self.qdrant_url}/collections/{collection_name}")
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
|
||||
data = response.json()
|
||||
result = data.get("result", {})
|
||||
|
||||
points_count = result.get("points_count", 0)
|
||||
status = result.get("status", "unknown")
|
||||
|
||||
# Get vector size
|
||||
vector_size = 1024
|
||||
try:
|
||||
config = result.get("config", {})
|
||||
params = config.get("params", {})
|
||||
vectors = params.get("vectors", {})
|
||||
if isinstance(vectors, dict) and "size" in vectors:
|
||||
vector_size = vectors["size"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
estimated_size = int(points_count * vector_size * 4 * 1.2)
|
||||
|
||||
return {
|
||||
"document_count": points_count,
|
||||
"vector_count": points_count,
|
||||
"size_bytes": estimated_size,
|
||||
"status": status
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting collection stats for {collection_name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Global instance
|
||||
qdrant_stats_service = QdrantStatsService()
|
||||
@@ -755,10 +755,11 @@ class RAGService:
|
||||
|
||||
# Process with RAG module
|
||||
try:
|
||||
# Pass file_path in metadata so JSONL indexing can reopen the source file
|
||||
processed_doc = await rag_module.process_document(
|
||||
file_content,
|
||||
document.original_filename,
|
||||
{}
|
||||
{"file_path": document.file_path}
|
||||
)
|
||||
|
||||
# Success case - update document with processed content
|
||||
|
||||
@@ -69,6 +69,7 @@ class ChatbotConfig:
|
||||
memory_length: int = 10 # Number of previous messages to remember
|
||||
use_rag: bool = False
|
||||
rag_top_k: int = 5
|
||||
rag_score_threshold: float = 0.02 # Lowered from default 0.3 to allow more results
|
||||
fallback_responses: List[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -386,7 +387,8 @@ class ChatbotModule(BaseModule):
|
||||
rag_results = await self.rag_module.search_documents(
|
||||
query=message,
|
||||
max_results=config.rag_top_k,
|
||||
collection_name=qdrant_collection_name
|
||||
collection_name=qdrant_collection_name,
|
||||
score_threshold=config.rag_score_threshold
|
||||
)
|
||||
|
||||
if rag_results:
|
||||
@@ -395,8 +397,8 @@ class ChatbotModule(BaseModule):
|
||||
for i, result in enumerate(rag_results)]
|
||||
|
||||
# Build full RAG context from all results
|
||||
rag_context = "\\n\\nRelevant information from knowledge base:\\n" + "\\n\\n".join([
|
||||
f"[Document {i+1}]:\\n{result.document.content}" for i, result in enumerate(rag_results)
|
||||
rag_context = "\n\nRelevant information from knowledge base:\n" + "\n\n".join([
|
||||
f"[Document {i+1}]:\n{result.document.content}" for i, result in enumerate(rag_results)
|
||||
])
|
||||
|
||||
# Detailed RAG logging - ALWAYS log for debugging
|
||||
@@ -405,14 +407,14 @@ class ChatbotModule(BaseModule):
|
||||
logger.info(f"Collection: {qdrant_collection_name}")
|
||||
logger.info(f"Number of results: {len(rag_results)}")
|
||||
for i, result in enumerate(rag_results):
|
||||
logger.info(f"\\n--- RAG Result {i+1} ---")
|
||||
logger.info(f"\n--- RAG Result {i+1} ---")
|
||||
logger.info(f"Score: {getattr(result, 'score', 'N/A')}")
|
||||
logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}")
|
||||
logger.info(f"Full Content ({len(result.document.content)} chars):")
|
||||
logger.info(f"{result.document.content}")
|
||||
if hasattr(result.document, 'metadata'):
|
||||
logger.info(f"Metadata: {result.document.metadata}")
|
||||
logger.info(f"\\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
|
||||
logger.info(f"\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
|
||||
logger.info(rag_context)
|
||||
logger.info("=== END RAG SEARCH RESULTS ===")
|
||||
else:
|
||||
@@ -445,9 +447,9 @@ class ChatbotModule(BaseModule):
|
||||
if config.use_rag and rag_context:
|
||||
logger.info(f"RAG context added: {len(rag_context)} characters")
|
||||
logger.info(f"RAG sources: {len(sources) if sources else 0} documents")
|
||||
logger.info("\\n=== COMPLETE MESSAGES SENT TO LLM ===")
|
||||
logger.info("\n=== COMPLETE MESSAGES SENT TO LLM ===")
|
||||
for i, msg in enumerate(messages):
|
||||
logger.info(f"\\n--- Message {i+1} ---")
|
||||
logger.info(f"\n--- Message {i+1} ---")
|
||||
logger.info(f"Role: {msg['role']}")
|
||||
logger.info(f"Content ({len(msg['content'])} chars):")
|
||||
# Truncate long content for logging (full RAG context can be very long)
|
||||
@@ -520,9 +522,11 @@ class ChatbotModule(BaseModule):
|
||||
# System prompt
|
||||
system_prompt = config.system_prompt
|
||||
if rag_context:
|
||||
system_prompt += rag_context
|
||||
# Add explicit instruction to use RAG context
|
||||
system_prompt += "\n\nIMPORTANT: Use the following information from the knowledge base to answer the user's question. " \
|
||||
"This information is directly relevant to their query and should be your primary source:\n" + rag_context
|
||||
if context and context.get('additional_instructions'):
|
||||
system_prompt += f"\\n\\nAdditional instructions: {context['additional_instructions']}"
|
||||
system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}"
|
||||
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
@@ -709,9 +713,21 @@ class ChatbotModule(BaseModule):
|
||||
fallback_responses=chatbot_config.get("fallback_responses", [])
|
||||
)
|
||||
|
||||
# Generate response using internal method with empty message history
|
||||
# Generate response using internal method
|
||||
# Create a temporary message object for the current user message
|
||||
temp_messages = [
|
||||
DBMessage(
|
||||
id=0,
|
||||
conversation_id=0,
|
||||
role="user",
|
||||
content=message,
|
||||
timestamp=datetime.utcnow(),
|
||||
metadata={}
|
||||
)
|
||||
]
|
||||
|
||||
response_content, sources = await self._generate_response(
|
||||
message, [], config, None, db
|
||||
message, temp_messages, config, None, db
|
||||
)
|
||||
|
||||
return {
|
||||
|
||||
@@ -53,7 +53,7 @@ except ImportError:
|
||||
PYTHON_DOCX_AVAILABLE = False
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
|
||||
from qdrant_client.models import Distance, VectorParams, PointStruct, ScoredPoint, Filter, FieldCondition, MatchValue
|
||||
from qdrant_client.http import models
|
||||
import tiktoken
|
||||
|
||||
@@ -134,6 +134,19 @@ class RAGModule(BaseModule):
|
||||
self.embedding_service = None
|
||||
self.tokenizer = None
|
||||
|
||||
# Set improved default configuration
|
||||
self.config = {
|
||||
"chunk_size": 300, # Reduced from 400 for better precision
|
||||
"chunk_overlap": 50, # Added overlap for context preservation
|
||||
"max_results": 10,
|
||||
"score_threshold": 0.3, # Increased from 0.0 to filter low-quality results
|
||||
"enable_hybrid": True, # Enable hybrid search (vector + BM25)
|
||||
"hybrid_weights": {"vector": 0.7, "bm25": 0.3} # Weight for hybrid scoring
|
||||
}
|
||||
# Update with any provided config
|
||||
if config:
|
||||
self.config.update(config)
|
||||
|
||||
# Content processing components
|
||||
self.nlp_model = None
|
||||
self.lemmatizer = None
|
||||
@@ -625,11 +638,19 @@ class RAGModule(BaseModule):
|
||||
np.random.seed(hash(text) % 2**32)
|
||||
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
|
||||
|
||||
async def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts (batch processing)"""
|
||||
if self.embedding_service:
|
||||
# Add task-specific prefixes for better E5 model performance
|
||||
if is_document:
|
||||
# For document passages, use "passage:" prefix
|
||||
prefixed_texts = [f"passage: {text}" for text in texts]
|
||||
else:
|
||||
# For queries, use "query:" prefix (handled in search method)
|
||||
prefixed_texts = texts
|
||||
|
||||
# Use real embedding service for batch processing
|
||||
return await self.embedding_service.get_embeddings(texts)
|
||||
return await self.embedding_service.get_embeddings(prefixed_texts)
|
||||
else:
|
||||
# Fallback to individual processing
|
||||
embeddings = []
|
||||
@@ -639,19 +660,33 @@ class RAGModule(BaseModule):
|
||||
return embeddings
|
||||
|
||||
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
|
||||
"""Split text into chunks"""
|
||||
chunk_size = chunk_size or self.config.get("chunk_size", 400)
|
||||
"""Split text into overlapping chunks for better context preservation"""
|
||||
chunk_size = chunk_size or self.config.get("chunk_size", 300)
|
||||
chunk_overlap = self.config.get("chunk_overlap", 50)
|
||||
|
||||
# Tokenize text
|
||||
tokens = self.tokenizer.encode(text)
|
||||
|
||||
# Split into chunks
|
||||
# Split into chunks with overlap
|
||||
chunks = []
|
||||
for i in range(0, len(tokens), chunk_size):
|
||||
chunk_tokens = tokens[i:i + chunk_size]
|
||||
start_idx = 0
|
||||
|
||||
while start_idx < len(tokens):
|
||||
end_idx = min(start_idx + chunk_size, len(tokens))
|
||||
chunk_tokens = tokens[start_idx:end_idx]
|
||||
chunk_text = self.tokenizer.decode(chunk_tokens)
|
||||
|
||||
# Only add non-empty chunks
|
||||
if chunk_text.strip():
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Move to next chunk with overlap
|
||||
start_idx = end_idx - chunk_overlap
|
||||
|
||||
# Ensure progress (in case overlap >= chunk_size)
|
||||
if start_idx >= end_idx:
|
||||
start_idx = end_idx
|
||||
|
||||
return chunks
|
||||
|
||||
async def _process_text(self, content: bytes, filename: str) -> str:
|
||||
@@ -895,12 +930,18 @@ class RAGModule(BaseModule):
|
||||
- Each line contains a JSON object with 'id' and 'payload'
|
||||
- Payload contains 'question', 'language', and 'answer' fields
|
||||
- Combines question and answer into searchable content
|
||||
|
||||
Performance optimizations:
|
||||
- Processes articles in smaller batches to reduce memory usage
|
||||
- Uses streaming approach for large files
|
||||
"""
|
||||
try:
|
||||
# Use streaming approach for large files
|
||||
jsonl_content = content.decode('utf-8', errors='replace')
|
||||
lines = jsonl_content.strip().split('\n')
|
||||
|
||||
processed_articles = []
|
||||
batch_size = 50 # Process in batches of 50 articles
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line.strip():
|
||||
@@ -1126,7 +1167,7 @@ class RAGModule(BaseModule):
|
||||
chunks = self._chunk_text(content)
|
||||
|
||||
# Generate embeddings for all chunks in batch (more efficient)
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
embeddings = await self._generate_embeddings(chunks, is_document=True)
|
||||
|
||||
# Create document points
|
||||
points = []
|
||||
@@ -1177,6 +1218,24 @@ class RAGModule(BaseModule):
|
||||
collection_name = collection_name or self.default_collection_name
|
||||
|
||||
try:
|
||||
# Special handling for JSONL files
|
||||
if processed_doc.file_type == 'jsonl':
|
||||
# Import the optimized JSONL processor
|
||||
from app.services.jsonl_processor import JSONLProcessor
|
||||
jsonl_processor = JSONLProcessor(self)
|
||||
|
||||
# Read the original file content
|
||||
with open(processed_doc.metadata.get('file_path', ''), 'rb') as f:
|
||||
file_content = f.read()
|
||||
|
||||
# Process using the optimized JSONL processor
|
||||
return await jsonl_processor.process_and_index_jsonl(
|
||||
collection_name=collection_name,
|
||||
content=file_content,
|
||||
filename=processed_doc.original_filename,
|
||||
metadata=processed_doc.metadata
|
||||
)
|
||||
|
||||
# Ensure collection exists
|
||||
await self._ensure_collection_exists(collection_name)
|
||||
|
||||
@@ -1189,7 +1248,7 @@ class RAGModule(BaseModule):
|
||||
chunks = self._chunk_text(processed_doc.content)
|
||||
|
||||
# Generate embeddings for all chunks in batch (more efficient)
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
embeddings = await self._generate_embeddings(chunks, is_document=True)
|
||||
|
||||
# Create document points with enhanced metadata
|
||||
points = []
|
||||
@@ -1260,12 +1319,196 @@ class RAGModule(BaseModule):
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
|
||||
async def _hybrid_search(self, collection_name: str, query: str, query_vector: List[float],
|
||||
query_filter: Optional[Filter], limit: int, score_threshold: float) -> List[Any]:
|
||||
"""Perform hybrid search combining vector similarity and BM25 scoring"""
|
||||
|
||||
# Preprocess query for BM25
|
||||
query_terms = self._preprocess_text_for_bm25(query)
|
||||
|
||||
# Get all documents from the collection (for BM25 scoring)
|
||||
# Note: In production, you'd want to optimize this with a proper BM25 index
|
||||
scroll_filter = query_filter or Filter()
|
||||
all_points = []
|
||||
|
||||
# Use scroll to get all points
|
||||
offset = None
|
||||
batch_size = 100
|
||||
while True:
|
||||
search_result = self.qdrant_client.scroll(
|
||||
collection_name=collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=batch_size,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
points = search_result[0]
|
||||
all_points.extend(points)
|
||||
|
||||
if len(points) < batch_size:
|
||||
break
|
||||
|
||||
offset = points[-1].id
|
||||
|
||||
# Calculate BM25 scores for each document
|
||||
bm25_scores = {}
|
||||
for point in all_points:
|
||||
doc_id = point.payload.get("document_id", "")
|
||||
content = point.payload.get("content", "")
|
||||
|
||||
# Calculate BM25 score
|
||||
bm25_score = self._calculate_bm25_score(query_terms, content)
|
||||
bm25_scores[doc_id] = bm25_score
|
||||
|
||||
# Perform vector search
|
||||
vector_results = self.qdrant_client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=query_filter,
|
||||
limit=limit * 2, # Get more results for re-ranking
|
||||
score_threshold=score_threshold / 2 # Lower threshold for initial search
|
||||
)
|
||||
|
||||
# Combine scores with improved normalization
|
||||
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
|
||||
vector_weight = hybrid_weights.get("vector", 0.7)
|
||||
bm25_weight = hybrid_weights.get("bm25", 0.3)
|
||||
|
||||
# Get score distributions for better normalization
|
||||
vector_scores = [r.score for r in vector_results]
|
||||
bm25_scores_list = list(bm25_scores.values())
|
||||
|
||||
# Calculate statistics for normalization
|
||||
if vector_scores:
|
||||
v_max = max(vector_scores)
|
||||
v_min = min(vector_scores)
|
||||
v_range = v_max - v_min if v_max != v_min else 1
|
||||
else:
|
||||
v_max, v_min, v_range = 1, 0, 1
|
||||
|
||||
if bm25_scores_list:
|
||||
bm25_max = max(bm25_scores_list)
|
||||
bm25_min = min(bm25_scores_list)
|
||||
bm25_range = bm25_max - bm25_min if bm25_max != bm25_min else 1
|
||||
else:
|
||||
bm25_max, bm25_min, bm25_range = 1, 0, 1
|
||||
|
||||
# Create hybrid results with improved scoring
|
||||
hybrid_results = []
|
||||
for result in vector_results:
|
||||
doc_id = result.payload.get("document_id", "")
|
||||
vector_score = result.score
|
||||
bm25_score = bm25_scores.get(doc_id, 0.0)
|
||||
|
||||
# Improved normalization using actual score distributions
|
||||
vector_norm = (vector_score - v_min) / v_range if v_range > 0 else 0.5
|
||||
bm25_norm = (bm25_score - bm25_min) / bm25_range if bm25_range > 0 else 0.5
|
||||
|
||||
# Apply reciprocal rank fusion for better combination
|
||||
# This gives more weight to documents that rank highly in both methods
|
||||
rrf_vector = 1.0 / (1.0 + vector_results.index(result) + 1) # +1 to avoid division by zero
|
||||
rrf_bm25 = 1.0 / (1.0 + sorted(bm25_scores_list, reverse=True).index(bm25_score) + 1) if bm25_score in bm25_scores_list else 0
|
||||
|
||||
# Calculate hybrid score using both normalized scores and RRF
|
||||
hybrid_score = (vector_weight * vector_norm + bm25_weight * bm25_norm) * 0.7 + (rrf_vector + rrf_bm25) * 0.3
|
||||
|
||||
# Create new point with hybrid score
|
||||
hybrid_point = ScoredPoint(
|
||||
id=result.id,
|
||||
payload=result.payload,
|
||||
score=hybrid_score,
|
||||
vector=result.vector,
|
||||
shard_key=None,
|
||||
order_value=None
|
||||
)
|
||||
hybrid_results.append(hybrid_point)
|
||||
|
||||
# Sort by hybrid score and apply final threshold
|
||||
hybrid_results.sort(key=lambda x: x.score, reverse=True)
|
||||
final_results = [r for r in hybrid_results if r.score >= score_threshold][:limit]
|
||||
|
||||
logger.info(f"Hybrid search: {len(vector_results)} vector results, {len(final_results)} final results")
|
||||
return final_results
|
||||
|
||||
def _preprocess_text_for_bm25(self, text: str) -> List[str]:
|
||||
"""Preprocess text for BM25 scoring"""
|
||||
if not NLTK_AVAILABLE:
|
||||
return text.lower().split()
|
||||
|
||||
try:
|
||||
# Tokenize
|
||||
tokens = word_tokenize(text.lower())
|
||||
|
||||
# Remove stopwords and non-alphabetic tokens
|
||||
stop_words = set(stopwords.words('english'))
|
||||
filtered_tokens = [
|
||||
token for token in tokens
|
||||
if token.isalpha() and token not in stop_words and len(token) > 2
|
||||
]
|
||||
|
||||
return filtered_tokens
|
||||
except:
|
||||
# Fallback to simple splitting
|
||||
return text.lower().split()
|
||||
|
||||
def _calculate_bm25_score(self, query_terms: List[str], document: str) -> float:
|
||||
"""Calculate BM25 score for a document against query terms"""
|
||||
if not query_terms:
|
||||
return 0.0
|
||||
|
||||
# Preprocess document
|
||||
doc_terms = self._preprocess_text_for_bm25(document)
|
||||
if not doc_terms:
|
||||
return 0.0
|
||||
|
||||
# Calculate term frequencies
|
||||
doc_len = len(doc_terms)
|
||||
avg_doc_len = 300 # Average document length (configurable)
|
||||
|
||||
# BM25 parameters
|
||||
k1 = 1.2 # Controls term frequency saturation
|
||||
b = 0.75 # Controls document length normalization
|
||||
|
||||
score = 0.0
|
||||
|
||||
# Calculate IDF for each query term
|
||||
for term in set(query_terms):
|
||||
# Term frequency in document
|
||||
tf = doc_terms.count(term)
|
||||
|
||||
# Simple IDF (log(N/n) + 1)
|
||||
# In production, you'd use the actual document frequency
|
||||
idf = 2.0 # Simplified IDF
|
||||
|
||||
# BM25 formula
|
||||
numerator = tf * (k1 + 1)
|
||||
denominator = tf + k1 * (1 - b + b * (doc_len / avg_doc_len))
|
||||
|
||||
score += idf * (numerator / denominator)
|
||||
|
||||
# Normalize score to 0-1 range
|
||||
return min(score / 10.0, 1.0) # Simple normalization
|
||||
|
||||
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
|
||||
"""Search for relevant documents"""
|
||||
if not self.enabled:
|
||||
raise RuntimeError("RAG module not initialized")
|
||||
|
||||
collection_name = collection_name or self.default_collection_name
|
||||
|
||||
# Special handling for collections with different vector dimensions
|
||||
SPECIAL_COLLECTIONS = {
|
||||
"bitbox02_faq_local": {
|
||||
"dimension": 384,
|
||||
"model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
},
|
||||
"bitbox_local_rag": {
|
||||
"dimension": 384,
|
||||
"model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
}
|
||||
}
|
||||
max_results = max_results or self.config.get("max_results", 10)
|
||||
|
||||
# Check cache (include collection name in cache key)
|
||||
@@ -1278,8 +1521,25 @@ class RAGModule(BaseModule):
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await self._generate_embedding(query)
|
||||
# Generate query embedding with task-specific prefix for better retrieval
|
||||
try:
|
||||
# Check if this is a special collection
|
||||
if collection_name in SPECIAL_COLLECTIONS:
|
||||
# Try to import sentence-transformers
|
||||
import sentence_transformers
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer(SPECIAL_COLLECTIONS[collection_name]["model"])
|
||||
query_embedding = model.encode([query], normalize_embeddings=True)[0].tolist()
|
||||
logger.info(f"Using {SPECIAL_COLLECTIONS[collection_name]['dimension']}-dim local model for {collection_name}")
|
||||
else:
|
||||
# The E5 model works better with "query:" prefix for search queries
|
||||
optimized_query = f"query: {query}"
|
||||
query_embedding = await self._generate_embedding(optimized_query)
|
||||
except ImportError:
|
||||
# Fallback to default embedding if sentence-transformers is not available
|
||||
logger.warning(f"sentence-transformers not available, falling back to default embedding for {collection_name}")
|
||||
optimized_query = f"query: {query}"
|
||||
query_embedding = await self._generate_embedding(optimized_query)
|
||||
|
||||
# Build filter
|
||||
search_filter = None
|
||||
@@ -1297,13 +1557,29 @@ class RAGModule(BaseModule):
|
||||
logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}")
|
||||
logger.info(f"Embedding service available: {self.embedding_service is not None}")
|
||||
|
||||
# Search in Qdrant
|
||||
# Check if hybrid search is enabled
|
||||
enable_hybrid = self.config.get("enable_hybrid", False)
|
||||
# Use provided score_threshold or fall back to config
|
||||
search_score_threshold = score_threshold if score_threshold is not None else self.config.get("score_threshold", 0.3)
|
||||
|
||||
if enable_hybrid and NLTK_AVAILABLE:
|
||||
# Perform hybrid search (vector + BM25)
|
||||
search_results = await self._hybrid_search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
query_vector=query_embedding,
|
||||
query_filter=search_filter,
|
||||
limit=max_results,
|
||||
score_threshold=search_score_threshold
|
||||
)
|
||||
else:
|
||||
# Pure vector search with improved threshold
|
||||
search_results = self.qdrant_client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_embedding,
|
||||
query_filter=search_filter,
|
||||
limit=max_results,
|
||||
score_threshold=0.0 # Lowered from 0.5 to see all results including low scores
|
||||
score_threshold=search_score_threshold
|
||||
)
|
||||
|
||||
logger.info(f"Raw search results count: {len(search_results)}")
|
||||
@@ -1317,6 +1593,23 @@ class RAGModule(BaseModule):
|
||||
content = result.payload.get("content", "")
|
||||
score = result.score
|
||||
|
||||
# Generic content extraction for documents without a 'content' field
|
||||
if not content:
|
||||
# Build content from all text-based fields in the payload
|
||||
# This makes the RAG module completely agnostic to document structure
|
||||
text_fields = []
|
||||
for field, value in result.payload.items():
|
||||
# Skip system/metadata fields
|
||||
if field not in ["document_id", "chunk_index", "chunk_count", "indexed_at", "processed_at",
|
||||
"file_hash", "mime_type", "file_type", "created_at", "__collection_metadata__"]:
|
||||
# Include any field that has a non-empty string value
|
||||
if value and isinstance(value, str) and len(value.strip()) > 0:
|
||||
text_fields.append(f"{field}: {value}")
|
||||
|
||||
# Join all text fields to create content
|
||||
if text_fields:
|
||||
content = "\n\n".join(text_fields)
|
||||
|
||||
# Log each raw result for debugging
|
||||
logger.info(f"\n--- Raw Result {i+1} ---")
|
||||
logger.info(f"Score: {score}")
|
||||
@@ -1651,9 +1944,9 @@ async def index_processed_document(processed_doc: ProcessedDocument, collection_
|
||||
"""Index a processed document"""
|
||||
return await rag_module.index_processed_document(processed_doc, collection_name)
|
||||
|
||||
async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
|
||||
async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
|
||||
"""Search documents"""
|
||||
return await rag_module.search_documents(query, max_results, filters, collection_name)
|
||||
return await rag_module.search_documents(query, max_results, filters, collection_name, score_threshold)
|
||||
|
||||
async def delete_document(document_id: str, collection_name: str = None) -> bool:
|
||||
"""Delete a document"""
|
||||
|
||||
@@ -46,6 +46,7 @@ qdrant-client==1.7.0
|
||||
|
||||
# Text Processing
|
||||
tiktoken==0.5.1
|
||||
numpy>=1.26.0
|
||||
|
||||
# Basic document processing (lightweight)
|
||||
markitdown==0.0.1a2
|
||||
@@ -56,8 +57,9 @@ python-docx==1.1.0
|
||||
# nltk==3.8.1
|
||||
# spacy==3.7.2
|
||||
|
||||
# Heavy ML dependencies (REMOVED - unused in codebase)
|
||||
# sentence-transformers==2.6.1 # REMOVED - not used anywhere in codebase
|
||||
# Heavy ML dependencies (sentence-transformers will be installed separately)
|
||||
# Note: PyTorch is already installed in the base Docker image
|
||||
sentence-transformers==2.6.1 # Added back - needed for bitbox02_faq_local collection
|
||||
# transformers==4.35.2 # REMOVED - already commented out
|
||||
|
||||
# Configuration
|
||||
|
||||
92
backend/scripts/import_jsonl.py
Normal file
92
backend/scripts/import_jsonl.py
Normal file
@@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Import a JSONL file into a Qdrant collection from inside the backend container.
|
||||
|
||||
Usage (from host):
|
||||
docker compose exec enclava-backend bash -lc \
|
||||
'python /app/scripts/import_jsonl.py \
|
||||
--collection rag_test_import_859b1f01 \
|
||||
--file /app/_to_delete/helpjuice-export.jsonl'
|
||||
|
||||
Notes:
|
||||
- Runs fully inside the backend, so Docker service hostnames (e.g. enclava-qdrant)
|
||||
and privatemode-proxy are reachable.
|
||||
- Uses RAGModule + JSONLProcessor to embed/index each JSONL line.
|
||||
- Creates the collection if missing (size=1024, cosine).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
async def import_jsonl(collection_name: str, file_path: str):
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
from app.modules.rag.main import RAGModule
|
||||
from app.services.jsonl_processor import JSONLProcessor
|
||||
from app.core.config import settings
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise SystemExit(f"File not found: {file_path}")
|
||||
|
||||
# Ensure collection exists (inside container uses Docker DNS hostnames)
|
||||
client = QdrantClient(host=settings.QDRANT_HOST, port=settings.QDRANT_PORT)
|
||||
collections = client.get_collections().collections
|
||||
if not any(c.name == collection_name for c in collections):
|
||||
client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(size=1024, distance=Distance.COSINE),
|
||||
)
|
||||
print(f"Created Qdrant collection '{collection_name}' (size=1024, cosine)")
|
||||
else:
|
||||
print(f"Using existing Qdrant collection '{collection_name}'")
|
||||
|
||||
# Initialize RAG
|
||||
rag = RAGModule({
|
||||
"chunk_size": 300,
|
||||
"chunk_overlap": 50,
|
||||
"max_results": 10,
|
||||
"score_threshold": 0.3,
|
||||
"embedding_model": "intfloat/multilingual-e5-large-instruct",
|
||||
})
|
||||
await rag.initialize()
|
||||
|
||||
# Process JSONL
|
||||
processor = JSONLProcessor(rag)
|
||||
with open(file_path, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
doc_id = await processor.process_and_index_jsonl(
|
||||
collection_name=collection_name,
|
||||
content=content,
|
||||
filename=os.path.basename(file_path),
|
||||
metadata={
|
||||
"source": "jsonl_upload",
|
||||
"upload_date": datetime.utcnow().isoformat(),
|
||||
"file_path": os.path.abspath(file_path),
|
||||
},
|
||||
)
|
||||
|
||||
# Report stats using safe HTTP method to avoid client parsing issues
|
||||
try:
|
||||
info = await rag._get_collection_info_safely(collection_name)
|
||||
print(f"Import complete. Points: {info.get('points_count', 0)}, vector_size: {info.get('vector_size', 'n/a')}")
|
||||
except Exception as e:
|
||||
print(f"Import complete. (Could not fetch collection info safely: {e})")
|
||||
await rag.cleanup()
|
||||
return doc_id
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--collection", required=True, help="Qdrant collection name")
|
||||
ap.add_argument("--file", required=True, help="Path inside container (e.g. /app/_to_delete/...).")
|
||||
args = ap.parse_args()
|
||||
|
||||
asyncio.run(import_jsonl(args.collection, args.file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -20,11 +20,13 @@ services:
|
||||
dockerfile: Dockerfile
|
||||
environment:
|
||||
- DATABASE_URL=postgresql://enclava_user:enclava_pass@enclava-postgres:5432/enclava_db
|
||||
- JWT_SECRET=${JWT_SECRET:-your-jwt-secret-here}
|
||||
depends_on:
|
||||
- enclava-postgres
|
||||
command: ["/usr/local/bin/migrate.sh"]
|
||||
volumes:
|
||||
- ./backend:/app
|
||||
- ./.env:/app/.env
|
||||
networks:
|
||||
- enclava-net
|
||||
restart: "no" # Run once and exit
|
||||
@@ -63,7 +65,7 @@ services:
|
||||
enclava-frontend:
|
||||
image: node:18-alpine
|
||||
working_dir: /app
|
||||
command: sh -c "npm install && npm run dev"
|
||||
command: sh -c "npm ci --ignore-scripts && npm run dev"
|
||||
environment:
|
||||
# Required base URL (derives APP/API/WS URLs)
|
||||
- BASE_URL=${BASE_URL}
|
||||
@@ -76,7 +78,7 @@ services:
|
||||
- "3002:3000" # Direct frontend access for development
|
||||
volumes:
|
||||
- ./frontend:/app
|
||||
- /app/node_modules
|
||||
- enclava-frontend-node-modules:/app/node_modules
|
||||
networks:
|
||||
- enclava-net
|
||||
restart: unless-stopped
|
||||
@@ -145,6 +147,7 @@ volumes:
|
||||
enclava-postgres-data:
|
||||
enclava-redis-data:
|
||||
enclava-qdrant-data:
|
||||
enclava-frontend-node-modules:
|
||||
# enclava-ollama-data:
|
||||
|
||||
networks:
|
||||
|
||||
3
frontend/package-lock.json
generated
3
frontend/package-lock.json
generated
@@ -2048,6 +2048,9 @@
|
||||
},
|
||||
"node_modules/axios": {
|
||||
"version": "1.12.2",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.12.2.tgz",
|
||||
"integrity": "sha512-vMJzPewAlRyOgxV2dU0Cuz2O8zzzx9VYtbJOaBgXFeLc4IV/Eg50n4LowmehOOR61S8ZMpc2K5Sa7g6A4jfkUw==",
|
||||
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"follow-redirects": "^1.15.6",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { useState, useEffect, Suspense } from "react";
|
||||
export const dynamic = 'force-dynamic'
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { Suspense } from "react";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
@@ -94,7 +95,8 @@ const PERMISSION_OPTIONS = [
|
||||
{ value: "llm:embeddings", label: "LLM Embeddings" },
|
||||
];
|
||||
|
||||
function ApiKeysPageContent() {
|
||||
function ApiKeysContent() {
|
||||
|
||||
const { toast } = useToast();
|
||||
const searchParams = useSearchParams();
|
||||
const [apiKeys, setApiKeys] = useState<ApiKey[]>([]);
|
||||
@@ -910,8 +912,9 @@ function ApiKeysPageContent() {
|
||||
|
||||
export default function ApiKeysPage() {
|
||||
return (
|
||||
<Suspense fallback={<div className="container mx-auto p-6">Loading...</div>}>
|
||||
<ApiKeysPageContent />
|
||||
<Suspense fallback={<div>Loading API keys...</div>}>
|
||||
<ApiKeysContent />
|
||||
</Suspense>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ export async function POST(request: NextRequest) {
|
||||
|
||||
// Make request to backend auth endpoint without requiring existing auth
|
||||
const baseUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`
|
||||
const url = `${baseUrl}/api/auth/login`
|
||||
const url = `${baseUrl}/api-internal/v1/auth/login`
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
|
||||
56
frontend/src/app/api/rag/debug/collections/route.ts
Normal file
56
frontend/src/app/api/rag/debug/collections/route.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { tokenManager } from '@/lib/token-manager';
|
||||
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
// Get authentication token from Authorization header or tokenManager
|
||||
const authHeader = request.headers.get('authorization');
|
||||
let token;
|
||||
|
||||
if (authHeader && authHeader.startsWith('Bearer ')) {
|
||||
token = authHeader.substring(7);
|
||||
} else {
|
||||
token = await tokenManager.getAccessToken();
|
||||
}
|
||||
|
||||
if (!token) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Authentication required' },
|
||||
{ status: 401 }
|
||||
);
|
||||
}
|
||||
|
||||
// Backend URL
|
||||
const backendUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`;
|
||||
|
||||
// Build the proxy URL
|
||||
const proxyUrl = `${backendUrl}/api-internal/v1/rag/debug/collections`;
|
||||
|
||||
// Proxy the request to the backend with authentication
|
||||
const response = await fetch(proxyUrl, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${token}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error('Backend list collections error:', response.status, errorText);
|
||||
return NextResponse.json(
|
||||
{ error: `Backend request failed: ${response.status}` },
|
||||
{ status: response.status }
|
||||
);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return NextResponse.json(data);
|
||||
} catch (error) {
|
||||
console.error('RAG collections proxy error:', error);
|
||||
return NextResponse.json(
|
||||
{ error: 'Failed to proxy collections request' },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
67
frontend/src/app/api/rag/debug/search/route.ts
Normal file
67
frontend/src/app/api/rag/debug/search/route.ts
Normal file
@@ -0,0 +1,67 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
import { tokenManager } from '@/lib/token-manager';
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
// Get the search parameters from the query string
|
||||
const searchParams = request.nextUrl.searchParams;
|
||||
const query = searchParams.get('query') || '';
|
||||
const max_results = searchParams.get('max_results') || '10';
|
||||
const score_threshold = searchParams.get('score_threshold') || '0.3';
|
||||
const collection_name = searchParams.get('collection_name');
|
||||
|
||||
// Get the config from the request body
|
||||
const body = await request.json();
|
||||
|
||||
// Get authentication token from Authorization header or tokenManager
|
||||
const authHeader = request.headers.get('authorization');
|
||||
let token;
|
||||
|
||||
if (authHeader && authHeader.startsWith('Bearer ')) {
|
||||
token = authHeader.substring(7);
|
||||
} else {
|
||||
token = await tokenManager.getAccessToken();
|
||||
}
|
||||
|
||||
if (!token) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Authentication required' },
|
||||
{ status: 401 }
|
||||
);
|
||||
}
|
||||
|
||||
// Backend URL
|
||||
const backendUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`;
|
||||
|
||||
// Build the proxy URL with query parameters
|
||||
const proxyUrl = `${backendUrl}/api-internal/v1/rag/debug/search?query=${encodeURIComponent(query)}&max_results=${max_results}&score_threshold=${score_threshold}${collection_name ? `&collection_name=${encodeURIComponent(collection_name)}` : ''}`;
|
||||
|
||||
// Proxy the request to the backend with authentication
|
||||
const response = await fetch(proxyUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${token}`,
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error('Backend RAG search error:', response.status, errorText);
|
||||
return NextResponse.json(
|
||||
{ error: `Backend request failed: ${response.status}` },
|
||||
{ status: response.status }
|
||||
);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
return NextResponse.json(data);
|
||||
} catch (error) {
|
||||
console.error('RAG debug search proxy error:', error);
|
||||
return NextResponse.json(
|
||||
{ error: 'Failed to proxy RAG search request' },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
569
frontend/src/app/rag-demo/page.tsx
Normal file
569
frontend/src/app/rag-demo/page.tsx
Normal file
@@ -0,0 +1,569 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from 'react';
|
||||
import { useAuth } from '@/contexts/AuthContext';
|
||||
import { tokenManager } from '@/lib/token-manager';
|
||||
|
||||
interface SearchResult {
|
||||
document: {
|
||||
id: string;
|
||||
content: string;
|
||||
metadata: Record<string, any>;
|
||||
};
|
||||
score: number;
|
||||
debug_info?: Record<string, any>;
|
||||
}
|
||||
|
||||
interface DebugInfo {
|
||||
query_embedding?: number[];
|
||||
embedding_dimension?: number;
|
||||
score_stats?: {
|
||||
min: number;
|
||||
max: number;
|
||||
avg: number;
|
||||
stddev: number;
|
||||
};
|
||||
collection_stats?: {
|
||||
total_documents: number;
|
||||
total_chunks: number;
|
||||
languages: string[];
|
||||
};
|
||||
}
|
||||
|
||||
export default function RAGDemoPage() {
|
||||
const { user, loading } = useAuth();
|
||||
const [query, setQuery] = useState('are sd card backups encrypted?');
|
||||
const [results, setResults] = useState<SearchResult[]>([]);
|
||||
const [debugInfo, setDebugInfo] = useState<DebugInfo>({});
|
||||
const [searchTime, setSearchTime] = useState(0);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState('');
|
||||
|
||||
// Configuration state
|
||||
const [config, setConfig] = useState({
|
||||
max_results: 10,
|
||||
score_threshold: 0.3,
|
||||
collection_name: '',
|
||||
chunk_size: 300,
|
||||
chunk_overlap: 50,
|
||||
enable_hybrid: false,
|
||||
vector_weight: 0.7,
|
||||
bm25_weight: 0.3,
|
||||
use_query_prefix: true,
|
||||
use_passage_prefix: true,
|
||||
show_timing: true,
|
||||
show_embeddings: false,
|
||||
});
|
||||
|
||||
// Available collections
|
||||
const [collections, setCollections] = useState<string[]>([]);
|
||||
const [collectionsLoading, setCollectionsLoading] = useState(false);
|
||||
|
||||
const presets = {
|
||||
default: {
|
||||
max_results: 10,
|
||||
score_threshold: 0.3,
|
||||
chunk_size: 300,
|
||||
chunk_overlap: 50,
|
||||
enable_hybrid: false,
|
||||
vector_weight: 0.7,
|
||||
bm25_weight: 0.3,
|
||||
},
|
||||
high_precision: {
|
||||
max_results: 5,
|
||||
score_threshold: 0.5,
|
||||
chunk_size: 200,
|
||||
chunk_overlap: 30,
|
||||
enable_hybrid: true,
|
||||
vector_weight: 0.8,
|
||||
bm25_weight: 0.2,
|
||||
},
|
||||
high_recall: {
|
||||
max_results: 20,
|
||||
score_threshold: 0.1,
|
||||
chunk_size: 400,
|
||||
chunk_overlap: 100,
|
||||
enable_hybrid: true,
|
||||
vector_weight: 0.6,
|
||||
bm25_weight: 0.4,
|
||||
},
|
||||
hybrid: {
|
||||
max_results: 10,
|
||||
score_threshold: 0.2,
|
||||
chunk_size: 300,
|
||||
chunk_overlap: 50,
|
||||
enable_hybrid: true,
|
||||
vector_weight: 0.5,
|
||||
bm25_weight: 0.5,
|
||||
},
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
// Check if we have tokens in localStorage but not in tokenManager
|
||||
const syncTokens = async () => {
|
||||
const rawTokens = localStorage.getItem('auth_tokens');
|
||||
if (rawTokens && !tokenManager.isAuthenticated()) {
|
||||
try {
|
||||
const tokens = JSON.parse(rawTokens);
|
||||
// Sync tokens to tokenManager
|
||||
tokenManager.setTokens(
|
||||
tokens.access_token,
|
||||
tokens.refresh_token,
|
||||
Math.floor((tokens.access_expires_at - Date.now()) / 1000)
|
||||
);
|
||||
console.log('RAG Demo: Tokens synced from localStorage to tokenManager');
|
||||
} catch (e) {
|
||||
console.error('RAG Demo: Failed to sync tokens:', e);
|
||||
}
|
||||
}
|
||||
loadCollections();
|
||||
};
|
||||
|
||||
syncTokens();
|
||||
}, [user]);
|
||||
|
||||
const loadCollections = async () => {
|
||||
setCollectionsLoading(true);
|
||||
try {
|
||||
console.log('RAG Demo: Loading collections...');
|
||||
console.log('RAG Demo: User authenticated:', !!user);
|
||||
console.log('RAG Demo: TokenManager authenticated:', tokenManager.isAuthenticated());
|
||||
|
||||
const token = await tokenManager.getAccessToken();
|
||||
console.log('RAG Demo: Token retrieved:', token ? 'Yes' : 'No');
|
||||
console.log('RAG Demo: Token expiry:', tokenManager.getTokenExpiry());
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
if (token) {
|
||||
headers['Authorization'] = `Bearer ${token}`;
|
||||
console.log('RAG Demo: Authorization header set');
|
||||
} else {
|
||||
console.warn('RAG Demo: No token available');
|
||||
}
|
||||
|
||||
const response = await fetch('/api/rag/debug/collections', { headers });
|
||||
console.log('RAG Demo: Collections response status:', response.status);
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
console.log('RAG Demo: Collections loaded:', data.collections);
|
||||
setCollections(data.collections || []);
|
||||
// Auto-select first collection if none selected
|
||||
if (data.collections && data.collections.length > 0 && !config.collection_name) {
|
||||
setConfig(prev => ({ ...prev, collection_name: data.collections[0] }));
|
||||
}
|
||||
} else {
|
||||
const errorText = await response.text();
|
||||
console.error('RAG Demo: Collections failed:', response.status, errorText);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('RAG Demo: Failed to load collections:', err);
|
||||
} finally {
|
||||
setCollectionsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const loadPreset = (presetName: keyof typeof presets) => {
|
||||
setConfig(prev => ({
|
||||
...prev,
|
||||
...presets[presetName],
|
||||
}));
|
||||
};
|
||||
|
||||
const performSearch = async () => {
|
||||
if (!query.trim()) return;
|
||||
if (!config.collection_name) {
|
||||
setError('Please select a collection');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsLoading(true);
|
||||
setError('');
|
||||
setResults([]);
|
||||
|
||||
try {
|
||||
const token = await tokenManager.getAccessToken();
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
if (token) {
|
||||
headers['Authorization'] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
const response = await fetch('/api/rag/debug/search', {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
query,
|
||||
max_results: config.max_results,
|
||||
score_threshold: config.score_threshold,
|
||||
collection_name: config.collection_name,
|
||||
config,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Search failed: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
setResults(data.results || []);
|
||||
setDebugInfo(data.debug_info || {});
|
||||
setSearchTime(data.search_time_ms || 0);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Unknown error');
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const updateConfig = (key: string, value: any) => {
|
||||
setConfig(prev => ({ ...prev, [key]: value }));
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center min-h-screen">
|
||||
<div className="text-lg">Loading...</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!user) {
|
||||
return (
|
||||
<div className="flex items-center justify-center min-h-screen">
|
||||
<div className="text-center">
|
||||
<h1 className="text-2xl font-bold mb-4">RAG Demo</h1>
|
||||
<p>Please log in to access the RAG demo interface.</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="container mx-auto px-4 py-8 max-w-7xl">
|
||||
<h1 className="text-3xl font-bold mb-2">🔍 RAG Search Demo</h1>
|
||||
<p className="text-gray-600 mb-6">Test and tune your RAG system with real-time search and debugging</p>
|
||||
|
||||
<div className="grid grid-cols-1 lg:grid-cols-4 gap-6">
|
||||
{/* Search Results - Main Content */}
|
||||
<div className="lg:col-span-3 space-y-6">
|
||||
{/* Preset Buttons */}
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{Object.entries(presets).map(([name, _]) => (
|
||||
<button
|
||||
key={name}
|
||||
onClick={() => loadPreset(name as keyof typeof presets)}
|
||||
className="px-3 py-1 bg-gray-100 hover:bg-gray-200 rounded-md text-sm capitalize"
|
||||
>
|
||||
{name.replace('_', ' ')}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Search Box */}
|
||||
<div className="bg-white rounded-lg shadow p-6">
|
||||
<div className="flex gap-2">
|
||||
<input
|
||||
type="text"
|
||||
value={query}
|
||||
onChange={(e) => setQuery(e.target.value)}
|
||||
onKeyPress={(e) => e.key === 'Enter' && performSearch()}
|
||||
placeholder="Enter your search query..."
|
||||
className="flex-1 px-4 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||
/>
|
||||
<button
|
||||
onClick={performSearch}
|
||||
disabled={isLoading || !config.collection_name}
|
||||
className="px-6 py-2 bg-blue-500 text-white rounded-md hover:bg-blue-600 disabled:opacity-50"
|
||||
>
|
||||
{isLoading ? 'Searching...' : 'Search'}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="mt-4 p-4 bg-red-100 text-red-700 rounded-md">
|
||||
Error: {error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Results Summary */}
|
||||
{results.length > 0 && (
|
||||
<div className="mt-4 p-3 bg-blue-50 rounded-md">
|
||||
<p className="text-sm">
|
||||
Found <strong>{results.length}</strong> results in <strong>{searchTime.toFixed(0)}ms</strong>
|
||||
{config.enable_hybrid && (
|
||||
<span className="ml-2 text-green-600">• Hybrid Search Enabled</span>
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Search Results */}
|
||||
<div className="space-y-4">
|
||||
{results.map((result, index) => (
|
||||
<div key={index} className="bg-white rounded-lg shadow p-6">
|
||||
<div className="flex justify-between items-start mb-3">
|
||||
<h3 className="text-lg font-semibold">Result {index + 1}</h3>
|
||||
<span className={`px-3 py-1 rounded-full text-sm font-medium ${
|
||||
result.score >= 0.5 ? 'bg-green-100 text-green-800' :
|
||||
result.score >= 0.3 ? 'bg-yellow-100 text-yellow-800' :
|
||||
'bg-red-100 text-red-800'
|
||||
}`}>
|
||||
Score: {result.score.toFixed(4)}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="text-gray-700 mb-4 whitespace-pre-wrap">
|
||||
{result.document.content}
|
||||
</div>
|
||||
|
||||
{/* Metadata */}
|
||||
<div className="text-sm text-gray-500 mb-3">
|
||||
{result.document.metadata.content_type && (
|
||||
<span>Type: {result.document.metadata.content_type}</span>
|
||||
)}
|
||||
{result.document.metadata.language && (
|
||||
<span className="ml-3">Language: {result.document.metadata.language}</span>
|
||||
)}
|
||||
{result.document.metadata.filename && (
|
||||
<span className="ml-3">File: {result.document.metadata.filename}</span>
|
||||
)}
|
||||
{result.document.metadata.chunk_index !== undefined && (
|
||||
<span className="ml-3">
|
||||
Chunk: {result.document.metadata.chunk_index + 1}/{result.document.metadata.chunk_count || '?'}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Debug Details */}
|
||||
{config.show_timing && result.debug_info && (
|
||||
<div className="mt-4 p-3 bg-gray-50 rounded-md text-xs font-mono">
|
||||
<p><strong>Debug Information:</strong></p>
|
||||
{result.debug_info.vector_score !== undefined && (
|
||||
<p>Vector Score: {result.debug_info.vector_score.toFixed(4)}</p>
|
||||
)}
|
||||
{result.debug_info.bm25_score !== undefined && (
|
||||
<p>BM25 Score: {result.debug_info.bm25_score.toFixed(4)}</p>
|
||||
)}
|
||||
{result.document.metadata.question && (
|
||||
<div className="mt-2">
|
||||
<p><strong>Question:</strong> {result.document.metadata.question}</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Debug Section */}
|
||||
{debugInfo && Object.keys(debugInfo).length > 0 && (
|
||||
<div className="bg-gray-900 text-green-400 rounded-lg shadow p-6 font-mono text-sm">
|
||||
<h3 className="text-lg font-semibold mb-4">Debug Information</h3>
|
||||
|
||||
{debugInfo.score_stats && (
|
||||
<div className="mb-4">
|
||||
<p className="font-semibold mb-2">Score Statistics:</p>
|
||||
<div className="grid grid-cols-2 md:grid-cols-4 gap-2 text-xs">
|
||||
<div>Min: {debugInfo.score_stats.min?.toFixed(4)}</div>
|
||||
<div>Max: {debugInfo.score_stats.max?.toFixed(4)}</div>
|
||||
<div>Avg: {debugInfo.score_stats.avg?.toFixed(4)}</div>
|
||||
<div>StdDev: {debugInfo.score_stats.stddev?.toFixed(4)}</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{debugInfo.collection_stats && (
|
||||
<div className="mb-4">
|
||||
<p className="font-semibold mb-2">Collection Stats:</p>
|
||||
<div className="text-xs">
|
||||
<p>Total Documents: {debugInfo.collection_stats.total_documents}</p>
|
||||
<p>Total Chunks: {debugInfo.collection_stats.total_chunks}</p>
|
||||
<p>Languages: {debugInfo.collection_stats.languages?.join(', ')}</p>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{debugInfo.query_embedding && config.show_embeddings && (
|
||||
<div>
|
||||
<p className="font-semibold mb-2">Query Embedding (first 10 dims):</p>
|
||||
<p className="text-xs">
|
||||
[{debugInfo.query_embedding.slice(0, 10).map(x => x.toFixed(6)).join(', ')}...]
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Configuration Panel */}
|
||||
<div className="space-y-6">
|
||||
<div className="bg-white rounded-lg shadow p-6">
|
||||
<h2 className="text-xl font-semibold mb-4">⚙️ Configuration</h2>
|
||||
|
||||
<div className="space-y-4">
|
||||
{/* Search Settings */}
|
||||
<div>
|
||||
<h3 className="font-medium mb-2">Search Settings</h3>
|
||||
<div className="space-y-3">
|
||||
<div>
|
||||
<label className="block text-sm mb-1">Max Results: {config.max_results}</label>
|
||||
<input
|
||||
type="range"
|
||||
min="1"
|
||||
max="50"
|
||||
value={config.max_results}
|
||||
onChange={(e) => updateConfig('max_results', parseInt(e.target.value))}
|
||||
className="w-full"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm mb-1">Score Threshold: {config.score_threshold}</label>
|
||||
<input
|
||||
type="range"
|
||||
min="0"
|
||||
max="1"
|
||||
step="0.05"
|
||||
value={config.score_threshold}
|
||||
onChange={(e) => updateConfig('score_threshold', parseFloat(e.target.value))}
|
||||
className="w-full"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm mb-1">Collection Name</label>
|
||||
{collectionsLoading ? (
|
||||
<select
|
||||
disabled
|
||||
className="w-full px-3 py-1 border border-gray-300 rounded-md text-sm bg-gray-50"
|
||||
>
|
||||
<option>Loading collections...</option>
|
||||
</select>
|
||||
) : (
|
||||
<select
|
||||
value={config.collection_name}
|
||||
onChange={(e) => updateConfig('collection_name', e.target.value)}
|
||||
className="w-full px-3 py-1 border border-gray-300 rounded-md text-sm"
|
||||
>
|
||||
<option value="">Select a collection...</option>
|
||||
{collections.map(collection => (
|
||||
<option key={collection} value={collection}>
|
||||
{collection}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Chunking Settings */}
|
||||
<div>
|
||||
<h3 className="font-medium mb-2">Chunking Settings</h3>
|
||||
<div className="space-y-3">
|
||||
<div>
|
||||
<label className="block text-sm mb-1">Chunk Size: {config.chunk_size}</label>
|
||||
<input
|
||||
type="range"
|
||||
min="100"
|
||||
max="1000"
|
||||
step="50"
|
||||
value={config.chunk_size}
|
||||
onChange={(e) => updateConfig('chunk_size', parseInt(e.target.value))}
|
||||
className="w-full"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm mb-1">Chunk Overlap: {config.chunk_overlap}</label>
|
||||
<input
|
||||
type="range"
|
||||
min="0"
|
||||
max="200"
|
||||
step="10"
|
||||
value={config.chunk_overlap}
|
||||
onChange={(e) => updateConfig('chunk_overlap', parseInt(e.target.value))}
|
||||
className="w-full"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Hybrid Search */}
|
||||
<div>
|
||||
<h3 className="font-medium mb-2">Hybrid Search</h3>
|
||||
<div className="space-y-3">
|
||||
<label className="flex items-center">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={config.enable_hybrid}
|
||||
onChange={(e) => updateConfig('enable_hybrid', e.target.checked)}
|
||||
className="mr-2"
|
||||
/>
|
||||
<span className="text-sm">Enable Hybrid Search</span>
|
||||
</label>
|
||||
{config.enable_hybrid && (
|
||||
<>
|
||||
<div>
|
||||
<label className="block text-sm mb-1">Vector Weight: {config.vector_weight}</label>
|
||||
<input
|
||||
type="range"
|
||||
min="0"
|
||||
max="1"
|
||||
step="0.05"
|
||||
value={config.vector_weight}
|
||||
onChange={(e) => updateConfig('vector_weight', parseFloat(e.target.value))}
|
||||
className="w-full"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm mb-1">BM25 Weight: {config.bm25_weight}</label>
|
||||
<input
|
||||
type="range"
|
||||
min="0"
|
||||
max="1"
|
||||
step="0.05"
|
||||
value={config.bm25_weight}
|
||||
onChange={(e) => updateConfig('bm25_weight', parseFloat(e.target.value))}
|
||||
className="w-full"
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Debug Options */}
|
||||
<div>
|
||||
<h3 className="font-medium mb-2">Debug Options</h3>
|
||||
<div className="space-y-2">
|
||||
<label className="flex items-center">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={config.show_timing}
|
||||
onChange={(e) => updateConfig('show_timing', e.target.checked)}
|
||||
className="mr-2"
|
||||
/>
|
||||
<span className="text-sm">Show Timing</span>
|
||||
</label>
|
||||
<label className="flex items-center">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={config.show_embeddings}
|
||||
onChange={(e) => updateConfig('show_embeddings', e.target.checked)}
|
||||
className="mr-2"
|
||||
/>
|
||||
<span className="text-sm">Show Embeddings</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -85,8 +85,31 @@ function RAGPageContent() {
|
||||
const loadStats = async () => {
|
||||
try {
|
||||
const data = await apiClient.get('/api-internal/v1/rag/stats')
|
||||
console.log('Stats API response:', data)
|
||||
|
||||
// Check if the response has the expected structure
|
||||
if (data && data.stats && data.stats.collections) {
|
||||
console.log('✓ Stats has collections property')
|
||||
setStats(data.stats)
|
||||
} else {
|
||||
console.error('✗ Invalid stats structure:', data)
|
||||
// Set default empty stats to prevent error
|
||||
setStats({
|
||||
collections: { total: 0, active: 0 },
|
||||
documents: { total: 0, processing: 0, processed: 0 },
|
||||
storage: { total_size_bytes: 0, total_size_mb: 0 },
|
||||
vectors: { total: 0 }
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error loading stats:', error)
|
||||
// Set default empty stats on error
|
||||
setStats({
|
||||
collections: { total: 0, active: 0 },
|
||||
documents: { total: 0, processing: 0, processed: 0 },
|
||||
storage: { total_size_bytes: 0, total_size_mb: 0 },
|
||||
vectors: { total: 0 }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -87,9 +87,8 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([])
|
||||
const [input, setInput] = useState("")
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const [conversationId, setConversationId] = useState<string | null>(null)
|
||||
const scrollAreaRef = useRef<HTMLDivElement>(null)
|
||||
const { toast } = useToast()
|
||||
const { success: toastSuccess, error: toastError } = useToast()
|
||||
|
||||
const scrollToBottom = useCallback(() => {
|
||||
if (scrollAreaRef.current) {
|
||||
@@ -130,43 +129,21 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
|
||||
console.log('=== CHAT REQUEST DEBUG ===', debugInfo)
|
||||
|
||||
try {
|
||||
// Build conversation history in OpenAI format
|
||||
let data: any
|
||||
|
||||
// Use internal API
|
||||
const conversationHistory = messages.map(msg => ({
|
||||
role: msg.role,
|
||||
content: msg.content
|
||||
}))
|
||||
|
||||
const requestData = {
|
||||
messages: conversationHistory,
|
||||
conversation_id: conversationId || undefined
|
||||
}
|
||||
|
||||
console.log('=== CHAT API REQUEST ===', {
|
||||
url: `/api-internal/v1/chatbot/${chatbotId}/chat/completions`,
|
||||
data: requestData
|
||||
})
|
||||
|
||||
const data = await chatbotApi.chat(
|
||||
data = await chatbotApi.sendMessage(
|
||||
chatbotId,
|
||||
messageToSend,
|
||||
requestData
|
||||
undefined, // No conversation ID
|
||||
conversationHistory
|
||||
)
|
||||
|
||||
console.log('=== CHAT API RESPONSE ===', {
|
||||
status: 'success',
|
||||
data,
|
||||
responseKeys: Object.keys(data),
|
||||
hasChoices: !!data.choices,
|
||||
hasResponse: !!data.response,
|
||||
content: data.choices?.[0]?.message?.content || data.response || 'No response'
|
||||
})
|
||||
|
||||
// Update conversation ID if it's a new conversation
|
||||
if (!conversationId && data.conversation_id) {
|
||||
setConversationId(data.conversationId)
|
||||
console.log('Updated conversation ID:', data.conversation_id)
|
||||
}
|
||||
|
||||
const assistantMessage: ChatMessage = {
|
||||
id: data.id || generateTimestampId('msg'),
|
||||
role: 'assistant',
|
||||
@@ -177,52 +154,21 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
|
||||
|
||||
setMessages(prev => [...prev, assistantMessage])
|
||||
|
||||
} catch (error: any) {
|
||||
console.error('=== CHAT ERROR DEBUG ===', {
|
||||
errorType: typeof error,
|
||||
error,
|
||||
errorMessage: error?.message,
|
||||
errorCode: error?.code,
|
||||
errorResponse: error?.response?.data,
|
||||
errorStatus: error?.response?.status,
|
||||
errorConfig: error?.config,
|
||||
errorStack: error?.stack,
|
||||
timestamp: new Date().toISOString()
|
||||
})
|
||||
} catch (error) {
|
||||
const appError = error as AppError
|
||||
|
||||
// Handle different error types
|
||||
if (error && typeof error === 'object') {
|
||||
if ('response' in error) {
|
||||
// Axios error
|
||||
const status = error.response?.status
|
||||
if (status === 401) {
|
||||
toast.error("Authentication Required", "Please log in to continue chatting.")
|
||||
} else if (status === 429) {
|
||||
toast.error("Rate Limit", "Too many requests. Please wait and try again.")
|
||||
// More specific error handling
|
||||
if (appError.code === 'UNAUTHORIZED') {
|
||||
toastError("Authentication Required", "Please log in to continue chatting.")
|
||||
} else if (appError.code === 'NETWORK_ERROR') {
|
||||
toastError("Connection Error", "Please check your internet connection and try again.")
|
||||
} else {
|
||||
toast.error("Message Failed", error.response?.data?.detail || "Failed to send message. Please try again.")
|
||||
}
|
||||
} else if ('code' in error) {
|
||||
// Custom error with code
|
||||
if (error.code === 'UNAUTHORIZED') {
|
||||
toast.error("Authentication Required", "Please log in to continue chatting.")
|
||||
} else if (error.code === 'NETWORK_ERROR') {
|
||||
toast.error("Connection Error", "Please check your internet connection and try again.")
|
||||
} else {
|
||||
toast.error("Message Failed", error.message || "Failed to send message. Please try again.")
|
||||
}
|
||||
} else {
|
||||
// Generic error
|
||||
toast.error("Message Failed", error.message || "Failed to send message. Please try again.")
|
||||
}
|
||||
} else {
|
||||
// Fallback for unknown error types
|
||||
toast.error("Message Failed", "An unexpected error occurred. Please try again.")
|
||||
toastError("Message Failed", appError.message || "Failed to send message. Please try again.")
|
||||
}
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [input, isLoading, chatbotId, conversationId, messages, toast])
|
||||
}, [input, isLoading, chatbotId, messages, toastError])
|
||||
|
||||
const handleKeyPress = useCallback((e: React.KeyboardEvent) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
@@ -234,11 +180,11 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
|
||||
const copyMessage = useCallback(async (content: string) => {
|
||||
try {
|
||||
await navigator.clipboard.writeText(content)
|
||||
toast.success("Copied", "Message copied to clipboard")
|
||||
toastSuccess("Copied", "Message copied to clipboard")
|
||||
} catch (error) {
|
||||
toast.error("Copy Failed", "Unable to copy message to clipboard")
|
||||
toastError("Copy Failed", "Unable to copy message to clipboard")
|
||||
}
|
||||
}, [toast])
|
||||
}, [toastSuccess, toastError])
|
||||
|
||||
const formatTime = useCallback((date: Date) => {
|
||||
return date.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' })
|
||||
|
||||
@@ -138,6 +138,7 @@ export function ChatbotManager() {
|
||||
const [editingChatbot, setEditingChatbot] = useState<ChatbotInstance | null>(null)
|
||||
const [showChatInterface, setShowChatInterface] = useState(false)
|
||||
const [testingChatbot, setTestingChatbot] = useState<ChatbotInstance | null>(null)
|
||||
const [chatbotApiKeys, setChatbotApiKeys] = useState<Record<string, string>>({})
|
||||
const { toast } = useToast()
|
||||
|
||||
// New chatbot form state
|
||||
|
||||
@@ -20,7 +20,7 @@ interface AuthContextType {
|
||||
user: User | null
|
||||
isLoading: boolean
|
||||
isAuthenticated: boolean
|
||||
login: (username: string, password: string) => Promise<void>
|
||||
login: (email: string, password: string) => Promise<void>
|
||||
logout: () => void
|
||||
register: (username: string, email: string, password: string) => Promise<void>
|
||||
refreshToken: () => Promise<void>
|
||||
@@ -62,9 +62,9 @@ export function AuthProvider({ children }: { children: React.ReactNode }) {
|
||||
}
|
||||
}
|
||||
|
||||
const login = async (username: string, password: string) => {
|
||||
const login = async (email: string, password: string) => {
|
||||
try {
|
||||
const data = await apiClient.post("/api-internal/v1/auth/login", { username, password })
|
||||
const data = await apiClient.post("/api-internal/v1/auth/login", { email, password })
|
||||
|
||||
// Store tokens using tokenManager
|
||||
tokenManager.setTokens(data.access_token, data.refresh_token)
|
||||
|
||||
@@ -9,7 +9,7 @@ import { Badge } from "@/components/ui/badge"
|
||||
import { Separator } from "@/components/ui/separator"
|
||||
import { AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger } from "@/components/ui/alert-dialog"
|
||||
import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog"
|
||||
import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter } from "lucide-react"
|
||||
import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter, RefreshCw } from "lucide-react"
|
||||
import { useToast } from "@/hooks/use-toast"
|
||||
import { apiClient } from "@/lib/api-client"
|
||||
import { config } from "@/lib/config"
|
||||
@@ -56,6 +56,7 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
|
||||
const [filterStatus, setFilterStatus] = useState("all")
|
||||
const [selectedDocument, setSelectedDocument] = useState<Document | null>(null)
|
||||
const [deleting, setDeleting] = useState<string | null>(null)
|
||||
const [reprocessing, setReprocessing] = useState<string | null>(null)
|
||||
const { toast } = useToast()
|
||||
|
||||
useEffect(() => {
|
||||
@@ -157,6 +158,43 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
|
||||
}
|
||||
}
|
||||
|
||||
const handleReprocessDocument = async (documentId: string) => {
|
||||
setReprocessing(documentId)
|
||||
|
||||
try {
|
||||
await apiClient.post(`/api-internal/v1/rag/documents/${documentId}/reprocess`)
|
||||
|
||||
// Update the document status to processing in the UI
|
||||
setDocuments(prev => prev.map(doc =>
|
||||
doc.id === documentId
|
||||
? { ...doc, status: 'processing' as const, processed_at: new Date().toISOString() }
|
||||
: doc
|
||||
))
|
||||
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "Document reprocessing started",
|
||||
})
|
||||
|
||||
// Reload documents after a short delay to see status updates
|
||||
setTimeout(() => {
|
||||
loadDocuments()
|
||||
}, 2000)
|
||||
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : "Failed to reprocess document"
|
||||
toast({
|
||||
title: "Error",
|
||||
description: errorMessage.includes("Cannot reprocess document with status 'processed'")
|
||||
? "Cannot reprocess documents that are already processed"
|
||||
: errorMessage,
|
||||
variant: "destructive",
|
||||
})
|
||||
} finally {
|
||||
setReprocessing(null)
|
||||
}
|
||||
}
|
||||
|
||||
const formatFileSize = (bytes: number) => {
|
||||
if (bytes === 0) return '0 Bytes'
|
||||
const k = 1024
|
||||
@@ -432,6 +470,21 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
|
||||
<Download className="h-4 w-4" />
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-8 w-8 p-0 hover:bg-blue-100"
|
||||
onClick={() => handleReprocessDocument(document.id)}
|
||||
disabled={reprocessing === document.id || document.status === 'processed'}
|
||||
title={document.status === 'processed' ? "Document already processed" : "Reprocess document"}
|
||||
>
|
||||
{reprocessing === document.id ? (
|
||||
<RefreshCw className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<RefreshCw className={`h-4 w-4 ${document.status === 'processed' ? 'text-gray-400' : ''}`} />
|
||||
)}
|
||||
</Button>
|
||||
|
||||
<AlertDialog>
|
||||
<AlertDialogTrigger asChild>
|
||||
<Button
|
||||
|
||||
@@ -73,6 +73,7 @@ const Navigation = () => {
|
||||
children: [
|
||||
{ href: "/llm", label: "Models & Config" },
|
||||
{ href: "/playground", label: "Playground" },
|
||||
{ href: "/rag-demo", label: "RAG Demo" },
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,119 +1,117 @@
|
||||
import axios from 'axios';
|
||||
import Cookies from 'js-cookie';
|
||||
|
||||
// Dynamic base URL with protocol detection
|
||||
const getApiBaseUrl = (): string => {
|
||||
if (typeof window !== 'undefined') {
|
||||
// Client-side: use the same protocol as the current page
|
||||
const protocol = window.location.protocol.slice(0, -1); // Remove ':' from 'https:'
|
||||
const host = window.location.hostname;
|
||||
return `${protocol}://${host}`;
|
||||
}
|
||||
export interface AppError extends Error {
|
||||
code: 'UNAUTHORIZED' | 'NETWORK_ERROR' | 'VALIDATION_ERROR' | 'NOT_FOUND' | 'FORBIDDEN' | 'TIMEOUT' | 'UNKNOWN'
|
||||
status?: number
|
||||
details?: any
|
||||
}
|
||||
|
||||
// Server-side: use environment variable or default to localhost
|
||||
const baseUrl = process.env.NEXT_PUBLIC_BASE_URL || 'localhost';
|
||||
const protocol = process.env.NODE_ENV === 'production' ? 'https' : 'http';
|
||||
return `${protocol}://${baseUrl}`;
|
||||
};
|
||||
|
||||
const axiosInstance = axios.create({
|
||||
baseURL: getApiBaseUrl(),
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
});
|
||||
|
||||
// Request interceptor to add auth token
|
||||
axiosInstance.interceptors.request.use(
|
||||
(config) => {
|
||||
const token = Cookies.get('access_token');
|
||||
if (token) {
|
||||
config.headers.Authorization = `Bearer ${token}`;
|
||||
}
|
||||
return config;
|
||||
},
|
||||
(error) => {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
);
|
||||
|
||||
// Response interceptor to handle token refresh
|
||||
axiosInstance.interceptors.response.use(
|
||||
(response) => response,
|
||||
async (error) => {
|
||||
const originalRequest = error.config;
|
||||
|
||||
if (error.response?.status === 401 && !originalRequest._retry) {
|
||||
originalRequest._retry = true;
|
||||
function makeError(message: string, code: AppError['code'], status?: number, details?: any): AppError {
|
||||
const err = new Error(message) as AppError
|
||||
err.code = code
|
||||
err.status = status
|
||||
err.details = details
|
||||
return err
|
||||
}
|
||||
|
||||
async function getAuthHeader(): Promise<Record<string, string>> {
|
||||
try {
|
||||
const refreshToken = Cookies.get('refresh_token');
|
||||
if (refreshToken) {
|
||||
const response = await axios.post(`${getApiBaseUrl()}/api-internal/v1/auth/refresh`, {
|
||||
refresh_token: refreshToken,
|
||||
});
|
||||
|
||||
const { access_token } = response.data;
|
||||
Cookies.set('access_token', access_token, { expires: 7 });
|
||||
|
||||
originalRequest.headers.Authorization = `Bearer ${access_token}`;
|
||||
return axiosInstance(originalRequest);
|
||||
}
|
||||
} catch (refreshError) {
|
||||
// Refresh failed, redirect to login
|
||||
Cookies.remove('access_token');
|
||||
Cookies.remove('refresh_token');
|
||||
window.location.href = '/login';
|
||||
return Promise.reject(refreshError);
|
||||
const { tokenManager } = await import('./token-manager')
|
||||
const token = await tokenManager.getAccessToken()
|
||||
return token ? { Authorization: `Bearer ${token}` } : {}
|
||||
} catch {
|
||||
return {}
|
||||
}
|
||||
}
|
||||
|
||||
async function request<T = any>(method: string, url: string, body?: any, extraInit?: RequestInit): Promise<T> {
|
||||
try {
|
||||
const headers: Record<string, string> = {
|
||||
'Accept': 'application/json',
|
||||
...(method !== 'GET' && method !== 'HEAD' ? { 'Content-Type': 'application/json' } : {}),
|
||||
...(await getAuthHeader()),
|
||||
...(extraInit?.headers as Record<string, string> | undefined),
|
||||
}
|
||||
|
||||
return Promise.reject(error);
|
||||
const res = await fetch(url, {
|
||||
method,
|
||||
headers,
|
||||
body: body != null && method !== 'GET' && method !== 'HEAD' ? JSON.stringify(body) : undefined,
|
||||
...extraInit,
|
||||
})
|
||||
|
||||
if (!res.ok) {
|
||||
let details: any = undefined
|
||||
try { details = await res.json() } catch { details = await res.text() }
|
||||
const status = res.status
|
||||
if (status === 401) throw makeError('Unauthorized', 'UNAUTHORIZED', status, details)
|
||||
if (status === 403) throw makeError('Forbidden', 'FORBIDDEN', status, details)
|
||||
if (status === 404) throw makeError('Not found', 'NOT_FOUND', status, details)
|
||||
if (status === 400) throw makeError('Validation error', 'VALIDATION_ERROR', status, details)
|
||||
throw makeError('Request failed', 'UNKNOWN', status, details)
|
||||
}
|
||||
);
|
||||
|
||||
const contentType = res.headers.get('content-type') || ''
|
||||
if (contentType.includes('application/json')) {
|
||||
return (await res.json()) as T
|
||||
}
|
||||
// @ts-expect-error allow non-json generic
|
||||
return (await res.text()) as T
|
||||
} catch (e: any) {
|
||||
if (e?.code) throw e
|
||||
if (e?.name === 'AbortError') throw makeError('Request timed out', 'TIMEOUT')
|
||||
throw makeError(e?.message || 'Network error', 'NETWORK_ERROR')
|
||||
}
|
||||
}
|
||||
|
||||
export const apiClient = {
|
||||
get: async <T = any>(url: string, config?: any): Promise<T> => {
|
||||
const response = await axiosInstance.get(url, config);
|
||||
return response.data;
|
||||
},
|
||||
get: <T = any>(url: string, init?: RequestInit) => request<T>('GET', url, undefined, init),
|
||||
post: <T = any>(url: string, body?: any, init?: RequestInit) => request<T>('POST', url, body, init),
|
||||
put: <T = any>(url: string, body?: any, init?: RequestInit) => request<T>('PUT', url, body, init),
|
||||
delete: <T = any>(url: string, init?: RequestInit) => request<T>('DELETE', url, undefined, init),
|
||||
}
|
||||
|
||||
post: async <T = any>(url: string, data?: any, config?: any): Promise<T> => {
|
||||
const response = await axiosInstance.post(url, data, config);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
put: async <T = any>(url: string, data?: any, config?: any): Promise<T> => {
|
||||
const response = await axiosInstance.put(url, data, config);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
delete: async <T = any>(url: string, config?: any): Promise<T> => {
|
||||
const response = await axiosInstance.delete(url, config);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
patch: async <T = any>(url: string, data?: any, config?: any): Promise<T> => {
|
||||
const response = await axiosInstance.patch(url, data, config);
|
||||
return response.data;
|
||||
},
|
||||
};
|
||||
|
||||
// Chatbot specific API methods
|
||||
export const chatbotApi = {
|
||||
create: async (data: any) => apiClient.post('/api-internal/v1/chatbot/create', data),
|
||||
list: async () => apiClient.get('/api-internal/v1/chatbot/list'),
|
||||
update: async (id: string, data: any) => apiClient.put(`/api-internal/v1/chatbot/update/${id}`, data),
|
||||
delete: async (id: string) => apiClient.delete(`/api-internal/v1/chatbot/delete/${id}`),
|
||||
chat: async (id: string, message: string, config?: any) => {
|
||||
// For OpenAI-compatible chat completions
|
||||
const messages = [
|
||||
...(config?.messages || []),
|
||||
{ role: 'user', content: message }
|
||||
]
|
||||
return apiClient.post(`/api-internal/v1/chatbot/${id}/chat/completions`, {
|
||||
messages,
|
||||
...config
|
||||
})
|
||||
async listChatbots() {
|
||||
try {
|
||||
return await apiClient.get('/api-internal/v1/chatbot/list')
|
||||
} catch {
|
||||
return await apiClient.get('/api-internal/v1/chatbot/instances')
|
||||
}
|
||||
},
|
||||
};
|
||||
createChatbot(config: any) {
|
||||
return apiClient.post('/api-internal/v1/chatbot/create', config)
|
||||
},
|
||||
updateChatbot(id: string, config: any) {
|
||||
return apiClient.put(`/api-internal/v1/chatbot/update/${encodeURIComponent(id)}`, config)
|
||||
},
|
||||
deleteChatbot(id: string) {
|
||||
return apiClient.delete(`/api-internal/v1/chatbot/delete/${encodeURIComponent(id)}`)
|
||||
},
|
||||
// Legacy method with JWT auth (to be deprecated)
|
||||
sendMessage(chatbotId: string, message: string, conversationId?: string, history?: Array<{role: string; content: string}>) {
|
||||
const body: any = { message }
|
||||
if (conversationId) body.conversation_id = conversationId
|
||||
if (history) body.history = history
|
||||
return apiClient.post(`/api-internal/v1/chatbot/chat/${encodeURIComponent(chatbotId)}`, body)
|
||||
},
|
||||
// OpenAI-compatible chatbot API with API key auth
|
||||
sendOpenAIChatMessage(chatbotId: string, messages: Array<{role: string; content: string}>, apiKey: string, options?: {
|
||||
temperature?: number
|
||||
max_tokens?: number
|
||||
stream?: boolean
|
||||
}) {
|
||||
const body: any = {
|
||||
messages,
|
||||
...options
|
||||
}
|
||||
return fetch(`/api/v1/chatbot/external/${encodeURIComponent(chatbotId)}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${apiKey}`
|
||||
},
|
||||
body: JSON.stringify(body)
|
||||
}).then(res => res.json())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,60 +1,14 @@
|
||||
export const config = {
|
||||
API_BASE_URL: process.env.NEXT_PUBLIC_BASE_URL || '',
|
||||
APP_NAME: process.env.NEXT_PUBLIC_APP_NAME || 'Enclava',
|
||||
DEFAULT_LANGUAGE: 'en',
|
||||
SUPPORTED_LANGUAGES: ['en', 'es', 'fr', 'de', 'it'],
|
||||
getPublicApiUrl() {
|
||||
if (this.API_BASE_URL) {
|
||||
return this.API_BASE_URL
|
||||
getPublicApiUrl(): string {
|
||||
if (typeof process !== 'undefined' && process.env.NEXT_PUBLIC_BASE_URL) {
|
||||
return process.env.NEXT_PUBLIC_BASE_URL
|
||||
}
|
||||
|
||||
if (typeof window !== 'undefined' && window.location.origin) {
|
||||
if (typeof window !== 'undefined') {
|
||||
return window.location.origin
|
||||
}
|
||||
|
||||
return ''
|
||||
return 'http://localhost:3000'
|
||||
},
|
||||
|
||||
// Feature flags
|
||||
FEATURES: {
|
||||
RAG: true,
|
||||
PLUGINS: true,
|
||||
ANALYTICS: true,
|
||||
AUDIT_LOGS: true,
|
||||
BUDGET_MANAGEMENT: true,
|
||||
getAppName(): string {
|
||||
return process.env.NEXT_PUBLIC_APP_NAME || 'Enclava'
|
||||
},
|
||||
|
||||
// Default values
|
||||
DEFAULTS: {
|
||||
TEMPERATURE: 0.7,
|
||||
MAX_TOKENS: 1000,
|
||||
TOP_K: 5,
|
||||
MEMORY_LENGTH: 10,
|
||||
},
|
||||
|
||||
// API endpoints
|
||||
ENDPOINTS: {
|
||||
AUTH: {
|
||||
LOGIN: '/api/auth/login',
|
||||
REGISTER: '/api/auth/register',
|
||||
REFRESH: '/api/auth/refresh',
|
||||
ME: '/api/auth/me',
|
||||
},
|
||||
CHATBOT: {
|
||||
LIST: '/api/chatbot/list',
|
||||
CREATE: '/api/chatbot/create',
|
||||
UPDATE: '/api/chatbot/update/:id',
|
||||
DELETE: '/api/chatbot/delete/:id',
|
||||
CHAT: '/api/chatbot/chat',
|
||||
},
|
||||
LLM: {
|
||||
MODELS: '/api/llm/models',
|
||||
API_KEYS: '/api/llm/api-keys',
|
||||
BUDGETS: '/api/llm/budgets',
|
||||
},
|
||||
RAG: {
|
||||
COLLECTIONS: '/api/rag/collections',
|
||||
DOCUMENTS: '/api/rag/documents',
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,138 +1,51 @@
|
||||
import Cookies from 'js-cookie';
|
||||
import { tokenManager } from './token-manager'
|
||||
|
||||
export const downloadFile = async (url: string, filename?: string): Promise<void> => {
|
||||
try {
|
||||
const response = await fetch(url);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
export async function downloadFile(path: string, filename: string, params?: URLSearchParams | Record<string, string>) {
|
||||
const url = new URL(path, typeof window !== 'undefined' ? window.location.origin : 'http://localhost:3000')
|
||||
if (params) {
|
||||
const p = params instanceof URLSearchParams ? params : new URLSearchParams(params)
|
||||
p.forEach((v, k) => url.searchParams.set(k, v))
|
||||
}
|
||||
|
||||
// Get the filename from the response headers if not provided
|
||||
const contentDisposition = response.headers.get('Content-Disposition');
|
||||
let defaultFilename = filename || 'download';
|
||||
const token = await tokenManager.getAccessToken()
|
||||
const res = await fetch(url.toString(), {
|
||||
headers: {
|
||||
...(token ? { Authorization: `Bearer ${token}` } : {}),
|
||||
},
|
||||
})
|
||||
if (!res.ok) throw new Error(`Failed to download file (${res.status})`)
|
||||
const blob = await res.blob()
|
||||
|
||||
if (contentDisposition) {
|
||||
const filenameMatch = contentDisposition.match(/filename[^;=\n]*=((['"]).*?\2|[^;\n]*)/);
|
||||
if (filenameMatch && filenameMatch[1]) {
|
||||
defaultFilename = filenameMatch[1].replace(/['"]/g, '');
|
||||
if (typeof window !== 'undefined') {
|
||||
const link = document.createElement('a')
|
||||
const href = URL.createObjectURL(blob)
|
||||
link.href = href
|
||||
link.download = filename
|
||||
document.body.appendChild(link)
|
||||
link.click()
|
||||
link.remove()
|
||||
URL.revokeObjectURL(href)
|
||||
}
|
||||
}
|
||||
|
||||
export async function uploadFile(path: string, file: File, extraFields?: Record<string, string>) {
|
||||
const form = new FormData()
|
||||
form.append('file', file)
|
||||
if (extraFields) Object.entries(extraFields).forEach(([k, v]) => form.append(k, v))
|
||||
|
||||
const token = await tokenManager.getAccessToken()
|
||||
const res = await fetch(path, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
...(token ? { Authorization: `Bearer ${token}` } : {}),
|
||||
},
|
||||
body: form,
|
||||
})
|
||||
if (!res.ok) {
|
||||
let details: any
|
||||
try { details = await res.json() } catch { details = await res.text() }
|
||||
throw new Error(typeof details === 'string' ? details : (details?.error || 'Upload failed'))
|
||||
}
|
||||
return await res.json()
|
||||
}
|
||||
|
||||
// Get the blob from the response
|
||||
const blob = await response.blob();
|
||||
|
||||
// Create a temporary URL for the blob
|
||||
const blobUrl = window.URL.createObjectURL(blob);
|
||||
|
||||
// Create a temporary link element
|
||||
const link = document.createElement('a');
|
||||
link.href = blobUrl;
|
||||
link.download = defaultFilename;
|
||||
|
||||
// Append the link to the body
|
||||
document.body.appendChild(link);
|
||||
|
||||
// Trigger the download
|
||||
link.click();
|
||||
|
||||
// Clean up
|
||||
document.body.removeChild(link);
|
||||
window.URL.revokeObjectURL(blobUrl);
|
||||
} catch (error) {
|
||||
console.error('Error downloading file:', error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const downloadFileFromData = (
|
||||
data: Blob | string,
|
||||
filename: string,
|
||||
mimeType?: string
|
||||
): void => {
|
||||
try {
|
||||
let blob: Blob;
|
||||
|
||||
if (typeof data === 'string') {
|
||||
blob = new Blob([data], { type: mimeType || 'text/plain' });
|
||||
} else {
|
||||
blob = data;
|
||||
}
|
||||
|
||||
const blobUrl = window.URL.createObjectURL(blob);
|
||||
const link = document.createElement('a');
|
||||
link.href = blobUrl;
|
||||
link.download = filename;
|
||||
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
|
||||
document.body.removeChild(link);
|
||||
window.URL.revokeObjectURL(blobUrl);
|
||||
} catch (error) {
|
||||
console.error('Error downloading file from data:', error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const uploadFile = async (
|
||||
file: File,
|
||||
url: string,
|
||||
onProgress?: (progress: number) => void,
|
||||
additionalData?: Record<string, any>
|
||||
): Promise<any> => {
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
|
||||
// Add additional form data if provided
|
||||
if (additionalData) {
|
||||
Object.entries(additionalData).forEach(([key, value]) => {
|
||||
formData.append(key, value);
|
||||
});
|
||||
}
|
||||
|
||||
const xhr = new XMLHttpRequest();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
xhr.upload.onprogress = (event) => {
|
||||
if (event.lengthComputable && onProgress) {
|
||||
const progress = (event.loaded / event.total) * 100;
|
||||
onProgress(progress);
|
||||
}
|
||||
};
|
||||
|
||||
xhr.onload = () => {
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
try {
|
||||
const response = JSON.parse(xhr.responseText);
|
||||
resolve(response);
|
||||
} catch (error) {
|
||||
resolve(xhr.responseText);
|
||||
}
|
||||
} else {
|
||||
reject(new Error(`Upload failed with status ${xhr.status}`));
|
||||
}
|
||||
};
|
||||
|
||||
xhr.onerror = () => {
|
||||
reject(new Error('Network error during upload'));
|
||||
};
|
||||
|
||||
// Get authentication token from cookies
|
||||
const token = Cookies.get('access_token');
|
||||
|
||||
xhr.open('POST', url, true);
|
||||
|
||||
// Set the authorization header if token exists
|
||||
if (token) {
|
||||
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
|
||||
}
|
||||
|
||||
xhr.send(formData);
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Error uploading file:', error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
@@ -1,36 +1,15 @@
|
||||
export function generateId(): string {
|
||||
return Math.random().toString(36).substr(2, 9);
|
||||
export function generateId(prefix = "id"): string {
|
||||
const rand = Math.random().toString(36).slice(2, 10)
|
||||
return `${prefix}_${rand}`
|
||||
}
|
||||
|
||||
export function generateUniqueId(): string {
|
||||
return Date.now().toString(36) + Math.random().toString(36).substr(2);
|
||||
export function generateShortId(prefix = "id"): string {
|
||||
const rand = Math.random().toString(36).slice(2, 7)
|
||||
return `${prefix}_${rand}`
|
||||
}
|
||||
|
||||
export function generateMessageId(): string {
|
||||
return `msg_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
||||
}
|
||||
|
||||
export function generateChatId(): string {
|
||||
return `chat_${Date.now()}_${Math.random().toString(36).substr(2, 5)}`;
|
||||
}
|
||||
|
||||
export function generateSessionId(): string {
|
||||
return `sess_${Date.now()}_${Math.random().toString(36).substr(2, 8)}`;
|
||||
}
|
||||
|
||||
export function generateShortId(): string {
|
||||
return Math.random().toString(36).substr(2, 6);
|
||||
}
|
||||
|
||||
export function generateTimestampId(): string {
|
||||
return `ts_${Date.now()}_${Math.random().toString(36).substr(2, 6)}`;
|
||||
}
|
||||
|
||||
export function isValidId(id: string): boolean {
|
||||
return typeof id === 'string' && id.length > 0;
|
||||
}
|
||||
|
||||
export function extractIdFromUrl(url: string): string | null {
|
||||
const match = url.match(/\/([a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12})$/);
|
||||
return match ? match[1] : null;
|
||||
export function generateTimestampId(prefix = "id"): string {
|
||||
const ts = Date.now()
|
||||
const rand = Math.floor(Math.random() * 1000).toString().padStart(3, '0')
|
||||
return `${prefix}_${ts}_${rand}`
|
||||
}
|
||||
@@ -1,176 +1,31 @@
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
const BACKEND_URL = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`
|
||||
|
||||
// This is a proxy auth utility for server-side API routes
|
||||
// It handles authentication and proxying requests to the backend
|
||||
|
||||
export interface ProxyAuthConfig {
|
||||
backendUrl: string;
|
||||
requireAuth?: boolean;
|
||||
allowedRoles?: string[];
|
||||
function mapPath(path: string): string {
|
||||
// Convert '/api-internal/..' to backend '/api/..'
|
||||
if (path.startsWith('/api-internal/')) {
|
||||
return path.replace('/api-internal/', '/api/')
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
export class ProxyAuth {
|
||||
private config: ProxyAuthConfig;
|
||||
|
||||
constructor(config: ProxyAuthConfig) {
|
||||
this.config = {
|
||||
requireAuth: true,
|
||||
...config,
|
||||
};
|
||||
}
|
||||
|
||||
async authenticate(request: NextRequest): Promise<{ success: boolean; user?: any; error?: string }> {
|
||||
// For server-side auth, we would typically validate the token
|
||||
// This is a simplified implementation
|
||||
const authHeader = request.headers.get('authorization');
|
||||
|
||||
if (!authHeader || !authHeader.startsWith('Bearer ')) {
|
||||
return { success: false, error: 'Missing or invalid authorization header' };
|
||||
}
|
||||
|
||||
const token = authHeader.substring(7);
|
||||
|
||||
// Here you would validate the token with your auth service
|
||||
// For now, we'll just check if it exists
|
||||
if (!token) {
|
||||
return { success: false, error: 'Invalid token' };
|
||||
}
|
||||
|
||||
// In a real implementation, you would decode and validate the JWT
|
||||
// and check user roles if required
|
||||
return {
|
||||
success: true,
|
||||
user: {
|
||||
id: 'user-id',
|
||||
email: 'user@example.com',
|
||||
role: 'user'
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async proxyRequest(
|
||||
request: NextRequest,
|
||||
path: string,
|
||||
options?: {
|
||||
method?: string;
|
||||
headers?: Record<string, string>;
|
||||
body?: any;
|
||||
}
|
||||
): Promise<NextResponse> {
|
||||
// Authenticate the request if required
|
||||
if (this.config.requireAuth) {
|
||||
const authResult = await this.authenticate(request);
|
||||
if (!authResult.success) {
|
||||
return NextResponse.json(
|
||||
{ error: authResult.error || 'Authentication failed' },
|
||||
{ status: 401 }
|
||||
);
|
||||
}
|
||||
|
||||
// Check roles if specified
|
||||
if (this.config.allowedRoles && authResult.user) {
|
||||
if (!this.config.allowedRoles.includes(authResult.user.role)) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Insufficient permissions' },
|
||||
{ status: 403 }
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the target URL
|
||||
const targetUrl = new URL(path, this.config.backendUrl);
|
||||
|
||||
// Copy query parameters
|
||||
targetUrl.search = request.nextUrl.search;
|
||||
|
||||
// Prepare headers
|
||||
export async function proxyRequest(path: string, init?: RequestInit): Promise<Response> {
|
||||
const url = `${BACKEND_URL}${mapPath(path)}`
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
...options?.headers,
|
||||
};
|
||||
|
||||
// Forward authorization header if present
|
||||
const authHeader = request.headers.get('authorization');
|
||||
if (authHeader) {
|
||||
headers.authorization = authHeader;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(targetUrl, {
|
||||
method: options?.method || request.method,
|
||||
headers,
|
||||
body: options?.body ? JSON.stringify(options.body) : request.body,
|
||||
});
|
||||
|
||||
// Create a new response with the data
|
||||
const data = await response.json();
|
||||
|
||||
return NextResponse.json(data, {
|
||||
status: response.status,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Proxy request failed:', error);
|
||||
return NextResponse.json(
|
||||
{ error: 'Internal server error' },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
...(init?.headers as Record<string, string> | undefined),
|
||||
}
|
||||
return fetch(url, { ...init, headers })
|
||||
}
|
||||
|
||||
// Utility function to create a proxy handler
|
||||
export function createProxyHandler(config: ProxyAuthConfig) {
|
||||
const proxyAuth = new ProxyAuth(config);
|
||||
|
||||
return async (request: NextRequest, { params }: { params: { path?: string[] } }) => {
|
||||
const path = params?.path ? params.path.join('/') : '';
|
||||
return proxyAuth.proxyRequest(request, path);
|
||||
};
|
||||
export async function handleProxyResponse<T = any>(response: Response, defaultMessage = 'Request failed'): Promise<T> {
|
||||
if (!response.ok) {
|
||||
let details: any
|
||||
try { details = await response.json() } catch { details = await response.text() }
|
||||
throw new Error(typeof details === 'string' ? `${defaultMessage}: ${details}` : (details?.error || defaultMessage))
|
||||
}
|
||||
const contentType = response.headers.get('content-type') || ''
|
||||
if (contentType.includes('application/json')) return (await response.json()) as T
|
||||
// @ts-ignore allow non-json
|
||||
return (await response.text()) as T
|
||||
}
|
||||
|
||||
// Simplified proxy request function for direct usage
|
||||
export async function proxyRequest(
|
||||
request: NextRequest,
|
||||
backendUrl: string,
|
||||
path: string = '',
|
||||
options?: {
|
||||
method?: string;
|
||||
headers?: Record<string, string>;
|
||||
body?: any;
|
||||
requireAuth?: boolean;
|
||||
}
|
||||
): Promise<NextResponse> {
|
||||
const proxyAuth = new ProxyAuth({
|
||||
backendUrl,
|
||||
requireAuth: options?.requireAuth ?? true,
|
||||
});
|
||||
|
||||
return proxyAuth.proxyRequest(request, path, options);
|
||||
}
|
||||
|
||||
// Helper function to handle proxy responses with error handling
|
||||
export async function handleProxyResponse(
|
||||
request: NextRequest,
|
||||
backendUrl: string,
|
||||
path: string = '',
|
||||
options?: {
|
||||
method?: string;
|
||||
headers?: Record<string, string>;
|
||||
body?: any;
|
||||
requireAuth?: boolean;
|
||||
}
|
||||
): Promise<NextResponse> {
|
||||
try {
|
||||
return await proxyRequest(request, backendUrl, path, options);
|
||||
} catch (error) {
|
||||
console.error('Proxy request failed:', error);
|
||||
return NextResponse.json(
|
||||
{ error: 'Internal server error' },
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,99 +1,140 @@
|
||||
import Cookies from 'js-cookie';
|
||||
import { EventEmitter } from 'events';
|
||||
type Listener = (...args: any[]) => void
|
||||
|
||||
interface TokenManagerEvents {
|
||||
tokensUpdated: [];
|
||||
tokensCleared: [];
|
||||
class SimpleEmitter {
|
||||
private listeners = new Map<string, Set<Listener>>()
|
||||
|
||||
on(event: string, listener: Listener) {
|
||||
if (!this.listeners.has(event)) this.listeners.set(event, new Set())
|
||||
this.listeners.get(event)!.add(listener)
|
||||
}
|
||||
|
||||
off(event: string, listener: Listener) {
|
||||
this.listeners.get(event)?.delete(listener)
|
||||
}
|
||||
|
||||
emit(event: string, ...args: any[]) {
|
||||
this.listeners.get(event)?.forEach(l => l(...args))
|
||||
}
|
||||
}
|
||||
|
||||
export interface TokenManagerInterface {
|
||||
getTokens(): { access_token: string | null; refresh_token: string | null };
|
||||
setTokens(access_token: string, refresh_token: string): void;
|
||||
clearTokens(): void;
|
||||
isAuthenticated(): boolean;
|
||||
getAccessToken(): string | null;
|
||||
getRefreshToken(): string | null;
|
||||
getTokenExpiry(): { access_token_expiry: number | null; refresh_token_expiry: number | null };
|
||||
on<E extends keyof TokenManagerEvents>(
|
||||
event: E,
|
||||
listener: (...args: TokenManagerEvents[E]) => void
|
||||
): this;
|
||||
off<E extends keyof TokenManagerEvents>(
|
||||
event: E,
|
||||
listener: (...args: TokenManagerEvents[E]) => void
|
||||
): this;
|
||||
interface StoredTokens {
|
||||
access_token: string
|
||||
refresh_token: string
|
||||
access_expires_at: number // epoch ms
|
||||
refresh_expires_at?: number // epoch ms
|
||||
}
|
||||
|
||||
class TokenManager extends EventEmitter implements TokenManagerInterface {
|
||||
private static instance: TokenManager;
|
||||
const ACCESS_LIFETIME_FALLBACK_MS = 30 * 60 * 1000 // 30 minutes
|
||||
const REFRESH_LIFETIME_FALLBACK_MS = 7 * 24 * 60 * 60 * 1000 // 7 days
|
||||
|
||||
private constructor() {
|
||||
super();
|
||||
// Set max listeners to avoid memory leak warnings
|
||||
this.setMaxListeners(100);
|
||||
function now() { return Date.now() }
|
||||
|
||||
function readTokens(): StoredTokens | null {
|
||||
if (typeof window === 'undefined') return null
|
||||
try {
|
||||
const raw = window.localStorage.getItem('auth_tokens')
|
||||
return raw ? JSON.parse(raw) as StoredTokens : null
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
function writeTokens(tokens: StoredTokens | null) {
|
||||
if (typeof window === 'undefined') return
|
||||
if (tokens) {
|
||||
window.localStorage.setItem('auth_tokens', JSON.stringify(tokens))
|
||||
} else {
|
||||
window.localStorage.removeItem('auth_tokens')
|
||||
}
|
||||
}
|
||||
|
||||
class TokenManager extends SimpleEmitter {
|
||||
private refreshTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
isAuthenticated(): boolean {
|
||||
const t = readTokens()
|
||||
return !!t && t.access_expires_at > now()
|
||||
}
|
||||
|
||||
static getInstance(): TokenManager {
|
||||
if (!TokenManager.instance) {
|
||||
TokenManager.instance = new TokenManager();
|
||||
}
|
||||
return TokenManager.instance;
|
||||
getTokenExpiry(): Date | null {
|
||||
const t = readTokens()
|
||||
return t ? new Date(t.access_expires_at) : null
|
||||
}
|
||||
|
||||
getTokens() {
|
||||
return {
|
||||
access_token: Cookies.get('access_token'),
|
||||
refresh_token: Cookies.get('refresh_token'),
|
||||
};
|
||||
getRefreshTokenExpiry(): Date | null {
|
||||
const t = readTokens()
|
||||
return t?.refresh_expires_at ? new Date(t.refresh_expires_at) : null
|
||||
}
|
||||
|
||||
setTokens(access_token: string, refresh_token: string) {
|
||||
// Set cookies with secure attributes
|
||||
Cookies.set('access_token', access_token, {
|
||||
expires: 7, // 7 days
|
||||
secure: process.env.NODE_ENV === 'production',
|
||||
sameSite: 'strict',
|
||||
});
|
||||
|
||||
Cookies.set('refresh_token', refresh_token, {
|
||||
expires: 30, // 30 days
|
||||
secure: process.env.NODE_ENV === 'production',
|
||||
sameSite: 'strict',
|
||||
});
|
||||
|
||||
// Emit event
|
||||
this.emit('tokensUpdated');
|
||||
setTokens(accessToken: string, refreshToken: string, expiresInSeconds?: number) {
|
||||
const access_expires_at = now() + (expiresInSeconds ? expiresInSeconds * 1000 : ACCESS_LIFETIME_FALLBACK_MS)
|
||||
const refresh_expires_at = now() + REFRESH_LIFETIME_FALLBACK_MS
|
||||
const tokens: StoredTokens = {
|
||||
access_token: accessToken,
|
||||
refresh_token: refreshToken,
|
||||
access_expires_at,
|
||||
refresh_expires_at,
|
||||
}
|
||||
writeTokens(tokens)
|
||||
this.scheduleRefresh()
|
||||
this.emit('tokensUpdated')
|
||||
}
|
||||
|
||||
clearTokens() {
|
||||
Cookies.remove('access_token');
|
||||
Cookies.remove('refresh_token');
|
||||
this.emit('tokensCleared');
|
||||
if (this.refreshTimer) {
|
||||
clearTimeout(this.refreshTimer)
|
||||
this.refreshTimer = null
|
||||
}
|
||||
writeTokens(null)
|
||||
this.emit('tokensCleared')
|
||||
}
|
||||
|
||||
isAuthenticated(): boolean {
|
||||
return !!this.getAccessToken();
|
||||
logout() {
|
||||
this.clearTokens()
|
||||
this.emit('logout')
|
||||
}
|
||||
|
||||
getAccessToken(): string | null {
|
||||
return Cookies.get('access_token');
|
||||
private scheduleRefresh() {
|
||||
if (typeof window === 'undefined') return
|
||||
const t = readTokens()
|
||||
if (!t) return
|
||||
if (this.refreshTimer) clearTimeout(this.refreshTimer)
|
||||
const msUntilRefresh = Math.max(5_000, t.access_expires_at - now() - 60_000) // 1 minute before expiry
|
||||
this.refreshTimer = setTimeout(() => {
|
||||
this.refreshAccessToken().catch(() => {
|
||||
this.emit('sessionExpired', 'refresh_failed')
|
||||
this.clearTokens()
|
||||
})
|
||||
}, msUntilRefresh)
|
||||
}
|
||||
|
||||
getRefreshToken(): string | null {
|
||||
return Cookies.get('refresh_token');
|
||||
async getAccessToken(): Promise<string | null> {
|
||||
const t = readTokens()
|
||||
if (!t) return null
|
||||
if (t.access_expires_at - now() > 10_000) return t.access_token
|
||||
try {
|
||||
await this.refreshAccessToken()
|
||||
return readTokens()?.access_token || null
|
||||
} catch {
|
||||
this.emit('sessionExpired', 'expired')
|
||||
this.clearTokens()
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
getTokenExpiry(): { access_token_expiry: number | null; refresh_token_expiry: number | null } {
|
||||
return {
|
||||
access_token_expiry: parseInt(Cookies.get('access_token_expiry') || '0') || null,
|
||||
refresh_token_expiry: parseInt(Cookies.get('refresh_token_expiry') || '0') || null,
|
||||
};
|
||||
}
|
||||
|
||||
getRefreshTokenExpiry(): number | null {
|
||||
return parseInt(Cookies.get('refresh_token_expiry') || '0') || null;
|
||||
private async refreshAccessToken(): Promise<void> {
|
||||
const t = readTokens()
|
||||
if (!t?.refresh_token) throw new Error('No refresh token')
|
||||
const res = await fetch('/api-internal/v1/auth/refresh', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ refresh_token: t.refresh_token }),
|
||||
})
|
||||
if (!res.ok) throw new Error('Refresh failed')
|
||||
const data = await res.json()
|
||||
const expiresIn = data.expires_in as number | undefined
|
||||
this.setTokens(data.access_token, data.refresh_token || t.refresh_token, expiresIn)
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const tokenManager = TokenManager.getInstance();
|
||||
export const tokenManager = new TokenManager()
|
||||
@@ -1,98 +1,8 @@
|
||||
import { type ClassValue, clsx } from "clsx"
|
||||
import { twMerge } from "tailwind-merge"
|
||||
import { type ClassValue } from 'clsx'
|
||||
import { clsx } from 'clsx'
|
||||
import { twMerge } from 'tailwind-merge'
|
||||
|
||||
|
||||
export function cn(...inputs: ClassValue[]) {
|
||||
return twMerge(clsx(inputs))
|
||||
}
|
||||
|
||||
export function formatDate(date: Date | string): string {
|
||||
const d = new Date(date);
|
||||
return d.toLocaleDateString('en-US', {
|
||||
year: 'numeric',
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
});
|
||||
}
|
||||
|
||||
export function formatDateTime(date: Date | string): string {
|
||||
const d = new Date(date);
|
||||
return d.toLocaleString('en-US', {
|
||||
year: 'numeric',
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
});
|
||||
}
|
||||
|
||||
export function formatRelativeTime(date: Date | string): string {
|
||||
const d = new Date(date);
|
||||
const now = new Date();
|
||||
const diffMs = now.getTime() - d.getTime();
|
||||
const diffSeconds = Math.floor(diffMs / 1000);
|
||||
const diffMinutes = Math.floor(diffSeconds / 60);
|
||||
const diffHours = Math.floor(diffMinutes / 60);
|
||||
const diffDays = Math.floor(diffHours / 24);
|
||||
|
||||
if (diffSeconds < 60) return 'just now';
|
||||
if (diffMinutes < 60) return `${diffMinutes} minute${diffMinutes > 1 ? 's' : ''} ago`;
|
||||
if (diffHours < 24) return `${diffHours} hour${diffHours > 1 ? 's' : ''} ago`;
|
||||
if (diffDays < 7) return `${diffDays} day${diffDays > 1 ? 's' : ''} ago`;
|
||||
|
||||
return formatDate(d);
|
||||
}
|
||||
|
||||
export function formatBytes(bytes: number, decimals = 2): string {
|
||||
if (bytes === 0) return '0 Bytes';
|
||||
|
||||
const k = 1024;
|
||||
const dm = decimals < 0 ? 0 : decimals;
|
||||
const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB'];
|
||||
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(dm)) + ' ' + sizes[i];
|
||||
}
|
||||
|
||||
export function formatNumber(num: number): string {
|
||||
return new Intl.NumberFormat('en-US').format(num);
|
||||
}
|
||||
|
||||
export function formatCurrency(amount: number, currency = 'USD'): string {
|
||||
return new Intl.NumberFormat('en-US', {
|
||||
style: 'currency',
|
||||
currency,
|
||||
}).format(amount);
|
||||
}
|
||||
|
||||
export function truncate(str: string, length: number): string {
|
||||
if (str.length <= length) return str;
|
||||
return str.slice(0, length) + '...';
|
||||
}
|
||||
|
||||
export function debounce<T extends (...args: any[]) => any>(
|
||||
func: T,
|
||||
wait: number
|
||||
): (...args: Parameters<T>) => void {
|
||||
let timeout: NodeJS.Timeout;
|
||||
|
||||
return (...args: Parameters<T>) => {
|
||||
clearTimeout(timeout);
|
||||
timeout = setTimeout(() => func(...args), wait);
|
||||
};
|
||||
}
|
||||
|
||||
export function throttle<T extends (...args: any[]) => any>(
|
||||
func: T,
|
||||
limit: number
|
||||
): (...args: Parameters<T>) => void {
|
||||
let inThrottle: boolean;
|
||||
|
||||
return (...args: Parameters<T>) => {
|
||||
if (!inThrottle) {
|
||||
func(...args);
|
||||
inThrottle = true;
|
||||
setTimeout(() => (inThrottle = false), limit);
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -26,6 +26,12 @@ http {
|
||||
listen 80;
|
||||
server_name localhost;
|
||||
|
||||
# Static files - serve directly from nginx
|
||||
location = /login_helper.html {
|
||||
root /usr/share/nginx/html;
|
||||
try_files $uri =404;
|
||||
}
|
||||
|
||||
# Frontend routes
|
||||
location / {
|
||||
proxy_pass http://frontend;
|
||||
@@ -66,6 +72,58 @@ http {
|
||||
}
|
||||
}
|
||||
|
||||
# RAG debug API routes - proxy to frontend (for Next.js API routes)
|
||||
location /api/rag/debug/ {
|
||||
proxy_pass http://frontend;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# CORS headers
|
||||
add_header 'Access-Control-Allow-Origin' '*' always;
|
||||
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always;
|
||||
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
|
||||
add_header 'Access-Control-Allow-Expose-Headers' 'Content-Length,Content-Range' always;
|
||||
|
||||
# Handle preflight requests
|
||||
if ($request_method = 'OPTIONS') {
|
||||
add_header 'Access-Control-Allow-Origin' '*';
|
||||
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS';
|
||||
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization';
|
||||
add_header 'Access-Control-Max-Age' 1728000;
|
||||
add_header 'Content-Type' 'text/plain; charset=utf-8';
|
||||
add_header 'Content-Length' 0;
|
||||
return 204;
|
||||
}
|
||||
}
|
||||
|
||||
# Frontend API routes for authentication - proxy to frontend
|
||||
location /api/auth/ {
|
||||
proxy_pass http://frontend;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# CORS headers
|
||||
add_header 'Access-Control-Allow-Origin' '*' always;
|
||||
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always;
|
||||
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
|
||||
add_header 'Access-Control-Expose-Headers' 'Content-Length,Content-Range' always;
|
||||
|
||||
# Handle preflight requests
|
||||
if ($request_method = 'OPTIONS') {
|
||||
add_header 'Access-Control-Allow-Origin' '*';
|
||||
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS';
|
||||
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization';
|
||||
add_header 'Access-Control-Max-Age' 1728000;
|
||||
add_header 'Content-Type' 'text/plain; charset=utf-8';
|
||||
add_header 'Content-Length' 0;
|
||||
return 204;
|
||||
}
|
||||
}
|
||||
|
||||
# Public API routes - proxy to backend (for external clients)
|
||||
location /api/ {
|
||||
proxy_pass http://backend;
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify debugging endpoints
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from backend.app.core.security import create_access_token
|
||||
from backend.app.db.database import SessionLocal
|
||||
from backend.app.models.user import User
|
||||
|
||||
def get_auth_token():
|
||||
"""Get an authentication token for testing"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get first user (or create one for testing)
|
||||
user = db.query(User).first()
|
||||
if not user:
|
||||
# Create test user if none exists
|
||||
user = User(
|
||||
email="test@example.com",
|
||||
hashed_password="$2b$12$LQv3c1yqBWVHxkd0LHAkCOYz6TtxMQJqhN8/LeZeUfkZMBs9kYZP6" # password: password
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
# Create JWT token
|
||||
token = create_access_token(data={"sub": str(user.id)})
|
||||
return token
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def test_endpoint(url, token, method="GET", data=None):
|
||||
"""Test an endpoint with authentication"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
if method == "GET":
|
||||
response = requests.get(f"http://localhost:3000{url}", headers=headers)
|
||||
elif method == "POST":
|
||||
response = requests.post(f"http://localhost:3000{url}", headers=headers, json=data)
|
||||
|
||||
print(f"\n{method} {url}")
|
||||
print(f"Status: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
print("Response:")
|
||||
print(json.dumps(response.json(), indent=2))
|
||||
else:
|
||||
print("Error:", response.text)
|
||||
|
||||
return response
|
||||
|
||||
def main():
|
||||
print("=== Testing Debugging Endpoints ===")
|
||||
|
||||
# Get authentication token
|
||||
print("\n1. Getting authentication token...")
|
||||
token = get_auth_token()
|
||||
print(f"Token: {token[:50]}...")
|
||||
|
||||
# Test system status
|
||||
print("\n2. Testing system status...")
|
||||
test_endpoint("/api-internal/v1/debugging/system/status", token)
|
||||
|
||||
# Test getting chatbot list first
|
||||
print("\n3. Getting chatbot list...")
|
||||
response = test_endpoint("/api-internal/v1/chatbot/list", token)
|
||||
|
||||
if response.status_code == 200:
|
||||
chatbots = response.json()
|
||||
if chatbots:
|
||||
chatbot_id = chatbots[0]["id"]
|
||||
print(f"\n4. Testing chatbot config for: {chatbot_id}")
|
||||
test_endpoint(f"/api-internal/v1/debugging/chatbot/{chatbot_id}/config", token)
|
||||
|
||||
print(f"\n5. Testing RAG search for: {chatbot_id}")
|
||||
test_endpoint(f"/api-internal/v1/debugging/chatbot/{chatbot_id}/test-rag?query=What is security?", token)
|
||||
else:
|
||||
print("\n4. No chatbots found to test")
|
||||
else:
|
||||
print("\n4. Could not get chatbot list")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
166
verify_security_removal.py
Normal file
166
verify_security_removal.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Verification script for security middleware removal
|
||||
"""
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
def run_command(cmd, cwd=None):
|
||||
"""Run a command and return the result"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=cwd,
|
||||
timeout=30
|
||||
)
|
||||
return result.returncode, result.stdout, result.stderr
|
||||
except subprocess.TimeoutExpired:
|
||||
return -1, "", "Command timed out"
|
||||
|
||||
def test_backend_syntax():
|
||||
"""Test if backend Python files have valid syntax"""
|
||||
print("🔍 Testing backend Python syntax...")
|
||||
|
||||
# Check main.py
|
||||
code, stdout, stderr = run_command("python3 -m py_compile app/main.py", cwd="backend")
|
||||
if code == 0:
|
||||
print("✅ main.py syntax OK")
|
||||
else:
|
||||
print(f"❌ main.py syntax error: {stderr}")
|
||||
return False
|
||||
|
||||
# Check security middleware
|
||||
code, stdout, stderr = run_command("python3 -m py_compile app/middleware/security.py", cwd="backend")
|
||||
if code == 0:
|
||||
print("✅ security.py syntax OK")
|
||||
else:
|
||||
print(f"❌ security.py syntax error: {stderr}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_docker_build():
|
||||
"""Test if Docker can build the backend service"""
|
||||
print("\n🐳 Testing Docker backend build...")
|
||||
|
||||
# Just check if the Dockerfile exists and is readable
|
||||
try:
|
||||
with open("backend/Dockerfile", "r") as f:
|
||||
content = f.read()
|
||||
if "FROM" in content and "python" in content:
|
||||
print("✅ Dockerfile exists and looks valid")
|
||||
return True
|
||||
else:
|
||||
print("❌ Dockerfile appears invalid")
|
||||
return False
|
||||
except FileNotFoundError:
|
||||
print("❌ Dockerfile not found")
|
||||
return False
|
||||
|
||||
def test_env_settings():
|
||||
"""Test if environment settings are correct"""
|
||||
print("\n⚙️ Testing environment settings...")
|
||||
|
||||
try:
|
||||
with open(".env", "r") as f:
|
||||
env_content = f.read()
|
||||
|
||||
if "API_SECURITY_ENABLED=false" in env_content:
|
||||
print("✅ Security is disabled in .env")
|
||||
else:
|
||||
print("❌ Security is not disabled in .env")
|
||||
return False
|
||||
|
||||
if "API_RATE_LIMITING_ENABLED=false" in env_content:
|
||||
print("✅ Rate limiting is disabled in .env")
|
||||
else:
|
||||
print("❌ Rate limiting is not disabled in .env")
|
||||
return False
|
||||
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
print("❌ .env file not found")
|
||||
return False
|
||||
|
||||
def test_imports():
|
||||
"""Test if the main application can be imported without security dependencies"""
|
||||
print("\n📦 Testing import dependencies...")
|
||||
|
||||
# Create a minimal test script
|
||||
test_script = """
|
||||
import sys
|
||||
sys.path.insert(0, 'backend')
|
||||
|
||||
try:
|
||||
# Test if we can create the app without security middleware
|
||||
from app.main import app
|
||||
print("✅ App can be imported successfully")
|
||||
except ImportError as e:
|
||||
print(f"❌ Import error: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"❌ Other error: {e}")
|
||||
sys.exit(1)
|
||||
"""
|
||||
|
||||
# Save test script
|
||||
with open("test_import.py", "w") as f:
|
||||
f.write(test_script)
|
||||
|
||||
# Run test (will likely fail due to missing dependencies, but should not fail due to security imports)
|
||||
code, stdout, stderr = run_command("python3 test_import.py")
|
||||
|
||||
# Clean up
|
||||
import os
|
||||
os.remove("test_import.py")
|
||||
|
||||
# We expect this to fail due to missing FastAPI, but not due to security imports
|
||||
if "security" in stderr.lower() and "No module named" not in stderr:
|
||||
print("❌ Security import errors detected")
|
||||
return False
|
||||
else:
|
||||
print("✅ No security import errors detected")
|
||||
return True
|
||||
|
||||
def main():
|
||||
"""Run all verification tests"""
|
||||
print("🚀 Starting verification of security middleware removal...\n")
|
||||
|
||||
tests = [
|
||||
("Environment Settings", test_env_settings),
|
||||
("Python Syntax", test_backend_syntax),
|
||||
("Docker Configuration", test_docker_build),
|
||||
("Import Dependencies", test_imports),
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_name, test_func in tests:
|
||||
print(f"\n--- {test_name} ---")
|
||||
result = test_func()
|
||||
results.append((test_name, result))
|
||||
|
||||
# Print summary
|
||||
print("\n" + "="*50)
|
||||
print("📊 VERIFICATION SUMMARY")
|
||||
print("="*50)
|
||||
|
||||
for test_name, result in results:
|
||||
status = "✅ PASS" if result else "❌ FAIL"
|
||||
print(f"{test_name}: {status}")
|
||||
|
||||
all_passed = all(result for _, result in results)
|
||||
|
||||
if all_passed:
|
||||
print("\n🎉 All tests passed! Security middleware has been successfully removed.")
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please review the issues above.")
|
||||
|
||||
return all_passed
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user