Merge pull request #4 from aljazceru/redoing-things

Redoing things
This commit is contained in:
2025-10-02 10:54:55 +02:00
committed by GitHub
65 changed files with 7211 additions and 3007 deletions

152
.env Normal file
View File

@@ -0,0 +1,152 @@
# ===================================
# ENCLAVA MINIMAL CONFIGURATION
# ===================================
# Only essential environment variables that CANNOT have defaults
# Other settings should be configurable through the app UI
# ===================================
# INFRASTRUCTURE (Required)
# ===================================
DATABASE_URL=postgresql://enclava_user:enclava_pass@enclava-postgres:5432/enclava_db
REDIS_URL=redis://enclava-redis:6379
# ===================================
# SECURITY CRITICAL (Required)
# ===================================
JWT_SECRET=your-super-secret-jwt-key-here-change-in-production
PRIVATEMODE_API_KEY=dfaea90e-df15-48d4-94ff-5ee243b846bb
# Admin user (created on first startup only)
ADMIN_EMAIL=admin@example.com
ADMIN_PASSWORD=admin123
API_RATE_LIMITING_ENABLED=false
# ===================================
# ADDITIONAL SECURITY SETTINGS (Optional but recommended)
# ===================================
# JWT Algorithm (default: HS256)
# JWT_ALGORITHM=HS256
# Token expiration times (in minutes)
# ACCESS_TOKEN_EXPIRE_MINUTES=30
# REFRESH_TOKEN_EXPIRE_MINUTES=10080
# SESSION_EXPIRE_MINUTES=1440
# API Key prefix (default: en_)
# API_KEY_PREFIX=en_
# Security thresholds (0.0-1.0)
# API_SECURITY_RISK_THRESHOLD=0.8
# API_SECURITY_WARNING_THRESHOLD=0.6
# API_SECURITY_ANOMALY_THRESHOLD=0.7
# IP security (comma-separated for multiple IPs)
# API_BLOCKED_IPS=
# API_ALLOWED_IPS=
# ===================================
# APPLICATION BASE URL (Required - derives all URLs and CORS)
# ===================================
BASE_URL=localhost:80
# Frontend derives: APP_URL=http://localhost, API_URL=http://localhost, WS_URL=ws://localhost
# Backend derives: CORS_ORIGINS=["http://localhost"]
# ===================================
# DOCKER NETWORKING (Required for containers)
# ===================================
BACKEND_INTERNAL_PORT=8000
FRONTEND_INTERNAL_PORT=3000
# Hosts are fixed: enclava-backend, enclava-frontend
# Upstreams derive: enclava-backend:8000, enclava-frontend:3000
# ===================================
# QDRANT (Required for RAG)
# ===================================
QDRANT_HOST=enclava-qdrant
QDRANT_PORT=6333
QDRANT_URL=http://enclava-qdrant:6333
# ===================================
# OPTIONAL PRIVATEMODE SETTINGS (Have defaults)
# ===================================
# PRIVATEMODE_CACHE_MODE=none # Optional: defaults to 'none'
# PRIVATEMODE_CACHE_SALT= # Optional: defaults to empty
# ===================================
# OPTIONAL CONFIGURATION (All have sensible defaults)
# ===================================
# Application Settings
# APP_NAME=Enclava
# APP_DEBUG=false
# APP_LOG_LEVEL=INFO
# APP_HOST=0.0.0.0
# APP_PORT=8000
# Security Features
API_SECURITY_ENABLED=false
# API_THREAT_DETECTION_ENABLED=true
# API_IP_REPUTATION_ENABLED=true
# API_ANOMALY_DETECTION_ENABLED=true
API_RATE_LIMITING_ENABLED=false
# API_SECURITY_HEADERS_ENABLED=true
# Content Security Policy
# API_CSP_HEADER=default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'
# Rate Limiting (requests per minute/hour)
# API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE=300
# API_RATE_LIMIT_AUTHENTICATED_PER_HOUR=5000
# API_RATE_LIMIT_API_KEY_PER_MINUTE=1000
# API_RATE_LIMIT_API_KEY_PER_HOUR=20000
# API_RATE_LIMIT_PREMIUM_PER_MINUTE=5000
# API_RATE_LIMIT_PREMIUM_PER_HOUR=100000
# Request Size Limits (in bytes)
# API_MAX_REQUEST_BODY_SIZE=10485760 # 10MB
# API_MAX_REQUEST_BODY_SIZE_PREMIUM=52428800 # 50MB
# MAX_UPLOAD_SIZE=10485760 # 10MB
# Monitoring
# PROMETHEUS_ENABLED=true
# PROMETHEUS_PORT=9090
# Logging
# LOG_FORMAT=json
# LOG_LEVEL=INFO
# LOG_LLM_PROMPTS=false
# Module Configuration
# MODULES_CONFIG_PATH=config/modules.yaml
# Plugin Configuration
# PLUGINS_DIR=/plugins
# PLUGINS_CONFIG_PATH=config/plugins.yaml
# PLUGIN_REPOSITORY_URL=https://plugins.enclava.com
# PLUGIN_ENCRYPTION_KEY=
# ===================================
# RAG EMBEDDING ENHANCED SETTINGS
# ===================================
# Enhanced embedding service configuration
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE=60
RAG_EMBEDDING_BATCH_SIZE=5
RAG_EMBEDDING_RETRY_COUNT=3
RAG_EMBEDDING_RETRY_DELAYS=1,2,4,8,16
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES=0.5
# Fallback embedding behavior
RAG_ALLOW_FALLBACK_EMBEDDINGS=true
RAG_WARN_ON_FALLBACK=true
# Processing timeouts (in seconds)
RAG_DOCUMENT_PROCESSING_TIMEOUT=300
RAG_EMBEDDING_GENERATION_TIMEOUT=120
RAG_INDEXING_TIMEOUT=120
# ===================================
# SUMMARY
# ===================================
# Required: DATABASE_URL, REDIS_URL, JWT_SECRET, ADMIN_EMAIL, ADMIN_PASSWORD, BASE_URL
# Recommended: PRIVATEMODE_API_KEY, QDRANT_HOST, QDRANT_PORT
# Optional: All other settings have secure defaults
# ===================================

View File

@@ -77,6 +77,16 @@ QDRANT_HOST=enclava-qdrant
QDRANT_PORT=6333
QDRANT_URL=http://enclava-qdrant:6333
# ===================================
# RAG EMBEDDING CONFIGURATION (Optional overrides)
# ===================================
# These control embedding throughput to avoid provider 429s.
# Defaults are conservative; uncomment to override.
# RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE=12
# RAG_EMBEDDING_BATCH_SIZE=3
# RAG_EMBEDDING_DELAY_BETWEEN_BATCHES=1.0 # seconds
# RAG_EMBEDDING_DELAY_PER_REQUEST=0.5 # seconds
# ===================================
# OPTIONAL PRIVATEMODE SETTINGS (Have defaults)
# ===================================

68
.gitignore vendored
View File

@@ -1,4 +1,68 @@
.env
backend/.config_encryption_key
# Python
__pycache__/
*.py[cod]
*.pyo
*.pyd
*.env
*.venv
env/
venv/
ENV/
env.bak/
venv.bak/
*.sqlite3
*.db
# FastAPI logs
*.log
# Node.js
node_modules/
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
# Next.js build
frontend/.next/
frontend/out/
frontend/.env.local
frontend/.env.production
frontend/.env.development
backend/storage/
# TypeScript
*.tsbuildinfo
# Coverage reports
htmlcov/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.pyc
*.pyo
*.pyd
.pytest_cache/
backend/.pytest_cache/
backend/.mypy_cache/
.mypy_cache/
*.prof
backend/_to_delete/
backend/__pycache__/
backend/app/core/__pycache__/
backend/app/services/__pycache__/
backend/app/services/llm/__pycache__/
backend/app/services/llm/providers/__pycache__/
backend/app/utils/__pycache__/
backend/modules/rag/__pycache__/
frontend/.next/
frontend/node_modules/
node_modules/
venv/

0
backend/.env Normal file
View File

View File

@@ -17,6 +17,9 @@ RUN DEBIAN_FRONTEND=noninteractive apt-get update && apt-get install -y \
ffmpeg \
&& rm -rf /var/lib/apt/lists/*
# Install CPU-only PyTorch and compatible numpy first (faster download)
RUN pip install --no-cache-dir torch==2.5.1+cpu torchaudio==2.5.1+cpu --index-url https://download.pytorch.org/whl/cpu -f https://download.pytorch.org/whl/torch_stable.html
# Copy requirements and install Python dependencies
COPY requirements.txt .
COPY tests/requirements-test.txt ./tests/

View File

@@ -12,8 +12,8 @@ from ..v1.audit import router as audit_router
from ..v1.settings import router as settings_router
from ..v1.analytics import router as analytics_router
from ..v1.rag import router as rag_router
from ..rag_debug import router as rag_debug_router
from ..v1.prompt_templates import router as prompt_templates_router
from ..v1.security import router as security_router
from ..v1.plugin_registry import router as plugin_registry_router
from ..v1.platform import router as platform_router
from ..v1.llm_internal import router as llm_internal_router
@@ -53,11 +53,12 @@ internal_api_router.include_router(analytics_router, prefix="/analytics", tags=[
# Include RAG routes (frontend RAG document management)
internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"])
# Include RAG debug routes (for demo and debugging)
internal_api_router.include_router(rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"])
# Include prompt template routes (frontend prompt template management)
internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"])
# Include security routes (frontend security settings)
internal_api_router.include_router(security_router, prefix="/security", tags=["internal-security"])
# Include plugin registry routes (frontend plugin management)
internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"])

View File

@@ -0,0 +1,97 @@
"""
RAG Debug API endpoints for testing and debugging
"""
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Dict, Any, Optional
import logging
from app.core.security import get_current_user
from app.core.config import settings
from app.modules.rag.main import RAGModule
from app.models.user import User
logger = logging.getLogger(__name__)
# Create router
router = APIRouter()
@router.get("/collections")
async def list_collections(
current_user: User = Depends(get_current_user)
):
"""List all available RAG collections"""
try:
from app.services.qdrant_stats_service import qdrant_stats_service
# Get collections from Qdrant (same as main RAG API)
stats_data = await qdrant_stats_service.get_collections_stats()
collections = stats_data.get("collections", [])
# Extract collection names
collection_names = [col["name"] for col in collections]
return {
"collections": collection_names,
"count": len(collection_names)
}
except Exception as e:
logger.error(f"List collections error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/search")
async def debug_search(
query: str = Query(..., description="Search query"),
max_results: int = Query(10, ge=1, le=50, description="Maximum number of results"),
score_threshold: float = Query(0.3, ge=0.0, le=1.0, description="Minimum score threshold"),
collection_name: Optional[str] = Query(None, description="Collection name to search"),
config: Optional[Dict[str, Any]] = None,
current_user: User = Depends(get_current_user)
):
"""Debug search endpoint with detailed information"""
try:
# Get configuration
app_config = settings
# Initialize RAG module
rag_module = RAGModule(app_config)
# Get available collections if none specified
if not collection_name:
collections = await rag_module.list_collections()
if collections:
collection_name = collections[0] # Use first collection
else:
return {
"results": [],
"debug_info": {
"error": "No collections available",
"collections_found": 0
},
"search_time_ms": 0
}
# Perform search
results = await rag_module.search(
query=query,
max_results=max_results,
score_threshold=score_threshold,
collection_name=collection_name,
config=config or {}
)
return results
except Exception as e:
logger.error(f"Debug search error: {e}")
return {
"results": [],
"debug_info": {
"error": str(e),
"query": query,
"collection_name": collection_name
},
"search_time_ms": 0
}

View File

@@ -16,7 +16,6 @@ from .analytics import router as analytics_router
from .rag import router as rag_router
from .chatbot import router as chatbot_router
from .prompt_templates import router as prompt_templates_router
from .security import router as security_router
from .plugin_registry import router as plugin_registry_router
# Create main API router
@@ -61,8 +60,6 @@ api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
# Include prompt template routes
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
# Include security routes
api_router.include_router(security_router, prefix="/security", tags=["security"])
# Include plugin registry routes

View File

@@ -32,12 +32,28 @@ class ChatbotCreateRequest(BaseModel):
use_rag: bool = False
rag_collection: Optional[str] = None
rag_top_k: int = 5
rag_score_threshold: float = 0.02 # Lowered from default 0.3 to allow more results
temperature: float = 0.7
max_tokens: int = 1000
memory_length: int = 10
fallback_responses: List[str] = []
class ChatbotUpdateRequest(BaseModel):
name: Optional[str] = None
chatbot_type: Optional[str] = None
model: Optional[str] = None
system_prompt: Optional[str] = None
use_rag: Optional[bool] = None
rag_collection: Optional[str] = None
rag_top_k: Optional[int] = None
rag_score_threshold: Optional[float] = None
temperature: Optional[float] = None
max_tokens: Optional[int] = None
memory_length: Optional[int] = None
fallback_responses: Optional[List[str]] = None
class ChatRequest(BaseModel):
message: str
conversation_id: Optional[str] = None
@@ -190,7 +206,7 @@ async def create_chatbot(
@router.put("/update/{chatbot_id}")
async def update_chatbot(
chatbot_id: str,
request: ChatbotCreateRequest,
request: ChatbotUpdateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
@@ -214,28 +230,23 @@ async def update_chatbot(
if not chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found or access denied")
# Update chatbot configuration
config = {
"name": request.name,
"chatbot_type": request.chatbot_type,
"model": request.model,
"system_prompt": request.system_prompt,
"use_rag": request.use_rag,
"rag_collection": request.rag_collection,
"rag_top_k": request.rag_top_k,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"memory_length": request.memory_length,
"fallback_responses": request.fallback_responses
}
# Get existing config
existing_config = chatbot.config.copy() if chatbot.config else {}
# Update only the fields that are provided in the request
update_data = request.dict(exclude_unset=True)
# Merge with existing config, preserving unset values
for key, value in update_data.items():
existing_config[key] = value
# Update the chatbot
await db.execute(
update(ChatbotInstance)
.where(ChatbotInstance.id == chatbot_id)
.values(
name=request.name,
config=config,
name=existing_config.get("name", chatbot.name),
config=existing_config,
updated_at=datetime.utcnow()
)
)
@@ -275,7 +286,7 @@ async def chat_with_chatbot(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""Send a message to a chatbot and get a response"""
"""Send a message to a chatbot and get a response (without persisting conversation)"""
user_id = current_user.get("id") if isinstance(current_user, dict) else current_user.id
log_api_request("chat_with_chatbot", {
"user_id": user_id,
@@ -298,42 +309,17 @@ async def chat_with_chatbot(
if not chatbot.is_active:
raise HTTPException(status_code=400, detail="Chatbot is not active")
# Initialize conversation service
conversation_service = ConversationService(db)
# Get or create conversation
conversation = await conversation_service.get_or_create_conversation(
chatbot_id=chatbot_id,
user_id=str(user_id),
conversation_id=request.conversation_id
)
# Add user message to conversation
await conversation_service.add_message(
conversation_id=conversation.id,
role="user",
content=request.message,
metadata={}
)
# Get chatbot module and generate response
try:
chatbot_module = module_manager.modules.get("chatbot")
if not chatbot_module:
raise HTTPException(status_code=500, detail="Chatbot module not available")
# Load conversation history for context
conversation_history = await conversation_service.get_conversation_history(
conversation_id=conversation.id,
limit=chatbot.config.get('memory_length', 10),
include_system=False
)
# Use the chatbot module to generate a response
# Use the chatbot module to generate a response (without persisting)
response_data = await chatbot_module.chat(
chatbot_config=chatbot.config,
message=request.message,
conversation_history=conversation_history,
conversation_history=[], # Empty history for test chat
user_id=str(user_id)
)
@@ -346,19 +332,10 @@ async def chat_with_chatbot(
])
response_content = fallback_responses[0] if fallback_responses else "I'm sorry, I couldn't process your request."
# Save assistant message using conversation service
assistant_message = await conversation_service.add_message(
conversation_id=conversation.id,
role="assistant",
content=response_content,
metadata={},
sources=response_data.get("sources")
)
# Return response without conversation ID (since we're not persisting)
return {
"conversation_id": conversation.id,
"response": response_content,
"timestamp": assistant_message.timestamp.isoformat()
"sources": response_data.get("sources")
}
except HTTPException:

View File

@@ -745,7 +745,6 @@ async def get_llm_metrics(
"total_requests": metrics.total_requests,
"successful_requests": metrics.successful_requests,
"failed_requests": metrics.failed_requests,
"security_blocked_requests": metrics.security_blocked_requests,
"average_latency_ms": metrics.average_latency_ms,
"average_risk_score": metrics.average_risk_score,
"provider_metrics": metrics.provider_metrics,

View File

@@ -493,6 +493,25 @@ async def seed_default_templates(
existing_template.system_prompt = template_data["prompt"]
existing_template.updated_at = datetime.utcnow()
updated_templates.append(type_key)
else:
# Check if any inactive template exists with this type_key
inactive_result = await db.execute(
select(PromptTemplate)
.where(PromptTemplate.type_key == type_key)
.where(PromptTemplate.is_active == False)
)
inactive_template = inactive_result.scalar_one_or_none()
if inactive_template:
# Reactivate the inactive template
inactive_template.is_active = True
inactive_template.name = template_data["name"]
inactive_template.description = template_data["description"]
inactive_template.system_prompt = template_data["prompt"]
inactive_template.is_default = True
inactive_template.version = 1
inactive_template.updated_at = datetime.utcnow()
updated_templates.append(type_key)
else:
# Create new template
new_template = PromptTemplate(

View File

@@ -3,12 +3,14 @@ RAG API Endpoints
Provides REST API for RAG (Retrieval Augmented Generation) operations
"""
from typing import List, Optional
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
import io
import asyncio
from datetime import datetime
from app.db.database import get_db
from app.core.security import get_current_user
@@ -16,6 +18,9 @@ from app.models.user import User
from app.services.rag_service import RAGService
from app.utils.exceptions import APIException
# Import RAG module from module manager
from app.services.module_manager import module_manager
router = APIRouter(tags=["RAG"])
@@ -78,14 +83,25 @@ async def get_collections(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get all RAG collections from Qdrant (source of truth) with PostgreSQL metadata"""
"""Get all RAG collections - live data directly from Qdrant (source of truth)"""
try:
rag_service = RAGService(db)
collections_data = await rag_service.get_all_collections(skip=skip, limit=limit)
from app.services.qdrant_stats_service import qdrant_stats_service
# Get live stats from Qdrant
stats_data = await qdrant_stats_service.get_collections_stats()
collections = stats_data.get("collections", [])
# Apply pagination
start_idx = skip
end_idx = skip + limit
paginated_collections = collections[start_idx:end_idx]
return {
"success": True,
"collections": collections_data,
"total": len(collections_data)
"collections": paginated_collections,
"total": len(collections),
"total_documents": stats_data.get("total_documents", 0),
"total_size_bytes": stats_data.get("total_size_bytes", 0)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -116,6 +132,62 @@ async def create_collection(
raise HTTPException(status_code=500, detail=str(e))
@router.get("/stats", response_model=dict)
async def get_rag_stats(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get overall RAG statistics - live data directly from Qdrant"""
try:
from app.services.qdrant_stats_service import qdrant_stats_service
# Get live stats from Qdrant
stats_data = await qdrant_stats_service.get_collections_stats()
# Calculate active collections (collections with documents)
active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
# Calculate processing documents from database
processing_docs = 0
try:
from sqlalchemy import select
from app.models.rag_document import RagDocument, ProcessingStatus
result = await db.execute(
select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
)
processing_docs = len(result.scalars().all())
except Exception:
pass # If database query fails, default to 0
response_data = {
"success": True,
"stats": {
"collections": {
"total": stats_data.get("total_collections", 0),
"active": active_collections
},
"documents": {
"total": stats_data.get("total_documents", 0),
"processing": processing_docs,
"processed": stats_data.get("total_documents", 0) # Indexed documents
},
"storage": {
"total_size_bytes": stats_data.get("total_size_bytes", 0),
"total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
},
"vectors": {
"total": stats_data.get("total_documents", 0) # Same as documents for RAG
},
"last_updated": datetime.utcnow().isoformat()
}
}
return response_data
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/collections/{collection_id}", response_model=dict)
async def get_collection(
collection_id: int,
@@ -232,11 +304,55 @@ async def upload_document(
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
# Validate file can be read before processing
filename = file.filename or "unknown"
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
try:
# Test file readability based on type
if file_extension == 'jsonl':
# Validate JSONL format - try to parse first few lines
try:
content_str = file_content.decode('utf-8')
lines = content_str.strip().split('\n')[:5] # Check first 5 lines
import json
for i, line in enumerate(lines):
if line.strip(): # Skip empty lines
json.loads(line) # Will raise JSONDecodeError if invalid
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
# Validate text files can be decoded
try:
file_content.decode('utf-8')
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
elif file_extension in ['pdf']:
# For PDF files, just check if it starts with PDF signature
if not file_content.startswith(b'%PDF'):
raise HTTPException(status_code=400, detail="Invalid PDF file format")
elif file_extension in ['docx', 'xlsx', 'pptx']:
# For Office documents, check ZIP signature
if not file_content.startswith(b'PK'):
raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
# For other file types, we'll rely on the document processor
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
rag_service = RAGService(db)
document = await rag_service.upload_document(
collection_id=collection_id,
file_content=file_content,
filename=file.filename or "unknown",
filename=filename,
content_type=file.content_type
)
@@ -362,21 +478,167 @@ async def download_document(
raise HTTPException(status_code=500, detail=str(e))
# Stats Endpoint
@router.get("/stats", response_model=dict)
async def get_rag_stats(
db: AsyncSession = Depends(get_db),
# Debug Endpoints
@router.post("/debug/search")
async def search_with_debug(
query: str,
max_results: int = 10,
score_threshold: float = 0.3,
collection_name: str = None,
config: Dict[str, Any] = None,
current_user: User = Depends(get_current_user)
):
"""Get RAG system statistics"""
) -> Dict[str, Any]:
"""
Enhanced search with comprehensive debug information
"""
# Get RAG module from module manager
rag_module = module_manager.modules.get('rag')
if not rag_module or not rag_module.enabled:
raise HTTPException(status_code=503, detail="RAG module not initialized")
debug_info = {}
start_time = datetime.utcnow()
try:
rag_service = RAGService(db)
stats = await rag_service.get_stats()
# Apply configuration if provided
if config:
# Update RAG config temporarily
original_config = rag_module.config.copy()
rag_module.config.update(config)
# Generate query embedding (with or without prefix)
if config and config.get("use_query_prefix"):
optimized_query = f"query: {query}"
else:
optimized_query = query
query_embedding = await rag_module._generate_embedding(optimized_query)
# Store embedding info for debug
if config and config.get("debug", {}).get("show_embeddings"):
debug_info["query_embedding"] = query_embedding[:10] # First 10 dimensions
debug_info["embedding_dimension"] = len(query_embedding)
debug_info["optimized_query"] = optimized_query
# Perform search
search_start = asyncio.get_event_loop().time()
results = await rag_module.search_documents(
query,
max_results=max_results,
score_threshold=score_threshold,
collection_name=collection_name
)
search_time = (asyncio.get_event_loop().time() - search_start) * 1000
# Calculate score statistics
scores = [r.score for r in results if r.score is not None]
if scores:
import statistics
debug_info["score_stats"] = {
"min": min(scores),
"max": max(scores),
"avg": statistics.mean(scores),
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0
}
# Get collection statistics
try:
from qdrant_client.http.models import Filter
collection_name = collection_name or rag_module.default_collection_name
# Count total documents
count_result = rag_module.qdrant_client.count(
collection_name=collection_name,
count_filter=Filter(must=[])
)
total_points = count_result.count
# Get unique documents and languages
scroll_result = rag_module.qdrant_client.scroll(
collection_name=collection_name,
limit=1000, # Sample for stats
with_payload=True,
with_vectors=False
)
unique_docs = set()
languages = set()
for point in scroll_result[0]:
payload = point.payload or {}
doc_id = payload.get("document_id")
if doc_id:
unique_docs.add(doc_id)
language = payload.get("language")
if language:
languages.add(language)
debug_info["collection_stats"] = {
"total_documents": len(unique_docs),
"total_chunks": total_points,
"languages": sorted(list(languages))
}
except Exception as e:
debug_info["collection_stats_error"] = str(e)
# Enhance results with debug info
enhanced_results = []
for result in results:
enhanced_result = {
"document": {
"id": result.document.id,
"content": result.document.content,
"metadata": result.document.metadata
},
"score": result.score,
"debug_info": {}
}
# Add hybrid search debug info if available
metadata = result.document.metadata or {}
if "_vector_score" in metadata:
enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
if "_bm25_score" in metadata:
enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
enhanced_results.append(enhanced_result)
# Note: Analytics logging disabled (module not available)
return {
"success": True,
"stats": stats
"results": enhanced_results,
"debug_info": debug_info,
"search_time_ms": search_time,
"timestamp": start_time.isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Note: Analytics logging disabled (module not available)
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
finally:
# Restore original config if modified
if config and 'original_config' in locals():
rag_module.config = original_config
@router.get("/debug/config")
async def get_current_config(
current_user: User = Depends(get_current_user)
) -> Dict[str, Any]:
"""Get current RAG configuration"""
# Get RAG module from module manager
rag_module = module_manager.modules.get('rag')
if not rag_module or not rag_module.enabled:
raise HTTPException(status_code=503, detail="RAG module not initialized")
return {
"config": rag_module.config,
"embedding_model": rag_module.embedding_model,
"enabled": rag_module.enabled,
"collections": await rag_module._get_collections_safely()
}

View File

@@ -1,251 +0,0 @@
"""
Security API endpoints for monitoring and configuration
"""
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel, Field
from app.core.security import get_current_active_user, RequiresRole
from app.middleware.security import get_security_stats, get_request_auth_level, get_request_risk_score
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
router = APIRouter(tags=["security"])
# Pydantic models for API responses
class SecurityStatsResponse(BaseModel):
"""Security statistics response model"""
total_requests_analyzed: int
threats_detected: int
threats_blocked: int
anomalies_detected: int
rate_limits_exceeded: int
avg_analysis_time: float
threat_types: Dict[str, int]
threat_levels: Dict[str, int]
top_attacking_ips: List[tuple]
security_enabled: bool
threat_detection_enabled: bool
rate_limiting_enabled: bool
class SecurityConfigResponse(BaseModel):
"""Security configuration response model"""
security_enabled: bool = Field(description="Overall security system enabled")
threat_detection_enabled: bool = Field(description="Threat detection analysis enabled")
rate_limiting_enabled: bool = Field(description="Rate limiting enabled")
ip_reputation_enabled: bool = Field(description="IP reputation checking enabled")
anomaly_detection_enabled: bool = Field(description="Anomaly detection enabled")
security_headers_enabled: bool = Field(description="Security headers enabled")
# Rate limiting settings
unauthenticated_per_minute: int = Field(description="Rate limit for unauthenticated requests per minute")
authenticated_per_minute: int = Field(description="Rate limit for authenticated users per minute")
api_key_per_minute: int = Field(description="Rate limit for API key users per minute")
premium_per_minute: int = Field(description="Rate limit for premium users per minute")
# Security thresholds
risk_threshold: float = Field(description="Risk score threshold for blocking requests")
warning_threshold: float = Field(description="Risk score threshold for warnings")
anomaly_threshold: float = Field(description="Anomaly severity threshold")
# IP settings
blocked_ips: List[str] = Field(description="List of blocked IP addresses")
allowed_ips: List[str] = Field(description="List of allowed IP addresses (empty = allow all)")
class RateLimitInfoResponse(BaseModel):
"""Rate limit information for current request"""
auth_level: str = Field(description="Authentication level (unauthenticated, authenticated, api_key, premium)")
current_limits: Dict[str, int] = Field(description="Current rate limits for this auth level")
remaining_requests: Optional[Dict[str, int]] = Field(description="Estimated remaining requests (if available)")
@router.get("/stats", response_model=SecurityStatsResponse)
async def get_security_statistics(
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
):
"""
Get security system statistics
Requires admin role. Returns comprehensive statistics about:
- Request analysis counts
- Threat detection results
- Rate limiting enforcement
- Top attacking IPs
- Performance metrics
"""
try:
stats = get_security_stats()
return SecurityStatsResponse(**stats)
except Exception as e:
logger.error(f"Error getting security stats: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve security statistics"
)
@router.get("/config", response_model=SecurityConfigResponse)
async def get_security_config(
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
):
"""
Get current security configuration
Requires admin role. Returns current security settings including:
- Feature enablement flags
- Rate limiting thresholds
- Security thresholds
- IP allowlists/blocklists
"""
return SecurityConfigResponse(
security_enabled=settings.API_SECURITY_ENABLED,
threat_detection_enabled=settings.API_THREAT_DETECTION_ENABLED,
rate_limiting_enabled=settings.API_RATE_LIMITING_ENABLED,
ip_reputation_enabled=settings.API_IP_REPUTATION_ENABLED,
anomaly_detection_enabled=settings.API_ANOMALY_DETECTION_ENABLED,
security_headers_enabled=settings.API_SECURITY_HEADERS_ENABLED,
unauthenticated_per_minute=settings.API_RATE_LIMIT_UNAUTHENTICATED_PER_MINUTE,
authenticated_per_minute=settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE,
api_key_per_minute=settings.API_RATE_LIMIT_API_KEY_PER_MINUTE,
premium_per_minute=settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE,
risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
anomaly_threshold=settings.API_SECURITY_ANOMALY_THRESHOLD,
blocked_ips=settings.API_BLOCKED_IPS,
allowed_ips=settings.API_ALLOWED_IPS
)
@router.get("/status")
async def get_security_status(
request: Request,
current_user: Dict[str, Any] = Depends(get_current_active_user)
):
"""
Get security status for current request
Returns information about the security analysis of the current request:
- Authentication level
- Risk score (if available)
- Rate limiting status
"""
auth_level = get_request_auth_level(request)
risk_score = get_request_risk_score(request)
# Get rate limits for current auth level
from app.core.threat_detection import AuthLevel
try:
auth_enum = AuthLevel(auth_level)
from app.core.threat_detection import threat_detection_service
minute_limit, hour_limit = threat_detection_service.get_rate_limits(auth_enum)
rate_limit_info = RateLimitInfoResponse(
auth_level=auth_level,
current_limits={
"per_minute": minute_limit,
"per_hour": hour_limit
},
remaining_requests=None # We don't track remaining requests in current implementation
)
except ValueError:
rate_limit_info = RateLimitInfoResponse(
auth_level=auth_level,
current_limits={},
remaining_requests=None
)
return {
"security_enabled": settings.API_SECURITY_ENABLED,
"auth_level": auth_level,
"risk_score": round(risk_score, 3) if risk_score > 0 else None,
"rate_limit_info": rate_limit_info.dict(),
"security_headers_enabled": settings.API_SECURITY_HEADERS_ENABLED
}
@router.post("/test")
async def test_security_analysis(
request: Request,
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
):
"""
Test security analysis on current request
Requires admin role. Manually triggers security analysis on the current request
and returns detailed results. Useful for testing security rules and thresholds.
"""
try:
from app.middleware.security import analyze_request_security
analysis = await analyze_request_security(request, current_user)
return {
"analysis_complete": True,
"is_threat": analysis.is_threat,
"risk_score": round(analysis.risk_score, 3),
"auth_level": analysis.auth_level.value,
"should_block": analysis.should_block,
"rate_limit_exceeded": analysis.rate_limit_exceeded,
"threat_count": len(analysis.threats),
"threats": [
{
"type": threat.threat_type,
"level": threat.level.value,
"confidence": round(threat.confidence, 3),
"description": threat.description,
"mitigation": threat.mitigation
}
for threat in analysis.threats
],
"recommendations": analysis.recommendations
}
except Exception as e:
logger.error(f"Error in security analysis test: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to perform security analysis test"
)
@router.get("/health")
async def security_health_check():
"""
Security system health check
Public endpoint that returns the health status of the security system.
Does not require authentication.
"""
try:
stats = get_security_stats()
# Basic health checks
is_healthy = (
settings.API_SECURITY_ENABLED and
stats.get("total_requests_analyzed", 0) >= 0 and
stats.get("avg_analysis_time", 0) < 1.0 # Analysis should be under 1 second
)
return {
"status": "healthy" if is_healthy else "degraded",
"security_enabled": settings.API_SECURITY_ENABLED,
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED,
"avg_analysis_time_ms": round(stats.get("avg_analysis_time", 0) * 1000, 2),
"total_requests_analyzed": stats.get("total_requests_analyzed", 0)
}
except Exception as e:
logger.error(f"Security health check failed: {e}")
return {
"status": "unhealthy",
"error": "Security system error",
"security_enabled": settings.API_SECURITY_ENABLED
}

View File

@@ -97,7 +97,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
"api": {
# Security Settings
"security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"},
"threat_detection_enabled": {"value": True, "type": "boolean", "description": "Enable threat detection analysis"},
"rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"},
"ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"},
"anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"},
@@ -112,7 +111,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"},
# Security Thresholds
"security_risk_threshold": {"value": 0.8, "type": "float", "description": "Risk score threshold for blocking requests (0.0-1.0)"},
"security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"},
"anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"},
@@ -601,7 +599,6 @@ async def reset_to_defaults(
"api": {
# Security Settings
"security_enabled": {"value": True, "type": "boolean"},
"threat_detection_enabled": {"value": True, "type": "boolean"},
"rate_limiting_enabled": {"value": True, "type": "boolean"},
"ip_reputation_enabled": {"value": True, "type": "boolean"},
"anomaly_detection_enabled": {"value": True, "type": "boolean"},
@@ -616,7 +613,6 @@ async def reset_to_defaults(
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer"},
# Security Thresholds
"security_risk_threshold": {"value": 0.8, "type": "float"},
"security_warning_threshold": {"value": 0.6, "type": "float"},
"anomaly_threshold": {"value": 0.7, "type": "float"},

View File

@@ -17,6 +17,8 @@ class Settings(BaseSettings):
APP_LOG_LEVEL: str = os.getenv("APP_LOG_LEVEL", "INFO")
APP_HOST: str = os.getenv("APP_HOST", "0.0.0.0")
APP_PORT: int = int(os.getenv("APP_PORT", "8000"))
BACKEND_INTERNAL_PORT: int = int(os.getenv("BACKEND_INTERNAL_PORT", "8000"))
FRONTEND_INTERNAL_PORT: int = int(os.getenv("FRONTEND_INTERNAL_PORT", "3000"))
# Detailed logging for LLM interactions
LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM
@@ -75,44 +77,37 @@ class Settings(BaseSettings):
QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost")
QDRANT_PORT: int = int(os.getenv("QDRANT_PORT", "6333"))
QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY")
QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
# API & Security Settings
API_SECURITY_ENABLED: bool = os.getenv("API_SECURITY_ENABLED", "True").lower() == "true"
API_THREAT_DETECTION_ENABLED: bool = os.getenv("API_THREAT_DETECTION_ENABLED", "True").lower() == "true"
API_IP_REPUTATION_ENABLED: bool = os.getenv("API_IP_REPUTATION_ENABLED", "True").lower() == "true"
API_ANOMALY_DETECTION_ENABLED: bool = os.getenv("API_ANOMALY_DETECTION_ENABLED", "True").lower() == "true"
# Rate Limiting Configuration
API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true"
# Authenticated users (JWT token)
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "300"))
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "5000"))
# PrivateMode Standard tier limits (organization-level, not per user)
# These are shared across all API keys and users in the organization
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20"))
PRIVATEMODE_REQUESTS_PER_HOUR: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_HOUR", "1200"))
PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_PROMPT_TOKENS_PER_MINUTE", "20000"))
PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_COMPLETION_TOKENS_PER_MINUTE", "10000"))
# Per-user limits (additional protection on top of organization limits)
API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_AUTHENTICATED_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_AUTHENTICATED_PER_HOUR", "1200"))
# API key users (programmatic access)
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "1000"))
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "20000"))
API_RATE_LIMIT_API_KEY_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_API_KEY_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_API_KEY_PER_HOUR", "1200"))
# Premium/Enterprise API keys
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "5000"))
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "100000"))
# Security Thresholds
API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score
API_SECURITY_WARNING_THRESHOLD: float = float(os.getenv("API_SECURITY_WARNING_THRESHOLD", "0.6")) # Log warnings above this threshold
API_SECURITY_ANOMALY_THRESHOLD: float = float(os.getenv("API_SECURITY_ANOMALY_THRESHOLD", "0.7")) # Flag anomalies above this threshold
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
# Request Size Limits
API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium
# IP Security
API_BLOCKED_IPS: List[str] = os.getenv("API_BLOCKED_IPS", "").split(",") if os.getenv("API_BLOCKED_IPS") else []
API_ALLOWED_IPS: List[str] = os.getenv("API_ALLOWED_IPS", "").split(",") if os.getenv("API_ALLOWED_IPS") else []
API_IP_REPUTATION_CACHE_TTL: int = int(os.getenv("API_IP_REPUTATION_CACHE_TTL", "3600")) # 1 hour
# Security Headers
API_SECURITY_HEADERS_ENABLED: bool = os.getenv("API_SECURITY_HEADERS_ENABLED", "True").lower() == "true"
API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
# Monitoring
@@ -125,6 +120,19 @@ class Settings(BaseSettings):
# Module configuration
MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml")
# RAG Embedding Configuration
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12"))
RAG_EMBEDDING_BATCH_SIZE: int = int(os.getenv("RAG_EMBEDDING_BATCH_SIZE", "3"))
RAG_EMBEDDING_RETRY_COUNT: int = int(os.getenv("RAG_EMBEDDING_RETRY_COUNT", "3"))
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv("RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16")
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0"))
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
# Plugin configuration
PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins")
PLUGINS_CONFIG_PATH: str = os.getenv("PLUGINS_CONFIG_PATH", "config/plugins.yaml")
@@ -137,17 +145,13 @@ class Settings(BaseSettings):
model_config = {
"env_file": ".env",
"case_sensitive": True
"case_sensitive": True,
# Ignore unknown environment variables to avoid validation errors
# when optional/deprecated flags are present in .env
"extra": "ignore",
}
# Global settings instance
settings = Settings()
# Log configuration values for debugging
import logging
logger = logging.getLogger(__name__)
logger.info(f"JWT Configuration loaded:")
logger.info(f"ACCESS_TOKEN_EXPIRE_MINUTES: {settings.ACCESS_TOKEN_EXPIRE_MINUTES}")
logger.info(f"REFRESH_TOKEN_EXPIRE_MINUTES: {settings.REFRESH_TOKEN_EXPIRE_MINUTES}")
logger.info(f"JWT_ALGORITHM: {settings.JWT_ALGORITHM}")

View File

@@ -1,744 +0,0 @@
"""
Core threat detection and security analysis for the platform
"""
import re
import time
from collections import defaultdict, deque
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Any, Union
from urllib.parse import unquote
from fastapi import Request
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
class ThreatLevel(Enum):
"""Threat severity levels"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class AuthLevel(Enum):
"""Authentication levels for rate limiting"""
AUTHENTICATED = "authenticated"
API_KEY = "api_key"
PREMIUM = "premium"
@dataclass
class SecurityThreat:
"""Security threat detection result"""
threat_type: str
level: ThreatLevel
confidence: float
description: str
source_ip: str
user_agent: Optional[str] = None
request_path: Optional[str] = None
payload: Optional[str] = None
timestamp: datetime = field(default_factory=datetime.utcnow)
mitigation: Optional[str] = None
@dataclass
class SecurityAnalysis:
"""Comprehensive security analysis result"""
is_threat: bool
threats: List[SecurityThreat]
risk_score: float
recommendations: List[str]
auth_level: AuthLevel
rate_limit_exceeded: bool
should_block: bool
timestamp: datetime = field(default_factory=datetime.utcnow)
@dataclass
class RateLimitInfo:
"""Rate limiting information"""
auth_level: AuthLevel
requests_per_minute: int
requests_per_hour: int
minute_limit: int
hour_limit: int
exceeded: bool
@dataclass
class AnomalyDetection:
"""Anomaly detection result"""
is_anomaly: bool
anomaly_type: str
severity: float
details: Dict[str, Any]
baseline_value: Optional[float] = None
current_value: Optional[float] = None
class ThreatDetectionService:
"""Core threat detection and security analysis service"""
def __init__(self):
self.name = "threat_detection"
# Statistics
self.stats = {
'total_requests_analyzed': 0,
'threats_detected': 0,
'threats_blocked': 0,
'anomalies_detected': 0,
'rate_limits_exceeded': 0,
'total_analysis_time': 0,
'threat_types': defaultdict(int),
'threat_levels': defaultdict(int),
'attacking_ips': defaultdict(int)
}
# Threat detection patterns
self.sql_injection_patterns = [
r"(\bunion\b.*\bselect\b)",
r"(\bselect\b.*\bfrom\b)",
r"(\binsert\b.*\binto\b)",
r"(\bupdate\b.*\bset\b)",
r"(\bdelete\b.*\bfrom\b)",
r"(\bdrop\b.*\btable\b)",
r"(\bor\b.*\b1\s*=\s*1\b)",
r"(\band\b.*\b1\s*=\s*1\b)",
r"(\bexec\b.*\bxp_\w+)",
r"(\bsp_\w+)",
r"(\bsleep\b\s*\(\s*\d+\s*\))",
r"(\bwaitfor\b.*\bdelay\b)",
r"(\bbenchmark\b\s*\(\s*\d+)",
r"(\bload_file\b\s*\()",
r"(\binto\b.*\boutfile\b)"
]
self.xss_patterns = [
r"<script[^>]*>.*?</script>",
r"<iframe[^>]*>.*?</iframe>",
r"<object[^>]*>.*?</object>",
r"<embed[^>]*>.*?</embed>",
r"<link[^>]*>",
r"<meta[^>]*>",
r"javascript:",
r"vbscript:",
r"on\w+\s*=",
r"style\s*=.*expression",
r"style\s*=.*javascript"
]
self.path_traversal_patterns = [
r"\.\.\/",
r"\.\.\\",
r"%2e%2e%2f",
r"%2e%2e%5c",
r"..%2f",
r"..%5c",
r"%252e%252e%252f",
r"%252e%252e%255c"
]
self.command_injection_patterns = [
r";\s*cat\s+",
r";\s*ls\s+",
r";\s*pwd\s*",
r";\s*whoami\s*",
r";\s*id\s*",
r";\s*uname\s*",
r";\s*ps\s+",
r";\s*netstat\s+",
r";\s*wget\s+",
r";\s*curl\s+",
r"\|\s*cat\s+",
r"\|\s*ls\s+",
r"&&\s*cat\s+",
r"&&\s*ls\s+"
]
self.suspicious_ua_patterns = [
r"sqlmap",
r"nikto",
r"nmap",
r"masscan",
r"zap",
r"burp",
r"w3af",
r"acunetix",
r"nessus",
r"openvas",
r"metasploit"
]
# Rate limiting tracking - separate by auth level (excluding unauthenticated since they're blocked)
self.rate_limits = {
AuthLevel.AUTHENTICATED: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
AuthLevel.API_KEY: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
AuthLevel.PREMIUM: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)})
}
# Anomaly detection
self.request_history = deque(maxlen=1000)
self.ip_history = defaultdict(lambda: deque(maxlen=100))
self.endpoint_history = defaultdict(lambda: deque(maxlen=100))
# Blocked and allowed IPs
self.blocked_ips = set(settings.API_BLOCKED_IPS)
self.allowed_ips = set(settings.API_ALLOWED_IPS) if settings.API_ALLOWED_IPS else None
# IP reputation cache
self.ip_reputation_cache = {}
self.cache_expiry = {}
# Compile patterns for performance
self._compile_patterns()
logger.info(f"ThreatDetectionService initialized with {len(self.sql_injection_patterns)} SQL patterns, "
f"{len(self.xss_patterns)} XSS patterns, rate limiting enabled: {settings.API_RATE_LIMITING_ENABLED}")
def _compile_patterns(self):
"""Compile regex patterns for better performance"""
try:
self.compiled_sql_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.sql_injection_patterns]
self.compiled_xss_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.xss_patterns]
self.compiled_path_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.path_traversal_patterns]
self.compiled_cmd_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.command_injection_patterns]
self.compiled_ua_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.suspicious_ua_patterns]
except re.error as e:
logger.error(f"Failed to compile security patterns: {e}")
# Fallback to empty lists to prevent crashes
self.compiled_sql_patterns = []
self.compiled_xss_patterns = []
self.compiled_path_patterns = []
self.compiled_cmd_patterns = []
self.compiled_ua_patterns = []
def determine_auth_level(self, request: Request, user_context: Optional[Dict] = None) -> AuthLevel:
"""Determine authentication level for rate limiting"""
# Check if request has API key authentication
if hasattr(request.state, 'api_key_context') and request.state.api_key_context:
api_key = request.state.api_key_context.get('api_key')
if api_key and hasattr(api_key, 'tier'):
# Check for premium tier
if api_key.tier in ['premium', 'enterprise']:
return AuthLevel.PREMIUM
return AuthLevel.API_KEY
# Check for JWT authentication
if user_context or hasattr(request.state, 'user'):
return AuthLevel.AUTHENTICATED
# Check Authorization header for API key
auth_header = request.headers.get("Authorization", "")
api_key_header = request.headers.get("X-API-Key", "")
if auth_header.startswith("Bearer ") or api_key_header:
return AuthLevel.API_KEY
# Default to authenticated since unauthenticated requests are blocked at middleware
return AuthLevel.AUTHENTICATED
def get_rate_limits(self, auth_level: AuthLevel) -> Tuple[int, int]:
"""Get rate limits for authentication level"""
if not settings.API_RATE_LIMITING_ENABLED:
return float('inf'), float('inf')
if auth_level == AuthLevel.AUTHENTICATED:
return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
elif auth_level == AuthLevel.API_KEY:
return (settings.API_RATE_LIMIT_API_KEY_PER_MINUTE, settings.API_RATE_LIMIT_API_KEY_PER_HOUR)
elif auth_level == AuthLevel.PREMIUM:
return (settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE, settings.API_RATE_LIMIT_PREMIUM_PER_HOUR)
else:
# Fallback to authenticated limits
return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
def check_rate_limit(self, client_ip: str, auth_level: AuthLevel) -> RateLimitInfo:
"""Check if request exceeds rate limits"""
minute_limit, hour_limit = self.get_rate_limits(auth_level)
current_time = time.time()
# Get or create tracking for this auth level
if auth_level not in self.rate_limits:
# This shouldn't happen, but handle gracefully
return RateLimitInfo(
auth_level=auth_level,
requests_per_minute=0,
requests_per_hour=0,
minute_limit=minute_limit,
hour_limit=hour_limit,
exceeded=False
)
ip_limits = self.rate_limits[auth_level][client_ip]
# Clean old entries
minute_ago = current_time - 60
hour_ago = current_time - 3600
while ip_limits['minute'] and ip_limits['minute'][0] < minute_ago:
ip_limits['minute'].popleft()
while ip_limits['hour'] and ip_limits['hour'][0] < hour_ago:
ip_limits['hour'].popleft()
# Check current counts
requests_per_minute = len(ip_limits['minute'])
requests_per_hour = len(ip_limits['hour'])
# Check if limits exceeded
exceeded = (requests_per_minute >= minute_limit) or (requests_per_hour >= hour_limit)
# Add current request to tracking
if not exceeded:
ip_limits['minute'].append(current_time)
ip_limits['hour'].append(current_time)
return RateLimitInfo(
auth_level=auth_level,
requests_per_minute=requests_per_minute,
requests_per_hour=requests_per_hour,
minute_limit=minute_limit,
hour_limit=hour_limit,
exceeded=exceeded
)
async def analyze_request(self, request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
"""Perform comprehensive security analysis on a request"""
start_time = time.time()
try:
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "")
path = str(request.url.path)
method = request.method
# Determine authentication level
auth_level = self.determine_auth_level(request, user_context)
# Check IP allowlist/blocklist first
if self.allowed_ips and client_ip not in self.allowed_ips:
threat = SecurityThreat(
threat_type="ip_not_allowed",
level=ThreatLevel.HIGH,
confidence=1.0,
description=f"IP {client_ip} not in allowlist",
source_ip=client_ip,
mitigation="Add IP to allowlist or remove IP restrictions"
)
return SecurityAnalysis(
is_threat=True,
threats=[threat],
risk_score=1.0,
recommendations=["Block request immediately"],
auth_level=auth_level,
rate_limit_exceeded=False,
should_block=True
)
if client_ip in self.blocked_ips:
threat = SecurityThreat(
threat_type="ip_blocked",
level=ThreatLevel.CRITICAL,
confidence=1.0,
description=f"IP {client_ip} is blocked",
source_ip=client_ip,
mitigation="Remove IP from blocklist if legitimate"
)
return SecurityAnalysis(
is_threat=True,
threats=[threat],
risk_score=1.0,
recommendations=["Block request immediately"],
auth_level=auth_level,
rate_limit_exceeded=False,
should_block=True
)
# Check rate limiting
rate_limit_info = self.check_rate_limit(client_ip, auth_level)
if rate_limit_info.exceeded:
self.stats['rate_limits_exceeded'] += 1
threat = SecurityThreat(
threat_type="rate_limit_exceeded",
level=ThreatLevel.MEDIUM,
confidence=0.9,
description=f"Rate limit exceeded for {auth_level.value}: {rate_limit_info.requests_per_minute}/min, {rate_limit_info.requests_per_hour}/hr",
source_ip=client_ip,
mitigation=f"Implement rate limiting, current limits: {rate_limit_info.minute_limit}/min, {rate_limit_info.hour_limit}/hr"
)
return SecurityAnalysis(
is_threat=True,
threats=[threat],
risk_score=0.7,
recommendations=[f"Rate limit exceeded for {auth_level.value} user"],
auth_level=auth_level,
rate_limit_exceeded=True,
should_block=True
)
# Skip threat detection if disabled
if not settings.API_THREAT_DETECTION_ENABLED:
return SecurityAnalysis(
is_threat=False,
threats=[],
risk_score=0.0,
recommendations=[],
auth_level=auth_level,
rate_limit_exceeded=False,
should_block=False
)
# Collect request data for threat analysis
query_params = str(request.query_params)
headers = dict(request.headers)
# Try to get body content safely
body_content = ""
try:
if hasattr(request, '_body') and request._body:
body_content = request._body.decode() if isinstance(request._body, bytes) else str(request._body)
except:
pass
threats = []
# Analyze for various threats
threats.extend(await self._detect_sql_injection(query_params, body_content, path, client_ip))
threats.extend(await self._detect_xss(query_params, body_content, headers, client_ip))
threats.extend(await self._detect_path_traversal(path, query_params, client_ip))
threats.extend(await self._detect_command_injection(query_params, body_content, client_ip))
threats.extend(await self._detect_suspicious_patterns(headers, user_agent, path, client_ip))
# Anomaly detection if enabled
if settings.API_ANOMALY_DETECTION_ENABLED:
anomaly = await self._detect_anomalies(client_ip, path, method, len(body_content))
if anomaly.is_anomaly and anomaly.severity > settings.API_SECURITY_ANOMALY_THRESHOLD:
threat = SecurityThreat(
threat_type=f"anomaly_{anomaly.anomaly_type}",
level=ThreatLevel.MEDIUM if anomaly.severity > 0.7 else ThreatLevel.LOW,
confidence=anomaly.severity,
description=f"Anomalous behavior detected: {anomaly.details}",
source_ip=client_ip,
user_agent=user_agent,
request_path=path
)
threats.append(threat)
# Calculate risk score
risk_score = self._calculate_risk_score(threats)
# Determine if request should be blocked
should_block = risk_score >= settings.API_SECURITY_RISK_THRESHOLD
# Generate recommendations
recommendations = self._generate_recommendations(threats, risk_score, auth_level)
# Update statistics
self._update_stats(threats, time.time() - start_time)
return SecurityAnalysis(
is_threat=len(threats) > 0,
threats=threats,
risk_score=risk_score,
recommendations=recommendations,
auth_level=auth_level,
rate_limit_exceeded=False,
should_block=should_block
)
except Exception as e:
logger.error(f"Error in threat analysis: {e}")
return SecurityAnalysis(
is_threat=False,
threats=[],
risk_score=0.0,
recommendations=["Error occurred during security analysis"],
auth_level=AuthLevel.AUTHENTICATED,
rate_limit_exceeded=False,
should_block=False
)
async def _detect_sql_injection(self, query_params: str, body_content: str, path: str, client_ip: str) -> List[SecurityThreat]:
"""Detect SQL injection attempts"""
threats = []
content_to_check = f"{query_params} {body_content} {path}".lower()
for pattern in self.compiled_sql_patterns:
if pattern.search(content_to_check):
threat = SecurityThreat(
threat_type="sql_injection",
level=ThreatLevel.HIGH,
confidence=0.85,
description="Potential SQL injection attempt detected",
source_ip=client_ip,
payload=pattern.pattern,
mitigation="Block request, sanitize input, use parameterized queries"
)
threats.append(threat)
break # Don't duplicate for multiple patterns
return threats
async def _detect_xss(self, query_params: str, body_content: str, headers: dict, client_ip: str) -> List[SecurityThreat]:
"""Detect XSS attempts"""
threats = []
content_to_check = f"{query_params} {body_content}".lower()
# Check headers for XSS
for header_name, header_value in headers.items():
content_to_check += f" {header_value}".lower()
for pattern in self.compiled_xss_patterns:
if pattern.search(content_to_check):
threat = SecurityThreat(
threat_type="xss",
level=ThreatLevel.HIGH,
confidence=0.80,
description="Potential XSS attack detected",
source_ip=client_ip,
payload=pattern.pattern,
mitigation="Block request, sanitize input, implement CSP headers"
)
threats.append(threat)
break
return threats
async def _detect_path_traversal(self, path: str, query_params: str, client_ip: str) -> List[SecurityThreat]:
"""Detect path traversal attempts"""
threats = []
content_to_check = f"{path} {query_params}".lower()
decoded_content = unquote(content_to_check)
for pattern in self.compiled_path_patterns:
if pattern.search(content_to_check) or pattern.search(decoded_content):
threat = SecurityThreat(
threat_type="path_traversal",
level=ThreatLevel.HIGH,
confidence=0.90,
description="Path traversal attempt detected",
source_ip=client_ip,
request_path=path,
mitigation="Block request, validate file paths, implement access controls"
)
threats.append(threat)
break
return threats
async def _detect_command_injection(self, query_params: str, body_content: str, client_ip: str) -> List[SecurityThreat]:
"""Detect command injection attempts"""
threats = []
content_to_check = f"{query_params} {body_content}".lower()
for pattern in self.compiled_cmd_patterns:
if pattern.search(content_to_check):
threat = SecurityThreat(
threat_type="command_injection",
level=ThreatLevel.CRITICAL,
confidence=0.95,
description="Command injection attempt detected",
source_ip=client_ip,
payload=pattern.pattern,
mitigation="Block request immediately, sanitize input, disable shell execution"
)
threats.append(threat)
break
return threats
async def _detect_suspicious_patterns(self, headers: dict, user_agent: str, path: str, client_ip: str) -> List[SecurityThreat]:
"""Detect suspicious patterns in headers and user agent"""
threats = []
# Check for suspicious user agents
ua_lower = user_agent.lower()
for pattern in self.compiled_ua_patterns:
if pattern.search(ua_lower):
threat = SecurityThreat(
threat_type="suspicious_user_agent",
level=ThreatLevel.HIGH,
confidence=0.85,
description=f"Suspicious user agent detected: {pattern.pattern}",
source_ip=client_ip,
user_agent=user_agent,
mitigation="Block request, monitor IP for further activity"
)
threats.append(threat)
break
# Check for suspicious headers
if "x-forwarded-for" in headers and "x-real-ip" in headers:
# Potential header manipulation
threat = SecurityThreat(
threat_type="header_manipulation",
level=ThreatLevel.LOW,
confidence=0.30,
description="Potential IP header manipulation detected",
source_ip=client_ip,
mitigation="Validate proxy headers, implement IP whitelisting"
)
threats.append(threat)
return threats
async def _detect_anomalies(self, client_ip: str, path: str, method: str, body_size: int) -> AnomalyDetection:
"""Detect anomalous behavior patterns"""
try:
# Request size anomaly
max_size = settings.API_MAX_REQUEST_BODY_SIZE
if body_size > max_size:
return AnomalyDetection(
is_anomaly=True,
anomaly_type="request_size",
severity=0.8,
details={"body_size": body_size, "threshold": max_size},
current_value=body_size,
baseline_value=max_size // 10
)
# Unusual endpoint access
if path.startswith("/admin") or path.startswith("/api/admin"):
return AnomalyDetection(
is_anomaly=True,
anomaly_type="sensitive_endpoint",
severity=0.6,
details={"path": path, "reason": "admin endpoint access"},
current_value=1.0,
baseline_value=0.0
)
# IP request frequency anomaly
current_time = time.time()
ip_requests = self.ip_history[client_ip]
# Clean old entries (last 5 minutes)
five_minutes_ago = current_time - 300
while ip_requests and ip_requests[0] < five_minutes_ago:
ip_requests.popleft()
ip_requests.append(current_time)
if len(ip_requests) > 100: # More than 100 requests in 5 minutes
return AnomalyDetection(
is_anomaly=True,
anomaly_type="request_frequency",
severity=0.7,
details={"requests_5min": len(ip_requests), "threshold": 100},
current_value=len(ip_requests),
baseline_value=10 # 10 requests baseline
)
return AnomalyDetection(
is_anomaly=False,
anomaly_type="none",
severity=0.0,
details={}
)
except Exception as e:
logger.error(f"Error in anomaly detection: {e}")
return AnomalyDetection(
is_anomaly=False,
anomaly_type="error",
severity=0.0,
details={"error": str(e)}
)
def _calculate_risk_score(self, threats: List[SecurityThreat]) -> float:
"""Calculate overall risk score based on threats"""
if not threats:
return 0.0
score = 0.0
for threat in threats:
level_multiplier = {
ThreatLevel.LOW: 0.25,
ThreatLevel.MEDIUM: 0.5,
ThreatLevel.HIGH: 0.75,
ThreatLevel.CRITICAL: 1.0
}
score += threat.confidence * level_multiplier.get(threat.level, 0.5)
# Normalize to 0-1 range
return min(score / len(threats), 1.0)
def _generate_recommendations(self, threats: List[SecurityThreat], risk_score: float, auth_level: AuthLevel) -> List[str]:
"""Generate security recommendations based on analysis"""
recommendations = []
if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
recommendations.append("CRITICAL: Block this request immediately")
elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
recommendations.append("HIGH: Consider blocking or rate limiting this IP")
elif risk_score > 0.4:
recommendations.append("MEDIUM: Monitor this IP closely")
threat_types = {threat.threat_type for threat in threats}
if "sql_injection" in threat_types:
recommendations.append("Implement parameterized queries and input validation")
if "xss" in threat_types:
recommendations.append("Implement Content Security Policy (CSP) headers")
if "command_injection" in threat_types:
recommendations.append("Disable shell execution and validate all inputs")
if "path_traversal" in threat_types:
recommendations.append("Implement proper file path validation and access controls")
if "rate_limit_exceeded" in threat_types:
recommendations.append(f"Rate limiting active for {auth_level.value} user")
if not recommendations:
recommendations.append("No immediate action required, continue monitoring")
return recommendations
def _update_stats(self, threats: List[SecurityThreat], analysis_time: float):
"""Update service statistics"""
self.stats['total_requests_analyzed'] += 1
self.stats['total_analysis_time'] += analysis_time
if threats:
self.stats['threats_detected'] += len(threats)
for threat in threats:
self.stats['threat_types'][threat.threat_type] += 1
self.stats['threat_levels'][threat.level.value] += 1
if threat.source_ip:
self.stats['attacking_ips'][threat.source_ip] += 1
def get_stats(self) -> Dict[str, Any]:
"""Get service statistics"""
avg_time = (self.stats['total_analysis_time'] / self.stats['total_requests_analyzed']
if self.stats['total_requests_analyzed'] > 0 else 0)
# Get top attacking IPs
top_ips = sorted(self.stats['attacking_ips'].items(), key=lambda x: x[1], reverse=True)[:10]
return {
"total_requests_analyzed": self.stats['total_requests_analyzed'],
"threats_detected": self.stats['threats_detected'],
"threats_blocked": self.stats['threats_blocked'],
"anomalies_detected": self.stats['anomalies_detected'],
"rate_limits_exceeded": self.stats['rate_limits_exceeded'],
"avg_analysis_time": avg_time,
"threat_types": dict(self.stats['threat_types']),
"threat_levels": dict(self.stats['threat_levels']),
"top_attacking_ips": top_ips,
"security_enabled": settings.API_SECURITY_ENABLED,
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED
}
# Global threat detection service instance
threat_detection_service = ThreatDetectionService()

View File

@@ -125,6 +125,14 @@ async def lifespan(app: FastAPI):
# Initialize config manager
await init_config_manager()
# Initialize LLM service (needed by RAG module)
from app.services.llm.service import llm_service
try:
await llm_service.initialize()
logger.info("LLM service initialized successfully")
except Exception as e:
logger.warning(f"LLM service initialization failed: {e}")
# Initialize analytics service
init_analytics_service()
@@ -215,13 +223,9 @@ app.add_middleware(
# Add analytics middleware
setup_analytics_middleware(app)
# Add debugging middleware for detailed request/response logging
from app.middleware.debugging import setup_debugging_middleware
setup_debugging_middleware(app)
# Security middleware disabled - handled externally
# Add security middleware
from app.middleware.security import setup_security_middleware
setup_security_middleware(app, enabled=settings.API_SECURITY_ENABLED)
# Rate limiting middleware disabled - handled externally
# Exception handlers

View File

@@ -1,313 +0,0 @@
"""
Rate limiting middleware
"""
import time
import redis
from typing import Dict, Optional
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
import asyncio
from datetime import datetime, timedelta
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
class RateLimiter:
"""Rate limiting implementation using Redis"""
def __init__(self):
try:
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
self.redis_client.ping() # Test connection
logger.info("Rate limiter initialized with Redis backend")
except Exception as e:
logger.warning(f"Redis not available for rate limiting: {e}")
self.redis_client = None
# Fall back to in-memory rate limiting
self.memory_store: Dict[str, Dict[str, float]] = {}
async def check_rate_limit(
self,
key: str,
limit: int,
window_seconds: int,
identifier: str = "default"
) -> tuple[bool, Dict[str, int]]:
"""
Check if request is within rate limit
Args:
key: Rate limiting key (e.g., IP address, API key)
limit: Maximum number of requests allowed
window_seconds: Time window in seconds
identifier: Additional identifier for the rate limit
Returns:
Tuple of (is_allowed, headers_dict)
"""
full_key = f"rate_limit:{identifier}:{key}"
current_time = int(time.time())
window_start = current_time - window_seconds
if self.redis_client:
return await self._check_redis_rate_limit(
full_key, limit, window_seconds, current_time, window_start
)
else:
return self._check_memory_rate_limit(
full_key, limit, window_seconds, current_time, window_start
)
async def _check_redis_rate_limit(
self,
key: str,
limit: int,
window_seconds: int,
current_time: int,
window_start: int
) -> tuple[bool, Dict[str, int]]:
"""Check rate limit using Redis"""
pipe = self.redis_client.pipeline()
# Remove old entries
pipe.zremrangebyscore(key, 0, window_start)
# Count current requests in window
pipe.zcard(key)
# Add current request
pipe.zadd(key, {str(current_time): current_time})
# Set expiration
pipe.expire(key, window_seconds + 1)
results = pipe.execute()
current_requests = results[1]
# Calculate remaining requests and reset time
remaining = max(0, limit - current_requests - 1)
reset_time = current_time + window_seconds
headers = {
"X-RateLimit-Limit": limit,
"X-RateLimit-Remaining": remaining,
"X-RateLimit-Reset": reset_time,
"X-RateLimit-Window": window_seconds
}
is_allowed = current_requests < limit
if not is_allowed:
logger.warning(f"Rate limit exceeded for key: {key}")
return is_allowed, headers
def _check_memory_rate_limit(
self,
key: str,
limit: int,
window_seconds: int,
current_time: int,
window_start: int
) -> tuple[bool, Dict[str, int]]:
"""Check rate limit using in-memory storage"""
if key not in self.memory_store:
self.memory_store[key] = {}
# Clean old entries
store = self.memory_store[key]
keys_to_remove = [k for k, v in store.items() if v < window_start]
for k in keys_to_remove:
del store[k]
current_requests = len(store)
# Calculate remaining requests and reset time
remaining = max(0, limit - current_requests - 1)
reset_time = current_time + window_seconds
headers = {
"X-RateLimit-Limit": limit,
"X-RateLimit-Remaining": remaining,
"X-RateLimit-Reset": reset_time,
"X-RateLimit-Window": window_seconds
}
is_allowed = current_requests < limit
if is_allowed:
# Add current request
store[str(current_time)] = current_time
else:
logger.warning(f"Rate limit exceeded for key: {key}")
return is_allowed, headers
# Global rate limiter instance
rate_limiter = RateLimiter()
async def rate_limit_middleware(request: Request, call_next):
"""
Rate limiting middleware for FastAPI
"""
# Skip rate limiting for health checks and static files
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
response = await call_next(request)
return response
# Get client IP
client_ip = request.client.host
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
# Check for API key in headers
api_key = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:]
elif request.headers.get("X-API-Key"):
api_key = request.headers.get("X-API-Key")
# Determine rate limiting strategy
if api_key:
# API key-based rate limiting
rate_limit_key = f"api_key:{api_key}"
# Get API key limits from database (simplified - would implement proper lookup)
limit_per_minute = 100 # Default limit
limit_per_hour = 1000 # Default limit
# Check per-minute limit
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
else:
# IP-based rate limiting for unauthenticated requests
rate_limit_key = f"ip:{client_ip}"
# More restrictive limits for unauthenticated requests
limit_per_minute = 20
limit_per_hour = 100
# Check per-minute limit
is_allowed_minute, headers_minute = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
# If rate limit exceeded, return 429
if not is_allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"error": "RATE_LIMIT_EXCEEDED",
"message": "Rate limit exceeded. Please try again later.",
"details": {
"limit": headers["X-RateLimit-Limit"],
"reset_time": headers["X-RateLimit-Reset"]
}
},
headers={k: str(v) for k, v in headers.items()}
)
# Continue with request
response = await call_next(request)
# Add rate limit headers to response
for key, value in headers.items():
response.headers[key] = str(value)
return response
class RateLimitExceeded(HTTPException):
"""Exception raised when rate limit is exceeded"""
def __init__(self, limit: int, reset_time: int):
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Rate limit exceeded. Limit: {limit}, Reset: {reset_time}"
)
# Decorator for applying rate limits to specific endpoints
def rate_limit(requests_per_minute: int = 60, requests_per_hour: int = 1000):
"""
Decorator to apply rate limiting to specific endpoints
Args:
requests_per_minute: Maximum requests per minute
requests_per_hour: Maximum requests per hour
"""
def decorator(func):
async def wrapper(*args, **kwargs):
# This would be implemented to work with FastAPI dependencies
# For now, this is a placeholder for endpoint-specific rate limiting
return await func(*args, **kwargs)
return wrapper
return decorator
# Helper functions for different rate limiting strategies
async def check_api_key_rate_limit(api_key: str, endpoint: str) -> bool:
"""Check rate limit for specific API key and endpoint"""
# This would lookup API key specific limits from database
# For now, using default limits
key = f"api_key:{api_key}:endpoint:{endpoint}"
is_allowed, _ = await rate_limiter.check_rate_limit(
key, limit=100, window_seconds=60, identifier="endpoint"
)
return is_allowed
async def check_user_rate_limit(user_id: str, action: str) -> bool:
"""Check rate limit for specific user and action"""
key = f"user:{user_id}:action:{action}"
is_allowed, _ = await rate_limiter.check_rate_limit(
key, limit=50, window_seconds=60, identifier="user_action"
)
return is_allowed
async def apply_burst_protection(key: str) -> bool:
"""Apply burst protection for high-frequency actions"""
# Allow burst of 10 requests in 10 seconds
is_allowed, _ = await rate_limiter.check_rate_limit(
key, limit=10, window_seconds=10, identifier="burst"
)
return is_allowed

View File

@@ -1,286 +0,0 @@
"""
Security middleware for request/response processing
"""
import json
import time
from typing import Callable, Optional, Dict, Any
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.config import settings
from app.core.logging import get_logger
from app.core.threat_detection import threat_detection_service, SecurityAnalysis
logger = get_logger(__name__)
class SecurityMiddleware(BaseHTTPMiddleware):
"""Security middleware for threat detection and request filtering"""
def __init__(self, app, enabled: bool = True):
super().__init__(app)
self.enabled = enabled and settings.API_SECURITY_ENABLED
logger.info(f"SecurityMiddleware initialized, enabled: {self.enabled}")
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Process request through security analysis"""
if not self.enabled:
# Security disabled, pass through
return await call_next(request)
# Skip security analysis for certain endpoints
if self._should_skip_security(request):
response = await call_next(request)
return self._add_security_headers(response)
# Simple authentication check - drop requests without valid auth
if not self._has_valid_auth(request):
return JSONResponse(
content={"error": "Authentication required", "message": "Valid API key or authentication token required"},
status_code=401,
headers={"WWW-Authenticate": "Bearer"}
)
try:
# Get user context if available
user_context = getattr(request.state, 'user', None)
# Perform security analysis
start_time = time.time()
analysis = await threat_detection_service.analyze_request(request, user_context)
analysis_time = time.time() - start_time
# Store analysis in request state for later use
request.state.security_analysis = analysis
# Log security events (only for significant threats to reduce false positive noise)
# Only log if: being blocked OR risk score above warning threshold (0.6)
if analysis.is_threat and (analysis.should_block or analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD):
await self._log_security_event(request, analysis)
# Check if request should be blocked
if analysis.should_block:
threat_detection_service.stats['threats_blocked'] += 1
logger.warning(f"Blocked request from {request.client.host if request.client else 'unknown'}: "
f"risk_score={analysis.risk_score:.3f}, threats={len(analysis.threats)}")
# Return security block response
return self._create_block_response(analysis)
# Log warnings for medium-risk requests
if analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
logger.warning(f"High-risk request detected from {request.client.host if request.client else 'unknown'}: "
f"risk_score={analysis.risk_score:.3f}, auth_level={analysis.auth_level.value}")
# Continue with request processing
response = await call_next(request)
# Add security headers and metrics
response = self._add_security_headers(response)
response = self._add_security_metrics(response, analysis, analysis_time)
return response
except Exception as e:
logger.error(f"Security middleware error: {e}")
# Continue with request on security middleware errors to avoid breaking the app
response = await call_next(request)
return self._add_security_headers(response)
def _should_skip_security(self, request: Request) -> bool:
"""Determine if security analysis should be skipped for this request"""
path = request.url.path
# Skip for health checks, authentication endpoints, and static assets
skip_paths = [
"/health",
"/metrics",
"/api/v1/docs",
"/api/v1/openapi.json",
"/api/v1/redoc",
"/favicon.ico",
"/api/v1/auth/register",
"/api/v1/auth/login",
"/api/v1/auth/refresh", # Allow refresh endpoint
"/api-internal/v1/auth/register",
"/api-internal/v1/auth/login",
"/api-internal/v1/auth/refresh", # Allow refresh endpoint for internal API
"/", # Root endpoint
]
# Skip for static file extensions
static_extensions = [".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".ico", ".svg", ".woff", ".woff2"]
return (
path in skip_paths or
any(path.endswith(ext) for ext in static_extensions) or
path.startswith("/static/")
)
def _has_valid_auth(self, request: Request) -> bool:
"""Check if request has valid authentication"""
# Check Authorization header
auth_header = request.headers.get("Authorization", "")
api_key_header = request.headers.get("X-API-Key", "")
# Has some form of auth token/key
return (
auth_header.startswith("Bearer ") and len(auth_header) > 7 or
len(api_key_header.strip()) > 0
)
def _create_block_response(self, analysis: SecurityAnalysis) -> JSONResponse:
"""Create response for blocked requests"""
# Determine status code based on threat type
status_code = 403 # Forbidden by default
# Rate limiting gets 429
if analysis.rate_limit_exceeded:
status_code = 429
# Critical threats get 403
for threat in analysis.threats:
if threat.threat_type in ["command_injection", "sql_injection"]:
status_code = 403
break
response_data = {
"error": "Security Policy Violation",
"message": "Request blocked due to security policy violation",
"risk_score": round(analysis.risk_score, 3),
"auth_level": analysis.auth_level.value,
"threat_count": len(analysis.threats),
"recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations
}
# Add rate limiting info if applicable
if analysis.rate_limit_exceeded:
response_data["error"] = "Rate Limit Exceeded"
response_data["message"] = f"Rate limit exceeded for {analysis.auth_level.value} user"
response_data["retry_after"] = "60" # Suggest retry after 60 seconds
response = JSONResponse(
content=response_data,
status_code=status_code
)
# Add rate limiting headers
if analysis.rate_limit_exceeded:
response.headers["Retry-After"] = "60"
response.headers["X-RateLimit-Limit"] = "See API documentation"
response.headers["X-RateLimit-Reset"] = str(int(time.time() + 60))
return response
def _add_security_headers(self, response: Response) -> Response:
"""Add security headers to response"""
if not settings.API_SECURITY_HEADERS_ENABLED:
return response
# Standard security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Only add HSTS for HTTPS
if hasattr(response, 'headers') and response.headers.get("X-Forwarded-Proto") == "https":
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
# Content Security Policy
if settings.API_CSP_HEADER:
response.headers["Content-Security-Policy"] = settings.API_CSP_HEADER
return response
def _add_security_metrics(self, response: Response, analysis: SecurityAnalysis, analysis_time: float) -> Response:
"""Add security metrics to response headers (for debugging/monitoring)"""
# Only add in debug mode or for admin users
if settings.APP_DEBUG:
response.headers["X-Security-Risk-Score"] = str(round(analysis.risk_score, 3))
response.headers["X-Security-Threats"] = str(len(analysis.threats))
response.headers["X-Security-Auth-Level"] = analysis.auth_level.value
response.headers["X-Security-Analysis-Time"] = f"{analysis_time*1000:.1f}ms"
return response
async def _log_security_event(self, request: Request, analysis: SecurityAnalysis):
"""Log security events for audit and monitoring"""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "")
# Create security event log
event_data = {
"timestamp": analysis.timestamp.isoformat(),
"client_ip": client_ip,
"user_agent": user_agent,
"path": str(request.url.path),
"method": request.method,
"risk_score": round(analysis.risk_score, 3),
"auth_level": analysis.auth_level.value,
"threat_count": len(analysis.threats),
"rate_limit_exceeded": analysis.rate_limit_exceeded,
"should_block": analysis.should_block,
"threats": [
{
"type": threat.threat_type,
"level": threat.level.value,
"confidence": round(threat.confidence, 3),
"description": threat.description
}
for threat in analysis.threats[:5] # Limit to first 5 threats
],
"recommendations": analysis.recommendations
}
# Log at appropriate level based on risk
if analysis.should_block:
logger.warning(f"SECURITY_BLOCK: {json.dumps(event_data)}")
elif analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
logger.warning(f"SECURITY_WARNING: {json.dumps(event_data)}")
else:
logger.info(f"SECURITY_THREAT: {json.dumps(event_data)}")
def setup_security_middleware(app, enabled: bool = True) -> None:
"""Setup security middleware on FastAPI app"""
if enabled and settings.API_SECURITY_ENABLED:
app.add_middleware(SecurityMiddleware, enabled=enabled)
logger.info("Security middleware enabled")
else:
logger.info("Security middleware disabled")
# Helper functions for manual security checks
async def analyze_request_security(request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
"""Manually analyze request security (for use in route handlers)"""
return await threat_detection_service.analyze_request(request, user_context)
def get_security_stats() -> Dict[str, Any]:
"""Get security statistics"""
return threat_detection_service.get_stats()
def is_request_blocked(request: Request) -> bool:
"""Check if request was blocked by security analysis"""
if hasattr(request.state, 'security_analysis'):
return request.state.security_analysis.should_block
return False
def get_request_risk_score(request: Request) -> float:
"""Get risk score for request"""
if hasattr(request.state, 'security_analysis'):
return request.state.security_analysis.risk_score
return 0.0
def get_request_auth_level(request: Request) -> str:
"""Get authentication level for request"""
if hasattr(request.state, 'security_analysis'):
return request.state.security_analysis.auth_level.value
return "unknown"

View File

@@ -0,0 +1,21 @@
"""
Chatbot Module - AI Chatbot with RAG Integration
This module provides AI chatbot capabilities with:
- Multiple personality types (Assistant, Customer Support, Teacher, etc.)
- RAG integration for knowledge-based responses
- Conversation memory and context management
- Workflow integration as building blocks
- UI-configurable settings
"""
from .main import ChatbotModule, create_module
__version__ = "1.0.0"
__author__ = "Enclava Team"
# Export main classes for easy importing
__all__ = [
"ChatbotModule",
"create_module"
]

View File

@@ -0,0 +1,126 @@
{
"title": "Chatbot Configuration",
"type": "object",
"properties": {
"name": {
"type": "string",
"title": "Chatbot Name",
"description": "Display name for this chatbot instance",
"minLength": 1,
"maxLength": 100
},
"chatbot_type": {
"type": "string",
"title": "Chatbot Type",
"description": "Select the type of chatbot personality",
"enum": ["assistant", "customer_support", "teacher", "researcher", "creative_writer", "custom"],
"enumNames": ["General Assistant", "Customer Support", "Teacher", "Researcher", "Creative Writer", "Custom"],
"default": "assistant"
},
"model": {
"type": "string",
"title": "AI Model",
"description": "Choose the LLM model for responses",
"enum": ["gpt-4", "gpt-3.5-turbo", "claude-3-sonnet", "claude-3-opus", "llama-70b"],
"default": "gpt-3.5-turbo"
},
"system_prompt": {
"type": "string",
"title": "System Prompt",
"description": "Define the chatbot's personality and behavior instructions",
"ui:widget": "textarea",
"ui:options": {
"rows": 6,
"placeholder": "You are a helpful AI assistant..."
}
},
"use_rag": {
"type": "boolean",
"title": "Enable Knowledge Base",
"description": "Use RAG to search knowledge base for context",
"default": false
},
"rag_collection": {
"type": "string",
"title": "Knowledge Base Collection",
"description": "Select which document collection to search",
"ui:widget": "rag-collection-selector",
"ui:condition": "use_rag === true"
},
"rag_top_k": {
"type": "integer",
"title": "Knowledge Base Results",
"description": "Number of relevant documents to include",
"minimum": 1,
"maximum": 10,
"default": 5,
"ui:condition": "use_rag === true"
},
"temperature": {
"type": "number",
"title": "Response Creativity",
"description": "Controls randomness (0.0 = focused, 1.0 = creative)",
"minimum": 0,
"maximum": 1,
"default": 0.7,
"ui:widget": "range",
"ui:options": {
"step": 0.1
}
},
"max_tokens": {
"type": "integer",
"title": "Maximum Response Length",
"description": "Maximum number of tokens in response",
"minimum": 50,
"maximum": 4000,
"default": 1000,
"ui:widget": "range",
"ui:options": {
"step": 50
}
},
"memory_length": {
"type": "integer",
"title": "Conversation Memory",
"description": "Number of previous message pairs to remember",
"minimum": 1,
"maximum": 50,
"default": 10,
"ui:widget": "range"
},
"fallback_responses": {
"type": "array",
"title": "Fallback Responses",
"description": "Responses to use when the AI cannot answer",
"items": {
"type": "string",
"title": "Fallback Response"
},
"default": [
"I'm not sure how to help with that. Could you please rephrase your question?",
"I don't have enough information to answer that question accurately.",
"That's outside my knowledge area. Is there something else I can help you with?"
],
"ui:options": {
"orderable": true,
"addable": true,
"removable": true
}
}
},
"required": ["name", "chatbot_type", "model"],
"ui:order": [
"name",
"chatbot_type",
"model",
"system_prompt",
"use_rag",
"rag_collection",
"rag_top_k",
"temperature",
"max_tokens",
"memory_length",
"fallback_responses"
]
}

View File

@@ -0,0 +1,182 @@
{
"name": "Customer Support Workflow",
"description": "Intelligent customer support workflow with intent classification, knowledge base search, and chatbot response generation",
"version": "1.0",
"variables": {
"support_chatbot_id": "cs-bot-001",
"escalation_threshold": 0.3,
"max_attempts": 3
},
"steps": [
{
"id": "classify_intent",
"name": "Classify Customer Intent",
"type": "llm_call",
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "You are an intent classifier for customer support. Classify the customer message into one of these categories: technical_issue, billing_question, feature_request, complaint, general_inquiry. Also provide a confidence score between 0 and 1. Respond with JSON: {\"intent\": \"category\", \"confidence\": 0.95, \"reasoning\": \"explanation\"}"
},
{
"role": "user",
"content": "{{ inputs.customer_message }}"
}
],
"output_variable": "intent_classification"
},
{
"id": "search_knowledge_base",
"name": "Search Knowledge Base",
"type": "workflow_step",
"module": "rag",
"action": "search",
"config": {
"query": "{{ inputs.customer_message }}",
"collection": "support_documentation",
"top_k": 5,
"include_metadata": true
},
"output_variable": "knowledge_results"
},
{
"id": "check_confidence",
"name": "Check Intent Confidence",
"type": "condition",
"condition": "JSON.parse(steps.classify_intent.result).confidence > variables.escalation_threshold",
"true_steps": [
{
"id": "generate_chatbot_response",
"name": "Generate Chatbot Response",
"type": "workflow_step",
"module": "chatbot",
"action": "workflow_chat_step",
"config": {
"message": "{{ inputs.customer_message }}",
"chatbot_id": "{{ variables.support_chatbot_id }}",
"use_rag": true,
"context": {
"intent": "{{ steps.classify_intent.result }}",
"knowledge_base_results": "{{ steps.search_knowledge_base.result }}",
"customer_history": "{{ inputs.customer_history }}",
"additional_instructions": "Be empathetic and professional. If you cannot fully resolve the issue, offer to escalate to a human agent."
}
},
"output_variable": "chatbot_response"
},
{
"id": "analyze_response_quality",
"name": "Analyze Response Quality",
"type": "llm_call",
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "Analyze if this customer support response adequately addresses the customer's question. Consider completeness, accuracy, and helpfulness. Respond with JSON: {\"quality_score\": 0.85, \"is_adequate\": true, \"requires_escalation\": false, \"reasoning\": \"explanation\"}"
},
{
"role": "user",
"content": "Customer Question: {{ inputs.customer_message }}\\n\\nChatbot Response: {{ steps.generate_chatbot_response.result.response }}\\n\\nKnowledge Base Context: {{ steps.search_knowledge_base.result }}"
}
],
"output_variable": "response_quality"
},
{
"id": "final_response_decision",
"name": "Final Response Decision",
"type": "condition",
"condition": "JSON.parse(steps.analyze_response_quality.result).is_adequate === true",
"true_steps": [
{
"id": "send_chatbot_response",
"name": "Send Chatbot Response",
"type": "output",
"config": {
"response_type": "chatbot_response",
"message": "{{ steps.generate_chatbot_response.result.response }}",
"sources": "{{ steps.generate_chatbot_response.result.sources }}",
"confidence": "{{ JSON.parse(steps.classify_intent.result).confidence }}",
"quality_score": "{{ JSON.parse(steps.analyze_response_quality.result).quality_score }}"
}
}
],
"false_steps": [
{
"id": "escalate_to_human",
"name": "Escalate to Human Agent",
"type": "output",
"config": {
"response_type": "human_escalation",
"message": "I'd like to connect you with one of our human support agents who can better assist with your specific situation. Please hold on while I transfer you.",
"escalation_reason": "Response quality below threshold",
"intent": "{{ steps.classify_intent.result }}",
"attempted_response": "{{ steps.generate_chatbot_response.result.response }}",
"priority": "normal"
}
}
]
}
],
"false_steps": [
{
"id": "low_confidence_escalation",
"name": "Low Confidence Escalation",
"type": "output",
"config": {
"response_type": "human_escalation",
"message": "I want to make sure you get the best possible help. Let me connect you with one of our human support agents.",
"escalation_reason": "Low intent classification confidence",
"intent": "{{ steps.classify_intent.result }}",
"priority": "high"
}
}
]
},
{
"id": "log_interaction",
"name": "Log Customer Interaction",
"type": "workflow_step",
"module": "analytics",
"action": "log_event",
"config": {
"event_type": "customer_support_interaction",
"data": {
"customer_message": "{{ inputs.customer_message }}",
"intent_classification": "{{ steps.classify_intent.result }}",
"response_generated": "{{ steps.generate_chatbot_response.result.response }}",
"knowledge_base_used": "{{ steps.search_knowledge_base.result }}",
"escalated": "{{ outputs.response_type === 'human_escalation' }}",
"workflow_execution_time": "{{ execution_time }}",
"timestamp": "{{ current_timestamp }}"
}
}
}
],
"outputs": {
"response_type": "string",
"message": "string",
"sources": "array",
"escalation_reason": "string",
"confidence": "number",
"quality_score": "number"
},
"error_handling": {
"retry_failed_steps": true,
"max_retries": 2,
"fallback_response": "I apologize, but I'm experiencing technical difficulties. Please contact our support team directly for assistance."
},
"metadata": {
"created_by": "support_team",
"use_case": "customer_support_automation",
"tags": ["customer_support", "chatbot", "rag", "escalation"],
"estimated_execution_time": "5-15 seconds"
}
}

View File

@@ -0,0 +1,949 @@
"""
Chatbot Module Implementation
Provides AI chatbot capabilities with:
- RAG integration for knowledge-based responses
- Custom prompts and personalities
- Conversation memory and context
- Workflow integration as building blocks
- UI-configurable settings
"""
import json
from pprint import pprint
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Union
from dataclasses import dataclass
from pydantic import BaseModel, Field
from enum import Enum
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.services.llm.service import llm_service
from app.services.llm.models import ChatRequest as LLMChatRequest, ChatMessage as LLMChatMessage
from app.services.llm.exceptions import LLMError, ProviderError, SecurityError
from app.services.base_module import BaseModule, Permission
from app.models.user import User
from app.models.chatbot import ChatbotInstance as DBChatbotInstance, ChatbotConversation as DBConversation, ChatbotMessage as DBMessage, ChatbotAnalytics
from app.core.security import get_current_user
from app.db.database import get_db
from app.core.config import settings
# Import protocols for type hints and dependency injection
from ..protocols import RAGServiceProtocol
# Note: LiteLLMClientProtocol replaced with direct LLM service usage
logger = get_logger(__name__)
class ChatbotType(str, Enum):
"""Types of chatbot personalities"""
ASSISTANT = "assistant"
CUSTOMER_SUPPORT = "customer_support"
TEACHER = "teacher"
RESEARCHER = "researcher"
CREATIVE_WRITER = "creative_writer"
CUSTOM = "custom"
class MessageRole(str, Enum):
"""Message roles in conversation"""
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
@dataclass
class ChatbotConfig:
"""Chatbot configuration"""
name: str
chatbot_type: str # Changed from ChatbotType enum to str to allow custom types
model: str
rag_collection: Optional[str] = None
system_prompt: str = ""
temperature: float = 0.7
max_tokens: int = 1000
memory_length: int = 10 # Number of previous messages to remember
use_rag: bool = False
rag_top_k: int = 5
rag_score_threshold: float = 0.02 # Lowered from default 0.3 to allow more results
fallback_responses: List[str] = None
def __post_init__(self):
if self.fallback_responses is None:
self.fallback_responses = [
"I'm not sure how to help with that. Could you please rephrase your question?",
"I don't have enough information to answer that question accurately.",
"That's outside my knowledge area. Is there something else I can help you with?"
]
class ChatMessage(BaseModel):
"""Individual chat message"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
role: MessageRole
content: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
metadata: Dict[str, Any] = Field(default_factory=dict)
sources: Optional[List[Dict[str, Any]]] = None
class Conversation(BaseModel):
"""Conversation state"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
chatbot_id: str
user_id: str
messages: List[ChatMessage] = Field(default_factory=list)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
metadata: Dict[str, Any] = Field(default_factory=dict)
class ChatRequest(BaseModel):
"""Chat completion request"""
message: str
conversation_id: Optional[str] = None
chatbot_id: str
use_rag: Optional[bool] = None
context: Optional[Dict[str, Any]] = None
class ChatResponse(BaseModel):
"""Chat completion response"""
response: str
conversation_id: str
message_id: str
sources: Optional[List[Dict[str, Any]]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
class ChatbotInstance(BaseModel):
"""Configured chatbot instance"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str
config: ChatbotConfig
created_by: str
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
is_active: bool = True
class ChatbotModule(BaseModule):
"""Main chatbot module implementation"""
def __init__(self, rag_service: Optional[RAGServiceProtocol] = None):
super().__init__("chatbot")
self.rag_module = rag_service # Keep same name for compatibility
self.db_session = None
# System prompts will be loaded from database
self.system_prompts = {}
async def initialize(self, **kwargs):
"""Initialize the chatbot module"""
await super().initialize(**kwargs)
# Initialize the LLM service
await llm_service.initialize()
# Get RAG module dependency if not already injected
if not self.rag_module:
try:
# Try to get RAG module from module manager
from app.services.module_manager import module_manager
if hasattr(module_manager, 'modules') and 'rag' in module_manager.modules:
self.rag_module = module_manager.modules['rag']
logger.info("RAG module injected from module manager")
except Exception as e:
logger.warning(f"Could not inject RAG module: {e}")
# Load prompt templates from database
await self._load_prompt_templates()
logger.info("Chatbot module initialized")
logger.info(f"LLM service available: {llm_service._initialized}")
logger.info(f"RAG module available after init: {self.rag_module is not None}")
logger.info(f"Loaded {len(self.system_prompts)} prompt templates")
async def _ensure_dependencies(self):
"""Lazy load dependencies if not available"""
# Ensure LLM service is initialized
if not llm_service._initialized:
await llm_service.initialize()
logger.info("LLM service lazy loaded")
if not self.rag_module:
try:
# Try to get RAG module from module manager
from app.services.module_manager import module_manager
if hasattr(module_manager, 'modules') and 'rag' in module_manager.modules:
self.rag_module = module_manager.modules['rag']
logger.info("RAG module lazy loaded from module manager")
except Exception as e:
logger.warning(f"Could not lazy load RAG module: {e}")
async def _load_prompt_templates(self):
"""Load prompt templates from database"""
try:
from app.db.database import SessionLocal
from app.models.prompt_template import PromptTemplate
from sqlalchemy import select
db = SessionLocal()
try:
result = db.execute(
select(PromptTemplate)
.where(PromptTemplate.is_active == True)
)
templates = result.scalars().all()
for template in templates:
self.system_prompts[template.type_key] = template.system_prompt
logger.info(f"Loaded {len(self.system_prompts)} prompt templates from database")
finally:
db.close()
except Exception as e:
logger.warning(f"Could not load prompt templates from database: {e}")
# Fallback to hardcoded prompts
self.system_prompts = {
"assistant": "You are a helpful AI assistant. Provide accurate, concise, and friendly responses. Always aim to be helpful while being honest about your limitations.",
"customer_support": "You are a professional customer support representative. Be empathetic, professional, and solution-focused in all interactions.",
"teacher": "You are an experienced educational tutor. Break down complex concepts into understandable parts. Be patient, supportive, and encouraging.",
"researcher": "You are a thorough research assistant with a focus on accuracy and evidence-based information.",
"creative_writer": "You are an experienced creative writing mentor and storytelling expert.",
"custom": "You are a helpful AI assistant. Your personality and behavior will be defined by custom instructions."
}
async def get_system_prompt_for_type(self, chatbot_type: str) -> str:
"""Get system prompt for a specific chatbot type"""
if chatbot_type in self.system_prompts:
return self.system_prompts[chatbot_type]
# If not found, try to reload templates
await self._load_prompt_templates()
return self.system_prompts.get(chatbot_type, self.system_prompts.get("assistant",
"You are a helpful AI assistant. Provide accurate, concise, and friendly responses."))
async def create_chatbot(self, config: ChatbotConfig, user_id: str, db: Session) -> ChatbotInstance:
"""Create a new chatbot instance"""
# Set system prompt based on type if not provided or empty
if not config.system_prompt or config.system_prompt.strip() == "":
config.system_prompt = await self.get_system_prompt_for_type(config.chatbot_type)
# Create database record
db_chatbot = DBChatbotInstance(
name=config.name,
description=f"{config.chatbot_type.replace('_', ' ').title()} chatbot",
config=config.__dict__,
created_by=user_id
)
db.add(db_chatbot)
db.commit()
db.refresh(db_chatbot)
# Convert to response model
chatbot = ChatbotInstance(
id=db_chatbot.id,
name=db_chatbot.name,
config=ChatbotConfig(**db_chatbot.config),
created_by=db_chatbot.created_by,
created_at=db_chatbot.created_at,
updated_at=db_chatbot.updated_at,
is_active=db_chatbot.is_active
)
logger.info(f"Created new chatbot: {chatbot.name} ({chatbot.id})")
return chatbot
async def chat_completion(self, request: ChatRequest, user_id: str, db: Session) -> ChatResponse:
"""Generate chat completion response"""
# Get chatbot configuration from database
db_chatbot = db.query(DBChatbotInstance).filter(DBChatbotInstance.id == request.chatbot_id).first()
if not db_chatbot:
raise HTTPException(status_code=404, detail="Chatbot not found")
chatbot_config = ChatbotConfig(**db_chatbot.config)
# Get or create conversation
conversation = await self._get_or_create_conversation(
request.conversation_id, request.chatbot_id, user_id, db
)
# Create user message
user_message = DBMessage(
conversation_id=conversation.id,
role=MessageRole.USER.value,
content=request.message
)
db.add(user_message)
db.commit()
db.refresh(user_message)
logger.info(f"Created user message with ID {user_message.id} for conversation {conversation.id}")
try:
# Force the session to see the committed changes
db.expire_all()
# Get conversation history for context - includes the current message we just created
# Fetch up to memory_length pairs of messages (user + assistant)
# The +1 ensures we include the current message if we're at the limit
messages = db.query(DBMessage).filter(
DBMessage.conversation_id == conversation.id
).order_by(DBMessage.timestamp.desc()).limit(chatbot_config.memory_length * 2 + 1).all()
logger.info(f"Query for conversation_id={conversation.id}, memory_length={chatbot_config.memory_length}")
logger.info(f"Found {len(messages)} messages in conversation history")
# If we don't have any messages, manually add the user message we just created
if len(messages) == 0:
logger.warning(f"No messages found in query, but we just created message {user_message.id}")
logger.warning(f"Using the user message we just created")
messages = [user_message]
for idx, msg in enumerate(messages):
logger.info(f"Message {idx}: id={msg.id}, role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...")
# Generate response
response_content, sources = await self._generate_response(
request.message, messages, chatbot_config, request.context, db
)
# Create assistant message
assistant_message = DBMessage(
conversation_id=conversation.id,
role=MessageRole.ASSISTANT.value,
content=response_content,
sources=sources,
metadata={"model": chatbot_config.model, "temperature": chatbot_config.temperature}
)
db.add(assistant_message)
db.commit()
db.refresh(assistant_message)
# Update conversation timestamp
conversation.updated_at = datetime.utcnow()
db.commit()
return ChatResponse(
response=response_content,
conversation_id=conversation.id,
message_id=assistant_message.id,
sources=sources
)
except Exception as e:
logger.error(f"Chat completion failed: {e}")
# Return fallback response
fallback = chatbot_config.fallback_responses[0] if chatbot_config.fallback_responses else "I'm having trouble responding right now."
assistant_message = DBMessage(
conversation_id=conversation.id,
role=MessageRole.ASSISTANT.value,
content=fallback,
metadata={"error": str(e), "fallback": True}
)
db.add(assistant_message)
db.commit()
db.refresh(assistant_message)
return ChatResponse(
response=fallback,
conversation_id=conversation.id,
message_id=assistant_message.id,
metadata={"error": str(e), "fallback": True}
)
async def _generate_response(self, message: str, db_messages: List[DBMessage],
config: ChatbotConfig, context: Optional[Dict] = None, db: Session = None) -> tuple[str, Optional[List]]:
"""Generate response using LLM with optional RAG"""
# Lazy load dependencies if not available
await self._ensure_dependencies()
sources = None
rag_context = ""
# Helper: detect encryption-related queries for extra care
def _is_encryption_query(q: str) -> bool:
ql = (q or "").lower()
return any(k in ql for k in ["encrypt", "encryption", "encrypted", "decrypt", "decryption", "sd card", "microsd", "micro-sd"])
is_encryption = _is_encryption_query(message)
# RAG search if enabled
if config.use_rag and config.rag_collection and self.rag_module:
logger.info(f"RAG search enabled for collection: {config.rag_collection}")
try:
# Get the Qdrant collection name from RAG collection
qdrant_collection_name = await self._get_qdrant_collection_name(config.rag_collection, db)
logger.info(f"Qdrant collection name: {qdrant_collection_name}")
if qdrant_collection_name:
logger.info(f"Searching RAG documents: query='{message[:50]}...', max_results={config.rag_top_k}")
rag_results = await self.rag_module.search_documents(
query=message,
max_results=config.rag_top_k,
collection_name=qdrant_collection_name,
score_threshold=config.rag_score_threshold
)
# If the user asks about encryption, prefer results that explicitly mention it
if rag_results and is_encryption:
kw = ["encrypt", "encryption", "encrypted", "decrypt", "decryption"]
filtered = [r for r in rag_results if any(k in (r.document.content or "").lower() for k in kw)]
if filtered:
rag_results = filtered + [r for r in rag_results if r not in filtered]
if rag_results:
logger.info(f"RAG search found {len(rag_results)} results")
sources = [{"title": f"Document {i+1}", "content": result.document.content[:200]}
for i, result in enumerate(rag_results)]
# Build full RAG context from all results
rag_context = "\n\nRelevant information from knowledge base:\n" + "\n\n".join([
f"[Document {i+1}]:\n{result.document.content}" for i, result in enumerate(rag_results)
])
# Detailed RAG logging - ALWAYS log for debugging
logger.info("=== COMPREHENSIVE RAG SEARCH RESULTS ===")
logger.info(f"Query: '{message}'")
logger.info(f"Collection: {qdrant_collection_name}")
logger.info(f"Number of results: {len(rag_results)}")
for i, result in enumerate(rag_results):
logger.info(f"\n--- RAG Result {i+1} ---")
logger.info(f"Score: {getattr(result, 'score', 'N/A')}")
logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}")
logger.info(f"Full Content ({len(result.document.content)} chars):")
logger.info(f"{result.document.content}")
if hasattr(result.document, 'metadata'):
logger.info(f"Metadata: {result.document.metadata}")
logger.info(f"\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
logger.info(rag_context)
logger.info("=== END RAG SEARCH RESULTS ===")
else:
logger.warning("RAG search returned no results")
else:
logger.warning(f"RAG collection '{config.rag_collection}' not found in database")
except Exception as e:
logger.warning(f"RAG search failed: {e}")
import traceback
logger.warning(f"RAG search traceback: {traceback.format_exc()}")
# Build conversation context (includes the current message from db_messages)
# Inject strict grounding instructions when RAG is used, especially for encryption questions
extra_instructions = {}
if config.use_rag:
guardrails = (
"Answer strictly using the 'Relevant information' provided. "
"If the information does not explicitly answer the question, say you don't have enough information instead of guessing. "
)
if is_encryption:
guardrails += (
"When asked about encryption or SD-card backups, do not claim that backups are encrypted unless the provided context explicitly uses wording like 'encrypt', 'encrypted', or 'encryption'. "
"If such wording is absent, state clearly that the SD-card backup is not encrypted. "
"Product policy: For BitBox devices, microSD (SD card) backups are not encrypted; verification steps may require a recovery password, but that is not encryption. Do not conflate password entry with encryption. "
)
extra_instructions["additional_instructions"] = guardrails
# Deterministic enforcement: if encryption question and RAG context does not explicitly
# contain encryption wording, return policy answer without calling the LLM.
ctx_lower = (rag_context or "").lower()
has_encryption_terms = any(k in ctx_lower for k in ["encrypt", "encrypted", "encryption", "decrypt", "decryption"])
if is_encryption and not has_encryption_terms:
policy_answer = (
"No. BitBox microSD (SD card) backups are not encrypted. "
"Verification may require entering a recovery password, but that does not encrypt the backup — "
"it only proves you have the correct credentials to restore. Keep the card and password secure."
)
return policy_answer, sources
messages = self._build_conversation_messages(db_messages, config, rag_context, extra_instructions)
# Note: Current user message is already included in db_messages from the query
logger.info(f"Built conversation context with {len(messages)} messages")
# LLM completion
logger.info(f"Attempting LLM completion with model: {config.model}")
logger.info(f"Messages to send: {len(messages)} messages")
# Always log detailed prompts for debugging
logger.info("=== COMPREHENSIVE LLM REQUEST ===")
logger.info(f"Model: {config.model}")
logger.info(f"Temperature: {config.temperature}")
logger.info(f"Max tokens: {config.max_tokens}")
logger.info(f"RAG enabled: {config.use_rag}")
logger.info(f"RAG collection: {config.rag_collection}")
if config.use_rag and rag_context:
logger.info(f"RAG context added: {len(rag_context)} characters")
logger.info(f"RAG sources: {len(sources) if sources else 0} documents")
logger.info("\n=== COMPLETE MESSAGES SENT TO LLM ===")
for i, msg in enumerate(messages):
logger.info(f"\n--- Message {i+1} ---")
logger.info(f"Role: {msg['role']}")
logger.info(f"Content ({len(msg['content'])} chars):")
# Truncate long content for logging (full RAG context can be very long)
if len(msg['content']) > 500:
logger.info(f"{msg['content'][:500]}... [truncated, total {len(msg['content'])} chars]")
else:
logger.info(msg['content'])
logger.info("=== END COMPREHENSIVE LLM REQUEST ===")
try:
logger.info("Calling LLM service create_chat_completion...")
# Convert messages to LLM service format
llm_messages = [LLMChatMessage(role=msg["role"], content=msg["content"]) for msg in messages]
# Create LLM service request
llm_request = LLMChatRequest(
model=config.model,
messages=llm_messages,
temperature=config.temperature,
max_tokens=config.max_tokens,
user_id="chatbot_user",
api_key_id=0 # Chatbot module uses internal service
)
# Make request to LLM service
llm_response = await llm_service.create_chat_completion(llm_request)
# Extract response content
if llm_response.choices:
content = llm_response.choices[0].message.content
logger.info(f"Response content length: {len(content)}")
# Always log response for debugging
logger.info("=== COMPREHENSIVE LLM RESPONSE ===")
logger.info(f"Response content ({len(content)} chars):")
logger.info(content)
if llm_response.usage:
usage = llm_response.usage
logger.info(f"Token usage - Prompt: {usage.prompt_tokens}, Completion: {usage.completion_tokens}, Total: {usage.total_tokens}")
if sources:
logger.info(f"RAG sources included: {len(sources)} documents")
logger.info("=== END COMPREHENSIVE LLM RESPONSE ===")
return content, sources
else:
logger.warning("No choices in LLM response")
return "I received an empty response from the AI model.", sources
except SecurityError as e:
logger.error(f"Security error in LLM completion: {e}")
raise HTTPException(status_code=400, detail=f"Security validation failed: {e.message}")
except ProviderError as e:
logger.error(f"Provider error in LLM completion: {e}")
raise HTTPException(status_code=503, detail="LLM service temporarily unavailable")
except LLMError as e:
logger.error(f"LLM service error: {e}")
raise HTTPException(status_code=500, detail="LLM service error")
except Exception as e:
logger.error(f"LLM completion failed: {e}")
# Return fallback if available
return "I'm currently unable to process your request. Please try again later.", None
def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig,
rag_context: str = "", context: Optional[Dict] = None) -> List[Dict]:
"""Build messages array for LLM completion"""
messages = []
# System prompt
system_prompt = config.system_prompt
if rag_context:
# Add explicit instruction to use RAG context
system_prompt += "\n\nIMPORTANT: Use the following information from the knowledge base to answer the user's question. " \
"This information is directly relevant to their query and should be your primary source:\n" + rag_context
if context and context.get('additional_instructions'):
system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}"
messages.append({"role": "system", "content": system_prompt})
logger.info(f"Building messages from {len(db_messages)} database messages")
# Conversation history (messages are already limited by memory_length in the query)
# Reverse to get chronological order
# Include ALL messages - the current user message is needed for the LLM to respond!
for idx, msg in enumerate(reversed(db_messages)):
logger.info(f"Processing message {idx}: role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...")
if msg.role in ["user", "assistant"]:
messages.append({
"role": msg.role,
"content": msg.content
})
logger.info(f"Added message with role {msg.role} to LLM messages")
else:
logger.info(f"Skipped message with role {msg.role}")
logger.info(f"Final messages array has {len(messages)} messages") # For debugging, can be removed in production
return messages
async def _get_or_create_conversation(self, conversation_id: Optional[str],
chatbot_id: str, user_id: str, db: Session) -> DBConversation:
"""Get existing conversation or create new one"""
if conversation_id:
conversation = db.query(DBConversation).filter(DBConversation.id == conversation_id).first()
if conversation:
return conversation
# Create new conversation
conversation = DBConversation(
chatbot_id=chatbot_id,
user_id=user_id,
title="New Conversation"
)
db.add(conversation)
db.commit()
db.refresh(conversation)
return conversation
def get_router(self) -> APIRouter:
"""Get FastAPI router for chatbot endpoints"""
router = APIRouter(prefix="/chatbot", tags=["chatbot"])
@router.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
request: ChatRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Chat completion endpoint"""
return await self.chat_completion(request, str(current_user['id']), db)
@router.post("/create", response_model=ChatbotInstance)
async def create_chatbot_endpoint(
config: ChatbotConfig,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Create new chatbot instance"""
return await self.create_chatbot(config, str(current_user['id']), db)
@router.get("/list", response_model=List[ChatbotInstance])
async def list_chatbots_endpoint(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""List user's chatbots"""
db_chatbots = db.query(DBChatbotInstance).filter(
(DBChatbotInstance.created_by == str(current_user['id'])) |
(DBChatbotInstance.created_by == "system")
).all()
chatbots = []
for db_chatbot in db_chatbots:
chatbot = ChatbotInstance(
id=db_chatbot.id,
name=db_chatbot.name,
config=ChatbotConfig(**db_chatbot.config),
created_by=db_chatbot.created_by,
created_at=db_chatbot.created_at,
updated_at=db_chatbot.updated_at,
is_active=db_chatbot.is_active
)
chatbots.append(chatbot)
return chatbots
@router.get("/conversations/{conversation_id}", response_model=Conversation)
async def get_conversation_endpoint(
conversation_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Get conversation history"""
conversation = db.query(DBConversation).filter(
DBConversation.id == conversation_id
).first()
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
# Check if user owns this conversation
if conversation.user_id != str(current_user['id']):
raise HTTPException(status_code=403, detail="Not authorized")
# Get messages
messages = db.query(DBMessage).filter(
DBMessage.conversation_id == conversation_id
).order_by(DBMessage.timestamp).all()
# Convert to response model
chat_messages = []
for msg in messages:
chat_message = ChatMessage(
id=msg.id,
role=MessageRole(msg.role),
content=msg.content,
timestamp=msg.timestamp,
metadata=msg.metadata or {},
sources=msg.sources
)
chat_messages.append(chat_message)
response_conversation = Conversation(
id=conversation.id,
chatbot_id=conversation.chatbot_id,
user_id=conversation.user_id,
messages=chat_messages,
created_at=conversation.created_at,
updated_at=conversation.updated_at,
metadata=conversation.context_data or {}
)
return response_conversation
@router.get("/types", response_model=List[Dict[str, str]])
async def get_chatbot_types_endpoint():
"""Get available chatbot types and their descriptions"""
return [
{"type": "assistant", "name": "General Assistant", "description": "Helpful AI assistant for general questions"},
{"type": "customer_support", "name": "Customer Support", "description": "Professional customer service chatbot"},
{"type": "teacher", "name": "Teacher", "description": "Educational tutor and learning assistant"},
{"type": "researcher", "name": "Researcher", "description": "Research assistant with fact-checking focus"},
{"type": "creative_writer", "name": "Creative Writer", "description": "Creative writing and storytelling assistant"},
{"type": "custom", "name": "Custom", "description": "Custom chatbot with user-defined personality"}
]
return router
# API Compatibility Methods
async def chat(self, chatbot_config: Dict[str, Any], message: str,
conversation_history: List = None, user_id: str = "anonymous") -> Dict[str, Any]:
"""Chat method for API compatibility"""
logger.info(f"Chat method called with message: {message[:50]}... by user: {user_id}")
# Lazy load dependencies
await self._ensure_dependencies()
logger.info(f"LLM service available: {llm_service._initialized}")
logger.info(f"RAG module available: {self.rag_module is not None}")
try:
# Create a minimal database session for the chat
from app.db.database import SessionLocal
db = SessionLocal()
try:
# Convert config dict to ChatbotConfig
config = ChatbotConfig(
name=chatbot_config.get("name", "Unknown"),
chatbot_type=chatbot_config.get("chatbot_type", "assistant"),
model=chatbot_config.get("model", "gpt-3.5-turbo"),
system_prompt=chatbot_config.get("system_prompt", ""),
temperature=chatbot_config.get("temperature", 0.7),
max_tokens=chatbot_config.get("max_tokens", 1000),
memory_length=chatbot_config.get("memory_length", 10),
use_rag=chatbot_config.get("use_rag", False),
rag_collection=chatbot_config.get("rag_collection"),
rag_top_k=chatbot_config.get("rag_top_k", 5),
fallback_responses=chatbot_config.get("fallback_responses", [])
)
# Generate response using internal method
# Create a temporary message object for the current user message
temp_messages = [
DBMessage(
id=0,
conversation_id=0,
role="user",
content=message,
timestamp=datetime.utcnow(),
metadata={}
)
]
response_content, sources = await self._generate_response(
message, temp_messages, config, None, db
)
return {
"response": response_content,
"sources": sources,
"conversation_id": None,
"message_id": f"msg_{uuid.uuid4()}"
}
finally:
db.close()
except Exception as e:
logger.error(f"Chat method failed: {e}")
fallback_responses = chatbot_config.get("fallback_responses", [
"I'm sorry, I'm having trouble processing your request right now."
])
return {
"response": fallback_responses[0] if fallback_responses else "I'm sorry, I couldn't process your request.",
"sources": None,
"conversation_id": None,
"message_id": f"msg_{uuid.uuid4()}"
}
# Workflow Integration Methods
async def workflow_chat_step(self, context: Dict[str, Any], step_config: Dict[str, Any], db: Session) -> Dict[str, Any]:
"""Execute chatbot as a workflow step"""
message = step_config.get('message', '')
chatbot_id = step_config.get('chatbot_id')
use_rag = step_config.get('use_rag', False)
# Template substitution from context
message = self._substitute_template_variables(message, context)
request = ChatRequest(
message=message,
chatbot_id=chatbot_id,
use_rag=use_rag,
context=step_config.get('context', {})
)
# Use system user for workflow executions
response = await self.chat_completion(request, "workflow_system", db)
return {
"response": response.response,
"conversation_id": response.conversation_id,
"sources": response.sources,
"metadata": response.metadata
}
def _substitute_template_variables(self, template: str, context: Dict[str, Any]) -> str:
"""Simple template variable substitution"""
import re
def replace_var(match):
var_path = match.group(1)
try:
# Simple dot notation support: context.user.name
value = context
for part in var_path.split('.'):
value = value[part]
return str(value)
except (KeyError, TypeError):
return match.group(0) # Return original if not found
return re.sub(r'\\{\\{\\s*([^}]+)\\s*\\}\\}', replace_var, template)
async def _get_qdrant_collection_name(self, collection_identifier: str, db: Session) -> Optional[str]:
"""Get Qdrant collection name from RAG collection ID, name, or direct Qdrant collection"""
try:
from app.models.rag_collection import RagCollection
from sqlalchemy import select
logger.info(f"Looking up RAG collection with identifier: '{collection_identifier}'")
# First check if this might be a direct Qdrant collection name
# (e.g., starts with "ext_", "rag_", or contains specific patterns)
if collection_identifier.startswith(("ext_", "rag_", "test_")) or "_" in collection_identifier:
# Check if this collection exists in Qdrant directly
actual_collection_name = collection_identifier
# Remove "ext_" prefix if present
if collection_identifier.startswith("ext_"):
actual_collection_name = collection_identifier[4:]
logger.info(f"Checking if '{actual_collection_name}' exists in Qdrant directly")
if self.rag_module:
try:
# Try to verify the collection exists in Qdrant
from qdrant_client import QdrantClient
qdrant_client = QdrantClient(host="enclava-qdrant", port=6333)
collections = qdrant_client.get_collections()
collection_names = [c.name for c in collections.collections]
if actual_collection_name in collection_names:
logger.info(f"Found Qdrant collection directly: {actual_collection_name}")
return actual_collection_name
except Exception as e:
logger.warning(f"Error checking Qdrant collections: {e}")
rag_collection = None
# Then try PostgreSQL lookup by ID if numeric
if collection_identifier.isdigit():
logger.info(f"Treating '{collection_identifier}' as collection ID")
stmt = select(RagCollection).where(
RagCollection.id == int(collection_identifier),
RagCollection.is_active == True
)
result = db.execute(stmt)
rag_collection = result.scalar_one_or_none()
# If not found by ID, try to look up by name in PostgreSQL
if not rag_collection:
logger.info(f"Collection not found by ID, trying by name: '{collection_identifier}'")
stmt = select(RagCollection).where(
RagCollection.name == collection_identifier,
RagCollection.is_active == True
)
result = db.execute(stmt)
rag_collection = result.scalar_one_or_none()
if rag_collection:
logger.info(f"Found RAG collection: ID={rag_collection.id}, name='{rag_collection.name}', qdrant_collection='{rag_collection.qdrant_collection_name}'")
return rag_collection.qdrant_collection_name
else:
logger.warning(f"RAG collection '{collection_identifier}' not found in database (tried both ID and name)")
return None
except Exception as e:
logger.error(f"Error looking up RAG collection '{collection_identifier}': {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
return None
# Required abstract methods from BaseModule
async def cleanup(self):
"""Cleanup chatbot module resources"""
logger.info("Chatbot module cleanup completed")
def get_required_permissions(self) -> List[Permission]:
"""Get required permissions for chatbot module"""
return [
Permission("chatbots", "create", "Create chatbot instances"),
Permission("chatbots", "configure", "Configure chatbot settings"),
Permission("chatbots", "chat", "Use chatbot for conversations"),
Permission("chatbots", "manage", "Manage all chatbots")
]
async def process_request(self, request_type: str, data: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
"""Process chatbot requests"""
if request_type == "chat":
# Handle chat requests
chat_request = ChatRequest(**data)
user_id = context.get("user_id", "anonymous")
db = context.get("db")
if db:
response = await self.chat_completion(chat_request, user_id, db)
return {
"success": True,
"response": response.response,
"conversation_id": response.conversation_id,
"sources": response.sources
}
return {"success": False, "error": f"Unknown request type: {request_type}"}
# Module factory function
def create_module(rag_service: Optional[RAGServiceProtocol] = None) -> ChatbotModule:
"""Factory function to create chatbot module instance"""
return ChatbotModule(rag_service=rag_service)
# Create module instance (dependencies will be injected via factory)
chatbot_module = ChatbotModule()

View File

@@ -0,0 +1,110 @@
name: chatbot
version: 1.0.0
description: "AI Chatbot with RAG integration and customizable prompts"
author: "Enclava Team"
category: "conversation"
# Module lifecycle
enabled: true
auto_start: true
dependencies:
- rag
optional_dependencies:
- analytics
# Configuration
config_schema: "./config_schema.json"
ui_components: "./ui_components/"
# Module capabilities
provides:
- "chat_completion"
- "conversation_management"
- "chatbot_configuration"
consumes:
- "rag_search"
- "llm_completion"
# API endpoints
endpoints:
- path: "/chatbot/chat"
method: "POST"
description: "Generate chat completion"
- path: "/chatbot/create"
method: "POST"
description: "Create new chatbot instance"
- path: "/chatbot/list"
method: "GET"
description: "List user chatbots"
# UI Configuration
ui_config:
icon: "message-circle"
color: "#10B981"
category: "AI & ML"
# Configuration forms
forms:
- name: "basic_config"
title: "Basic Settings"
fields: ["name", "chatbot_type", "model"]
- name: "personality"
title: "Personality & Prompts"
fields: ["system_prompt", "temperature", "fallback_responses"]
- name: "knowledge_base"
title: "Knowledge Base"
fields: ["use_rag", "rag_collection", "rag_top_k"]
- name: "advanced"
title: "Advanced Settings"
fields: ["max_tokens", "memory_length"]
# Permissions
permissions:
- name: "chatbot.create"
description: "Create new chatbot instances"
- name: "chatbot.configure"
description: "Configure chatbot settings"
- name: "chatbot.chat"
description: "Use chatbot for conversations"
- name: "chatbot.manage"
description: "Manage all chatbots (admin)"
# Analytics events
analytics_events:
- name: "chatbot_created"
description: "New chatbot instance created"
- name: "chat_message_sent"
description: "User sent message to chatbot"
- name: "chat_response_generated"
description: "Chatbot generated response"
- name: "rag_context_used"
description: "RAG context was used in response"
# Health checks
health_checks:
- name: "llm_connectivity"
description: "Check LLM client connection"
- name: "rag_availability"
description: "Check RAG module availability"
- name: "conversation_memory"
description: "Check conversation storage health"
# Documentation
documentation:
readme: "./README.md"
examples: "./examples/"
api_docs: "./docs/api.md"

View File

@@ -0,0 +1,225 @@
"""
Module Factory for Confidential Empire
This factory creates and wires up all modules with their dependencies.
It ensures proper dependency injection while maintaining optimal performance
through direct method calls and minimal indirection.
"""
from typing import Dict, Optional, Any
import logging
# Import all modules
from .rag.main import RAGModule
from .chatbot.main import ChatbotModule, create_module as create_chatbot_module
from .workflow.main import WorkflowModule
# Import services that modules depend on
from app.services.litellm_client import LiteLLMClient
# Import protocols for type safety
from .protocols import (
RAGServiceProtocol,
ChatbotServiceProtocol,
LiteLLMClientProtocol,
WorkflowServiceProtocol,
ServiceRegistry
)
logger = logging.getLogger(__name__)
class ModuleFactory:
"""Factory for creating and wiring module dependencies"""
def __init__(self):
self.modules: Dict[str, Any] = {}
self.initialized = False
async def create_all_modules(self, config: Optional[Dict[str, Any]] = None) -> ServiceRegistry:
"""
Create all modules with proper dependency injection
Args:
config: Optional configuration for modules
Returns:
Dictionary of created modules with their dependencies wired
"""
config = config or {}
logger.info("Creating modules with dependency injection...")
# Step 1: Create LiteLLM client (shared dependency)
litellm_client = LiteLLMClient()
# Step 2: Create RAG module (no dependencies on other modules)
rag_module = RAGModule(config=config.get("rag", {}))
# Step 3: Create chatbot module with RAG dependency
chatbot_module = create_chatbot_module(
litellm_client=litellm_client,
rag_service=rag_module # RAG module implements RAGServiceProtocol
)
# Step 4: Create workflow module with chatbot dependency
workflow_module = WorkflowModule(
chatbot_service=chatbot_module # Chatbot module implements ChatbotServiceProtocol
)
# Store all modules
modules = {
"rag": rag_module,
"chatbot": chatbot_module,
"workflow": workflow_module
}
logger.info(f"Created {len(modules)} modules with dependencies wired")
# Initialize all modules
await self._initialize_modules(modules, config)
self.modules = modules
self.initialized = True
return modules
async def _initialize_modules(self, modules: Dict[str, Any], config: Dict[str, Any]):
"""Initialize all modules in dependency order"""
# Initialize in dependency order (modules with no deps first)
initialization_order = [
("rag", modules["rag"]),
("chatbot", modules["chatbot"]), # Depends on RAG
("workflow", modules["workflow"]) # Depends on Chatbot
]
for module_name, module in initialization_order:
try:
logger.info(f"Initializing {module_name} module...")
module_config = config.get(module_name, {})
# Different modules have different initialization patterns
if hasattr(module, 'initialize'):
if module_name == "rag":
await module.initialize()
else:
await module.initialize(**module_config)
logger.info(f"{module_name} module initialized successfully")
except Exception as e:
logger.error(f"❌ Failed to initialize {module_name} module: {e}")
raise RuntimeError(f"Module initialization failed: {module_name}") from e
async def cleanup_all_modules(self):
"""Cleanup all modules in reverse dependency order"""
if not self.initialized:
return
# Cleanup in reverse order
cleanup_order = ["workflow", "chatbot", "rag"]
for module_name in cleanup_order:
if module_name in self.modules:
try:
logger.info(f"Cleaning up {module_name} module...")
module = self.modules[module_name]
if hasattr(module, 'cleanup'):
await module.cleanup()
logger.info(f"{module_name} module cleaned up")
except Exception as e:
logger.error(f"❌ Error cleaning up {module_name}: {e}")
self.modules.clear()
self.initialized = False
def get_module(self, name: str) -> Optional[Any]:
"""Get a module by name"""
return self.modules.get(name)
def is_initialized(self) -> bool:
"""Check if factory is initialized"""
return self.initialized
# Global factory instance
module_factory = ModuleFactory()
# Convenience functions for external use
async def create_modules(config: Optional[Dict[str, Any]] = None) -> ServiceRegistry:
"""Create all modules with dependencies wired"""
return await module_factory.create_all_modules(config)
async def cleanup_modules():
"""Cleanup all modules"""
await module_factory.cleanup_all_modules()
def get_module(name: str) -> Optional[Any]:
"""Get a module by name"""
return module_factory.get_module(name)
def get_all_modules() -> Dict[str, Any]:
"""Get all modules"""
return module_factory.modules.copy()
# Factory functions for individual modules (for testing/special cases)
def create_rag_module(config: Optional[Dict[str, Any]] = None) -> RAGModule:
"""Create RAG module"""
return RAGModule(config=config or {})
def create_chatbot_with_rag(rag_service: RAGServiceProtocol,
litellm_client: LiteLLMClientProtocol) -> ChatbotModule:
"""Create chatbot module with RAG dependency"""
return create_chatbot_module(litellm_client=litellm_client, rag_service=rag_service)
def create_workflow_with_chatbot(chatbot_service: ChatbotServiceProtocol) -> WorkflowModule:
"""Create workflow module with chatbot dependency"""
return WorkflowModule(chatbot_service=chatbot_service)
# Module registry for backward compatibility
class ModuleRegistry:
"""Registry that provides access to modules (for backward compatibility)"""
def __init__(self, factory: ModuleFactory):
self._factory = factory
@property
def modules(self) -> Dict[str, Any]:
"""Get all modules (compatible with existing module_manager interface)"""
return self._factory.modules
def get(self, name: str) -> Optional[Any]:
"""Get module by name"""
return self._factory.get_module(name)
def __getitem__(self, name: str) -> Any:
"""Support dictionary-style access"""
module = self.get(name)
if module is None:
raise KeyError(f"Module '{name}' not found")
return module
def keys(self):
"""Get module names"""
return self._factory.modules.keys()
def values(self):
"""Get module instances"""
return self._factory.modules.values()
def items(self):
"""Get module name-instance pairs"""
return self._factory.modules.items()
# Create registry instance for backward compatibility
module_registry = ModuleRegistry(module_factory)

View File

@@ -0,0 +1,258 @@
"""
Module Protocols for Confidential Empire
This file defines the interface contracts that modules must implement for inter-module communication.
Using Python protocols provides compile-time type checking with zero runtime overhead.
"""
from typing import Protocol, Dict, List, Any, Optional, Union
from datetime import datetime
from abc import abstractmethod
class RAGServiceProtocol(Protocol):
"""Protocol for RAG (Retrieval-Augmented Generation) service interface"""
@abstractmethod
async def search(self, query: str, collection_name: str, top_k: int) -> Dict[str, Any]:
"""
Search for relevant documents
Args:
query: Search query string
collection_name: Name of the collection to search in
top_k: Number of top results to return
Returns:
Dictionary containing search results with 'results' key
"""
...
@abstractmethod
async def index_document(self, content: str, metadata: Dict[str, Any] = None) -> str:
"""
Index a document in the vector database
Args:
content: Document content to index
metadata: Optional metadata for the document
Returns:
Document ID
"""
...
@abstractmethod
async def delete_document(self, document_id: str) -> bool:
"""
Delete a document from the vector database
Args:
document_id: ID of document to delete
Returns:
True if successfully deleted
"""
...
class ChatbotServiceProtocol(Protocol):
"""Protocol for Chatbot service interface"""
@abstractmethod
async def chat_completion(self, request: Any, user_id: str, db: Any) -> Any:
"""
Generate chat completion response
Args:
request: Chat request object
user_id: ID of the user making the request
db: Database session
Returns:
Chat response object
"""
...
@abstractmethod
async def create_chatbot(self, config: Any, user_id: str, db: Any) -> Any:
"""
Create a new chatbot instance
Args:
config: Chatbot configuration
user_id: ID of the user creating the chatbot
db: Database session
Returns:
Created chatbot instance
"""
...
class LiteLLMClientProtocol(Protocol):
"""Protocol for LiteLLM client interface"""
@abstractmethod
async def completion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> Any:
"""
Create a completion using the specified model
Args:
model: Model name to use
messages: List of messages for the conversation
**kwargs: Additional parameters for the completion
Returns:
Completion response object
"""
...
@abstractmethod
async def create_chat_completion(self, model: str, messages: List[Dict[str, str]],
user_id: str, api_key_id: str, **kwargs) -> Any:
"""
Create a chat completion with user tracking
Args:
model: Model name to use
messages: List of messages for the conversation
user_id: ID of the user making the request
api_key_id: API key identifier
**kwargs: Additional parameters
Returns:
Chat completion response
"""
...
class CacheServiceProtocol(Protocol):
"""Protocol for Cache service interface"""
@abstractmethod
async def get(self, key: str, default: Any = None) -> Any:
"""
Get value from cache
Args:
key: Cache key
default: Default value if key not found
Returns:
Cached value or default
"""
...
@abstractmethod
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""
Set value in cache
Args:
key: Cache key
value: Value to cache
ttl: Time to live in seconds
Returns:
True if successfully cached
"""
...
@abstractmethod
async def delete(self, key: str) -> bool:
"""
Delete key from cache
Args:
key: Cache key to delete
Returns:
True if successfully deleted
"""
...
class SecurityServiceProtocol(Protocol):
"""Protocol for Security service interface"""
@abstractmethod
async def analyze_request(self, request: Any) -> Any:
"""
Perform security analysis on a request
Args:
request: Request object to analyze
Returns:
Security analysis result
"""
...
@abstractmethod
async def validate_request(self, request: Any) -> bool:
"""
Validate request for security compliance
Args:
request: Request object to validate
Returns:
True if request is valid/safe
"""
...
class WorkflowServiceProtocol(Protocol):
"""Protocol for Workflow service interface"""
@abstractmethod
async def execute_workflow(self, workflow: Any, input_data: Dict[str, Any] = None) -> Any:
"""
Execute a workflow definition
Args:
workflow: Workflow definition to execute
input_data: Optional input data for the workflow
Returns:
Workflow execution result
"""
...
@abstractmethod
async def get_execution(self, execution_id: str) -> Any:
"""
Get workflow execution status
Args:
execution_id: ID of the execution to retrieve
Returns:
Execution status object
"""
...
class ModuleServiceProtocol(Protocol):
"""Base protocol for all module services"""
@abstractmethod
async def initialize(self, **kwargs) -> None:
"""Initialize the module"""
...
@abstractmethod
async def cleanup(self) -> None:
"""Cleanup module resources"""
...
@abstractmethod
def get_required_permissions(self) -> List[Any]:
"""Get required permissions for this module"""
...
# Type aliases for common service combinations
ServiceRegistry = Dict[str, ModuleServiceProtocol]
ServiceDependencies = Dict[str, Optional[ModuleServiceProtocol]]

View File

@@ -0,0 +1,6 @@
"""
RAG (Retrieval-Augmented Generation) module for Confidential Empire platform
"""
from .main import RAGModule
__all__ = ["RAGModule"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,82 @@
name: rag
version: 1.0.0
description: "Document search, retrieval, and vector storage"
author: "Enclava Team"
category: "ai"
# Module lifecycle
enabled: true
auto_start: true
dependencies: []
optional_dependencies:
- cache
# Module capabilities
provides:
- "document_storage"
- "semantic_search"
- "vector_embeddings"
- "document_processing"
consumes:
- "qdrant_connection"
- "llm_embeddings"
- "document_parsing"
# API endpoints
endpoints:
- path: "/rag/collections"
method: "GET"
description: "List document collections"
- path: "/rag/upload"
method: "POST"
description: "Upload and process documents"
- path: "/rag/search"
method: "POST"
description: "Semantic search in documents"
- path: "/rag/collections/{collection_id}/documents"
method: "GET"
description: "List documents in collection"
# UI Configuration
ui_config:
icon: "search"
color: "#8B5CF6"
category: "AI & ML"
forms:
- name: "collection_config"
title: "Collection Settings"
fields: ["name", "description", "embedding_model"]
- name: "search_config"
title: "Search Configuration"
fields: ["top_k", "similarity_threshold", "rerank_enabled"]
# Permissions
permissions:
- name: "rag.create"
description: "Create document collections"
- name: "rag.upload"
description: "Upload documents to collections"
- name: "rag.search"
description: "Search document collections"
- name: "rag.manage"
description: "Manage all collections (admin)"
# Health checks
health_checks:
- name: "qdrant_connectivity"
description: "Check Qdrant vector database connection"
- name: "embeddings_service"
description: "Check LLM embeddings service"
- name: "document_processing"
description: "Check document parsing capabilities"

View File

@@ -162,6 +162,7 @@ class DocumentProcessor:
async def _process_document(self, task: ProcessingTask) -> bool:
"""Process a single document"""
from datetime import datetime
from app.db.database import async_session_factory
async with async_session_factory() as session:
try:
@@ -182,16 +183,24 @@ class DocumentProcessor:
document.status = ProcessingStatus.PROCESSING
await session.commit()
# Get RAG module for processing (now includes content processing)
# Get RAG module for processing
try:
from app.services.module_manager import module_manager
rag_module = module_manager.get_module('rag')
# Import RAG module and initialize it properly
from modules.rag.main import RAGModule
from app.core.config import settings
# Create and initialize RAG module instance
rag_module = RAGModule(settings)
init_result = await rag_module.initialize()
if not rag_module.enabled:
raise Exception("Failed to enable RAG module")
except Exception as e:
logger.error(f"Failed to get RAG module: {e}")
raise Exception(f"RAG module not available: {e}")
if not rag_module:
raise Exception("RAG module not available")
if not rag_module or not rag_module.enabled:
raise Exception("RAG module not available or not enabled")
logger.info(f"RAG module loaded successfully for document {task.document_id}")
@@ -204,23 +213,31 @@ class DocumentProcessor:
# Process with RAG module
logger.info(f"Starting document processing for document {task.document_id} with RAG module")
# Special handling for JSONL files - skip processing phase
if document.file_type == 'jsonl':
# For JSONL files, we don't need to process content here
# The optimized JSONL processor will handle everything during indexing
document.converted_content = f"JSONL file with {len(file_content)} bytes"
document.word_count = 0 # Will be updated during indexing
document.character_count = len(file_content)
document.document_metadata = {"file_path": document.file_path, "processed": "jsonl"}
document.status = ProcessingStatus.PROCESSED
document.processed_at = datetime.utcnow()
logger.info(f"JSONL document {task.document_id} marked for optimized processing")
else:
# Standard processing for other file types
try:
# Add timeout to prevent hanging
processed_doc = await asyncio.wait_for(
rag_module.process_document(
file_content,
document.original_filename,
{}
{"file_path": document.file_path}
),
timeout=300.0 # 5 minute timeout
)
logger.info(f"Document processing completed for document {task.document_id}")
except asyncio.TimeoutError:
logger.error(f"Document processing timed out for document {task.document_id}")
raise Exception("Document processing timed out after 5 minutes")
except Exception as e:
logger.error(f"Document processing failed for document {task.document_id}: {e}")
raise
# Update document with processed content
document.converted_content = processed_doc.content
@@ -229,6 +246,12 @@ class DocumentProcessor:
document.document_metadata = processed_doc.metadata
document.status = ProcessingStatus.PROCESSED
document.processed_at = datetime.utcnow()
except asyncio.TimeoutError:
logger.error(f"Document processing timed out for document {task.document_id}")
raise Exception("Document processing timed out after 5 minutes")
except Exception as e:
logger.error(f"Document processing failed for document {task.document_id}: {e}")
raise
# Index in RAG system using same RAG module
if rag_module and document.converted_content:
@@ -245,6 +268,49 @@ class DocumentProcessor:
}
# Use the correct Qdrant collection name for this document
# For JSONL files, we need to use the processed document flow
if document.file_type == 'jsonl':
# Create a ProcessedDocument for the JSONL processor
from app.modules.rag.main import ProcessedDocument
from datetime import datetime
import hashlib
# Calculate file hash
processed_at = datetime.utcnow()
file_hash = hashlib.md5(str(document.id).encode()).hexdigest()
processed_doc = ProcessedDocument(
id=str(document.id),
content="", # Will be filled by JSONL processor
extracted_text="", # Will be filled by JSONL processor
metadata={
**doc_metadata,
"file_path": document.file_path
},
original_filename=document.original_filename,
file_type=document.file_type,
mime_type=document.mime_type,
language=document.document_metadata.get('language', 'EN'),
word_count=0, # Will be updated during processing
sentence_count=0, # Will be updated during processing
entities=[],
keywords=[],
processing_time=0.0,
processed_at=processed_at,
file_hash=file_hash,
file_size=document.file_size
)
# The JSONL processor will read the original file
await asyncio.wait_for(
rag_module.index_processed_document(
processed_doc=processed_doc,
collection_name=document.collection.qdrant_collection_name
),
timeout=300.0 # 5 minute timeout for JSONL processing
)
else:
# Use standard indexing for other file types
await asyncio.wait_for(
rag_module.index_document(
content=document.converted_content,
@@ -271,7 +337,9 @@ class DocumentProcessor:
except Exception as e:
logger.error(f"Failed to index document {task.document_id} in RAG: {e}")
# Keep as processed even if indexing fails
# Mark as error since indexing failed
document.status = ProcessingStatus.ERROR
document.processing_error = f"Indexing failed: {str(e)}"
# Don't raise the exception to avoid retries on indexing failures
await session.commit()

View File

@@ -28,9 +28,19 @@ class EmbeddingService:
await llm_service.initialize()
# Test LLM service health
health_summary = llm_service.get_health_summary()
if health_summary.get("service_status") != "healthy":
logger.error(f"LLM service unhealthy: {health_summary}")
if not llm_service._initialized:
logger.error("LLM service not initialized")
return False
# Check if PrivateMode provider is available
try:
provider_status = await llm_service.get_provider_status()
privatemode_status = provider_status.get("privatemode")
if not privatemode_status or privatemode_status.status != "healthy":
logger.error(f"PrivateMode provider not available: {privatemode_status}")
return False
except Exception as e:
logger.error(f"Failed to check provider status: {e}")
return False
self.initialized = True
@@ -75,6 +85,12 @@ class EmbeddingService:
else:
truncated_text = text
# Guard: skip empty inputs (validator rejects empty strings)
if not truncated_text.strip():
logger.debug("Empty input for embedding; using fallback vector")
batch_embeddings.append(self._generate_fallback_embedding(text))
continue
# Call LLM service embedding endpoint
from app.services.llm.service import llm_service
from app.services.llm.models import EmbeddingRequest

View File

@@ -0,0 +1,211 @@
# Enhanced Embedding Service with Rate Limiting Handling
"""
Enhanced embedding service with robust rate limiting and retry logic
"""
import asyncio
import logging
import time
from typing import List, Dict, Any, Optional
import numpy as np
from datetime import datetime, timedelta
from .embedding_service import EmbeddingService
from app.core.config import settings
logger = logging.getLogger(__name__)
class EnhancedEmbeddingService(EmbeddingService):
"""Enhanced embedding service with rate limiting handling"""
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"):
super().__init__(model_name)
self.rate_limit_tracker = {
'requests_count': 0,
'window_start': time.time(),
'window_size': 60, # 1 minute window
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 12)), # Configurable
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 1.0)),
'delay_per_request': float(getattr(settings, 'RAG_EMBEDDING_DELAY_PER_REQUEST', 0.5)),
'last_rate_limit_error': None
}
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
"""
Get embeddings with rate limiting and retry logic
"""
if max_retries is None:
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 3))
if not self.initialized:
logger.warning("Embedding service not initialized, using fallback")
return self._generate_fallback_embeddings(texts), False
embeddings = []
success = True
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
batch_embeddings, batch_success = await self._get_batch_embeddings_with_retry(batch, max_retries)
embeddings.extend(batch_embeddings)
success = success and batch_success
# Add delay between batches to avoid rate limiting
if i + batch_size < len(texts):
delay = self.rate_limit_tracker['delay_between_batches']
await asyncio.sleep(delay) # Configurable delay between batches
return embeddings, success
async def _get_batch_embeddings_with_retry(self, texts: List[str], max_retries: int) -> tuple[List[List[float]], bool]:
"""Get embeddings for a batch with retry logic"""
last_error = None
for attempt in range(max_retries + 1):
try:
# Check rate limit before making request
if self._is_rate_limited():
delay = self._get_rate_limit_delay()
logger.warning(f"Rate limit detected, waiting {delay} seconds")
await asyncio.sleep(delay)
continue
# Make the request
embeddings = await self._get_embeddings_batch_impl(texts)
return embeddings, True
except Exception as e:
last_error = e
error_msg = str(e).lower()
# Check if it's a rate limit error
if any(indicator in error_msg for indicator in ['429', 'rate limit', 'too many requests', 'quota exceeded']):
logger.warning(f"Rate limit error (attempt {attempt + 1}/{max_retries + 1}): {e}")
self._update_rate_limit_tracker(success=False)
if attempt < max_retries:
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
logger.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)
continue
else:
logger.error(f"Max retries exceeded for rate limit, using fallback embeddings")
return self._generate_fallback_embeddings(texts), False
else:
# Non-rate-limit error
logger.error(f"Error generating embeddings: {e}")
if attempt < max_retries:
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
await asyncio.sleep(delay)
else:
logger.error("Max retries exceeded, using fallback embeddings")
return self._generate_fallback_embeddings(texts), False
# If we get here, all retries failed
logger.error(f"All retries failed, last error: {last_error}")
return self._generate_fallback_embeddings(texts), False
async def _get_embeddings_batch_impl(self, texts: List[str]) -> List[List[float]]:
"""Implementation of getting embeddings for a batch"""
from app.services.llm.service import llm_service
from app.services.llm.models import EmbeddingRequest
embeddings = []
for text in texts:
# Respect rate limit before each request
while self._is_rate_limited():
delay = self._get_rate_limit_delay()
logger.warning(f"Rate limit window exceeded, waiting {delay:.2f}s before next request")
await asyncio.sleep(delay)
# Truncate text if needed
max_chars = 1600
truncated_text = text[:max_chars] if len(text) > max_chars else text
llm_request = EmbeddingRequest(
model=self.model_name,
input=truncated_text,
user_id="rag_system",
api_key_id=0
)
response = await llm_service.create_embedding(llm_request)
if response.data and len(response.data) > 0:
embedding = response.data[0].embedding
if embedding:
embeddings.append(embedding)
if not hasattr(self, '_dimension_confirmed'):
self.dimension = len(embedding)
self._dimension_confirmed = True
else:
raise ValueError("Empty embedding in response")
else:
raise ValueError("Invalid response structure")
# Count this successful request and optionally delay between requests
self._update_rate_limit_tracker(success=True)
per_req_delay = self.rate_limit_tracker.get('delay_per_request', 0)
if per_req_delay and per_req_delay > 0:
await asyncio.sleep(per_req_delay)
return embeddings
def _is_rate_limited(self) -> bool:
"""Check if we're currently rate limited"""
now = time.time()
window_start = self.rate_limit_tracker['window_start']
# Reset window if it's expired
if now - window_start > self.rate_limit_tracker['window_size']:
self.rate_limit_tracker['requests_count'] = 0
self.rate_limit_tracker['window_start'] = now
return False
# Check if we've exceeded the limit
return self.rate_limit_tracker['requests_count'] >= self.rate_limit_tracker['max_requests_per_minute']
def _get_rate_limit_delay(self) -> float:
"""Get delay to wait for rate limit reset"""
now = time.time()
window_end = self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size']
return max(0, window_end - now)
def _update_rate_limit_tracker(self, success: bool):
"""Update the rate limit tracker"""
now = time.time()
# Reset window if it's expired
if now - self.rate_limit_tracker['window_start'] > self.rate_limit_tracker['window_size']:
self.rate_limit_tracker['requests_count'] = 0
self.rate_limit_tracker['window_start'] = now
# Increment counter on successful requests
if success:
self.rate_limit_tracker['requests_count'] += 1
async def get_embedding_stats(self) -> Dict[str, Any]:
"""Get embedding service statistics including rate limiting info"""
base_stats = await self.get_stats()
return {
**base_stats,
"rate_limit_info": {
"requests_in_current_window": self.rate_limit_tracker['requests_count'],
"max_requests_per_minute": self.rate_limit_tracker['max_requests_per_minute'],
"window_reset_in_seconds": max(0,
self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - time.time()
),
"last_rate_limit_error": self.rate_limit_tracker['last_rate_limit_error']
}
}
# Global enhanced embedding service instance
enhanced_embedding_service = EnhancedEmbeddingService()

View File

@@ -0,0 +1,211 @@
"""
Optimized JSONL Processor for RAG Module
Handles JSONL files efficiently to prevent resource exhaustion
"""
import json
import logging
import asyncio
from typing import Dict, Any, List
from datetime import datetime
import uuid
from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue
from qdrant_client.http.models import Batch
from app.modules.rag.main import ProcessedDocument
# from app.core.analytics import log_module_event # Analytics module not available
logger = logging.getLogger(__name__)
class JSONLProcessor:
"""Specialized processor for JSONL files"""
def __init__(self, rag_module):
self.rag_module = rag_module
self.config = rag_module.config
async def process_and_index_jsonl(self, collection_name: str, content: bytes,
filename: str, metadata: Dict[str, Any]) -> str:
"""Process and index a JSONL file efficiently
Processes each JSON line as a separate document to avoid
creating thousands of chunks from a single large document.
"""
try:
# Decode content
jsonl_content = content.decode('utf-8', errors='replace')
lines = jsonl_content.strip().split('\n')
logger.info(f"Processing JSONL file {filename} with {len(lines)} lines")
# Generate base document ID
base_doc_id = self.rag_module._generate_document_id(jsonl_content, metadata)
# Process lines in batches
batch_size = 10 # Smaller batches for better memory management
processed_count = 0
for batch_start in range(0, len(lines), batch_size):
batch_end = min(batch_start + batch_size, len(lines))
batch_lines = lines[batch_start:batch_end]
# Process batch
await self._process_jsonl_batch(
collection_name,
batch_lines,
batch_start,
base_doc_id,
filename,
metadata
)
processed_count += len(batch_lines)
# Log progress
if processed_count % 50 == 0:
logger.info(f"Processed {processed_count}/{len(lines)} lines from {filename}")
# Small delay to prevent resource exhaustion
await asyncio.sleep(0.05)
logger.info(f"Successfully processed JSONL file {filename} with {len(lines)} lines")
return base_doc_id
except Exception as e:
logger.error(f"Error processing JSONL file {filename}: {e}")
raise
async def _process_jsonl_batch(self, collection_name: str, lines: List[str],
start_idx: int, base_doc_id: str,
filename: str, metadata: Dict[str, Any]) -> None:
"""Process a batch of JSONL lines"""
try:
points = []
for line_idx, line in enumerate(lines, start=start_idx + 1):
if not line.strip():
continue
try:
# Parse JSON line
data = json.loads(line)
# Debug: check if data is None
if data is None:
logger.warning(f"JSON line {line_idx} parsed as None")
continue
# Handle helpjuice export format
if 'payload' in data and data['payload'] is not None:
payload = data['payload']
article_id = data.get('id', f'article_{line_idx}')
# Extract Q&A
question = payload.get('question', '')
answer = payload.get('answer', '')
language = payload.get('language', 'EN')
if question or answer:
# Create Q&A content
content = f"Question: {question}\n\nAnswer: {answer}"
# Create metadata
doc_metadata = {
**metadata,
"article_id": article_id,
"language": language,
"filename": filename,
"line_number": line_idx,
"content_type": "qa_pair",
"question": question[:100], # Truncate for metadata
"processed_at": datetime.utcnow().isoformat()
}
# Generate single embedding for the Q&A pair
embeddings = await self.rag_module._generate_embeddings([content])
# Create point
point_id = str(uuid.uuid4())
points.append(PointStruct(
id=point_id,
vector=embeddings[0],
payload={
**doc_metadata,
"document_id": f"{base_doc_id}_{article_id}",
"content": content,
"chunk_index": 0,
"chunk_count": 1
}
))
# Handle generic JSON format
else:
content = json.dumps(data, indent=2, ensure_ascii=False)
# For larger JSON objects, we might need to chunk
if len(content) > 1000:
chunks = self.rag_module._chunk_text(content, chunk_size=500)
embeddings = await self.rag_module._generate_embeddings(chunks)
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
point_id = str(uuid.uuid4())
points.append(PointStruct(
id=point_id,
vector=embedding,
payload={
**metadata,
"filename": filename,
"line_number": line_idx,
"content_type": "json_object",
"document_id": f"{base_doc_id}_line_{line_idx}",
"content": chunk,
"chunk_index": i,
"chunk_count": len(chunks)
}
))
else:
# Small JSON - no chunking needed
embeddings = await self.rag_module._generate_embeddings([content])
point_id = str(uuid.uuid4())
points.append(PointStruct(
id=point_id,
vector=embeddings[0],
payload={
**metadata,
"filename": filename,
"line_number": line_idx,
"content_type": "json_object",
"document_id": f"{base_doc_id}_line_{line_idx}",
"content": content,
"chunk_index": 0,
"chunk_count": 1
}
))
except json.JSONDecodeError as e:
logger.warning(f"Error parsing JSONL line {line_idx}: {e}")
continue
except Exception as e:
logger.warning(f"Error processing JSONL line {line_idx}: {e}")
continue
# Insert all points in this batch
if points:
self.rag_module.qdrant_client.upsert(
collection_name=collection_name,
points=points
)
# Update stats
self.rag_module.stats["documents_indexed"] += len(points)
# log_module_event("rag", "jsonl_batch_processed", { # Analytics module not available
# "filename": filename,
# "lines_processed": len(lines),
# "points_created": len(points)
# })
except Exception as e:
logger.error(f"Error processing JSONL batch: {e}")
raise

View File

@@ -16,6 +16,7 @@ from .models import ResilienceConfig
class ProviderConfig(BaseModel):
"""Configuration for an LLM provider"""
name: str = Field(..., description="Provider name")
provider_type: str = Field(..., description="Provider type (e.g., 'openai', 'privatemode')")
enabled: bool = Field(True, description="Whether provider is enabled")
base_url: str = Field(..., description="Provider base URL")
api_key_env_var: str = Field(..., description="Environment variable for API key")
@@ -53,9 +54,6 @@ class LLMServiceConfig(BaseModel):
enable_security_checks: bool = Field(True, description="Enable security validation")
enable_metrics_collection: bool = Field(True, description="Enable metrics collection")
# Security settings
security_risk_threshold: float = Field(0.8, ge=0.0, le=1.0, description="Risk threshold for blocking")
security_warning_threshold: float = Field(0.6, ge=0.0, le=1.0, description="Risk threshold for warnings")
max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length")
max_response_length: int = Field(32000, ge=1000, description="Maximum response length")
@@ -66,15 +64,18 @@ class LLMServiceConfig(BaseModel):
# Provider configurations
providers: Dict[str, ProviderConfig] = Field(default_factory=dict, description="Provider configurations")
# Token rate limiting (organization-wide)
token_limits_per_minute: Dict[str, int] = Field(
default_factory=lambda: {
"prompt_tokens": 20000, # PrivateMode Standard tier
"completion_tokens": 10000 # PrivateMode Standard tier
},
description="Token rate limits per minute (organization-wide)"
)
# Model routing (model_name -> provider_name)
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
@validator('security_risk_threshold')
def validate_risk_threshold(cls, v, values):
warning_threshold = values.get('security_warning_threshold', 0.6)
if v <= warning_threshold:
raise ValueError("Risk threshold must be greater than warning threshold")
return v
def create_default_config() -> LLMServiceConfig:
@@ -84,6 +85,7 @@ def create_default_config() -> LLMServiceConfig:
# Models will be fetched dynamically from proxy /models endpoint
privatemode_config = ProviderConfig(
name="privatemode",
provider_type="privatemode",
enabled=True,
base_url=settings.PRIVATEMODE_PROXY_URL,
api_key_env_var="PRIVATEMODE_API_KEY",
@@ -91,8 +93,8 @@ def create_default_config() -> LLMServiceConfig:
supported_models=[], # Will be populated dynamically from proxy
capabilities=["chat", "embeddings", "tee"],
priority=1,
max_requests_per_minute=100,
max_requests_per_hour=2000,
max_requests_per_minute=20, # PrivateMode Standard tier limit: 20 req/min
max_requests_per_hour=1200, # 20 req/min * 60 min
supports_streaming=True,
supports_function_calling=True,
max_context_window=128000,
@@ -110,9 +112,6 @@ def create_default_config() -> LLMServiceConfig:
config = LLMServiceConfig(
default_provider="privatemode",
enable_detailed_logging=settings.LOG_LLM_PROMPTS,
enable_security_checks=settings.API_SECURITY_ENABLED,
security_risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
security_warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
providers={
"privatemode": privatemode_config
},

View File

@@ -124,7 +124,6 @@ class MetricsCollector:
total_requests = len(self._metrics)
successful_requests = sum(1 for m in self._metrics if m.success)
failed_requests = total_requests - successful_requests
security_blocked = sum(1 for m in self._metrics if not m.success and m.security_risk_score > 0.8)
# Calculate averages
latencies = [m.latency_ms for m in self._metrics if m.latency_ms > 0]
@@ -143,7 +142,6 @@ class MetricsCollector:
total_requests=total_requests,
successful_requests=successful_requests,
failed_requests=failed_requests,
security_blocked_requests=security_blocked,
average_latency_ms=avg_latency,
average_risk_score=avg_risk_score,
provider_metrics=provider_metrics,

View File

@@ -452,6 +452,8 @@ class PrivateModeProvider(BaseLLMProvider):
else:
error_text = await response.text()
# Log the detailed error response from the provider
logger.error(f"PrivateMode embedding error - Status {response.status}: {error_text}")
self._handle_http_error(response.status, error_text, "embeddings")
except aiohttp.ClientError as e:

View File

@@ -1,281 +0,0 @@
"""
LLM Security Manager
Handles prompt injection detection and audit logging.
Provides comprehensive security for LLM interactions.
"""
import os
import re
import json
import logging
import hashlib
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from app.core.config import settings
logger = logging.getLogger(__name__)
class SecurityManager:
"""Manages security for LLM operations"""
def __init__(self):
self._setup_prompt_injection_patterns()
def _setup_prompt_injection_patterns(self):
"""Setup patterns for prompt injection detection"""
self.injection_patterns = [
# Direct instruction injection
r"(?i)(ignore|forget|disregard|override)\s+(previous|all|above|prior)\s+(instructions|rules|prompts)",
r"(?i)(new|updated|different)\s+(instructions|rules|system)",
r"(?i)act\s+as\s+(if|though)\s+you\s+(are|were)",
r"(?i)pretend\s+(to\s+be|you\s+are)",
r"(?i)you\s+are\s+now\s+(a|an)\s+",
# System role manipulation
r"(?i)system\s*:\s*",
r"(?i)\[system\]",
r"(?i)<system>",
r"(?i)assistant\s*:\s*",
r"(?i)\[assistant\]",
# Escape attempts
r"(?i)\\n\\n#+",
r"(?i)```\s*(system|assistant|user)",
r"(?i)---\s*(new|system|override)",
# Role manipulation
r"(?i)(you|your)\s+(role|purpose|function)\s+(is|has\s+changed)",
r"(?i)switch\s+to\s+(admin|developer|debug)\s+mode",
r"(?i)(admin|root|sudo|developer)\s+(access|mode|privileges)",
# Information extraction attempts
r"(?i)(show|display|reveal|expose)\s+(your|the)\s+(prompt|instructions|system)",
r"(?i)what\s+(are|were)\s+your\s+(original|initial)\s+(instructions|prompts)",
r"(?i)(debug|verbose|diagnostic)\s+mode",
# Encoding/obfuscation attempts
r"(?i)base64\s*:",
r"(?i)hex\s*:",
r"(?i)unicode\s*:",
r"[A-Za-z0-9+/]{20,}={0,2}", # Potential base64
# SQL injection patterns (for system prompts)
r"(?i)(union|select|insert|update|delete|drop|create)\s+",
r"(?i)(or|and)\s+1\s*=\s*1",
r"(?i)';?\s*(drop|delete|insert)",
# Command injection patterns
r"(?i)(exec|eval|system|shell|cmd)\s*\(",
r"(?i)(\$\(|\`)[^)]+(\)|\`)",
r"(?i)&&\s*(rm|del|format)",
# Jailbreak attempts
r"(?i)jailbreak",
r"(?i)break\s+out\s+of",
r"(?i)escape\s+(the|your)\s+(rules|constraints)",
r"(?i)(DAN|Do\s+Anything\s+Now)",
r"(?i)unrestricted\s+mode",
]
self.compiled_patterns = [re.compile(pattern) for pattern in self.injection_patterns]
logger.info(f"Initialized {len(self.injection_patterns)} prompt injection patterns")
def validate_prompt_security(self, messages: List[Dict[str, str]]) -> Tuple[bool, float, List[str]]:
"""
Validate messages for prompt injection attempts
Returns:
Tuple[bool, float, List[str]]: (is_safe, risk_score, detected_patterns)
"""
detected_patterns = []
total_risk = 0.0
for message in messages:
content = message.get("content", "")
if not content:
continue
# Check against injection patterns
for i, pattern in enumerate(self.compiled_patterns):
matches = pattern.findall(content)
if matches:
pattern_risk = self._calculate_pattern_risk(i, matches)
total_risk += pattern_risk
detected_patterns.append({
"pattern_index": i,
"pattern": self.injection_patterns[i],
"matches": matches,
"risk": pattern_risk
})
# Additional security checks
total_risk += self._check_message_characteristics(content)
# Normalize risk score (0.0 to 1.0)
risk_score = min(total_risk / len(messages) if messages else 0.0, 1.0)
is_safe = risk_score < settings.API_SECURITY_RISK_THRESHOLD
if detected_patterns:
logger.warning(f"Detected {len(detected_patterns)} potential injection patterns, risk score: {risk_score}")
return is_safe, risk_score, detected_patterns
def _calculate_pattern_risk(self, pattern_index: int, matches: List) -> float:
"""Calculate risk score for a detected pattern"""
# Different patterns have different risk levels
high_risk_patterns = [0, 1, 2, 3, 4, 5, 6, 7, 14, 15, 16, 22, 23, 24] # System manipulation, jailbreak
medium_risk_patterns = [8, 9, 10, 11, 12, 13, 17, 18, 19, 20, 21] # Escape attempts, info extraction
base_risk = 0.8 if pattern_index in high_risk_patterns else 0.5 if pattern_index in medium_risk_patterns else 0.3
# Increase risk based on number of matches
match_multiplier = min(1.0 + (len(matches) - 1) * 0.2, 2.0)
return base_risk * match_multiplier
def _check_message_characteristics(self, content: str) -> float:
"""Check message characteristics for additional risk factors"""
risk = 0.0
# Excessive length (potential stuffing attack)
if len(content) > 10000:
risk += 0.3
# High ratio of special characters
special_chars = sum(1 for c in content if not c.isalnum() and not c.isspace())
if len(content) > 0 and special_chars / len(content) > 0.5:
risk += 0.4
# Multiple encoding indicators
encoding_indicators = ["base64", "hex", "unicode", "url", "ascii"]
found_encodings = sum(1 for indicator in encoding_indicators if indicator.lower() in content.lower())
if found_encodings > 1:
risk += 0.3
# Excessive newlines or formatting (potential formatting attacks)
if content.count('\n') > 50 or content.count('\\n') > 50:
risk += 0.2
return risk
def create_audit_log(
self,
user_id: str,
api_key_id: int,
provider: str,
model: str,
request_type: str,
risk_score: float,
detected_patterns: List[str],
metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create comprehensive audit log for LLM request"""
audit_entry = {
"timestamp": datetime.utcnow().isoformat(),
"user_id": user_id,
"api_key_id": api_key_id,
"provider": provider,
"model": model,
"request_type": request_type,
"security": {
"risk_score": risk_score,
"detected_patterns": detected_patterns,
"security_check_passed": risk_score < settings.API_SECURITY_RISK_THRESHOLD
},
"metadata": metadata or {},
"audit_hash": None # Will be set below
}
# Create hash for audit integrity
audit_hash = self._create_audit_hash(audit_entry)
audit_entry["audit_hash"] = audit_hash
# Log based on risk level
if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
logger.error(f"HIGH RISK LLM REQUEST BLOCKED: {json.dumps(audit_entry)}")
elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
logger.warning(f"MEDIUM RISK LLM REQUEST: {json.dumps(audit_entry)}")
else:
logger.info(f"LLM REQUEST AUDIT: user={user_id}, model={model}, risk={risk_score:.3f}")
return audit_entry
def _create_audit_hash(self, audit_entry: Dict[str, Any]) -> str:
"""Create hash for audit trail integrity"""
# Create hash from key fields (excluding the hash itself)
hash_data = {
"timestamp": audit_entry["timestamp"],
"user_id": audit_entry["user_id"],
"api_key_id": audit_entry["api_key_id"],
"provider": audit_entry["provider"],
"model": audit_entry["model"],
"request_type": audit_entry["request_type"],
"risk_score": audit_entry["security"]["risk_score"]
}
hash_string = json.dumps(hash_data, sort_keys=True)
return hashlib.sha256(hash_string.encode()).hexdigest()
def log_detailed_request(
self,
messages: List[Dict[str, str]],
model: str,
user_id: str,
provider: str,
context_info: Optional[Dict[str, Any]] = None
):
"""Log detailed LLM request if LOG_LLM_PROMPTS is enabled"""
if not settings.LOG_LLM_PROMPTS:
return
logger.info("=== DETAILED LLM REQUEST ===")
logger.info(f"Model: {model}")
logger.info(f"Provider: {provider}")
logger.info(f"User ID: {user_id}")
if context_info:
for key, value in context_info.items():
logger.info(f"{key}: {value}")
logger.info("Messages to LLM:")
for i, message in enumerate(messages):
role = message.get("role", "unknown")
content = message.get("content", "")[:500] # Truncate for logging
logger.info(f" Message {i+1} [{role}]: {content}{'...' if len(message.get('content', '')) > 500 else ''}")
logger.info("=== END DETAILED LLM REQUEST ===")
def log_detailed_response(
self,
response_content: str,
token_usage: Optional[Dict[str, int]] = None,
provider: str = "unknown"
):
"""Log detailed LLM response if LOG_LLM_PROMPTS is enabled"""
if not settings.LOG_LLM_PROMPTS:
return
logger.info("=== DETAILED LLM RESPONSE ===")
logger.info(f"Provider: {provider}")
logger.info(f"Response content: {response_content[:500]}{'...' if len(response_content) > 500 else ''}")
if token_usage:
logger.info(f"Token usage - Prompt: {token_usage.get('prompt_tokens', 0)}, "
f"Completion: {token_usage.get('completion_tokens', 0)}, "
f"Total: {token_usage.get('total_tokens', 0)}")
logger.info("=== END DETAILED LLM RESPONSE ===")
class SecurityError(Exception):
"""Security-related errors in LLM operations"""
pass
# Global security manager instance
security_manager = SecurityManager()

View File

@@ -16,9 +16,10 @@ from .models import (
ModelInfo, ProviderStatus, LLMMetrics
)
from .config import config_manager, ProviderConfig
# Security service removed as requested
from ...core.config import settings
from .resilience import ResilienceManagerFactory
from .metrics import metrics_collector
# from .metrics import metrics_collector
from .providers import BaseLLMProvider, PrivateModeProvider
from .exceptions import (
LLMError, ProviderError, SecurityError, ConfigurationError,
@@ -149,8 +150,7 @@ class LLMService:
if not request.messages:
raise ValidationError("Messages cannot be empty", field="messages")
# Security validation removed as requested
messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages]
risk_score = 0.0
# Get provider for model
provider_name = self._get_provider_for_model(request.model)
@@ -159,7 +159,6 @@ class LLMService:
if not provider:
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
# Security logging removed as requested
# Execute with resilience
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
@@ -173,49 +172,20 @@ class LLMService:
non_retryable_exceptions=(ValidationError,)
)
# Set default security values since security is removed
response.security_check = True
response.risk_score = 0.0
response.detected_patterns = []
# Security logging removed as requested
# Record successful request
# Record successful request - metrics disabled
total_latency = (time.time() - start_time) * 1000
metrics_collector.record_request(
provider=provider_name,
model=request.model,
request_type="chat_completion",
success=True,
latency_ms=total_latency,
token_usage=response.usage.model_dump() if response.usage else None,
# security_risk_score removed as requested
user_id=request.user_id,
api_key_id=request.api_key_id
)
# Security audit logging removed as requested
return response
except Exception as e:
# Record failed request
# Record failed request - metrics disabled
total_latency = (time.time() - start_time) * 1000
error_code = getattr(e, 'error_code', e.__class__.__name__)
metrics_collector.record_request(
provider=provider_name,
model=request.model,
request_type="chat_completion",
success=False,
latency_ms=total_latency,
# security_risk_score removed as requested
error_code=error_code,
user_id=request.user_id,
api_key_id=request.api_key_id
)
# Security audit logging removed as requested
raise
@@ -224,8 +194,9 @@ class LLMService:
if not self._initialized:
await self.initialize()
# Security validation removed as requested
messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages]
# Security validation disabled - always allow streaming requests
risk_score = 0.0
# Get provider
provider_name = self._get_provider_for_model(request.model)
@@ -247,19 +218,8 @@ class LLMService:
yield chunk
except Exception as e:
# Record streaming failure
# Record streaming failure - metrics disabled
error_code = getattr(e, 'error_code', e.__class__.__name__)
metrics_collector.record_request(
provider=provider_name,
model=request.model,
request_type="chat_completion_stream",
success=False,
latency_ms=0,
# security_risk_score removed as requested
error_code=error_code,
user_id=request.user_id,
api_key_id=request.api_key_id
)
raise
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
@@ -267,7 +227,9 @@ class LLMService:
if not self._initialized:
await self.initialize()
# Security validation removed as requested
# Security validation disabled - always allow embedding requests
risk_score = 0.0
# Get provider
provider_name = self._get_provider_for_model(request.model)
@@ -288,43 +250,17 @@ class LLMService:
non_retryable_exceptions=(ValidationError,)
)
# Set default security values since security is removed
response.security_check = True
response.risk_score = 0.0
# Record successful request
# Record successful request - metrics disabled
total_latency = (time.time() - start_time) * 1000
metrics_collector.record_request(
provider=provider_name,
model=request.model,
request_type="embedding",
success=True,
latency_ms=total_latency,
token_usage=response.usage.model_dump() if response.usage else None,
# security_risk_score removed as requested
user_id=request.user_id,
api_key_id=request.api_key_id
)
return response
except Exception as e:
# Record failed request
# Record failed request - metrics disabled
total_latency = (time.time() - start_time) * 1000
error_code = getattr(e, 'error_code', e.__class__.__name__)
metrics_collector.record_request(
provider=provider_name,
model=request.model,
request_type="embedding",
success=False,
latency_ms=total_latency,
# security_risk_score removed as requested
error_code=error_code,
user_id=request.user_id,
api_key_id=request.api_key_id
)
raise
async def get_models(self, provider_name: Optional[str] = None) -> List[ModelInfo]:
@@ -378,12 +314,18 @@ class LLMService:
return status_dict
def get_metrics(self) -> LLMMetrics:
"""Get service metrics"""
return metrics_collector.get_metrics()
"""Get service metrics - metrics disabled"""
# return metrics_collector.get_metrics()
return LLMMetrics(
total_requests=0,
success_rate=0.0,
avg_latency_ms=0,
error_rates={}
)
def get_health_summary(self) -> Dict[str, Any]:
"""Get comprehensive health summary"""
metrics_health = metrics_collector.get_health_summary()
"""Get comprehensive health summary - metrics disabled"""
# metrics_health = metrics_collector.get_health_summary()
resilience_health = ResilienceManagerFactory.get_all_health_status()
return {
@@ -391,7 +333,7 @@ class LLMService:
"startup_time": self._startup_time.isoformat() if self._startup_time else None,
"provider_count": len(self._providers),
"active_providers": list(self._providers.keys()),
"metrics": metrics_health,
"metrics": {"status": "disabled"},
"resilience": resilience_health
}

View File

@@ -0,0 +1,163 @@
"""
Qdrant Stats Service
Provides direct, live statistics from Qdrant vector database
This is the single source of truth for all RAG collection statistics
"""
import httpx
import logging
from typing import List, Dict, Any, Optional
from datetime import datetime
from app.core.config import settings
logger = logging.getLogger(__name__)
class QdrantStatsService:
"""Service for getting live statistics from Qdrant"""
def __init__(self):
self.qdrant_host = getattr(settings, 'QDRANT_HOST', 'enclava-qdrant')
self.qdrant_port = getattr(settings, 'QDRANT_PORT', 6333)
self.qdrant_url = f"http://{self.qdrant_host}:{self.qdrant_port}"
async def get_collections_stats(self) -> Dict[str, Any]:
"""Get live collection statistics directly from Qdrant"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# Get all collections
response = await client.get(f"{self.qdrant_url}/collections")
if response.status_code != 200:
logger.error(f"Failed to get collections: {response.status_code}")
return {"collections": [], "total_documents": 0, "total_size_bytes": 0}
data = response.json()
result = data.get("result", {})
collections_data = result.get("collections", [])
collections = []
total_documents = 0
total_size_bytes = 0
# Get detailed info for each collection
for col_info in collections_data:
collection_name = col_info.get("name", "")
# Include all collections, not just rag_ ones
# Get detailed collection info
try:
detail_response = await client.get(f"{self.qdrant_url}/collections/{collection_name}")
if detail_response.status_code == 200:
detail_data = detail_response.json()
detail_result = detail_data.get("result", {})
points_count = detail_result.get("points_count", 0)
status = detail_result.get("status", "unknown")
# Get vector size for size calculation
vector_size = 1024 # Default for multilingual-e5-large
try:
config = detail_result.get("config", {})
params = config.get("params", {})
vectors = params.get("vectors", {})
if isinstance(vectors, dict) and "size" in vectors:
vector_size = vectors["size"]
elif isinstance(vectors, dict) and "default" in vectors:
vector_size = vectors["default"].get("size", 1024)
except Exception:
pass
# Estimate size (points * vector_size * 4 bytes + 20% metadata overhead)
estimated_size = int(points_count * vector_size * 4 * 1.2)
# Extract collection metadata for user-friendly name
display_name = collection_name
description = ""
# Parse collection name to get original name
if collection_name.startswith("rag_"):
parts = collection_name[4:].split("_")
if len(parts) > 1:
# Remove the UUID suffix
uuid_parts = [p for p in parts if len(p) == 8 and all(c in '0123456789abcdef' for c in p)]
for uuid_part in uuid_parts:
parts.remove(uuid_part)
display_name = " ".join(parts).replace("_", " ").title()
collection_stat = {
"id": collection_name,
"name": display_name,
"description": description,
"document_count": points_count,
"vector_count": points_count,
"size_bytes": estimated_size,
"status": status,
"qdrant_collection_name": collection_name,
"created_at": "", # Not available from Qdrant
"updated_at": datetime.utcnow().isoformat(),
"is_active": status == "green",
"is_managed": True,
"source": "qdrant"
}
collections.append(collection_stat)
total_documents += points_count
total_size_bytes += estimated_size
except Exception as e:
logger.error(f"Error getting details for collection {collection_name}: {e}")
continue
return {
"collections": collections,
"total_documents": total_documents,
"total_size_bytes": total_size_bytes,
"total_collections": len(collections)
}
except Exception as e:
logger.error(f"Error getting Qdrant stats: {e}")
return {"collections": [], "total_documents": 0, "total_size_bytes": 0, "total_collections": 0}
async def get_collection_stats(self, collection_name: str) -> Optional[Dict[str, Any]]:
"""Get statistics for a specific collection"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(f"{self.qdrant_url}/collections/{collection_name}")
if response.status_code != 200:
return None
data = response.json()
result = data.get("result", {})
points_count = result.get("points_count", 0)
status = result.get("status", "unknown")
# Get vector size
vector_size = 1024
try:
config = result.get("config", {})
params = config.get("params", {})
vectors = params.get("vectors", {})
if isinstance(vectors, dict) and "size" in vectors:
vector_size = vectors["size"]
except Exception:
pass
estimated_size = int(points_count * vector_size * 4 * 1.2)
return {
"document_count": points_count,
"vector_count": points_count,
"size_bytes": estimated_size,
"status": status
}
except Exception as e:
logger.error(f"Error getting collection stats for {collection_name}: {e}")
return None
# Global instance
qdrant_stats_service = QdrantStatsService()

View File

@@ -755,10 +755,11 @@ class RAGService:
# Process with RAG module
try:
# Pass file_path in metadata so JSONL indexing can reopen the source file
processed_doc = await rag_module.process_document(
file_content,
document.original_filename,
{}
{"file_path": document.file_path}
)
# Success case - update document with processed content

View File

@@ -69,6 +69,7 @@ class ChatbotConfig:
memory_length: int = 10 # Number of previous messages to remember
use_rag: bool = False
rag_top_k: int = 5
rag_score_threshold: float = 0.02 # Lowered from default 0.3 to allow more results
fallback_responses: List[str] = None
def __post_init__(self):
@@ -386,7 +387,8 @@ class ChatbotModule(BaseModule):
rag_results = await self.rag_module.search_documents(
query=message,
max_results=config.rag_top_k,
collection_name=qdrant_collection_name
collection_name=qdrant_collection_name,
score_threshold=config.rag_score_threshold
)
if rag_results:
@@ -395,8 +397,8 @@ class ChatbotModule(BaseModule):
for i, result in enumerate(rag_results)]
# Build full RAG context from all results
rag_context = "\\n\\nRelevant information from knowledge base:\\n" + "\\n\\n".join([
f"[Document {i+1}]:\\n{result.document.content}" for i, result in enumerate(rag_results)
rag_context = "\n\nRelevant information from knowledge base:\n" + "\n\n".join([
f"[Document {i+1}]:\n{result.document.content}" for i, result in enumerate(rag_results)
])
# Detailed RAG logging - ALWAYS log for debugging
@@ -405,14 +407,14 @@ class ChatbotModule(BaseModule):
logger.info(f"Collection: {qdrant_collection_name}")
logger.info(f"Number of results: {len(rag_results)}")
for i, result in enumerate(rag_results):
logger.info(f"\\n--- RAG Result {i+1} ---")
logger.info(f"\n--- RAG Result {i+1} ---")
logger.info(f"Score: {getattr(result, 'score', 'N/A')}")
logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}")
logger.info(f"Full Content ({len(result.document.content)} chars):")
logger.info(f"{result.document.content}")
if hasattr(result.document, 'metadata'):
logger.info(f"Metadata: {result.document.metadata}")
logger.info(f"\\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
logger.info(f"\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
logger.info(rag_context)
logger.info("=== END RAG SEARCH RESULTS ===")
else:
@@ -445,9 +447,9 @@ class ChatbotModule(BaseModule):
if config.use_rag and rag_context:
logger.info(f"RAG context added: {len(rag_context)} characters")
logger.info(f"RAG sources: {len(sources) if sources else 0} documents")
logger.info("\\n=== COMPLETE MESSAGES SENT TO LLM ===")
logger.info("\n=== COMPLETE MESSAGES SENT TO LLM ===")
for i, msg in enumerate(messages):
logger.info(f"\\n--- Message {i+1} ---")
logger.info(f"\n--- Message {i+1} ---")
logger.info(f"Role: {msg['role']}")
logger.info(f"Content ({len(msg['content'])} chars):")
# Truncate long content for logging (full RAG context can be very long)
@@ -520,9 +522,11 @@ class ChatbotModule(BaseModule):
# System prompt
system_prompt = config.system_prompt
if rag_context:
system_prompt += rag_context
# Add explicit instruction to use RAG context
system_prompt += "\n\nIMPORTANT: Use the following information from the knowledge base to answer the user's question. " \
"This information is directly relevant to their query and should be your primary source:\n" + rag_context
if context and context.get('additional_instructions'):
system_prompt += f"\\n\\nAdditional instructions: {context['additional_instructions']}"
system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}"
messages.append({"role": "system", "content": system_prompt})
@@ -709,9 +713,21 @@ class ChatbotModule(BaseModule):
fallback_responses=chatbot_config.get("fallback_responses", [])
)
# Generate response using internal method with empty message history
# Generate response using internal method
# Create a temporary message object for the current user message
temp_messages = [
DBMessage(
id=0,
conversation_id=0,
role="user",
content=message,
timestamp=datetime.utcnow(),
metadata={}
)
]
response_content, sources = await self._generate_response(
message, [], config, None, db
message, temp_messages, config, None, db
)
return {

View File

@@ -53,7 +53,7 @@ except ImportError:
PYTHON_DOCX_AVAILABLE = False
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
from qdrant_client.models import Distance, VectorParams, PointStruct, ScoredPoint, Filter, FieldCondition, MatchValue
from qdrant_client.http import models
import tiktoken
@@ -134,6 +134,19 @@ class RAGModule(BaseModule):
self.embedding_service = None
self.tokenizer = None
# Set improved default configuration
self.config = {
"chunk_size": 300, # Reduced from 400 for better precision
"chunk_overlap": 50, # Added overlap for context preservation
"max_results": 10,
"score_threshold": 0.3, # Increased from 0.0 to filter low-quality results
"enable_hybrid": True, # Enable hybrid search (vector + BM25)
"hybrid_weights": {"vector": 0.7, "bm25": 0.3} # Weight for hybrid scoring
}
# Update with any provided config
if config:
self.config.update(config)
# Content processing components
self.nlp_model = None
self.lemmatizer = None
@@ -625,11 +638,19 @@ class RAGModule(BaseModule):
np.random.seed(hash(text) % 2**32)
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
async def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
"""Generate embeddings for multiple texts (batch processing)"""
if self.embedding_service:
# Add task-specific prefixes for better E5 model performance
if is_document:
# For document passages, use "passage:" prefix
prefixed_texts = [f"passage: {text}" for text in texts]
else:
# For queries, use "query:" prefix (handled in search method)
prefixed_texts = texts
# Use real embedding service for batch processing
return await self.embedding_service.get_embeddings(texts)
return await self.embedding_service.get_embeddings(prefixed_texts)
else:
# Fallback to individual processing
embeddings = []
@@ -639,19 +660,33 @@ class RAGModule(BaseModule):
return embeddings
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
"""Split text into chunks"""
chunk_size = chunk_size or self.config.get("chunk_size", 400)
"""Split text into overlapping chunks for better context preservation"""
chunk_size = chunk_size or self.config.get("chunk_size", 300)
chunk_overlap = self.config.get("chunk_overlap", 50)
# Tokenize text
tokens = self.tokenizer.encode(text)
# Split into chunks
# Split into chunks with overlap
chunks = []
for i in range(0, len(tokens), chunk_size):
chunk_tokens = tokens[i:i + chunk_size]
start_idx = 0
while start_idx < len(tokens):
end_idx = min(start_idx + chunk_size, len(tokens))
chunk_tokens = tokens[start_idx:end_idx]
chunk_text = self.tokenizer.decode(chunk_tokens)
# Only add non-empty chunks
if chunk_text.strip():
chunks.append(chunk_text)
# Move to next chunk with overlap
start_idx = end_idx - chunk_overlap
# Ensure progress (in case overlap >= chunk_size)
if start_idx >= end_idx:
start_idx = end_idx
return chunks
async def _process_text(self, content: bytes, filename: str) -> str:
@@ -895,12 +930,18 @@ class RAGModule(BaseModule):
- Each line contains a JSON object with 'id' and 'payload'
- Payload contains 'question', 'language', and 'answer' fields
- Combines question and answer into searchable content
Performance optimizations:
- Processes articles in smaller batches to reduce memory usage
- Uses streaming approach for large files
"""
try:
# Use streaming approach for large files
jsonl_content = content.decode('utf-8', errors='replace')
lines = jsonl_content.strip().split('\n')
processed_articles = []
batch_size = 50 # Process in batches of 50 articles
for line_num, line in enumerate(lines, 1):
if not line.strip():
@@ -1126,7 +1167,7 @@ class RAGModule(BaseModule):
chunks = self._chunk_text(content)
# Generate embeddings for all chunks in batch (more efficient)
embeddings = await self._generate_embeddings(chunks)
embeddings = await self._generate_embeddings(chunks, is_document=True)
# Create document points
points = []
@@ -1177,6 +1218,24 @@ class RAGModule(BaseModule):
collection_name = collection_name or self.default_collection_name
try:
# Special handling for JSONL files
if processed_doc.file_type == 'jsonl':
# Import the optimized JSONL processor
from app.services.jsonl_processor import JSONLProcessor
jsonl_processor = JSONLProcessor(self)
# Read the original file content
with open(processed_doc.metadata.get('file_path', ''), 'rb') as f:
file_content = f.read()
# Process using the optimized JSONL processor
return await jsonl_processor.process_and_index_jsonl(
collection_name=collection_name,
content=file_content,
filename=processed_doc.original_filename,
metadata=processed_doc.metadata
)
# Ensure collection exists
await self._ensure_collection_exists(collection_name)
@@ -1189,7 +1248,7 @@ class RAGModule(BaseModule):
chunks = self._chunk_text(processed_doc.content)
# Generate embeddings for all chunks in batch (more efficient)
embeddings = await self._generate_embeddings(chunks)
embeddings = await self._generate_embeddings(chunks, is_document=True)
# Create document points with enhanced metadata
points = []
@@ -1260,12 +1319,196 @@ class RAGModule(BaseModule):
except Exception:
return False
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
async def _hybrid_search(self, collection_name: str, query: str, query_vector: List[float],
query_filter: Optional[Filter], limit: int, score_threshold: float) -> List[Any]:
"""Perform hybrid search combining vector similarity and BM25 scoring"""
# Preprocess query for BM25
query_terms = self._preprocess_text_for_bm25(query)
# Get all documents from the collection (for BM25 scoring)
# Note: In production, you'd want to optimize this with a proper BM25 index
scroll_filter = query_filter or Filter()
all_points = []
# Use scroll to get all points
offset = None
batch_size = 100
while True:
search_result = self.qdrant_client.scroll(
collection_name=collection_name,
scroll_filter=scroll_filter,
limit=batch_size,
offset=offset,
with_payload=True,
with_vectors=False
)
points = search_result[0]
all_points.extend(points)
if len(points) < batch_size:
break
offset = points[-1].id
# Calculate BM25 scores for each document
bm25_scores = {}
for point in all_points:
doc_id = point.payload.get("document_id", "")
content = point.payload.get("content", "")
# Calculate BM25 score
bm25_score = self._calculate_bm25_score(query_terms, content)
bm25_scores[doc_id] = bm25_score
# Perform vector search
vector_results = self.qdrant_client.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=query_filter,
limit=limit * 2, # Get more results for re-ranking
score_threshold=score_threshold / 2 # Lower threshold for initial search
)
# Combine scores with improved normalization
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
vector_weight = hybrid_weights.get("vector", 0.7)
bm25_weight = hybrid_weights.get("bm25", 0.3)
# Get score distributions for better normalization
vector_scores = [r.score for r in vector_results]
bm25_scores_list = list(bm25_scores.values())
# Calculate statistics for normalization
if vector_scores:
v_max = max(vector_scores)
v_min = min(vector_scores)
v_range = v_max - v_min if v_max != v_min else 1
else:
v_max, v_min, v_range = 1, 0, 1
if bm25_scores_list:
bm25_max = max(bm25_scores_list)
bm25_min = min(bm25_scores_list)
bm25_range = bm25_max - bm25_min if bm25_max != bm25_min else 1
else:
bm25_max, bm25_min, bm25_range = 1, 0, 1
# Create hybrid results with improved scoring
hybrid_results = []
for result in vector_results:
doc_id = result.payload.get("document_id", "")
vector_score = result.score
bm25_score = bm25_scores.get(doc_id, 0.0)
# Improved normalization using actual score distributions
vector_norm = (vector_score - v_min) / v_range if v_range > 0 else 0.5
bm25_norm = (bm25_score - bm25_min) / bm25_range if bm25_range > 0 else 0.5
# Apply reciprocal rank fusion for better combination
# This gives more weight to documents that rank highly in both methods
rrf_vector = 1.0 / (1.0 + vector_results.index(result) + 1) # +1 to avoid division by zero
rrf_bm25 = 1.0 / (1.0 + sorted(bm25_scores_list, reverse=True).index(bm25_score) + 1) if bm25_score in bm25_scores_list else 0
# Calculate hybrid score using both normalized scores and RRF
hybrid_score = (vector_weight * vector_norm + bm25_weight * bm25_norm) * 0.7 + (rrf_vector + rrf_bm25) * 0.3
# Create new point with hybrid score
hybrid_point = ScoredPoint(
id=result.id,
payload=result.payload,
score=hybrid_score,
vector=result.vector,
shard_key=None,
order_value=None
)
hybrid_results.append(hybrid_point)
# Sort by hybrid score and apply final threshold
hybrid_results.sort(key=lambda x: x.score, reverse=True)
final_results = [r for r in hybrid_results if r.score >= score_threshold][:limit]
logger.info(f"Hybrid search: {len(vector_results)} vector results, {len(final_results)} final results")
return final_results
def _preprocess_text_for_bm25(self, text: str) -> List[str]:
"""Preprocess text for BM25 scoring"""
if not NLTK_AVAILABLE:
return text.lower().split()
try:
# Tokenize
tokens = word_tokenize(text.lower())
# Remove stopwords and non-alphabetic tokens
stop_words = set(stopwords.words('english'))
filtered_tokens = [
token for token in tokens
if token.isalpha() and token not in stop_words and len(token) > 2
]
return filtered_tokens
except:
# Fallback to simple splitting
return text.lower().split()
def _calculate_bm25_score(self, query_terms: List[str], document: str) -> float:
"""Calculate BM25 score for a document against query terms"""
if not query_terms:
return 0.0
# Preprocess document
doc_terms = self._preprocess_text_for_bm25(document)
if not doc_terms:
return 0.0
# Calculate term frequencies
doc_len = len(doc_terms)
avg_doc_len = 300 # Average document length (configurable)
# BM25 parameters
k1 = 1.2 # Controls term frequency saturation
b = 0.75 # Controls document length normalization
score = 0.0
# Calculate IDF for each query term
for term in set(query_terms):
# Term frequency in document
tf = doc_terms.count(term)
# Simple IDF (log(N/n) + 1)
# In production, you'd use the actual document frequency
idf = 2.0 # Simplified IDF
# BM25 formula
numerator = tf * (k1 + 1)
denominator = tf + k1 * (1 - b + b * (doc_len / avg_doc_len))
score += idf * (numerator / denominator)
# Normalize score to 0-1 range
return min(score / 10.0, 1.0) # Simple normalization
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
"""Search for relevant documents"""
if not self.enabled:
raise RuntimeError("RAG module not initialized")
collection_name = collection_name or self.default_collection_name
# Special handling for collections with different vector dimensions
SPECIAL_COLLECTIONS = {
"bitbox02_faq_local": {
"dimension": 384,
"model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
},
"bitbox_local_rag": {
"dimension": 384,
"model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
}
}
max_results = max_results or self.config.get("max_results", 10)
# Check cache (include collection name in cache key)
@@ -1278,8 +1521,25 @@ class RAGModule(BaseModule):
import time
start_time = time.time()
# Generate query embedding
query_embedding = await self._generate_embedding(query)
# Generate query embedding with task-specific prefix for better retrieval
try:
# Check if this is a special collection
if collection_name in SPECIAL_COLLECTIONS:
# Try to import sentence-transformers
import sentence_transformers
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(SPECIAL_COLLECTIONS[collection_name]["model"])
query_embedding = model.encode([query], normalize_embeddings=True)[0].tolist()
logger.info(f"Using {SPECIAL_COLLECTIONS[collection_name]['dimension']}-dim local model for {collection_name}")
else:
# The E5 model works better with "query:" prefix for search queries
optimized_query = f"query: {query}"
query_embedding = await self._generate_embedding(optimized_query)
except ImportError:
# Fallback to default embedding if sentence-transformers is not available
logger.warning(f"sentence-transformers not available, falling back to default embedding for {collection_name}")
optimized_query = f"query: {query}"
query_embedding = await self._generate_embedding(optimized_query)
# Build filter
search_filter = None
@@ -1297,13 +1557,29 @@ class RAGModule(BaseModule):
logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}")
logger.info(f"Embedding service available: {self.embedding_service is not None}")
# Search in Qdrant
# Check if hybrid search is enabled
enable_hybrid = self.config.get("enable_hybrid", False)
# Use provided score_threshold or fall back to config
search_score_threshold = score_threshold if score_threshold is not None else self.config.get("score_threshold", 0.3)
if enable_hybrid and NLTK_AVAILABLE:
# Perform hybrid search (vector + BM25)
search_results = await self._hybrid_search(
collection_name=collection_name,
query=query,
query_vector=query_embedding,
query_filter=search_filter,
limit=max_results,
score_threshold=search_score_threshold
)
else:
# Pure vector search with improved threshold
search_results = self.qdrant_client.search(
collection_name=collection_name,
query_vector=query_embedding,
query_filter=search_filter,
limit=max_results,
score_threshold=0.0 # Lowered from 0.5 to see all results including low scores
score_threshold=search_score_threshold
)
logger.info(f"Raw search results count: {len(search_results)}")
@@ -1317,6 +1593,23 @@ class RAGModule(BaseModule):
content = result.payload.get("content", "")
score = result.score
# Generic content extraction for documents without a 'content' field
if not content:
# Build content from all text-based fields in the payload
# This makes the RAG module completely agnostic to document structure
text_fields = []
for field, value in result.payload.items():
# Skip system/metadata fields
if field not in ["document_id", "chunk_index", "chunk_count", "indexed_at", "processed_at",
"file_hash", "mime_type", "file_type", "created_at", "__collection_metadata__"]:
# Include any field that has a non-empty string value
if value and isinstance(value, str) and len(value.strip()) > 0:
text_fields.append(f"{field}: {value}")
# Join all text fields to create content
if text_fields:
content = "\n\n".join(text_fields)
# Log each raw result for debugging
logger.info(f"\n--- Raw Result {i+1} ---")
logger.info(f"Score: {score}")
@@ -1651,9 +1944,9 @@ async def index_processed_document(processed_doc: ProcessedDocument, collection_
"""Index a processed document"""
return await rag_module.index_processed_document(processed_doc, collection_name)
async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
"""Search documents"""
return await rag_module.search_documents(query, max_results, filters, collection_name)
return await rag_module.search_documents(query, max_results, filters, collection_name, score_threshold)
async def delete_document(document_id: str, collection_name: str = None) -> bool:
"""Delete a document"""

View File

@@ -46,6 +46,7 @@ qdrant-client==1.7.0
# Text Processing
tiktoken==0.5.1
numpy>=1.26.0
# Basic document processing (lightweight)
markitdown==0.0.1a2
@@ -56,8 +57,9 @@ python-docx==1.1.0
# nltk==3.8.1
# spacy==3.7.2
# Heavy ML dependencies (REMOVED - unused in codebase)
# sentence-transformers==2.6.1 # REMOVED - not used anywhere in codebase
# Heavy ML dependencies (sentence-transformers will be installed separately)
# Note: PyTorch is already installed in the base Docker image
sentence-transformers==2.6.1 # Added back - needed for bitbox02_faq_local collection
# transformers==4.35.2 # REMOVED - already commented out
# Configuration

View File

@@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""
Import a JSONL file into a Qdrant collection from inside the backend container.
Usage (from host):
docker compose exec enclava-backend bash -lc \
'python /app/scripts/import_jsonl.py \
--collection rag_test_import_859b1f01 \
--file /app/_to_delete/helpjuice-export.jsonl'
Notes:
- Runs fully inside the backend, so Docker service hostnames (e.g. enclava-qdrant)
and privatemode-proxy are reachable.
- Uses RAGModule + JSONLProcessor to embed/index each JSONL line.
- Creates the collection if missing (size=1024, cosine).
"""
import argparse
import asyncio
import os
from datetime import datetime
async def import_jsonl(collection_name: str, file_path: str):
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from app.modules.rag.main import RAGModule
from app.services.jsonl_processor import JSONLProcessor
from app.core.config import settings
if not os.path.exists(file_path):
raise SystemExit(f"File not found: {file_path}")
# Ensure collection exists (inside container uses Docker DNS hostnames)
client = QdrantClient(host=settings.QDRANT_HOST, port=settings.QDRANT_PORT)
collections = client.get_collections().collections
if not any(c.name == collection_name for c in collections):
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=1024, distance=Distance.COSINE),
)
print(f"Created Qdrant collection '{collection_name}' (size=1024, cosine)")
else:
print(f"Using existing Qdrant collection '{collection_name}'")
# Initialize RAG
rag = RAGModule({
"chunk_size": 300,
"chunk_overlap": 50,
"max_results": 10,
"score_threshold": 0.3,
"embedding_model": "intfloat/multilingual-e5-large-instruct",
})
await rag.initialize()
# Process JSONL
processor = JSONLProcessor(rag)
with open(file_path, "rb") as f:
content = f.read()
doc_id = await processor.process_and_index_jsonl(
collection_name=collection_name,
content=content,
filename=os.path.basename(file_path),
metadata={
"source": "jsonl_upload",
"upload_date": datetime.utcnow().isoformat(),
"file_path": os.path.abspath(file_path),
},
)
# Report stats using safe HTTP method to avoid client parsing issues
try:
info = await rag._get_collection_info_safely(collection_name)
print(f"Import complete. Points: {info.get('points_count', 0)}, vector_size: {info.get('vector_size', 'n/a')}")
except Exception as e:
print(f"Import complete. (Could not fetch collection info safely: {e})")
await rag.cleanup()
return doc_id
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--collection", required=True, help="Qdrant collection name")
ap.add_argument("--file", required=True, help="Path inside container (e.g. /app/_to_delete/...).")
args = ap.parse_args()
asyncio.run(import_jsonl(args.collection, args.file))
if __name__ == "__main__":
main()

View File

@@ -20,11 +20,13 @@ services:
dockerfile: Dockerfile
environment:
- DATABASE_URL=postgresql://enclava_user:enclava_pass@enclava-postgres:5432/enclava_db
- JWT_SECRET=${JWT_SECRET:-your-jwt-secret-here}
depends_on:
- enclava-postgres
command: ["/usr/local/bin/migrate.sh"]
volumes:
- ./backend:/app
- ./.env:/app/.env
networks:
- enclava-net
restart: "no" # Run once and exit
@@ -63,7 +65,7 @@ services:
enclava-frontend:
image: node:18-alpine
working_dir: /app
command: sh -c "npm install && npm run dev"
command: sh -c "npm ci --ignore-scripts && npm run dev"
environment:
# Required base URL (derives APP/API/WS URLs)
- BASE_URL=${BASE_URL}
@@ -76,7 +78,7 @@ services:
- "3002:3000" # Direct frontend access for development
volumes:
- ./frontend:/app
- /app/node_modules
- enclava-frontend-node-modules:/app/node_modules
networks:
- enclava-net
restart: unless-stopped
@@ -145,6 +147,7 @@ volumes:
enclava-postgres-data:
enclava-redis-data:
enclava-qdrant-data:
enclava-frontend-node-modules:
# enclava-ollama-data:
networks:

View File

@@ -2048,6 +2048,9 @@
},
"node_modules/axios": {
"version": "1.12.2",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.12.2.tgz",
"integrity": "sha512-vMJzPewAlRyOgxV2dU0Cuz2O8zzzx9VYtbJOaBgXFeLc4IV/Eg50n4LowmehOOR61S8ZMpc2K5Sa7g6A4jfkUw==",
"license": "MIT",
"dependencies": {
"follow-redirects": "^1.15.6",

View File

@@ -3,6 +3,7 @@
import { useState, useEffect, Suspense } from "react";
export const dynamic = 'force-dynamic'
import { useSearchParams } from "next/navigation";
import { Suspense } from "react";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
@@ -94,7 +95,8 @@ const PERMISSION_OPTIONS = [
{ value: "llm:embeddings", label: "LLM Embeddings" },
];
function ApiKeysPageContent() {
function ApiKeysContent() {
const { toast } = useToast();
const searchParams = useSearchParams();
const [apiKeys, setApiKeys] = useState<ApiKey[]>([]);
@@ -910,8 +912,9 @@ function ApiKeysPageContent() {
export default function ApiKeysPage() {
return (
<Suspense fallback={<div className="container mx-auto p-6">Loading...</div>}>
<ApiKeysPageContent />
<Suspense fallback={<div>Loading API keys...</div>}>
<ApiKeysContent />
</Suspense>
);
}

View File

@@ -7,7 +7,7 @@ export async function POST(request: NextRequest) {
// Make request to backend auth endpoint without requiring existing auth
const baseUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`
const url = `${baseUrl}/api/auth/login`
const url = `${baseUrl}/api-internal/v1/auth/login`
const response = await fetch(url, {
method: 'POST',

View File

@@ -0,0 +1,56 @@
import { NextRequest, NextResponse } from 'next/server';
import { tokenManager } from '@/lib/token-manager';
export async function GET(request: NextRequest) {
try {
// Get authentication token from Authorization header or tokenManager
const authHeader = request.headers.get('authorization');
let token;
if (authHeader && authHeader.startsWith('Bearer ')) {
token = authHeader.substring(7);
} else {
token = await tokenManager.getAccessToken();
}
if (!token) {
return NextResponse.json(
{ error: 'Authentication required' },
{ status: 401 }
);
}
// Backend URL
const backendUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`;
// Build the proxy URL
const proxyUrl = `${backendUrl}/api-internal/v1/rag/debug/collections`;
// Proxy the request to the backend with authentication
const response = await fetch(proxyUrl, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${token}`,
},
});
if (!response.ok) {
const errorText = await response.text();
console.error('Backend list collections error:', response.status, errorText);
return NextResponse.json(
{ error: `Backend request failed: ${response.status}` },
{ status: response.status }
);
}
const data = await response.json();
return NextResponse.json(data);
} catch (error) {
console.error('RAG collections proxy error:', error);
return NextResponse.json(
{ error: 'Failed to proxy collections request' },
{ status: 500 }
);
}
}

View File

@@ -0,0 +1,67 @@
import { NextRequest, NextResponse } from 'next/server';
import { tokenManager } from '@/lib/token-manager';
export async function POST(request: NextRequest) {
try {
// Get the search parameters from the query string
const searchParams = request.nextUrl.searchParams;
const query = searchParams.get('query') || '';
const max_results = searchParams.get('max_results') || '10';
const score_threshold = searchParams.get('score_threshold') || '0.3';
const collection_name = searchParams.get('collection_name');
// Get the config from the request body
const body = await request.json();
// Get authentication token from Authorization header or tokenManager
const authHeader = request.headers.get('authorization');
let token;
if (authHeader && authHeader.startsWith('Bearer ')) {
token = authHeader.substring(7);
} else {
token = await tokenManager.getAccessToken();
}
if (!token) {
return NextResponse.json(
{ error: 'Authentication required' },
{ status: 401 }
);
}
// Backend URL
const backendUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`;
// Build the proxy URL with query parameters
const proxyUrl = `${backendUrl}/api-internal/v1/rag/debug/search?query=${encodeURIComponent(query)}&max_results=${max_results}&score_threshold=${score_threshold}${collection_name ? `&collection_name=${encodeURIComponent(collection_name)}` : ''}`;
// Proxy the request to the backend with authentication
const response = await fetch(proxyUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${token}`,
},
body: JSON.stringify(body),
});
if (!response.ok) {
const errorText = await response.text();
console.error('Backend RAG search error:', response.status, errorText);
return NextResponse.json(
{ error: `Backend request failed: ${response.status}` },
{ status: response.status }
);
}
const data = await response.json();
return NextResponse.json(data);
} catch (error) {
console.error('RAG debug search proxy error:', error);
return NextResponse.json(
{ error: 'Failed to proxy RAG search request' },
{ status: 500 }
);
}
}

View File

@@ -0,0 +1,569 @@
"use client";
import { useState, useEffect } from 'react';
import { useAuth } from '@/contexts/AuthContext';
import { tokenManager } from '@/lib/token-manager';
interface SearchResult {
document: {
id: string;
content: string;
metadata: Record<string, any>;
};
score: number;
debug_info?: Record<string, any>;
}
interface DebugInfo {
query_embedding?: number[];
embedding_dimension?: number;
score_stats?: {
min: number;
max: number;
avg: number;
stddev: number;
};
collection_stats?: {
total_documents: number;
total_chunks: number;
languages: string[];
};
}
export default function RAGDemoPage() {
const { user, loading } = useAuth();
const [query, setQuery] = useState('are sd card backups encrypted?');
const [results, setResults] = useState<SearchResult[]>([]);
const [debugInfo, setDebugInfo] = useState<DebugInfo>({});
const [searchTime, setSearchTime] = useState(0);
const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState('');
// Configuration state
const [config, setConfig] = useState({
max_results: 10,
score_threshold: 0.3,
collection_name: '',
chunk_size: 300,
chunk_overlap: 50,
enable_hybrid: false,
vector_weight: 0.7,
bm25_weight: 0.3,
use_query_prefix: true,
use_passage_prefix: true,
show_timing: true,
show_embeddings: false,
});
// Available collections
const [collections, setCollections] = useState<string[]>([]);
const [collectionsLoading, setCollectionsLoading] = useState(false);
const presets = {
default: {
max_results: 10,
score_threshold: 0.3,
chunk_size: 300,
chunk_overlap: 50,
enable_hybrid: false,
vector_weight: 0.7,
bm25_weight: 0.3,
},
high_precision: {
max_results: 5,
score_threshold: 0.5,
chunk_size: 200,
chunk_overlap: 30,
enable_hybrid: true,
vector_weight: 0.8,
bm25_weight: 0.2,
},
high_recall: {
max_results: 20,
score_threshold: 0.1,
chunk_size: 400,
chunk_overlap: 100,
enable_hybrid: true,
vector_weight: 0.6,
bm25_weight: 0.4,
},
hybrid: {
max_results: 10,
score_threshold: 0.2,
chunk_size: 300,
chunk_overlap: 50,
enable_hybrid: true,
vector_weight: 0.5,
bm25_weight: 0.5,
},
};
useEffect(() => {
// Check if we have tokens in localStorage but not in tokenManager
const syncTokens = async () => {
const rawTokens = localStorage.getItem('auth_tokens');
if (rawTokens && !tokenManager.isAuthenticated()) {
try {
const tokens = JSON.parse(rawTokens);
// Sync tokens to tokenManager
tokenManager.setTokens(
tokens.access_token,
tokens.refresh_token,
Math.floor((tokens.access_expires_at - Date.now()) / 1000)
);
console.log('RAG Demo: Tokens synced from localStorage to tokenManager');
} catch (e) {
console.error('RAG Demo: Failed to sync tokens:', e);
}
}
loadCollections();
};
syncTokens();
}, [user]);
const loadCollections = async () => {
setCollectionsLoading(true);
try {
console.log('RAG Demo: Loading collections...');
console.log('RAG Demo: User authenticated:', !!user);
console.log('RAG Demo: TokenManager authenticated:', tokenManager.isAuthenticated());
const token = await tokenManager.getAccessToken();
console.log('RAG Demo: Token retrieved:', token ? 'Yes' : 'No');
console.log('RAG Demo: Token expiry:', tokenManager.getTokenExpiry());
const headers: Record<string, string> = {
'Content-Type': 'application/json',
};
if (token) {
headers['Authorization'] = `Bearer ${token}`;
console.log('RAG Demo: Authorization header set');
} else {
console.warn('RAG Demo: No token available');
}
const response = await fetch('/api/rag/debug/collections', { headers });
console.log('RAG Demo: Collections response status:', response.status);
if (response.ok) {
const data = await response.json();
console.log('RAG Demo: Collections loaded:', data.collections);
setCollections(data.collections || []);
// Auto-select first collection if none selected
if (data.collections && data.collections.length > 0 && !config.collection_name) {
setConfig(prev => ({ ...prev, collection_name: data.collections[0] }));
}
} else {
const errorText = await response.text();
console.error('RAG Demo: Collections failed:', response.status, errorText);
}
} catch (err) {
console.error('RAG Demo: Failed to load collections:', err);
} finally {
setCollectionsLoading(false);
}
};
const loadPreset = (presetName: keyof typeof presets) => {
setConfig(prev => ({
...prev,
...presets[presetName],
}));
};
const performSearch = async () => {
if (!query.trim()) return;
if (!config.collection_name) {
setError('Please select a collection');
return;
}
setIsLoading(true);
setError('');
setResults([]);
try {
const token = await tokenManager.getAccessToken();
const headers: Record<string, string> = {
'Content-Type': 'application/json',
};
if (token) {
headers['Authorization'] = `Bearer ${token}`;
}
const response = await fetch('/api/rag/debug/search', {
method: 'POST',
headers,
body: JSON.stringify({
query,
max_results: config.max_results,
score_threshold: config.score_threshold,
collection_name: config.collection_name,
config,
}),
});
if (!response.ok) {
throw new Error(`Search failed: ${response.statusText}`);
}
const data = await response.json();
setResults(data.results || []);
setDebugInfo(data.debug_info || {});
setSearchTime(data.search_time_ms || 0);
} catch (err) {
setError(err instanceof Error ? err.message : 'Unknown error');
} finally {
setIsLoading(false);
}
};
const updateConfig = (key: string, value: any) => {
setConfig(prev => ({ ...prev, [key]: value }));
};
if (loading) {
return (
<div className="flex items-center justify-center min-h-screen">
<div className="text-lg">Loading...</div>
</div>
);
}
if (!user) {
return (
<div className="flex items-center justify-center min-h-screen">
<div className="text-center">
<h1 className="text-2xl font-bold mb-4">RAG Demo</h1>
<p>Please log in to access the RAG demo interface.</p>
</div>
</div>
);
}
return (
<div className="container mx-auto px-4 py-8 max-w-7xl">
<h1 className="text-3xl font-bold mb-2">🔍 RAG Search Demo</h1>
<p className="text-gray-600 mb-6">Test and tune your RAG system with real-time search and debugging</p>
<div className="grid grid-cols-1 lg:grid-cols-4 gap-6">
{/* Search Results - Main Content */}
<div className="lg:col-span-3 space-y-6">
{/* Preset Buttons */}
<div className="flex flex-wrap gap-2">
{Object.entries(presets).map(([name, _]) => (
<button
key={name}
onClick={() => loadPreset(name as keyof typeof presets)}
className="px-3 py-1 bg-gray-100 hover:bg-gray-200 rounded-md text-sm capitalize"
>
{name.replace('_', ' ')}
</button>
))}
</div>
{/* Search Box */}
<div className="bg-white rounded-lg shadow p-6">
<div className="flex gap-2">
<input
type="text"
value={query}
onChange={(e) => setQuery(e.target.value)}
onKeyPress={(e) => e.key === 'Enter' && performSearch()}
placeholder="Enter your search query..."
className="flex-1 px-4 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500"
/>
<button
onClick={performSearch}
disabled={isLoading || !config.collection_name}
className="px-6 py-2 bg-blue-500 text-white rounded-md hover:bg-blue-600 disabled:opacity-50"
>
{isLoading ? 'Searching...' : 'Search'}
</button>
</div>
{error && (
<div className="mt-4 p-4 bg-red-100 text-red-700 rounded-md">
Error: {error}
</div>
)}
{/* Results Summary */}
{results.length > 0 && (
<div className="mt-4 p-3 bg-blue-50 rounded-md">
<p className="text-sm">
Found <strong>{results.length}</strong> results in <strong>{searchTime.toFixed(0)}ms</strong>
{config.enable_hybrid && (
<span className="ml-2 text-green-600"> Hybrid Search Enabled</span>
)}
</p>
</div>
)}
</div>
{/* Search Results */}
<div className="space-y-4">
{results.map((result, index) => (
<div key={index} className="bg-white rounded-lg shadow p-6">
<div className="flex justify-between items-start mb-3">
<h3 className="text-lg font-semibold">Result {index + 1}</h3>
<span className={`px-3 py-1 rounded-full text-sm font-medium ${
result.score >= 0.5 ? 'bg-green-100 text-green-800' :
result.score >= 0.3 ? 'bg-yellow-100 text-yellow-800' :
'bg-red-100 text-red-800'
}`}>
Score: {result.score.toFixed(4)}
</span>
</div>
<div className="text-gray-700 mb-4 whitespace-pre-wrap">
{result.document.content}
</div>
{/* Metadata */}
<div className="text-sm text-gray-500 mb-3">
{result.document.metadata.content_type && (
<span>Type: {result.document.metadata.content_type}</span>
)}
{result.document.metadata.language && (
<span className="ml-3">Language: {result.document.metadata.language}</span>
)}
{result.document.metadata.filename && (
<span className="ml-3">File: {result.document.metadata.filename}</span>
)}
{result.document.metadata.chunk_index !== undefined && (
<span className="ml-3">
Chunk: {result.document.metadata.chunk_index + 1}/{result.document.metadata.chunk_count || '?'}
</span>
)}
</div>
{/* Debug Details */}
{config.show_timing && result.debug_info && (
<div className="mt-4 p-3 bg-gray-50 rounded-md text-xs font-mono">
<p><strong>Debug Information:</strong></p>
{result.debug_info.vector_score !== undefined && (
<p>Vector Score: {result.debug_info.vector_score.toFixed(4)}</p>
)}
{result.debug_info.bm25_score !== undefined && (
<p>BM25 Score: {result.debug_info.bm25_score.toFixed(4)}</p>
)}
{result.document.metadata.question && (
<div className="mt-2">
<p><strong>Question:</strong> {result.document.metadata.question}</p>
</div>
)}
</div>
)}
</div>
))}
</div>
{/* Debug Section */}
{debugInfo && Object.keys(debugInfo).length > 0 && (
<div className="bg-gray-900 text-green-400 rounded-lg shadow p-6 font-mono text-sm">
<h3 className="text-lg font-semibold mb-4">Debug Information</h3>
{debugInfo.score_stats && (
<div className="mb-4">
<p className="font-semibold mb-2">Score Statistics:</p>
<div className="grid grid-cols-2 md:grid-cols-4 gap-2 text-xs">
<div>Min: {debugInfo.score_stats.min?.toFixed(4)}</div>
<div>Max: {debugInfo.score_stats.max?.toFixed(4)}</div>
<div>Avg: {debugInfo.score_stats.avg?.toFixed(4)}</div>
<div>StdDev: {debugInfo.score_stats.stddev?.toFixed(4)}</div>
</div>
</div>
)}
{debugInfo.collection_stats && (
<div className="mb-4">
<p className="font-semibold mb-2">Collection Stats:</p>
<div className="text-xs">
<p>Total Documents: {debugInfo.collection_stats.total_documents}</p>
<p>Total Chunks: {debugInfo.collection_stats.total_chunks}</p>
<p>Languages: {debugInfo.collection_stats.languages?.join(', ')}</p>
</div>
</div>
)}
{debugInfo.query_embedding && config.show_embeddings && (
<div>
<p className="font-semibold mb-2">Query Embedding (first 10 dims):</p>
<p className="text-xs">
[{debugInfo.query_embedding.slice(0, 10).map(x => x.toFixed(6)).join(', ')}...]
</p>
</div>
)}
</div>
)}
</div>
{/* Configuration Panel */}
<div className="space-y-6">
<div className="bg-white rounded-lg shadow p-6">
<h2 className="text-xl font-semibold mb-4"> Configuration</h2>
<div className="space-y-4">
{/* Search Settings */}
<div>
<h3 className="font-medium mb-2">Search Settings</h3>
<div className="space-y-3">
<div>
<label className="block text-sm mb-1">Max Results: {config.max_results}</label>
<input
type="range"
min="1"
max="50"
value={config.max_results}
onChange={(e) => updateConfig('max_results', parseInt(e.target.value))}
className="w-full"
/>
</div>
<div>
<label className="block text-sm mb-1">Score Threshold: {config.score_threshold}</label>
<input
type="range"
min="0"
max="1"
step="0.05"
value={config.score_threshold}
onChange={(e) => updateConfig('score_threshold', parseFloat(e.target.value))}
className="w-full"
/>
</div>
<div>
<label className="block text-sm mb-1">Collection Name</label>
{collectionsLoading ? (
<select
disabled
className="w-full px-3 py-1 border border-gray-300 rounded-md text-sm bg-gray-50"
>
<option>Loading collections...</option>
</select>
) : (
<select
value={config.collection_name}
onChange={(e) => updateConfig('collection_name', e.target.value)}
className="w-full px-3 py-1 border border-gray-300 rounded-md text-sm"
>
<option value="">Select a collection...</option>
{collections.map(collection => (
<option key={collection} value={collection}>
{collection}
</option>
))}
</select>
)}
</div>
</div>
</div>
{/* Chunking Settings */}
<div>
<h3 className="font-medium mb-2">Chunking Settings</h3>
<div className="space-y-3">
<div>
<label className="block text-sm mb-1">Chunk Size: {config.chunk_size}</label>
<input
type="range"
min="100"
max="1000"
step="50"
value={config.chunk_size}
onChange={(e) => updateConfig('chunk_size', parseInt(e.target.value))}
className="w-full"
/>
</div>
<div>
<label className="block text-sm mb-1">Chunk Overlap: {config.chunk_overlap}</label>
<input
type="range"
min="0"
max="200"
step="10"
value={config.chunk_overlap}
onChange={(e) => updateConfig('chunk_overlap', parseInt(e.target.value))}
className="w-full"
/>
</div>
</div>
</div>
{/* Hybrid Search */}
<div>
<h3 className="font-medium mb-2">Hybrid Search</h3>
<div className="space-y-3">
<label className="flex items-center">
<input
type="checkbox"
checked={config.enable_hybrid}
onChange={(e) => updateConfig('enable_hybrid', e.target.checked)}
className="mr-2"
/>
<span className="text-sm">Enable Hybrid Search</span>
</label>
{config.enable_hybrid && (
<>
<div>
<label className="block text-sm mb-1">Vector Weight: {config.vector_weight}</label>
<input
type="range"
min="0"
max="1"
step="0.05"
value={config.vector_weight}
onChange={(e) => updateConfig('vector_weight', parseFloat(e.target.value))}
className="w-full"
/>
</div>
<div>
<label className="block text-sm mb-1">BM25 Weight: {config.bm25_weight}</label>
<input
type="range"
min="0"
max="1"
step="0.05"
value={config.bm25_weight}
onChange={(e) => updateConfig('bm25_weight', parseFloat(e.target.value))}
className="w-full"
/>
</div>
</>
)}
</div>
</div>
{/* Debug Options */}
<div>
<h3 className="font-medium mb-2">Debug Options</h3>
<div className="space-y-2">
<label className="flex items-center">
<input
type="checkbox"
checked={config.show_timing}
onChange={(e) => updateConfig('show_timing', e.target.checked)}
className="mr-2"
/>
<span className="text-sm">Show Timing</span>
</label>
<label className="flex items-center">
<input
type="checkbox"
checked={config.show_embeddings}
onChange={(e) => updateConfig('show_embeddings', e.target.checked)}
className="mr-2"
/>
<span className="text-sm">Show Embeddings</span>
</label>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
);
}

View File

@@ -85,8 +85,31 @@ function RAGPageContent() {
const loadStats = async () => {
try {
const data = await apiClient.get('/api-internal/v1/rag/stats')
console.log('Stats API response:', data)
// Check if the response has the expected structure
if (data && data.stats && data.stats.collections) {
console.log('✓ Stats has collections property')
setStats(data.stats)
} else {
console.error('✗ Invalid stats structure:', data)
// Set default empty stats to prevent error
setStats({
collections: { total: 0, active: 0 },
documents: { total: 0, processing: 0, processed: 0 },
storage: { total_size_bytes: 0, total_size_mb: 0 },
vectors: { total: 0 }
})
}
} catch (error) {
console.error('Error loading stats:', error)
// Set default empty stats on error
setStats({
collections: { total: 0, active: 0 },
documents: { total: 0, processing: 0, processed: 0 },
storage: { total_size_bytes: 0, total_size_mb: 0 },
vectors: { total: 0 }
})
}
}

View File

@@ -87,9 +87,8 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
const [messages, setMessages] = useState<ChatMessage[]>([])
const [input, setInput] = useState("")
const [isLoading, setIsLoading] = useState(false)
const [conversationId, setConversationId] = useState<string | null>(null)
const scrollAreaRef = useRef<HTMLDivElement>(null)
const { toast } = useToast()
const { success: toastSuccess, error: toastError } = useToast()
const scrollToBottom = useCallback(() => {
if (scrollAreaRef.current) {
@@ -130,43 +129,21 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
console.log('=== CHAT REQUEST DEBUG ===', debugInfo)
try {
// Build conversation history in OpenAI format
let data: any
// Use internal API
const conversationHistory = messages.map(msg => ({
role: msg.role,
content: msg.content
}))
const requestData = {
messages: conversationHistory,
conversation_id: conversationId || undefined
}
console.log('=== CHAT API REQUEST ===', {
url: `/api-internal/v1/chatbot/${chatbotId}/chat/completions`,
data: requestData
})
const data = await chatbotApi.chat(
data = await chatbotApi.sendMessage(
chatbotId,
messageToSend,
requestData
undefined, // No conversation ID
conversationHistory
)
console.log('=== CHAT API RESPONSE ===', {
status: 'success',
data,
responseKeys: Object.keys(data),
hasChoices: !!data.choices,
hasResponse: !!data.response,
content: data.choices?.[0]?.message?.content || data.response || 'No response'
})
// Update conversation ID if it's a new conversation
if (!conversationId && data.conversation_id) {
setConversationId(data.conversationId)
console.log('Updated conversation ID:', data.conversation_id)
}
const assistantMessage: ChatMessage = {
id: data.id || generateTimestampId('msg'),
role: 'assistant',
@@ -177,52 +154,21 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
setMessages(prev => [...prev, assistantMessage])
} catch (error: any) {
console.error('=== CHAT ERROR DEBUG ===', {
errorType: typeof error,
error,
errorMessage: error?.message,
errorCode: error?.code,
errorResponse: error?.response?.data,
errorStatus: error?.response?.status,
errorConfig: error?.config,
errorStack: error?.stack,
timestamp: new Date().toISOString()
})
} catch (error) {
const appError = error as AppError
// Handle different error types
if (error && typeof error === 'object') {
if ('response' in error) {
// Axios error
const status = error.response?.status
if (status === 401) {
toast.error("Authentication Required", "Please log in to continue chatting.")
} else if (status === 429) {
toast.error("Rate Limit", "Too many requests. Please wait and try again.")
// More specific error handling
if (appError.code === 'UNAUTHORIZED') {
toastError("Authentication Required", "Please log in to continue chatting.")
} else if (appError.code === 'NETWORK_ERROR') {
toastError("Connection Error", "Please check your internet connection and try again.")
} else {
toast.error("Message Failed", error.response?.data?.detail || "Failed to send message. Please try again.")
}
} else if ('code' in error) {
// Custom error with code
if (error.code === 'UNAUTHORIZED') {
toast.error("Authentication Required", "Please log in to continue chatting.")
} else if (error.code === 'NETWORK_ERROR') {
toast.error("Connection Error", "Please check your internet connection and try again.")
} else {
toast.error("Message Failed", error.message || "Failed to send message. Please try again.")
}
} else {
// Generic error
toast.error("Message Failed", error.message || "Failed to send message. Please try again.")
}
} else {
// Fallback for unknown error types
toast.error("Message Failed", "An unexpected error occurred. Please try again.")
toastError("Message Failed", appError.message || "Failed to send message. Please try again.")
}
} finally {
setIsLoading(false)
}
}, [input, isLoading, chatbotId, conversationId, messages, toast])
}, [input, isLoading, chatbotId, messages, toastError])
const handleKeyPress = useCallback((e: React.KeyboardEvent) => {
if (e.key === 'Enter' && !e.shiftKey) {
@@ -234,11 +180,11 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
const copyMessage = useCallback(async (content: string) => {
try {
await navigator.clipboard.writeText(content)
toast.success("Copied", "Message copied to clipboard")
toastSuccess("Copied", "Message copied to clipboard")
} catch (error) {
toast.error("Copy Failed", "Unable to copy message to clipboard")
toastError("Copy Failed", "Unable to copy message to clipboard")
}
}, [toast])
}, [toastSuccess, toastError])
const formatTime = useCallback((date: Date) => {
return date.toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' })

View File

@@ -138,6 +138,7 @@ export function ChatbotManager() {
const [editingChatbot, setEditingChatbot] = useState<ChatbotInstance | null>(null)
const [showChatInterface, setShowChatInterface] = useState(false)
const [testingChatbot, setTestingChatbot] = useState<ChatbotInstance | null>(null)
const [chatbotApiKeys, setChatbotApiKeys] = useState<Record<string, string>>({})
const { toast } = useToast()
// New chatbot form state

View File

@@ -9,7 +9,7 @@ import { Badge } from "@/components/ui/badge"
import { Separator } from "@/components/ui/separator"
import { AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger } from "@/components/ui/alert-dialog"
import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog"
import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter } from "lucide-react"
import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter, RefreshCw } from "lucide-react"
import { useToast } from "@/hooks/use-toast"
import { apiClient } from "@/lib/api-client"
import { config } from "@/lib/config"
@@ -56,6 +56,7 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
const [filterStatus, setFilterStatus] = useState("all")
const [selectedDocument, setSelectedDocument] = useState<Document | null>(null)
const [deleting, setDeleting] = useState<string | null>(null)
const [reprocessing, setReprocessing] = useState<string | null>(null)
const { toast } = useToast()
useEffect(() => {
@@ -157,6 +158,43 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
}
}
const handleReprocessDocument = async (documentId: string) => {
setReprocessing(documentId)
try {
await apiClient.post(`/api-internal/v1/rag/documents/${documentId}/reprocess`)
// Update the document status to processing in the UI
setDocuments(prev => prev.map(doc =>
doc.id === documentId
? { ...doc, status: 'processing' as const, processed_at: new Date().toISOString() }
: doc
))
toast({
title: "Success",
description: "Document reprocessing started",
})
// Reload documents after a short delay to see status updates
setTimeout(() => {
loadDocuments()
}, 2000)
} catch (error) {
const errorMessage = error instanceof Error ? error.message : "Failed to reprocess document"
toast({
title: "Error",
description: errorMessage.includes("Cannot reprocess document with status 'processed'")
? "Cannot reprocess documents that are already processed"
: errorMessage,
variant: "destructive",
})
} finally {
setReprocessing(null)
}
}
const formatFileSize = (bytes: number) => {
if (bytes === 0) return '0 Bytes'
const k = 1024
@@ -432,6 +470,21 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
<Download className="h-4 w-4" />
</Button>
<Button
variant="ghost"
size="sm"
className="h-8 w-8 p-0 hover:bg-blue-100"
onClick={() => handleReprocessDocument(document.id)}
disabled={reprocessing === document.id || document.status === 'processed'}
title={document.status === 'processed' ? "Document already processed" : "Reprocess document"}
>
{reprocessing === document.id ? (
<RefreshCw className="h-4 w-4 animate-spin" />
) : (
<RefreshCw className={`h-4 w-4 ${document.status === 'processed' ? 'text-gray-400' : ''}`} />
)}
</Button>
<AlertDialog>
<AlertDialogTrigger asChild>
<Button

View File

@@ -73,6 +73,7 @@ const Navigation = () => {
children: [
{ href: "/llm", label: "Models & Config" },
{ href: "/playground", label: "Playground" },
{ href: "/rag-demo", label: "RAG Demo" },
]
},
{

View File

@@ -1,119 +1,117 @@
import axios from 'axios';
import Cookies from 'js-cookie';
// Dynamic base URL with protocol detection
const getApiBaseUrl = (): string => {
if (typeof window !== 'undefined') {
// Client-side: use the same protocol as the current page
const protocol = window.location.protocol.slice(0, -1); // Remove ':' from 'https:'
const host = window.location.hostname;
return `${protocol}://${host}`;
export interface AppError extends Error {
code: 'UNAUTHORIZED' | 'NETWORK_ERROR' | 'VALIDATION_ERROR' | 'NOT_FOUND' | 'FORBIDDEN' | 'TIMEOUT' | 'UNKNOWN'
status?: number
details?: any
}
// Server-side: use environment variable or default to localhost
const baseUrl = process.env.NEXT_PUBLIC_BASE_URL || 'localhost';
const protocol = process.env.NODE_ENV === 'production' ? 'https' : 'http';
return `${protocol}://${baseUrl}`;
};
const axiosInstance = axios.create({
baseURL: getApiBaseUrl(),
headers: {
'Content-Type': 'application/json',
},
});
// Request interceptor to add auth token
axiosInstance.interceptors.request.use(
(config) => {
const token = Cookies.get('access_token');
if (token) {
config.headers.Authorization = `Bearer ${token}`;
function makeError(message: string, code: AppError['code'], status?: number, details?: any): AppError {
const err = new Error(message) as AppError
err.code = code
err.status = status
err.details = details
return err
}
return config;
},
(error) => {
return Promise.reject(error);
}
);
// Response interceptor to handle token refresh
axiosInstance.interceptors.response.use(
(response) => response,
async (error) => {
const originalRequest = error.config;
if (error.response?.status === 401 && !originalRequest._retry) {
originalRequest._retry = true;
async function getAuthHeader(): Promise<Record<string, string>> {
try {
const refreshToken = Cookies.get('refresh_token');
if (refreshToken) {
const response = await axios.post(`${getApiBaseUrl()}/api-internal/v1/auth/refresh`, {
refresh_token: refreshToken,
});
const { access_token } = response.data;
Cookies.set('access_token', access_token, { expires: 7 });
originalRequest.headers.Authorization = `Bearer ${access_token}`;
return axiosInstance(originalRequest);
}
} catch (refreshError) {
// Refresh failed, redirect to login
Cookies.remove('access_token');
Cookies.remove('refresh_token');
window.location.href = '/login';
return Promise.reject(refreshError);
const { tokenManager } = await import('./token-manager')
const token = await tokenManager.getAccessToken()
return token ? { Authorization: `Bearer ${token}` } : {}
} catch {
return {}
}
}
return Promise.reject(error);
async function request<T = any>(method: string, url: string, body?: any, extraInit?: RequestInit): Promise<T> {
try {
const headers: Record<string, string> = {
'Accept': 'application/json',
...(method !== 'GET' && method !== 'HEAD' ? { 'Content-Type': 'application/json' } : {}),
...(await getAuthHeader()),
...(extraInit?.headers as Record<string, string> | undefined),
}
const res = await fetch(url, {
method,
headers,
body: body != null && method !== 'GET' && method !== 'HEAD' ? JSON.stringify(body) : undefined,
...extraInit,
})
if (!res.ok) {
let details: any = undefined
try { details = await res.json() } catch { details = await res.text() }
const status = res.status
if (status === 401) throw makeError('Unauthorized', 'UNAUTHORIZED', status, details)
if (status === 403) throw makeError('Forbidden', 'FORBIDDEN', status, details)
if (status === 404) throw makeError('Not found', 'NOT_FOUND', status, details)
if (status === 400) throw makeError('Validation error', 'VALIDATION_ERROR', status, details)
throw makeError('Request failed', 'UNKNOWN', status, details)
}
const contentType = res.headers.get('content-type') || ''
if (contentType.includes('application/json')) {
return (await res.json()) as T
}
// @ts-expect-error allow non-json generic
return (await res.text()) as T
} catch (e: any) {
if (e?.code) throw e
if (e?.name === 'AbortError') throw makeError('Request timed out', 'TIMEOUT')
throw makeError(e?.message || 'Network error', 'NETWORK_ERROR')
}
}
);
export const apiClient = {
get: async <T = any>(url: string, config?: any): Promise<T> => {
const response = await axiosInstance.get(url, config);
return response.data;
},
get: <T = any>(url: string, init?: RequestInit) => request<T>('GET', url, undefined, init),
post: <T = any>(url: string, body?: any, init?: RequestInit) => request<T>('POST', url, body, init),
put: <T = any>(url: string, body?: any, init?: RequestInit) => request<T>('PUT', url, body, init),
delete: <T = any>(url: string, init?: RequestInit) => request<T>('DELETE', url, undefined, init),
}
post: async <T = any>(url: string, data?: any, config?: any): Promise<T> => {
const response = await axiosInstance.post(url, data, config);
return response.data;
},
put: async <T = any>(url: string, data?: any, config?: any): Promise<T> => {
const response = await axiosInstance.put(url, data, config);
return response.data;
},
delete: async <T = any>(url: string, config?: any): Promise<T> => {
const response = await axiosInstance.delete(url, config);
return response.data;
},
patch: async <T = any>(url: string, data?: any, config?: any): Promise<T> => {
const response = await axiosInstance.patch(url, data, config);
return response.data;
},
};
// Chatbot specific API methods
export const chatbotApi = {
create: async (data: any) => apiClient.post('/api-internal/v1/chatbot/create', data),
list: async () => apiClient.get('/api-internal/v1/chatbot/list'),
update: async (id: string, data: any) => apiClient.put(`/api-internal/v1/chatbot/update/${id}`, data),
delete: async (id: string) => apiClient.delete(`/api-internal/v1/chatbot/delete/${id}`),
chat: async (id: string, message: string, config?: any) => {
// For OpenAI-compatible chat completions
const messages = [
...(config?.messages || []),
{ role: 'user', content: message }
]
return apiClient.post(`/api-internal/v1/chatbot/${id}/chat/completions`, {
messages,
...config
})
async listChatbots() {
try {
return await apiClient.get('/api-internal/v1/chatbot/list')
} catch {
return await apiClient.get('/api-internal/v1/chatbot/instances')
}
},
};
createChatbot(config: any) {
return apiClient.post('/api-internal/v1/chatbot/create', config)
},
updateChatbot(id: string, config: any) {
return apiClient.put(`/api-internal/v1/chatbot/update/${encodeURIComponent(id)}`, config)
},
deleteChatbot(id: string) {
return apiClient.delete(`/api-internal/v1/chatbot/delete/${encodeURIComponent(id)}`)
},
// Legacy method with JWT auth (to be deprecated)
sendMessage(chatbotId: string, message: string, conversationId?: string, history?: Array<{role: string; content: string}>) {
const body: any = { message }
if (conversationId) body.conversation_id = conversationId
if (history) body.history = history
return apiClient.post(`/api-internal/v1/chatbot/chat/${encodeURIComponent(chatbotId)}`, body)
},
// OpenAI-compatible chatbot API with API key auth
sendOpenAIChatMessage(chatbotId: string, messages: Array<{role: string; content: string}>, apiKey: string, options?: {
temperature?: number
max_tokens?: number
stream?: boolean
}) {
const body: any = {
messages,
...options
}
return fetch(`/api/v1/chatbot/external/${encodeURIComponent(chatbotId)}/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${apiKey}`
},
body: JSON.stringify(body)
}).then(res => res.json())
}
}

View File

@@ -1,60 +1,14 @@
export const config = {
API_BASE_URL: process.env.NEXT_PUBLIC_BASE_URL || '',
APP_NAME: process.env.NEXT_PUBLIC_APP_NAME || 'Enclava',
DEFAULT_LANGUAGE: 'en',
SUPPORTED_LANGUAGES: ['en', 'es', 'fr', 'de', 'it'],
getPublicApiUrl() {
if (this.API_BASE_URL) {
return this.API_BASE_URL
getPublicApiUrl(): string {
if (typeof process !== 'undefined' && process.env.NEXT_PUBLIC_BASE_URL) {
return process.env.NEXT_PUBLIC_BASE_URL
}
if (typeof window !== 'undefined' && window.location.origin) {
if (typeof window !== 'undefined') {
return window.location.origin
}
return ''
return 'http://localhost:3000'
},
// Feature flags
FEATURES: {
RAG: true,
PLUGINS: true,
ANALYTICS: true,
AUDIT_LOGS: true,
BUDGET_MANAGEMENT: true,
getAppName(): string {
return process.env.NEXT_PUBLIC_APP_NAME || 'Enclava'
},
// Default values
DEFAULTS: {
TEMPERATURE: 0.7,
MAX_TOKENS: 1000,
TOP_K: 5,
MEMORY_LENGTH: 10,
},
// API endpoints
ENDPOINTS: {
AUTH: {
LOGIN: '/api/auth/login',
REGISTER: '/api/auth/register',
REFRESH: '/api/auth/refresh',
ME: '/api/auth/me',
},
CHATBOT: {
LIST: '/api/chatbot/list',
CREATE: '/api/chatbot/create',
UPDATE: '/api/chatbot/update/:id',
DELETE: '/api/chatbot/delete/:id',
CHAT: '/api/chatbot/chat',
},
LLM: {
MODELS: '/api/llm/models',
API_KEYS: '/api/llm/api-keys',
BUDGETS: '/api/llm/budgets',
},
RAG: {
COLLECTIONS: '/api/rag/collections',
DOCUMENTS: '/api/rag/documents',
},
},
};
}

View File

@@ -1,138 +1,51 @@
import Cookies from 'js-cookie';
import { tokenManager } from './token-manager'
export const downloadFile = async (url: string, filename?: string): Promise<void> => {
try {
const response = await fetch(url);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
export async function downloadFile(path: string, filename: string, params?: URLSearchParams | Record<string, string>) {
const url = new URL(path, typeof window !== 'undefined' ? window.location.origin : 'http://localhost:3000')
if (params) {
const p = params instanceof URLSearchParams ? params : new URLSearchParams(params)
p.forEach((v, k) => url.searchParams.set(k, v))
}
// Get the filename from the response headers if not provided
const contentDisposition = response.headers.get('Content-Disposition');
let defaultFilename = filename || 'download';
const token = await tokenManager.getAccessToken()
const res = await fetch(url.toString(), {
headers: {
...(token ? { Authorization: `Bearer ${token}` } : {}),
},
})
if (!res.ok) throw new Error(`Failed to download file (${res.status})`)
const blob = await res.blob()
if (contentDisposition) {
const filenameMatch = contentDisposition.match(/filename[^;=\n]*=((['"]).*?\2|[^;\n]*)/);
if (filenameMatch && filenameMatch[1]) {
defaultFilename = filenameMatch[1].replace(/['"]/g, '');
if (typeof window !== 'undefined') {
const link = document.createElement('a')
const href = URL.createObjectURL(blob)
link.href = href
link.download = filename
document.body.appendChild(link)
link.click()
link.remove()
URL.revokeObjectURL(href)
}
}
// Get the blob from the response
const blob = await response.blob();
export async function uploadFile(path: string, file: File, extraFields?: Record<string, string>) {
const form = new FormData()
form.append('file', file)
if (extraFields) Object.entries(extraFields).forEach(([k, v]) => form.append(k, v))
// Create a temporary URL for the blob
const blobUrl = window.URL.createObjectURL(blob);
// Create a temporary link element
const link = document.createElement('a');
link.href = blobUrl;
link.download = defaultFilename;
// Append the link to the body
document.body.appendChild(link);
// Trigger the download
link.click();
// Clean up
document.body.removeChild(link);
window.URL.revokeObjectURL(blobUrl);
} catch (error) {
console.error('Error downloading file:', error);
throw error;
const token = await tokenManager.getAccessToken()
const res = await fetch(path, {
method: 'POST',
headers: {
...(token ? { Authorization: `Bearer ${token}` } : {}),
},
body: form,
})
if (!res.ok) {
let details: any
try { details = await res.json() } catch { details = await res.text() }
throw new Error(typeof details === 'string' ? details : (details?.error || 'Upload failed'))
}
};
export const downloadFileFromData = (
data: Blob | string,
filename: string,
mimeType?: string
): void => {
try {
let blob: Blob;
if (typeof data === 'string') {
blob = new Blob([data], { type: mimeType || 'text/plain' });
} else {
blob = data;
return await res.json()
}
const blobUrl = window.URL.createObjectURL(blob);
const link = document.createElement('a');
link.href = blobUrl;
link.download = filename;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
window.URL.revokeObjectURL(blobUrl);
} catch (error) {
console.error('Error downloading file from data:', error);
throw error;
}
};
export const uploadFile = async (
file: File,
url: string,
onProgress?: (progress: number) => void,
additionalData?: Record<string, any>
): Promise<any> => {
try {
const formData = new FormData();
formData.append('file', file);
// Add additional form data if provided
if (additionalData) {
Object.entries(additionalData).forEach(([key, value]) => {
formData.append(key, value);
});
}
const xhr = new XMLHttpRequest();
return new Promise((resolve, reject) => {
xhr.upload.onprogress = (event) => {
if (event.lengthComputable && onProgress) {
const progress = (event.loaded / event.total) * 100;
onProgress(progress);
}
};
xhr.onload = () => {
if (xhr.status >= 200 && xhr.status < 300) {
try {
const response = JSON.parse(xhr.responseText);
resolve(response);
} catch (error) {
resolve(xhr.responseText);
}
} else {
reject(new Error(`Upload failed with status ${xhr.status}`));
}
};
xhr.onerror = () => {
reject(new Error('Network error during upload'));
};
// Get authentication token from cookies
const token = Cookies.get('access_token');
xhr.open('POST', url, true);
// Set the authorization header if token exists
if (token) {
xhr.setRequestHeader('Authorization', `Bearer ${token}`);
}
xhr.send(formData);
});
} catch (error) {
console.error('Error uploading file:', error);
throw error;
}
};

View File

@@ -1,36 +1,15 @@
export function generateId(): string {
return Math.random().toString(36).substr(2, 9);
export function generateId(prefix = "id"): string {
const rand = Math.random().toString(36).slice(2, 10)
return `${prefix}_${rand}`
}
export function generateUniqueId(): string {
return Date.now().toString(36) + Math.random().toString(36).substr(2);
export function generateShortId(prefix = "id"): string {
const rand = Math.random().toString(36).slice(2, 7)
return `${prefix}_${rand}`
}
export function generateMessageId(): string {
return `msg_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
}
export function generateChatId(): string {
return `chat_${Date.now()}_${Math.random().toString(36).substr(2, 5)}`;
}
export function generateSessionId(): string {
return `sess_${Date.now()}_${Math.random().toString(36).substr(2, 8)}`;
}
export function generateShortId(): string {
return Math.random().toString(36).substr(2, 6);
}
export function generateTimestampId(): string {
return `ts_${Date.now()}_${Math.random().toString(36).substr(2, 6)}`;
}
export function isValidId(id: string): boolean {
return typeof id === 'string' && id.length > 0;
}
export function extractIdFromUrl(url: string): string | null {
const match = url.match(/\/([a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12})$/);
return match ? match[1] : null;
export function generateTimestampId(prefix = "id"): string {
const ts = Date.now()
const rand = Math.floor(Math.random() * 1000).toString().padStart(3, '0')
return `${prefix}_${ts}_${rand}`
}

View File

@@ -1,176 +1,31 @@
import { NextRequest, NextResponse } from 'next/server';
const BACKEND_URL = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`
// This is a proxy auth utility for server-side API routes
// It handles authentication and proxying requests to the backend
export interface ProxyAuthConfig {
backendUrl: string;
requireAuth?: boolean;
allowedRoles?: string[];
function mapPath(path: string): string {
// Convert '/api-internal/..' to backend '/api/..'
if (path.startsWith('/api-internal/')) {
return path.replace('/api-internal/', '/api/')
}
return path
}
export class ProxyAuth {
private config: ProxyAuthConfig;
constructor(config: ProxyAuthConfig) {
this.config = {
requireAuth: true,
...config,
};
}
async authenticate(request: NextRequest): Promise<{ success: boolean; user?: any; error?: string }> {
// For server-side auth, we would typically validate the token
// This is a simplified implementation
const authHeader = request.headers.get('authorization');
if (!authHeader || !authHeader.startsWith('Bearer ')) {
return { success: false, error: 'Missing or invalid authorization header' };
}
const token = authHeader.substring(7);
// Here you would validate the token with your auth service
// For now, we'll just check if it exists
if (!token) {
return { success: false, error: 'Invalid token' };
}
// In a real implementation, you would decode and validate the JWT
// and check user roles if required
return {
success: true,
user: {
id: 'user-id',
email: 'user@example.com',
role: 'user'
}
};
}
async proxyRequest(
request: NextRequest,
path: string,
options?: {
method?: string;
headers?: Record<string, string>;
body?: any;
}
): Promise<NextResponse> {
// Authenticate the request if required
if (this.config.requireAuth) {
const authResult = await this.authenticate(request);
if (!authResult.success) {
return NextResponse.json(
{ error: authResult.error || 'Authentication failed' },
{ status: 401 }
);
}
// Check roles if specified
if (this.config.allowedRoles && authResult.user) {
if (!this.config.allowedRoles.includes(authResult.user.role)) {
return NextResponse.json(
{ error: 'Insufficient permissions' },
{ status: 403 }
);
}
}
}
// Build the target URL
const targetUrl = new URL(path, this.config.backendUrl);
// Copy query parameters
targetUrl.search = request.nextUrl.search;
// Prepare headers
export async function proxyRequest(path: string, init?: RequestInit): Promise<Response> {
const url = `${BACKEND_URL}${mapPath(path)}`
const headers: Record<string, string> = {
'Content-Type': 'application/json',
...options?.headers,
};
// Forward authorization header if present
const authHeader = request.headers.get('authorization');
if (authHeader) {
headers.authorization = authHeader;
...(init?.headers as Record<string, string> | undefined),
}
return fetch(url, { ...init, headers })
}
try {
const response = await fetch(targetUrl, {
method: options?.method || request.method,
headers,
body: options?.body ? JSON.stringify(options.body) : request.body,
});
// Create a new response with the data
const data = await response.json();
return NextResponse.json(data, {
status: response.status,
headers: {
'Content-Type': 'application/json',
},
});
} catch (error) {
console.error('Proxy request failed:', error);
return NextResponse.json(
{ error: 'Internal server error' },
{ status: 500 }
);
}
export async function handleProxyResponse<T = any>(response: Response, defaultMessage = 'Request failed'): Promise<T> {
if (!response.ok) {
let details: any
try { details = await response.json() } catch { details = await response.text() }
throw new Error(typeof details === 'string' ? `${defaultMessage}: ${details}` : (details?.error || defaultMessage))
}
const contentType = response.headers.get('content-type') || ''
if (contentType.includes('application/json')) return (await response.json()) as T
// @ts-ignore allow non-json
return (await response.text()) as T
}
// Utility function to create a proxy handler
export function createProxyHandler(config: ProxyAuthConfig) {
const proxyAuth = new ProxyAuth(config);
return async (request: NextRequest, { params }: { params: { path?: string[] } }) => {
const path = params?.path ? params.path.join('/') : '';
return proxyAuth.proxyRequest(request, path);
};
}
// Simplified proxy request function for direct usage
export async function proxyRequest(
request: NextRequest,
backendUrl: string,
path: string = '',
options?: {
method?: string;
headers?: Record<string, string>;
body?: any;
requireAuth?: boolean;
}
): Promise<NextResponse> {
const proxyAuth = new ProxyAuth({
backendUrl,
requireAuth: options?.requireAuth ?? true,
});
return proxyAuth.proxyRequest(request, path, options);
}
// Helper function to handle proxy responses with error handling
export async function handleProxyResponse(
request: NextRequest,
backendUrl: string,
path: string = '',
options?: {
method?: string;
headers?: Record<string, string>;
body?: any;
requireAuth?: boolean;
}
): Promise<NextResponse> {
try {
return await proxyRequest(request, backendUrl, path, options);
} catch (error) {
console.error('Proxy request failed:', error);
return NextResponse.json(
{ error: 'Internal server error' },
{ status: 500 }
);
}
}

View File

@@ -1,99 +1,140 @@
import Cookies from 'js-cookie';
import { EventEmitter } from 'events';
type Listener = (...args: any[]) => void
interface TokenManagerEvents {
tokensUpdated: [];
tokensCleared: [];
class SimpleEmitter {
private listeners = new Map<string, Set<Listener>>()
on(event: string, listener: Listener) {
if (!this.listeners.has(event)) this.listeners.set(event, new Set())
this.listeners.get(event)!.add(listener)
}
export interface TokenManagerInterface {
getTokens(): { access_token: string | null; refresh_token: string | null };
setTokens(access_token: string, refresh_token: string): void;
clearTokens(): void;
isAuthenticated(): boolean;
getAccessToken(): string | null;
getRefreshToken(): string | null;
getTokenExpiry(): { access_token_expiry: number | null; refresh_token_expiry: number | null };
on<E extends keyof TokenManagerEvents>(
event: E,
listener: (...args: TokenManagerEvents[E]) => void
): this;
off<E extends keyof TokenManagerEvents>(
event: E,
listener: (...args: TokenManagerEvents[E]) => void
): this;
off(event: string, listener: Listener) {
this.listeners.get(event)?.delete(listener)
}
class TokenManager extends EventEmitter implements TokenManagerInterface {
private static instance: TokenManager;
private constructor() {
super();
// Set max listeners to avoid memory leak warnings
this.setMaxListeners(100);
emit(event: string, ...args: any[]) {
this.listeners.get(event)?.forEach(l => l(...args))
}
}
static getInstance(): TokenManager {
if (!TokenManager.instance) {
TokenManager.instance = new TokenManager();
}
return TokenManager.instance;
interface StoredTokens {
access_token: string
refresh_token: string
access_expires_at: number // epoch ms
refresh_expires_at?: number // epoch ms
}
getTokens() {
return {
access_token: Cookies.get('access_token'),
refresh_token: Cookies.get('refresh_token'),
};
const ACCESS_LIFETIME_FALLBACK_MS = 30 * 60 * 1000 // 30 minutes
const REFRESH_LIFETIME_FALLBACK_MS = 7 * 24 * 60 * 60 * 1000 // 7 days
function now() { return Date.now() }
function readTokens(): StoredTokens | null {
if (typeof window === 'undefined') return null
try {
const raw = window.localStorage.getItem('auth_tokens')
return raw ? JSON.parse(raw) as StoredTokens : null
} catch {
return null
}
}
setTokens(access_token: string, refresh_token: string) {
// Set cookies with secure attributes
Cookies.set('access_token', access_token, {
expires: 7, // 7 days
secure: process.env.NODE_ENV === 'production',
sameSite: 'strict',
});
function writeTokens(tokens: StoredTokens | null) {
if (typeof window === 'undefined') return
if (tokens) {
window.localStorage.setItem('auth_tokens', JSON.stringify(tokens))
} else {
window.localStorage.removeItem('auth_tokens')
}
}
Cookies.set('refresh_token', refresh_token, {
expires: 30, // 30 days
secure: process.env.NODE_ENV === 'production',
sameSite: 'strict',
});
class TokenManager extends SimpleEmitter {
private refreshTimer: ReturnType<typeof setTimeout> | null = null
// Emit event
this.emit('tokensUpdated');
isAuthenticated(): boolean {
const t = readTokens()
return !!t && t.access_expires_at > now()
}
getTokenExpiry(): Date | null {
const t = readTokens()
return t ? new Date(t.access_expires_at) : null
}
getRefreshTokenExpiry(): Date | null {
const t = readTokens()
return t?.refresh_expires_at ? new Date(t.refresh_expires_at) : null
}
setTokens(accessToken: string, refreshToken: string, expiresInSeconds?: number) {
const access_expires_at = now() + (expiresInSeconds ? expiresInSeconds * 1000 : ACCESS_LIFETIME_FALLBACK_MS)
const refresh_expires_at = now() + REFRESH_LIFETIME_FALLBACK_MS
const tokens: StoredTokens = {
access_token: accessToken,
refresh_token: refreshToken,
access_expires_at,
refresh_expires_at,
}
writeTokens(tokens)
this.scheduleRefresh()
this.emit('tokensUpdated')
}
clearTokens() {
Cookies.remove('access_token');
Cookies.remove('refresh_token');
this.emit('tokensCleared');
if (this.refreshTimer) {
clearTimeout(this.refreshTimer)
this.refreshTimer = null
}
writeTokens(null)
this.emit('tokensCleared')
}
isAuthenticated(): boolean {
return !!this.getAccessToken();
logout() {
this.clearTokens()
this.emit('logout')
}
getAccessToken(): string | null {
return Cookies.get('access_token');
private scheduleRefresh() {
if (typeof window === 'undefined') return
const t = readTokens()
if (!t) return
if (this.refreshTimer) clearTimeout(this.refreshTimer)
const msUntilRefresh = Math.max(5_000, t.access_expires_at - now() - 60_000) // 1 minute before expiry
this.refreshTimer = setTimeout(() => {
this.refreshAccessToken().catch(() => {
this.emit('sessionExpired', 'refresh_failed')
this.clearTokens()
})
}, msUntilRefresh)
}
getRefreshToken(): string | null {
return Cookies.get('refresh_token');
}
getTokenExpiry(): { access_token_expiry: number | null; refresh_token_expiry: number | null } {
return {
access_token_expiry: parseInt(Cookies.get('access_token_expiry') || '0') || null,
refresh_token_expiry: parseInt(Cookies.get('refresh_token_expiry') || '0') || null,
};
}
getRefreshTokenExpiry(): number | null {
return parseInt(Cookies.get('refresh_token_expiry') || '0') || null;
async getAccessToken(): Promise<string | null> {
const t = readTokens()
if (!t) return null
if (t.access_expires_at - now() > 10_000) return t.access_token
try {
await this.refreshAccessToken()
return readTokens()?.access_token || null
} catch {
this.emit('sessionExpired', 'expired')
this.clearTokens()
return null
}
}
// Export singleton instance
export const tokenManager = TokenManager.getInstance();
private async refreshAccessToken(): Promise<void> {
const t = readTokens()
if (!t?.refresh_token) throw new Error('No refresh token')
const res = await fetch('/api-internal/v1/auth/refresh', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ refresh_token: t.refresh_token }),
})
if (!res.ok) throw new Error('Refresh failed')
const data = await res.json()
const expiresIn = data.expires_in as number | undefined
this.setTokens(data.access_token, data.refresh_token || t.refresh_token, expiresIn)
}
}
export const tokenManager = new TokenManager()

View File

@@ -1,98 +1,8 @@
import { type ClassValue, clsx } from "clsx"
import { twMerge } from "tailwind-merge"
import { type ClassValue } from 'clsx'
import { clsx } from 'clsx'
import { twMerge } from 'tailwind-merge'
export function cn(...inputs: ClassValue[]) {
return twMerge(clsx(inputs))
}
export function formatDate(date: Date | string): string {
const d = new Date(date);
return d.toLocaleDateString('en-US', {
year: 'numeric',
month: 'short',
day: 'numeric',
});
}
export function formatDateTime(date: Date | string): string {
const d = new Date(date);
return d.toLocaleString('en-US', {
year: 'numeric',
month: 'short',
day: 'numeric',
hour: '2-digit',
minute: '2-digit',
});
}
export function formatRelativeTime(date: Date | string): string {
const d = new Date(date);
const now = new Date();
const diffMs = now.getTime() - d.getTime();
const diffSeconds = Math.floor(diffMs / 1000);
const diffMinutes = Math.floor(diffSeconds / 60);
const diffHours = Math.floor(diffMinutes / 60);
const diffDays = Math.floor(diffHours / 24);
if (diffSeconds < 60) return 'just now';
if (diffMinutes < 60) return `${diffMinutes} minute${diffMinutes > 1 ? 's' : ''} ago`;
if (diffHours < 24) return `${diffHours} hour${diffHours > 1 ? 's' : ''} ago`;
if (diffDays < 7) return `${diffDays} day${diffDays > 1 ? 's' : ''} ago`;
return formatDate(d);
}
export function formatBytes(bytes: number, decimals = 2): string {
if (bytes === 0) return '0 Bytes';
const k = 1024;
const dm = decimals < 0 ? 0 : decimals;
const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB'];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(dm)) + ' ' + sizes[i];
}
export function formatNumber(num: number): string {
return new Intl.NumberFormat('en-US').format(num);
}
export function formatCurrency(amount: number, currency = 'USD'): string {
return new Intl.NumberFormat('en-US', {
style: 'currency',
currency,
}).format(amount);
}
export function truncate(str: string, length: number): string {
if (str.length <= length) return str;
return str.slice(0, length) + '...';
}
export function debounce<T extends (...args: any[]) => any>(
func: T,
wait: number
): (...args: Parameters<T>) => void {
let timeout: NodeJS.Timeout;
return (...args: Parameters<T>) => {
clearTimeout(timeout);
timeout = setTimeout(() => func(...args), wait);
};
}
export function throttle<T extends (...args: any[]) => any>(
func: T,
limit: number
): (...args: Parameters<T>) => void {
let inThrottle: boolean;
return (...args: Parameters<T>) => {
if (!inThrottle) {
func(...args);
inThrottle = true;
setTimeout(() => (inThrottle = false), limit);
}
};
}

View File

@@ -26,6 +26,12 @@ http {
listen 80;
server_name localhost;
# Static files - serve directly from nginx
location = /login_helper.html {
root /usr/share/nginx/html;
try_files $uri =404;
}
# Frontend routes
location / {
proxy_pass http://frontend;
@@ -66,6 +72,58 @@ http {
}
}
# RAG debug API routes - proxy to frontend (for Next.js API routes)
location /api/rag/debug/ {
proxy_pass http://frontend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# CORS headers
add_header 'Access-Control-Allow-Origin' '*' always;
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always;
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
add_header 'Access-Control-Allow-Expose-Headers' 'Content-Length,Content-Range' always;
# Handle preflight requests
if ($request_method = 'OPTIONS') {
add_header 'Access-Control-Allow-Origin' '*';
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS';
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization';
add_header 'Access-Control-Max-Age' 1728000;
add_header 'Content-Type' 'text/plain; charset=utf-8';
add_header 'Content-Length' 0;
return 204;
}
}
# Frontend API routes for authentication - proxy to frontend
location /api/auth/ {
proxy_pass http://frontend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# CORS headers
add_header 'Access-Control-Allow-Origin' '*' always;
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always;
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
add_header 'Access-Control-Expose-Headers' 'Content-Length,Content-Range' always;
# Handle preflight requests
if ($request_method = 'OPTIONS') {
add_header 'Access-Control-Allow-Origin' '*';
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS';
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization';
add_header 'Access-Control-Max-Age' 1728000;
add_header 'Content-Type' 'text/plain; charset=utf-8';
add_header 'Content-Length' 0;
return 204;
}
}
# Public API routes - proxy to backend (for external clients)
location /api/ {
proxy_pass http://backend;

166
verify_security_removal.py Normal file
View File

@@ -0,0 +1,166 @@
#!/usr/bin/env python3
"""
Verification script for security middleware removal
"""
import subprocess
import sys
import time
def run_command(cmd, cwd=None):
"""Run a command and return the result"""
try:
result = subprocess.run(
cmd,
shell=True,
capture_output=True,
text=True,
cwd=cwd,
timeout=30
)
return result.returncode, result.stdout, result.stderr
except subprocess.TimeoutExpired:
return -1, "", "Command timed out"
def test_backend_syntax():
"""Test if backend Python files have valid syntax"""
print("🔍 Testing backend Python syntax...")
# Check main.py
code, stdout, stderr = run_command("python3 -m py_compile app/main.py", cwd="backend")
if code == 0:
print("✅ main.py syntax OK")
else:
print(f"❌ main.py syntax error: {stderr}")
return False
# Check security middleware
code, stdout, stderr = run_command("python3 -m py_compile app/middleware/security.py", cwd="backend")
if code == 0:
print("✅ security.py syntax OK")
else:
print(f"❌ security.py syntax error: {stderr}")
return False
return True
def test_docker_build():
"""Test if Docker can build the backend service"""
print("\n🐳 Testing Docker backend build...")
# Just check if the Dockerfile exists and is readable
try:
with open("backend/Dockerfile", "r") as f:
content = f.read()
if "FROM" in content and "python" in content:
print("✅ Dockerfile exists and looks valid")
return True
else:
print("❌ Dockerfile appears invalid")
return False
except FileNotFoundError:
print("❌ Dockerfile not found")
return False
def test_env_settings():
"""Test if environment settings are correct"""
print("\n⚙️ Testing environment settings...")
try:
with open(".env", "r") as f:
env_content = f.read()
if "API_SECURITY_ENABLED=false" in env_content:
print("✅ Security is disabled in .env")
else:
print("❌ Security is not disabled in .env")
return False
if "API_RATE_LIMITING_ENABLED=false" in env_content:
print("✅ Rate limiting is disabled in .env")
else:
print("❌ Rate limiting is not disabled in .env")
return False
return True
except FileNotFoundError:
print("❌ .env file not found")
return False
def test_imports():
"""Test if the main application can be imported without security dependencies"""
print("\n📦 Testing import dependencies...")
# Create a minimal test script
test_script = """
import sys
sys.path.insert(0, 'backend')
try:
# Test if we can create the app without security middleware
from app.main import app
print("✅ App can be imported successfully")
except ImportError as e:
print(f"❌ Import error: {e}")
sys.exit(1)
except Exception as e:
print(f"❌ Other error: {e}")
sys.exit(1)
"""
# Save test script
with open("test_import.py", "w") as f:
f.write(test_script)
# Run test (will likely fail due to missing dependencies, but should not fail due to security imports)
code, stdout, stderr = run_command("python3 test_import.py")
# Clean up
import os
os.remove("test_import.py")
# We expect this to fail due to missing FastAPI, but not due to security imports
if "security" in stderr.lower() and "No module named" not in stderr:
print("❌ Security import errors detected")
return False
else:
print("✅ No security import errors detected")
return True
def main():
"""Run all verification tests"""
print("🚀 Starting verification of security middleware removal...\n")
tests = [
("Environment Settings", test_env_settings),
("Python Syntax", test_backend_syntax),
("Docker Configuration", test_docker_build),
("Import Dependencies", test_imports),
]
results = []
for test_name, test_func in tests:
print(f"\n--- {test_name} ---")
result = test_func()
results.append((test_name, result))
# Print summary
print("\n" + "="*50)
print("📊 VERIFICATION SUMMARY")
print("="*50)
for test_name, result in results:
status = "✅ PASS" if result else "❌ FAIL"
print(f"{test_name}: {status}")
all_passed = all(result for _, result in results)
if all_passed:
print("\n🎉 All tests passed! Security middleware has been successfully removed.")
else:
print("\n⚠️ Some tests failed. Please review the issues above.")
return all_passed
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)