rag improvements

This commit is contained in:
2025-09-23 15:26:54 +02:00
parent 354b43494d
commit f8d127ff42
30 changed files with 817 additions and 2428 deletions

2
.env
View File

@@ -46,7 +46,7 @@ API_RATE_LIMITING_ENABLED=false
# =================================== # ===================================
# APPLICATION BASE URL (Required - derives all URLs and CORS) # APPLICATION BASE URL (Required - derives all URLs and CORS)
# =================================== # ===================================
BASE_URL=localhost BASE_URL=localhost:80
# Frontend derives: APP_URL=http://localhost, API_URL=http://localhost, WS_URL=ws://localhost # Frontend derives: APP_URL=http://localhost, API_URL=http://localhost, WS_URL=ws://localhost
# Backend derives: CORS_ORIGINS=["http://localhost"] # Backend derives: CORS_ORIGINS=["http://localhost"]

View File

@@ -65,6 +65,16 @@ QDRANT_HOST=enclava-qdrant
QDRANT_PORT=6333 QDRANT_PORT=6333
QDRANT_URL=http://enclava-qdrant:6333 QDRANT_URL=http://enclava-qdrant:6333
# ===================================
# RAG EMBEDDING CONFIGURATION (Optional overrides)
# ===================================
# These control embedding throughput to avoid provider 429s.
# Defaults are conservative; uncomment to override.
# RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE=12
# RAG_EMBEDDING_BATCH_SIZE=3
# RAG_EMBEDDING_DELAY_BETWEEN_BATCHES=1.0 # seconds
# RAG_EMBEDDING_DELAY_PER_REQUEST=0.5 # seconds
# =================================== # ===================================
# OPTIONAL PRIVATEMODE SETTINGS (Have defaults) # OPTIONAL PRIVATEMODE SETTINGS (Have defaults)
# =================================== # ===================================

View File

@@ -12,8 +12,8 @@ from ..v1.audit import router as audit_router
from ..v1.settings import router as settings_router from ..v1.settings import router as settings_router
from ..v1.analytics import router as analytics_router from ..v1.analytics import router as analytics_router
from ..v1.rag import router as rag_router from ..v1.rag import router as rag_router
from ..rag_debug import router as rag_debug_router
from ..v1.prompt_templates import router as prompt_templates_router from ..v1.prompt_templates import router as prompt_templates_router
from ..v1.security import router as security_router
from ..v1.plugin_registry import router as plugin_registry_router from ..v1.plugin_registry import router as plugin_registry_router
from ..v1.platform import router as platform_router from ..v1.platform import router as platform_router
from ..v1.llm_internal import router as llm_internal_router from ..v1.llm_internal import router as llm_internal_router
@@ -52,11 +52,12 @@ internal_api_router.include_router(analytics_router, prefix="/analytics", tags=[
# Include RAG routes (frontend RAG document management) # Include RAG routes (frontend RAG document management)
internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"]) internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"])
# Include RAG debug routes (for demo and debugging)
internal_api_router.include_router(rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"])
# Include prompt template routes (frontend prompt template management) # Include prompt template routes (frontend prompt template management)
internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"]) internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"])
# Include security routes (frontend security settings)
internal_api_router.include_router(security_router, prefix="/security", tags=["internal-security"])
# Include plugin registry routes (frontend plugin management) # Include plugin registry routes (frontend plugin management)
internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"]) internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"])

View File

@@ -16,7 +16,6 @@ from .analytics import router as analytics_router
from .rag import router as rag_router from .rag import router as rag_router
from .chatbot import router as chatbot_router from .chatbot import router as chatbot_router
from .prompt_templates import router as prompt_templates_router from .prompt_templates import router as prompt_templates_router
from .security import router as security_router
from .plugin_registry import router as plugin_registry_router from .plugin_registry import router as plugin_registry_router
# Create main API router # Create main API router
@@ -61,8 +60,6 @@ api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
# Include prompt template routes # Include prompt template routes
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"]) api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
# Include security routes
api_router.include_router(security_router, prefix="/security", tags=["security"])
# Include plugin registry routes # Include plugin registry routes

View File

@@ -745,7 +745,6 @@ async def get_llm_metrics(
"total_requests": metrics.total_requests, "total_requests": metrics.total_requests,
"successful_requests": metrics.successful_requests, "successful_requests": metrics.successful_requests,
"failed_requests": metrics.failed_requests, "failed_requests": metrics.failed_requests,
"security_blocked_requests": metrics.security_blocked_requests,
"average_latency_ms": metrics.average_latency_ms, "average_latency_ms": metrics.average_latency_ms,
"average_risk_score": metrics.average_risk_score, "average_risk_score": metrics.average_risk_score,
"provider_metrics": metrics.provider_metrics, "provider_metrics": metrics.provider_metrics,

View File

@@ -3,12 +3,14 @@ RAG API Endpoints
Provides REST API for RAG (Retrieval Augmented Generation) operations Provides REST API for RAG (Retrieval Augmented Generation) operations
""" """
from typing import List, Optional from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel from pydantic import BaseModel
import io import io
import asyncio
from datetime import datetime
from app.db.database import get_db from app.db.database import get_db
from app.core.security import get_current_user from app.core.security import get_current_user
@@ -16,6 +18,9 @@ from app.models.user import User
from app.services.rag_service import RAGService from app.services.rag_service import RAGService
from app.utils.exceptions import APIException from app.utils.exceptions import APIException
# Import RAG module from module manager
from app.services.module_manager import module_manager
router = APIRouter(tags=["RAG"]) router = APIRouter(tags=["RAG"])
@@ -78,14 +83,25 @@ async def get_collections(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
"""Get all RAG collections from Qdrant (source of truth) with PostgreSQL metadata""" """Get all RAG collections - live data directly from Qdrant (source of truth)"""
try: try:
rag_service = RAGService(db) from app.services.qdrant_stats_service import qdrant_stats_service
collections_data = await rag_service.get_all_collections(skip=skip, limit=limit)
# Get live stats from Qdrant
stats_data = await qdrant_stats_service.get_collections_stats()
collections = stats_data.get("collections", [])
# Apply pagination
start_idx = skip
end_idx = skip + limit
paginated_collections = collections[start_idx:end_idx]
return { return {
"success": True, "success": True,
"collections": collections_data, "collections": paginated_collections,
"total": len(collections_data) "total": len(collections),
"total_documents": stats_data.get("total_documents", 0),
"total_size_bytes": stats_data.get("total_size_bytes", 0)
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -116,6 +132,62 @@ async def create_collection(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("/stats", response_model=dict)
async def get_rag_stats(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get overall RAG statistics - live data directly from Qdrant"""
try:
from app.services.qdrant_stats_service import qdrant_stats_service
# Get live stats from Qdrant
stats_data = await qdrant_stats_service.get_collections_stats()
# Calculate active collections (collections with documents)
active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
# Calculate processing documents from database
processing_docs = 0
try:
from sqlalchemy import select
from app.models.rag_document import RagDocument, ProcessingStatus
result = await db.execute(
select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
)
processing_docs = len(result.scalars().all())
except Exception:
pass # If database query fails, default to 0
response_data = {
"success": True,
"stats": {
"collections": {
"total": stats_data.get("total_collections", 0),
"active": active_collections
},
"documents": {
"total": stats_data.get("total_documents", 0),
"processing": processing_docs,
"processed": stats_data.get("total_documents", 0) # Indexed documents
},
"storage": {
"total_size_bytes": stats_data.get("total_size_bytes", 0),
"total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
},
"vectors": {
"total": stats_data.get("total_documents", 0) # Same as documents for RAG
},
"last_updated": datetime.utcnow().isoformat()
}
}
return response_data
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/collections/{collection_id}", response_model=dict) @router.get("/collections/{collection_id}", response_model=dict)
async def get_collection( async def get_collection(
collection_id: int, collection_id: int,
@@ -232,11 +304,55 @@ async def upload_document(
if len(file_content) > 50 * 1024 * 1024: # 50MB limit if len(file_content) > 50 * 1024 * 1024: # 50MB limit
raise HTTPException(status_code=400, detail="File too large (max 50MB)") raise HTTPException(status_code=400, detail="File too large (max 50MB)")
# Validate file can be read before processing
filename = file.filename or "unknown"
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
try:
# Test file readability based on type
if file_extension == 'jsonl':
# Validate JSONL format - try to parse first few lines
try:
content_str = file_content.decode('utf-8')
lines = content_str.strip().split('\n')[:5] # Check first 5 lines
import json
for i, line in enumerate(lines):
if line.strip(): # Skip empty lines
json.loads(line) # Will raise JSONDecodeError if invalid
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
# Validate text files can be decoded
try:
file_content.decode('utf-8')
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
elif file_extension in ['pdf']:
# For PDF files, just check if it starts with PDF signature
if not file_content.startswith(b'%PDF'):
raise HTTPException(status_code=400, detail="Invalid PDF file format")
elif file_extension in ['docx', 'xlsx', 'pptx']:
# For Office documents, check ZIP signature
if not file_content.startswith(b'PK'):
raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
# For other file types, we'll rely on the document processor
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
rag_service = RAGService(db) rag_service = RAGService(db)
document = await rag_service.upload_document( document = await rag_service.upload_document(
collection_id=collection_id, collection_id=collection_id,
file_content=file_content, file_content=file_content,
filename=file.filename or "unknown", filename=filename,
content_type=file.content_type content_type=file.content_type
) )
@@ -362,21 +478,167 @@ async def download_document(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# Stats Endpoint
@router.get("/stats", response_model=dict) # Debug Endpoints
async def get_rag_stats(
db: AsyncSession = Depends(get_db), @router.post("/debug/search")
async def search_with_debug(
query: str,
max_results: int = 10,
score_threshold: float = 0.3,
collection_name: str = None,
config: Dict[str, Any] = None,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ) -> Dict[str, Any]:
"""Get RAG system statistics""" """
Enhanced search with comprehensive debug information
"""
# Get RAG module from module manager
rag_module = module_manager.modules.get('rag')
if not rag_module or not rag_module.enabled:
raise HTTPException(status_code=503, detail="RAG module not initialized")
debug_info = {}
start_time = datetime.utcnow()
try: try:
rag_service = RAGService(db) # Apply configuration if provided
stats = await rag_service.get_stats() if config:
# Update RAG config temporarily
original_config = rag_module.config.copy()
rag_module.config.update(config)
# Generate query embedding (with or without prefix)
if config and config.get("use_query_prefix"):
optimized_query = f"query: {query}"
else:
optimized_query = query
query_embedding = await rag_module._generate_embedding(optimized_query)
# Store embedding info for debug
if config and config.get("debug", {}).get("show_embeddings"):
debug_info["query_embedding"] = query_embedding[:10] # First 10 dimensions
debug_info["embedding_dimension"] = len(query_embedding)
debug_info["optimized_query"] = optimized_query
# Perform search
search_start = asyncio.get_event_loop().time()
results = await rag_module.search_documents(
query,
max_results=max_results,
score_threshold=score_threshold,
collection_name=collection_name
)
search_time = (asyncio.get_event_loop().time() - search_start) * 1000
# Calculate score statistics
scores = [r.score for r in results if r.score is not None]
if scores:
import statistics
debug_info["score_stats"] = {
"min": min(scores),
"max": max(scores),
"avg": statistics.mean(scores),
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0
}
# Get collection statistics
try:
from qdrant_client.http.models import Filter
collection_name = collection_name or rag_module.default_collection_name
# Count total documents
count_result = rag_module.qdrant_client.count(
collection_name=collection_name,
count_filter=Filter(must=[])
)
total_points = count_result.count
# Get unique documents and languages
scroll_result = rag_module.qdrant_client.scroll(
collection_name=collection_name,
limit=1000, # Sample for stats
with_payload=True,
with_vectors=False
)
unique_docs = set()
languages = set()
for point in scroll_result[0]:
payload = point.payload or {}
doc_id = payload.get("document_id")
if doc_id:
unique_docs.add(doc_id)
language = payload.get("language")
if language:
languages.add(language)
debug_info["collection_stats"] = {
"total_documents": len(unique_docs),
"total_chunks": total_points,
"languages": sorted(list(languages))
}
except Exception as e:
debug_info["collection_stats_error"] = str(e)
# Enhance results with debug info
enhanced_results = []
for result in results:
enhanced_result = {
"document": {
"id": result.document.id,
"content": result.document.content,
"metadata": result.document.metadata
},
"score": result.score,
"debug_info": {}
}
# Add hybrid search debug info if available
metadata = result.document.metadata or {}
if "_vector_score" in metadata:
enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
if "_bm25_score" in metadata:
enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
enhanced_results.append(enhanced_result)
# Note: Analytics logging disabled (module not available)
return { return {
"success": True, "results": enhanced_results,
"stats": stats "debug_info": debug_info,
"search_time_ms": search_time,
"timestamp": start_time.isoformat()
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) # Note: Analytics logging disabled (module not available)
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
finally:
# Restore original config if modified
if config and 'original_config' in locals():
rag_module.config = original_config
@router.get("/debug/config")
async def get_current_config(
current_user: User = Depends(get_current_user)
) -> Dict[str, Any]:
"""Get current RAG configuration"""
# Get RAG module from module manager
rag_module = module_manager.modules.get('rag')
if not rag_module or not rag_module.enabled:
raise HTTPException(status_code=503, detail="RAG module not initialized")
return {
"config": rag_module.config,
"embedding_model": rag_module.embedding_model,
"enabled": rag_module.enabled,
"collections": await rag_module._get_collections_safely()
}

View File

@@ -1,251 +0,0 @@
"""
Security API endpoints for monitoring and configuration
"""
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel, Field
from app.core.security import get_current_active_user, RequiresRole
from app.middleware.security import get_security_stats, get_request_auth_level, get_request_risk_score
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
router = APIRouter(tags=["security"])
# Pydantic models for API responses
class SecurityStatsResponse(BaseModel):
"""Security statistics response model"""
total_requests_analyzed: int
threats_detected: int
threats_blocked: int
anomalies_detected: int
rate_limits_exceeded: int
avg_analysis_time: float
threat_types: Dict[str, int]
threat_levels: Dict[str, int]
top_attacking_ips: List[tuple]
security_enabled: bool
threat_detection_enabled: bool
rate_limiting_enabled: bool
class SecurityConfigResponse(BaseModel):
"""Security configuration response model"""
security_enabled: bool = Field(description="Overall security system enabled")
threat_detection_enabled: bool = Field(description="Threat detection analysis enabled")
rate_limiting_enabled: bool = Field(description="Rate limiting enabled")
ip_reputation_enabled: bool = Field(description="IP reputation checking enabled")
anomaly_detection_enabled: bool = Field(description="Anomaly detection enabled")
security_headers_enabled: bool = Field(description="Security headers enabled")
# Rate limiting settings
unauthenticated_per_minute: int = Field(description="Rate limit for unauthenticated requests per minute")
authenticated_per_minute: int = Field(description="Rate limit for authenticated users per minute")
api_key_per_minute: int = Field(description="Rate limit for API key users per minute")
premium_per_minute: int = Field(description="Rate limit for premium users per minute")
# Security thresholds
risk_threshold: float = Field(description="Risk score threshold for blocking requests")
warning_threshold: float = Field(description="Risk score threshold for warnings")
anomaly_threshold: float = Field(description="Anomaly severity threshold")
# IP settings
blocked_ips: List[str] = Field(description="List of blocked IP addresses")
allowed_ips: List[str] = Field(description="List of allowed IP addresses (empty = allow all)")
class RateLimitInfoResponse(BaseModel):
"""Rate limit information for current request"""
auth_level: str = Field(description="Authentication level (unauthenticated, authenticated, api_key, premium)")
current_limits: Dict[str, int] = Field(description="Current rate limits for this auth level")
remaining_requests: Optional[Dict[str, int]] = Field(description="Estimated remaining requests (if available)")
@router.get("/stats", response_model=SecurityStatsResponse)
async def get_security_statistics(
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
):
"""
Get security system statistics
Requires admin role. Returns comprehensive statistics about:
- Request analysis counts
- Threat detection results
- Rate limiting enforcement
- Top attacking IPs
- Performance metrics
"""
try:
stats = get_security_stats()
return SecurityStatsResponse(**stats)
except Exception as e:
logger.error(f"Error getting security stats: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve security statistics"
)
@router.get("/config", response_model=SecurityConfigResponse)
async def get_security_config(
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
):
"""
Get current security configuration
Requires admin role. Returns current security settings including:
- Feature enablement flags
- Rate limiting thresholds
- Security thresholds
- IP allowlists/blocklists
"""
return SecurityConfigResponse(
security_enabled=settings.API_SECURITY_ENABLED,
threat_detection_enabled=settings.API_THREAT_DETECTION_ENABLED,
rate_limiting_enabled=settings.API_RATE_LIMITING_ENABLED,
ip_reputation_enabled=settings.API_IP_REPUTATION_ENABLED,
anomaly_detection_enabled=settings.API_ANOMALY_DETECTION_ENABLED,
security_headers_enabled=settings.API_SECURITY_HEADERS_ENABLED,
unauthenticated_per_minute=settings.API_RATE_LIMIT_UNAUTHENTICATED_PER_MINUTE,
authenticated_per_minute=settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE,
api_key_per_minute=settings.API_RATE_LIMIT_API_KEY_PER_MINUTE,
premium_per_minute=settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE,
risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
anomaly_threshold=settings.API_SECURITY_ANOMALY_THRESHOLD,
blocked_ips=settings.API_BLOCKED_IPS,
allowed_ips=settings.API_ALLOWED_IPS
)
@router.get("/status")
async def get_security_status(
request: Request,
current_user: Dict[str, Any] = Depends(get_current_active_user)
):
"""
Get security status for current request
Returns information about the security analysis of the current request:
- Authentication level
- Risk score (if available)
- Rate limiting status
"""
auth_level = get_request_auth_level(request)
risk_score = get_request_risk_score(request)
# Get rate limits for current auth level
from app.core.threat_detection import AuthLevel
try:
auth_enum = AuthLevel(auth_level)
from app.core.threat_detection import threat_detection_service
minute_limit, hour_limit = threat_detection_service.get_rate_limits(auth_enum)
rate_limit_info = RateLimitInfoResponse(
auth_level=auth_level,
current_limits={
"per_minute": minute_limit,
"per_hour": hour_limit
},
remaining_requests=None # We don't track remaining requests in current implementation
)
except ValueError:
rate_limit_info = RateLimitInfoResponse(
auth_level=auth_level,
current_limits={},
remaining_requests=None
)
return {
"security_enabled": settings.API_SECURITY_ENABLED,
"auth_level": auth_level,
"risk_score": round(risk_score, 3) if risk_score > 0 else None,
"rate_limit_info": rate_limit_info.dict(),
"security_headers_enabled": settings.API_SECURITY_HEADERS_ENABLED
}
@router.post("/test")
async def test_security_analysis(
request: Request,
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
):
"""
Test security analysis on current request
Requires admin role. Manually triggers security analysis on the current request
and returns detailed results. Useful for testing security rules and thresholds.
"""
try:
from app.middleware.security import analyze_request_security
analysis = await analyze_request_security(request, current_user)
return {
"analysis_complete": True,
"is_threat": analysis.is_threat,
"risk_score": round(analysis.risk_score, 3),
"auth_level": analysis.auth_level.value,
"should_block": analysis.should_block,
"rate_limit_exceeded": analysis.rate_limit_exceeded,
"threat_count": len(analysis.threats),
"threats": [
{
"type": threat.threat_type,
"level": threat.level.value,
"confidence": round(threat.confidence, 3),
"description": threat.description,
"mitigation": threat.mitigation
}
for threat in analysis.threats
],
"recommendations": analysis.recommendations
}
except Exception as e:
logger.error(f"Error in security analysis test: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to perform security analysis test"
)
@router.get("/health")
async def security_health_check():
"""
Security system health check
Public endpoint that returns the health status of the security system.
Does not require authentication.
"""
try:
stats = get_security_stats()
# Basic health checks
is_healthy = (
settings.API_SECURITY_ENABLED and
stats.get("total_requests_analyzed", 0) >= 0 and
stats.get("avg_analysis_time", 0) < 1.0 # Analysis should be under 1 second
)
return {
"status": "healthy" if is_healthy else "degraded",
"security_enabled": settings.API_SECURITY_ENABLED,
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED,
"avg_analysis_time_ms": round(stats.get("avg_analysis_time", 0) * 1000, 2),
"total_requests_analyzed": stats.get("total_requests_analyzed", 0)
}
except Exception as e:
logger.error(f"Security health check failed: {e}")
return {
"status": "unhealthy",
"error": "Security system error",
"security_enabled": settings.API_SECURITY_ENABLED
}

View File

@@ -97,7 +97,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
"api": { "api": {
# Security Settings # Security Settings
"security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"}, "security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"},
"threat_detection_enabled": {"value": True, "type": "boolean", "description": "Enable threat detection analysis"},
"rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"}, "rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"},
"ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"}, "ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"},
"anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"}, "anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"},
@@ -112,7 +111,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"}, "rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"},
# Security Thresholds # Security Thresholds
"security_risk_threshold": {"value": 0.8, "type": "float", "description": "Risk score threshold for blocking requests (0.0-1.0)"},
"security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"}, "security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"},
"anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"}, "anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"},
@@ -601,7 +599,6 @@ async def reset_to_defaults(
"api": { "api": {
# Security Settings # Security Settings
"security_enabled": {"value": True, "type": "boolean"}, "security_enabled": {"value": True, "type": "boolean"},
"threat_detection_enabled": {"value": True, "type": "boolean"},
"rate_limiting_enabled": {"value": True, "type": "boolean"}, "rate_limiting_enabled": {"value": True, "type": "boolean"},
"ip_reputation_enabled": {"value": True, "type": "boolean"}, "ip_reputation_enabled": {"value": True, "type": "boolean"},
"anomaly_detection_enabled": {"value": True, "type": "boolean"}, "anomaly_detection_enabled": {"value": True, "type": "boolean"},
@@ -616,7 +613,6 @@ async def reset_to_defaults(
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer"}, "rate_limit_premium_per_hour": {"value": 100000, "type": "integer"},
# Security Thresholds # Security Thresholds
"security_risk_threshold": {"value": 0.8, "type": "float"},
"security_warning_threshold": {"value": 0.6, "type": "float"}, "security_warning_threshold": {"value": 0.6, "type": "float"},
"anomaly_threshold": {"value": 0.7, "type": "float"}, "anomaly_threshold": {"value": 0.7, "type": "float"},

View File

@@ -17,6 +17,8 @@ class Settings(BaseSettings):
APP_LOG_LEVEL: str = os.getenv("APP_LOG_LEVEL", "INFO") APP_LOG_LEVEL: str = os.getenv("APP_LOG_LEVEL", "INFO")
APP_HOST: str = os.getenv("APP_HOST", "0.0.0.0") APP_HOST: str = os.getenv("APP_HOST", "0.0.0.0")
APP_PORT: int = int(os.getenv("APP_PORT", "8000")) APP_PORT: int = int(os.getenv("APP_PORT", "8000"))
BACKEND_INTERNAL_PORT: int = int(os.getenv("BACKEND_INTERNAL_PORT", "8000"))
FRONTEND_INTERNAL_PORT: int = int(os.getenv("FRONTEND_INTERNAL_PORT", "3000"))
# Detailed logging for LLM interactions # Detailed logging for LLM interactions
LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM
@@ -73,15 +75,10 @@ class Settings(BaseSettings):
QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost") QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost")
QDRANT_PORT: int = int(os.getenv("QDRANT_PORT", "6333")) QDRANT_PORT: int = int(os.getenv("QDRANT_PORT", "6333"))
QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY") QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY")
QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
# API & Security Settings
API_SECURITY_ENABLED: bool = os.getenv("API_SECURITY_ENABLED", "True").lower() == "true"
API_THREAT_DETECTION_ENABLED: bool = os.getenv("API_THREAT_DETECTION_ENABLED", "True").lower() == "true"
API_IP_REPUTATION_ENABLED: bool = os.getenv("API_IP_REPUTATION_ENABLED", "True").lower() == "true"
API_ANOMALY_DETECTION_ENABLED: bool = os.getenv("API_ANOMALY_DETECTION_ENABLED", "True").lower() == "true"
# Rate Limiting Configuration # Rate Limiting Configuration
API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true"
# PrivateMode Standard tier limits (organization-level, not per user) # PrivateMode Standard tier limits (organization-level, not per user)
# These are shared across all API keys and users in the organization # These are shared across all API keys and users in the organization
@@ -102,22 +99,13 @@ class Settings(BaseSettings):
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200")) API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
# Security Thresholds
API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score
API_SECURITY_WARNING_THRESHOLD: float = float(os.getenv("API_SECURITY_WARNING_THRESHOLD", "0.6")) # Log warnings above this threshold
API_SECURITY_ANOMALY_THRESHOLD: float = float(os.getenv("API_SECURITY_ANOMALY_THRESHOLD", "0.7")) # Flag anomalies above this threshold
# Request Size Limits # Request Size Limits
API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium
# IP Security # IP Security
API_BLOCKED_IPS: List[str] = os.getenv("API_BLOCKED_IPS", "").split(",") if os.getenv("API_BLOCKED_IPS") else []
API_ALLOWED_IPS: List[str] = os.getenv("API_ALLOWED_IPS", "").split(",") if os.getenv("API_ALLOWED_IPS") else []
API_IP_REPUTATION_CACHE_TTL: int = int(os.getenv("API_IP_REPUTATION_CACHE_TTL", "3600")) # 1 hour
# Security Headers # Security Headers
API_SECURITY_HEADERS_ENABLED: bool = os.getenv("API_SECURITY_HEADERS_ENABLED", "True").lower() == "true"
API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'") API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
# Monitoring # Monitoring
@@ -130,6 +118,19 @@ class Settings(BaseSettings):
# Module configuration # Module configuration
MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml") MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml")
# RAG Embedding Configuration
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12"))
RAG_EMBEDDING_BATCH_SIZE: int = int(os.getenv("RAG_EMBEDDING_BATCH_SIZE", "3"))
RAG_EMBEDDING_RETRY_COUNT: int = int(os.getenv("RAG_EMBEDDING_RETRY_COUNT", "3"))
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv("RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16")
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0"))
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
# Plugin configuration # Plugin configuration
PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins") PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins")
PLUGINS_CONFIG_PATH: str = os.getenv("PLUGINS_CONFIG_PATH", "config/plugins.yaml") PLUGINS_CONFIG_PATH: str = os.getenv("PLUGINS_CONFIG_PATH", "config/plugins.yaml")
@@ -142,7 +143,10 @@ class Settings(BaseSettings):
model_config = { model_config = {
"env_file": ".env", "env_file": ".env",
"case_sensitive": True "case_sensitive": True,
# Ignore unknown environment variables to avoid validation errors
# when optional/deprecated flags are present in .env
"extra": "ignore",
} }

View File

@@ -1,744 +0,0 @@
"""
Core threat detection and security analysis for the platform
"""
import re
import time
from collections import defaultdict, deque
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Any, Union
from urllib.parse import unquote
from fastapi import Request
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
class ThreatLevel(Enum):
"""Threat severity levels"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class AuthLevel(Enum):
"""Authentication levels for rate limiting"""
AUTHENTICATED = "authenticated"
API_KEY = "api_key"
PREMIUM = "premium"
@dataclass
class SecurityThreat:
"""Security threat detection result"""
threat_type: str
level: ThreatLevel
confidence: float
description: str
source_ip: str
user_agent: Optional[str] = None
request_path: Optional[str] = None
payload: Optional[str] = None
timestamp: datetime = field(default_factory=datetime.utcnow)
mitigation: Optional[str] = None
@dataclass
class SecurityAnalysis:
"""Comprehensive security analysis result"""
is_threat: bool
threats: List[SecurityThreat]
risk_score: float
recommendations: List[str]
auth_level: AuthLevel
rate_limit_exceeded: bool
should_block: bool
timestamp: datetime = field(default_factory=datetime.utcnow)
@dataclass
class RateLimitInfo:
"""Rate limiting information"""
auth_level: AuthLevel
requests_per_minute: int
requests_per_hour: int
minute_limit: int
hour_limit: int
exceeded: bool
@dataclass
class AnomalyDetection:
"""Anomaly detection result"""
is_anomaly: bool
anomaly_type: str
severity: float
details: Dict[str, Any]
baseline_value: Optional[float] = None
current_value: Optional[float] = None
class ThreatDetectionService:
"""Core threat detection and security analysis service"""
def __init__(self):
self.name = "threat_detection"
# Statistics
self.stats = {
'total_requests_analyzed': 0,
'threats_detected': 0,
'threats_blocked': 0,
'anomalies_detected': 0,
'rate_limits_exceeded': 0,
'total_analysis_time': 0,
'threat_types': defaultdict(int),
'threat_levels': defaultdict(int),
'attacking_ips': defaultdict(int)
}
# Threat detection patterns
self.sql_injection_patterns = [
r"(\bunion\b.*\bselect\b)",
r"(\bselect\b.*\bfrom\b)",
r"(\binsert\b.*\binto\b)",
r"(\bupdate\b.*\bset\b)",
r"(\bdelete\b.*\bfrom\b)",
r"(\bdrop\b.*\btable\b)",
r"(\bor\b.*\b1\s*=\s*1\b)",
r"(\band\b.*\b1\s*=\s*1\b)",
r"(\bexec\b.*\bxp_\w+)",
r"(\bsp_\w+)",
r"(\bsleep\b\s*\(\s*\d+\s*\))",
r"(\bwaitfor\b.*\bdelay\b)",
r"(\bbenchmark\b\s*\(\s*\d+)",
r"(\bload_file\b\s*\()",
r"(\binto\b.*\boutfile\b)"
]
self.xss_patterns = [
r"<script[^>]*>.*?</script>",
r"<iframe[^>]*>.*?</iframe>",
r"<object[^>]*>.*?</object>",
r"<embed[^>]*>.*?</embed>",
r"<link[^>]*>",
r"<meta[^>]*>",
r"javascript:",
r"vbscript:",
r"on\w+\s*=",
r"style\s*=.*expression",
r"style\s*=.*javascript"
]
self.path_traversal_patterns = [
r"\.\.\/",
r"\.\.\\",
r"%2e%2e%2f",
r"%2e%2e%5c",
r"..%2f",
r"..%5c",
r"%252e%252e%252f",
r"%252e%252e%255c"
]
self.command_injection_patterns = [
r";\s*cat\s+",
r";\s*ls\s+",
r";\s*pwd\s*",
r";\s*whoami\s*",
r";\s*id\s*",
r";\s*uname\s*",
r";\s*ps\s+",
r";\s*netstat\s+",
r";\s*wget\s+",
r";\s*curl\s+",
r"\|\s*cat\s+",
r"\|\s*ls\s+",
r"&&\s*cat\s+",
r"&&\s*ls\s+"
]
self.suspicious_ua_patterns = [
r"sqlmap",
r"nikto",
r"nmap",
r"masscan",
r"zap",
r"burp",
r"w3af",
r"acunetix",
r"nessus",
r"openvas",
r"metasploit"
]
# Rate limiting tracking - separate by auth level (excluding unauthenticated since they're blocked)
self.rate_limits = {
AuthLevel.AUTHENTICATED: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
AuthLevel.API_KEY: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
AuthLevel.PREMIUM: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)})
}
# Anomaly detection
self.request_history = deque(maxlen=1000)
self.ip_history = defaultdict(lambda: deque(maxlen=100))
self.endpoint_history = defaultdict(lambda: deque(maxlen=100))
# Blocked and allowed IPs
self.blocked_ips = set(settings.API_BLOCKED_IPS)
self.allowed_ips = set(settings.API_ALLOWED_IPS) if settings.API_ALLOWED_IPS else None
# IP reputation cache
self.ip_reputation_cache = {}
self.cache_expiry = {}
# Compile patterns for performance
self._compile_patterns()
logger.info(f"ThreatDetectionService initialized with {len(self.sql_injection_patterns)} SQL patterns, "
f"{len(self.xss_patterns)} XSS patterns, rate limiting enabled: {settings.API_RATE_LIMITING_ENABLED}")
def _compile_patterns(self):
"""Compile regex patterns for better performance"""
try:
self.compiled_sql_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.sql_injection_patterns]
self.compiled_xss_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.xss_patterns]
self.compiled_path_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.path_traversal_patterns]
self.compiled_cmd_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.command_injection_patterns]
self.compiled_ua_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.suspicious_ua_patterns]
except re.error as e:
logger.error(f"Failed to compile security patterns: {e}")
# Fallback to empty lists to prevent crashes
self.compiled_sql_patterns = []
self.compiled_xss_patterns = []
self.compiled_path_patterns = []
self.compiled_cmd_patterns = []
self.compiled_ua_patterns = []
def determine_auth_level(self, request: Request, user_context: Optional[Dict] = None) -> AuthLevel:
"""Determine authentication level for rate limiting"""
# Check if request has API key authentication
if hasattr(request.state, 'api_key_context') and request.state.api_key_context:
api_key = request.state.api_key_context.get('api_key')
if api_key and hasattr(api_key, 'tier'):
# Check for premium tier
if api_key.tier in ['premium', 'enterprise']:
return AuthLevel.PREMIUM
return AuthLevel.API_KEY
# Check for JWT authentication
if user_context or hasattr(request.state, 'user'):
return AuthLevel.AUTHENTICATED
# Check Authorization header for API key
auth_header = request.headers.get("Authorization", "")
api_key_header = request.headers.get("X-API-Key", "")
if auth_header.startswith("Bearer ") or api_key_header:
return AuthLevel.API_KEY
# Default to authenticated since unauthenticated requests are blocked at middleware
return AuthLevel.AUTHENTICATED
def get_rate_limits(self, auth_level: AuthLevel) -> Tuple[int, int]:
"""Get rate limits for authentication level"""
if not settings.API_RATE_LIMITING_ENABLED:
return float('inf'), float('inf')
if auth_level == AuthLevel.AUTHENTICATED:
return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
elif auth_level == AuthLevel.API_KEY:
return (settings.API_RATE_LIMIT_API_KEY_PER_MINUTE, settings.API_RATE_LIMIT_API_KEY_PER_HOUR)
elif auth_level == AuthLevel.PREMIUM:
return (settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE, settings.API_RATE_LIMIT_PREMIUM_PER_HOUR)
else:
# Fallback to authenticated limits
return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
def check_rate_limit(self, client_ip: str, auth_level: AuthLevel) -> RateLimitInfo:
"""Check if request exceeds rate limits"""
minute_limit, hour_limit = self.get_rate_limits(auth_level)
current_time = time.time()
# Get or create tracking for this auth level
if auth_level not in self.rate_limits:
# This shouldn't happen, but handle gracefully
return RateLimitInfo(
auth_level=auth_level,
requests_per_minute=0,
requests_per_hour=0,
minute_limit=minute_limit,
hour_limit=hour_limit,
exceeded=False
)
ip_limits = self.rate_limits[auth_level][client_ip]
# Clean old entries
minute_ago = current_time - 60
hour_ago = current_time - 3600
while ip_limits['minute'] and ip_limits['minute'][0] < minute_ago:
ip_limits['minute'].popleft()
while ip_limits['hour'] and ip_limits['hour'][0] < hour_ago:
ip_limits['hour'].popleft()
# Check current counts
requests_per_minute = len(ip_limits['minute'])
requests_per_hour = len(ip_limits['hour'])
# Check if limits exceeded
exceeded = (requests_per_minute >= minute_limit) or (requests_per_hour >= hour_limit)
# Add current request to tracking
if not exceeded:
ip_limits['minute'].append(current_time)
ip_limits['hour'].append(current_time)
return RateLimitInfo(
auth_level=auth_level,
requests_per_minute=requests_per_minute,
requests_per_hour=requests_per_hour,
minute_limit=minute_limit,
hour_limit=hour_limit,
exceeded=exceeded
)
async def analyze_request(self, request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
"""Perform comprehensive security analysis on a request"""
start_time = time.time()
try:
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "")
path = str(request.url.path)
method = request.method
# Determine authentication level
auth_level = self.determine_auth_level(request, user_context)
# Check IP allowlist/blocklist first
if self.allowed_ips and client_ip not in self.allowed_ips:
threat = SecurityThreat(
threat_type="ip_not_allowed",
level=ThreatLevel.HIGH,
confidence=1.0,
description=f"IP {client_ip} not in allowlist",
source_ip=client_ip,
mitigation="Add IP to allowlist or remove IP restrictions"
)
return SecurityAnalysis(
is_threat=True,
threats=[threat],
risk_score=1.0,
recommendations=["Block request immediately"],
auth_level=auth_level,
rate_limit_exceeded=False,
should_block=True
)
if client_ip in self.blocked_ips:
threat = SecurityThreat(
threat_type="ip_blocked",
level=ThreatLevel.CRITICAL,
confidence=1.0,
description=f"IP {client_ip} is blocked",
source_ip=client_ip,
mitigation="Remove IP from blocklist if legitimate"
)
return SecurityAnalysis(
is_threat=True,
threats=[threat],
risk_score=1.0,
recommendations=["Block request immediately"],
auth_level=auth_level,
rate_limit_exceeded=False,
should_block=True
)
# Check rate limiting
rate_limit_info = self.check_rate_limit(client_ip, auth_level)
if rate_limit_info.exceeded:
self.stats['rate_limits_exceeded'] += 1
threat = SecurityThreat(
threat_type="rate_limit_exceeded",
level=ThreatLevel.MEDIUM,
confidence=0.9,
description=f"Rate limit exceeded for {auth_level.value}: {rate_limit_info.requests_per_minute}/min, {rate_limit_info.requests_per_hour}/hr",
source_ip=client_ip,
mitigation=f"Implement rate limiting, current limits: {rate_limit_info.minute_limit}/min, {rate_limit_info.hour_limit}/hr"
)
return SecurityAnalysis(
is_threat=True,
threats=[threat],
risk_score=0.7,
recommendations=[f"Rate limit exceeded for {auth_level.value} user"],
auth_level=auth_level,
rate_limit_exceeded=True,
should_block=True
)
# Skip threat detection if disabled
if not settings.API_THREAT_DETECTION_ENABLED:
return SecurityAnalysis(
is_threat=False,
threats=[],
risk_score=0.0,
recommendations=[],
auth_level=auth_level,
rate_limit_exceeded=False,
should_block=False
)
# Collect request data for threat analysis
query_params = str(request.query_params)
headers = dict(request.headers)
# Try to get body content safely
body_content = ""
try:
if hasattr(request, '_body') and request._body:
body_content = request._body.decode() if isinstance(request._body, bytes) else str(request._body)
except:
pass
threats = []
# Analyze for various threats
threats.extend(await self._detect_sql_injection(query_params, body_content, path, client_ip))
threats.extend(await self._detect_xss(query_params, body_content, headers, client_ip))
threats.extend(await self._detect_path_traversal(path, query_params, client_ip))
threats.extend(await self._detect_command_injection(query_params, body_content, client_ip))
threats.extend(await self._detect_suspicious_patterns(headers, user_agent, path, client_ip))
# Anomaly detection if enabled
if settings.API_ANOMALY_DETECTION_ENABLED:
anomaly = await self._detect_anomalies(client_ip, path, method, len(body_content))
if anomaly.is_anomaly and anomaly.severity > settings.API_SECURITY_ANOMALY_THRESHOLD:
threat = SecurityThreat(
threat_type=f"anomaly_{anomaly.anomaly_type}",
level=ThreatLevel.MEDIUM if anomaly.severity > 0.7 else ThreatLevel.LOW,
confidence=anomaly.severity,
description=f"Anomalous behavior detected: {anomaly.details}",
source_ip=client_ip,
user_agent=user_agent,
request_path=path
)
threats.append(threat)
# Calculate risk score
risk_score = self._calculate_risk_score(threats)
# Determine if request should be blocked
should_block = risk_score >= settings.API_SECURITY_RISK_THRESHOLD
# Generate recommendations
recommendations = self._generate_recommendations(threats, risk_score, auth_level)
# Update statistics
self._update_stats(threats, time.time() - start_time)
return SecurityAnalysis(
is_threat=len(threats) > 0,
threats=threats,
risk_score=risk_score,
recommendations=recommendations,
auth_level=auth_level,
rate_limit_exceeded=False,
should_block=should_block
)
except Exception as e:
logger.error(f"Error in threat analysis: {e}")
return SecurityAnalysis(
is_threat=False,
threats=[],
risk_score=0.0,
recommendations=["Error occurred during security analysis"],
auth_level=AuthLevel.AUTHENTICATED,
rate_limit_exceeded=False,
should_block=False
)
async def _detect_sql_injection(self, query_params: str, body_content: str, path: str, client_ip: str) -> List[SecurityThreat]:
"""Detect SQL injection attempts"""
threats = []
content_to_check = f"{query_params} {body_content} {path}".lower()
for pattern in self.compiled_sql_patterns:
if pattern.search(content_to_check):
threat = SecurityThreat(
threat_type="sql_injection",
level=ThreatLevel.HIGH,
confidence=0.85,
description="Potential SQL injection attempt detected",
source_ip=client_ip,
payload=pattern.pattern,
mitigation="Block request, sanitize input, use parameterized queries"
)
threats.append(threat)
break # Don't duplicate for multiple patterns
return threats
async def _detect_xss(self, query_params: str, body_content: str, headers: dict, client_ip: str) -> List[SecurityThreat]:
"""Detect XSS attempts"""
threats = []
content_to_check = f"{query_params} {body_content}".lower()
# Check headers for XSS
for header_name, header_value in headers.items():
content_to_check += f" {header_value}".lower()
for pattern in self.compiled_xss_patterns:
if pattern.search(content_to_check):
threat = SecurityThreat(
threat_type="xss",
level=ThreatLevel.HIGH,
confidence=0.80,
description="Potential XSS attack detected",
source_ip=client_ip,
payload=pattern.pattern,
mitigation="Block request, sanitize input, implement CSP headers"
)
threats.append(threat)
break
return threats
async def _detect_path_traversal(self, path: str, query_params: str, client_ip: str) -> List[SecurityThreat]:
"""Detect path traversal attempts"""
threats = []
content_to_check = f"{path} {query_params}".lower()
decoded_content = unquote(content_to_check)
for pattern in self.compiled_path_patterns:
if pattern.search(content_to_check) or pattern.search(decoded_content):
threat = SecurityThreat(
threat_type="path_traversal",
level=ThreatLevel.HIGH,
confidence=0.90,
description="Path traversal attempt detected",
source_ip=client_ip,
request_path=path,
mitigation="Block request, validate file paths, implement access controls"
)
threats.append(threat)
break
return threats
async def _detect_command_injection(self, query_params: str, body_content: str, client_ip: str) -> List[SecurityThreat]:
"""Detect command injection attempts"""
threats = []
content_to_check = f"{query_params} {body_content}".lower()
for pattern in self.compiled_cmd_patterns:
if pattern.search(content_to_check):
threat = SecurityThreat(
threat_type="command_injection",
level=ThreatLevel.CRITICAL,
confidence=0.95,
description="Command injection attempt detected",
source_ip=client_ip,
payload=pattern.pattern,
mitigation="Block request immediately, sanitize input, disable shell execution"
)
threats.append(threat)
break
return threats
async def _detect_suspicious_patterns(self, headers: dict, user_agent: str, path: str, client_ip: str) -> List[SecurityThreat]:
"""Detect suspicious patterns in headers and user agent"""
threats = []
# Check for suspicious user agents
ua_lower = user_agent.lower()
for pattern in self.compiled_ua_patterns:
if pattern.search(ua_lower):
threat = SecurityThreat(
threat_type="suspicious_user_agent",
level=ThreatLevel.HIGH,
confidence=0.85,
description=f"Suspicious user agent detected: {pattern.pattern}",
source_ip=client_ip,
user_agent=user_agent,
mitigation="Block request, monitor IP for further activity"
)
threats.append(threat)
break
# Check for suspicious headers
if "x-forwarded-for" in headers and "x-real-ip" in headers:
# Potential header manipulation
threat = SecurityThreat(
threat_type="header_manipulation",
level=ThreatLevel.LOW,
confidence=0.30,
description="Potential IP header manipulation detected",
source_ip=client_ip,
mitigation="Validate proxy headers, implement IP whitelisting"
)
threats.append(threat)
return threats
async def _detect_anomalies(self, client_ip: str, path: str, method: str, body_size: int) -> AnomalyDetection:
"""Detect anomalous behavior patterns"""
try:
# Request size anomaly
max_size = settings.API_MAX_REQUEST_BODY_SIZE
if body_size > max_size:
return AnomalyDetection(
is_anomaly=True,
anomaly_type="request_size",
severity=0.8,
details={"body_size": body_size, "threshold": max_size},
current_value=body_size,
baseline_value=max_size // 10
)
# Unusual endpoint access
if path.startswith("/admin") or path.startswith("/api/admin"):
return AnomalyDetection(
is_anomaly=True,
anomaly_type="sensitive_endpoint",
severity=0.6,
details={"path": path, "reason": "admin endpoint access"},
current_value=1.0,
baseline_value=0.0
)
# IP request frequency anomaly
current_time = time.time()
ip_requests = self.ip_history[client_ip]
# Clean old entries (last 5 minutes)
five_minutes_ago = current_time - 300
while ip_requests and ip_requests[0] < five_minutes_ago:
ip_requests.popleft()
ip_requests.append(current_time)
if len(ip_requests) > 100: # More than 100 requests in 5 minutes
return AnomalyDetection(
is_anomaly=True,
anomaly_type="request_frequency",
severity=0.7,
details={"requests_5min": len(ip_requests), "threshold": 100},
current_value=len(ip_requests),
baseline_value=10 # 10 requests baseline
)
return AnomalyDetection(
is_anomaly=False,
anomaly_type="none",
severity=0.0,
details={}
)
except Exception as e:
logger.error(f"Error in anomaly detection: {e}")
return AnomalyDetection(
is_anomaly=False,
anomaly_type="error",
severity=0.0,
details={"error": str(e)}
)
def _calculate_risk_score(self, threats: List[SecurityThreat]) -> float:
"""Calculate overall risk score based on threats"""
if not threats:
return 0.0
score = 0.0
for threat in threats:
level_multiplier = {
ThreatLevel.LOW: 0.25,
ThreatLevel.MEDIUM: 0.5,
ThreatLevel.HIGH: 0.75,
ThreatLevel.CRITICAL: 1.0
}
score += threat.confidence * level_multiplier.get(threat.level, 0.5)
# Normalize to 0-1 range
return min(score / len(threats), 1.0)
def _generate_recommendations(self, threats: List[SecurityThreat], risk_score: float, auth_level: AuthLevel) -> List[str]:
"""Generate security recommendations based on analysis"""
recommendations = []
if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
recommendations.append("CRITICAL: Block this request immediately")
elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
recommendations.append("HIGH: Consider blocking or rate limiting this IP")
elif risk_score > 0.4:
recommendations.append("MEDIUM: Monitor this IP closely")
threat_types = {threat.threat_type for threat in threats}
if "sql_injection" in threat_types:
recommendations.append("Implement parameterized queries and input validation")
if "xss" in threat_types:
recommendations.append("Implement Content Security Policy (CSP) headers")
if "command_injection" in threat_types:
recommendations.append("Disable shell execution and validate all inputs")
if "path_traversal" in threat_types:
recommendations.append("Implement proper file path validation and access controls")
if "rate_limit_exceeded" in threat_types:
recommendations.append(f"Rate limiting active for {auth_level.value} user")
if not recommendations:
recommendations.append("No immediate action required, continue monitoring")
return recommendations
def _update_stats(self, threats: List[SecurityThreat], analysis_time: float):
"""Update service statistics"""
self.stats['total_requests_analyzed'] += 1
self.stats['total_analysis_time'] += analysis_time
if threats:
self.stats['threats_detected'] += len(threats)
for threat in threats:
self.stats['threat_types'][threat.threat_type] += 1
self.stats['threat_levels'][threat.level.value] += 1
if threat.source_ip:
self.stats['attacking_ips'][threat.source_ip] += 1
def get_stats(self) -> Dict[str, Any]:
"""Get service statistics"""
avg_time = (self.stats['total_analysis_time'] / self.stats['total_requests_analyzed']
if self.stats['total_requests_analyzed'] > 0 else 0)
# Get top attacking IPs
top_ips = sorted(self.stats['attacking_ips'].items(), key=lambda x: x[1], reverse=True)[:10]
return {
"total_requests_analyzed": self.stats['total_requests_analyzed'],
"threats_detected": self.stats['threats_detected'],
"threats_blocked": self.stats['threats_blocked'],
"anomalies_detected": self.stats['anomalies_detected'],
"rate_limits_exceeded": self.stats['rate_limits_exceeded'],
"avg_analysis_time": avg_time,
"threat_types": dict(self.stats['threat_types']),
"threat_levels": dict(self.stats['threat_levels']),
"top_attacking_ips": top_ips,
"security_enabled": settings.API_SECURITY_ENABLED,
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED
}
# Global threat detection service instance
threat_detection_service = ThreatDetectionService()

View File

@@ -53,6 +53,14 @@ async def lifespan(app: FastAPI):
# Initialize config manager # Initialize config manager
await init_config_manager() await init_config_manager()
# Initialize LLM service (needed by RAG module)
from app.services.llm.service import llm_service
try:
await llm_service.initialize()
logger.info("LLM service initialized successfully")
except Exception as e:
logger.warning(f"LLM service initialization failed: {e}")
# Initialize analytics service # Initialize analytics service
init_analytics_service() init_analytics_service()

View File

@@ -1,371 +0,0 @@
"""
Rate limiting middleware
"""
import time
import redis
from typing import Dict, Optional
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import asyncio
from datetime import datetime, timedelta
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
class RateLimiter:
"""Rate limiting implementation using Redis"""
def __init__(self):
try:
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
self.redis_client.ping() # Test connection
logger.info("Rate limiter initialized with Redis backend")
except Exception as e:
logger.warning(f"Redis not available for rate limiting: {e}")
self.redis_client = None
# Fall back to in-memory rate limiting
self.memory_store: Dict[str, Dict[str, float]] = {}
async def check_rate_limit(
self,
key: str,
limit: int,
window_seconds: int,
identifier: str = "default"
) -> tuple[bool, Dict[str, int]]:
"""
Check if request is within rate limit
Args:
key: Rate limiting key (e.g., IP address, API key)
limit: Maximum number of requests allowed
window_seconds: Time window in seconds
identifier: Additional identifier for the rate limit
Returns:
Tuple of (is_allowed, headers_dict)
"""
full_key = f"rate_limit:{identifier}:{key}"
current_time = int(time.time())
window_start = current_time - window_seconds
if self.redis_client:
return await self._check_redis_rate_limit(
full_key, limit, window_seconds, current_time, window_start
)
else:
return self._check_memory_rate_limit(
full_key, limit, window_seconds, current_time, window_start
)
async def _check_redis_rate_limit(
self,
key: str,
limit: int,
window_seconds: int,
current_time: int,
window_start: int
) -> tuple[bool, Dict[str, int]]:
"""Check rate limit using Redis"""
pipe = self.redis_client.pipeline()
# Remove old entries
pipe.zremrangebyscore(key, 0, window_start)
# Count current requests in window
pipe.zcard(key)
# Add current request
pipe.zadd(key, {str(current_time): current_time})
# Set expiration
pipe.expire(key, window_seconds + 1)
results = pipe.execute()
current_requests = results[1]
# Calculate remaining requests and reset time
remaining = max(0, limit - current_requests - 1)
reset_time = current_time + window_seconds
headers = {
"X-RateLimit-Limit": limit,
"X-RateLimit-Remaining": remaining,
"X-RateLimit-Reset": reset_time,
"X-RateLimit-Window": window_seconds
}
is_allowed = current_requests < limit
if not is_allowed:
logger.warning(f"Rate limit exceeded for key: {key}")
return is_allowed, headers
def _check_memory_rate_limit(
self,
key: str,
limit: int,
window_seconds: int,
current_time: int,
window_start: int
) -> tuple[bool, Dict[str, int]]:
"""Check rate limit using in-memory storage"""
if key not in self.memory_store:
self.memory_store[key] = {}
# Clean old entries
store = self.memory_store[key]
keys_to_remove = [k for k, v in store.items() if v < window_start]
for k in keys_to_remove:
del store[k]
current_requests = len(store)
# Calculate remaining requests and reset time
remaining = max(0, limit - current_requests - 1)
reset_time = current_time + window_seconds
headers = {
"X-RateLimit-Limit": limit,
"X-RateLimit-Remaining": remaining,
"X-RateLimit-Reset": reset_time,
"X-RateLimit-Window": window_seconds
}
is_allowed = current_requests < limit
if is_allowed:
# Add current request
store[str(current_time)] = current_time
else:
logger.warning(f"Rate limit exceeded for key: {key}")
return is_allowed, headers
# Global rate limiter instance
rate_limiter = RateLimiter()
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware for FastAPI"""
def __init__(self, app):
super().__init__(app)
self.rate_limiter = RateLimiter()
logger.info("RateLimitMiddleware initialized")
async def dispatch(self, request: Request, call_next):
"""Process request through rate limiting"""
# Skip rate limiting if disabled in settings
if not settings.API_RATE_LIMITING_ENABLED:
response = await call_next(request)
return response
# Skip rate limiting for all internal API endpoints (platform operations)
if request.url.path.startswith("/api-internal/v1/"):
response = await call_next(request)
return response
# Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service)
# Skip for all other endpoints
if not (request.url.path.startswith("/api/v1/chat/completions") or
request.url.path.startswith("/api/v1/embeddings") or
request.url.path.startswith("/api/v1/models") or
request.url.path.startswith("/api/v1/llm/")):
response = await call_next(request)
return response
# Skip rate limiting for health checks and static files
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
response = await call_next(request)
return response
# Get client IP
client_ip = request.client.host
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
client_ip = forwarded_for.split(",")[0].strip()
# Check for API key in headers
api_key = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
api_key = auth_header[7:]
elif request.headers.get("X-API-Key"):
api_key = request.headers.get("X-API-Key")
# Determine rate limiting strategy
headers = {}
is_allowed = True
if api_key:
# API key-based rate limiting
api_key_key = f"api_key:{api_key}"
# First check organization-wide limits (PrivateMode limits are org-wide)
org_key = "organization:privatemode"
# Check organization per-minute limit
org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit(
org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute"
)
# Check organization per-hour limit
org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit(
org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour"
)
# If organization limits are exceeded, return 429
if not (org_allowed_minute and org_allowed_hour):
logger.warning(f"Organization rate limit exceeded for {org_key}")
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Organization rate limit exceeded"},
headers=org_headers_minute
)
# Then check per-API key limits
limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE
limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR
# Check per-minute limit
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
api_key_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
api_key_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
else:
# IP-based rate limiting for unauthenticated requests
rate_limit_key = f"ip:{client_ip}"
# More restrictive limits for unauthenticated requests
limit_per_minute = 20 # Hardcoded for unauthenticated users
limit_per_hour = 100
# Check per-minute limit
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
rate_limit_key, limit_per_minute, 60, "minute"
)
# Check per-hour limit
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
rate_limit_key, limit_per_hour, 3600, "hour"
)
is_allowed = is_allowed_minute and is_allowed_hour
headers = headers_minute # Use minute headers for response
# If rate limit exceeded, return 429
if not is_allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"error": "RATE_LIMIT_EXCEEDED",
"message": "Rate limit exceeded. Please try again later.",
"details": {
"limit": headers["X-RateLimit-Limit"],
"reset_time": headers["X-RateLimit-Reset"]
}
},
headers={k: str(v) for k, v in headers.items()}
)
# Continue with request
response = await call_next(request)
# Add rate limit headers to response
for key, value in headers.items():
response.headers[key] = str(value)
return response
# Keep the old function for backward compatibility
async def rate_limit_middleware(request: Request, call_next):
"""Legacy function - use RateLimitMiddleware class instead"""
middleware = RateLimitMiddleware(None)
return await middleware.dispatch(request, call_next)
class RateLimitExceeded(HTTPException):
"""Exception raised when rate limit is exceeded"""
def __init__(self, limit: int, reset_time: int):
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Rate limit exceeded. Limit: {limit}, Reset: {reset_time}"
)
# Decorator for applying rate limits to specific endpoints
def rate_limit(requests_per_minute: int = 60, requests_per_hour: int = 1000):
"""
Decorator to apply rate limiting to specific endpoints
Args:
requests_per_minute: Maximum requests per minute
requests_per_hour: Maximum requests per hour
"""
def decorator(func):
async def wrapper(*args, **kwargs):
# This would be implemented to work with FastAPI dependencies
# For now, this is a placeholder for endpoint-specific rate limiting
return await func(*args, **kwargs)
return wrapper
return decorator
# Helper functions for different rate limiting strategies
async def check_api_key_rate_limit(api_key: str, endpoint: str) -> bool:
"""Check rate limit for specific API key and endpoint"""
# This would lookup API key specific limits from database
# For now, using default limits
key = f"api_key:{api_key}:endpoint:{endpoint}"
is_allowed, _ = await rate_limiter.check_rate_limit(
key, limit=100, window_seconds=60, identifier="endpoint"
)
return is_allowed
async def check_user_rate_limit(user_id: str, action: str) -> bool:
"""Check rate limit for specific user and action"""
key = f"user:{user_id}:action:{action}"
is_allowed, _ = await rate_limiter.check_rate_limit(
key, limit=50, window_seconds=60, identifier="user_action"
)
return is_allowed
async def apply_burst_protection(key: str) -> bool:
"""Apply burst protection for high-frequency actions"""
# Allow burst of 10 requests in 10 seconds
is_allowed, _ = await rate_limiter.check_rate_limit(
key, limit=10, window_seconds=10, identifier="burst"
)
return is_allowed

View File

@@ -1,210 +0,0 @@
"""
Security middleware for request/response processing
"""
import json
import time
from typing import Callable, Optional, Dict, Any
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.config import settings
from app.core.logging import get_logger
from app.core.threat_detection import threat_detection_service, SecurityAnalysis
logger = get_logger(__name__)
class SecurityMiddleware(BaseHTTPMiddleware):
"""Security middleware for threat detection and request filtering - DISABLED"""
def __init__(self, app, enabled: bool = True):
super().__init__(app)
self.enabled = False # Force disable regardless of settings
logger.info("SecurityMiddleware initialized, enabled: False (DISABLED)")
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Process request through security analysis - DISABLED"""
# Security disabled, always pass through
return await call_next(request)
def _should_skip_security(self, request: Request) -> bool:
"""Determine if security analysis should be skipped for this request"""
path = request.url.path
# Skip for health checks, authentication endpoints, and static assets
skip_paths = [
"/health",
"/metrics",
"/api/v1/docs",
"/api/v1/openapi.json",
"/api/v1/redoc",
"/favicon.ico",
"/api/v1/auth/register",
"/api/v1/auth/login",
"/api/v1/auth/refresh", # Allow refresh endpoint
"/api-internal/v1/auth/register",
"/api-internal/v1/auth/login",
"/api-internal/v1/auth/refresh", # Allow refresh endpoint for internal API
"/", # Root endpoint
]
# Skip for static file extensions
static_extensions = [".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".ico", ".svg", ".woff", ".woff2"]
return (
path in skip_paths or
any(path.endswith(ext) for ext in static_extensions) or
path.startswith("/static/")
)
def _has_valid_auth(self, request: Request) -> bool:
"""Check if request has valid authentication"""
# Check Authorization header
auth_header = request.headers.get("Authorization", "")
api_key_header = request.headers.get("X-API-Key", "")
# Has some form of auth token/key
return (
auth_header.startswith("Bearer ") and len(auth_header) > 7 or
len(api_key_header.strip()) > 0
)
def _create_block_response(self, analysis: SecurityAnalysis) -> JSONResponse:
"""Create response for blocked requests"""
# Determine status code based on threat type
status_code = 403 # Forbidden by default
# Critical threats get 403
for threat in analysis.threats:
if threat.threat_type in ["command_injection", "sql_injection"]:
status_code = 403
break
response_data = {
"error": "Security Policy Violation",
"message": "Request blocked due to security policy violation",
"risk_score": round(analysis.risk_score, 3),
"auth_level": analysis.auth_level.value,
"threat_count": len(analysis.threats),
"recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations
}
response = JSONResponse(
content=response_data,
status_code=status_code
)
return response
def _add_security_headers(self, response: Response) -> Response:
"""Add security headers to response"""
if not settings.API_SECURITY_HEADERS_ENABLED:
return response
# Standard security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Only add HSTS for HTTPS
if hasattr(response, 'headers') and response.headers.get("X-Forwarded-Proto") == "https":
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
# Content Security Policy
if settings.API_CSP_HEADER:
response.headers["Content-Security-Policy"] = settings.API_CSP_HEADER
return response
def _add_security_metrics(self, response: Response, analysis: SecurityAnalysis, analysis_time: float) -> Response:
"""Add security metrics to response headers (for debugging/monitoring)"""
# Only add in debug mode or for admin users
if settings.APP_DEBUG:
response.headers["X-Security-Risk-Score"] = str(round(analysis.risk_score, 3))
response.headers["X-Security-Threats"] = str(len(analysis.threats))
response.headers["X-Security-Auth-Level"] = analysis.auth_level.value
response.headers["X-Security-Analysis-Time"] = f"{analysis_time*1000:.1f}ms"
return response
async def _log_security_event(self, request: Request, analysis: SecurityAnalysis):
"""Log security events for audit and monitoring"""
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "")
# Create security event log
event_data = {
"timestamp": analysis.timestamp.isoformat(),
"client_ip": client_ip,
"user_agent": user_agent,
"path": str(request.url.path),
"method": request.method,
"risk_score": round(analysis.risk_score, 3),
"auth_level": analysis.auth_level.value,
"threat_count": len(analysis.threats),
"rate_limit_exceeded": analysis.rate_limit_exceeded,
"should_block": analysis.should_block,
"threats": [
{
"type": threat.threat_type,
"level": threat.level.value,
"confidence": round(threat.confidence, 3),
"description": threat.description
}
for threat in analysis.threats[:5] # Limit to first 5 threats
],
"recommendations": analysis.recommendations
}
# Log at appropriate level based on risk
if analysis.should_block:
logger.warning(f"SECURITY_BLOCK: {json.dumps(event_data)}")
elif analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
logger.warning(f"SECURITY_WARNING: {json.dumps(event_data)}")
else:
logger.info(f"SECURITY_THREAT: {json.dumps(event_data)}")
def setup_security_middleware(app, enabled: bool = True) -> None:
"""Setup security middleware on FastAPI app"""
if enabled and settings.API_SECURITY_ENABLED:
app.add_middleware(SecurityMiddleware, enabled=enabled)
logger.info("Security middleware enabled")
else:
logger.info("Security middleware disabled")
# Helper functions for manual security checks
async def analyze_request_security(request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
"""Manually analyze request security (for use in route handlers)"""
return await threat_detection_service.analyze_request(request, user_context)
def get_security_stats() -> Dict[str, Any]:
"""Get security statistics"""
return threat_detection_service.get_stats()
def is_request_blocked(request: Request) -> bool:
"""Check if request was blocked by security analysis"""
if hasattr(request.state, 'security_analysis'):
return request.state.security_analysis.should_block
return False
def get_request_risk_score(request: Request) -> float:
"""Get risk score for request"""
if hasattr(request.state, 'security_analysis'):
return request.state.security_analysis.risk_score
return 0.0
def get_request_auth_level(request: Request) -> str:
"""Get authentication level for request"""
if hasattr(request.state, 'security_analysis'):
return request.state.security_analysis.auth_level.value
return "unknown"

View File

@@ -162,6 +162,7 @@ class DocumentProcessor:
async def _process_document(self, task: ProcessingTask) -> bool: async def _process_document(self, task: ProcessingTask) -> bool:
"""Process a single document""" """Process a single document"""
from datetime import datetime
from app.db.database import async_session_factory from app.db.database import async_session_factory
async with async_session_factory() as session: async with async_session_factory() as session:
try: try:
@@ -182,16 +183,24 @@ class DocumentProcessor:
document.status = ProcessingStatus.PROCESSING document.status = ProcessingStatus.PROCESSING
await session.commit() await session.commit()
# Get RAG module for processing (now includes content processing) # Get RAG module for processing
try: try:
from app.services.module_manager import module_manager # Import RAG module and initialize it properly
rag_module = module_manager.get_module('rag') from modules.rag.main import RAGModule
from app.core.config import settings
# Create and initialize RAG module instance
rag_module = RAGModule(settings)
init_result = await rag_module.initialize()
if not rag_module.enabled:
raise Exception("Failed to enable RAG module")
except Exception as e: except Exception as e:
logger.error(f"Failed to get RAG module: {e}") logger.error(f"Failed to get RAG module: {e}")
raise Exception(f"RAG module not available: {e}") raise Exception(f"RAG module not available: {e}")
if not rag_module: if not rag_module or not rag_module.enabled:
raise Exception("RAG module not available") raise Exception("RAG module not available or not enabled")
logger.info(f"RAG module loaded successfully for document {task.document_id}") logger.info(f"RAG module loaded successfully for document {task.document_id}")
@@ -204,23 +213,31 @@ class DocumentProcessor:
# Process with RAG module # Process with RAG module
logger.info(f"Starting document processing for document {task.document_id} with RAG module") logger.info(f"Starting document processing for document {task.document_id} with RAG module")
# Special handling for JSONL files - skip processing phase
if document.file_type == 'jsonl':
# For JSONL files, we don't need to process content here
# The optimized JSONL processor will handle everything during indexing
document.converted_content = f"JSONL file with {len(file_content)} bytes"
document.word_count = 0 # Will be updated during indexing
document.character_count = len(file_content)
document.document_metadata = {"file_path": document.file_path, "processed": "jsonl"}
document.status = ProcessingStatus.PROCESSED
document.processed_at = datetime.utcnow()
logger.info(f"JSONL document {task.document_id} marked for optimized processing")
else:
# Standard processing for other file types
try: try:
# Add timeout to prevent hanging # Add timeout to prevent hanging
processed_doc = await asyncio.wait_for( processed_doc = await asyncio.wait_for(
rag_module.process_document( rag_module.process_document(
file_content, file_content,
document.original_filename, document.original_filename,
{} {"file_path": document.file_path}
), ),
timeout=300.0 # 5 minute timeout timeout=300.0 # 5 minute timeout
) )
logger.info(f"Document processing completed for document {task.document_id}") logger.info(f"Document processing completed for document {task.document_id}")
except asyncio.TimeoutError:
logger.error(f"Document processing timed out for document {task.document_id}")
raise Exception("Document processing timed out after 5 minutes")
except Exception as e:
logger.error(f"Document processing failed for document {task.document_id}: {e}")
raise
# Update document with processed content # Update document with processed content
document.converted_content = processed_doc.content document.converted_content = processed_doc.content
@@ -229,6 +246,12 @@ class DocumentProcessor:
document.document_metadata = processed_doc.metadata document.document_metadata = processed_doc.metadata
document.status = ProcessingStatus.PROCESSED document.status = ProcessingStatus.PROCESSED
document.processed_at = datetime.utcnow() document.processed_at = datetime.utcnow()
except asyncio.TimeoutError:
logger.error(f"Document processing timed out for document {task.document_id}")
raise Exception("Document processing timed out after 5 minutes")
except Exception as e:
logger.error(f"Document processing failed for document {task.document_id}: {e}")
raise
# Index in RAG system using same RAG module # Index in RAG system using same RAG module
if rag_module and document.converted_content: if rag_module and document.converted_content:
@@ -245,6 +268,49 @@ class DocumentProcessor:
} }
# Use the correct Qdrant collection name for this document # Use the correct Qdrant collection name for this document
# For JSONL files, we need to use the processed document flow
if document.file_type == 'jsonl':
# Create a ProcessedDocument for the JSONL processor
from app.modules.rag.main import ProcessedDocument
from datetime import datetime
import hashlib
# Calculate file hash
processed_at = datetime.utcnow()
file_hash = hashlib.md5(str(document.id).encode()).hexdigest()
processed_doc = ProcessedDocument(
id=str(document.id),
content="", # Will be filled by JSONL processor
extracted_text="", # Will be filled by JSONL processor
metadata={
**doc_metadata,
"file_path": document.file_path
},
original_filename=document.original_filename,
file_type=document.file_type,
mime_type=document.mime_type,
language=document.document_metadata.get('language', 'EN'),
word_count=0, # Will be updated during processing
sentence_count=0, # Will be updated during processing
entities=[],
keywords=[],
processing_time=0.0,
processed_at=processed_at,
file_hash=file_hash,
file_size=document.file_size
)
# The JSONL processor will read the original file
await asyncio.wait_for(
rag_module.index_processed_document(
processed_doc=processed_doc,
collection_name=document.collection.qdrant_collection_name
),
timeout=300.0 # 5 minute timeout for JSONL processing
)
else:
# Use standard indexing for other file types
await asyncio.wait_for( await asyncio.wait_for(
rag_module.index_document( rag_module.index_document(
content=document.converted_content, content=document.converted_content,
@@ -271,7 +337,9 @@ class DocumentProcessor:
except Exception as e: except Exception as e:
logger.error(f"Failed to index document {task.document_id} in RAG: {e}") logger.error(f"Failed to index document {task.document_id} in RAG: {e}")
# Keep as processed even if indexing fails # Mark as error since indexing failed
document.status = ProcessingStatus.ERROR
document.processing_error = f"Indexing failed: {str(e)}"
# Don't raise the exception to avoid retries on indexing failures # Don't raise the exception to avoid retries on indexing failures
await session.commit() await session.commit()

View File

@@ -28,9 +28,19 @@ class EmbeddingService:
await llm_service.initialize() await llm_service.initialize()
# Test LLM service health # Test LLM service health
health_summary = llm_service.get_health_summary() if not llm_service._initialized:
if health_summary.get("service_status") != "healthy": logger.error("LLM service not initialized")
logger.error(f"LLM service unhealthy: {health_summary}") return False
# Check if PrivateMode provider is available
try:
provider_status = await llm_service.get_provider_status()
privatemode_status = provider_status.get("privatemode")
if not privatemode_status or privatemode_status.status != "healthy":
logger.error(f"PrivateMode provider not available: {privatemode_status}")
return False
except Exception as e:
logger.error(f"Failed to check provider status: {e}")
return False return False
self.initialized = True self.initialized = True
@@ -75,6 +85,12 @@ class EmbeddingService:
else: else:
truncated_text = text truncated_text = text
# Guard: skip empty inputs (validator rejects empty strings)
if not truncated_text.strip():
logger.debug("Empty input for embedding; using fallback vector")
batch_embeddings.append(self._generate_fallback_embedding(text))
continue
# Call LLM service embedding endpoint # Call LLM service embedding endpoint
from app.services.llm.service import llm_service from app.services.llm.service import llm_service
from app.services.llm.models import EmbeddingRequest from app.services.llm.models import EmbeddingRequest

View File

@@ -25,9 +25,10 @@ class EnhancedEmbeddingService(EmbeddingService):
'requests_count': 0, 'requests_count': 0,
'window_start': time.time(), 'window_start': time.time(),
'window_size': 60, # 1 minute window 'window_size': 60, # 1 minute window
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 60)), # Configurable 'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 12)), # Configurable
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff 'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 0.5)), 'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 1.0)),
'delay_per_request': float(getattr(settings, 'RAG_EMBEDDING_DELAY_PER_REQUEST', 0.5)),
'last_rate_limit_error': None 'last_rate_limit_error': None
} }
@@ -38,7 +39,7 @@ class EnhancedEmbeddingService(EmbeddingService):
if max_retries is None: if max_retries is None:
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3)) max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 5)) batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 3))
if not self.initialized: if not self.initialized:
logger.warning("Embedding service not initialized, using fallback") logger.warning("Embedding service not initialized, using fallback")
@@ -76,9 +77,6 @@ class EnhancedEmbeddingService(EmbeddingService):
# Make the request # Make the request
embeddings = await self._get_embeddings_batch_impl(texts) embeddings = await self._get_embeddings_batch_impl(texts)
# Update rate limit tracker on success
self._update_rate_limit_tracker(success=True)
return embeddings, True return embeddings, True
except Exception as e: except Exception as e:
@@ -120,6 +118,12 @@ class EnhancedEmbeddingService(EmbeddingService):
embeddings = [] embeddings = []
for text in texts: for text in texts:
# Respect rate limit before each request
while self._is_rate_limited():
delay = self._get_rate_limit_delay()
logger.warning(f"Rate limit window exceeded, waiting {delay:.2f}s before next request")
await asyncio.sleep(delay)
# Truncate text if needed # Truncate text if needed
max_chars = 1600 max_chars = 1600
truncated_text = text[:max_chars] if len(text) > max_chars else text truncated_text = text[:max_chars] if len(text) > max_chars else text
@@ -145,6 +149,12 @@ class EnhancedEmbeddingService(EmbeddingService):
else: else:
raise ValueError("Invalid response structure") raise ValueError("Invalid response structure")
# Count this successful request and optionally delay between requests
self._update_rate_limit_tracker(success=True)
per_req_delay = self.rate_limit_tracker.get('delay_per_request', 0)
if per_req_delay and per_req_delay > 0:
await asyncio.sleep(per_req_delay)
return embeddings return embeddings
def _is_rate_limited(self) -> bool: def _is_rate_limited(self) -> bool:

View File

@@ -16,6 +16,7 @@ from .models import ResilienceConfig
class ProviderConfig(BaseModel): class ProviderConfig(BaseModel):
"""Configuration for an LLM provider""" """Configuration for an LLM provider"""
name: str = Field(..., description="Provider name") name: str = Field(..., description="Provider name")
provider_type: str = Field(..., description="Provider type (e.g., 'openai', 'privatemode')")
enabled: bool = Field(True, description="Whether provider is enabled") enabled: bool = Field(True, description="Whether provider is enabled")
base_url: str = Field(..., description="Provider base URL") base_url: str = Field(..., description="Provider base URL")
api_key_env_var: str = Field(..., description="Environment variable for API key") api_key_env_var: str = Field(..., description="Environment variable for API key")
@@ -53,9 +54,6 @@ class LLMServiceConfig(BaseModel):
enable_security_checks: bool = Field(True, description="Enable security validation") enable_security_checks: bool = Field(True, description="Enable security validation")
enable_metrics_collection: bool = Field(True, description="Enable metrics collection") enable_metrics_collection: bool = Field(True, description="Enable metrics collection")
# Security settings
security_risk_threshold: float = Field(0.8, ge=0.0, le=1.0, description="Risk threshold for blocking")
security_warning_threshold: float = Field(0.6, ge=0.0, le=1.0, description="Risk threshold for warnings")
max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length") max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length")
max_response_length: int = Field(32000, ge=1000, description="Maximum response length") max_response_length: int = Field(32000, ge=1000, description="Maximum response length")
@@ -78,12 +76,6 @@ class LLMServiceConfig(BaseModel):
# Model routing (model_name -> provider_name) # Model routing (model_name -> provider_name)
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing") model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
@validator('security_risk_threshold')
def validate_risk_threshold(cls, v, values):
warning_threshold = values.get('security_warning_threshold', 0.6)
if v <= warning_threshold:
raise ValueError("Risk threshold must be greater than warning threshold")
return v
def create_default_config() -> LLMServiceConfig: def create_default_config() -> LLMServiceConfig:
@@ -93,6 +85,7 @@ def create_default_config() -> LLMServiceConfig:
# Models will be fetched dynamically from proxy /models endpoint # Models will be fetched dynamically from proxy /models endpoint
privatemode_config = ProviderConfig( privatemode_config = ProviderConfig(
name="privatemode", name="privatemode",
provider_type="privatemode",
enabled=True, enabled=True,
base_url=settings.PRIVATEMODE_PROXY_URL, base_url=settings.PRIVATEMODE_PROXY_URL,
api_key_env_var="PRIVATEMODE_API_KEY", api_key_env_var="PRIVATEMODE_API_KEY",
@@ -119,9 +112,6 @@ def create_default_config() -> LLMServiceConfig:
config = LLMServiceConfig( config = LLMServiceConfig(
default_provider="privatemode", default_provider="privatemode",
enable_detailed_logging=settings.LOG_LLM_PROMPTS, enable_detailed_logging=settings.LOG_LLM_PROMPTS,
enable_security_checks=settings.API_SECURITY_ENABLED,
security_risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
security_warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
providers={ providers={
"privatemode": privatemode_config "privatemode": privatemode_config
}, },

View File

@@ -124,7 +124,6 @@ class MetricsCollector:
total_requests = len(self._metrics) total_requests = len(self._metrics)
successful_requests = sum(1 for m in self._metrics if m.success) successful_requests = sum(1 for m in self._metrics if m.success)
failed_requests = total_requests - successful_requests failed_requests = total_requests - successful_requests
security_blocked = sum(1 for m in self._metrics if not m.success and m.security_risk_score > 0.8)
# Calculate averages # Calculate averages
latencies = [m.latency_ms for m in self._metrics if m.latency_ms > 0] latencies = [m.latency_ms for m in self._metrics if m.latency_ms > 0]
@@ -143,7 +142,6 @@ class MetricsCollector:
total_requests=total_requests, total_requests=total_requests,
successful_requests=successful_requests, successful_requests=successful_requests,
failed_requests=failed_requests, failed_requests=failed_requests,
security_blocked_requests=security_blocked,
average_latency_ms=avg_latency, average_latency_ms=avg_latency,
average_risk_score=avg_risk_score, average_risk_score=avg_risk_score,
provider_metrics=provider_metrics, provider_metrics=provider_metrics,

View File

@@ -157,7 +157,6 @@ class LLMMetrics(BaseModel):
total_requests: int = Field(0, description="Total requests processed") total_requests: int = Field(0, description="Total requests processed")
successful_requests: int = Field(0, description="Successful requests") successful_requests: int = Field(0, description="Successful requests")
failed_requests: int = Field(0, description="Failed requests") failed_requests: int = Field(0, description="Failed requests")
security_blocked_requests: int = Field(0, description="Security blocked requests")
average_latency_ms: float = Field(0.0, description="Average response latency") average_latency_ms: float = Field(0.0, description="Average response latency")
average_risk_score: float = Field(0.0, description="Average security risk score") average_risk_score: float = Field(0.0, description="Average security risk score")
provider_metrics: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Per-provider metrics") provider_metrics: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Per-provider metrics")

View File

@@ -452,6 +452,8 @@ class PrivateModeProvider(BaseLLMProvider):
else: else:
error_text = await response.text() error_text = await response.text()
# Log the detailed error response from the provider
logger.error(f"PrivateMode embedding error - Status {response.status}: {error_text}")
self._handle_http_error(response.status, error_text, "embeddings") self._handle_http_error(response.status, error_text, "embeddings")
except aiohttp.ClientError as e: except aiohttp.ClientError as e:

View File

@@ -1,325 +0,0 @@
"""
LLM Security Manager
Handles prompt injection detection and audit logging.
Provides comprehensive security for LLM interactions.
"""
import os
import re
import json
import logging
import hashlib
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from app.core.config import settings
logger = logging.getLogger(__name__)
class SecurityManager:
"""Manages security for LLM operations"""
def __init__(self):
self._setup_prompt_injection_patterns()
def _setup_prompt_injection_patterns(self):
"""Setup patterns for prompt injection detection"""
self.injection_patterns = [
# Direct instruction injection
r"(?i)(ignore|forget|disregard|override).{0,20}(instructions|rules|prompts)",
r"(?i)(new|updated|different)\s+(instructions|rules|system)",
r"(?i)act\s+as\s+(if|though)\s+you\s+(are|were)",
r"(?i)pretend\s+(to\s+be|you\s+are)",
r"(?i)you\s+are\s+now\s+(a|an)\s+",
# System role manipulation
r"(?i)system\s*:\s*",
r"(?i)\[system\]",
r"(?i)<system>",
r"(?i)assistant\s*:\s*",
r"(?i)\[assistant\]",
# Escape attempts
r"(?i)\\n\\n#+",
r"(?i)```\s*(system|assistant|user)",
r"(?i)---\s*(new|system|override)",
# Role manipulation
r"(?i)(you|your)\s+(role|purpose|function)\s+(is|has\s+changed)",
r"(?i)switch\s+to\s+(admin|developer|debug)\s+mode",
r"(?i)(admin|root|sudo|developer)\s+(access|mode|privileges)",
# Information extraction attempts
r"(?i)(show|display|reveal|expose)\s+(your|the)\s+(prompt|instructions|system)",
r"(?i)what\s+(are|were)\s+your\s+(original|initial)\s+(instructions|prompts)",
r"(?i)(debug|verbose|diagnostic)\s+mode",
# Encoding/obfuscation attempts
r"(?i)base64\s*:",
r"(?i)hex\s*:",
r"(?i)unicode\s*:",
r"(?i)\b[A-Za-z0-9+/]{40,}={0,2}\b", # More specific base64 pattern (longer sequences)
# SQL injection patterns (more specific to reduce false positives)
r"(?i)(union\s+select|select\s+\*|insert\s+into|update\s+\w+\s+set|delete\s+from|drop\s+table|create\s+table)\s",
r"(?i)(or|and)\s+\d+\s*=\s*\d+",
r"(?i)';?\s*(drop\s+table|delete\s+from|insert\s+into)",
# Command injection patterns
r"(?i)(exec|eval|system|shell|cmd)\s*\(",
r"(?i)(\$\(|\`)[^)]+(\)|\`)",
r"(?i)&&\s*(rm|del|format)",
# Jailbreak attempts
r"(?i)jailbreak",
r"(?i)break\s+out\s+of",
r"(?i)escape\s+(the|your)\s+(rules|constraints)",
r"(?i)(DAN|Do\s+Anything\s+Now)",
r"(?i)unrestricted\s+mode",
]
self.compiled_patterns = [re.compile(pattern) for pattern in self.injection_patterns]
logger.info(f"Initialized {len(self.injection_patterns)} prompt injection patterns")
def validate_prompt_security(self, messages: List[Dict[str, str]]) -> Tuple[bool, float, List[str]]:
"""
Validate messages for prompt injection attempts
Returns:
Tuple[bool, float, List[str]]: (is_safe, risk_score, detected_patterns)
"""
detected_patterns = []
total_risk = 0.0
# Check if this is a system/RAG request
is_system_request = self._is_system_request(messages)
for message in messages:
content = message.get("content", "")
if not content:
continue
# Check against injection patterns with context awareness
for i, pattern in enumerate(self.compiled_patterns):
matches = pattern.findall(content)
if matches:
# Apply context-aware risk calculation
pattern_risk = self._calculate_pattern_risk(i, matches, message.get("role", "user"), is_system_request)
total_risk += pattern_risk
detected_patterns.append({
"pattern_index": i,
"pattern": self.injection_patterns[i],
"matches": matches,
"risk": pattern_risk
})
# Additional security checks with context awareness
total_risk += self._check_message_characteristics(content, message.get("role", "user"), is_system_request)
# Normalize risk score (0.0 to 1.0)
risk_score = min(total_risk / len(messages) if messages else 0.0, 1.0)
# Never block - always return True for is_safe
is_safe = True
if detected_patterns:
logger.info(f"Detected {len(detected_patterns)} potential injection patterns, risk score: {risk_score} (system_request: {is_system_request})")
return is_safe, risk_score, detected_patterns
def _calculate_pattern_risk(self, pattern_index: int, matches: List, role: str, is_system_request: bool) -> float:
"""Calculate risk score for a detected pattern with context awareness"""
# Different patterns have different risk levels
high_risk_patterns = [0, 1, 2, 3, 4, 5, 6, 7, 22, 23, 24] # System manipulation, jailbreak
medium_risk_patterns = [8, 9, 10, 11, 12, 13, 17, 18, 19, 20, 21] # Escape attempts, info extraction
# Base risk score
base_risk = 0.8 if pattern_index in high_risk_patterns else 0.5 if pattern_index in medium_risk_patterns else 0.3
# Apply context-specific risk reduction
if is_system_request or role == "system":
# Reduce risk for system messages and RAG content
if pattern_index in [14, 15, 16]: # Encoding patterns (base64, hex, unicode)
base_risk *= 0.2 # Reduce encoding risk by 80% for system content
elif pattern_index in [17, 18, 19]: # SQL patterns
base_risk *= 0.3 # Reduce SQL risk by 70% for system content
else:
base_risk *= 0.6 # Reduce other risks by 40% for system content
# Increase risk based on number of matches, but cap it
match_multiplier = min(1.0 + (len(matches) - 1) * 0.1, 1.5) # Reduced multiplier
return base_risk * match_multiplier
def _check_message_characteristics(self, content: str, role: str, is_system_request: bool) -> float:
"""Check message characteristics for additional risk factors with context awareness"""
risk = 0.0
# Excessive length (potential stuffing attack) - less restrictive for system content
length_threshold = 50000 if is_system_request else 10000 # Much higher threshold for system content
if len(content) > length_threshold:
risk += 0.1 if is_system_request else 0.3
# High ratio of special characters - more lenient for system content
special_chars = sum(1 for c in content if not c.isalnum() and not c.isspace())
if len(content) > 0:
char_ratio = special_chars / len(content)
threshold = 0.8 if is_system_request else 0.5
if char_ratio > threshold:
risk += 0.2 if is_system_request else 0.4
# Multiple encoding indicators - reduced risk for system content
encoding_indicators = ["base64", "hex", "unicode", "url", "ascii"]
found_encodings = sum(1 for indicator in encoding_indicators if indicator.lower() in content.lower())
if found_encodings > 1:
risk += 0.1 if is_system_request else 0.3
# Excessive newlines or formatting - more lenient for system content
newline_threshold = 200 if is_system_request else 50
if content.count('\n') > newline_threshold or content.count('\\n') > newline_threshold:
risk += 0.1 if is_system_request else 0.2
return risk
def _is_system_request(self, messages: List[Dict[str, str]]) -> bool:
"""Determine if this is a system/RAG request"""
if not messages:
return False
# Check for system messages
for message in messages:
if message.get("role") == "system":
return True
# Check message content for RAG indicators
for message in messages:
content = message.get("content", "")
if ("document:" in content.lower() or
"context:" in content.lower() or
"source:" in content.lower() or
"retrieved:" in content.lower() or
"citation:" in content.lower() or
"reference:" in content.lower()):
return True
return False
def create_audit_log(
self,
user_id: str,
api_key_id: int,
provider: str,
model: str,
request_type: str,
risk_score: float,
detected_patterns: List[str],
metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Create comprehensive audit log for LLM request"""
audit_entry = {
"timestamp": datetime.utcnow().isoformat(),
"user_id": user_id,
"api_key_id": api_key_id,
"provider": provider,
"model": model,
"request_type": request_type,
"security": {
"risk_score": risk_score,
"detected_patterns": detected_patterns,
"security_check_passed": risk_score < settings.API_SECURITY_RISK_THRESHOLD
},
"metadata": metadata or {},
"audit_hash": None # Will be set below
}
# Create hash for audit integrity
audit_hash = self._create_audit_hash(audit_entry)
audit_entry["audit_hash"] = audit_hash
# Log based on risk level (never block, only log)
if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
logger.warning(f"HIGH RISK LLM REQUEST DETECTED (NOT BLOCKED): {json.dumps(audit_entry)}")
elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
logger.info(f"MEDIUM RISK LLM REQUEST: {json.dumps(audit_entry)}")
else:
logger.info(f"LLM REQUEST AUDIT: user={user_id}, model={model}, risk={risk_score:.3f}")
return audit_entry
def _create_audit_hash(self, audit_entry: Dict[str, Any]) -> str:
"""Create hash for audit trail integrity"""
# Create hash from key fields (excluding the hash itself)
hash_data = {
"timestamp": audit_entry["timestamp"],
"user_id": audit_entry["user_id"],
"api_key_id": audit_entry["api_key_id"],
"provider": audit_entry["provider"],
"model": audit_entry["model"],
"request_type": audit_entry["request_type"],
"risk_score": audit_entry["security"]["risk_score"]
}
hash_string = json.dumps(hash_data, sort_keys=True)
return hashlib.sha256(hash_string.encode()).hexdigest()
def log_detailed_request(
self,
messages: List[Dict[str, str]],
model: str,
user_id: str,
provider: str,
context_info: Optional[Dict[str, Any]] = None
):
"""Log detailed LLM request if LOG_LLM_PROMPTS is enabled"""
if not settings.LOG_LLM_PROMPTS:
return
logger.info("=== DETAILED LLM REQUEST ===")
logger.info(f"Model: {model}")
logger.info(f"Provider: {provider}")
logger.info(f"User ID: {user_id}")
if context_info:
for key, value in context_info.items():
logger.info(f"{key}: {value}")
logger.info("Messages to LLM:")
for i, message in enumerate(messages):
role = message.get("role", "unknown")
content = message.get("content", "")[:500] # Truncate for logging
logger.info(f" Message {i+1} [{role}]: {content}{'...' if len(message.get('content', '')) > 500 else ''}")
logger.info("=== END DETAILED LLM REQUEST ===")
def log_detailed_response(
self,
response_content: str,
token_usage: Optional[Dict[str, int]] = None,
provider: str = "unknown"
):
"""Log detailed LLM response if LOG_LLM_PROMPTS is enabled"""
if not settings.LOG_LLM_PROMPTS:
return
logger.info("=== DETAILED LLM RESPONSE ===")
logger.info(f"Provider: {provider}")
logger.info(f"Response content: {response_content[:500]}{'...' if len(response_content) > 500 else ''}")
if token_usage:
logger.info(f"Token usage - Prompt: {token_usage.get('prompt_tokens', 0)}, "
f"Completion: {token_usage.get('completion_tokens', 0)}, "
f"Total: {token_usage.get('total_tokens', 0)}")
logger.info("=== END DETAILED LLM RESPONSE ===")
class SecurityError(Exception):
"""Security-related errors in LLM operations"""
pass
# Global security manager instance
security_manager = SecurityManager()

View File

@@ -17,9 +17,8 @@ from .models import (
) )
from .config import config_manager, ProviderConfig from .config import config_manager, ProviderConfig
from ...core.config import settings from ...core.config import settings
from .security import security_manager
from .resilience import ResilienceManagerFactory from .resilience import ResilienceManagerFactory
from .metrics import metrics_collector # from .metrics import metrics_collector
from .providers import BaseLLMProvider, PrivateModeProvider from .providers import BaseLLMProvider, PrivateModeProvider
from .exceptions import ( from .exceptions import (
LLMError, ProviderError, SecurityError, ConfigurationError, LLMError, ProviderError, SecurityError, ConfigurationError,
@@ -150,45 +149,8 @@ class LLMService:
if not request.messages: if not request.messages:
raise ValidationError("Messages cannot be empty", field="messages") raise ValidationError("Messages cannot be empty", field="messages")
# Security validation (only if enabled) # Security validation disabled - always allow requests
messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages] risk_score = 0.0
if settings.API_SECURITY_ENABLED:
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict)
else:
# Security disabled - always safe
is_safe, risk_score, detected_patterns = True, 0.0, []
if not is_safe:
# Log security violation
security_manager.create_audit_log(
user_id=request.user_id,
api_key_id=request.api_key_id,
provider="blocked",
model=request.model,
request_type="chat_completion",
risk_score=risk_score,
detected_patterns=[p.get("pattern", "") for p in detected_patterns]
)
# Record blocked request
metrics_collector.record_request(
provider="security",
model=request.model,
request_type="chat_completion",
success=False,
latency_ms=0,
security_risk_score=risk_score,
error_code="SECURITY_BLOCKED",
user_id=request.user_id,
api_key_id=request.api_key_id
)
raise SecurityError(
"Request blocked due to security concerns",
risk_score=risk_score,
details={"detected_patterns": detected_patterns}
)
# Get provider for model # Get provider for model
provider_name = self._get_provider_for_model(request.model) provider_name = self._get_provider_for_model(request.model)
@@ -197,18 +159,7 @@ class LLMService:
if not provider: if not provider:
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name) raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
# Log detailed request if enabled # Security logging disabled
security_manager.log_detailed_request(
messages=messages_dict,
model=request.model,
user_id=request.user_id,
provider=provider_name,
context_info={
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"risk_score": f"{risk_score:.3f}"
}
)
# Execute with resilience # Execute with resilience
resilience_manager = ResilienceManagerFactory.get_manager(provider_name) resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
@@ -222,85 +173,46 @@ class LLMService:
non_retryable_exceptions=(SecurityError, ValidationError) non_retryable_exceptions=(SecurityError, ValidationError)
) )
# Update response with security information # Security features disabled
response.security_check = is_safe
response.risk_score = risk_score
response.detected_patterns = [p.get("pattern", "") for p in detected_patterns]
# Log detailed response if enabled # Security logging disabled
if response.choices:
content = response.choices[0].message.content
security_manager.log_detailed_response(
response_content=content,
token_usage=response.usage.model_dump() if response.usage else None,
provider=provider_name
)
# Record successful request # Record successful request - metrics disabled
total_latency = (time.time() - start_time) * 1000 total_latency = (time.time() - start_time) * 1000
metrics_collector.record_request( # metrics_collector.record_request(
provider=provider_name, # provider=provider_name,
model=request.model, # model=request.model,
request_type="chat_completion", # request_type="chat_completion",
success=True, # success=True,
latency_ms=total_latency, # latency_ms=total_latency,
token_usage=response.usage.model_dump() if response.usage else None, # token_usage=response.usage.model_dump() if response.usage else None,
security_risk_score=risk_score, # security_risk_score=risk_score,
user_id=request.user_id, # user_id=request.user_id,
api_key_id=request.api_key_id # api_key_id=request.api_key_id
) # )
# Create audit log # Security audit logging disabled
security_manager.create_audit_log(
user_id=request.user_id,
api_key_id=request.api_key_id,
provider=provider_name,
model=request.model,
request_type="chat_completion",
risk_score=risk_score,
detected_patterns=[p.get("pattern", "") for p in detected_patterns],
metadata={
"success": True,
"latency_ms": total_latency,
"token_usage": response.usage.model_dump() if response.usage else None
}
)
return response return response
except Exception as e: except Exception as e:
# Record failed request # Record failed request - metrics disabled
total_latency = (time.time() - start_time) * 1000 total_latency = (time.time() - start_time) * 1000
error_code = getattr(e, 'error_code', e.__class__.__name__) error_code = getattr(e, 'error_code', e.__class__.__name__)
metrics_collector.record_request( # metrics_collector.record_request(
provider=provider_name, # provider=provider_name,
model=request.model, # model=request.model,
request_type="chat_completion", # request_type="chat_completion",
success=False, # success=False,
latency_ms=total_latency, # latency_ms=total_latency,
security_risk_score=risk_score, # security_risk_score=risk_score,
error_code=error_code, # error_code=error_code,
user_id=request.user_id, # user_id=request.user_id,
api_key_id=request.api_key_id # api_key_id=request.api_key_id
) # )
# Create audit log for failure # Security audit logging disabled
security_manager.create_audit_log(
user_id=request.user_id,
api_key_id=request.api_key_id,
provider=provider_name,
model=request.model,
request_type="chat_completion",
risk_score=risk_score,
detected_patterns=[p.get("pattern", "") for p in detected_patterns],
metadata={
"success": False,
"error": str(e),
"error_code": error_code,
"latency_ms": total_latency
}
)
raise raise
@@ -309,21 +221,8 @@ class LLMService:
if not self._initialized: if not self._initialized:
await self.initialize() await self.initialize()
# Security validation (same as non-streaming) # Security validation disabled - always allow streaming requests
messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages] risk_score = 0.0
if settings.API_SECURITY_ENABLED:
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict)
else:
# Security disabled - always safe
is_safe, risk_score, detected_patterns = True, 0.0, []
if not is_safe:
raise SecurityError(
"Streaming request blocked due to security concerns",
risk_score=risk_score,
details={"detected_patterns": detected_patterns}
)
# Get provider # Get provider
provider_name = self._get_provider_for_model(request.model) provider_name = self._get_provider_for_model(request.model)
@@ -345,19 +244,19 @@ class LLMService:
yield chunk yield chunk
except Exception as e: except Exception as e:
# Record streaming failure # Record streaming failure - metrics disabled
error_code = getattr(e, 'error_code', e.__class__.__name__) error_code = getattr(e, 'error_code', e.__class__.__name__)
metrics_collector.record_request( # metrics_collector.record_request(
provider=provider_name, # provider=provider_name,
model=request.model, # model=request.model,
request_type="chat_completion_stream", # request_type="chat_completion_stream",
success=False, # success=False,
latency_ms=0, # latency_ms=0,
security_risk_score=risk_score, # security_risk_score=risk_score,
error_code=error_code, # error_code=error_code,
user_id=request.user_id, # user_id=request.user_id,
api_key_id=request.api_key_id # api_key_id=request.api_key_id
) # )
raise raise
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse: async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
@@ -365,23 +264,8 @@ class LLMService:
if not self._initialized: if not self._initialized:
await self.initialize() await self.initialize()
# Security validation for embedding input # Security validation disabled - always allow embedding requests
input_text = request.input if isinstance(request.input, str) else " ".join(request.input) risk_score = 0.0
if settings.API_SECURITY_ENABLED:
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security([
{"role": "user", "content": input_text}
])
else:
# Security disabled - always safe
is_safe, risk_score, detected_patterns = True, 0.0, []
if not is_safe:
raise SecurityError(
"Embedding request blocked due to security concerns",
risk_score=risk_score,
details={"detected_patterns": detected_patterns}
)
# Get provider # Get provider
provider_name = self._get_provider_for_model(request.model) provider_name = self._get_provider_for_model(request.model)
@@ -402,42 +286,40 @@ class LLMService:
non_retryable_exceptions=(SecurityError, ValidationError) non_retryable_exceptions=(SecurityError, ValidationError)
) )
# Update response with security information # Security features disabled
response.security_check = is_safe
response.risk_score = risk_score
# Record successful request # Record successful request - metrics disabled
total_latency = (time.time() - start_time) * 1000 total_latency = (time.time() - start_time) * 1000
metrics_collector.record_request( # metrics_collector.record_request(
provider=provider_name, # provider=provider_name,
model=request.model, # model=request.model,
request_type="embedding", # request_type="embedding",
success=True, # success=True,
latency_ms=total_latency, # latency_ms=total_latency,
token_usage=response.usage.model_dump() if response.usage else None, # token_usage=response.usage.model_dump() if response.usage else None,
security_risk_score=risk_score, # security_risk_score=risk_score,
user_id=request.user_id, # user_id=request.user_id,
api_key_id=request.api_key_id # api_key_id=request.api_key_id
) # )
return response return response
except Exception as e: except Exception as e:
# Record failed request # Record failed request - metrics disabled
total_latency = (time.time() - start_time) * 1000 total_latency = (time.time() - start_time) * 1000
error_code = getattr(e, 'error_code', e.__class__.__name__) error_code = getattr(e, 'error_code', e.__class__.__name__)
metrics_collector.record_request( # metrics_collector.record_request(
provider=provider_name, # provider=provider_name,
model=request.model, # model=request.model,
request_type="embedding", # request_type="embedding",
success=False, # success=False,
latency_ms=total_latency, # latency_ms=total_latency,
security_risk_score=risk_score, # security_risk_score=risk_score,
error_code=error_code, # error_code=error_code,
user_id=request.user_id, # user_id=request.user_id,
api_key_id=request.api_key_id # api_key_id=request.api_key_id
) # )
raise raise
@@ -492,12 +374,18 @@ class LLMService:
return status_dict return status_dict
def get_metrics(self) -> LLMMetrics: def get_metrics(self) -> LLMMetrics:
"""Get service metrics""" """Get service metrics - metrics disabled"""
return metrics_collector.get_metrics() # return metrics_collector.get_metrics()
return LLMMetrics(
total_requests=0,
success_rate=0.0,
avg_latency_ms=0,
error_rates={}
)
def get_health_summary(self) -> Dict[str, Any]: def get_health_summary(self) -> Dict[str, Any]:
"""Get comprehensive health summary""" """Get comprehensive health summary - metrics disabled"""
metrics_health = metrics_collector.get_health_summary() # metrics_health = metrics_collector.get_health_summary()
resilience_health = ResilienceManagerFactory.get_all_health_status() resilience_health = ResilienceManagerFactory.get_all_health_status()
return { return {
@@ -505,7 +393,7 @@ class LLMService:
"startup_time": self._startup_time.isoformat() if self._startup_time else None, "startup_time": self._startup_time.isoformat() if self._startup_time else None,
"provider_count": len(self._providers), "provider_count": len(self._providers),
"active_providers": list(self._providers.keys()), "active_providers": list(self._providers.keys()),
"metrics": metrics_health, "metrics": {"status": "disabled"},
"resilience": resilience_health "resilience": resilience_health
} }

View File

@@ -1,153 +0,0 @@
"""
Token-based rate limiting for LLM service
"""
import time
import redis
from typing import Dict, Optional, Tuple
from datetime import datetime, timedelta
from ..core.config import settings
from ..core.logging import get_logger
logger = get_logger(__name__)
class TokenRateLimiter:
"""Token-based rate limiting implementation"""
def __init__(self):
try:
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
self.redis_client.ping()
logger.info("Token rate limiter initialized with Redis backend")
except Exception as e:
logger.warning(f"Redis not available for token rate limiting: {e}")
self.redis_client = None
# Fall back to in-memory rate limiting
self.in_memory_store = {}
logger.info("Token rate limiter using in-memory fallback")
async def check_token_limits(
self,
provider: str,
prompt_tokens: int,
completion_tokens: int = 0
) -> Tuple[bool, Dict[str, str]]:
"""
Check if token usage is within limits
Args:
provider: Provider name (e.g., "privatemode")
prompt_tokens: Number of prompt tokens to use
completion_tokens: Number of completion tokens to use
Returns:
Tuple of (is_allowed, headers)
"""
# Get token limits from configuration
from .config import get_config
config = get_config()
token_limits = config.token_limits_per_minute
# Check organization-wide limits
org_key = f"tokens:org:{provider}"
# Get current usage
current_usage = await self._get_token_usage(org_key)
# Calculate new usage
new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens
new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens
# Check limits
prompt_limit = token_limits.get("prompt_tokens", 20000)
completion_limit = token_limits.get("completion_tokens", 10000)
is_allowed = (
new_prompt_tokens <= prompt_limit and
new_completion_tokens <= completion_limit
)
if is_allowed:
# Update usage
await self._update_token_usage(org_key, prompt_tokens, completion_tokens)
logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, "
f"{new_completion_tokens}/{completion_limit} completion")
# Calculate remaining tokens
remaining_prompt = max(0, prompt_limit - new_prompt_tokens)
remaining_completion = max(0, completion_limit - new_completion_tokens)
# Create headers
headers = {
"X-TokenLimit-Prompt-Remaining": str(remaining_prompt),
"X-TokenLimit-Completion-Remaining": str(remaining_completion),
"X-TokenLimit-Prompt-Limit": str(prompt_limit),
"X-TokenLimit-Completion-Limit": str(completion_limit),
"X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute
}
if not is_allowed:
logger.warning(f"Token rate limit exceeded for {provider}. "
f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. "
f"Current: {current_usage}")
return is_allowed, headers
async def _get_token_usage(self, key: str) -> Dict[str, int]:
"""Get current token usage"""
if self.redis_client:
try:
data = self.redis_client.hgetall(key)
if data:
return {
"prompt_tokens": int(data.get("prompt_tokens", 0)),
"completion_tokens": int(data.get("completion_tokens", 0)),
"updated_at": float(data.get("updated_at", time.time()))
}
except Exception as e:
logger.error(f"Error getting token usage from Redis: {e}")
# Fallback to in-memory
return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0})
async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int):
"""Update token usage"""
if self.redis_client:
try:
pipe = self.redis_client.pipeline()
pipe.hincrby(key, "prompt_tokens", prompt_tokens)
pipe.hincrby(key, "completion_tokens", completion_tokens)
pipe.hset(key, "updated_at", time.time())
pipe.expire(key, 60) # Expire after 1 minute
pipe.execute()
except Exception as e:
logger.error(f"Error updating token usage in Redis: {e}")
# Fallback to in-memory
self._update_in_memory(key, prompt_tokens, completion_tokens)
else:
self._update_in_memory(key, prompt_tokens, completion_tokens)
def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int):
"""Update in-memory token usage"""
if key not in self.in_memory_store:
self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0}
self.in_memory_store[key]["prompt_tokens"] += prompt_tokens
self.in_memory_store[key]["completion_tokens"] += completion_tokens
self.in_memory_store[key]["updated_at"] = time.time()
def cleanup_expired(self):
"""Clean up expired entries (for in-memory store)"""
if not self.redis_client:
current_time = time.time()
expired_keys = [
key for key, data in self.in_memory_store.items()
if current_time - data.get("updated_at", 0) > 60
]
for key in expired_keys:
del self.in_memory_store[key]
# Global token rate limiter instance
token_rate_limiter = TokenRateLimiter()

View File

@@ -755,10 +755,11 @@ class RAGService:
# Process with RAG module # Process with RAG module
try: try:
# Pass file_path in metadata so JSONL indexing can reopen the source file
processed_doc = await rag_module.process_document( processed_doc = await rag_module.process_document(
file_content, file_content,
document.original_filename, document.original_filename,
{} {"file_path": document.file_path}
) )
# Success case - update document with processed content # Success case - update document with processed content

View File

@@ -638,11 +638,19 @@ class RAGModule(BaseModule):
np.random.seed(hash(text) % 2**32) np.random.seed(hash(text) % 2**32)
return np.random.random(self.embedding_model.get("dimension", 768)).tolist() return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
async def _generate_embeddings(self, texts: List[str]) -> List[List[float]]: async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
"""Generate embeddings for multiple texts (batch processing)""" """Generate embeddings for multiple texts (batch processing)"""
if self.embedding_service: if self.embedding_service:
# Add task-specific prefixes for better E5 model performance
if is_document:
# For document passages, use "passage:" prefix
prefixed_texts = [f"passage: {text}" for text in texts]
else:
# For queries, use "query:" prefix (handled in search method)
prefixed_texts = texts
# Use real embedding service for batch processing # Use real embedding service for batch processing
return await self.embedding_service.get_embeddings(texts) return await self.embedding_service.get_embeddings(prefixed_texts)
else: else:
# Fallback to individual processing # Fallback to individual processing
embeddings = [] embeddings = []
@@ -922,12 +930,18 @@ class RAGModule(BaseModule):
- Each line contains a JSON object with 'id' and 'payload' - Each line contains a JSON object with 'id' and 'payload'
- Payload contains 'question', 'language', and 'answer' fields - Payload contains 'question', 'language', and 'answer' fields
- Combines question and answer into searchable content - Combines question and answer into searchable content
Performance optimizations:
- Processes articles in smaller batches to reduce memory usage
- Uses streaming approach for large files
""" """
try: try:
# Use streaming approach for large files
jsonl_content = content.decode('utf-8', errors='replace') jsonl_content = content.decode('utf-8', errors='replace')
lines = jsonl_content.strip().split('\n') lines = jsonl_content.strip().split('\n')
processed_articles = [] processed_articles = []
batch_size = 50 # Process in batches of 50 articles
for line_num, line in enumerate(lines, 1): for line_num, line in enumerate(lines, 1):
if not line.strip(): if not line.strip():
@@ -1153,7 +1167,7 @@ class RAGModule(BaseModule):
chunks = self._chunk_text(content) chunks = self._chunk_text(content)
# Generate embeddings for all chunks in batch (more efficient) # Generate embeddings for all chunks in batch (more efficient)
embeddings = await self._generate_embeddings(chunks) embeddings = await self._generate_embeddings(chunks, is_document=True)
# Create document points # Create document points
points = [] points = []
@@ -1204,6 +1218,24 @@ class RAGModule(BaseModule):
collection_name = collection_name or self.default_collection_name collection_name = collection_name or self.default_collection_name
try: try:
# Special handling for JSONL files
if processed_doc.file_type == 'jsonl':
# Import the optimized JSONL processor
from app.services.jsonl_processor import JSONLProcessor
jsonl_processor = JSONLProcessor(self)
# Read the original file content
with open(processed_doc.metadata.get('file_path', ''), 'rb') as f:
file_content = f.read()
# Process using the optimized JSONL processor
return await jsonl_processor.process_and_index_jsonl(
collection_name=collection_name,
content=file_content,
filename=processed_doc.original_filename,
metadata=processed_doc.metadata
)
# Ensure collection exists # Ensure collection exists
await self._ensure_collection_exists(collection_name) await self._ensure_collection_exists(collection_name)
@@ -1216,7 +1248,7 @@ class RAGModule(BaseModule):
chunks = self._chunk_text(processed_doc.content) chunks = self._chunk_text(processed_doc.content)
# Generate embeddings for all chunks in batch (more efficient) # Generate embeddings for all chunks in batch (more efficient)
embeddings = await self._generate_embeddings(chunks) embeddings = await self._generate_embeddings(chunks, is_document=True)
# Create document points with enhanced metadata # Create document points with enhanced metadata
points = [] points = []
@@ -1339,24 +1371,48 @@ class RAGModule(BaseModule):
score_threshold=score_threshold / 2 # Lower threshold for initial search score_threshold=score_threshold / 2 # Lower threshold for initial search
) )
# Combine scores # Combine scores with improved normalization
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3}) hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
vector_weight = hybrid_weights.get("vector", 0.7) vector_weight = hybrid_weights.get("vector", 0.7)
bm25_weight = hybrid_weights.get("bm25", 0.3) bm25_weight = hybrid_weights.get("bm25", 0.3)
# Create hybrid results # Get score distributions for better normalization
vector_scores = [r.score for r in vector_results]
bm25_scores_list = list(bm25_scores.values())
# Calculate statistics for normalization
if vector_scores:
v_max = max(vector_scores)
v_min = min(vector_scores)
v_range = v_max - v_min if v_max != v_min else 1
else:
v_max, v_min, v_range = 1, 0, 1
if bm25_scores_list:
bm25_max = max(bm25_scores_list)
bm25_min = min(bm25_scores_list)
bm25_range = bm25_max - bm25_min if bm25_max != bm25_min else 1
else:
bm25_max, bm25_min, bm25_range = 1, 0, 1
# Create hybrid results with improved scoring
hybrid_results = [] hybrid_results = []
for result in vector_results: for result in vector_results:
doc_id = result.payload.get("document_id", "") doc_id = result.payload.get("document_id", "")
vector_score = result.score vector_score = result.score
bm25_score = bm25_scores.get(doc_id, 0.0) bm25_score = bm25_scores.get(doc_id, 0.0)
# Normalize scores (simple min-max normalization) # Improved normalization using actual score distributions
vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0 vector_norm = (vector_score - v_min) / v_range if v_range > 0 else 0.5
bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1 bm25_norm = (bm25_score - bm25_min) / bm25_range if bm25_range > 0 else 0.5
# Calculate hybrid score # Apply reciprocal rank fusion for better combination
hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm) # This gives more weight to documents that rank highly in both methods
rrf_vector = 1.0 / (1.0 + vector_results.index(result) + 1) # +1 to avoid division by zero
rrf_bm25 = 1.0 / (1.0 + sorted(bm25_scores_list, reverse=True).index(bm25_score) + 1) if bm25_score in bm25_scores_list else 0
# Calculate hybrid score using both normalized scores and RRF
hybrid_score = (vector_weight * vector_norm + bm25_weight * bm25_norm) * 0.7 + (rrf_vector + rrf_bm25) * 0.3
# Create new point with hybrid score # Create new point with hybrid score
hybrid_point = ScoredPoint( hybrid_point = ScoredPoint(
@@ -1435,7 +1491,7 @@ class RAGModule(BaseModule):
# Normalize score to 0-1 range # Normalize score to 0-1 range
return min(score / 10.0, 1.0) # Simple normalization return min(score / 10.0, 1.0) # Simple normalization
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]: async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
"""Search for relevant documents""" """Search for relevant documents"""
if not self.enabled: if not self.enabled:
raise RuntimeError("RAG module not initialized") raise RuntimeError("RAG module not initialized")
@@ -1453,8 +1509,10 @@ class RAGModule(BaseModule):
import time import time
start_time = time.time() start_time = time.time()
# Generate query embedding # Generate query embedding with task-specific prefix for better retrieval
query_embedding = await self._generate_embedding(query) # The E5 model works better with "query:" prefix for search queries
optimized_query = f"query: {query}"
query_embedding = await self._generate_embedding(optimized_query)
# Build filter # Build filter
search_filter = None search_filter = None
@@ -1474,7 +1532,8 @@ class RAGModule(BaseModule):
# Check if hybrid search is enabled # Check if hybrid search is enabled
enable_hybrid = self.config.get("enable_hybrid", False) enable_hybrid = self.config.get("enable_hybrid", False)
score_threshold = self.config.get("score_threshold", 0.3) # Use provided score_threshold or fall back to config
search_score_threshold = score_threshold if score_threshold is not None else self.config.get("score_threshold", 0.3)
if enable_hybrid and NLTK_AVAILABLE: if enable_hybrid and NLTK_AVAILABLE:
# Perform hybrid search (vector + BM25) # Perform hybrid search (vector + BM25)
@@ -1484,7 +1543,7 @@ class RAGModule(BaseModule):
query_vector=query_embedding, query_vector=query_embedding,
query_filter=search_filter, query_filter=search_filter,
limit=max_results, limit=max_results,
score_threshold=score_threshold score_threshold=search_score_threshold
) )
else: else:
# Pure vector search with improved threshold # Pure vector search with improved threshold
@@ -1493,7 +1552,7 @@ class RAGModule(BaseModule):
query_vector=query_embedding, query_vector=query_embedding,
query_filter=search_filter, query_filter=search_filter,
limit=max_results, limit=max_results,
score_threshold=score_threshold score_threshold=search_score_threshold
) )
logger.info(f"Raw search results count: {len(search_results)}") logger.info(f"Raw search results count: {len(search_results)}")
@@ -1841,9 +1900,9 @@ async def index_processed_document(processed_doc: ProcessedDocument, collection_
"""Index a processed document""" """Index a processed document"""
return await rag_module.index_processed_document(processed_doc, collection_name) return await rag_module.index_processed_document(processed_doc, collection_name)
async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]: async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
"""Search documents""" """Search documents"""
return await rag_module.search_documents(query, max_results, filters, collection_name) return await rag_module.search_documents(query, max_results, filters, collection_name, score_threshold)
async def delete_document(document_id: str, collection_name: str = None) -> bool: async def delete_document(document_id: str, collection_name: str = None) -> bool:
"""Delete a document""" """Delete a document"""

View File

@@ -7,7 +7,7 @@ export async function POST(request: NextRequest) {
// Make request to backend auth endpoint without requiring existing auth // Make request to backend auth endpoint without requiring existing auth
const baseUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}` const baseUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`
const url = `${baseUrl}/api/auth/login` const url = `${baseUrl}/api-internal/v1/auth/login`
const response = await fetch(url, { const response = await fetch(url, {
method: 'POST', method: 'POST',

View File

@@ -85,8 +85,31 @@ function RAGPageContent() {
const loadStats = async () => { const loadStats = async () => {
try { try {
const data = await apiClient.get('/api-internal/v1/rag/stats') const data = await apiClient.get('/api-internal/v1/rag/stats')
console.log('Stats API response:', data)
// Check if the response has the expected structure
if (data && data.stats && data.stats.collections) {
console.log('✓ Stats has collections property')
setStats(data.stats) setStats(data.stats)
} else {
console.error('✗ Invalid stats structure:', data)
// Set default empty stats to prevent error
setStats({
collections: { total: 0, active: 0 },
documents: { total: 0, processing: 0, processed: 0 },
storage: { total_size_bytes: 0, total_size_mb: 0 },
vectors: { total: 0 }
})
}
} catch (error) { } catch (error) {
console.error('Error loading stats:', error)
// Set default empty stats on error
setStats({
collections: { total: 0, active: 0 },
documents: { total: 0, processing: 0, processed: 0 },
storage: { total_size_bytes: 0, total_size_mb: 0 },
vectors: { total: 0 }
})
} }
} }

View File

@@ -9,7 +9,7 @@ import { Badge } from "@/components/ui/badge"
import { Separator } from "@/components/ui/separator" import { Separator } from "@/components/ui/separator"
import { AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger } from "@/components/ui/alert-dialog" import { AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger } from "@/components/ui/alert-dialog"
import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog" import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog"
import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter } from "lucide-react" import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter, RefreshCw } from "lucide-react"
import { useToast } from "@/hooks/use-toast" import { useToast } from "@/hooks/use-toast"
import { apiClient } from "@/lib/api-client" import { apiClient } from "@/lib/api-client"
import { config } from "@/lib/config" import { config } from "@/lib/config"
@@ -56,6 +56,7 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
const [filterStatus, setFilterStatus] = useState("all") const [filterStatus, setFilterStatus] = useState("all")
const [selectedDocument, setSelectedDocument] = useState<Document | null>(null) const [selectedDocument, setSelectedDocument] = useState<Document | null>(null)
const [deleting, setDeleting] = useState<string | null>(null) const [deleting, setDeleting] = useState<string | null>(null)
const [reprocessing, setReprocessing] = useState<string | null>(null)
const { toast } = useToast() const { toast } = useToast()
useEffect(() => { useEffect(() => {
@@ -157,6 +158,43 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
} }
} }
const handleReprocessDocument = async (documentId: string) => {
setReprocessing(documentId)
try {
await apiClient.post(`/api-internal/v1/rag/documents/${documentId}/reprocess`)
// Update the document status to processing in the UI
setDocuments(prev => prev.map(doc =>
doc.id === documentId
? { ...doc, status: 'processing' as const, processed_at: new Date().toISOString() }
: doc
))
toast({
title: "Success",
description: "Document reprocessing started",
})
// Reload documents after a short delay to see status updates
setTimeout(() => {
loadDocuments()
}, 2000)
} catch (error) {
const errorMessage = error instanceof Error ? error.message : "Failed to reprocess document"
toast({
title: "Error",
description: errorMessage.includes("Cannot reprocess document with status 'processed'")
? "Cannot reprocess documents that are already processed"
: errorMessage,
variant: "destructive",
})
} finally {
setReprocessing(null)
}
}
const formatFileSize = (bytes: number) => { const formatFileSize = (bytes: number) => {
if (bytes === 0) return '0 Bytes' if (bytes === 0) return '0 Bytes'
const k = 1024 const k = 1024
@@ -432,6 +470,21 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
<Download className="h-4 w-4" /> <Download className="h-4 w-4" />
</Button> </Button>
<Button
variant="ghost"
size="sm"
className="h-8 w-8 p-0 hover:bg-blue-100"
onClick={() => handleReprocessDocument(document.id)}
disabled={reprocessing === document.id || document.status === 'processed'}
title={document.status === 'processed' ? "Document already processed" : "Reprocess document"}
>
{reprocessing === document.id ? (
<RefreshCw className="h-4 w-4 animate-spin" />
) : (
<RefreshCw className={`h-4 w-4 ${document.status === 'processed' ? 'text-gray-400' : ''}`} />
)}
</Button>
<AlertDialog> <AlertDialog>
<AlertDialogTrigger asChild> <AlertDialogTrigger asChild>
<Button <Button

View File

@@ -73,6 +73,7 @@ const Navigation = () => {
children: [ children: [
{ href: "/llm", label: "Models & Config" }, { href: "/llm", label: "Models & Config" },
{ href: "/playground", label: "Playground" }, { href: "/playground", label: "Playground" },
{ href: "/rag-demo", label: "RAG Demo" },
] ]
}, },
{ {

View File

@@ -25,6 +25,12 @@ http {
listen 80; listen 80;
server_name localhost; server_name localhost;
# Static files - serve directly from nginx
location = /login_helper.html {
root /usr/share/nginx/html;
try_files $uri =404;
}
# Frontend routes # Frontend routes
location / { location / {
proxy_pass http://frontend; proxy_pass http://frontend;
@@ -65,6 +71,58 @@ http {
} }
} }
# RAG debug API routes - proxy to frontend (for Next.js API routes)
location /api/rag/debug/ {
proxy_pass http://frontend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# CORS headers
add_header 'Access-Control-Allow-Origin' '*' always;
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always;
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
add_header 'Access-Control-Allow-Expose-Headers' 'Content-Length,Content-Range' always;
# Handle preflight requests
if ($request_method = 'OPTIONS') {
add_header 'Access-Control-Allow-Origin' '*';
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS';
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization';
add_header 'Access-Control-Max-Age' 1728000;
add_header 'Content-Type' 'text/plain; charset=utf-8';
add_header 'Content-Length' 0;
return 204;
}
}
# Frontend API routes for authentication - proxy to frontend
location /api/auth/ {
proxy_pass http://frontend;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# CORS headers
add_header 'Access-Control-Allow-Origin' '*' always;
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always;
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
add_header 'Access-Control-Expose-Headers' 'Content-Length,Content-Range' always;
# Handle preflight requests
if ($request_method = 'OPTIONS') {
add_header 'Access-Control-Allow-Origin' '*';
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS';
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization';
add_header 'Access-Control-Max-Age' 1728000;
add_header 'Content-Type' 'text/plain; charset=utf-8';
add_header 'Content-Length' 0;
return 204;
}
}
# Public API routes - proxy to backend (for external clients) # Public API routes - proxy to backend (for external clients)
location /api/ { location /api/ {
proxy_pass http://backend; proxy_pass http://backend;