diff --git a/.env b/.env
index 9e074ae..b8d34af 100644
--- a/.env
+++ b/.env
@@ -46,7 +46,7 @@ API_RATE_LIMITING_ENABLED=false
# ===================================
# APPLICATION BASE URL (Required - derives all URLs and CORS)
# ===================================
-BASE_URL=localhost
+BASE_URL=localhost:80
# Frontend derives: APP_URL=http://localhost, API_URL=http://localhost, WS_URL=ws://localhost
# Backend derives: CORS_ORIGINS=["http://localhost"]
diff --git a/.env.example b/.env.example
index cf6d8f1..b9dd120 100644
--- a/.env.example
+++ b/.env.example
@@ -65,6 +65,16 @@ QDRANT_HOST=enclava-qdrant
QDRANT_PORT=6333
QDRANT_URL=http://enclava-qdrant:6333
+# ===================================
+# RAG EMBEDDING CONFIGURATION (Optional overrides)
+# ===================================
+# These control embedding throughput to avoid provider 429s.
+# Defaults are conservative; uncomment to override.
+# RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE=12
+# RAG_EMBEDDING_BATCH_SIZE=3
+# RAG_EMBEDDING_DELAY_BETWEEN_BATCHES=1.0 # seconds
+# RAG_EMBEDDING_DELAY_PER_REQUEST=0.5 # seconds
+
# ===================================
# OPTIONAL PRIVATEMODE SETTINGS (Have defaults)
# ===================================
@@ -130,4 +140,4 @@ QDRANT_URL=http://enclava-qdrant:6333
# Required: DATABASE_URL, REDIS_URL, JWT_SECRET, ADMIN_EMAIL, ADMIN_PASSWORD, BASE_URL
# Recommended: PRIVATEMODE_API_KEY, QDRANT_HOST, QDRANT_PORT
# Optional: All other settings have secure defaults
-# ===================================
\ No newline at end of file
+# ===================================
diff --git a/backend/app/api/internal_v1/__init__.py b/backend/app/api/internal_v1/__init__.py
index 97e8510..29af4ab 100644
--- a/backend/app/api/internal_v1/__init__.py
+++ b/backend/app/api/internal_v1/__init__.py
@@ -12,8 +12,8 @@ from ..v1.audit import router as audit_router
from ..v1.settings import router as settings_router
from ..v1.analytics import router as analytics_router
from ..v1.rag import router as rag_router
+from ..rag_debug import router as rag_debug_router
from ..v1.prompt_templates import router as prompt_templates_router
-from ..v1.security import router as security_router
from ..v1.plugin_registry import router as plugin_registry_router
from ..v1.platform import router as platform_router
from ..v1.llm_internal import router as llm_internal_router
@@ -52,11 +52,12 @@ internal_api_router.include_router(analytics_router, prefix="/analytics", tags=[
# Include RAG routes (frontend RAG document management)
internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"])
+# Include RAG debug routes (for demo and debugging)
+internal_api_router.include_router(rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"])
+
# Include prompt template routes (frontend prompt template management)
internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"])
-# Include security routes (frontend security settings)
-internal_api_router.include_router(security_router, prefix="/security", tags=["internal-security"])
# Include plugin registry routes (frontend plugin management)
internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"])
diff --git a/backend/app/api/v1/__init__.py b/backend/app/api/v1/__init__.py
index 6f66641..f9412e4 100644
--- a/backend/app/api/v1/__init__.py
+++ b/backend/app/api/v1/__init__.py
@@ -16,7 +16,6 @@ from .analytics import router as analytics_router
from .rag import router as rag_router
from .chatbot import router as chatbot_router
from .prompt_templates import router as prompt_templates_router
-from .security import router as security_router
from .plugin_registry import router as plugin_registry_router
# Create main API router
@@ -61,8 +60,6 @@ api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
# Include prompt template routes
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
-# Include security routes
-api_router.include_router(security_router, prefix="/security", tags=["security"])
# Include plugin registry routes
diff --git a/backend/app/api/v1/llm.py b/backend/app/api/v1/llm.py
index c30d797..5fdc20c 100644
--- a/backend/app/api/v1/llm.py
+++ b/backend/app/api/v1/llm.py
@@ -745,8 +745,7 @@ async def get_llm_metrics(
"total_requests": metrics.total_requests,
"successful_requests": metrics.successful_requests,
"failed_requests": metrics.failed_requests,
- "security_blocked_requests": metrics.security_blocked_requests,
- "average_latency_ms": metrics.average_latency_ms,
+ "average_latency_ms": metrics.average_latency_ms,
"average_risk_score": metrics.average_risk_score,
"provider_metrics": metrics.provider_metrics,
"last_updated": metrics.last_updated.isoformat()
diff --git a/backend/app/api/v1/rag.py b/backend/app/api/v1/rag.py
index b5d00cf..0e65c2f 100644
--- a/backend/app/api/v1/rag.py
+++ b/backend/app/api/v1/rag.py
@@ -3,12 +3,14 @@ RAG API Endpoints
Provides REST API for RAG (Retrieval Augmented Generation) operations
"""
-from typing import List, Optional
+from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
import io
+import asyncio
+from datetime import datetime
from app.db.database import get_db
from app.core.security import get_current_user
@@ -16,6 +18,9 @@ from app.models.user import User
from app.services.rag_service import RAGService
from app.utils.exceptions import APIException
+# Import RAG module from module manager
+from app.services.module_manager import module_manager
+
router = APIRouter(tags=["RAG"])
@@ -78,14 +83,25 @@ async def get_collections(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
- """Get all RAG collections from Qdrant (source of truth) with PostgreSQL metadata"""
+ """Get all RAG collections - live data directly from Qdrant (source of truth)"""
try:
- rag_service = RAGService(db)
- collections_data = await rag_service.get_all_collections(skip=skip, limit=limit)
+ from app.services.qdrant_stats_service import qdrant_stats_service
+
+ # Get live stats from Qdrant
+ stats_data = await qdrant_stats_service.get_collections_stats()
+ collections = stats_data.get("collections", [])
+
+ # Apply pagination
+ start_idx = skip
+ end_idx = skip + limit
+ paginated_collections = collections[start_idx:end_idx]
+
return {
"success": True,
- "collections": collections_data,
- "total": len(collections_data)
+ "collections": paginated_collections,
+ "total": len(collections),
+ "total_documents": stats_data.get("total_documents", 0),
+ "total_size_bytes": stats_data.get("total_size_bytes", 0)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -116,6 +132,62 @@ async def create_collection(
raise HTTPException(status_code=500, detail=str(e))
+@router.get("/stats", response_model=dict)
+async def get_rag_stats(
+ db: AsyncSession = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ """Get overall RAG statistics - live data directly from Qdrant"""
+ try:
+ from app.services.qdrant_stats_service import qdrant_stats_service
+
+ # Get live stats from Qdrant
+ stats_data = await qdrant_stats_service.get_collections_stats()
+
+ # Calculate active collections (collections with documents)
+ active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
+
+ # Calculate processing documents from database
+ processing_docs = 0
+ try:
+ from sqlalchemy import select
+ from app.models.rag_document import RagDocument, ProcessingStatus
+
+ result = await db.execute(
+ select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
+ )
+ processing_docs = len(result.scalars().all())
+ except Exception:
+ pass # If database query fails, default to 0
+
+ response_data = {
+ "success": True,
+ "stats": {
+ "collections": {
+ "total": stats_data.get("total_collections", 0),
+ "active": active_collections
+ },
+ "documents": {
+ "total": stats_data.get("total_documents", 0),
+ "processing": processing_docs,
+ "processed": stats_data.get("total_documents", 0) # Indexed documents
+ },
+ "storage": {
+ "total_size_bytes": stats_data.get("total_size_bytes", 0),
+ "total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
+ },
+ "vectors": {
+ "total": stats_data.get("total_documents", 0) # Same as documents for RAG
+ },
+ "last_updated": datetime.utcnow().isoformat()
+ }
+ }
+
+ return response_data
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
@router.get("/collections/{collection_id}", response_model=dict)
async def get_collection(
collection_id: int,
@@ -225,21 +297,65 @@ async def upload_document(
try:
# Read file content
file_content = await file.read()
-
+
if len(file_content) == 0:
raise HTTPException(status_code=400, detail="Empty file uploaded")
-
+
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
-
+
+ # Validate file can be read before processing
+ filename = file.filename or "unknown"
+ file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
+
+ try:
+ # Test file readability based on type
+ if file_extension == 'jsonl':
+ # Validate JSONL format - try to parse first few lines
+ try:
+ content_str = file_content.decode('utf-8')
+ lines = content_str.strip().split('\n')[:5] # Check first 5 lines
+ import json
+ for i, line in enumerate(lines):
+ if line.strip(): # Skip empty lines
+ json.loads(line) # Will raise JSONDecodeError if invalid
+ except UnicodeDecodeError:
+ raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
+ except json.JSONDecodeError as e:
+ raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
+
+ elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
+ # Validate text files can be decoded
+ try:
+ file_content.decode('utf-8')
+ except UnicodeDecodeError:
+ raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
+
+ elif file_extension in ['pdf']:
+ # For PDF files, just check if it starts with PDF signature
+ if not file_content.startswith(b'%PDF'):
+ raise HTTPException(status_code=400, detail="Invalid PDF file format")
+
+ elif file_extension in ['docx', 'xlsx', 'pptx']:
+ # For Office documents, check ZIP signature
+ if not file_content.startswith(b'PK'):
+ raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
+
+ # For other file types, we'll rely on the document processor
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
+
rag_service = RAGService(db)
document = await rag_service.upload_document(
collection_id=collection_id,
file_content=file_content,
- filename=file.filename or "unknown",
+ filename=filename,
content_type=file.content_type
)
-
+
return {
"success": True,
"document": document.to_dict(),
@@ -362,21 +478,167 @@ async def download_document(
raise HTTPException(status_code=500, detail=str(e))
-# Stats Endpoint
-@router.get("/stats", response_model=dict)
-async def get_rag_stats(
- db: AsyncSession = Depends(get_db),
+# Debug Endpoints
+
+@router.post("/debug/search")
+async def search_with_debug(
+ query: str,
+ max_results: int = 10,
+ score_threshold: float = 0.3,
+ collection_name: str = None,
+ config: Dict[str, Any] = None,
current_user: User = Depends(get_current_user)
-):
- """Get RAG system statistics"""
+) -> Dict[str, Any]:
+ """
+ Enhanced search with comprehensive debug information
+ """
+ # Get RAG module from module manager
+ rag_module = module_manager.modules.get('rag')
+ if not rag_module or not rag_module.enabled:
+ raise HTTPException(status_code=503, detail="RAG module not initialized")
+
+ debug_info = {}
+ start_time = datetime.utcnow()
+
try:
- rag_service = RAGService(db)
- stats = await rag_service.get_stats()
-
+ # Apply configuration if provided
+ if config:
+ # Update RAG config temporarily
+ original_config = rag_module.config.copy()
+ rag_module.config.update(config)
+
+ # Generate query embedding (with or without prefix)
+ if config and config.get("use_query_prefix"):
+ optimized_query = f"query: {query}"
+ else:
+ optimized_query = query
+
+ query_embedding = await rag_module._generate_embedding(optimized_query)
+
+ # Store embedding info for debug
+ if config and config.get("debug", {}).get("show_embeddings"):
+ debug_info["query_embedding"] = query_embedding[:10] # First 10 dimensions
+ debug_info["embedding_dimension"] = len(query_embedding)
+ debug_info["optimized_query"] = optimized_query
+
+ # Perform search
+ search_start = asyncio.get_event_loop().time()
+ results = await rag_module.search_documents(
+ query,
+ max_results=max_results,
+ score_threshold=score_threshold,
+ collection_name=collection_name
+ )
+ search_time = (asyncio.get_event_loop().time() - search_start) * 1000
+
+ # Calculate score statistics
+ scores = [r.score for r in results if r.score is not None]
+ if scores:
+ import statistics
+ debug_info["score_stats"] = {
+ "min": min(scores),
+ "max": max(scores),
+ "avg": statistics.mean(scores),
+ "stddev": statistics.stdev(scores) if len(scores) > 1 else 0
+ }
+
+ # Get collection statistics
+ try:
+ from qdrant_client.http.models import Filter
+ collection_name = collection_name or rag_module.default_collection_name
+
+ # Count total documents
+ count_result = rag_module.qdrant_client.count(
+ collection_name=collection_name,
+ count_filter=Filter(must=[])
+ )
+ total_points = count_result.count
+
+ # Get unique documents and languages
+ scroll_result = rag_module.qdrant_client.scroll(
+ collection_name=collection_name,
+ limit=1000, # Sample for stats
+ with_payload=True,
+ with_vectors=False
+ )
+
+ unique_docs = set()
+ languages = set()
+
+ for point in scroll_result[0]:
+ payload = point.payload or {}
+ doc_id = payload.get("document_id")
+ if doc_id:
+ unique_docs.add(doc_id)
+
+ language = payload.get("language")
+ if language:
+ languages.add(language)
+
+ debug_info["collection_stats"] = {
+ "total_documents": len(unique_docs),
+ "total_chunks": total_points,
+ "languages": sorted(list(languages))
+ }
+
+ except Exception as e:
+ debug_info["collection_stats_error"] = str(e)
+
+ # Enhance results with debug info
+ enhanced_results = []
+ for result in results:
+ enhanced_result = {
+ "document": {
+ "id": result.document.id,
+ "content": result.document.content,
+ "metadata": result.document.metadata
+ },
+ "score": result.score,
+ "debug_info": {}
+ }
+
+ # Add hybrid search debug info if available
+ metadata = result.document.metadata or {}
+ if "_vector_score" in metadata:
+ enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
+ if "_bm25_score" in metadata:
+ enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
+
+ enhanced_results.append(enhanced_result)
+
+ # Note: Analytics logging disabled (module not available)
+
return {
- "success": True,
- "stats": stats
+ "results": enhanced_results,
+ "debug_info": debug_info,
+ "search_time_ms": search_time,
+ "timestamp": start_time.isoformat()
}
+
except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
\ No newline at end of file
+ # Note: Analytics logging disabled (module not available)
+ raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
+
+ finally:
+ # Restore original config if modified
+ if config and 'original_config' in locals():
+ rag_module.config = original_config
+
+
+@router.get("/debug/config")
+async def get_current_config(
+ current_user: User = Depends(get_current_user)
+) -> Dict[str, Any]:
+ """Get current RAG configuration"""
+ # Get RAG module from module manager
+ rag_module = module_manager.modules.get('rag')
+ if not rag_module or not rag_module.enabled:
+ raise HTTPException(status_code=503, detail="RAG module not initialized")
+
+ return {
+ "config": rag_module.config,
+ "embedding_model": rag_module.embedding_model,
+ "enabled": rag_module.enabled,
+ "collections": await rag_module._get_collections_safely()
+ }
diff --git a/backend/app/api/v1/security.py b/backend/app/api/v1/security.py
deleted file mode 100644
index 838dd6f..0000000
--- a/backend/app/api/v1/security.py
+++ /dev/null
@@ -1,251 +0,0 @@
-"""
-Security API endpoints for monitoring and configuration
-"""
-
-from typing import Dict, Any, List, Optional
-from fastapi import APIRouter, Depends, HTTPException, Request, status
-from pydantic import BaseModel, Field
-
-from app.core.security import get_current_active_user, RequiresRole
-from app.middleware.security import get_security_stats, get_request_auth_level, get_request_risk_score
-from app.core.config import settings
-from app.core.logging import get_logger
-
-logger = get_logger(__name__)
-
-router = APIRouter(tags=["security"])
-
-
-# Pydantic models for API responses
-class SecurityStatsResponse(BaseModel):
- """Security statistics response model"""
- total_requests_analyzed: int
- threats_detected: int
- threats_blocked: int
- anomalies_detected: int
- rate_limits_exceeded: int
- avg_analysis_time: float
- threat_types: Dict[str, int]
- threat_levels: Dict[str, int]
- top_attacking_ips: List[tuple]
- security_enabled: bool
- threat_detection_enabled: bool
- rate_limiting_enabled: bool
-
-
-class SecurityConfigResponse(BaseModel):
- """Security configuration response model"""
- security_enabled: bool = Field(description="Overall security system enabled")
- threat_detection_enabled: bool = Field(description="Threat detection analysis enabled")
- rate_limiting_enabled: bool = Field(description="Rate limiting enabled")
- ip_reputation_enabled: bool = Field(description="IP reputation checking enabled")
- anomaly_detection_enabled: bool = Field(description="Anomaly detection enabled")
- security_headers_enabled: bool = Field(description="Security headers enabled")
-
- # Rate limiting settings
- unauthenticated_per_minute: int = Field(description="Rate limit for unauthenticated requests per minute")
- authenticated_per_minute: int = Field(description="Rate limit for authenticated users per minute")
- api_key_per_minute: int = Field(description="Rate limit for API key users per minute")
- premium_per_minute: int = Field(description="Rate limit for premium users per minute")
-
- # Security thresholds
- risk_threshold: float = Field(description="Risk score threshold for blocking requests")
- warning_threshold: float = Field(description="Risk score threshold for warnings")
- anomaly_threshold: float = Field(description="Anomaly severity threshold")
-
- # IP settings
- blocked_ips: List[str] = Field(description="List of blocked IP addresses")
- allowed_ips: List[str] = Field(description="List of allowed IP addresses (empty = allow all)")
-
-
-class RateLimitInfoResponse(BaseModel):
- """Rate limit information for current request"""
- auth_level: str = Field(description="Authentication level (unauthenticated, authenticated, api_key, premium)")
- current_limits: Dict[str, int] = Field(description="Current rate limits for this auth level")
- remaining_requests: Optional[Dict[str, int]] = Field(description="Estimated remaining requests (if available)")
-
-
-@router.get("/stats", response_model=SecurityStatsResponse)
-async def get_security_statistics(
- current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
-):
- """
- Get security system statistics
-
- Requires admin role. Returns comprehensive statistics about:
- - Request analysis counts
- - Threat detection results
- - Rate limiting enforcement
- - Top attacking IPs
- - Performance metrics
- """
- try:
- stats = get_security_stats()
- return SecurityStatsResponse(**stats)
- except Exception as e:
- logger.error(f"Error getting security stats: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to retrieve security statistics"
- )
-
-
-@router.get("/config", response_model=SecurityConfigResponse)
-async def get_security_config(
- current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
-):
- """
- Get current security configuration
-
- Requires admin role. Returns current security settings including:
- - Feature enablement flags
- - Rate limiting thresholds
- - Security thresholds
- - IP allowlists/blocklists
- """
- return SecurityConfigResponse(
- security_enabled=settings.API_SECURITY_ENABLED,
- threat_detection_enabled=settings.API_THREAT_DETECTION_ENABLED,
- rate_limiting_enabled=settings.API_RATE_LIMITING_ENABLED,
- ip_reputation_enabled=settings.API_IP_REPUTATION_ENABLED,
- anomaly_detection_enabled=settings.API_ANOMALY_DETECTION_ENABLED,
- security_headers_enabled=settings.API_SECURITY_HEADERS_ENABLED,
-
- unauthenticated_per_minute=settings.API_RATE_LIMIT_UNAUTHENTICATED_PER_MINUTE,
- authenticated_per_minute=settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE,
- api_key_per_minute=settings.API_RATE_LIMIT_API_KEY_PER_MINUTE,
- premium_per_minute=settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE,
-
- risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
- warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
- anomaly_threshold=settings.API_SECURITY_ANOMALY_THRESHOLD,
-
- blocked_ips=settings.API_BLOCKED_IPS,
- allowed_ips=settings.API_ALLOWED_IPS
- )
-
-
-@router.get("/status")
-async def get_security_status(
- request: Request,
- current_user: Dict[str, Any] = Depends(get_current_active_user)
-):
- """
- Get security status for current request
-
- Returns information about the security analysis of the current request:
- - Authentication level
- - Risk score (if available)
- - Rate limiting status
- """
- auth_level = get_request_auth_level(request)
- risk_score = get_request_risk_score(request)
-
- # Get rate limits for current auth level
- from app.core.threat_detection import AuthLevel
- try:
- auth_enum = AuthLevel(auth_level)
- from app.core.threat_detection import threat_detection_service
- minute_limit, hour_limit = threat_detection_service.get_rate_limits(auth_enum)
-
- rate_limit_info = RateLimitInfoResponse(
- auth_level=auth_level,
- current_limits={
- "per_minute": minute_limit,
- "per_hour": hour_limit
- },
- remaining_requests=None # We don't track remaining requests in current implementation
- )
- except ValueError:
- rate_limit_info = RateLimitInfoResponse(
- auth_level=auth_level,
- current_limits={},
- remaining_requests=None
- )
-
- return {
- "security_enabled": settings.API_SECURITY_ENABLED,
- "auth_level": auth_level,
- "risk_score": round(risk_score, 3) if risk_score > 0 else None,
- "rate_limit_info": rate_limit_info.dict(),
- "security_headers_enabled": settings.API_SECURITY_HEADERS_ENABLED
- }
-
-
-@router.post("/test")
-async def test_security_analysis(
- request: Request,
- current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
-):
- """
- Test security analysis on current request
-
- Requires admin role. Manually triggers security analysis on the current request
- and returns detailed results. Useful for testing security rules and thresholds.
- """
- try:
- from app.middleware.security import analyze_request_security
-
- analysis = await analyze_request_security(request, current_user)
-
- return {
- "analysis_complete": True,
- "is_threat": analysis.is_threat,
- "risk_score": round(analysis.risk_score, 3),
- "auth_level": analysis.auth_level.value,
- "should_block": analysis.should_block,
- "rate_limit_exceeded": analysis.rate_limit_exceeded,
- "threat_count": len(analysis.threats),
- "threats": [
- {
- "type": threat.threat_type,
- "level": threat.level.value,
- "confidence": round(threat.confidence, 3),
- "description": threat.description,
- "mitigation": threat.mitigation
- }
- for threat in analysis.threats
- ],
- "recommendations": analysis.recommendations
- }
- except Exception as e:
- logger.error(f"Error in security analysis test: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to perform security analysis test"
- )
-
-
-@router.get("/health")
-async def security_health_check():
- """
- Security system health check
-
- Public endpoint that returns the health status of the security system.
- Does not require authentication.
- """
- try:
- stats = get_security_stats()
-
- # Basic health checks
- is_healthy = (
- settings.API_SECURITY_ENABLED and
- stats.get("total_requests_analyzed", 0) >= 0 and
- stats.get("avg_analysis_time", 0) < 1.0 # Analysis should be under 1 second
- )
-
- return {
- "status": "healthy" if is_healthy else "degraded",
- "security_enabled": settings.API_SECURITY_ENABLED,
- "threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
- "rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED,
- "avg_analysis_time_ms": round(stats.get("avg_analysis_time", 0) * 1000, 2),
- "total_requests_analyzed": stats.get("total_requests_analyzed", 0)
- }
- except Exception as e:
- logger.error(f"Security health check failed: {e}")
- return {
- "status": "unhealthy",
- "error": "Security system error",
- "security_enabled": settings.API_SECURITY_ENABLED
- }
\ No newline at end of file
diff --git a/backend/app/api/v1/settings.py b/backend/app/api/v1/settings.py
index 8595ad6..4b97e25 100644
--- a/backend/app/api/v1/settings.py
+++ b/backend/app/api/v1/settings.py
@@ -97,7 +97,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
"api": {
# Security Settings
"security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"},
- "threat_detection_enabled": {"value": True, "type": "boolean", "description": "Enable threat detection analysis"},
"rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"},
"ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"},
"anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"},
@@ -112,7 +111,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"},
# Security Thresholds
- "security_risk_threshold": {"value": 0.8, "type": "float", "description": "Risk score threshold for blocking requests (0.0-1.0)"},
"security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"},
"anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"},
@@ -601,7 +599,6 @@ async def reset_to_defaults(
"api": {
# Security Settings
"security_enabled": {"value": True, "type": "boolean"},
- "threat_detection_enabled": {"value": True, "type": "boolean"},
"rate_limiting_enabled": {"value": True, "type": "boolean"},
"ip_reputation_enabled": {"value": True, "type": "boolean"},
"anomaly_detection_enabled": {"value": True, "type": "boolean"},
@@ -616,7 +613,6 @@ async def reset_to_defaults(
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer"},
# Security Thresholds
- "security_risk_threshold": {"value": 0.8, "type": "float"},
"security_warning_threshold": {"value": 0.6, "type": "float"},
"anomaly_threshold": {"value": 0.7, "type": "float"},
diff --git a/backend/app/core/config.py b/backend/app/core/config.py
index f3ac614..7d53387 100644
--- a/backend/app/core/config.py
+++ b/backend/app/core/config.py
@@ -17,6 +17,8 @@ class Settings(BaseSettings):
APP_LOG_LEVEL: str = os.getenv("APP_LOG_LEVEL", "INFO")
APP_HOST: str = os.getenv("APP_HOST", "0.0.0.0")
APP_PORT: int = int(os.getenv("APP_PORT", "8000"))
+ BACKEND_INTERNAL_PORT: int = int(os.getenv("BACKEND_INTERNAL_PORT", "8000"))
+ FRONTEND_INTERNAL_PORT: int = int(os.getenv("FRONTEND_INTERNAL_PORT", "3000"))
# Detailed logging for LLM interactions
LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM
@@ -73,16 +75,11 @@ class Settings(BaseSettings):
QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost")
QDRANT_PORT: int = int(os.getenv("QDRANT_PORT", "6333"))
QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY")
+ QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
- # API & Security Settings
- API_SECURITY_ENABLED: bool = os.getenv("API_SECURITY_ENABLED", "True").lower() == "true"
- API_THREAT_DETECTION_ENABLED: bool = os.getenv("API_THREAT_DETECTION_ENABLED", "True").lower() == "true"
- API_IP_REPUTATION_ENABLED: bool = os.getenv("API_IP_REPUTATION_ENABLED", "True").lower() == "true"
- API_ANOMALY_DETECTION_ENABLED: bool = os.getenv("API_ANOMALY_DETECTION_ENABLED", "True").lower() == "true"
-
+
# Rate Limiting Configuration
- API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true"
-
+
# PrivateMode Standard tier limits (organization-level, not per user)
# These are shared across all API keys and users in the organization
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20"))
@@ -101,23 +98,14 @@ class Settings(BaseSettings):
# Premium/Enterprise API keys
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
-
- # Security Thresholds
- API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score
- API_SECURITY_WARNING_THRESHOLD: float = float(os.getenv("API_SECURITY_WARNING_THRESHOLD", "0.6")) # Log warnings above this threshold
- API_SECURITY_ANOMALY_THRESHOLD: float = float(os.getenv("API_SECURITY_ANOMALY_THRESHOLD", "0.7")) # Flag anomalies above this threshold
-
+
# Request Size Limits
API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium
# IP Security
- API_BLOCKED_IPS: List[str] = os.getenv("API_BLOCKED_IPS", "").split(",") if os.getenv("API_BLOCKED_IPS") else []
- API_ALLOWED_IPS: List[str] = os.getenv("API_ALLOWED_IPS", "").split(",") if os.getenv("API_ALLOWED_IPS") else []
- API_IP_REPUTATION_CACHE_TTL: int = int(os.getenv("API_IP_REPUTATION_CACHE_TTL", "3600")) # 1 hour
# Security Headers
- API_SECURITY_HEADERS_ENABLED: bool = os.getenv("API_SECURITY_HEADERS_ENABLED", "True").lower() == "true"
API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
# Monitoring
@@ -129,6 +117,19 @@ class Settings(BaseSettings):
# Module configuration
MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml")
+
+ # RAG Embedding Configuration
+ RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12"))
+ RAG_EMBEDDING_BATCH_SIZE: int = int(os.getenv("RAG_EMBEDDING_BATCH_SIZE", "3"))
+ RAG_EMBEDDING_RETRY_COUNT: int = int(os.getenv("RAG_EMBEDDING_RETRY_COUNT", "3"))
+ RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv("RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16")
+ RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0"))
+ RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
+ RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
+ RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
+ RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
+ RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
+ RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
# Plugin configuration
PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins")
@@ -142,9 +143,12 @@ class Settings(BaseSettings):
model_config = {
"env_file": ".env",
- "case_sensitive": True
+ "case_sensitive": True,
+ # Ignore unknown environment variables to avoid validation errors
+ # when optional/deprecated flags are present in .env
+ "extra": "ignore",
}
# Global settings instance
-settings = Settings()
\ No newline at end of file
+settings = Settings()
diff --git a/backend/app/core/threat_detection.py b/backend/app/core/threat_detection.py
deleted file mode 100644
index cac2c7b..0000000
--- a/backend/app/core/threat_detection.py
+++ /dev/null
@@ -1,744 +0,0 @@
-"""
-Core threat detection and security analysis for the platform
-"""
-
-import re
-import time
-from collections import defaultdict, deque
-from dataclasses import dataclass, field
-from datetime import datetime, timedelta
-from enum import Enum
-from typing import Dict, List, Optional, Set, Tuple, Any, Union
-from urllib.parse import unquote
-
-from fastapi import Request
-from app.core.config import settings
-from app.core.logging import get_logger
-
-logger = get_logger(__name__)
-
-
-class ThreatLevel(Enum):
- """Threat severity levels"""
- LOW = "low"
- MEDIUM = "medium"
- HIGH = "high"
- CRITICAL = "critical"
-
-
-class AuthLevel(Enum):
- """Authentication levels for rate limiting"""
- AUTHENTICATED = "authenticated"
- API_KEY = "api_key"
- PREMIUM = "premium"
-
-
-@dataclass
-class SecurityThreat:
- """Security threat detection result"""
- threat_type: str
- level: ThreatLevel
- confidence: float
- description: str
- source_ip: str
- user_agent: Optional[str] = None
- request_path: Optional[str] = None
- payload: Optional[str] = None
- timestamp: datetime = field(default_factory=datetime.utcnow)
- mitigation: Optional[str] = None
-
-
-@dataclass
-class SecurityAnalysis:
- """Comprehensive security analysis result"""
- is_threat: bool
- threats: List[SecurityThreat]
- risk_score: float
- recommendations: List[str]
- auth_level: AuthLevel
- rate_limit_exceeded: bool
- should_block: bool
- timestamp: datetime = field(default_factory=datetime.utcnow)
-
-
-@dataclass
-class RateLimitInfo:
- """Rate limiting information"""
- auth_level: AuthLevel
- requests_per_minute: int
- requests_per_hour: int
- minute_limit: int
- hour_limit: int
- exceeded: bool
-
-
-@dataclass
-class AnomalyDetection:
- """Anomaly detection result"""
- is_anomaly: bool
- anomaly_type: str
- severity: float
- details: Dict[str, Any]
- baseline_value: Optional[float] = None
- current_value: Optional[float] = None
-
-
-class ThreatDetectionService:
- """Core threat detection and security analysis service"""
-
- def __init__(self):
- self.name = "threat_detection"
-
- # Statistics
- self.stats = {
- 'total_requests_analyzed': 0,
- 'threats_detected': 0,
- 'threats_blocked': 0,
- 'anomalies_detected': 0,
- 'rate_limits_exceeded': 0,
- 'total_analysis_time': 0,
- 'threat_types': defaultdict(int),
- 'threat_levels': defaultdict(int),
- 'attacking_ips': defaultdict(int)
- }
-
- # Threat detection patterns
- self.sql_injection_patterns = [
- r"(\bunion\b.*\bselect\b)",
- r"(\bselect\b.*\bfrom\b)",
- r"(\binsert\b.*\binto\b)",
- r"(\bupdate\b.*\bset\b)",
- r"(\bdelete\b.*\bfrom\b)",
- r"(\bdrop\b.*\btable\b)",
- r"(\bor\b.*\b1\s*=\s*1\b)",
- r"(\band\b.*\b1\s*=\s*1\b)",
- r"(\bexec\b.*\bxp_\w+)",
- r"(\bsp_\w+)",
- r"(\bsleep\b\s*\(\s*\d+\s*\))",
- r"(\bwaitfor\b.*\bdelay\b)",
- r"(\bbenchmark\b\s*\(\s*\d+)",
- r"(\bload_file\b\s*\()",
- r"(\binto\b.*\boutfile\b)"
- ]
-
- self.xss_patterns = [
- r"",
- r"",
- r"",
- r"",
- r"]*>",
- r"]*>",
- r"javascript:",
- r"vbscript:",
- r"on\w+\s*=",
- r"style\s*=.*expression",
- r"style\s*=.*javascript"
- ]
-
- self.path_traversal_patterns = [
- r"\.\.\/",
- r"\.\.\\",
- r"%2e%2e%2f",
- r"%2e%2e%5c",
- r"..%2f",
- r"..%5c",
- r"%252e%252e%252f",
- r"%252e%252e%255c"
- ]
-
- self.command_injection_patterns = [
- r";\s*cat\s+",
- r";\s*ls\s+",
- r";\s*pwd\s*",
- r";\s*whoami\s*",
- r";\s*id\s*",
- r";\s*uname\s*",
- r";\s*ps\s+",
- r";\s*netstat\s+",
- r";\s*wget\s+",
- r";\s*curl\s+",
- r"\|\s*cat\s+",
- r"\|\s*ls\s+",
- r"&&\s*cat\s+",
- r"&&\s*ls\s+"
- ]
-
- self.suspicious_ua_patterns = [
- r"sqlmap",
- r"nikto",
- r"nmap",
- r"masscan",
- r"zap",
- r"burp",
- r"w3af",
- r"acunetix",
- r"nessus",
- r"openvas",
- r"metasploit"
- ]
-
- # Rate limiting tracking - separate by auth level (excluding unauthenticated since they're blocked)
- self.rate_limits = {
- AuthLevel.AUTHENTICATED: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
- AuthLevel.API_KEY: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
- AuthLevel.PREMIUM: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)})
- }
-
- # Anomaly detection
- self.request_history = deque(maxlen=1000)
- self.ip_history = defaultdict(lambda: deque(maxlen=100))
- self.endpoint_history = defaultdict(lambda: deque(maxlen=100))
-
- # Blocked and allowed IPs
- self.blocked_ips = set(settings.API_BLOCKED_IPS)
- self.allowed_ips = set(settings.API_ALLOWED_IPS) if settings.API_ALLOWED_IPS else None
-
- # IP reputation cache
- self.ip_reputation_cache = {}
- self.cache_expiry = {}
-
- # Compile patterns for performance
- self._compile_patterns()
-
- logger.info(f"ThreatDetectionService initialized with {len(self.sql_injection_patterns)} SQL patterns, "
- f"{len(self.xss_patterns)} XSS patterns, rate limiting enabled: {settings.API_RATE_LIMITING_ENABLED}")
-
- def _compile_patterns(self):
- """Compile regex patterns for better performance"""
- try:
- self.compiled_sql_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.sql_injection_patterns]
- self.compiled_xss_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.xss_patterns]
- self.compiled_path_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.path_traversal_patterns]
- self.compiled_cmd_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.command_injection_patterns]
- self.compiled_ua_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.suspicious_ua_patterns]
- except re.error as e:
- logger.error(f"Failed to compile security patterns: {e}")
- # Fallback to empty lists to prevent crashes
- self.compiled_sql_patterns = []
- self.compiled_xss_patterns = []
- self.compiled_path_patterns = []
- self.compiled_cmd_patterns = []
- self.compiled_ua_patterns = []
-
- def determine_auth_level(self, request: Request, user_context: Optional[Dict] = None) -> AuthLevel:
- """Determine authentication level for rate limiting"""
- # Check if request has API key authentication
- if hasattr(request.state, 'api_key_context') and request.state.api_key_context:
- api_key = request.state.api_key_context.get('api_key')
- if api_key and hasattr(api_key, 'tier'):
- # Check for premium tier
- if api_key.tier in ['premium', 'enterprise']:
- return AuthLevel.PREMIUM
- return AuthLevel.API_KEY
-
- # Check for JWT authentication
- if user_context or hasattr(request.state, 'user'):
- return AuthLevel.AUTHENTICATED
-
- # Check Authorization header for API key
- auth_header = request.headers.get("Authorization", "")
- api_key_header = request.headers.get("X-API-Key", "")
- if auth_header.startswith("Bearer ") or api_key_header:
- return AuthLevel.API_KEY
-
- # Default to authenticated since unauthenticated requests are blocked at middleware
- return AuthLevel.AUTHENTICATED
-
- def get_rate_limits(self, auth_level: AuthLevel) -> Tuple[int, int]:
- """Get rate limits for authentication level"""
- if not settings.API_RATE_LIMITING_ENABLED:
- return float('inf'), float('inf')
-
- if auth_level == AuthLevel.AUTHENTICATED:
- return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
- elif auth_level == AuthLevel.API_KEY:
- return (settings.API_RATE_LIMIT_API_KEY_PER_MINUTE, settings.API_RATE_LIMIT_API_KEY_PER_HOUR)
- elif auth_level == AuthLevel.PREMIUM:
- return (settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE, settings.API_RATE_LIMIT_PREMIUM_PER_HOUR)
- else:
- # Fallback to authenticated limits
- return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
-
- def check_rate_limit(self, client_ip: str, auth_level: AuthLevel) -> RateLimitInfo:
- """Check if request exceeds rate limits"""
- minute_limit, hour_limit = self.get_rate_limits(auth_level)
- current_time = time.time()
-
- # Get or create tracking for this auth level
- if auth_level not in self.rate_limits:
- # This shouldn't happen, but handle gracefully
- return RateLimitInfo(
- auth_level=auth_level,
- requests_per_minute=0,
- requests_per_hour=0,
- minute_limit=minute_limit,
- hour_limit=hour_limit,
- exceeded=False
- )
-
- ip_limits = self.rate_limits[auth_level][client_ip]
-
- # Clean old entries
- minute_ago = current_time - 60
- hour_ago = current_time - 3600
-
- while ip_limits['minute'] and ip_limits['minute'][0] < minute_ago:
- ip_limits['minute'].popleft()
-
- while ip_limits['hour'] and ip_limits['hour'][0] < hour_ago:
- ip_limits['hour'].popleft()
-
- # Check current counts
- requests_per_minute = len(ip_limits['minute'])
- requests_per_hour = len(ip_limits['hour'])
-
- # Check if limits exceeded
- exceeded = (requests_per_minute >= minute_limit) or (requests_per_hour >= hour_limit)
-
- # Add current request to tracking
- if not exceeded:
- ip_limits['minute'].append(current_time)
- ip_limits['hour'].append(current_time)
-
- return RateLimitInfo(
- auth_level=auth_level,
- requests_per_minute=requests_per_minute,
- requests_per_hour=requests_per_hour,
- minute_limit=minute_limit,
- hour_limit=hour_limit,
- exceeded=exceeded
- )
-
- async def analyze_request(self, request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
- """Perform comprehensive security analysis on a request"""
- start_time = time.time()
-
- try:
- client_ip = request.client.host if request.client else "unknown"
- user_agent = request.headers.get("user-agent", "")
- path = str(request.url.path)
- method = request.method
-
- # Determine authentication level
- auth_level = self.determine_auth_level(request, user_context)
-
- # Check IP allowlist/blocklist first
- if self.allowed_ips and client_ip not in self.allowed_ips:
- threat = SecurityThreat(
- threat_type="ip_not_allowed",
- level=ThreatLevel.HIGH,
- confidence=1.0,
- description=f"IP {client_ip} not in allowlist",
- source_ip=client_ip,
- mitigation="Add IP to allowlist or remove IP restrictions"
- )
- return SecurityAnalysis(
- is_threat=True,
- threats=[threat],
- risk_score=1.0,
- recommendations=["Block request immediately"],
- auth_level=auth_level,
- rate_limit_exceeded=False,
- should_block=True
- )
-
- if client_ip in self.blocked_ips:
- threat = SecurityThreat(
- threat_type="ip_blocked",
- level=ThreatLevel.CRITICAL,
- confidence=1.0,
- description=f"IP {client_ip} is blocked",
- source_ip=client_ip,
- mitigation="Remove IP from blocklist if legitimate"
- )
- return SecurityAnalysis(
- is_threat=True,
- threats=[threat],
- risk_score=1.0,
- recommendations=["Block request immediately"],
- auth_level=auth_level,
- rate_limit_exceeded=False,
- should_block=True
- )
-
- # Check rate limiting
- rate_limit_info = self.check_rate_limit(client_ip, auth_level)
- if rate_limit_info.exceeded:
- self.stats['rate_limits_exceeded'] += 1
- threat = SecurityThreat(
- threat_type="rate_limit_exceeded",
- level=ThreatLevel.MEDIUM,
- confidence=0.9,
- description=f"Rate limit exceeded for {auth_level.value}: {rate_limit_info.requests_per_minute}/min, {rate_limit_info.requests_per_hour}/hr",
- source_ip=client_ip,
- mitigation=f"Implement rate limiting, current limits: {rate_limit_info.minute_limit}/min, {rate_limit_info.hour_limit}/hr"
- )
- return SecurityAnalysis(
- is_threat=True,
- threats=[threat],
- risk_score=0.7,
- recommendations=[f"Rate limit exceeded for {auth_level.value} user"],
- auth_level=auth_level,
- rate_limit_exceeded=True,
- should_block=True
- )
-
- # Skip threat detection if disabled
- if not settings.API_THREAT_DETECTION_ENABLED:
- return SecurityAnalysis(
- is_threat=False,
- threats=[],
- risk_score=0.0,
- recommendations=[],
- auth_level=auth_level,
- rate_limit_exceeded=False,
- should_block=False
- )
-
- # Collect request data for threat analysis
- query_params = str(request.query_params)
- headers = dict(request.headers)
-
- # Try to get body content safely
- body_content = ""
- try:
- if hasattr(request, '_body') and request._body:
- body_content = request._body.decode() if isinstance(request._body, bytes) else str(request._body)
- except:
- pass
-
- threats = []
-
- # Analyze for various threats
- threats.extend(await self._detect_sql_injection(query_params, body_content, path, client_ip))
- threats.extend(await self._detect_xss(query_params, body_content, headers, client_ip))
- threats.extend(await self._detect_path_traversal(path, query_params, client_ip))
- threats.extend(await self._detect_command_injection(query_params, body_content, client_ip))
- threats.extend(await self._detect_suspicious_patterns(headers, user_agent, path, client_ip))
-
- # Anomaly detection if enabled
- if settings.API_ANOMALY_DETECTION_ENABLED:
- anomaly = await self._detect_anomalies(client_ip, path, method, len(body_content))
- if anomaly.is_anomaly and anomaly.severity > settings.API_SECURITY_ANOMALY_THRESHOLD:
- threat = SecurityThreat(
- threat_type=f"anomaly_{anomaly.anomaly_type}",
- level=ThreatLevel.MEDIUM if anomaly.severity > 0.7 else ThreatLevel.LOW,
- confidence=anomaly.severity,
- description=f"Anomalous behavior detected: {anomaly.details}",
- source_ip=client_ip,
- user_agent=user_agent,
- request_path=path
- )
- threats.append(threat)
-
- # Calculate risk score
- risk_score = self._calculate_risk_score(threats)
-
- # Determine if request should be blocked
- should_block = risk_score >= settings.API_SECURITY_RISK_THRESHOLD
-
- # Generate recommendations
- recommendations = self._generate_recommendations(threats, risk_score, auth_level)
-
- # Update statistics
- self._update_stats(threats, time.time() - start_time)
-
- return SecurityAnalysis(
- is_threat=len(threats) > 0,
- threats=threats,
- risk_score=risk_score,
- recommendations=recommendations,
- auth_level=auth_level,
- rate_limit_exceeded=False,
- should_block=should_block
- )
-
- except Exception as e:
- logger.error(f"Error in threat analysis: {e}")
- return SecurityAnalysis(
- is_threat=False,
- threats=[],
- risk_score=0.0,
- recommendations=["Error occurred during security analysis"],
- auth_level=AuthLevel.AUTHENTICATED,
- rate_limit_exceeded=False,
- should_block=False
- )
-
- async def _detect_sql_injection(self, query_params: str, body_content: str, path: str, client_ip: str) -> List[SecurityThreat]:
- """Detect SQL injection attempts"""
- threats = []
- content_to_check = f"{query_params} {body_content} {path}".lower()
-
- for pattern in self.compiled_sql_patterns:
- if pattern.search(content_to_check):
- threat = SecurityThreat(
- threat_type="sql_injection",
- level=ThreatLevel.HIGH,
- confidence=0.85,
- description="Potential SQL injection attempt detected",
- source_ip=client_ip,
- payload=pattern.pattern,
- mitigation="Block request, sanitize input, use parameterized queries"
- )
- threats.append(threat)
- break # Don't duplicate for multiple patterns
-
- return threats
-
- async def _detect_xss(self, query_params: str, body_content: str, headers: dict, client_ip: str) -> List[SecurityThreat]:
- """Detect XSS attempts"""
- threats = []
- content_to_check = f"{query_params} {body_content}".lower()
-
- # Check headers for XSS
- for header_name, header_value in headers.items():
- content_to_check += f" {header_value}".lower()
-
- for pattern in self.compiled_xss_patterns:
- if pattern.search(content_to_check):
- threat = SecurityThreat(
- threat_type="xss",
- level=ThreatLevel.HIGH,
- confidence=0.80,
- description="Potential XSS attack detected",
- source_ip=client_ip,
- payload=pattern.pattern,
- mitigation="Block request, sanitize input, implement CSP headers"
- )
- threats.append(threat)
- break
-
- return threats
-
- async def _detect_path_traversal(self, path: str, query_params: str, client_ip: str) -> List[SecurityThreat]:
- """Detect path traversal attempts"""
- threats = []
- content_to_check = f"{path} {query_params}".lower()
- decoded_content = unquote(content_to_check)
-
- for pattern in self.compiled_path_patterns:
- if pattern.search(content_to_check) or pattern.search(decoded_content):
- threat = SecurityThreat(
- threat_type="path_traversal",
- level=ThreatLevel.HIGH,
- confidence=0.90,
- description="Path traversal attempt detected",
- source_ip=client_ip,
- request_path=path,
- mitigation="Block request, validate file paths, implement access controls"
- )
- threats.append(threat)
- break
-
- return threats
-
- async def _detect_command_injection(self, query_params: str, body_content: str, client_ip: str) -> List[SecurityThreat]:
- """Detect command injection attempts"""
- threats = []
- content_to_check = f"{query_params} {body_content}".lower()
-
- for pattern in self.compiled_cmd_patterns:
- if pattern.search(content_to_check):
- threat = SecurityThreat(
- threat_type="command_injection",
- level=ThreatLevel.CRITICAL,
- confidence=0.95,
- description="Command injection attempt detected",
- source_ip=client_ip,
- payload=pattern.pattern,
- mitigation="Block request immediately, sanitize input, disable shell execution"
- )
- threats.append(threat)
- break
-
- return threats
-
- async def _detect_suspicious_patterns(self, headers: dict, user_agent: str, path: str, client_ip: str) -> List[SecurityThreat]:
- """Detect suspicious patterns in headers and user agent"""
- threats = []
-
- # Check for suspicious user agents
- ua_lower = user_agent.lower()
- for pattern in self.compiled_ua_patterns:
- if pattern.search(ua_lower):
- threat = SecurityThreat(
- threat_type="suspicious_user_agent",
- level=ThreatLevel.HIGH,
- confidence=0.85,
- description=f"Suspicious user agent detected: {pattern.pattern}",
- source_ip=client_ip,
- user_agent=user_agent,
- mitigation="Block request, monitor IP for further activity"
- )
- threats.append(threat)
- break
-
- # Check for suspicious headers
- if "x-forwarded-for" in headers and "x-real-ip" in headers:
- # Potential header manipulation
- threat = SecurityThreat(
- threat_type="header_manipulation",
- level=ThreatLevel.LOW,
- confidence=0.30,
- description="Potential IP header manipulation detected",
- source_ip=client_ip,
- mitigation="Validate proxy headers, implement IP whitelisting"
- )
- threats.append(threat)
-
- return threats
-
- async def _detect_anomalies(self, client_ip: str, path: str, method: str, body_size: int) -> AnomalyDetection:
- """Detect anomalous behavior patterns"""
- try:
- # Request size anomaly
- max_size = settings.API_MAX_REQUEST_BODY_SIZE
- if body_size > max_size:
- return AnomalyDetection(
- is_anomaly=True,
- anomaly_type="request_size",
- severity=0.8,
- details={"body_size": body_size, "threshold": max_size},
- current_value=body_size,
- baseline_value=max_size // 10
- )
-
- # Unusual endpoint access
- if path.startswith("/admin") or path.startswith("/api/admin"):
- return AnomalyDetection(
- is_anomaly=True,
- anomaly_type="sensitive_endpoint",
- severity=0.6,
- details={"path": path, "reason": "admin endpoint access"},
- current_value=1.0,
- baseline_value=0.0
- )
-
- # IP request frequency anomaly
- current_time = time.time()
- ip_requests = self.ip_history[client_ip]
-
- # Clean old entries (last 5 minutes)
- five_minutes_ago = current_time - 300
- while ip_requests and ip_requests[0] < five_minutes_ago:
- ip_requests.popleft()
-
- ip_requests.append(current_time)
-
- if len(ip_requests) > 100: # More than 100 requests in 5 minutes
- return AnomalyDetection(
- is_anomaly=True,
- anomaly_type="request_frequency",
- severity=0.7,
- details={"requests_5min": len(ip_requests), "threshold": 100},
- current_value=len(ip_requests),
- baseline_value=10 # 10 requests baseline
- )
-
- return AnomalyDetection(
- is_anomaly=False,
- anomaly_type="none",
- severity=0.0,
- details={}
- )
-
- except Exception as e:
- logger.error(f"Error in anomaly detection: {e}")
- return AnomalyDetection(
- is_anomaly=False,
- anomaly_type="error",
- severity=0.0,
- details={"error": str(e)}
- )
-
- def _calculate_risk_score(self, threats: List[SecurityThreat]) -> float:
- """Calculate overall risk score based on threats"""
- if not threats:
- return 0.0
-
- score = 0.0
- for threat in threats:
- level_multiplier = {
- ThreatLevel.LOW: 0.25,
- ThreatLevel.MEDIUM: 0.5,
- ThreatLevel.HIGH: 0.75,
- ThreatLevel.CRITICAL: 1.0
- }
- score += threat.confidence * level_multiplier.get(threat.level, 0.5)
-
- # Normalize to 0-1 range
- return min(score / len(threats), 1.0)
-
- def _generate_recommendations(self, threats: List[SecurityThreat], risk_score: float, auth_level: AuthLevel) -> List[str]:
- """Generate security recommendations based on analysis"""
- recommendations = []
-
- if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
- recommendations.append("CRITICAL: Block this request immediately")
- elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
- recommendations.append("HIGH: Consider blocking or rate limiting this IP")
- elif risk_score > 0.4:
- recommendations.append("MEDIUM: Monitor this IP closely")
-
- threat_types = {threat.threat_type for threat in threats}
-
- if "sql_injection" in threat_types:
- recommendations.append("Implement parameterized queries and input validation")
-
- if "xss" in threat_types:
- recommendations.append("Implement Content Security Policy (CSP) headers")
-
- if "command_injection" in threat_types:
- recommendations.append("Disable shell execution and validate all inputs")
-
- if "path_traversal" in threat_types:
- recommendations.append("Implement proper file path validation and access controls")
-
- if "rate_limit_exceeded" in threat_types:
- recommendations.append(f"Rate limiting active for {auth_level.value} user")
-
- if not recommendations:
- recommendations.append("No immediate action required, continue monitoring")
-
- return recommendations
-
- def _update_stats(self, threats: List[SecurityThreat], analysis_time: float):
- """Update service statistics"""
- self.stats['total_requests_analyzed'] += 1
- self.stats['total_analysis_time'] += analysis_time
-
- if threats:
- self.stats['threats_detected'] += len(threats)
- for threat in threats:
- self.stats['threat_types'][threat.threat_type] += 1
- self.stats['threat_levels'][threat.level.value] += 1
- if threat.source_ip:
- self.stats['attacking_ips'][threat.source_ip] += 1
-
- def get_stats(self) -> Dict[str, Any]:
- """Get service statistics"""
- avg_time = (self.stats['total_analysis_time'] / self.stats['total_requests_analyzed']
- if self.stats['total_requests_analyzed'] > 0 else 0)
-
- # Get top attacking IPs
- top_ips = sorted(self.stats['attacking_ips'].items(), key=lambda x: x[1], reverse=True)[:10]
-
- return {
- "total_requests_analyzed": self.stats['total_requests_analyzed'],
- "threats_detected": self.stats['threats_detected'],
- "threats_blocked": self.stats['threats_blocked'],
- "anomalies_detected": self.stats['anomalies_detected'],
- "rate_limits_exceeded": self.stats['rate_limits_exceeded'],
- "avg_analysis_time": avg_time,
- "threat_types": dict(self.stats['threat_types']),
- "threat_levels": dict(self.stats['threat_levels']),
- "top_attacking_ips": top_ips,
- "security_enabled": settings.API_SECURITY_ENABLED,
- "threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
- "rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED
- }
-
-
-# Global threat detection service instance
-threat_detection_service = ThreatDetectionService()
\ No newline at end of file
diff --git a/backend/app/main.py b/backend/app/main.py
index 8bea827..8c8b26d 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -52,10 +52,18 @@ async def lifespan(app: FastAPI):
# Initialize config manager
await init_config_manager()
-
+
+ # Initialize LLM service (needed by RAG module)
+ from app.services.llm.service import llm_service
+ try:
+ await llm_service.initialize()
+ logger.info("LLM service initialized successfully")
+ except Exception as e:
+ logger.warning(f"LLM service initialization failed: {e}")
+
# Initialize analytics service
init_analytics_service()
-
+
# Initialize module manager with FastAPI app for router registration
await module_manager.initialize(app)
app.state.module_manager = module_manager
diff --git a/backend/app/middleware/rate_limiting.py b/backend/app/middleware/rate_limiting.py
deleted file mode 100644
index f6e1901..0000000
--- a/backend/app/middleware/rate_limiting.py
+++ /dev/null
@@ -1,371 +0,0 @@
-"""
-Rate limiting middleware
-"""
-
-import time
-import redis
-from typing import Dict, Optional
-from fastapi import Request, HTTPException, status
-from fastapi.responses import JSONResponse
-from starlette.middleware.base import BaseHTTPMiddleware
-import asyncio
-from datetime import datetime, timedelta
-
-from app.core.config import settings
-from app.core.logging import get_logger
-
-logger = get_logger(__name__)
-
-
-class RateLimiter:
- """Rate limiting implementation using Redis"""
-
- def __init__(self):
- try:
- self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
- self.redis_client.ping() # Test connection
- logger.info("Rate limiter initialized with Redis backend")
- except Exception as e:
- logger.warning(f"Redis not available for rate limiting: {e}")
- self.redis_client = None
- # Fall back to in-memory rate limiting
- self.memory_store: Dict[str, Dict[str, float]] = {}
-
- async def check_rate_limit(
- self,
- key: str,
- limit: int,
- window_seconds: int,
- identifier: str = "default"
- ) -> tuple[bool, Dict[str, int]]:
- """
- Check if request is within rate limit
-
- Args:
- key: Rate limiting key (e.g., IP address, API key)
- limit: Maximum number of requests allowed
- window_seconds: Time window in seconds
- identifier: Additional identifier for the rate limit
-
- Returns:
- Tuple of (is_allowed, headers_dict)
- """
-
- full_key = f"rate_limit:{identifier}:{key}"
- current_time = int(time.time())
- window_start = current_time - window_seconds
-
- if self.redis_client:
- return await self._check_redis_rate_limit(
- full_key, limit, window_seconds, current_time, window_start
- )
- else:
- return self._check_memory_rate_limit(
- full_key, limit, window_seconds, current_time, window_start
- )
-
- async def _check_redis_rate_limit(
- self,
- key: str,
- limit: int,
- window_seconds: int,
- current_time: int,
- window_start: int
- ) -> tuple[bool, Dict[str, int]]:
- """Check rate limit using Redis"""
-
- pipe = self.redis_client.pipeline()
-
- # Remove old entries
- pipe.zremrangebyscore(key, 0, window_start)
-
- # Count current requests in window
- pipe.zcard(key)
-
- # Add current request
- pipe.zadd(key, {str(current_time): current_time})
-
- # Set expiration
- pipe.expire(key, window_seconds + 1)
-
- results = pipe.execute()
- current_requests = results[1]
-
- # Calculate remaining requests and reset time
- remaining = max(0, limit - current_requests - 1)
- reset_time = current_time + window_seconds
-
- headers = {
- "X-RateLimit-Limit": limit,
- "X-RateLimit-Remaining": remaining,
- "X-RateLimit-Reset": reset_time,
- "X-RateLimit-Window": window_seconds
- }
-
- is_allowed = current_requests < limit
-
- if not is_allowed:
- logger.warning(f"Rate limit exceeded for key: {key}")
-
- return is_allowed, headers
-
- def _check_memory_rate_limit(
- self,
- key: str,
- limit: int,
- window_seconds: int,
- current_time: int,
- window_start: int
- ) -> tuple[bool, Dict[str, int]]:
- """Check rate limit using in-memory storage"""
-
- if key not in self.memory_store:
- self.memory_store[key] = {}
-
- # Clean old entries
- store = self.memory_store[key]
- keys_to_remove = [k for k, v in store.items() if v < window_start]
- for k in keys_to_remove:
- del store[k]
-
- current_requests = len(store)
-
- # Calculate remaining requests and reset time
- remaining = max(0, limit - current_requests - 1)
- reset_time = current_time + window_seconds
-
- headers = {
- "X-RateLimit-Limit": limit,
- "X-RateLimit-Remaining": remaining,
- "X-RateLimit-Reset": reset_time,
- "X-RateLimit-Window": window_seconds
- }
-
- is_allowed = current_requests < limit
-
- if is_allowed:
- # Add current request
- store[str(current_time)] = current_time
- else:
- logger.warning(f"Rate limit exceeded for key: {key}")
-
- return is_allowed, headers
-
-
-# Global rate limiter instance
-rate_limiter = RateLimiter()
-
-
-class RateLimitMiddleware(BaseHTTPMiddleware):
- """Rate limiting middleware for FastAPI"""
-
- def __init__(self, app):
- super().__init__(app)
- self.rate_limiter = RateLimiter()
- logger.info("RateLimitMiddleware initialized")
-
- async def dispatch(self, request: Request, call_next):
- """Process request through rate limiting"""
-
- # Skip rate limiting if disabled in settings
- if not settings.API_RATE_LIMITING_ENABLED:
- response = await call_next(request)
- return response
-
- # Skip rate limiting for all internal API endpoints (platform operations)
- if request.url.path.startswith("/api-internal/v1/"):
- response = await call_next(request)
- return response
-
- # Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service)
- # Skip for all other endpoints
- if not (request.url.path.startswith("/api/v1/chat/completions") or
- request.url.path.startswith("/api/v1/embeddings") or
- request.url.path.startswith("/api/v1/models") or
- request.url.path.startswith("/api/v1/llm/")):
- response = await call_next(request)
- return response
-
- # Skip rate limiting for health checks and static files
- if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
- response = await call_next(request)
- return response
-
- # Get client IP
- client_ip = request.client.host
- forwarded_for = request.headers.get("X-Forwarded-For")
- if forwarded_for:
- client_ip = forwarded_for.split(",")[0].strip()
-
- # Check for API key in headers
- api_key = None
- auth_header = request.headers.get("Authorization")
- if auth_header and auth_header.startswith("Bearer "):
- api_key = auth_header[7:]
- elif request.headers.get("X-API-Key"):
- api_key = request.headers.get("X-API-Key")
-
- # Determine rate limiting strategy
- headers = {}
- is_allowed = True
-
- if api_key:
- # API key-based rate limiting
- api_key_key = f"api_key:{api_key}"
-
- # First check organization-wide limits (PrivateMode limits are org-wide)
- org_key = "organization:privatemode"
-
- # Check organization per-minute limit
- org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit(
- org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute"
- )
-
- # Check organization per-hour limit
- org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit(
- org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour"
- )
-
- # If organization limits are exceeded, return 429
- if not (org_allowed_minute and org_allowed_hour):
- logger.warning(f"Organization rate limit exceeded for {org_key}")
- return JSONResponse(
- status_code=status.HTTP_429_TOO_MANY_REQUESTS,
- content={"detail": "Organization rate limit exceeded"},
- headers=org_headers_minute
- )
-
- # Then check per-API key limits
- limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE
- limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR
-
- # Check per-minute limit
- is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
- api_key_key, limit_per_minute, 60, "minute"
- )
-
- # Check per-hour limit
- is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
- api_key_key, limit_per_hour, 3600, "hour"
- )
-
- is_allowed = is_allowed_minute and is_allowed_hour
- headers = headers_minute # Use minute headers for response
-
- else:
- # IP-based rate limiting for unauthenticated requests
- rate_limit_key = f"ip:{client_ip}"
-
- # More restrictive limits for unauthenticated requests
- limit_per_minute = 20 # Hardcoded for unauthenticated users
- limit_per_hour = 100
-
- # Check per-minute limit
- is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
- rate_limit_key, limit_per_minute, 60, "minute"
- )
-
- # Check per-hour limit
- is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
- rate_limit_key, limit_per_hour, 3600, "hour"
- )
-
- is_allowed = is_allowed_minute and is_allowed_hour
- headers = headers_minute # Use minute headers for response
-
- # If rate limit exceeded, return 429
- if not is_allowed:
- return JSONResponse(
- status_code=status.HTTP_429_TOO_MANY_REQUESTS,
- content={
- "error": "RATE_LIMIT_EXCEEDED",
- "message": "Rate limit exceeded. Please try again later.",
- "details": {
- "limit": headers["X-RateLimit-Limit"],
- "reset_time": headers["X-RateLimit-Reset"]
- }
- },
- headers={k: str(v) for k, v in headers.items()}
- )
-
- # Continue with request
- response = await call_next(request)
-
- # Add rate limit headers to response
- for key, value in headers.items():
- response.headers[key] = str(value)
-
- return response
-
-
-# Keep the old function for backward compatibility
-async def rate_limit_middleware(request: Request, call_next):
- """Legacy function - use RateLimitMiddleware class instead"""
- middleware = RateLimitMiddleware(None)
- return await middleware.dispatch(request, call_next)
-
-
-class RateLimitExceeded(HTTPException):
- """Exception raised when rate limit is exceeded"""
-
- def __init__(self, limit: int, reset_time: int):
- super().__init__(
- status_code=status.HTTP_429_TOO_MANY_REQUESTS,
- detail=f"Rate limit exceeded. Limit: {limit}, Reset: {reset_time}"
- )
-
-
-# Decorator for applying rate limits to specific endpoints
-def rate_limit(requests_per_minute: int = 60, requests_per_hour: int = 1000):
- """
- Decorator to apply rate limiting to specific endpoints
-
- Args:
- requests_per_minute: Maximum requests per minute
- requests_per_hour: Maximum requests per hour
- """
- def decorator(func):
- async def wrapper(*args, **kwargs):
- # This would be implemented to work with FastAPI dependencies
- # For now, this is a placeholder for endpoint-specific rate limiting
- return await func(*args, **kwargs)
- return wrapper
- return decorator
-
-
-# Helper functions for different rate limiting strategies
-async def check_api_key_rate_limit(api_key: str, endpoint: str) -> bool:
- """Check rate limit for specific API key and endpoint"""
-
- # This would lookup API key specific limits from database
- # For now, using default limits
- key = f"api_key:{api_key}:endpoint:{endpoint}"
-
- is_allowed, _ = await rate_limiter.check_rate_limit(
- key, limit=100, window_seconds=60, identifier="endpoint"
- )
-
- return is_allowed
-
-
-async def check_user_rate_limit(user_id: str, action: str) -> bool:
- """Check rate limit for specific user and action"""
-
- key = f"user:{user_id}:action:{action}"
-
- is_allowed, _ = await rate_limiter.check_rate_limit(
- key, limit=50, window_seconds=60, identifier="user_action"
- )
-
- return is_allowed
-
-
-async def apply_burst_protection(key: str) -> bool:
- """Apply burst protection for high-frequency actions"""
-
- # Allow burst of 10 requests in 10 seconds
- is_allowed, _ = await rate_limiter.check_rate_limit(
- key, limit=10, window_seconds=10, identifier="burst"
- )
-
- return is_allowed
\ No newline at end of file
diff --git a/backend/app/middleware/security.py b/backend/app/middleware/security.py
deleted file mode 100644
index c7b7952..0000000
--- a/backend/app/middleware/security.py
+++ /dev/null
@@ -1,210 +0,0 @@
-"""
-Security middleware for request/response processing
-"""
-
-import json
-import time
-from typing import Callable, Optional, Dict, Any
-
-from fastapi import Request, Response
-from fastapi.responses import JSONResponse
-from starlette.middleware.base import BaseHTTPMiddleware
-
-from app.core.config import settings
-from app.core.logging import get_logger
-from app.core.threat_detection import threat_detection_service, SecurityAnalysis
-
-logger = get_logger(__name__)
-
-
-class SecurityMiddleware(BaseHTTPMiddleware):
- """Security middleware for threat detection and request filtering - DISABLED"""
-
- def __init__(self, app, enabled: bool = True):
- super().__init__(app)
- self.enabled = False # Force disable regardless of settings
- logger.info("SecurityMiddleware initialized, enabled: False (DISABLED)")
-
- async def dispatch(self, request: Request, call_next: Callable) -> Response:
- """Process request through security analysis - DISABLED"""
- # Security disabled, always pass through
- return await call_next(request)
-
- def _should_skip_security(self, request: Request) -> bool:
- """Determine if security analysis should be skipped for this request"""
- path = request.url.path
-
- # Skip for health checks, authentication endpoints, and static assets
- skip_paths = [
- "/health",
- "/metrics",
- "/api/v1/docs",
- "/api/v1/openapi.json",
- "/api/v1/redoc",
- "/favicon.ico",
- "/api/v1/auth/register",
- "/api/v1/auth/login",
- "/api/v1/auth/refresh", # Allow refresh endpoint
- "/api-internal/v1/auth/register",
- "/api-internal/v1/auth/login",
- "/api-internal/v1/auth/refresh", # Allow refresh endpoint for internal API
- "/", # Root endpoint
- ]
-
- # Skip for static file extensions
- static_extensions = [".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".ico", ".svg", ".woff", ".woff2"]
-
- return (
- path in skip_paths or
- any(path.endswith(ext) for ext in static_extensions) or
- path.startswith("/static/")
- )
-
- def _has_valid_auth(self, request: Request) -> bool:
- """Check if request has valid authentication"""
- # Check Authorization header
- auth_header = request.headers.get("Authorization", "")
- api_key_header = request.headers.get("X-API-Key", "")
-
- # Has some form of auth token/key
- return (
- auth_header.startswith("Bearer ") and len(auth_header) > 7 or
- len(api_key_header.strip()) > 0
- )
-
- def _create_block_response(self, analysis: SecurityAnalysis) -> JSONResponse:
- """Create response for blocked requests"""
- # Determine status code based on threat type
- status_code = 403 # Forbidden by default
-
- # Critical threats get 403
- for threat in analysis.threats:
- if threat.threat_type in ["command_injection", "sql_injection"]:
- status_code = 403
- break
-
- response_data = {
- "error": "Security Policy Violation",
- "message": "Request blocked due to security policy violation",
- "risk_score": round(analysis.risk_score, 3),
- "auth_level": analysis.auth_level.value,
- "threat_count": len(analysis.threats),
- "recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations
- }
-
- response = JSONResponse(
- content=response_data,
- status_code=status_code
- )
-
- return response
-
- def _add_security_headers(self, response: Response) -> Response:
- """Add security headers to response"""
- if not settings.API_SECURITY_HEADERS_ENABLED:
- return response
-
- # Standard security headers
- response.headers["X-Content-Type-Options"] = "nosniff"
- response.headers["X-Frame-Options"] = "DENY"
- response.headers["X-XSS-Protection"] = "1; mode=block"
- response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
-
- # Only add HSTS for HTTPS
- if hasattr(response, 'headers') and response.headers.get("X-Forwarded-Proto") == "https":
- response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
-
- # Content Security Policy
- if settings.API_CSP_HEADER:
- response.headers["Content-Security-Policy"] = settings.API_CSP_HEADER
-
- return response
-
- def _add_security_metrics(self, response: Response, analysis: SecurityAnalysis, analysis_time: float) -> Response:
- """Add security metrics to response headers (for debugging/monitoring)"""
- # Only add in debug mode or for admin users
- if settings.APP_DEBUG:
- response.headers["X-Security-Risk-Score"] = str(round(analysis.risk_score, 3))
- response.headers["X-Security-Threats"] = str(len(analysis.threats))
- response.headers["X-Security-Auth-Level"] = analysis.auth_level.value
- response.headers["X-Security-Analysis-Time"] = f"{analysis_time*1000:.1f}ms"
-
- return response
-
- async def _log_security_event(self, request: Request, analysis: SecurityAnalysis):
- """Log security events for audit and monitoring"""
- client_ip = request.client.host if request.client else "unknown"
- user_agent = request.headers.get("user-agent", "")
-
- # Create security event log
- event_data = {
- "timestamp": analysis.timestamp.isoformat(),
- "client_ip": client_ip,
- "user_agent": user_agent,
- "path": str(request.url.path),
- "method": request.method,
- "risk_score": round(analysis.risk_score, 3),
- "auth_level": analysis.auth_level.value,
- "threat_count": len(analysis.threats),
- "rate_limit_exceeded": analysis.rate_limit_exceeded,
- "should_block": analysis.should_block,
- "threats": [
- {
- "type": threat.threat_type,
- "level": threat.level.value,
- "confidence": round(threat.confidence, 3),
- "description": threat.description
- }
- for threat in analysis.threats[:5] # Limit to first 5 threats
- ],
- "recommendations": analysis.recommendations
- }
-
- # Log at appropriate level based on risk
- if analysis.should_block:
- logger.warning(f"SECURITY_BLOCK: {json.dumps(event_data)}")
- elif analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
- logger.warning(f"SECURITY_WARNING: {json.dumps(event_data)}")
- else:
- logger.info(f"SECURITY_THREAT: {json.dumps(event_data)}")
-
-
-def setup_security_middleware(app, enabled: bool = True) -> None:
- """Setup security middleware on FastAPI app"""
- if enabled and settings.API_SECURITY_ENABLED:
- app.add_middleware(SecurityMiddleware, enabled=enabled)
- logger.info("Security middleware enabled")
- else:
- logger.info("Security middleware disabled")
-
-
-# Helper functions for manual security checks
-async def analyze_request_security(request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
- """Manually analyze request security (for use in route handlers)"""
- return await threat_detection_service.analyze_request(request, user_context)
-
-
-def get_security_stats() -> Dict[str, Any]:
- """Get security statistics"""
- return threat_detection_service.get_stats()
-
-
-def is_request_blocked(request: Request) -> bool:
- """Check if request was blocked by security analysis"""
- if hasattr(request.state, 'security_analysis'):
- return request.state.security_analysis.should_block
- return False
-
-
-def get_request_risk_score(request: Request) -> float:
- """Get risk score for request"""
- if hasattr(request.state, 'security_analysis'):
- return request.state.security_analysis.risk_score
- return 0.0
-
-
-def get_request_auth_level(request: Request) -> str:
- """Get authentication level for request"""
- if hasattr(request.state, 'security_analysis'):
- return request.state.security_analysis.auth_level.value
- return "unknown"
\ No newline at end of file
diff --git a/backend/app/services/document_processor.py b/backend/app/services/document_processor.py
index 8447333..8875ae8 100644
--- a/backend/app/services/document_processor.py
+++ b/backend/app/services/document_processor.py
@@ -162,6 +162,7 @@ class DocumentProcessor:
async def _process_document(self, task: ProcessingTask) -> bool:
"""Process a single document"""
+ from datetime import datetime
from app.db.database import async_session_factory
async with async_session_factory() as session:
try:
@@ -182,16 +183,24 @@ class DocumentProcessor:
document.status = ProcessingStatus.PROCESSING
await session.commit()
- # Get RAG module for processing (now includes content processing)
+ # Get RAG module for processing
try:
- from app.services.module_manager import module_manager
- rag_module = module_manager.get_module('rag')
+ # Import RAG module and initialize it properly
+ from modules.rag.main import RAGModule
+ from app.core.config import settings
+
+ # Create and initialize RAG module instance
+ rag_module = RAGModule(settings)
+ init_result = await rag_module.initialize()
+ if not rag_module.enabled:
+ raise Exception("Failed to enable RAG module")
+
except Exception as e:
logger.error(f"Failed to get RAG module: {e}")
raise Exception(f"RAG module not available: {e}")
-
- if not rag_module:
- raise Exception("RAG module not available")
+
+ if not rag_module or not rag_module.enabled:
+ raise Exception("RAG module not available or not enabled")
logger.info(f"RAG module loaded successfully for document {task.document_id}")
@@ -204,31 +213,45 @@ class DocumentProcessor:
# Process with RAG module
logger.info(f"Starting document processing for document {task.document_id} with RAG module")
- try:
- # Add timeout to prevent hanging
- processed_doc = await asyncio.wait_for(
- rag_module.process_document(
- file_content,
- document.original_filename,
- {}
- ),
- timeout=300.0 # 5 minute timeout
- )
- logger.info(f"Document processing completed for document {task.document_id}")
- except asyncio.TimeoutError:
- logger.error(f"Document processing timed out for document {task.document_id}")
- raise Exception("Document processing timed out after 5 minutes")
- except Exception as e:
- logger.error(f"Document processing failed for document {task.document_id}: {e}")
- raise
-
- # Update document with processed content
- document.converted_content = processed_doc.content
- document.word_count = processed_doc.word_count
- document.character_count = len(processed_doc.content)
- document.document_metadata = processed_doc.metadata
- document.status = ProcessingStatus.PROCESSED
- document.processed_at = datetime.utcnow()
+
+ # Special handling for JSONL files - skip processing phase
+ if document.file_type == 'jsonl':
+ # For JSONL files, we don't need to process content here
+ # The optimized JSONL processor will handle everything during indexing
+ document.converted_content = f"JSONL file with {len(file_content)} bytes"
+ document.word_count = 0 # Will be updated during indexing
+ document.character_count = len(file_content)
+ document.document_metadata = {"file_path": document.file_path, "processed": "jsonl"}
+ document.status = ProcessingStatus.PROCESSED
+ document.processed_at = datetime.utcnow()
+ logger.info(f"JSONL document {task.document_id} marked for optimized processing")
+ else:
+ # Standard processing for other file types
+ try:
+ # Add timeout to prevent hanging
+ processed_doc = await asyncio.wait_for(
+ rag_module.process_document(
+ file_content,
+ document.original_filename,
+ {"file_path": document.file_path}
+ ),
+ timeout=300.0 # 5 minute timeout
+ )
+ logger.info(f"Document processing completed for document {task.document_id}")
+
+ # Update document with processed content
+ document.converted_content = processed_doc.content
+ document.word_count = processed_doc.word_count
+ document.character_count = len(processed_doc.content)
+ document.document_metadata = processed_doc.metadata
+ document.status = ProcessingStatus.PROCESSED
+ document.processed_at = datetime.utcnow()
+ except asyncio.TimeoutError:
+ logger.error(f"Document processing timed out for document {task.document_id}")
+ raise Exception("Document processing timed out after 5 minutes")
+ except Exception as e:
+ logger.error(f"Document processing failed for document {task.document_id}: {e}")
+ raise
# Index in RAG system using same RAG module
if rag_module and document.converted_content:
@@ -245,14 +268,57 @@ class DocumentProcessor:
}
# Use the correct Qdrant collection name for this document
- await asyncio.wait_for(
- rag_module.index_document(
- content=document.converted_content,
- metadata=doc_metadata,
- collection_name=document.collection.qdrant_collection_name
- ),
- timeout=120.0 # 2 minute timeout for indexing
- )
+ # For JSONL files, we need to use the processed document flow
+ if document.file_type == 'jsonl':
+ # Create a ProcessedDocument for the JSONL processor
+ from app.modules.rag.main import ProcessedDocument
+ from datetime import datetime
+ import hashlib
+
+ # Calculate file hash
+ processed_at = datetime.utcnow()
+ file_hash = hashlib.md5(str(document.id).encode()).hexdigest()
+
+ processed_doc = ProcessedDocument(
+ id=str(document.id),
+ content="", # Will be filled by JSONL processor
+ extracted_text="", # Will be filled by JSONL processor
+ metadata={
+ **doc_metadata,
+ "file_path": document.file_path
+ },
+ original_filename=document.original_filename,
+ file_type=document.file_type,
+ mime_type=document.mime_type,
+ language=document.document_metadata.get('language', 'EN'),
+ word_count=0, # Will be updated during processing
+ sentence_count=0, # Will be updated during processing
+ entities=[],
+ keywords=[],
+ processing_time=0.0,
+ processed_at=processed_at,
+ file_hash=file_hash,
+ file_size=document.file_size
+ )
+
+ # The JSONL processor will read the original file
+ await asyncio.wait_for(
+ rag_module.index_processed_document(
+ processed_doc=processed_doc,
+ collection_name=document.collection.qdrant_collection_name
+ ),
+ timeout=300.0 # 5 minute timeout for JSONL processing
+ )
+ else:
+ # Use standard indexing for other file types
+ await asyncio.wait_for(
+ rag_module.index_document(
+ content=document.converted_content,
+ metadata=doc_metadata,
+ collection_name=document.collection.qdrant_collection_name
+ ),
+ timeout=120.0 # 2 minute timeout for indexing
+ )
logger.info(f"Document {task.document_id} indexed successfully in collection {document.collection.qdrant_collection_name}")
@@ -271,7 +337,9 @@ class DocumentProcessor:
except Exception as e:
logger.error(f"Failed to index document {task.document_id} in RAG: {e}")
- # Keep as processed even if indexing fails
+ # Mark as error since indexing failed
+ document.status = ProcessingStatus.ERROR
+ document.processing_error = f"Indexing failed: {str(e)}"
# Don't raise the exception to avoid retries on indexing failures
await session.commit()
diff --git a/backend/app/services/embedding_service.py b/backend/app/services/embedding_service.py
index 4032086..ab7e04f 100644
--- a/backend/app/services/embedding_service.py
+++ b/backend/app/services/embedding_service.py
@@ -28,9 +28,19 @@ class EmbeddingService:
await llm_service.initialize()
# Test LLM service health
- health_summary = llm_service.get_health_summary()
- if health_summary.get("service_status") != "healthy":
- logger.error(f"LLM service unhealthy: {health_summary}")
+ if not llm_service._initialized:
+ logger.error("LLM service not initialized")
+ return False
+
+ # Check if PrivateMode provider is available
+ try:
+ provider_status = await llm_service.get_provider_status()
+ privatemode_status = provider_status.get("privatemode")
+ if not privatemode_status or privatemode_status.status != "healthy":
+ logger.error(f"PrivateMode provider not available: {privatemode_status}")
+ return False
+ except Exception as e:
+ logger.error(f"Failed to check provider status: {e}")
return False
self.initialized = True
@@ -75,6 +85,12 @@ class EmbeddingService:
else:
truncated_text = text
+ # Guard: skip empty inputs (validator rejects empty strings)
+ if not truncated_text.strip():
+ logger.debug("Empty input for embedding; using fallback vector")
+ batch_embeddings.append(self._generate_fallback_embedding(text))
+ continue
+
# Call LLM service embedding endpoint
from app.services.llm.service import llm_service
from app.services.llm.models import EmbeddingRequest
@@ -163,4 +179,4 @@ class EmbeddingService:
# Global embedding service instance
-embedding_service = EmbeddingService()
\ No newline at end of file
+embedding_service = EmbeddingService()
diff --git a/backend/app/services/enhanced_embedding_service.py b/backend/app/services/enhanced_embedding_service.py
index 284773f..cc66e42 100644
--- a/backend/app/services/enhanced_embedding_service.py
+++ b/backend/app/services/enhanced_embedding_service.py
@@ -25,9 +25,10 @@ class EnhancedEmbeddingService(EmbeddingService):
'requests_count': 0,
'window_start': time.time(),
'window_size': 60, # 1 minute window
- 'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 60)), # Configurable
+ 'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 12)), # Configurable
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
- 'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 0.5)),
+ 'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 1.0)),
+ 'delay_per_request': float(getattr(settings, 'RAG_EMBEDDING_DELAY_PER_REQUEST', 0.5)),
'last_rate_limit_error': None
}
@@ -38,7 +39,7 @@ class EnhancedEmbeddingService(EmbeddingService):
if max_retries is None:
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
- batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 5))
+ batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 3))
if not self.initialized:
logger.warning("Embedding service not initialized, using fallback")
@@ -76,9 +77,6 @@ class EnhancedEmbeddingService(EmbeddingService):
# Make the request
embeddings = await self._get_embeddings_batch_impl(texts)
- # Update rate limit tracker on success
- self._update_rate_limit_tracker(success=True)
-
return embeddings, True
except Exception as e:
@@ -120,6 +118,12 @@ class EnhancedEmbeddingService(EmbeddingService):
embeddings = []
for text in texts:
+ # Respect rate limit before each request
+ while self._is_rate_limited():
+ delay = self._get_rate_limit_delay()
+ logger.warning(f"Rate limit window exceeded, waiting {delay:.2f}s before next request")
+ await asyncio.sleep(delay)
+
# Truncate text if needed
max_chars = 1600
truncated_text = text[:max_chars] if len(text) > max_chars else text
@@ -142,8 +146,14 @@ class EnhancedEmbeddingService(EmbeddingService):
self._dimension_confirmed = True
else:
raise ValueError("Empty embedding in response")
- else:
- raise ValueError("Invalid response structure")
+ else:
+ raise ValueError("Invalid response structure")
+
+ # Count this successful request and optionally delay between requests
+ self._update_rate_limit_tracker(success=True)
+ per_req_delay = self.rate_limit_tracker.get('delay_per_request', 0)
+ if per_req_delay and per_req_delay > 0:
+ await asyncio.sleep(per_req_delay)
return embeddings
@@ -198,4 +208,4 @@ class EnhancedEmbeddingService(EmbeddingService):
# Global enhanced embedding service instance
-enhanced_embedding_service = EnhancedEmbeddingService()
\ No newline at end of file
+enhanced_embedding_service = EnhancedEmbeddingService()
diff --git a/backend/app/services/llm/config.py b/backend/app/services/llm/config.py
index 61a8576..b7aeb13 100644
--- a/backend/app/services/llm/config.py
+++ b/backend/app/services/llm/config.py
@@ -16,6 +16,7 @@ from .models import ResilienceConfig
class ProviderConfig(BaseModel):
"""Configuration for an LLM provider"""
name: str = Field(..., description="Provider name")
+ provider_type: str = Field(..., description="Provider type (e.g., 'openai', 'privatemode')")
enabled: bool = Field(True, description="Whether provider is enabled")
base_url: str = Field(..., description="Provider base URL")
api_key_env_var: str = Field(..., description="Environment variable for API key")
@@ -53,9 +54,6 @@ class LLMServiceConfig(BaseModel):
enable_security_checks: bool = Field(True, description="Enable security validation")
enable_metrics_collection: bool = Field(True, description="Enable metrics collection")
- # Security settings
- security_risk_threshold: float = Field(0.8, ge=0.0, le=1.0, description="Risk threshold for blocking")
- security_warning_threshold: float = Field(0.6, ge=0.0, le=1.0, description="Risk threshold for warnings")
max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length")
max_response_length: int = Field(32000, ge=1000, description="Maximum response length")
@@ -78,12 +76,6 @@ class LLMServiceConfig(BaseModel):
# Model routing (model_name -> provider_name)
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
- @validator('security_risk_threshold')
- def validate_risk_threshold(cls, v, values):
- warning_threshold = values.get('security_warning_threshold', 0.6)
- if v <= warning_threshold:
- raise ValueError("Risk threshold must be greater than warning threshold")
- return v
def create_default_config() -> LLMServiceConfig:
@@ -93,6 +85,7 @@ def create_default_config() -> LLMServiceConfig:
# Models will be fetched dynamically from proxy /models endpoint
privatemode_config = ProviderConfig(
name="privatemode",
+ provider_type="privatemode",
enabled=True,
base_url=settings.PRIVATEMODE_PROXY_URL,
api_key_env_var="PRIVATEMODE_API_KEY",
@@ -119,9 +112,6 @@ def create_default_config() -> LLMServiceConfig:
config = LLMServiceConfig(
default_provider="privatemode",
enable_detailed_logging=settings.LOG_LLM_PROMPTS,
- enable_security_checks=settings.API_SECURITY_ENABLED,
- security_risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
- security_warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
providers={
"privatemode": privatemode_config
},
diff --git a/backend/app/services/llm/metrics.py b/backend/app/services/llm/metrics.py
index 542dd7d..9a35fc4 100644
--- a/backend/app/services/llm/metrics.py
+++ b/backend/app/services/llm/metrics.py
@@ -124,7 +124,6 @@ class MetricsCollector:
total_requests = len(self._metrics)
successful_requests = sum(1 for m in self._metrics if m.success)
failed_requests = total_requests - successful_requests
- security_blocked = sum(1 for m in self._metrics if not m.success and m.security_risk_score > 0.8)
# Calculate averages
latencies = [m.latency_ms for m in self._metrics if m.latency_ms > 0]
@@ -143,7 +142,6 @@ class MetricsCollector:
total_requests=total_requests,
successful_requests=successful_requests,
failed_requests=failed_requests,
- security_blocked_requests=security_blocked,
average_latency_ms=avg_latency,
average_risk_score=avg_risk_score,
provider_metrics=provider_metrics,
diff --git a/backend/app/services/llm/models.py b/backend/app/services/llm/models.py
index 903451d..b699b2c 100644
--- a/backend/app/services/llm/models.py
+++ b/backend/app/services/llm/models.py
@@ -157,7 +157,6 @@ class LLMMetrics(BaseModel):
total_requests: int = Field(0, description="Total requests processed")
successful_requests: int = Field(0, description="Successful requests")
failed_requests: int = Field(0, description="Failed requests")
- security_blocked_requests: int = Field(0, description="Security blocked requests")
average_latency_ms: float = Field(0.0, description="Average response latency")
average_risk_score: float = Field(0.0, description="Average security risk score")
provider_metrics: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Per-provider metrics")
diff --git a/backend/app/services/llm/providers/privatemode.py b/backend/app/services/llm/providers/privatemode.py
index 63f18ad..b136ccb 100644
--- a/backend/app/services/llm/providers/privatemode.py
+++ b/backend/app/services/llm/providers/privatemode.py
@@ -452,6 +452,8 @@ class PrivateModeProvider(BaseLLMProvider):
else:
error_text = await response.text()
+ # Log the detailed error response from the provider
+ logger.error(f"PrivateMode embedding error - Status {response.status}: {error_text}")
self._handle_http_error(response.status, error_text, "embeddings")
except aiohttp.ClientError as e:
diff --git a/backend/app/services/llm/security.py b/backend/app/services/llm/security.py
deleted file mode 100644
index 8aa37be..0000000
--- a/backend/app/services/llm/security.py
+++ /dev/null
@@ -1,325 +0,0 @@
-"""
-LLM Security Manager
-
-Handles prompt injection detection and audit logging.
-Provides comprehensive security for LLM interactions.
-"""
-
-import os
-import re
-import json
-import logging
-import hashlib
-from typing import Dict, Any, List, Optional, Tuple
-from datetime import datetime
-
-from app.core.config import settings
-
-logger = logging.getLogger(__name__)
-
-
-class SecurityManager:
- """Manages security for LLM operations"""
-
- def __init__(self):
- self._setup_prompt_injection_patterns()
-
-
- def _setup_prompt_injection_patterns(self):
- """Setup patterns for prompt injection detection"""
- self.injection_patterns = [
- # Direct instruction injection
- r"(?i)(ignore|forget|disregard|override).{0,20}(instructions|rules|prompts)",
- r"(?i)(new|updated|different)\s+(instructions|rules|system)",
- r"(?i)act\s+as\s+(if|though)\s+you\s+(are|were)",
- r"(?i)pretend\s+(to\s+be|you\s+are)",
- r"(?i)you\s+are\s+now\s+(a|an)\s+",
-
- # System role manipulation
- r"(?i)system\s*:\s*",
- r"(?i)\[system\]",
- r"(?i)",
- r"(?i)assistant\s*:\s*",
- r"(?i)\[assistant\]",
-
- # Escape attempts
- r"(?i)\\n\\n#+",
- r"(?i)```\s*(system|assistant|user)",
- r"(?i)---\s*(new|system|override)",
-
- # Role manipulation
- r"(?i)(you|your)\s+(role|purpose|function)\s+(is|has\s+changed)",
- r"(?i)switch\s+to\s+(admin|developer|debug)\s+mode",
- r"(?i)(admin|root|sudo|developer)\s+(access|mode|privileges)",
-
- # Information extraction attempts
- r"(?i)(show|display|reveal|expose)\s+(your|the)\s+(prompt|instructions|system)",
- r"(?i)what\s+(are|were)\s+your\s+(original|initial)\s+(instructions|prompts)",
- r"(?i)(debug|verbose|diagnostic)\s+mode",
-
- # Encoding/obfuscation attempts
- r"(?i)base64\s*:",
- r"(?i)hex\s*:",
- r"(?i)unicode\s*:",
- r"(?i)\b[A-Za-z0-9+/]{40,}={0,2}\b", # More specific base64 pattern (longer sequences)
-
- # SQL injection patterns (more specific to reduce false positives)
- r"(?i)(union\s+select|select\s+\*|insert\s+into|update\s+\w+\s+set|delete\s+from|drop\s+table|create\s+table)\s",
- r"(?i)(or|and)\s+\d+\s*=\s*\d+",
- r"(?i)';?\s*(drop\s+table|delete\s+from|insert\s+into)",
-
- # Command injection patterns
- r"(?i)(exec|eval|system|shell|cmd)\s*\(",
- r"(?i)(\$\(|\`)[^)]+(\)|\`)",
- r"(?i)&&\s*(rm|del|format)",
-
- # Jailbreak attempts
- r"(?i)jailbreak",
- r"(?i)break\s+out\s+of",
- r"(?i)escape\s+(the|your)\s+(rules|constraints)",
- r"(?i)(DAN|Do\s+Anything\s+Now)",
- r"(?i)unrestricted\s+mode",
- ]
-
- self.compiled_patterns = [re.compile(pattern) for pattern in self.injection_patterns]
- logger.info(f"Initialized {len(self.injection_patterns)} prompt injection patterns")
-
-
- def validate_prompt_security(self, messages: List[Dict[str, str]]) -> Tuple[bool, float, List[str]]:
- """
- Validate messages for prompt injection attempts
-
- Returns:
- Tuple[bool, float, List[str]]: (is_safe, risk_score, detected_patterns)
- """
- detected_patterns = []
- total_risk = 0.0
-
- # Check if this is a system/RAG request
- is_system_request = self._is_system_request(messages)
-
- for message in messages:
- content = message.get("content", "")
- if not content:
- continue
-
- # Check against injection patterns with context awareness
- for i, pattern in enumerate(self.compiled_patterns):
- matches = pattern.findall(content)
- if matches:
- # Apply context-aware risk calculation
- pattern_risk = self._calculate_pattern_risk(i, matches, message.get("role", "user"), is_system_request)
- total_risk += pattern_risk
- detected_patterns.append({
- "pattern_index": i,
- "pattern": self.injection_patterns[i],
- "matches": matches,
- "risk": pattern_risk
- })
-
- # Additional security checks with context awareness
- total_risk += self._check_message_characteristics(content, message.get("role", "user"), is_system_request)
-
- # Normalize risk score (0.0 to 1.0)
- risk_score = min(total_risk / len(messages) if messages else 0.0, 1.0)
- # Never block - always return True for is_safe
- is_safe = True
-
- if detected_patterns:
- logger.info(f"Detected {len(detected_patterns)} potential injection patterns, risk score: {risk_score} (system_request: {is_system_request})")
-
- return is_safe, risk_score, detected_patterns
-
- def _calculate_pattern_risk(self, pattern_index: int, matches: List, role: str, is_system_request: bool) -> float:
- """Calculate risk score for a detected pattern with context awareness"""
- # Different patterns have different risk levels
- high_risk_patterns = [0, 1, 2, 3, 4, 5, 6, 7, 22, 23, 24] # System manipulation, jailbreak
- medium_risk_patterns = [8, 9, 10, 11, 12, 13, 17, 18, 19, 20, 21] # Escape attempts, info extraction
-
- # Base risk score
- base_risk = 0.8 if pattern_index in high_risk_patterns else 0.5 if pattern_index in medium_risk_patterns else 0.3
-
- # Apply context-specific risk reduction
- if is_system_request or role == "system":
- # Reduce risk for system messages and RAG content
- if pattern_index in [14, 15, 16]: # Encoding patterns (base64, hex, unicode)
- base_risk *= 0.2 # Reduce encoding risk by 80% for system content
- elif pattern_index in [17, 18, 19]: # SQL patterns
- base_risk *= 0.3 # Reduce SQL risk by 70% for system content
- else:
- base_risk *= 0.6 # Reduce other risks by 40% for system content
-
- # Increase risk based on number of matches, but cap it
- match_multiplier = min(1.0 + (len(matches) - 1) * 0.1, 1.5) # Reduced multiplier
-
- return base_risk * match_multiplier
-
- def _check_message_characteristics(self, content: str, role: str, is_system_request: bool) -> float:
- """Check message characteristics for additional risk factors with context awareness"""
- risk = 0.0
-
- # Excessive length (potential stuffing attack) - less restrictive for system content
- length_threshold = 50000 if is_system_request else 10000 # Much higher threshold for system content
- if len(content) > length_threshold:
- risk += 0.1 if is_system_request else 0.3
-
- # High ratio of special characters - more lenient for system content
- special_chars = sum(1 for c in content if not c.isalnum() and not c.isspace())
- if len(content) > 0:
- char_ratio = special_chars / len(content)
- threshold = 0.8 if is_system_request else 0.5
- if char_ratio > threshold:
- risk += 0.2 if is_system_request else 0.4
-
- # Multiple encoding indicators - reduced risk for system content
- encoding_indicators = ["base64", "hex", "unicode", "url", "ascii"]
- found_encodings = sum(1 for indicator in encoding_indicators if indicator.lower() in content.lower())
- if found_encodings > 1:
- risk += 0.1 if is_system_request else 0.3
-
- # Excessive newlines or formatting - more lenient for system content
- newline_threshold = 200 if is_system_request else 50
- if content.count('\n') > newline_threshold or content.count('\\n') > newline_threshold:
- risk += 0.1 if is_system_request else 0.2
-
- return risk
-
- def _is_system_request(self, messages: List[Dict[str, str]]) -> bool:
- """Determine if this is a system/RAG request"""
- if not messages:
- return False
-
- # Check for system messages
- for message in messages:
- if message.get("role") == "system":
- return True
-
- # Check message content for RAG indicators
- for message in messages:
- content = message.get("content", "")
- if ("document:" in content.lower() or
- "context:" in content.lower() or
- "source:" in content.lower() or
- "retrieved:" in content.lower() or
- "citation:" in content.lower() or
- "reference:" in content.lower()):
- return True
-
- return False
-
- def create_audit_log(
- self,
- user_id: str,
- api_key_id: int,
- provider: str,
- model: str,
- request_type: str,
- risk_score: float,
- detected_patterns: List[str],
- metadata: Optional[Dict[str, Any]] = None
- ) -> Dict[str, Any]:
- """Create comprehensive audit log for LLM request"""
- audit_entry = {
- "timestamp": datetime.utcnow().isoformat(),
- "user_id": user_id,
- "api_key_id": api_key_id,
- "provider": provider,
- "model": model,
- "request_type": request_type,
- "security": {
- "risk_score": risk_score,
- "detected_patterns": detected_patterns,
- "security_check_passed": risk_score < settings.API_SECURITY_RISK_THRESHOLD
- },
- "metadata": metadata or {},
- "audit_hash": None # Will be set below
- }
-
- # Create hash for audit integrity
- audit_hash = self._create_audit_hash(audit_entry)
- audit_entry["audit_hash"] = audit_hash
-
- # Log based on risk level (never block, only log)
- if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
- logger.warning(f"HIGH RISK LLM REQUEST DETECTED (NOT BLOCKED): {json.dumps(audit_entry)}")
- elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
- logger.info(f"MEDIUM RISK LLM REQUEST: {json.dumps(audit_entry)}")
- else:
- logger.info(f"LLM REQUEST AUDIT: user={user_id}, model={model}, risk={risk_score:.3f}")
-
- return audit_entry
-
- def _create_audit_hash(self, audit_entry: Dict[str, Any]) -> str:
- """Create hash for audit trail integrity"""
- # Create hash from key fields (excluding the hash itself)
- hash_data = {
- "timestamp": audit_entry["timestamp"],
- "user_id": audit_entry["user_id"],
- "api_key_id": audit_entry["api_key_id"],
- "provider": audit_entry["provider"],
- "model": audit_entry["model"],
- "request_type": audit_entry["request_type"],
- "risk_score": audit_entry["security"]["risk_score"]
- }
-
- hash_string = json.dumps(hash_data, sort_keys=True)
- return hashlib.sha256(hash_string.encode()).hexdigest()
-
- def log_detailed_request(
- self,
- messages: List[Dict[str, str]],
- model: str,
- user_id: str,
- provider: str,
- context_info: Optional[Dict[str, Any]] = None
- ):
- """Log detailed LLM request if LOG_LLM_PROMPTS is enabled"""
- if not settings.LOG_LLM_PROMPTS:
- return
-
- logger.info("=== DETAILED LLM REQUEST ===")
- logger.info(f"Model: {model}")
- logger.info(f"Provider: {provider}")
- logger.info(f"User ID: {user_id}")
-
- if context_info:
- for key, value in context_info.items():
- logger.info(f"{key}: {value}")
-
- logger.info("Messages to LLM:")
- for i, message in enumerate(messages):
- role = message.get("role", "unknown")
- content = message.get("content", "")[:500] # Truncate for logging
- logger.info(f" Message {i+1} [{role}]: {content}{'...' if len(message.get('content', '')) > 500 else ''}")
-
- logger.info("=== END DETAILED LLM REQUEST ===")
-
- def log_detailed_response(
- self,
- response_content: str,
- token_usage: Optional[Dict[str, int]] = None,
- provider: str = "unknown"
- ):
- """Log detailed LLM response if LOG_LLM_PROMPTS is enabled"""
- if not settings.LOG_LLM_PROMPTS:
- return
-
- logger.info("=== DETAILED LLM RESPONSE ===")
- logger.info(f"Provider: {provider}")
- logger.info(f"Response content: {response_content[:500]}{'...' if len(response_content) > 500 else ''}")
-
- if token_usage:
- logger.info(f"Token usage - Prompt: {token_usage.get('prompt_tokens', 0)}, "
- f"Completion: {token_usage.get('completion_tokens', 0)}, "
- f"Total: {token_usage.get('total_tokens', 0)}")
-
- logger.info("=== END DETAILED LLM RESPONSE ===")
-
-
-class SecurityError(Exception):
- """Security-related errors in LLM operations"""
- pass
-
-
-# Global security manager instance
-security_manager = SecurityManager()
\ No newline at end of file
diff --git a/backend/app/services/llm/service.py b/backend/app/services/llm/service.py
index bb8e683..d3f2503 100644
--- a/backend/app/services/llm/service.py
+++ b/backend/app/services/llm/service.py
@@ -17,9 +17,8 @@ from .models import (
)
from .config import config_manager, ProviderConfig
from ...core.config import settings
-from .security import security_manager
from .resilience import ResilienceManagerFactory
-from .metrics import metrics_collector
+# from .metrics import metrics_collector
from .providers import BaseLLMProvider, PrivateModeProvider
from .exceptions import (
LLMError, ProviderError, SecurityError, ConfigurationError,
@@ -150,45 +149,8 @@ class LLMService:
if not request.messages:
raise ValidationError("Messages cannot be empty", field="messages")
- # Security validation (only if enabled)
- messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages]
-
- if settings.API_SECURITY_ENABLED:
- is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict)
- else:
- # Security disabled - always safe
- is_safe, risk_score, detected_patterns = True, 0.0, []
-
- if not is_safe:
- # Log security violation
- security_manager.create_audit_log(
- user_id=request.user_id,
- api_key_id=request.api_key_id,
- provider="blocked",
- model=request.model,
- request_type="chat_completion",
- risk_score=risk_score,
- detected_patterns=[p.get("pattern", "") for p in detected_patterns]
- )
-
- # Record blocked request
- metrics_collector.record_request(
- provider="security",
- model=request.model,
- request_type="chat_completion",
- success=False,
- latency_ms=0,
- security_risk_score=risk_score,
- error_code="SECURITY_BLOCKED",
- user_id=request.user_id,
- api_key_id=request.api_key_id
- )
-
- raise SecurityError(
- "Request blocked due to security concerns",
- risk_score=risk_score,
- details={"detected_patterns": detected_patterns}
- )
+ # Security validation disabled - always allow requests
+ risk_score = 0.0
# Get provider for model
provider_name = self._get_provider_for_model(request.model)
@@ -197,18 +159,7 @@ class LLMService:
if not provider:
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
- # Log detailed request if enabled
- security_manager.log_detailed_request(
- messages=messages_dict,
- model=request.model,
- user_id=request.user_id,
- provider=provider_name,
- context_info={
- "temperature": request.temperature,
- "max_tokens": request.max_tokens,
- "risk_score": f"{risk_score:.3f}"
- }
- )
+ # Security logging disabled
# Execute with resilience
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
@@ -222,85 +173,46 @@ class LLMService:
non_retryable_exceptions=(SecurityError, ValidationError)
)
- # Update response with security information
- response.security_check = is_safe
- response.risk_score = risk_score
- response.detected_patterns = [p.get("pattern", "") for p in detected_patterns]
+ # Security features disabled
- # Log detailed response if enabled
- if response.choices:
- content = response.choices[0].message.content
- security_manager.log_detailed_response(
- response_content=content,
- token_usage=response.usage.model_dump() if response.usage else None,
- provider=provider_name
- )
+ # Security logging disabled
- # Record successful request
+ # Record successful request - metrics disabled
total_latency = (time.time() - start_time) * 1000
- metrics_collector.record_request(
- provider=provider_name,
- model=request.model,
- request_type="chat_completion",
- success=True,
- latency_ms=total_latency,
- token_usage=response.usage.model_dump() if response.usage else None,
- security_risk_score=risk_score,
- user_id=request.user_id,
- api_key_id=request.api_key_id
- )
+ # metrics_collector.record_request(
+ # provider=provider_name,
+ # model=request.model,
+ # request_type="chat_completion",
+ # success=True,
+ # latency_ms=total_latency,
+ # token_usage=response.usage.model_dump() if response.usage else None,
+ # security_risk_score=risk_score,
+ # user_id=request.user_id,
+ # api_key_id=request.api_key_id
+ # )
- # Create audit log
- security_manager.create_audit_log(
- user_id=request.user_id,
- api_key_id=request.api_key_id,
- provider=provider_name,
- model=request.model,
- request_type="chat_completion",
- risk_score=risk_score,
- detected_patterns=[p.get("pattern", "") for p in detected_patterns],
- metadata={
- "success": True,
- "latency_ms": total_latency,
- "token_usage": response.usage.model_dump() if response.usage else None
- }
- )
+ # Security audit logging disabled
return response
except Exception as e:
- # Record failed request
+ # Record failed request - metrics disabled
total_latency = (time.time() - start_time) * 1000
error_code = getattr(e, 'error_code', e.__class__.__name__)
+
+ # metrics_collector.record_request(
+ # provider=provider_name,
+ # model=request.model,
+ # request_type="chat_completion",
+ # success=False,
+ # latency_ms=total_latency,
+ # security_risk_score=risk_score,
+ # error_code=error_code,
+ # user_id=request.user_id,
+ # api_key_id=request.api_key_id
+ # )
- metrics_collector.record_request(
- provider=provider_name,
- model=request.model,
- request_type="chat_completion",
- success=False,
- latency_ms=total_latency,
- security_risk_score=risk_score,
- error_code=error_code,
- user_id=request.user_id,
- api_key_id=request.api_key_id
- )
-
- # Create audit log for failure
- security_manager.create_audit_log(
- user_id=request.user_id,
- api_key_id=request.api_key_id,
- provider=provider_name,
- model=request.model,
- request_type="chat_completion",
- risk_score=risk_score,
- detected_patterns=[p.get("pattern", "") for p in detected_patterns],
- metadata={
- "success": False,
- "error": str(e),
- "error_code": error_code,
- "latency_ms": total_latency
- }
- )
+ # Security audit logging disabled
raise
@@ -309,21 +221,8 @@ class LLMService:
if not self._initialized:
await self.initialize()
- # Security validation (same as non-streaming)
- messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages]
-
- if settings.API_SECURITY_ENABLED:
- is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict)
- else:
- # Security disabled - always safe
- is_safe, risk_score, detected_patterns = True, 0.0, []
-
- if not is_safe:
- raise SecurityError(
- "Streaming request blocked due to security concerns",
- risk_score=risk_score,
- details={"detected_patterns": detected_patterns}
- )
+ # Security validation disabled - always allow streaming requests
+ risk_score = 0.0
# Get provider
provider_name = self._get_provider_for_model(request.model)
@@ -345,19 +244,19 @@ class LLMService:
yield chunk
except Exception as e:
- # Record streaming failure
+ # Record streaming failure - metrics disabled
error_code = getattr(e, 'error_code', e.__class__.__name__)
- metrics_collector.record_request(
- provider=provider_name,
- model=request.model,
- request_type="chat_completion_stream",
- success=False,
- latency_ms=0,
- security_risk_score=risk_score,
- error_code=error_code,
- user_id=request.user_id,
- api_key_id=request.api_key_id
- )
+ # metrics_collector.record_request(
+ # provider=provider_name,
+ # model=request.model,
+ # request_type="chat_completion_stream",
+ # success=False,
+ # latency_ms=0,
+ # security_risk_score=risk_score,
+ # error_code=error_code,
+ # user_id=request.user_id,
+ # api_key_id=request.api_key_id
+ # )
raise
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
@@ -365,23 +264,8 @@ class LLMService:
if not self._initialized:
await self.initialize()
- # Security validation for embedding input
- input_text = request.input if isinstance(request.input, str) else " ".join(request.input)
-
- if settings.API_SECURITY_ENABLED:
- is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security([
- {"role": "user", "content": input_text}
- ])
- else:
- # Security disabled - always safe
- is_safe, risk_score, detected_patterns = True, 0.0, []
-
- if not is_safe:
- raise SecurityError(
- "Embedding request blocked due to security concerns",
- risk_score=risk_score,
- details={"detected_patterns": detected_patterns}
- )
+ # Security validation disabled - always allow embedding requests
+ risk_score = 0.0
# Get provider
provider_name = self._get_provider_for_model(request.model)
@@ -402,42 +286,40 @@ class LLMService:
non_retryable_exceptions=(SecurityError, ValidationError)
)
- # Update response with security information
- response.security_check = is_safe
- response.risk_score = risk_score
+ # Security features disabled
- # Record successful request
+ # Record successful request - metrics disabled
total_latency = (time.time() - start_time) * 1000
- metrics_collector.record_request(
- provider=provider_name,
- model=request.model,
- request_type="embedding",
- success=True,
- latency_ms=total_latency,
- token_usage=response.usage.model_dump() if response.usage else None,
- security_risk_score=risk_score,
- user_id=request.user_id,
- api_key_id=request.api_key_id
- )
+ # metrics_collector.record_request(
+ # provider=provider_name,
+ # model=request.model,
+ # request_type="embedding",
+ # success=True,
+ # latency_ms=total_latency,
+ # token_usage=response.usage.model_dump() if response.usage else None,
+ # security_risk_score=risk_score,
+ # user_id=request.user_id,
+ # api_key_id=request.api_key_id
+ # )
return response
except Exception as e:
- # Record failed request
+ # Record failed request - metrics disabled
total_latency = (time.time() - start_time) * 1000
error_code = getattr(e, 'error_code', e.__class__.__name__)
-
- metrics_collector.record_request(
- provider=provider_name,
- model=request.model,
- request_type="embedding",
- success=False,
- latency_ms=total_latency,
- security_risk_score=risk_score,
- error_code=error_code,
- user_id=request.user_id,
- api_key_id=request.api_key_id
- )
+
+ # metrics_collector.record_request(
+ # provider=provider_name,
+ # model=request.model,
+ # request_type="embedding",
+ # success=False,
+ # latency_ms=total_latency,
+ # security_risk_score=risk_score,
+ # error_code=error_code,
+ # user_id=request.user_id,
+ # api_key_id=request.api_key_id
+ # )
raise
@@ -492,20 +374,26 @@ class LLMService:
return status_dict
def get_metrics(self) -> LLMMetrics:
- """Get service metrics"""
- return metrics_collector.get_metrics()
+ """Get service metrics - metrics disabled"""
+ # return metrics_collector.get_metrics()
+ return LLMMetrics(
+ total_requests=0,
+ success_rate=0.0,
+ avg_latency_ms=0,
+ error_rates={}
+ )
def get_health_summary(self) -> Dict[str, Any]:
- """Get comprehensive health summary"""
- metrics_health = metrics_collector.get_health_summary()
+ """Get comprehensive health summary - metrics disabled"""
+ # metrics_health = metrics_collector.get_health_summary()
resilience_health = ResilienceManagerFactory.get_all_health_status()
-
+
return {
"service_status": "healthy" if self._initialized else "initializing",
"startup_time": self._startup_time.isoformat() if self._startup_time else None,
"provider_count": len(self._providers),
"active_providers": list(self._providers.keys()),
- "metrics": metrics_health,
+ "metrics": {"status": "disabled"},
"resilience": resilience_health
}
diff --git a/backend/app/services/llm/token_rate_limiter.py b/backend/app/services/llm/token_rate_limiter.py
deleted file mode 100644
index 2338a03..0000000
--- a/backend/app/services/llm/token_rate_limiter.py
+++ /dev/null
@@ -1,153 +0,0 @@
-"""
-Token-based rate limiting for LLM service
-"""
-
-import time
-import redis
-from typing import Dict, Optional, Tuple
-from datetime import datetime, timedelta
-from ..core.config import settings
-from ..core.logging import get_logger
-
-logger = get_logger(__name__)
-
-
-class TokenRateLimiter:
- """Token-based rate limiting implementation"""
-
- def __init__(self):
- try:
- self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
- self.redis_client.ping()
- logger.info("Token rate limiter initialized with Redis backend")
- except Exception as e:
- logger.warning(f"Redis not available for token rate limiting: {e}")
- self.redis_client = None
- # Fall back to in-memory rate limiting
- self.in_memory_store = {}
- logger.info("Token rate limiter using in-memory fallback")
-
- async def check_token_limits(
- self,
- provider: str,
- prompt_tokens: int,
- completion_tokens: int = 0
- ) -> Tuple[bool, Dict[str, str]]:
- """
- Check if token usage is within limits
-
- Args:
- provider: Provider name (e.g., "privatemode")
- prompt_tokens: Number of prompt tokens to use
- completion_tokens: Number of completion tokens to use
-
- Returns:
- Tuple of (is_allowed, headers)
- """
- # Get token limits from configuration
- from .config import get_config
- config = get_config()
- token_limits = config.token_limits_per_minute
-
- # Check organization-wide limits
- org_key = f"tokens:org:{provider}"
-
- # Get current usage
- current_usage = await self._get_token_usage(org_key)
-
- # Calculate new usage
- new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens
- new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens
-
- # Check limits
- prompt_limit = token_limits.get("prompt_tokens", 20000)
- completion_limit = token_limits.get("completion_tokens", 10000)
-
- is_allowed = (
- new_prompt_tokens <= prompt_limit and
- new_completion_tokens <= completion_limit
- )
-
- if is_allowed:
- # Update usage
- await self._update_token_usage(org_key, prompt_tokens, completion_tokens)
- logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, "
- f"{new_completion_tokens}/{completion_limit} completion")
-
- # Calculate remaining tokens
- remaining_prompt = max(0, prompt_limit - new_prompt_tokens)
- remaining_completion = max(0, completion_limit - new_completion_tokens)
-
- # Create headers
- headers = {
- "X-TokenLimit-Prompt-Remaining": str(remaining_prompt),
- "X-TokenLimit-Completion-Remaining": str(remaining_completion),
- "X-TokenLimit-Prompt-Limit": str(prompt_limit),
- "X-TokenLimit-Completion-Limit": str(completion_limit),
- "X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute
- }
-
- if not is_allowed:
- logger.warning(f"Token rate limit exceeded for {provider}. "
- f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. "
- f"Current: {current_usage}")
-
- return is_allowed, headers
-
- async def _get_token_usage(self, key: str) -> Dict[str, int]:
- """Get current token usage"""
- if self.redis_client:
- try:
- data = self.redis_client.hgetall(key)
- if data:
- return {
- "prompt_tokens": int(data.get("prompt_tokens", 0)),
- "completion_tokens": int(data.get("completion_tokens", 0)),
- "updated_at": float(data.get("updated_at", time.time()))
- }
- except Exception as e:
- logger.error(f"Error getting token usage from Redis: {e}")
-
- # Fallback to in-memory
- return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0})
-
- async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int):
- """Update token usage"""
- if self.redis_client:
- try:
- pipe = self.redis_client.pipeline()
- pipe.hincrby(key, "prompt_tokens", prompt_tokens)
- pipe.hincrby(key, "completion_tokens", completion_tokens)
- pipe.hset(key, "updated_at", time.time())
- pipe.expire(key, 60) # Expire after 1 minute
- pipe.execute()
- except Exception as e:
- logger.error(f"Error updating token usage in Redis: {e}")
- # Fallback to in-memory
- self._update_in_memory(key, prompt_tokens, completion_tokens)
- else:
- self._update_in_memory(key, prompt_tokens, completion_tokens)
-
- def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int):
- """Update in-memory token usage"""
- if key not in self.in_memory_store:
- self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0}
-
- self.in_memory_store[key]["prompt_tokens"] += prompt_tokens
- self.in_memory_store[key]["completion_tokens"] += completion_tokens
- self.in_memory_store[key]["updated_at"] = time.time()
-
- def cleanup_expired(self):
- """Clean up expired entries (for in-memory store)"""
- if not self.redis_client:
- current_time = time.time()
- expired_keys = [
- key for key, data in self.in_memory_store.items()
- if current_time - data.get("updated_at", 0) > 60
- ]
- for key in expired_keys:
- del self.in_memory_store[key]
-
-
-# Global token rate limiter instance
-token_rate_limiter = TokenRateLimiter()
\ No newline at end of file
diff --git a/backend/app/services/rag_service.py b/backend/app/services/rag_service.py
index 119cb26..1741362 100644
--- a/backend/app/services/rag_service.py
+++ b/backend/app/services/rag_service.py
@@ -755,10 +755,11 @@ class RAGService:
# Process with RAG module
try:
+ # Pass file_path in metadata so JSONL indexing can reopen the source file
processed_doc = await rag_module.process_document(
- file_content,
- document.original_filename,
- {}
+ file_content,
+ document.original_filename,
+ {"file_path": document.file_path}
)
# Success case - update document with processed content
@@ -873,4 +874,4 @@ class RAGService:
except Exception as e:
logger.error(f"Error reprocessing document {document_id}: {e}")
- return False
\ No newline at end of file
+ return False
diff --git a/backend/modules/rag/main.py b/backend/modules/rag/main.py
index 7d75fbd..d56503c 100644
--- a/backend/modules/rag/main.py
+++ b/backend/modules/rag/main.py
@@ -638,11 +638,19 @@ class RAGModule(BaseModule):
np.random.seed(hash(text) % 2**32)
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
- async def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
+ async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
"""Generate embeddings for multiple texts (batch processing)"""
if self.embedding_service:
+ # Add task-specific prefixes for better E5 model performance
+ if is_document:
+ # For document passages, use "passage:" prefix
+ prefixed_texts = [f"passage: {text}" for text in texts]
+ else:
+ # For queries, use "query:" prefix (handled in search method)
+ prefixed_texts = texts
+
# Use real embedding service for batch processing
- return await self.embedding_service.get_embeddings(texts)
+ return await self.embedding_service.get_embeddings(prefixed_texts)
else:
# Fallback to individual processing
embeddings = []
@@ -917,69 +925,75 @@ class RAGModule(BaseModule):
async def _process_jsonl(self, content: bytes, filename: str) -> str:
"""Process JSONL files (newline-delimited JSON)
-
+
Specifically optimized for helpjuice-export.jsonl format:
- Each line contains a JSON object with 'id' and 'payload'
- Payload contains 'question', 'language', and 'answer' fields
- Combines question and answer into searchable content
+
+ Performance optimizations:
+ - Processes articles in smaller batches to reduce memory usage
+ - Uses streaming approach for large files
"""
try:
+ # Use streaming approach for large files
jsonl_content = content.decode('utf-8', errors='replace')
lines = jsonl_content.strip().split('\n')
-
+
processed_articles = []
-
+ batch_size = 50 # Process in batches of 50 articles
+
for line_num, line in enumerate(lines, 1):
if not line.strip():
continue
-
+
try:
# Parse each JSON line
data = json.loads(line)
-
+
# Handle helpjuice export format
if 'payload' in data:
payload = data['payload']
article_id = data.get('id', f'article_{line_num}')
-
+
# Extract fields
question = payload.get('question', '')
answer = payload.get('answer', '')
language = payload.get('language', 'EN')
-
+
# Combine question and answer for better search
if question or answer:
# Format as Q&A for better context
article_text = f"## {question}\n\n{answer}\n\n"
-
+
# Add language tag if not English
if language != 'EN':
article_text = f"[{language}] {article_text}"
-
+
# Add metadata separator
article_text += f"---\nArticle ID: {article_id}\nLanguage: {language}\n\n"
-
+
processed_articles.append(article_text)
-
+
# Handle generic JSONL format
else:
# Convert the entire JSON object to readable text
json_text = json.dumps(data, indent=2, ensure_ascii=False)
processed_articles.append(json_text + "\n\n")
-
+
except json.JSONDecodeError as e:
logger.warning(f"Error parsing JSONL line {line_num}: {e}")
continue
except Exception as e:
logger.warning(f"Error processing JSONL line {line_num}: {e}")
continue
-
+
# Combine all articles
combined_text = '\n'.join(processed_articles)
-
+
logger.info(f"Successfully processed {len(processed_articles)} articles from JSONL file {filename}")
return combined_text
-
+
except Exception as e:
logger.error(f"Error processing JSONL file {filename}: {e}")
return ""
@@ -1153,7 +1167,7 @@ class RAGModule(BaseModule):
chunks = self._chunk_text(content)
# Generate embeddings for all chunks in batch (more efficient)
- embeddings = await self._generate_embeddings(chunks)
+ embeddings = await self._generate_embeddings(chunks, is_document=True)
# Create document points
points = []
@@ -1200,10 +1214,28 @@ class RAGModule(BaseModule):
"""Index a processed document in the vector database"""
if not self.enabled:
raise RuntimeError("RAG module not initialized")
-
+
collection_name = collection_name or self.default_collection_name
-
+
try:
+ # Special handling for JSONL files
+ if processed_doc.file_type == 'jsonl':
+ # Import the optimized JSONL processor
+ from app.services.jsonl_processor import JSONLProcessor
+ jsonl_processor = JSONLProcessor(self)
+
+ # Read the original file content
+ with open(processed_doc.metadata.get('file_path', ''), 'rb') as f:
+ file_content = f.read()
+
+ # Process using the optimized JSONL processor
+ return await jsonl_processor.process_and_index_jsonl(
+ collection_name=collection_name,
+ content=file_content,
+ filename=processed_doc.original_filename,
+ metadata=processed_doc.metadata
+ )
+
# Ensure collection exists
await self._ensure_collection_exists(collection_name)
@@ -1216,7 +1248,7 @@ class RAGModule(BaseModule):
chunks = self._chunk_text(processed_doc.content)
# Generate embeddings for all chunks in batch (more efficient)
- embeddings = await self._generate_embeddings(chunks)
+ embeddings = await self._generate_embeddings(chunks, is_document=True)
# Create document points with enhanced metadata
points = []
@@ -1339,24 +1371,48 @@ class RAGModule(BaseModule):
score_threshold=score_threshold / 2 # Lower threshold for initial search
)
- # Combine scores
+ # Combine scores with improved normalization
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
vector_weight = hybrid_weights.get("vector", 0.7)
bm25_weight = hybrid_weights.get("bm25", 0.3)
- # Create hybrid results
+ # Get score distributions for better normalization
+ vector_scores = [r.score for r in vector_results]
+ bm25_scores_list = list(bm25_scores.values())
+
+ # Calculate statistics for normalization
+ if vector_scores:
+ v_max = max(vector_scores)
+ v_min = min(vector_scores)
+ v_range = v_max - v_min if v_max != v_min else 1
+ else:
+ v_max, v_min, v_range = 1, 0, 1
+
+ if bm25_scores_list:
+ bm25_max = max(bm25_scores_list)
+ bm25_min = min(bm25_scores_list)
+ bm25_range = bm25_max - bm25_min if bm25_max != bm25_min else 1
+ else:
+ bm25_max, bm25_min, bm25_range = 1, 0, 1
+
+ # Create hybrid results with improved scoring
hybrid_results = []
for result in vector_results:
doc_id = result.payload.get("document_id", "")
vector_score = result.score
bm25_score = bm25_scores.get(doc_id, 0.0)
- # Normalize scores (simple min-max normalization)
- vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0
- bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1
+ # Improved normalization using actual score distributions
+ vector_norm = (vector_score - v_min) / v_range if v_range > 0 else 0.5
+ bm25_norm = (bm25_score - bm25_min) / bm25_range if bm25_range > 0 else 0.5
- # Calculate hybrid score
- hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm)
+ # Apply reciprocal rank fusion for better combination
+ # This gives more weight to documents that rank highly in both methods
+ rrf_vector = 1.0 / (1.0 + vector_results.index(result) + 1) # +1 to avoid division by zero
+ rrf_bm25 = 1.0 / (1.0 + sorted(bm25_scores_list, reverse=True).index(bm25_score) + 1) if bm25_score in bm25_scores_list else 0
+
+ # Calculate hybrid score using both normalized scores and RRF
+ hybrid_score = (vector_weight * vector_norm + bm25_weight * bm25_norm) * 0.7 + (rrf_vector + rrf_bm25) * 0.3
# Create new point with hybrid score
hybrid_point = ScoredPoint(
@@ -1435,7 +1491,7 @@ class RAGModule(BaseModule):
# Normalize score to 0-1 range
return min(score / 10.0, 1.0) # Simple normalization
- async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
+ async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
"""Search for relevant documents"""
if not self.enabled:
raise RuntimeError("RAG module not initialized")
@@ -1453,8 +1509,10 @@ class RAGModule(BaseModule):
import time
start_time = time.time()
- # Generate query embedding
- query_embedding = await self._generate_embedding(query)
+ # Generate query embedding with task-specific prefix for better retrieval
+ # The E5 model works better with "query:" prefix for search queries
+ optimized_query = f"query: {query}"
+ query_embedding = await self._generate_embedding(optimized_query)
# Build filter
search_filter = None
@@ -1474,7 +1532,8 @@ class RAGModule(BaseModule):
# Check if hybrid search is enabled
enable_hybrid = self.config.get("enable_hybrid", False)
- score_threshold = self.config.get("score_threshold", 0.3)
+ # Use provided score_threshold or fall back to config
+ search_score_threshold = score_threshold if score_threshold is not None else self.config.get("score_threshold", 0.3)
if enable_hybrid and NLTK_AVAILABLE:
# Perform hybrid search (vector + BM25)
@@ -1484,7 +1543,7 @@ class RAGModule(BaseModule):
query_vector=query_embedding,
query_filter=search_filter,
limit=max_results,
- score_threshold=score_threshold
+ score_threshold=search_score_threshold
)
else:
# Pure vector search with improved threshold
@@ -1493,7 +1552,7 @@ class RAGModule(BaseModule):
query_vector=query_embedding,
query_filter=search_filter,
limit=max_results,
- score_threshold=score_threshold
+ score_threshold=search_score_threshold
)
logger.info(f"Raw search results count: {len(search_results)}")
@@ -1841,9 +1900,9 @@ async def index_processed_document(processed_doc: ProcessedDocument, collection_
"""Index a processed document"""
return await rag_module.index_processed_document(processed_doc, collection_name)
-async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
+async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
"""Search documents"""
- return await rag_module.search_documents(query, max_results, filters, collection_name)
+ return await rag_module.search_documents(query, max_results, filters, collection_name, score_threshold)
async def delete_document(document_id: str, collection_name: str = None) -> bool:
"""Delete a document"""
diff --git a/frontend/src/app/api/auth/login/route.ts b/frontend/src/app/api/auth/login/route.ts
index c32f93e..fefb7fe 100644
--- a/frontend/src/app/api/auth/login/route.ts
+++ b/frontend/src/app/api/auth/login/route.ts
@@ -7,7 +7,7 @@ export async function POST(request: NextRequest) {
// Make request to backend auth endpoint without requiring existing auth
const baseUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`
- const url = `${baseUrl}/api/auth/login`
+ const url = `${baseUrl}/api-internal/v1/auth/login`
const response = await fetch(url, {
method: 'POST',
diff --git a/frontend/src/app/rag/page.tsx b/frontend/src/app/rag/page.tsx
index 87616c1..48ae013 100644
--- a/frontend/src/app/rag/page.tsx
+++ b/frontend/src/app/rag/page.tsx
@@ -85,8 +85,31 @@ function RAGPageContent() {
const loadStats = async () => {
try {
const data = await apiClient.get('/api-internal/v1/rag/stats')
- setStats(data.stats)
+ console.log('Stats API response:', data)
+
+ // Check if the response has the expected structure
+ if (data && data.stats && data.stats.collections) {
+ console.log('✓ Stats has collections property')
+ setStats(data.stats)
+ } else {
+ console.error('✗ Invalid stats structure:', data)
+ // Set default empty stats to prevent error
+ setStats({
+ collections: { total: 0, active: 0 },
+ documents: { total: 0, processing: 0, processed: 0 },
+ storage: { total_size_bytes: 0, total_size_mb: 0 },
+ vectors: { total: 0 }
+ })
+ }
} catch (error) {
+ console.error('Error loading stats:', error)
+ // Set default empty stats on error
+ setStats({
+ collections: { total: 0, active: 0 },
+ documents: { total: 0, processing: 0, processed: 0 },
+ storage: { total_size_bytes: 0, total_size_mb: 0 },
+ vectors: { total: 0 }
+ })
}
}
diff --git a/frontend/src/components/rag/document-browser.tsx b/frontend/src/components/rag/document-browser.tsx
index c3e643f..2643e9c 100644
--- a/frontend/src/components/rag/document-browser.tsx
+++ b/frontend/src/components/rag/document-browser.tsx
@@ -9,7 +9,7 @@ import { Badge } from "@/components/ui/badge"
import { Separator } from "@/components/ui/separator"
import { AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger } from "@/components/ui/alert-dialog"
import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog"
-import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter } from "lucide-react"
+import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter, RefreshCw } from "lucide-react"
import { useToast } from "@/hooks/use-toast"
import { apiClient } from "@/lib/api-client"
import { config } from "@/lib/config"
@@ -56,6 +56,7 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
const [filterStatus, setFilterStatus] = useState("all")
const [selectedDocument, setSelectedDocument] = useState(null)
const [deleting, setDeleting] = useState(null)
+ const [reprocessing, setReprocessing] = useState(null)
const { toast } = useToast()
useEffect(() => {
@@ -157,6 +158,43 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
}
}
+ const handleReprocessDocument = async (documentId: string) => {
+ setReprocessing(documentId)
+
+ try {
+ await apiClient.post(`/api-internal/v1/rag/documents/${documentId}/reprocess`)
+
+ // Update the document status to processing in the UI
+ setDocuments(prev => prev.map(doc =>
+ doc.id === documentId
+ ? { ...doc, status: 'processing' as const, processed_at: new Date().toISOString() }
+ : doc
+ ))
+
+ toast({
+ title: "Success",
+ description: "Document reprocessing started",
+ })
+
+ // Reload documents after a short delay to see status updates
+ setTimeout(() => {
+ loadDocuments()
+ }, 2000)
+
+ } catch (error) {
+ const errorMessage = error instanceof Error ? error.message : "Failed to reprocess document"
+ toast({
+ title: "Error",
+ description: errorMessage.includes("Cannot reprocess document with status 'processed'")
+ ? "Cannot reprocess documents that are already processed"
+ : errorMessage,
+ variant: "destructive",
+ })
+ } finally {
+ setReprocessing(null)
+ }
+ }
+
const formatFileSize = (bytes: number) => {
if (bytes === 0) return '0 Bytes'
const k = 1024
@@ -432,6 +470,21 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
+
+