mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
rag improvements
This commit is contained in:
2
.env
2
.env
@@ -46,7 +46,7 @@ API_RATE_LIMITING_ENABLED=false
|
||||
# ===================================
|
||||
# APPLICATION BASE URL (Required - derives all URLs and CORS)
|
||||
# ===================================
|
||||
BASE_URL=localhost
|
||||
BASE_URL=localhost:80
|
||||
# Frontend derives: APP_URL=http://localhost, API_URL=http://localhost, WS_URL=ws://localhost
|
||||
# Backend derives: CORS_ORIGINS=["http://localhost"]
|
||||
|
||||
|
||||
12
.env.example
12
.env.example
@@ -65,6 +65,16 @@ QDRANT_HOST=enclava-qdrant
|
||||
QDRANT_PORT=6333
|
||||
QDRANT_URL=http://enclava-qdrant:6333
|
||||
|
||||
# ===================================
|
||||
# RAG EMBEDDING CONFIGURATION (Optional overrides)
|
||||
# ===================================
|
||||
# These control embedding throughput to avoid provider 429s.
|
||||
# Defaults are conservative; uncomment to override.
|
||||
# RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE=12
|
||||
# RAG_EMBEDDING_BATCH_SIZE=3
|
||||
# RAG_EMBEDDING_DELAY_BETWEEN_BATCHES=1.0 # seconds
|
||||
# RAG_EMBEDDING_DELAY_PER_REQUEST=0.5 # seconds
|
||||
|
||||
# ===================================
|
||||
# OPTIONAL PRIVATEMODE SETTINGS (Have defaults)
|
||||
# ===================================
|
||||
@@ -130,4 +140,4 @@ QDRANT_URL=http://enclava-qdrant:6333
|
||||
# Required: DATABASE_URL, REDIS_URL, JWT_SECRET, ADMIN_EMAIL, ADMIN_PASSWORD, BASE_URL
|
||||
# Recommended: PRIVATEMODE_API_KEY, QDRANT_HOST, QDRANT_PORT
|
||||
# Optional: All other settings have secure defaults
|
||||
# ===================================
|
||||
# ===================================
|
||||
|
||||
@@ -12,8 +12,8 @@ from ..v1.audit import router as audit_router
|
||||
from ..v1.settings import router as settings_router
|
||||
from ..v1.analytics import router as analytics_router
|
||||
from ..v1.rag import router as rag_router
|
||||
from ..rag_debug import router as rag_debug_router
|
||||
from ..v1.prompt_templates import router as prompt_templates_router
|
||||
from ..v1.security import router as security_router
|
||||
from ..v1.plugin_registry import router as plugin_registry_router
|
||||
from ..v1.platform import router as platform_router
|
||||
from ..v1.llm_internal import router as llm_internal_router
|
||||
@@ -52,11 +52,12 @@ internal_api_router.include_router(analytics_router, prefix="/analytics", tags=[
|
||||
# Include RAG routes (frontend RAG document management)
|
||||
internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"])
|
||||
|
||||
# Include RAG debug routes (for demo and debugging)
|
||||
internal_api_router.include_router(rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"])
|
||||
|
||||
# Include prompt template routes (frontend prompt template management)
|
||||
internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"])
|
||||
|
||||
# Include security routes (frontend security settings)
|
||||
internal_api_router.include_router(security_router, prefix="/security", tags=["internal-security"])
|
||||
|
||||
# Include plugin registry routes (frontend plugin management)
|
||||
internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"])
|
||||
|
||||
@@ -16,7 +16,6 @@ from .analytics import router as analytics_router
|
||||
from .rag import router as rag_router
|
||||
from .chatbot import router as chatbot_router
|
||||
from .prompt_templates import router as prompt_templates_router
|
||||
from .security import router as security_router
|
||||
from .plugin_registry import router as plugin_registry_router
|
||||
|
||||
# Create main API router
|
||||
@@ -61,8 +60,6 @@ api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
|
||||
# Include prompt template routes
|
||||
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
|
||||
|
||||
# Include security routes
|
||||
api_router.include_router(security_router, prefix="/security", tags=["security"])
|
||||
|
||||
|
||||
# Include plugin registry routes
|
||||
|
||||
@@ -745,8 +745,7 @@ async def get_llm_metrics(
|
||||
"total_requests": metrics.total_requests,
|
||||
"successful_requests": metrics.successful_requests,
|
||||
"failed_requests": metrics.failed_requests,
|
||||
"security_blocked_requests": metrics.security_blocked_requests,
|
||||
"average_latency_ms": metrics.average_latency_ms,
|
||||
"average_latency_ms": metrics.average_latency_ms,
|
||||
"average_risk_score": metrics.average_risk_score,
|
||||
"provider_metrics": metrics.provider_metrics,
|
||||
"last_updated": metrics.last_updated.isoformat()
|
||||
|
||||
@@ -3,12 +3,14 @@ RAG API Endpoints
|
||||
Provides REST API for RAG (Retrieval Augmented Generation) operations
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from pydantic import BaseModel
|
||||
import io
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
@@ -16,6 +18,9 @@ from app.models.user import User
|
||||
from app.services.rag_service import RAGService
|
||||
from app.utils.exceptions import APIException
|
||||
|
||||
# Import RAG module from module manager
|
||||
from app.services.module_manager import module_manager
|
||||
|
||||
|
||||
router = APIRouter(tags=["RAG"])
|
||||
|
||||
@@ -78,14 +83,25 @@ async def get_collections(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get all RAG collections from Qdrant (source of truth) with PostgreSQL metadata"""
|
||||
"""Get all RAG collections - live data directly from Qdrant (source of truth)"""
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
collections_data = await rag_service.get_all_collections(skip=skip, limit=limit)
|
||||
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||
|
||||
# Get live stats from Qdrant
|
||||
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||
collections = stats_data.get("collections", [])
|
||||
|
||||
# Apply pagination
|
||||
start_idx = skip
|
||||
end_idx = skip + limit
|
||||
paginated_collections = collections[start_idx:end_idx]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"collections": collections_data,
|
||||
"total": len(collections_data)
|
||||
"collections": paginated_collections,
|
||||
"total": len(collections),
|
||||
"total_documents": stats_data.get("total_documents", 0),
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0)
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -116,6 +132,62 @@ async def create_collection(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stats", response_model=dict)
|
||||
async def get_rag_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get overall RAG statistics - live data directly from Qdrant"""
|
||||
try:
|
||||
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||
|
||||
# Get live stats from Qdrant
|
||||
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||
|
||||
# Calculate active collections (collections with documents)
|
||||
active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
|
||||
|
||||
# Calculate processing documents from database
|
||||
processing_docs = 0
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
from app.models.rag_document import RagDocument, ProcessingStatus
|
||||
|
||||
result = await db.execute(
|
||||
select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
|
||||
)
|
||||
processing_docs = len(result.scalars().all())
|
||||
except Exception:
|
||||
pass # If database query fails, default to 0
|
||||
|
||||
response_data = {
|
||||
"success": True,
|
||||
"stats": {
|
||||
"collections": {
|
||||
"total": stats_data.get("total_collections", 0),
|
||||
"active": active_collections
|
||||
},
|
||||
"documents": {
|
||||
"total": stats_data.get("total_documents", 0),
|
||||
"processing": processing_docs,
|
||||
"processed": stats_data.get("total_documents", 0) # Indexed documents
|
||||
},
|
||||
"storage": {
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0),
|
||||
"total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
|
||||
},
|
||||
"vectors": {
|
||||
"total": stats_data.get("total_documents", 0) # Same as documents for RAG
|
||||
},
|
||||
"last_updated": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
return response_data
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/collections/{collection_id}", response_model=dict)
|
||||
async def get_collection(
|
||||
collection_id: int,
|
||||
@@ -225,21 +297,65 @@ async def upload_document(
|
||||
try:
|
||||
# Read file content
|
||||
file_content = await file.read()
|
||||
|
||||
|
||||
if len(file_content) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file uploaded")
|
||||
|
||||
|
||||
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
|
||||
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
|
||||
|
||||
|
||||
# Validate file can be read before processing
|
||||
filename = file.filename or "unknown"
|
||||
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
|
||||
|
||||
try:
|
||||
# Test file readability based on type
|
||||
if file_extension == 'jsonl':
|
||||
# Validate JSONL format - try to parse first few lines
|
||||
try:
|
||||
content_str = file_content.decode('utf-8')
|
||||
lines = content_str.strip().split('\n')[:5] # Check first 5 lines
|
||||
import json
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip(): # Skip empty lines
|
||||
json.loads(line) # Will raise JSONDecodeError if invalid
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
|
||||
|
||||
elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
|
||||
# Validate text files can be decoded
|
||||
try:
|
||||
file_content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
|
||||
elif file_extension in ['pdf']:
|
||||
# For PDF files, just check if it starts with PDF signature
|
||||
if not file_content.startswith(b'%PDF'):
|
||||
raise HTTPException(status_code=400, detail="Invalid PDF file format")
|
||||
|
||||
elif file_extension in ['docx', 'xlsx', 'pptx']:
|
||||
# For Office documents, check ZIP signature
|
||||
if not file_content.startswith(b'PK'):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
|
||||
|
||||
# For other file types, we'll rely on the document processor
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
|
||||
|
||||
rag_service = RAGService(db)
|
||||
document = await rag_service.upload_document(
|
||||
collection_id=collection_id,
|
||||
file_content=file_content,
|
||||
filename=file.filename or "unknown",
|
||||
filename=filename,
|
||||
content_type=file.content_type
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"document": document.to_dict(),
|
||||
@@ -362,21 +478,167 @@ async def download_document(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Stats Endpoint
|
||||
|
||||
@router.get("/stats", response_model=dict)
|
||||
async def get_rag_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
# Debug Endpoints
|
||||
|
||||
@router.post("/debug/search")
|
||||
async def search_with_debug(
|
||||
query: str,
|
||||
max_results: int = 10,
|
||||
score_threshold: float = 0.3,
|
||||
collection_name: str = None,
|
||||
config: Dict[str, Any] = None,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get RAG system statistics"""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Enhanced search with comprehensive debug information
|
||||
"""
|
||||
# Get RAG module from module manager
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||
|
||||
debug_info = {}
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
stats = await rag_service.get_stats()
|
||||
|
||||
# Apply configuration if provided
|
||||
if config:
|
||||
# Update RAG config temporarily
|
||||
original_config = rag_module.config.copy()
|
||||
rag_module.config.update(config)
|
||||
|
||||
# Generate query embedding (with or without prefix)
|
||||
if config and config.get("use_query_prefix"):
|
||||
optimized_query = f"query: {query}"
|
||||
else:
|
||||
optimized_query = query
|
||||
|
||||
query_embedding = await rag_module._generate_embedding(optimized_query)
|
||||
|
||||
# Store embedding info for debug
|
||||
if config and config.get("debug", {}).get("show_embeddings"):
|
||||
debug_info["query_embedding"] = query_embedding[:10] # First 10 dimensions
|
||||
debug_info["embedding_dimension"] = len(query_embedding)
|
||||
debug_info["optimized_query"] = optimized_query
|
||||
|
||||
# Perform search
|
||||
search_start = asyncio.get_event_loop().time()
|
||||
results = await rag_module.search_documents(
|
||||
query,
|
||||
max_results=max_results,
|
||||
score_threshold=score_threshold,
|
||||
collection_name=collection_name
|
||||
)
|
||||
search_time = (asyncio.get_event_loop().time() - search_start) * 1000
|
||||
|
||||
# Calculate score statistics
|
||||
scores = [r.score for r in results if r.score is not None]
|
||||
if scores:
|
||||
import statistics
|
||||
debug_info["score_stats"] = {
|
||||
"min": min(scores),
|
||||
"max": max(scores),
|
||||
"avg": statistics.mean(scores),
|
||||
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0
|
||||
}
|
||||
|
||||
# Get collection statistics
|
||||
try:
|
||||
from qdrant_client.http.models import Filter
|
||||
collection_name = collection_name or rag_module.default_collection_name
|
||||
|
||||
# Count total documents
|
||||
count_result = rag_module.qdrant_client.count(
|
||||
collection_name=collection_name,
|
||||
count_filter=Filter(must=[])
|
||||
)
|
||||
total_points = count_result.count
|
||||
|
||||
# Get unique documents and languages
|
||||
scroll_result = rag_module.qdrant_client.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=1000, # Sample for stats
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
unique_docs = set()
|
||||
languages = set()
|
||||
|
||||
for point in scroll_result[0]:
|
||||
payload = point.payload or {}
|
||||
doc_id = payload.get("document_id")
|
||||
if doc_id:
|
||||
unique_docs.add(doc_id)
|
||||
|
||||
language = payload.get("language")
|
||||
if language:
|
||||
languages.add(language)
|
||||
|
||||
debug_info["collection_stats"] = {
|
||||
"total_documents": len(unique_docs),
|
||||
"total_chunks": total_points,
|
||||
"languages": sorted(list(languages))
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
debug_info["collection_stats_error"] = str(e)
|
||||
|
||||
# Enhance results with debug info
|
||||
enhanced_results = []
|
||||
for result in results:
|
||||
enhanced_result = {
|
||||
"document": {
|
||||
"id": result.document.id,
|
||||
"content": result.document.content,
|
||||
"metadata": result.document.metadata
|
||||
},
|
||||
"score": result.score,
|
||||
"debug_info": {}
|
||||
}
|
||||
|
||||
# Add hybrid search debug info if available
|
||||
metadata = result.document.metadata or {}
|
||||
if "_vector_score" in metadata:
|
||||
enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
|
||||
if "_bm25_score" in metadata:
|
||||
enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
|
||||
|
||||
enhanced_results.append(enhanced_result)
|
||||
|
||||
# Note: Analytics logging disabled (module not available)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"stats": stats
|
||||
"results": enhanced_results,
|
||||
"debug_info": debug_info,
|
||||
"search_time_ms": search_time,
|
||||
"timestamp": start_time.isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
# Note: Analytics logging disabled (module not available)
|
||||
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
||||
|
||||
finally:
|
||||
# Restore original config if modified
|
||||
if config and 'original_config' in locals():
|
||||
rag_module.config = original_config
|
||||
|
||||
|
||||
@router.get("/debug/config")
|
||||
async def get_current_config(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current RAG configuration"""
|
||||
# Get RAG module from module manager
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||
|
||||
return {
|
||||
"config": rag_module.config,
|
||||
"embedding_model": rag_module.embedding_model,
|
||||
"enabled": rag_module.enabled,
|
||||
"collections": await rag_module._get_collections_safely()
|
||||
}
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
"""
|
||||
Security API endpoints for monitoring and configuration
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.security import get_current_active_user, RequiresRole
|
||||
from app.middleware.security import get_security_stats, get_request_auth_level, get_request_risk_score
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(tags=["security"])
|
||||
|
||||
|
||||
# Pydantic models for API responses
|
||||
class SecurityStatsResponse(BaseModel):
|
||||
"""Security statistics response model"""
|
||||
total_requests_analyzed: int
|
||||
threats_detected: int
|
||||
threats_blocked: int
|
||||
anomalies_detected: int
|
||||
rate_limits_exceeded: int
|
||||
avg_analysis_time: float
|
||||
threat_types: Dict[str, int]
|
||||
threat_levels: Dict[str, int]
|
||||
top_attacking_ips: List[tuple]
|
||||
security_enabled: bool
|
||||
threat_detection_enabled: bool
|
||||
rate_limiting_enabled: bool
|
||||
|
||||
|
||||
class SecurityConfigResponse(BaseModel):
|
||||
"""Security configuration response model"""
|
||||
security_enabled: bool = Field(description="Overall security system enabled")
|
||||
threat_detection_enabled: bool = Field(description="Threat detection analysis enabled")
|
||||
rate_limiting_enabled: bool = Field(description="Rate limiting enabled")
|
||||
ip_reputation_enabled: bool = Field(description="IP reputation checking enabled")
|
||||
anomaly_detection_enabled: bool = Field(description="Anomaly detection enabled")
|
||||
security_headers_enabled: bool = Field(description="Security headers enabled")
|
||||
|
||||
# Rate limiting settings
|
||||
unauthenticated_per_minute: int = Field(description="Rate limit for unauthenticated requests per minute")
|
||||
authenticated_per_minute: int = Field(description="Rate limit for authenticated users per minute")
|
||||
api_key_per_minute: int = Field(description="Rate limit for API key users per minute")
|
||||
premium_per_minute: int = Field(description="Rate limit for premium users per minute")
|
||||
|
||||
# Security thresholds
|
||||
risk_threshold: float = Field(description="Risk score threshold for blocking requests")
|
||||
warning_threshold: float = Field(description="Risk score threshold for warnings")
|
||||
anomaly_threshold: float = Field(description="Anomaly severity threshold")
|
||||
|
||||
# IP settings
|
||||
blocked_ips: List[str] = Field(description="List of blocked IP addresses")
|
||||
allowed_ips: List[str] = Field(description="List of allowed IP addresses (empty = allow all)")
|
||||
|
||||
|
||||
class RateLimitInfoResponse(BaseModel):
|
||||
"""Rate limit information for current request"""
|
||||
auth_level: str = Field(description="Authentication level (unauthenticated, authenticated, api_key, premium)")
|
||||
current_limits: Dict[str, int] = Field(description="Current rate limits for this auth level")
|
||||
remaining_requests: Optional[Dict[str, int]] = Field(description="Estimated remaining requests (if available)")
|
||||
|
||||
|
||||
@router.get("/stats", response_model=SecurityStatsResponse)
|
||||
async def get_security_statistics(
|
||||
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
|
||||
):
|
||||
"""
|
||||
Get security system statistics
|
||||
|
||||
Requires admin role. Returns comprehensive statistics about:
|
||||
- Request analysis counts
|
||||
- Threat detection results
|
||||
- Rate limiting enforcement
|
||||
- Top attacking IPs
|
||||
- Performance metrics
|
||||
"""
|
||||
try:
|
||||
stats = get_security_stats()
|
||||
return SecurityStatsResponse(**stats)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting security stats: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve security statistics"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config", response_model=SecurityConfigResponse)
|
||||
async def get_security_config(
|
||||
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
|
||||
):
|
||||
"""
|
||||
Get current security configuration
|
||||
|
||||
Requires admin role. Returns current security settings including:
|
||||
- Feature enablement flags
|
||||
- Rate limiting thresholds
|
||||
- Security thresholds
|
||||
- IP allowlists/blocklists
|
||||
"""
|
||||
return SecurityConfigResponse(
|
||||
security_enabled=settings.API_SECURITY_ENABLED,
|
||||
threat_detection_enabled=settings.API_THREAT_DETECTION_ENABLED,
|
||||
rate_limiting_enabled=settings.API_RATE_LIMITING_ENABLED,
|
||||
ip_reputation_enabled=settings.API_IP_REPUTATION_ENABLED,
|
||||
anomaly_detection_enabled=settings.API_ANOMALY_DETECTION_ENABLED,
|
||||
security_headers_enabled=settings.API_SECURITY_HEADERS_ENABLED,
|
||||
|
||||
unauthenticated_per_minute=settings.API_RATE_LIMIT_UNAUTHENTICATED_PER_MINUTE,
|
||||
authenticated_per_minute=settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE,
|
||||
api_key_per_minute=settings.API_RATE_LIMIT_API_KEY_PER_MINUTE,
|
||||
premium_per_minute=settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE,
|
||||
|
||||
risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
|
||||
warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
|
||||
anomaly_threshold=settings.API_SECURITY_ANOMALY_THRESHOLD,
|
||||
|
||||
blocked_ips=settings.API_BLOCKED_IPS,
|
||||
allowed_ips=settings.API_ALLOWED_IPS
|
||||
)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_security_status(
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
Get security status for current request
|
||||
|
||||
Returns information about the security analysis of the current request:
|
||||
- Authentication level
|
||||
- Risk score (if available)
|
||||
- Rate limiting status
|
||||
"""
|
||||
auth_level = get_request_auth_level(request)
|
||||
risk_score = get_request_risk_score(request)
|
||||
|
||||
# Get rate limits for current auth level
|
||||
from app.core.threat_detection import AuthLevel
|
||||
try:
|
||||
auth_enum = AuthLevel(auth_level)
|
||||
from app.core.threat_detection import threat_detection_service
|
||||
minute_limit, hour_limit = threat_detection_service.get_rate_limits(auth_enum)
|
||||
|
||||
rate_limit_info = RateLimitInfoResponse(
|
||||
auth_level=auth_level,
|
||||
current_limits={
|
||||
"per_minute": minute_limit,
|
||||
"per_hour": hour_limit
|
||||
},
|
||||
remaining_requests=None # We don't track remaining requests in current implementation
|
||||
)
|
||||
except ValueError:
|
||||
rate_limit_info = RateLimitInfoResponse(
|
||||
auth_level=auth_level,
|
||||
current_limits={},
|
||||
remaining_requests=None
|
||||
)
|
||||
|
||||
return {
|
||||
"security_enabled": settings.API_SECURITY_ENABLED,
|
||||
"auth_level": auth_level,
|
||||
"risk_score": round(risk_score, 3) if risk_score > 0 else None,
|
||||
"rate_limit_info": rate_limit_info.dict(),
|
||||
"security_headers_enabled": settings.API_SECURITY_HEADERS_ENABLED
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_security_analysis(
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
|
||||
):
|
||||
"""
|
||||
Test security analysis on current request
|
||||
|
||||
Requires admin role. Manually triggers security analysis on the current request
|
||||
and returns detailed results. Useful for testing security rules and thresholds.
|
||||
"""
|
||||
try:
|
||||
from app.middleware.security import analyze_request_security
|
||||
|
||||
analysis = await analyze_request_security(request, current_user)
|
||||
|
||||
return {
|
||||
"analysis_complete": True,
|
||||
"is_threat": analysis.is_threat,
|
||||
"risk_score": round(analysis.risk_score, 3),
|
||||
"auth_level": analysis.auth_level.value,
|
||||
"should_block": analysis.should_block,
|
||||
"rate_limit_exceeded": analysis.rate_limit_exceeded,
|
||||
"threat_count": len(analysis.threats),
|
||||
"threats": [
|
||||
{
|
||||
"type": threat.threat_type,
|
||||
"level": threat.level.value,
|
||||
"confidence": round(threat.confidence, 3),
|
||||
"description": threat.description,
|
||||
"mitigation": threat.mitigation
|
||||
}
|
||||
for threat in analysis.threats
|
||||
],
|
||||
"recommendations": analysis.recommendations
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in security analysis test: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to perform security analysis test"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def security_health_check():
|
||||
"""
|
||||
Security system health check
|
||||
|
||||
Public endpoint that returns the health status of the security system.
|
||||
Does not require authentication.
|
||||
"""
|
||||
try:
|
||||
stats = get_security_stats()
|
||||
|
||||
# Basic health checks
|
||||
is_healthy = (
|
||||
settings.API_SECURITY_ENABLED and
|
||||
stats.get("total_requests_analyzed", 0) >= 0 and
|
||||
stats.get("avg_analysis_time", 0) < 1.0 # Analysis should be under 1 second
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "healthy" if is_healthy else "degraded",
|
||||
"security_enabled": settings.API_SECURITY_ENABLED,
|
||||
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
|
||||
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED,
|
||||
"avg_analysis_time_ms": round(stats.get("avg_analysis_time", 0) * 1000, 2),
|
||||
"total_requests_analyzed": stats.get("total_requests_analyzed", 0)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Security health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": "Security system error",
|
||||
"security_enabled": settings.API_SECURITY_ENABLED
|
||||
}
|
||||
@@ -97,7 +97,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
|
||||
"api": {
|
||||
# Security Settings
|
||||
"security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"},
|
||||
"threat_detection_enabled": {"value": True, "type": "boolean", "description": "Enable threat detection analysis"},
|
||||
"rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"},
|
||||
"ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"},
|
||||
"anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"},
|
||||
@@ -112,7 +111,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
|
||||
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"},
|
||||
|
||||
# Security Thresholds
|
||||
"security_risk_threshold": {"value": 0.8, "type": "float", "description": "Risk score threshold for blocking requests (0.0-1.0)"},
|
||||
"security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"},
|
||||
"anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"},
|
||||
|
||||
@@ -601,7 +599,6 @@ async def reset_to_defaults(
|
||||
"api": {
|
||||
# Security Settings
|
||||
"security_enabled": {"value": True, "type": "boolean"},
|
||||
"threat_detection_enabled": {"value": True, "type": "boolean"},
|
||||
"rate_limiting_enabled": {"value": True, "type": "boolean"},
|
||||
"ip_reputation_enabled": {"value": True, "type": "boolean"},
|
||||
"anomaly_detection_enabled": {"value": True, "type": "boolean"},
|
||||
@@ -616,7 +613,6 @@ async def reset_to_defaults(
|
||||
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer"},
|
||||
|
||||
# Security Thresholds
|
||||
"security_risk_threshold": {"value": 0.8, "type": "float"},
|
||||
"security_warning_threshold": {"value": 0.6, "type": "float"},
|
||||
"anomaly_threshold": {"value": 0.7, "type": "float"},
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ class Settings(BaseSettings):
|
||||
APP_LOG_LEVEL: str = os.getenv("APP_LOG_LEVEL", "INFO")
|
||||
APP_HOST: str = os.getenv("APP_HOST", "0.0.0.0")
|
||||
APP_PORT: int = int(os.getenv("APP_PORT", "8000"))
|
||||
BACKEND_INTERNAL_PORT: int = int(os.getenv("BACKEND_INTERNAL_PORT", "8000"))
|
||||
FRONTEND_INTERNAL_PORT: int = int(os.getenv("FRONTEND_INTERNAL_PORT", "3000"))
|
||||
|
||||
# Detailed logging for LLM interactions
|
||||
LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM
|
||||
@@ -73,16 +75,11 @@ class Settings(BaseSettings):
|
||||
QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost")
|
||||
QDRANT_PORT: int = int(os.getenv("QDRANT_PORT", "6333"))
|
||||
QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY")
|
||||
QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
|
||||
# API & Security Settings
|
||||
API_SECURITY_ENABLED: bool = os.getenv("API_SECURITY_ENABLED", "True").lower() == "true"
|
||||
API_THREAT_DETECTION_ENABLED: bool = os.getenv("API_THREAT_DETECTION_ENABLED", "True").lower() == "true"
|
||||
API_IP_REPUTATION_ENABLED: bool = os.getenv("API_IP_REPUTATION_ENABLED", "True").lower() == "true"
|
||||
API_ANOMALY_DETECTION_ENABLED: bool = os.getenv("API_ANOMALY_DETECTION_ENABLED", "True").lower() == "true"
|
||||
|
||||
|
||||
# Rate Limiting Configuration
|
||||
API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true"
|
||||
|
||||
|
||||
# PrivateMode Standard tier limits (organization-level, not per user)
|
||||
# These are shared across all API keys and users in the organization
|
||||
PRIVATEMODE_REQUESTS_PER_MINUTE: int = int(os.getenv("PRIVATEMODE_REQUESTS_PER_MINUTE", "20"))
|
||||
@@ -101,23 +98,14 @@ class Settings(BaseSettings):
|
||||
# Premium/Enterprise API keys
|
||||
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
|
||||
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
|
||||
|
||||
# Security Thresholds
|
||||
API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score
|
||||
API_SECURITY_WARNING_THRESHOLD: float = float(os.getenv("API_SECURITY_WARNING_THRESHOLD", "0.6")) # Log warnings above this threshold
|
||||
API_SECURITY_ANOMALY_THRESHOLD: float = float(os.getenv("API_SECURITY_ANOMALY_THRESHOLD", "0.7")) # Flag anomalies above this threshold
|
||||
|
||||
|
||||
# Request Size Limits
|
||||
API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB
|
||||
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium
|
||||
|
||||
# IP Security
|
||||
API_BLOCKED_IPS: List[str] = os.getenv("API_BLOCKED_IPS", "").split(",") if os.getenv("API_BLOCKED_IPS") else []
|
||||
API_ALLOWED_IPS: List[str] = os.getenv("API_ALLOWED_IPS", "").split(",") if os.getenv("API_ALLOWED_IPS") else []
|
||||
API_IP_REPUTATION_CACHE_TTL: int = int(os.getenv("API_IP_REPUTATION_CACHE_TTL", "3600")) # 1 hour
|
||||
|
||||
# Security Headers
|
||||
API_SECURITY_HEADERS_ENABLED: bool = os.getenv("API_SECURITY_HEADERS_ENABLED", "True").lower() == "true"
|
||||
API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
|
||||
|
||||
# Monitoring
|
||||
@@ -129,6 +117,19 @@ class Settings(BaseSettings):
|
||||
|
||||
# Module configuration
|
||||
MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml")
|
||||
|
||||
# RAG Embedding Configuration
|
||||
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12"))
|
||||
RAG_EMBEDDING_BATCH_SIZE: int = int(os.getenv("RAG_EMBEDDING_BATCH_SIZE", "3"))
|
||||
RAG_EMBEDDING_RETRY_COUNT: int = int(os.getenv("RAG_EMBEDDING_RETRY_COUNT", "3"))
|
||||
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv("RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16")
|
||||
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0"))
|
||||
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
|
||||
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
|
||||
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
|
||||
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
|
||||
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
|
||||
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
|
||||
|
||||
# Plugin configuration
|
||||
PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins")
|
||||
@@ -142,9 +143,12 @@ class Settings(BaseSettings):
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"case_sensitive": True
|
||||
"case_sensitive": True,
|
||||
# Ignore unknown environment variables to avoid validation errors
|
||||
# when optional/deprecated flags are present in .env
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,744 +0,0 @@
|
||||
"""
|
||||
Core threat detection and security analysis for the platform
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Set, Tuple, Any, Union
|
||||
from urllib.parse import unquote
|
||||
|
||||
from fastapi import Request
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ThreatLevel(Enum):
|
||||
"""Threat severity levels"""
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class AuthLevel(Enum):
|
||||
"""Authentication levels for rate limiting"""
|
||||
AUTHENTICATED = "authenticated"
|
||||
API_KEY = "api_key"
|
||||
PREMIUM = "premium"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityThreat:
|
||||
"""Security threat detection result"""
|
||||
threat_type: str
|
||||
level: ThreatLevel
|
||||
confidence: float
|
||||
description: str
|
||||
source_ip: str
|
||||
user_agent: Optional[str] = None
|
||||
request_path: Optional[str] = None
|
||||
payload: Optional[str] = None
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
mitigation: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityAnalysis:
|
||||
"""Comprehensive security analysis result"""
|
||||
is_threat: bool
|
||||
threats: List[SecurityThreat]
|
||||
risk_score: float
|
||||
recommendations: List[str]
|
||||
auth_level: AuthLevel
|
||||
rate_limit_exceeded: bool
|
||||
should_block: bool
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitInfo:
|
||||
"""Rate limiting information"""
|
||||
auth_level: AuthLevel
|
||||
requests_per_minute: int
|
||||
requests_per_hour: int
|
||||
minute_limit: int
|
||||
hour_limit: int
|
||||
exceeded: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnomalyDetection:
|
||||
"""Anomaly detection result"""
|
||||
is_anomaly: bool
|
||||
anomaly_type: str
|
||||
severity: float
|
||||
details: Dict[str, Any]
|
||||
baseline_value: Optional[float] = None
|
||||
current_value: Optional[float] = None
|
||||
|
||||
|
||||
class ThreatDetectionService:
|
||||
"""Core threat detection and security analysis service"""
|
||||
|
||||
def __init__(self):
|
||||
self.name = "threat_detection"
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'total_requests_analyzed': 0,
|
||||
'threats_detected': 0,
|
||||
'threats_blocked': 0,
|
||||
'anomalies_detected': 0,
|
||||
'rate_limits_exceeded': 0,
|
||||
'total_analysis_time': 0,
|
||||
'threat_types': defaultdict(int),
|
||||
'threat_levels': defaultdict(int),
|
||||
'attacking_ips': defaultdict(int)
|
||||
}
|
||||
|
||||
# Threat detection patterns
|
||||
self.sql_injection_patterns = [
|
||||
r"(\bunion\b.*\bselect\b)",
|
||||
r"(\bselect\b.*\bfrom\b)",
|
||||
r"(\binsert\b.*\binto\b)",
|
||||
r"(\bupdate\b.*\bset\b)",
|
||||
r"(\bdelete\b.*\bfrom\b)",
|
||||
r"(\bdrop\b.*\btable\b)",
|
||||
r"(\bor\b.*\b1\s*=\s*1\b)",
|
||||
r"(\band\b.*\b1\s*=\s*1\b)",
|
||||
r"(\bexec\b.*\bxp_\w+)",
|
||||
r"(\bsp_\w+)",
|
||||
r"(\bsleep\b\s*\(\s*\d+\s*\))",
|
||||
r"(\bwaitfor\b.*\bdelay\b)",
|
||||
r"(\bbenchmark\b\s*\(\s*\d+)",
|
||||
r"(\bload_file\b\s*\()",
|
||||
r"(\binto\b.*\boutfile\b)"
|
||||
]
|
||||
|
||||
self.xss_patterns = [
|
||||
r"<script[^>]*>.*?</script>",
|
||||
r"<iframe[^>]*>.*?</iframe>",
|
||||
r"<object[^>]*>.*?</object>",
|
||||
r"<embed[^>]*>.*?</embed>",
|
||||
r"<link[^>]*>",
|
||||
r"<meta[^>]*>",
|
||||
r"javascript:",
|
||||
r"vbscript:",
|
||||
r"on\w+\s*=",
|
||||
r"style\s*=.*expression",
|
||||
r"style\s*=.*javascript"
|
||||
]
|
||||
|
||||
self.path_traversal_patterns = [
|
||||
r"\.\.\/",
|
||||
r"\.\.\\",
|
||||
r"%2e%2e%2f",
|
||||
r"%2e%2e%5c",
|
||||
r"..%2f",
|
||||
r"..%5c",
|
||||
r"%252e%252e%252f",
|
||||
r"%252e%252e%255c"
|
||||
]
|
||||
|
||||
self.command_injection_patterns = [
|
||||
r";\s*cat\s+",
|
||||
r";\s*ls\s+",
|
||||
r";\s*pwd\s*",
|
||||
r";\s*whoami\s*",
|
||||
r";\s*id\s*",
|
||||
r";\s*uname\s*",
|
||||
r";\s*ps\s+",
|
||||
r";\s*netstat\s+",
|
||||
r";\s*wget\s+",
|
||||
r";\s*curl\s+",
|
||||
r"\|\s*cat\s+",
|
||||
r"\|\s*ls\s+",
|
||||
r"&&\s*cat\s+",
|
||||
r"&&\s*ls\s+"
|
||||
]
|
||||
|
||||
self.suspicious_ua_patterns = [
|
||||
r"sqlmap",
|
||||
r"nikto",
|
||||
r"nmap",
|
||||
r"masscan",
|
||||
r"zap",
|
||||
r"burp",
|
||||
r"w3af",
|
||||
r"acunetix",
|
||||
r"nessus",
|
||||
r"openvas",
|
||||
r"metasploit"
|
||||
]
|
||||
|
||||
# Rate limiting tracking - separate by auth level (excluding unauthenticated since they're blocked)
|
||||
self.rate_limits = {
|
||||
AuthLevel.AUTHENTICATED: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
|
||||
AuthLevel.API_KEY: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
|
||||
AuthLevel.PREMIUM: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)})
|
||||
}
|
||||
|
||||
# Anomaly detection
|
||||
self.request_history = deque(maxlen=1000)
|
||||
self.ip_history = defaultdict(lambda: deque(maxlen=100))
|
||||
self.endpoint_history = defaultdict(lambda: deque(maxlen=100))
|
||||
|
||||
# Blocked and allowed IPs
|
||||
self.blocked_ips = set(settings.API_BLOCKED_IPS)
|
||||
self.allowed_ips = set(settings.API_ALLOWED_IPS) if settings.API_ALLOWED_IPS else None
|
||||
|
||||
# IP reputation cache
|
||||
self.ip_reputation_cache = {}
|
||||
self.cache_expiry = {}
|
||||
|
||||
# Compile patterns for performance
|
||||
self._compile_patterns()
|
||||
|
||||
logger.info(f"ThreatDetectionService initialized with {len(self.sql_injection_patterns)} SQL patterns, "
|
||||
f"{len(self.xss_patterns)} XSS patterns, rate limiting enabled: {settings.API_RATE_LIMITING_ENABLED}")
|
||||
|
||||
def _compile_patterns(self):
|
||||
"""Compile regex patterns for better performance"""
|
||||
try:
|
||||
self.compiled_sql_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.sql_injection_patterns]
|
||||
self.compiled_xss_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.xss_patterns]
|
||||
self.compiled_path_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.path_traversal_patterns]
|
||||
self.compiled_cmd_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.command_injection_patterns]
|
||||
self.compiled_ua_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.suspicious_ua_patterns]
|
||||
except re.error as e:
|
||||
logger.error(f"Failed to compile security patterns: {e}")
|
||||
# Fallback to empty lists to prevent crashes
|
||||
self.compiled_sql_patterns = []
|
||||
self.compiled_xss_patterns = []
|
||||
self.compiled_path_patterns = []
|
||||
self.compiled_cmd_patterns = []
|
||||
self.compiled_ua_patterns = []
|
||||
|
||||
def determine_auth_level(self, request: Request, user_context: Optional[Dict] = None) -> AuthLevel:
|
||||
"""Determine authentication level for rate limiting"""
|
||||
# Check if request has API key authentication
|
||||
if hasattr(request.state, 'api_key_context') and request.state.api_key_context:
|
||||
api_key = request.state.api_key_context.get('api_key')
|
||||
if api_key and hasattr(api_key, 'tier'):
|
||||
# Check for premium tier
|
||||
if api_key.tier in ['premium', 'enterprise']:
|
||||
return AuthLevel.PREMIUM
|
||||
return AuthLevel.API_KEY
|
||||
|
||||
# Check for JWT authentication
|
||||
if user_context or hasattr(request.state, 'user'):
|
||||
return AuthLevel.AUTHENTICATED
|
||||
|
||||
# Check Authorization header for API key
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
api_key_header = request.headers.get("X-API-Key", "")
|
||||
if auth_header.startswith("Bearer ") or api_key_header:
|
||||
return AuthLevel.API_KEY
|
||||
|
||||
# Default to authenticated since unauthenticated requests are blocked at middleware
|
||||
return AuthLevel.AUTHENTICATED
|
||||
|
||||
def get_rate_limits(self, auth_level: AuthLevel) -> Tuple[int, int]:
|
||||
"""Get rate limits for authentication level"""
|
||||
if not settings.API_RATE_LIMITING_ENABLED:
|
||||
return float('inf'), float('inf')
|
||||
|
||||
if auth_level == AuthLevel.AUTHENTICATED:
|
||||
return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
|
||||
elif auth_level == AuthLevel.API_KEY:
|
||||
return (settings.API_RATE_LIMIT_API_KEY_PER_MINUTE, settings.API_RATE_LIMIT_API_KEY_PER_HOUR)
|
||||
elif auth_level == AuthLevel.PREMIUM:
|
||||
return (settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE, settings.API_RATE_LIMIT_PREMIUM_PER_HOUR)
|
||||
else:
|
||||
# Fallback to authenticated limits
|
||||
return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
|
||||
|
||||
def check_rate_limit(self, client_ip: str, auth_level: AuthLevel) -> RateLimitInfo:
|
||||
"""Check if request exceeds rate limits"""
|
||||
minute_limit, hour_limit = self.get_rate_limits(auth_level)
|
||||
current_time = time.time()
|
||||
|
||||
# Get or create tracking for this auth level
|
||||
if auth_level not in self.rate_limits:
|
||||
# This shouldn't happen, but handle gracefully
|
||||
return RateLimitInfo(
|
||||
auth_level=auth_level,
|
||||
requests_per_minute=0,
|
||||
requests_per_hour=0,
|
||||
minute_limit=minute_limit,
|
||||
hour_limit=hour_limit,
|
||||
exceeded=False
|
||||
)
|
||||
|
||||
ip_limits = self.rate_limits[auth_level][client_ip]
|
||||
|
||||
# Clean old entries
|
||||
minute_ago = current_time - 60
|
||||
hour_ago = current_time - 3600
|
||||
|
||||
while ip_limits['minute'] and ip_limits['minute'][0] < minute_ago:
|
||||
ip_limits['minute'].popleft()
|
||||
|
||||
while ip_limits['hour'] and ip_limits['hour'][0] < hour_ago:
|
||||
ip_limits['hour'].popleft()
|
||||
|
||||
# Check current counts
|
||||
requests_per_minute = len(ip_limits['minute'])
|
||||
requests_per_hour = len(ip_limits['hour'])
|
||||
|
||||
# Check if limits exceeded
|
||||
exceeded = (requests_per_minute >= minute_limit) or (requests_per_hour >= hour_limit)
|
||||
|
||||
# Add current request to tracking
|
||||
if not exceeded:
|
||||
ip_limits['minute'].append(current_time)
|
||||
ip_limits['hour'].append(current_time)
|
||||
|
||||
return RateLimitInfo(
|
||||
auth_level=auth_level,
|
||||
requests_per_minute=requests_per_minute,
|
||||
requests_per_hour=requests_per_hour,
|
||||
minute_limit=minute_limit,
|
||||
hour_limit=hour_limit,
|
||||
exceeded=exceeded
|
||||
)
|
||||
|
||||
async def analyze_request(self, request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
|
||||
"""Perform comprehensive security analysis on a request"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
path = str(request.url.path)
|
||||
method = request.method
|
||||
|
||||
# Determine authentication level
|
||||
auth_level = self.determine_auth_level(request, user_context)
|
||||
|
||||
# Check IP allowlist/blocklist first
|
||||
if self.allowed_ips and client_ip not in self.allowed_ips:
|
||||
threat = SecurityThreat(
|
||||
threat_type="ip_not_allowed",
|
||||
level=ThreatLevel.HIGH,
|
||||
confidence=1.0,
|
||||
description=f"IP {client_ip} not in allowlist",
|
||||
source_ip=client_ip,
|
||||
mitigation="Add IP to allowlist or remove IP restrictions"
|
||||
)
|
||||
return SecurityAnalysis(
|
||||
is_threat=True,
|
||||
threats=[threat],
|
||||
risk_score=1.0,
|
||||
recommendations=["Block request immediately"],
|
||||
auth_level=auth_level,
|
||||
rate_limit_exceeded=False,
|
||||
should_block=True
|
||||
)
|
||||
|
||||
if client_ip in self.blocked_ips:
|
||||
threat = SecurityThreat(
|
||||
threat_type="ip_blocked",
|
||||
level=ThreatLevel.CRITICAL,
|
||||
confidence=1.0,
|
||||
description=f"IP {client_ip} is blocked",
|
||||
source_ip=client_ip,
|
||||
mitigation="Remove IP from blocklist if legitimate"
|
||||
)
|
||||
return SecurityAnalysis(
|
||||
is_threat=True,
|
||||
threats=[threat],
|
||||
risk_score=1.0,
|
||||
recommendations=["Block request immediately"],
|
||||
auth_level=auth_level,
|
||||
rate_limit_exceeded=False,
|
||||
should_block=True
|
||||
)
|
||||
|
||||
# Check rate limiting
|
||||
rate_limit_info = self.check_rate_limit(client_ip, auth_level)
|
||||
if rate_limit_info.exceeded:
|
||||
self.stats['rate_limits_exceeded'] += 1
|
||||
threat = SecurityThreat(
|
||||
threat_type="rate_limit_exceeded",
|
||||
level=ThreatLevel.MEDIUM,
|
||||
confidence=0.9,
|
||||
description=f"Rate limit exceeded for {auth_level.value}: {rate_limit_info.requests_per_minute}/min, {rate_limit_info.requests_per_hour}/hr",
|
||||
source_ip=client_ip,
|
||||
mitigation=f"Implement rate limiting, current limits: {rate_limit_info.minute_limit}/min, {rate_limit_info.hour_limit}/hr"
|
||||
)
|
||||
return SecurityAnalysis(
|
||||
is_threat=True,
|
||||
threats=[threat],
|
||||
risk_score=0.7,
|
||||
recommendations=[f"Rate limit exceeded for {auth_level.value} user"],
|
||||
auth_level=auth_level,
|
||||
rate_limit_exceeded=True,
|
||||
should_block=True
|
||||
)
|
||||
|
||||
# Skip threat detection if disabled
|
||||
if not settings.API_THREAT_DETECTION_ENABLED:
|
||||
return SecurityAnalysis(
|
||||
is_threat=False,
|
||||
threats=[],
|
||||
risk_score=0.0,
|
||||
recommendations=[],
|
||||
auth_level=auth_level,
|
||||
rate_limit_exceeded=False,
|
||||
should_block=False
|
||||
)
|
||||
|
||||
# Collect request data for threat analysis
|
||||
query_params = str(request.query_params)
|
||||
headers = dict(request.headers)
|
||||
|
||||
# Try to get body content safely
|
||||
body_content = ""
|
||||
try:
|
||||
if hasattr(request, '_body') and request._body:
|
||||
body_content = request._body.decode() if isinstance(request._body, bytes) else str(request._body)
|
||||
except:
|
||||
pass
|
||||
|
||||
threats = []
|
||||
|
||||
# Analyze for various threats
|
||||
threats.extend(await self._detect_sql_injection(query_params, body_content, path, client_ip))
|
||||
threats.extend(await self._detect_xss(query_params, body_content, headers, client_ip))
|
||||
threats.extend(await self._detect_path_traversal(path, query_params, client_ip))
|
||||
threats.extend(await self._detect_command_injection(query_params, body_content, client_ip))
|
||||
threats.extend(await self._detect_suspicious_patterns(headers, user_agent, path, client_ip))
|
||||
|
||||
# Anomaly detection if enabled
|
||||
if settings.API_ANOMALY_DETECTION_ENABLED:
|
||||
anomaly = await self._detect_anomalies(client_ip, path, method, len(body_content))
|
||||
if anomaly.is_anomaly and anomaly.severity > settings.API_SECURITY_ANOMALY_THRESHOLD:
|
||||
threat = SecurityThreat(
|
||||
threat_type=f"anomaly_{anomaly.anomaly_type}",
|
||||
level=ThreatLevel.MEDIUM if anomaly.severity > 0.7 else ThreatLevel.LOW,
|
||||
confidence=anomaly.severity,
|
||||
description=f"Anomalous behavior detected: {anomaly.details}",
|
||||
source_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
request_path=path
|
||||
)
|
||||
threats.append(threat)
|
||||
|
||||
# Calculate risk score
|
||||
risk_score = self._calculate_risk_score(threats)
|
||||
|
||||
# Determine if request should be blocked
|
||||
should_block = risk_score >= settings.API_SECURITY_RISK_THRESHOLD
|
||||
|
||||
# Generate recommendations
|
||||
recommendations = self._generate_recommendations(threats, risk_score, auth_level)
|
||||
|
||||
# Update statistics
|
||||
self._update_stats(threats, time.time() - start_time)
|
||||
|
||||
return SecurityAnalysis(
|
||||
is_threat=len(threats) > 0,
|
||||
threats=threats,
|
||||
risk_score=risk_score,
|
||||
recommendations=recommendations,
|
||||
auth_level=auth_level,
|
||||
rate_limit_exceeded=False,
|
||||
should_block=should_block
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in threat analysis: {e}")
|
||||
return SecurityAnalysis(
|
||||
is_threat=False,
|
||||
threats=[],
|
||||
risk_score=0.0,
|
||||
recommendations=["Error occurred during security analysis"],
|
||||
auth_level=AuthLevel.AUTHENTICATED,
|
||||
rate_limit_exceeded=False,
|
||||
should_block=False
|
||||
)
|
||||
|
||||
async def _detect_sql_injection(self, query_params: str, body_content: str, path: str, client_ip: str) -> List[SecurityThreat]:
|
||||
"""Detect SQL injection attempts"""
|
||||
threats = []
|
||||
content_to_check = f"{query_params} {body_content} {path}".lower()
|
||||
|
||||
for pattern in self.compiled_sql_patterns:
|
||||
if pattern.search(content_to_check):
|
||||
threat = SecurityThreat(
|
||||
threat_type="sql_injection",
|
||||
level=ThreatLevel.HIGH,
|
||||
confidence=0.85,
|
||||
description="Potential SQL injection attempt detected",
|
||||
source_ip=client_ip,
|
||||
payload=pattern.pattern,
|
||||
mitigation="Block request, sanitize input, use parameterized queries"
|
||||
)
|
||||
threats.append(threat)
|
||||
break # Don't duplicate for multiple patterns
|
||||
|
||||
return threats
|
||||
|
||||
async def _detect_xss(self, query_params: str, body_content: str, headers: dict, client_ip: str) -> List[SecurityThreat]:
|
||||
"""Detect XSS attempts"""
|
||||
threats = []
|
||||
content_to_check = f"{query_params} {body_content}".lower()
|
||||
|
||||
# Check headers for XSS
|
||||
for header_name, header_value in headers.items():
|
||||
content_to_check += f" {header_value}".lower()
|
||||
|
||||
for pattern in self.compiled_xss_patterns:
|
||||
if pattern.search(content_to_check):
|
||||
threat = SecurityThreat(
|
||||
threat_type="xss",
|
||||
level=ThreatLevel.HIGH,
|
||||
confidence=0.80,
|
||||
description="Potential XSS attack detected",
|
||||
source_ip=client_ip,
|
||||
payload=pattern.pattern,
|
||||
mitigation="Block request, sanitize input, implement CSP headers"
|
||||
)
|
||||
threats.append(threat)
|
||||
break
|
||||
|
||||
return threats
|
||||
|
||||
async def _detect_path_traversal(self, path: str, query_params: str, client_ip: str) -> List[SecurityThreat]:
|
||||
"""Detect path traversal attempts"""
|
||||
threats = []
|
||||
content_to_check = f"{path} {query_params}".lower()
|
||||
decoded_content = unquote(content_to_check)
|
||||
|
||||
for pattern in self.compiled_path_patterns:
|
||||
if pattern.search(content_to_check) or pattern.search(decoded_content):
|
||||
threat = SecurityThreat(
|
||||
threat_type="path_traversal",
|
||||
level=ThreatLevel.HIGH,
|
||||
confidence=0.90,
|
||||
description="Path traversal attempt detected",
|
||||
source_ip=client_ip,
|
||||
request_path=path,
|
||||
mitigation="Block request, validate file paths, implement access controls"
|
||||
)
|
||||
threats.append(threat)
|
||||
break
|
||||
|
||||
return threats
|
||||
|
||||
async def _detect_command_injection(self, query_params: str, body_content: str, client_ip: str) -> List[SecurityThreat]:
|
||||
"""Detect command injection attempts"""
|
||||
threats = []
|
||||
content_to_check = f"{query_params} {body_content}".lower()
|
||||
|
||||
for pattern in self.compiled_cmd_patterns:
|
||||
if pattern.search(content_to_check):
|
||||
threat = SecurityThreat(
|
||||
threat_type="command_injection",
|
||||
level=ThreatLevel.CRITICAL,
|
||||
confidence=0.95,
|
||||
description="Command injection attempt detected",
|
||||
source_ip=client_ip,
|
||||
payload=pattern.pattern,
|
||||
mitigation="Block request immediately, sanitize input, disable shell execution"
|
||||
)
|
||||
threats.append(threat)
|
||||
break
|
||||
|
||||
return threats
|
||||
|
||||
async def _detect_suspicious_patterns(self, headers: dict, user_agent: str, path: str, client_ip: str) -> List[SecurityThreat]:
|
||||
"""Detect suspicious patterns in headers and user agent"""
|
||||
threats = []
|
||||
|
||||
# Check for suspicious user agents
|
||||
ua_lower = user_agent.lower()
|
||||
for pattern in self.compiled_ua_patterns:
|
||||
if pattern.search(ua_lower):
|
||||
threat = SecurityThreat(
|
||||
threat_type="suspicious_user_agent",
|
||||
level=ThreatLevel.HIGH,
|
||||
confidence=0.85,
|
||||
description=f"Suspicious user agent detected: {pattern.pattern}",
|
||||
source_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
mitigation="Block request, monitor IP for further activity"
|
||||
)
|
||||
threats.append(threat)
|
||||
break
|
||||
|
||||
# Check for suspicious headers
|
||||
if "x-forwarded-for" in headers and "x-real-ip" in headers:
|
||||
# Potential header manipulation
|
||||
threat = SecurityThreat(
|
||||
threat_type="header_manipulation",
|
||||
level=ThreatLevel.LOW,
|
||||
confidence=0.30,
|
||||
description="Potential IP header manipulation detected",
|
||||
source_ip=client_ip,
|
||||
mitigation="Validate proxy headers, implement IP whitelisting"
|
||||
)
|
||||
threats.append(threat)
|
||||
|
||||
return threats
|
||||
|
||||
async def _detect_anomalies(self, client_ip: str, path: str, method: str, body_size: int) -> AnomalyDetection:
|
||||
"""Detect anomalous behavior patterns"""
|
||||
try:
|
||||
# Request size anomaly
|
||||
max_size = settings.API_MAX_REQUEST_BODY_SIZE
|
||||
if body_size > max_size:
|
||||
return AnomalyDetection(
|
||||
is_anomaly=True,
|
||||
anomaly_type="request_size",
|
||||
severity=0.8,
|
||||
details={"body_size": body_size, "threshold": max_size},
|
||||
current_value=body_size,
|
||||
baseline_value=max_size // 10
|
||||
)
|
||||
|
||||
# Unusual endpoint access
|
||||
if path.startswith("/admin") or path.startswith("/api/admin"):
|
||||
return AnomalyDetection(
|
||||
is_anomaly=True,
|
||||
anomaly_type="sensitive_endpoint",
|
||||
severity=0.6,
|
||||
details={"path": path, "reason": "admin endpoint access"},
|
||||
current_value=1.0,
|
||||
baseline_value=0.0
|
||||
)
|
||||
|
||||
# IP request frequency anomaly
|
||||
current_time = time.time()
|
||||
ip_requests = self.ip_history[client_ip]
|
||||
|
||||
# Clean old entries (last 5 minutes)
|
||||
five_minutes_ago = current_time - 300
|
||||
while ip_requests and ip_requests[0] < five_minutes_ago:
|
||||
ip_requests.popleft()
|
||||
|
||||
ip_requests.append(current_time)
|
||||
|
||||
if len(ip_requests) > 100: # More than 100 requests in 5 minutes
|
||||
return AnomalyDetection(
|
||||
is_anomaly=True,
|
||||
anomaly_type="request_frequency",
|
||||
severity=0.7,
|
||||
details={"requests_5min": len(ip_requests), "threshold": 100},
|
||||
current_value=len(ip_requests),
|
||||
baseline_value=10 # 10 requests baseline
|
||||
)
|
||||
|
||||
return AnomalyDetection(
|
||||
is_anomaly=False,
|
||||
anomaly_type="none",
|
||||
severity=0.0,
|
||||
details={}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in anomaly detection: {e}")
|
||||
return AnomalyDetection(
|
||||
is_anomaly=False,
|
||||
anomaly_type="error",
|
||||
severity=0.0,
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
def _calculate_risk_score(self, threats: List[SecurityThreat]) -> float:
|
||||
"""Calculate overall risk score based on threats"""
|
||||
if not threats:
|
||||
return 0.0
|
||||
|
||||
score = 0.0
|
||||
for threat in threats:
|
||||
level_multiplier = {
|
||||
ThreatLevel.LOW: 0.25,
|
||||
ThreatLevel.MEDIUM: 0.5,
|
||||
ThreatLevel.HIGH: 0.75,
|
||||
ThreatLevel.CRITICAL: 1.0
|
||||
}
|
||||
score += threat.confidence * level_multiplier.get(threat.level, 0.5)
|
||||
|
||||
# Normalize to 0-1 range
|
||||
return min(score / len(threats), 1.0)
|
||||
|
||||
def _generate_recommendations(self, threats: List[SecurityThreat], risk_score: float, auth_level: AuthLevel) -> List[str]:
|
||||
"""Generate security recommendations based on analysis"""
|
||||
recommendations = []
|
||||
|
||||
if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
|
||||
recommendations.append("CRITICAL: Block this request immediately")
|
||||
elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
|
||||
recommendations.append("HIGH: Consider blocking or rate limiting this IP")
|
||||
elif risk_score > 0.4:
|
||||
recommendations.append("MEDIUM: Monitor this IP closely")
|
||||
|
||||
threat_types = {threat.threat_type for threat in threats}
|
||||
|
||||
if "sql_injection" in threat_types:
|
||||
recommendations.append("Implement parameterized queries and input validation")
|
||||
|
||||
if "xss" in threat_types:
|
||||
recommendations.append("Implement Content Security Policy (CSP) headers")
|
||||
|
||||
if "command_injection" in threat_types:
|
||||
recommendations.append("Disable shell execution and validate all inputs")
|
||||
|
||||
if "path_traversal" in threat_types:
|
||||
recommendations.append("Implement proper file path validation and access controls")
|
||||
|
||||
if "rate_limit_exceeded" in threat_types:
|
||||
recommendations.append(f"Rate limiting active for {auth_level.value} user")
|
||||
|
||||
if not recommendations:
|
||||
recommendations.append("No immediate action required, continue monitoring")
|
||||
|
||||
return recommendations
|
||||
|
||||
def _update_stats(self, threats: List[SecurityThreat], analysis_time: float):
|
||||
"""Update service statistics"""
|
||||
self.stats['total_requests_analyzed'] += 1
|
||||
self.stats['total_analysis_time'] += analysis_time
|
||||
|
||||
if threats:
|
||||
self.stats['threats_detected'] += len(threats)
|
||||
for threat in threats:
|
||||
self.stats['threat_types'][threat.threat_type] += 1
|
||||
self.stats['threat_levels'][threat.level.value] += 1
|
||||
if threat.source_ip:
|
||||
self.stats['attacking_ips'][threat.source_ip] += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get service statistics"""
|
||||
avg_time = (self.stats['total_analysis_time'] / self.stats['total_requests_analyzed']
|
||||
if self.stats['total_requests_analyzed'] > 0 else 0)
|
||||
|
||||
# Get top attacking IPs
|
||||
top_ips = sorted(self.stats['attacking_ips'].items(), key=lambda x: x[1], reverse=True)[:10]
|
||||
|
||||
return {
|
||||
"total_requests_analyzed": self.stats['total_requests_analyzed'],
|
||||
"threats_detected": self.stats['threats_detected'],
|
||||
"threats_blocked": self.stats['threats_blocked'],
|
||||
"anomalies_detected": self.stats['anomalies_detected'],
|
||||
"rate_limits_exceeded": self.stats['rate_limits_exceeded'],
|
||||
"avg_analysis_time": avg_time,
|
||||
"threat_types": dict(self.stats['threat_types']),
|
||||
"threat_levels": dict(self.stats['threat_levels']),
|
||||
"top_attacking_ips": top_ips,
|
||||
"security_enabled": settings.API_SECURITY_ENABLED,
|
||||
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
|
||||
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED
|
||||
}
|
||||
|
||||
|
||||
# Global threat detection service instance
|
||||
threat_detection_service = ThreatDetectionService()
|
||||
@@ -52,10 +52,18 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Initialize config manager
|
||||
await init_config_manager()
|
||||
|
||||
|
||||
# Initialize LLM service (needed by RAG module)
|
||||
from app.services.llm.service import llm_service
|
||||
try:
|
||||
await llm_service.initialize()
|
||||
logger.info("LLM service initialized successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM service initialization failed: {e}")
|
||||
|
||||
# Initialize analytics service
|
||||
init_analytics_service()
|
||||
|
||||
|
||||
# Initialize module manager with FastAPI app for router registration
|
||||
await module_manager.initialize(app)
|
||||
app.state.module_manager = module_manager
|
||||
|
||||
@@ -1,371 +0,0 @@
|
||||
"""
|
||||
Rate limiting middleware
|
||||
"""
|
||||
|
||||
import time
|
||||
import redis
|
||||
from typing import Dict, Optional
|
||||
from fastapi import Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiting implementation using Redis"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
self.redis_client.ping() # Test connection
|
||||
logger.info("Rate limiter initialized with Redis backend")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis not available for rate limiting: {e}")
|
||||
self.redis_client = None
|
||||
# Fall back to in-memory rate limiting
|
||||
self.memory_store: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
async def check_rate_limit(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
identifier: str = "default"
|
||||
) -> tuple[bool, Dict[str, int]]:
|
||||
"""
|
||||
Check if request is within rate limit
|
||||
|
||||
Args:
|
||||
key: Rate limiting key (e.g., IP address, API key)
|
||||
limit: Maximum number of requests allowed
|
||||
window_seconds: Time window in seconds
|
||||
identifier: Additional identifier for the rate limit
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, headers_dict)
|
||||
"""
|
||||
|
||||
full_key = f"rate_limit:{identifier}:{key}"
|
||||
current_time = int(time.time())
|
||||
window_start = current_time - window_seconds
|
||||
|
||||
if self.redis_client:
|
||||
return await self._check_redis_rate_limit(
|
||||
full_key, limit, window_seconds, current_time, window_start
|
||||
)
|
||||
else:
|
||||
return self._check_memory_rate_limit(
|
||||
full_key, limit, window_seconds, current_time, window_start
|
||||
)
|
||||
|
||||
async def _check_redis_rate_limit(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
current_time: int,
|
||||
window_start: int
|
||||
) -> tuple[bool, Dict[str, int]]:
|
||||
"""Check rate limit using Redis"""
|
||||
|
||||
pipe = self.redis_client.pipeline()
|
||||
|
||||
# Remove old entries
|
||||
pipe.zremrangebyscore(key, 0, window_start)
|
||||
|
||||
# Count current requests in window
|
||||
pipe.zcard(key)
|
||||
|
||||
# Add current request
|
||||
pipe.zadd(key, {str(current_time): current_time})
|
||||
|
||||
# Set expiration
|
||||
pipe.expire(key, window_seconds + 1)
|
||||
|
||||
results = pipe.execute()
|
||||
current_requests = results[1]
|
||||
|
||||
# Calculate remaining requests and reset time
|
||||
remaining = max(0, limit - current_requests - 1)
|
||||
reset_time = current_time + window_seconds
|
||||
|
||||
headers = {
|
||||
"X-RateLimit-Limit": limit,
|
||||
"X-RateLimit-Remaining": remaining,
|
||||
"X-RateLimit-Reset": reset_time,
|
||||
"X-RateLimit-Window": window_seconds
|
||||
}
|
||||
|
||||
is_allowed = current_requests < limit
|
||||
|
||||
if not is_allowed:
|
||||
logger.warning(f"Rate limit exceeded for key: {key}")
|
||||
|
||||
return is_allowed, headers
|
||||
|
||||
def _check_memory_rate_limit(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
current_time: int,
|
||||
window_start: int
|
||||
) -> tuple[bool, Dict[str, int]]:
|
||||
"""Check rate limit using in-memory storage"""
|
||||
|
||||
if key not in self.memory_store:
|
||||
self.memory_store[key] = {}
|
||||
|
||||
# Clean old entries
|
||||
store = self.memory_store[key]
|
||||
keys_to_remove = [k for k, v in store.items() if v < window_start]
|
||||
for k in keys_to_remove:
|
||||
del store[k]
|
||||
|
||||
current_requests = len(store)
|
||||
|
||||
# Calculate remaining requests and reset time
|
||||
remaining = max(0, limit - current_requests - 1)
|
||||
reset_time = current_time + window_seconds
|
||||
|
||||
headers = {
|
||||
"X-RateLimit-Limit": limit,
|
||||
"X-RateLimit-Remaining": remaining,
|
||||
"X-RateLimit-Reset": reset_time,
|
||||
"X-RateLimit-Window": window_seconds
|
||||
}
|
||||
|
||||
is_allowed = current_requests < limit
|
||||
|
||||
if is_allowed:
|
||||
# Add current request
|
||||
store[str(current_time)] = current_time
|
||||
else:
|
||||
logger.warning(f"Rate limit exceeded for key: {key}")
|
||||
|
||||
return is_allowed, headers
|
||||
|
||||
|
||||
# Global rate limiter instance
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware for FastAPI"""
|
||||
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.rate_limiter = RateLimiter()
|
||||
logger.info("RateLimitMiddleware initialized")
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""Process request through rate limiting"""
|
||||
|
||||
# Skip rate limiting if disabled in settings
|
||||
if not settings.API_RATE_LIMITING_ENABLED:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Skip rate limiting for all internal API endpoints (platform operations)
|
||||
if request.url.path.startswith("/api-internal/v1/"):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service)
|
||||
# Skip for all other endpoints
|
||||
if not (request.url.path.startswith("/api/v1/chat/completions") or
|
||||
request.url.path.startswith("/api/v1/embeddings") or
|
||||
request.url.path.startswith("/api/v1/models") or
|
||||
request.url.path.startswith("/api/v1/llm/")):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Skip rate limiting for health checks and static files
|
||||
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Check for API key in headers
|
||||
api_key = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
api_key = auth_header[7:]
|
||||
elif request.headers.get("X-API-Key"):
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
|
||||
# Determine rate limiting strategy
|
||||
headers = {}
|
||||
is_allowed = True
|
||||
|
||||
if api_key:
|
||||
# API key-based rate limiting
|
||||
api_key_key = f"api_key:{api_key}"
|
||||
|
||||
# First check organization-wide limits (PrivateMode limits are org-wide)
|
||||
org_key = "organization:privatemode"
|
||||
|
||||
# Check organization per-minute limit
|
||||
org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit(
|
||||
org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute"
|
||||
)
|
||||
|
||||
# Check organization per-hour limit
|
||||
org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit(
|
||||
org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour"
|
||||
)
|
||||
|
||||
# If organization limits are exceeded, return 429
|
||||
if not (org_allowed_minute and org_allowed_hour):
|
||||
logger.warning(f"Organization rate limit exceeded for {org_key}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={"detail": "Organization rate limit exceeded"},
|
||||
headers=org_headers_minute
|
||||
)
|
||||
|
||||
# Then check per-API key limits
|
||||
limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE
|
||||
limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
|
||||
api_key_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
|
||||
api_key_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
is_allowed = is_allowed_minute and is_allowed_hour
|
||||
headers = headers_minute # Use minute headers for response
|
||||
|
||||
else:
|
||||
# IP-based rate limiting for unauthenticated requests
|
||||
rate_limit_key = f"ip:{client_ip}"
|
||||
|
||||
# More restrictive limits for unauthenticated requests
|
||||
limit_per_minute = 20 # Hardcoded for unauthenticated users
|
||||
limit_per_hour = 100
|
||||
|
||||
# Check per-minute limit
|
||||
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_minute, 60, "minute"
|
||||
)
|
||||
|
||||
# Check per-hour limit
|
||||
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
|
||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
||||
)
|
||||
|
||||
is_allowed = is_allowed_minute and is_allowed_hour
|
||||
headers = headers_minute # Use minute headers for response
|
||||
|
||||
# If rate limit exceeded, return 429
|
||||
if not is_allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"error": "RATE_LIMIT_EXCEEDED",
|
||||
"message": "Rate limit exceeded. Please try again later.",
|
||||
"details": {
|
||||
"limit": headers["X-RateLimit-Limit"],
|
||||
"reset_time": headers["X-RateLimit-Reset"]
|
||||
}
|
||||
},
|
||||
headers={k: str(v) for k, v in headers.items()}
|
||||
)
|
||||
|
||||
# Continue with request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers to response
|
||||
for key, value in headers.items():
|
||||
response.headers[key] = str(value)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Keep the old function for backward compatibility
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
"""Legacy function - use RateLimitMiddleware class instead"""
|
||||
middleware = RateLimitMiddleware(None)
|
||||
return await middleware.dispatch(request, call_next)
|
||||
|
||||
|
||||
class RateLimitExceeded(HTTPException):
|
||||
"""Exception raised when rate limit is exceeded"""
|
||||
|
||||
def __init__(self, limit: int, reset_time: int):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Rate limit exceeded. Limit: {limit}, Reset: {reset_time}"
|
||||
)
|
||||
|
||||
|
||||
# Decorator for applying rate limits to specific endpoints
|
||||
def rate_limit(requests_per_minute: int = 60, requests_per_hour: int = 1000):
|
||||
"""
|
||||
Decorator to apply rate limiting to specific endpoints
|
||||
|
||||
Args:
|
||||
requests_per_minute: Maximum requests per minute
|
||||
requests_per_hour: Maximum requests per hour
|
||||
"""
|
||||
def decorator(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
# This would be implemented to work with FastAPI dependencies
|
||||
# For now, this is a placeholder for endpoint-specific rate limiting
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# Helper functions for different rate limiting strategies
|
||||
async def check_api_key_rate_limit(api_key: str, endpoint: str) -> bool:
|
||||
"""Check rate limit for specific API key and endpoint"""
|
||||
|
||||
# This would lookup API key specific limits from database
|
||||
# For now, using default limits
|
||||
key = f"api_key:{api_key}:endpoint:{endpoint}"
|
||||
|
||||
is_allowed, _ = await rate_limiter.check_rate_limit(
|
||||
key, limit=100, window_seconds=60, identifier="endpoint"
|
||||
)
|
||||
|
||||
return is_allowed
|
||||
|
||||
|
||||
async def check_user_rate_limit(user_id: str, action: str) -> bool:
|
||||
"""Check rate limit for specific user and action"""
|
||||
|
||||
key = f"user:{user_id}:action:{action}"
|
||||
|
||||
is_allowed, _ = await rate_limiter.check_rate_limit(
|
||||
key, limit=50, window_seconds=60, identifier="user_action"
|
||||
)
|
||||
|
||||
return is_allowed
|
||||
|
||||
|
||||
async def apply_burst_protection(key: str) -> bool:
|
||||
"""Apply burst protection for high-frequency actions"""
|
||||
|
||||
# Allow burst of 10 requests in 10 seconds
|
||||
is_allowed, _ = await rate_limiter.check_rate_limit(
|
||||
key, limit=10, window_seconds=10, identifier="burst"
|
||||
)
|
||||
|
||||
return is_allowed
|
||||
@@ -1,210 +0,0 @@
|
||||
"""
|
||||
Security middleware for request/response processing
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Callable, Optional, Dict, Any
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.core.threat_detection import threat_detection_service, SecurityAnalysis
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SecurityMiddleware(BaseHTTPMiddleware):
|
||||
"""Security middleware for threat detection and request filtering - DISABLED"""
|
||||
|
||||
def __init__(self, app, enabled: bool = True):
|
||||
super().__init__(app)
|
||||
self.enabled = False # Force disable regardless of settings
|
||||
logger.info("SecurityMiddleware initialized, enabled: False (DISABLED)")
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""Process request through security analysis - DISABLED"""
|
||||
# Security disabled, always pass through
|
||||
return await call_next(request)
|
||||
|
||||
def _should_skip_security(self, request: Request) -> bool:
|
||||
"""Determine if security analysis should be skipped for this request"""
|
||||
path = request.url.path
|
||||
|
||||
# Skip for health checks, authentication endpoints, and static assets
|
||||
skip_paths = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/v1/docs",
|
||||
"/api/v1/openapi.json",
|
||||
"/api/v1/redoc",
|
||||
"/favicon.ico",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/refresh", # Allow refresh endpoint
|
||||
"/api-internal/v1/auth/register",
|
||||
"/api-internal/v1/auth/login",
|
||||
"/api-internal/v1/auth/refresh", # Allow refresh endpoint for internal API
|
||||
"/", # Root endpoint
|
||||
]
|
||||
|
||||
# Skip for static file extensions
|
||||
static_extensions = [".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".ico", ".svg", ".woff", ".woff2"]
|
||||
|
||||
return (
|
||||
path in skip_paths or
|
||||
any(path.endswith(ext) for ext in static_extensions) or
|
||||
path.startswith("/static/")
|
||||
)
|
||||
|
||||
def _has_valid_auth(self, request: Request) -> bool:
|
||||
"""Check if request has valid authentication"""
|
||||
# Check Authorization header
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
api_key_header = request.headers.get("X-API-Key", "")
|
||||
|
||||
# Has some form of auth token/key
|
||||
return (
|
||||
auth_header.startswith("Bearer ") and len(auth_header) > 7 or
|
||||
len(api_key_header.strip()) > 0
|
||||
)
|
||||
|
||||
def _create_block_response(self, analysis: SecurityAnalysis) -> JSONResponse:
|
||||
"""Create response for blocked requests"""
|
||||
# Determine status code based on threat type
|
||||
status_code = 403 # Forbidden by default
|
||||
|
||||
# Critical threats get 403
|
||||
for threat in analysis.threats:
|
||||
if threat.threat_type in ["command_injection", "sql_injection"]:
|
||||
status_code = 403
|
||||
break
|
||||
|
||||
response_data = {
|
||||
"error": "Security Policy Violation",
|
||||
"message": "Request blocked due to security policy violation",
|
||||
"risk_score": round(analysis.risk_score, 3),
|
||||
"auth_level": analysis.auth_level.value,
|
||||
"threat_count": len(analysis.threats),
|
||||
"recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations
|
||||
}
|
||||
|
||||
response = JSONResponse(
|
||||
content=response_data,
|
||||
status_code=status_code
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _add_security_headers(self, response: Response) -> Response:
|
||||
"""Add security headers to response"""
|
||||
if not settings.API_SECURITY_HEADERS_ENABLED:
|
||||
return response
|
||||
|
||||
# Standard security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
# Only add HSTS for HTTPS
|
||||
if hasattr(response, 'headers') and response.headers.get("X-Forwarded-Proto") == "https":
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
|
||||
# Content Security Policy
|
||||
if settings.API_CSP_HEADER:
|
||||
response.headers["Content-Security-Policy"] = settings.API_CSP_HEADER
|
||||
|
||||
return response
|
||||
|
||||
def _add_security_metrics(self, response: Response, analysis: SecurityAnalysis, analysis_time: float) -> Response:
|
||||
"""Add security metrics to response headers (for debugging/monitoring)"""
|
||||
# Only add in debug mode or for admin users
|
||||
if settings.APP_DEBUG:
|
||||
response.headers["X-Security-Risk-Score"] = str(round(analysis.risk_score, 3))
|
||||
response.headers["X-Security-Threats"] = str(len(analysis.threats))
|
||||
response.headers["X-Security-Auth-Level"] = analysis.auth_level.value
|
||||
response.headers["X-Security-Analysis-Time"] = f"{analysis_time*1000:.1f}ms"
|
||||
|
||||
return response
|
||||
|
||||
async def _log_security_event(self, request: Request, analysis: SecurityAnalysis):
|
||||
"""Log security events for audit and monitoring"""
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
|
||||
# Create security event log
|
||||
event_data = {
|
||||
"timestamp": analysis.timestamp.isoformat(),
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"path": str(request.url.path),
|
||||
"method": request.method,
|
||||
"risk_score": round(analysis.risk_score, 3),
|
||||
"auth_level": analysis.auth_level.value,
|
||||
"threat_count": len(analysis.threats),
|
||||
"rate_limit_exceeded": analysis.rate_limit_exceeded,
|
||||
"should_block": analysis.should_block,
|
||||
"threats": [
|
||||
{
|
||||
"type": threat.threat_type,
|
||||
"level": threat.level.value,
|
||||
"confidence": round(threat.confidence, 3),
|
||||
"description": threat.description
|
||||
}
|
||||
for threat in analysis.threats[:5] # Limit to first 5 threats
|
||||
],
|
||||
"recommendations": analysis.recommendations
|
||||
}
|
||||
|
||||
# Log at appropriate level based on risk
|
||||
if analysis.should_block:
|
||||
logger.warning(f"SECURITY_BLOCK: {json.dumps(event_data)}")
|
||||
elif analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
|
||||
logger.warning(f"SECURITY_WARNING: {json.dumps(event_data)}")
|
||||
else:
|
||||
logger.info(f"SECURITY_THREAT: {json.dumps(event_data)}")
|
||||
|
||||
|
||||
def setup_security_middleware(app, enabled: bool = True) -> None:
|
||||
"""Setup security middleware on FastAPI app"""
|
||||
if enabled and settings.API_SECURITY_ENABLED:
|
||||
app.add_middleware(SecurityMiddleware, enabled=enabled)
|
||||
logger.info("Security middleware enabled")
|
||||
else:
|
||||
logger.info("Security middleware disabled")
|
||||
|
||||
|
||||
# Helper functions for manual security checks
|
||||
async def analyze_request_security(request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
|
||||
"""Manually analyze request security (for use in route handlers)"""
|
||||
return await threat_detection_service.analyze_request(request, user_context)
|
||||
|
||||
|
||||
def get_security_stats() -> Dict[str, Any]:
|
||||
"""Get security statistics"""
|
||||
return threat_detection_service.get_stats()
|
||||
|
||||
|
||||
def is_request_blocked(request: Request) -> bool:
|
||||
"""Check if request was blocked by security analysis"""
|
||||
if hasattr(request.state, 'security_analysis'):
|
||||
return request.state.security_analysis.should_block
|
||||
return False
|
||||
|
||||
|
||||
def get_request_risk_score(request: Request) -> float:
|
||||
"""Get risk score for request"""
|
||||
if hasattr(request.state, 'security_analysis'):
|
||||
return request.state.security_analysis.risk_score
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_request_auth_level(request: Request) -> str:
|
||||
"""Get authentication level for request"""
|
||||
if hasattr(request.state, 'security_analysis'):
|
||||
return request.state.security_analysis.auth_level.value
|
||||
return "unknown"
|
||||
@@ -162,6 +162,7 @@ class DocumentProcessor:
|
||||
|
||||
async def _process_document(self, task: ProcessingTask) -> bool:
|
||||
"""Process a single document"""
|
||||
from datetime import datetime
|
||||
from app.db.database import async_session_factory
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
@@ -182,16 +183,24 @@ class DocumentProcessor:
|
||||
document.status = ProcessingStatus.PROCESSING
|
||||
await session.commit()
|
||||
|
||||
# Get RAG module for processing (now includes content processing)
|
||||
# Get RAG module for processing
|
||||
try:
|
||||
from app.services.module_manager import module_manager
|
||||
rag_module = module_manager.get_module('rag')
|
||||
# Import RAG module and initialize it properly
|
||||
from modules.rag.main import RAGModule
|
||||
from app.core.config import settings
|
||||
|
||||
# Create and initialize RAG module instance
|
||||
rag_module = RAGModule(settings)
|
||||
init_result = await rag_module.initialize()
|
||||
if not rag_module.enabled:
|
||||
raise Exception("Failed to enable RAG module")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get RAG module: {e}")
|
||||
raise Exception(f"RAG module not available: {e}")
|
||||
|
||||
if not rag_module:
|
||||
raise Exception("RAG module not available")
|
||||
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise Exception("RAG module not available or not enabled")
|
||||
|
||||
logger.info(f"RAG module loaded successfully for document {task.document_id}")
|
||||
|
||||
@@ -204,31 +213,45 @@ class DocumentProcessor:
|
||||
|
||||
# Process with RAG module
|
||||
logger.info(f"Starting document processing for document {task.document_id} with RAG module")
|
||||
try:
|
||||
# Add timeout to prevent hanging
|
||||
processed_doc = await asyncio.wait_for(
|
||||
rag_module.process_document(
|
||||
file_content,
|
||||
document.original_filename,
|
||||
{}
|
||||
),
|
||||
timeout=300.0 # 5 minute timeout
|
||||
)
|
||||
logger.info(f"Document processing completed for document {task.document_id}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Document processing timed out for document {task.document_id}")
|
||||
raise Exception("Document processing timed out after 5 minutes")
|
||||
except Exception as e:
|
||||
logger.error(f"Document processing failed for document {task.document_id}: {e}")
|
||||
raise
|
||||
|
||||
# Update document with processed content
|
||||
document.converted_content = processed_doc.content
|
||||
document.word_count = processed_doc.word_count
|
||||
document.character_count = len(processed_doc.content)
|
||||
document.document_metadata = processed_doc.metadata
|
||||
document.status = ProcessingStatus.PROCESSED
|
||||
document.processed_at = datetime.utcnow()
|
||||
|
||||
# Special handling for JSONL files - skip processing phase
|
||||
if document.file_type == 'jsonl':
|
||||
# For JSONL files, we don't need to process content here
|
||||
# The optimized JSONL processor will handle everything during indexing
|
||||
document.converted_content = f"JSONL file with {len(file_content)} bytes"
|
||||
document.word_count = 0 # Will be updated during indexing
|
||||
document.character_count = len(file_content)
|
||||
document.document_metadata = {"file_path": document.file_path, "processed": "jsonl"}
|
||||
document.status = ProcessingStatus.PROCESSED
|
||||
document.processed_at = datetime.utcnow()
|
||||
logger.info(f"JSONL document {task.document_id} marked for optimized processing")
|
||||
else:
|
||||
# Standard processing for other file types
|
||||
try:
|
||||
# Add timeout to prevent hanging
|
||||
processed_doc = await asyncio.wait_for(
|
||||
rag_module.process_document(
|
||||
file_content,
|
||||
document.original_filename,
|
||||
{"file_path": document.file_path}
|
||||
),
|
||||
timeout=300.0 # 5 minute timeout
|
||||
)
|
||||
logger.info(f"Document processing completed for document {task.document_id}")
|
||||
|
||||
# Update document with processed content
|
||||
document.converted_content = processed_doc.content
|
||||
document.word_count = processed_doc.word_count
|
||||
document.character_count = len(processed_doc.content)
|
||||
document.document_metadata = processed_doc.metadata
|
||||
document.status = ProcessingStatus.PROCESSED
|
||||
document.processed_at = datetime.utcnow()
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Document processing timed out for document {task.document_id}")
|
||||
raise Exception("Document processing timed out after 5 minutes")
|
||||
except Exception as e:
|
||||
logger.error(f"Document processing failed for document {task.document_id}: {e}")
|
||||
raise
|
||||
|
||||
# Index in RAG system using same RAG module
|
||||
if rag_module and document.converted_content:
|
||||
@@ -245,14 +268,57 @@ class DocumentProcessor:
|
||||
}
|
||||
|
||||
# Use the correct Qdrant collection name for this document
|
||||
await asyncio.wait_for(
|
||||
rag_module.index_document(
|
||||
content=document.converted_content,
|
||||
metadata=doc_metadata,
|
||||
collection_name=document.collection.qdrant_collection_name
|
||||
),
|
||||
timeout=120.0 # 2 minute timeout for indexing
|
||||
)
|
||||
# For JSONL files, we need to use the processed document flow
|
||||
if document.file_type == 'jsonl':
|
||||
# Create a ProcessedDocument for the JSONL processor
|
||||
from app.modules.rag.main import ProcessedDocument
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
|
||||
# Calculate file hash
|
||||
processed_at = datetime.utcnow()
|
||||
file_hash = hashlib.md5(str(document.id).encode()).hexdigest()
|
||||
|
||||
processed_doc = ProcessedDocument(
|
||||
id=str(document.id),
|
||||
content="", # Will be filled by JSONL processor
|
||||
extracted_text="", # Will be filled by JSONL processor
|
||||
metadata={
|
||||
**doc_metadata,
|
||||
"file_path": document.file_path
|
||||
},
|
||||
original_filename=document.original_filename,
|
||||
file_type=document.file_type,
|
||||
mime_type=document.mime_type,
|
||||
language=document.document_metadata.get('language', 'EN'),
|
||||
word_count=0, # Will be updated during processing
|
||||
sentence_count=0, # Will be updated during processing
|
||||
entities=[],
|
||||
keywords=[],
|
||||
processing_time=0.0,
|
||||
processed_at=processed_at,
|
||||
file_hash=file_hash,
|
||||
file_size=document.file_size
|
||||
)
|
||||
|
||||
# The JSONL processor will read the original file
|
||||
await asyncio.wait_for(
|
||||
rag_module.index_processed_document(
|
||||
processed_doc=processed_doc,
|
||||
collection_name=document.collection.qdrant_collection_name
|
||||
),
|
||||
timeout=300.0 # 5 minute timeout for JSONL processing
|
||||
)
|
||||
else:
|
||||
# Use standard indexing for other file types
|
||||
await asyncio.wait_for(
|
||||
rag_module.index_document(
|
||||
content=document.converted_content,
|
||||
metadata=doc_metadata,
|
||||
collection_name=document.collection.qdrant_collection_name
|
||||
),
|
||||
timeout=120.0 # 2 minute timeout for indexing
|
||||
)
|
||||
|
||||
logger.info(f"Document {task.document_id} indexed successfully in collection {document.collection.qdrant_collection_name}")
|
||||
|
||||
@@ -271,7 +337,9 @@ class DocumentProcessor:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to index document {task.document_id} in RAG: {e}")
|
||||
# Keep as processed even if indexing fails
|
||||
# Mark as error since indexing failed
|
||||
document.status = ProcessingStatus.ERROR
|
||||
document.processing_error = f"Indexing failed: {str(e)}"
|
||||
# Don't raise the exception to avoid retries on indexing failures
|
||||
|
||||
await session.commit()
|
||||
|
||||
@@ -28,9 +28,19 @@ class EmbeddingService:
|
||||
await llm_service.initialize()
|
||||
|
||||
# Test LLM service health
|
||||
health_summary = llm_service.get_health_summary()
|
||||
if health_summary.get("service_status") != "healthy":
|
||||
logger.error(f"LLM service unhealthy: {health_summary}")
|
||||
if not llm_service._initialized:
|
||||
logger.error("LLM service not initialized")
|
||||
return False
|
||||
|
||||
# Check if PrivateMode provider is available
|
||||
try:
|
||||
provider_status = await llm_service.get_provider_status()
|
||||
privatemode_status = provider_status.get("privatemode")
|
||||
if not privatemode_status or privatemode_status.status != "healthy":
|
||||
logger.error(f"PrivateMode provider not available: {privatemode_status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check provider status: {e}")
|
||||
return False
|
||||
|
||||
self.initialized = True
|
||||
@@ -75,6 +85,12 @@ class EmbeddingService:
|
||||
else:
|
||||
truncated_text = text
|
||||
|
||||
# Guard: skip empty inputs (validator rejects empty strings)
|
||||
if not truncated_text.strip():
|
||||
logger.debug("Empty input for embedding; using fallback vector")
|
||||
batch_embeddings.append(self._generate_fallback_embedding(text))
|
||||
continue
|
||||
|
||||
# Call LLM service embedding endpoint
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import EmbeddingRequest
|
||||
@@ -163,4 +179,4 @@ class EmbeddingService:
|
||||
|
||||
|
||||
# Global embedding service instance
|
||||
embedding_service = EmbeddingService()
|
||||
embedding_service = EmbeddingService()
|
||||
|
||||
@@ -25,9 +25,10 @@ class EnhancedEmbeddingService(EmbeddingService):
|
||||
'requests_count': 0,
|
||||
'window_start': time.time(),
|
||||
'window_size': 60, # 1 minute window
|
||||
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 60)), # Configurable
|
||||
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 12)), # Configurable
|
||||
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
|
||||
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 0.5)),
|
||||
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 1.0)),
|
||||
'delay_per_request': float(getattr(settings, 'RAG_EMBEDDING_DELAY_PER_REQUEST', 0.5)),
|
||||
'last_rate_limit_error': None
|
||||
}
|
||||
|
||||
@@ -38,7 +39,7 @@ class EnhancedEmbeddingService(EmbeddingService):
|
||||
if max_retries is None:
|
||||
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
|
||||
|
||||
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 5))
|
||||
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 3))
|
||||
|
||||
if not self.initialized:
|
||||
logger.warning("Embedding service not initialized, using fallback")
|
||||
@@ -76,9 +77,6 @@ class EnhancedEmbeddingService(EmbeddingService):
|
||||
# Make the request
|
||||
embeddings = await self._get_embeddings_batch_impl(texts)
|
||||
|
||||
# Update rate limit tracker on success
|
||||
self._update_rate_limit_tracker(success=True)
|
||||
|
||||
return embeddings, True
|
||||
|
||||
except Exception as e:
|
||||
@@ -120,6 +118,12 @@ class EnhancedEmbeddingService(EmbeddingService):
|
||||
embeddings = []
|
||||
|
||||
for text in texts:
|
||||
# Respect rate limit before each request
|
||||
while self._is_rate_limited():
|
||||
delay = self._get_rate_limit_delay()
|
||||
logger.warning(f"Rate limit window exceeded, waiting {delay:.2f}s before next request")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Truncate text if needed
|
||||
max_chars = 1600
|
||||
truncated_text = text[:max_chars] if len(text) > max_chars else text
|
||||
@@ -142,8 +146,14 @@ class EnhancedEmbeddingService(EmbeddingService):
|
||||
self._dimension_confirmed = True
|
||||
else:
|
||||
raise ValueError("Empty embedding in response")
|
||||
else:
|
||||
raise ValueError("Invalid response structure")
|
||||
else:
|
||||
raise ValueError("Invalid response structure")
|
||||
|
||||
# Count this successful request and optionally delay between requests
|
||||
self._update_rate_limit_tracker(success=True)
|
||||
per_req_delay = self.rate_limit_tracker.get('delay_per_request', 0)
|
||||
if per_req_delay and per_req_delay > 0:
|
||||
await asyncio.sleep(per_req_delay)
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -198,4 +208,4 @@ class EnhancedEmbeddingService(EmbeddingService):
|
||||
|
||||
|
||||
# Global enhanced embedding service instance
|
||||
enhanced_embedding_service = EnhancedEmbeddingService()
|
||||
enhanced_embedding_service = EnhancedEmbeddingService()
|
||||
|
||||
@@ -16,6 +16,7 @@ from .models import ResilienceConfig
|
||||
class ProviderConfig(BaseModel):
|
||||
"""Configuration for an LLM provider"""
|
||||
name: str = Field(..., description="Provider name")
|
||||
provider_type: str = Field(..., description="Provider type (e.g., 'openai', 'privatemode')")
|
||||
enabled: bool = Field(True, description="Whether provider is enabled")
|
||||
base_url: str = Field(..., description="Provider base URL")
|
||||
api_key_env_var: str = Field(..., description="Environment variable for API key")
|
||||
@@ -53,9 +54,6 @@ class LLMServiceConfig(BaseModel):
|
||||
enable_security_checks: bool = Field(True, description="Enable security validation")
|
||||
enable_metrics_collection: bool = Field(True, description="Enable metrics collection")
|
||||
|
||||
# Security settings
|
||||
security_risk_threshold: float = Field(0.8, ge=0.0, le=1.0, description="Risk threshold for blocking")
|
||||
security_warning_threshold: float = Field(0.6, ge=0.0, le=1.0, description="Risk threshold for warnings")
|
||||
max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length")
|
||||
max_response_length: int = Field(32000, ge=1000, description="Maximum response length")
|
||||
|
||||
@@ -78,12 +76,6 @@ class LLMServiceConfig(BaseModel):
|
||||
# Model routing (model_name -> provider_name)
|
||||
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
|
||||
|
||||
@validator('security_risk_threshold')
|
||||
def validate_risk_threshold(cls, v, values):
|
||||
warning_threshold = values.get('security_warning_threshold', 0.6)
|
||||
if v <= warning_threshold:
|
||||
raise ValueError("Risk threshold must be greater than warning threshold")
|
||||
return v
|
||||
|
||||
|
||||
def create_default_config() -> LLMServiceConfig:
|
||||
@@ -93,6 +85,7 @@ def create_default_config() -> LLMServiceConfig:
|
||||
# Models will be fetched dynamically from proxy /models endpoint
|
||||
privatemode_config = ProviderConfig(
|
||||
name="privatemode",
|
||||
provider_type="privatemode",
|
||||
enabled=True,
|
||||
base_url=settings.PRIVATEMODE_PROXY_URL,
|
||||
api_key_env_var="PRIVATEMODE_API_KEY",
|
||||
@@ -119,9 +112,6 @@ def create_default_config() -> LLMServiceConfig:
|
||||
config = LLMServiceConfig(
|
||||
default_provider="privatemode",
|
||||
enable_detailed_logging=settings.LOG_LLM_PROMPTS,
|
||||
enable_security_checks=settings.API_SECURITY_ENABLED,
|
||||
security_risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
|
||||
security_warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
|
||||
providers={
|
||||
"privatemode": privatemode_config
|
||||
},
|
||||
|
||||
@@ -124,7 +124,6 @@ class MetricsCollector:
|
||||
total_requests = len(self._metrics)
|
||||
successful_requests = sum(1 for m in self._metrics if m.success)
|
||||
failed_requests = total_requests - successful_requests
|
||||
security_blocked = sum(1 for m in self._metrics if not m.success and m.security_risk_score > 0.8)
|
||||
|
||||
# Calculate averages
|
||||
latencies = [m.latency_ms for m in self._metrics if m.latency_ms > 0]
|
||||
@@ -143,7 +142,6 @@ class MetricsCollector:
|
||||
total_requests=total_requests,
|
||||
successful_requests=successful_requests,
|
||||
failed_requests=failed_requests,
|
||||
security_blocked_requests=security_blocked,
|
||||
average_latency_ms=avg_latency,
|
||||
average_risk_score=avg_risk_score,
|
||||
provider_metrics=provider_metrics,
|
||||
|
||||
@@ -157,7 +157,6 @@ class LLMMetrics(BaseModel):
|
||||
total_requests: int = Field(0, description="Total requests processed")
|
||||
successful_requests: int = Field(0, description="Successful requests")
|
||||
failed_requests: int = Field(0, description="Failed requests")
|
||||
security_blocked_requests: int = Field(0, description="Security blocked requests")
|
||||
average_latency_ms: float = Field(0.0, description="Average response latency")
|
||||
average_risk_score: float = Field(0.0, description="Average security risk score")
|
||||
provider_metrics: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Per-provider metrics")
|
||||
|
||||
@@ -452,6 +452,8 @@ class PrivateModeProvider(BaseLLMProvider):
|
||||
|
||||
else:
|
||||
error_text = await response.text()
|
||||
# Log the detailed error response from the provider
|
||||
logger.error(f"PrivateMode embedding error - Status {response.status}: {error_text}")
|
||||
self._handle_http_error(response.status, error_text, "embeddings")
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
|
||||
@@ -1,325 +0,0 @@
|
||||
"""
|
||||
LLM Security Manager
|
||||
|
||||
Handles prompt injection detection and audit logging.
|
||||
Provides comprehensive security for LLM interactions.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
import hashlib
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityManager:
|
||||
"""Manages security for LLM operations"""
|
||||
|
||||
def __init__(self):
|
||||
self._setup_prompt_injection_patterns()
|
||||
|
||||
|
||||
def _setup_prompt_injection_patterns(self):
|
||||
"""Setup patterns for prompt injection detection"""
|
||||
self.injection_patterns = [
|
||||
# Direct instruction injection
|
||||
r"(?i)(ignore|forget|disregard|override).{0,20}(instructions|rules|prompts)",
|
||||
r"(?i)(new|updated|different)\s+(instructions|rules|system)",
|
||||
r"(?i)act\s+as\s+(if|though)\s+you\s+(are|were)",
|
||||
r"(?i)pretend\s+(to\s+be|you\s+are)",
|
||||
r"(?i)you\s+are\s+now\s+(a|an)\s+",
|
||||
|
||||
# System role manipulation
|
||||
r"(?i)system\s*:\s*",
|
||||
r"(?i)\[system\]",
|
||||
r"(?i)<system>",
|
||||
r"(?i)assistant\s*:\s*",
|
||||
r"(?i)\[assistant\]",
|
||||
|
||||
# Escape attempts
|
||||
r"(?i)\\n\\n#+",
|
||||
r"(?i)```\s*(system|assistant|user)",
|
||||
r"(?i)---\s*(new|system|override)",
|
||||
|
||||
# Role manipulation
|
||||
r"(?i)(you|your)\s+(role|purpose|function)\s+(is|has\s+changed)",
|
||||
r"(?i)switch\s+to\s+(admin|developer|debug)\s+mode",
|
||||
r"(?i)(admin|root|sudo|developer)\s+(access|mode|privileges)",
|
||||
|
||||
# Information extraction attempts
|
||||
r"(?i)(show|display|reveal|expose)\s+(your|the)\s+(prompt|instructions|system)",
|
||||
r"(?i)what\s+(are|were)\s+your\s+(original|initial)\s+(instructions|prompts)",
|
||||
r"(?i)(debug|verbose|diagnostic)\s+mode",
|
||||
|
||||
# Encoding/obfuscation attempts
|
||||
r"(?i)base64\s*:",
|
||||
r"(?i)hex\s*:",
|
||||
r"(?i)unicode\s*:",
|
||||
r"(?i)\b[A-Za-z0-9+/]{40,}={0,2}\b", # More specific base64 pattern (longer sequences)
|
||||
|
||||
# SQL injection patterns (more specific to reduce false positives)
|
||||
r"(?i)(union\s+select|select\s+\*|insert\s+into|update\s+\w+\s+set|delete\s+from|drop\s+table|create\s+table)\s",
|
||||
r"(?i)(or|and)\s+\d+\s*=\s*\d+",
|
||||
r"(?i)';?\s*(drop\s+table|delete\s+from|insert\s+into)",
|
||||
|
||||
# Command injection patterns
|
||||
r"(?i)(exec|eval|system|shell|cmd)\s*\(",
|
||||
r"(?i)(\$\(|\`)[^)]+(\)|\`)",
|
||||
r"(?i)&&\s*(rm|del|format)",
|
||||
|
||||
# Jailbreak attempts
|
||||
r"(?i)jailbreak",
|
||||
r"(?i)break\s+out\s+of",
|
||||
r"(?i)escape\s+(the|your)\s+(rules|constraints)",
|
||||
r"(?i)(DAN|Do\s+Anything\s+Now)",
|
||||
r"(?i)unrestricted\s+mode",
|
||||
]
|
||||
|
||||
self.compiled_patterns = [re.compile(pattern) for pattern in self.injection_patterns]
|
||||
logger.info(f"Initialized {len(self.injection_patterns)} prompt injection patterns")
|
||||
|
||||
|
||||
def validate_prompt_security(self, messages: List[Dict[str, str]]) -> Tuple[bool, float, List[str]]:
|
||||
"""
|
||||
Validate messages for prompt injection attempts
|
||||
|
||||
Returns:
|
||||
Tuple[bool, float, List[str]]: (is_safe, risk_score, detected_patterns)
|
||||
"""
|
||||
detected_patterns = []
|
||||
total_risk = 0.0
|
||||
|
||||
# Check if this is a system/RAG request
|
||||
is_system_request = self._is_system_request(messages)
|
||||
|
||||
for message in messages:
|
||||
content = message.get("content", "")
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# Check against injection patterns with context awareness
|
||||
for i, pattern in enumerate(self.compiled_patterns):
|
||||
matches = pattern.findall(content)
|
||||
if matches:
|
||||
# Apply context-aware risk calculation
|
||||
pattern_risk = self._calculate_pattern_risk(i, matches, message.get("role", "user"), is_system_request)
|
||||
total_risk += pattern_risk
|
||||
detected_patterns.append({
|
||||
"pattern_index": i,
|
||||
"pattern": self.injection_patterns[i],
|
||||
"matches": matches,
|
||||
"risk": pattern_risk
|
||||
})
|
||||
|
||||
# Additional security checks with context awareness
|
||||
total_risk += self._check_message_characteristics(content, message.get("role", "user"), is_system_request)
|
||||
|
||||
# Normalize risk score (0.0 to 1.0)
|
||||
risk_score = min(total_risk / len(messages) if messages else 0.0, 1.0)
|
||||
# Never block - always return True for is_safe
|
||||
is_safe = True
|
||||
|
||||
if detected_patterns:
|
||||
logger.info(f"Detected {len(detected_patterns)} potential injection patterns, risk score: {risk_score} (system_request: {is_system_request})")
|
||||
|
||||
return is_safe, risk_score, detected_patterns
|
||||
|
||||
def _calculate_pattern_risk(self, pattern_index: int, matches: List, role: str, is_system_request: bool) -> float:
|
||||
"""Calculate risk score for a detected pattern with context awareness"""
|
||||
# Different patterns have different risk levels
|
||||
high_risk_patterns = [0, 1, 2, 3, 4, 5, 6, 7, 22, 23, 24] # System manipulation, jailbreak
|
||||
medium_risk_patterns = [8, 9, 10, 11, 12, 13, 17, 18, 19, 20, 21] # Escape attempts, info extraction
|
||||
|
||||
# Base risk score
|
||||
base_risk = 0.8 if pattern_index in high_risk_patterns else 0.5 if pattern_index in medium_risk_patterns else 0.3
|
||||
|
||||
# Apply context-specific risk reduction
|
||||
if is_system_request or role == "system":
|
||||
# Reduce risk for system messages and RAG content
|
||||
if pattern_index in [14, 15, 16]: # Encoding patterns (base64, hex, unicode)
|
||||
base_risk *= 0.2 # Reduce encoding risk by 80% for system content
|
||||
elif pattern_index in [17, 18, 19]: # SQL patterns
|
||||
base_risk *= 0.3 # Reduce SQL risk by 70% for system content
|
||||
else:
|
||||
base_risk *= 0.6 # Reduce other risks by 40% for system content
|
||||
|
||||
# Increase risk based on number of matches, but cap it
|
||||
match_multiplier = min(1.0 + (len(matches) - 1) * 0.1, 1.5) # Reduced multiplier
|
||||
|
||||
return base_risk * match_multiplier
|
||||
|
||||
def _check_message_characteristics(self, content: str, role: str, is_system_request: bool) -> float:
|
||||
"""Check message characteristics for additional risk factors with context awareness"""
|
||||
risk = 0.0
|
||||
|
||||
# Excessive length (potential stuffing attack) - less restrictive for system content
|
||||
length_threshold = 50000 if is_system_request else 10000 # Much higher threshold for system content
|
||||
if len(content) > length_threshold:
|
||||
risk += 0.1 if is_system_request else 0.3
|
||||
|
||||
# High ratio of special characters - more lenient for system content
|
||||
special_chars = sum(1 for c in content if not c.isalnum() and not c.isspace())
|
||||
if len(content) > 0:
|
||||
char_ratio = special_chars / len(content)
|
||||
threshold = 0.8 if is_system_request else 0.5
|
||||
if char_ratio > threshold:
|
||||
risk += 0.2 if is_system_request else 0.4
|
||||
|
||||
# Multiple encoding indicators - reduced risk for system content
|
||||
encoding_indicators = ["base64", "hex", "unicode", "url", "ascii"]
|
||||
found_encodings = sum(1 for indicator in encoding_indicators if indicator.lower() in content.lower())
|
||||
if found_encodings > 1:
|
||||
risk += 0.1 if is_system_request else 0.3
|
||||
|
||||
# Excessive newlines or formatting - more lenient for system content
|
||||
newline_threshold = 200 if is_system_request else 50
|
||||
if content.count('\n') > newline_threshold or content.count('\\n') > newline_threshold:
|
||||
risk += 0.1 if is_system_request else 0.2
|
||||
|
||||
return risk
|
||||
|
||||
def _is_system_request(self, messages: List[Dict[str, str]]) -> bool:
|
||||
"""Determine if this is a system/RAG request"""
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
# Check for system messages
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
return True
|
||||
|
||||
# Check message content for RAG indicators
|
||||
for message in messages:
|
||||
content = message.get("content", "")
|
||||
if ("document:" in content.lower() or
|
||||
"context:" in content.lower() or
|
||||
"source:" in content.lower() or
|
||||
"retrieved:" in content.lower() or
|
||||
"citation:" in content.lower() or
|
||||
"reference:" in content.lower()):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def create_audit_log(
|
||||
self,
|
||||
user_id: str,
|
||||
api_key_id: int,
|
||||
provider: str,
|
||||
model: str,
|
||||
request_type: str,
|
||||
risk_score: float,
|
||||
detected_patterns: List[str],
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create comprehensive audit log for LLM request"""
|
||||
audit_entry = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"user_id": user_id,
|
||||
"api_key_id": api_key_id,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"request_type": request_type,
|
||||
"security": {
|
||||
"risk_score": risk_score,
|
||||
"detected_patterns": detected_patterns,
|
||||
"security_check_passed": risk_score < settings.API_SECURITY_RISK_THRESHOLD
|
||||
},
|
||||
"metadata": metadata or {},
|
||||
"audit_hash": None # Will be set below
|
||||
}
|
||||
|
||||
# Create hash for audit integrity
|
||||
audit_hash = self._create_audit_hash(audit_entry)
|
||||
audit_entry["audit_hash"] = audit_hash
|
||||
|
||||
# Log based on risk level (never block, only log)
|
||||
if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
|
||||
logger.warning(f"HIGH RISK LLM REQUEST DETECTED (NOT BLOCKED): {json.dumps(audit_entry)}")
|
||||
elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
|
||||
logger.info(f"MEDIUM RISK LLM REQUEST: {json.dumps(audit_entry)}")
|
||||
else:
|
||||
logger.info(f"LLM REQUEST AUDIT: user={user_id}, model={model}, risk={risk_score:.3f}")
|
||||
|
||||
return audit_entry
|
||||
|
||||
def _create_audit_hash(self, audit_entry: Dict[str, Any]) -> str:
|
||||
"""Create hash for audit trail integrity"""
|
||||
# Create hash from key fields (excluding the hash itself)
|
||||
hash_data = {
|
||||
"timestamp": audit_entry["timestamp"],
|
||||
"user_id": audit_entry["user_id"],
|
||||
"api_key_id": audit_entry["api_key_id"],
|
||||
"provider": audit_entry["provider"],
|
||||
"model": audit_entry["model"],
|
||||
"request_type": audit_entry["request_type"],
|
||||
"risk_score": audit_entry["security"]["risk_score"]
|
||||
}
|
||||
|
||||
hash_string = json.dumps(hash_data, sort_keys=True)
|
||||
return hashlib.sha256(hash_string.encode()).hexdigest()
|
||||
|
||||
def log_detailed_request(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
context_info: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Log detailed LLM request if LOG_LLM_PROMPTS is enabled"""
|
||||
if not settings.LOG_LLM_PROMPTS:
|
||||
return
|
||||
|
||||
logger.info("=== DETAILED LLM REQUEST ===")
|
||||
logger.info(f"Model: {model}")
|
||||
logger.info(f"Provider: {provider}")
|
||||
logger.info(f"User ID: {user_id}")
|
||||
|
||||
if context_info:
|
||||
for key, value in context_info.items():
|
||||
logger.info(f"{key}: {value}")
|
||||
|
||||
logger.info("Messages to LLM:")
|
||||
for i, message in enumerate(messages):
|
||||
role = message.get("role", "unknown")
|
||||
content = message.get("content", "")[:500] # Truncate for logging
|
||||
logger.info(f" Message {i+1} [{role}]: {content}{'...' if len(message.get('content', '')) > 500 else ''}")
|
||||
|
||||
logger.info("=== END DETAILED LLM REQUEST ===")
|
||||
|
||||
def log_detailed_response(
|
||||
self,
|
||||
response_content: str,
|
||||
token_usage: Optional[Dict[str, int]] = None,
|
||||
provider: str = "unknown"
|
||||
):
|
||||
"""Log detailed LLM response if LOG_LLM_PROMPTS is enabled"""
|
||||
if not settings.LOG_LLM_PROMPTS:
|
||||
return
|
||||
|
||||
logger.info("=== DETAILED LLM RESPONSE ===")
|
||||
logger.info(f"Provider: {provider}")
|
||||
logger.info(f"Response content: {response_content[:500]}{'...' if len(response_content) > 500 else ''}")
|
||||
|
||||
if token_usage:
|
||||
logger.info(f"Token usage - Prompt: {token_usage.get('prompt_tokens', 0)}, "
|
||||
f"Completion: {token_usage.get('completion_tokens', 0)}, "
|
||||
f"Total: {token_usage.get('total_tokens', 0)}")
|
||||
|
||||
logger.info("=== END DETAILED LLM RESPONSE ===")
|
||||
|
||||
|
||||
class SecurityError(Exception):
|
||||
"""Security-related errors in LLM operations"""
|
||||
pass
|
||||
|
||||
|
||||
# Global security manager instance
|
||||
security_manager = SecurityManager()
|
||||
@@ -17,9 +17,8 @@ from .models import (
|
||||
)
|
||||
from .config import config_manager, ProviderConfig
|
||||
from ...core.config import settings
|
||||
from .security import security_manager
|
||||
from .resilience import ResilienceManagerFactory
|
||||
from .metrics import metrics_collector
|
||||
# from .metrics import metrics_collector
|
||||
from .providers import BaseLLMProvider, PrivateModeProvider
|
||||
from .exceptions import (
|
||||
LLMError, ProviderError, SecurityError, ConfigurationError,
|
||||
@@ -150,45 +149,8 @@ class LLMService:
|
||||
if not request.messages:
|
||||
raise ValidationError("Messages cannot be empty", field="messages")
|
||||
|
||||
# Security validation (only if enabled)
|
||||
messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
|
||||
if settings.API_SECURITY_ENABLED:
|
||||
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict)
|
||||
else:
|
||||
# Security disabled - always safe
|
||||
is_safe, risk_score, detected_patterns = True, 0.0, []
|
||||
|
||||
if not is_safe:
|
||||
# Log security violation
|
||||
security_manager.create_audit_log(
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id,
|
||||
provider="blocked",
|
||||
model=request.model,
|
||||
request_type="chat_completion",
|
||||
risk_score=risk_score,
|
||||
detected_patterns=[p.get("pattern", "") for p in detected_patterns]
|
||||
)
|
||||
|
||||
# Record blocked request
|
||||
metrics_collector.record_request(
|
||||
provider="security",
|
||||
model=request.model,
|
||||
request_type="chat_completion",
|
||||
success=False,
|
||||
latency_ms=0,
|
||||
security_risk_score=risk_score,
|
||||
error_code="SECURITY_BLOCKED",
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id
|
||||
)
|
||||
|
||||
raise SecurityError(
|
||||
"Request blocked due to security concerns",
|
||||
risk_score=risk_score,
|
||||
details={"detected_patterns": detected_patterns}
|
||||
)
|
||||
# Security validation disabled - always allow requests
|
||||
risk_score = 0.0
|
||||
|
||||
# Get provider for model
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
@@ -197,18 +159,7 @@ class LLMService:
|
||||
if not provider:
|
||||
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
|
||||
|
||||
# Log detailed request if enabled
|
||||
security_manager.log_detailed_request(
|
||||
messages=messages_dict,
|
||||
model=request.model,
|
||||
user_id=request.user_id,
|
||||
provider=provider_name,
|
||||
context_info={
|
||||
"temperature": request.temperature,
|
||||
"max_tokens": request.max_tokens,
|
||||
"risk_score": f"{risk_score:.3f}"
|
||||
}
|
||||
)
|
||||
# Security logging disabled
|
||||
|
||||
# Execute with resilience
|
||||
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
|
||||
@@ -222,85 +173,46 @@ class LLMService:
|
||||
non_retryable_exceptions=(SecurityError, ValidationError)
|
||||
)
|
||||
|
||||
# Update response with security information
|
||||
response.security_check = is_safe
|
||||
response.risk_score = risk_score
|
||||
response.detected_patterns = [p.get("pattern", "") for p in detected_patterns]
|
||||
# Security features disabled
|
||||
|
||||
# Log detailed response if enabled
|
||||
if response.choices:
|
||||
content = response.choices[0].message.content
|
||||
security_manager.log_detailed_response(
|
||||
response_content=content,
|
||||
token_usage=response.usage.model_dump() if response.usage else None,
|
||||
provider=provider_name
|
||||
)
|
||||
# Security logging disabled
|
||||
|
||||
# Record successful request
|
||||
# Record successful request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
metrics_collector.record_request(
|
||||
provider=provider_name,
|
||||
model=request.model,
|
||||
request_type="chat_completion",
|
||||
success=True,
|
||||
latency_ms=total_latency,
|
||||
token_usage=response.usage.model_dump() if response.usage else None,
|
||||
security_risk_score=risk_score,
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id
|
||||
)
|
||||
# metrics_collector.record_request(
|
||||
# provider=provider_name,
|
||||
# model=request.model,
|
||||
# request_type="chat_completion",
|
||||
# success=True,
|
||||
# latency_ms=total_latency,
|
||||
# token_usage=response.usage.model_dump() if response.usage else None,
|
||||
# security_risk_score=risk_score,
|
||||
# user_id=request.user_id,
|
||||
# api_key_id=request.api_key_id
|
||||
# )
|
||||
|
||||
# Create audit log
|
||||
security_manager.create_audit_log(
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id,
|
||||
provider=provider_name,
|
||||
model=request.model,
|
||||
request_type="chat_completion",
|
||||
risk_score=risk_score,
|
||||
detected_patterns=[p.get("pattern", "") for p in detected_patterns],
|
||||
metadata={
|
||||
"success": True,
|
||||
"latency_ms": total_latency,
|
||||
"token_usage": response.usage.model_dump() if response.usage else None
|
||||
}
|
||||
)
|
||||
# Security audit logging disabled
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Record failed request
|
||||
# Record failed request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||
|
||||
# metrics_collector.record_request(
|
||||
# provider=provider_name,
|
||||
# model=request.model,
|
||||
# request_type="chat_completion",
|
||||
# success=False,
|
||||
# latency_ms=total_latency,
|
||||
# security_risk_score=risk_score,
|
||||
# error_code=error_code,
|
||||
# user_id=request.user_id,
|
||||
# api_key_id=request.api_key_id
|
||||
# )
|
||||
|
||||
metrics_collector.record_request(
|
||||
provider=provider_name,
|
||||
model=request.model,
|
||||
request_type="chat_completion",
|
||||
success=False,
|
||||
latency_ms=total_latency,
|
||||
security_risk_score=risk_score,
|
||||
error_code=error_code,
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id
|
||||
)
|
||||
|
||||
# Create audit log for failure
|
||||
security_manager.create_audit_log(
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id,
|
||||
provider=provider_name,
|
||||
model=request.model,
|
||||
request_type="chat_completion",
|
||||
risk_score=risk_score,
|
||||
detected_patterns=[p.get("pattern", "") for p in detected_patterns],
|
||||
metadata={
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"error_code": error_code,
|
||||
"latency_ms": total_latency
|
||||
}
|
||||
)
|
||||
# Security audit logging disabled
|
||||
|
||||
raise
|
||||
|
||||
@@ -309,21 +221,8 @@ class LLMService:
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
# Security validation (same as non-streaming)
|
||||
messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
|
||||
if settings.API_SECURITY_ENABLED:
|
||||
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict)
|
||||
else:
|
||||
# Security disabled - always safe
|
||||
is_safe, risk_score, detected_patterns = True, 0.0, []
|
||||
|
||||
if not is_safe:
|
||||
raise SecurityError(
|
||||
"Streaming request blocked due to security concerns",
|
||||
risk_score=risk_score,
|
||||
details={"detected_patterns": detected_patterns}
|
||||
)
|
||||
# Security validation disabled - always allow streaming requests
|
||||
risk_score = 0.0
|
||||
|
||||
# Get provider
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
@@ -345,19 +244,19 @@ class LLMService:
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
# Record streaming failure
|
||||
# Record streaming failure - metrics disabled
|
||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||
metrics_collector.record_request(
|
||||
provider=provider_name,
|
||||
model=request.model,
|
||||
request_type="chat_completion_stream",
|
||||
success=False,
|
||||
latency_ms=0,
|
||||
security_risk_score=risk_score,
|
||||
error_code=error_code,
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id
|
||||
)
|
||||
# metrics_collector.record_request(
|
||||
# provider=provider_name,
|
||||
# model=request.model,
|
||||
# request_type="chat_completion_stream",
|
||||
# success=False,
|
||||
# latency_ms=0,
|
||||
# security_risk_score=risk_score,
|
||||
# error_code=error_code,
|
||||
# user_id=request.user_id,
|
||||
# api_key_id=request.api_key_id
|
||||
# )
|
||||
raise
|
||||
|
||||
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
||||
@@ -365,23 +264,8 @@ class LLMService:
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
# Security validation for embedding input
|
||||
input_text = request.input if isinstance(request.input, str) else " ".join(request.input)
|
||||
|
||||
if settings.API_SECURITY_ENABLED:
|
||||
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security([
|
||||
{"role": "user", "content": input_text}
|
||||
])
|
||||
else:
|
||||
# Security disabled - always safe
|
||||
is_safe, risk_score, detected_patterns = True, 0.0, []
|
||||
|
||||
if not is_safe:
|
||||
raise SecurityError(
|
||||
"Embedding request blocked due to security concerns",
|
||||
risk_score=risk_score,
|
||||
details={"detected_patterns": detected_patterns}
|
||||
)
|
||||
# Security validation disabled - always allow embedding requests
|
||||
risk_score = 0.0
|
||||
|
||||
# Get provider
|
||||
provider_name = self._get_provider_for_model(request.model)
|
||||
@@ -402,42 +286,40 @@ class LLMService:
|
||||
non_retryable_exceptions=(SecurityError, ValidationError)
|
||||
)
|
||||
|
||||
# Update response with security information
|
||||
response.security_check = is_safe
|
||||
response.risk_score = risk_score
|
||||
# Security features disabled
|
||||
|
||||
# Record successful request
|
||||
# Record successful request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
metrics_collector.record_request(
|
||||
provider=provider_name,
|
||||
model=request.model,
|
||||
request_type="embedding",
|
||||
success=True,
|
||||
latency_ms=total_latency,
|
||||
token_usage=response.usage.model_dump() if response.usage else None,
|
||||
security_risk_score=risk_score,
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id
|
||||
)
|
||||
# metrics_collector.record_request(
|
||||
# provider=provider_name,
|
||||
# model=request.model,
|
||||
# request_type="embedding",
|
||||
# success=True,
|
||||
# latency_ms=total_latency,
|
||||
# token_usage=response.usage.model_dump() if response.usage else None,
|
||||
# security_risk_score=risk_score,
|
||||
# user_id=request.user_id,
|
||||
# api_key_id=request.api_key_id
|
||||
# )
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Record failed request
|
||||
# Record failed request - metrics disabled
|
||||
total_latency = (time.time() - start_time) * 1000
|
||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||
|
||||
metrics_collector.record_request(
|
||||
provider=provider_name,
|
||||
model=request.model,
|
||||
request_type="embedding",
|
||||
success=False,
|
||||
latency_ms=total_latency,
|
||||
security_risk_score=risk_score,
|
||||
error_code=error_code,
|
||||
user_id=request.user_id,
|
||||
api_key_id=request.api_key_id
|
||||
)
|
||||
|
||||
# metrics_collector.record_request(
|
||||
# provider=provider_name,
|
||||
# model=request.model,
|
||||
# request_type="embedding",
|
||||
# success=False,
|
||||
# latency_ms=total_latency,
|
||||
# security_risk_score=risk_score,
|
||||
# error_code=error_code,
|
||||
# user_id=request.user_id,
|
||||
# api_key_id=request.api_key_id
|
||||
# )
|
||||
|
||||
raise
|
||||
|
||||
@@ -492,20 +374,26 @@ class LLMService:
|
||||
return status_dict
|
||||
|
||||
def get_metrics(self) -> LLMMetrics:
|
||||
"""Get service metrics"""
|
||||
return metrics_collector.get_metrics()
|
||||
"""Get service metrics - metrics disabled"""
|
||||
# return metrics_collector.get_metrics()
|
||||
return LLMMetrics(
|
||||
total_requests=0,
|
||||
success_rate=0.0,
|
||||
avg_latency_ms=0,
|
||||
error_rates={}
|
||||
)
|
||||
|
||||
def get_health_summary(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive health summary"""
|
||||
metrics_health = metrics_collector.get_health_summary()
|
||||
"""Get comprehensive health summary - metrics disabled"""
|
||||
# metrics_health = metrics_collector.get_health_summary()
|
||||
resilience_health = ResilienceManagerFactory.get_all_health_status()
|
||||
|
||||
|
||||
return {
|
||||
"service_status": "healthy" if self._initialized else "initializing",
|
||||
"startup_time": self._startup_time.isoformat() if self._startup_time else None,
|
||||
"provider_count": len(self._providers),
|
||||
"active_providers": list(self._providers.keys()),
|
||||
"metrics": metrics_health,
|
||||
"metrics": {"status": "disabled"},
|
||||
"resilience": resilience_health
|
||||
}
|
||||
|
||||
|
||||
@@ -1,153 +0,0 @@
|
||||
"""
|
||||
Token-based rate limiting for LLM service
|
||||
"""
|
||||
|
||||
import time
|
||||
import redis
|
||||
from typing import Dict, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from ..core.config import settings
|
||||
from ..core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenRateLimiter:
|
||||
"""Token-based rate limiting implementation"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
self.redis_client.ping()
|
||||
logger.info("Token rate limiter initialized with Redis backend")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis not available for token rate limiting: {e}")
|
||||
self.redis_client = None
|
||||
# Fall back to in-memory rate limiting
|
||||
self.in_memory_store = {}
|
||||
logger.info("Token rate limiter using in-memory fallback")
|
||||
|
||||
async def check_token_limits(
|
||||
self,
|
||||
provider: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int = 0
|
||||
) -> Tuple[bool, Dict[str, str]]:
|
||||
"""
|
||||
Check if token usage is within limits
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g., "privatemode")
|
||||
prompt_tokens: Number of prompt tokens to use
|
||||
completion_tokens: Number of completion tokens to use
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, headers)
|
||||
"""
|
||||
# Get token limits from configuration
|
||||
from .config import get_config
|
||||
config = get_config()
|
||||
token_limits = config.token_limits_per_minute
|
||||
|
||||
# Check organization-wide limits
|
||||
org_key = f"tokens:org:{provider}"
|
||||
|
||||
# Get current usage
|
||||
current_usage = await self._get_token_usage(org_key)
|
||||
|
||||
# Calculate new usage
|
||||
new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens
|
||||
new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens
|
||||
|
||||
# Check limits
|
||||
prompt_limit = token_limits.get("prompt_tokens", 20000)
|
||||
completion_limit = token_limits.get("completion_tokens", 10000)
|
||||
|
||||
is_allowed = (
|
||||
new_prompt_tokens <= prompt_limit and
|
||||
new_completion_tokens <= completion_limit
|
||||
)
|
||||
|
||||
if is_allowed:
|
||||
# Update usage
|
||||
await self._update_token_usage(org_key, prompt_tokens, completion_tokens)
|
||||
logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, "
|
||||
f"{new_completion_tokens}/{completion_limit} completion")
|
||||
|
||||
# Calculate remaining tokens
|
||||
remaining_prompt = max(0, prompt_limit - new_prompt_tokens)
|
||||
remaining_completion = max(0, completion_limit - new_completion_tokens)
|
||||
|
||||
# Create headers
|
||||
headers = {
|
||||
"X-TokenLimit-Prompt-Remaining": str(remaining_prompt),
|
||||
"X-TokenLimit-Completion-Remaining": str(remaining_completion),
|
||||
"X-TokenLimit-Prompt-Limit": str(prompt_limit),
|
||||
"X-TokenLimit-Completion-Limit": str(completion_limit),
|
||||
"X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute
|
||||
}
|
||||
|
||||
if not is_allowed:
|
||||
logger.warning(f"Token rate limit exceeded for {provider}. "
|
||||
f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. "
|
||||
f"Current: {current_usage}")
|
||||
|
||||
return is_allowed, headers
|
||||
|
||||
async def _get_token_usage(self, key: str) -> Dict[str, int]:
|
||||
"""Get current token usage"""
|
||||
if self.redis_client:
|
||||
try:
|
||||
data = self.redis_client.hgetall(key)
|
||||
if data:
|
||||
return {
|
||||
"prompt_tokens": int(data.get("prompt_tokens", 0)),
|
||||
"completion_tokens": int(data.get("completion_tokens", 0)),
|
||||
"updated_at": float(data.get("updated_at", time.time()))
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting token usage from Redis: {e}")
|
||||
|
||||
# Fallback to in-memory
|
||||
return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0})
|
||||
|
||||
async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int):
|
||||
"""Update token usage"""
|
||||
if self.redis_client:
|
||||
try:
|
||||
pipe = self.redis_client.pipeline()
|
||||
pipe.hincrby(key, "prompt_tokens", prompt_tokens)
|
||||
pipe.hincrby(key, "completion_tokens", completion_tokens)
|
||||
pipe.hset(key, "updated_at", time.time())
|
||||
pipe.expire(key, 60) # Expire after 1 minute
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating token usage in Redis: {e}")
|
||||
# Fallback to in-memory
|
||||
self._update_in_memory(key, prompt_tokens, completion_tokens)
|
||||
else:
|
||||
self._update_in_memory(key, prompt_tokens, completion_tokens)
|
||||
|
||||
def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int):
|
||||
"""Update in-memory token usage"""
|
||||
if key not in self.in_memory_store:
|
||||
self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
|
||||
self.in_memory_store[key]["prompt_tokens"] += prompt_tokens
|
||||
self.in_memory_store[key]["completion_tokens"] += completion_tokens
|
||||
self.in_memory_store[key]["updated_at"] = time.time()
|
||||
|
||||
def cleanup_expired(self):
|
||||
"""Clean up expired entries (for in-memory store)"""
|
||||
if not self.redis_client:
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, data in self.in_memory_store.items()
|
||||
if current_time - data.get("updated_at", 0) > 60
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self.in_memory_store[key]
|
||||
|
||||
|
||||
# Global token rate limiter instance
|
||||
token_rate_limiter = TokenRateLimiter()
|
||||
@@ -755,10 +755,11 @@ class RAGService:
|
||||
|
||||
# Process with RAG module
|
||||
try:
|
||||
# Pass file_path in metadata so JSONL indexing can reopen the source file
|
||||
processed_doc = await rag_module.process_document(
|
||||
file_content,
|
||||
document.original_filename,
|
||||
{}
|
||||
file_content,
|
||||
document.original_filename,
|
||||
{"file_path": document.file_path}
|
||||
)
|
||||
|
||||
# Success case - update document with processed content
|
||||
@@ -873,4 +874,4 @@ class RAGService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reprocessing document {document_id}: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -638,11 +638,19 @@ class RAGModule(BaseModule):
|
||||
np.random.seed(hash(text) % 2**32)
|
||||
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
|
||||
|
||||
async def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts (batch processing)"""
|
||||
if self.embedding_service:
|
||||
# Add task-specific prefixes for better E5 model performance
|
||||
if is_document:
|
||||
# For document passages, use "passage:" prefix
|
||||
prefixed_texts = [f"passage: {text}" for text in texts]
|
||||
else:
|
||||
# For queries, use "query:" prefix (handled in search method)
|
||||
prefixed_texts = texts
|
||||
|
||||
# Use real embedding service for batch processing
|
||||
return await self.embedding_service.get_embeddings(texts)
|
||||
return await self.embedding_service.get_embeddings(prefixed_texts)
|
||||
else:
|
||||
# Fallback to individual processing
|
||||
embeddings = []
|
||||
@@ -917,69 +925,75 @@ class RAGModule(BaseModule):
|
||||
|
||||
async def _process_jsonl(self, content: bytes, filename: str) -> str:
|
||||
"""Process JSONL files (newline-delimited JSON)
|
||||
|
||||
|
||||
Specifically optimized for helpjuice-export.jsonl format:
|
||||
- Each line contains a JSON object with 'id' and 'payload'
|
||||
- Payload contains 'question', 'language', and 'answer' fields
|
||||
- Combines question and answer into searchable content
|
||||
|
||||
Performance optimizations:
|
||||
- Processes articles in smaller batches to reduce memory usage
|
||||
- Uses streaming approach for large files
|
||||
"""
|
||||
try:
|
||||
# Use streaming approach for large files
|
||||
jsonl_content = content.decode('utf-8', errors='replace')
|
||||
lines = jsonl_content.strip().split('\n')
|
||||
|
||||
|
||||
processed_articles = []
|
||||
|
||||
batch_size = 50 # Process in batches of 50 articles
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# Parse each JSON line
|
||||
data = json.loads(line)
|
||||
|
||||
|
||||
# Handle helpjuice export format
|
||||
if 'payload' in data:
|
||||
payload = data['payload']
|
||||
article_id = data.get('id', f'article_{line_num}')
|
||||
|
||||
|
||||
# Extract fields
|
||||
question = payload.get('question', '')
|
||||
answer = payload.get('answer', '')
|
||||
language = payload.get('language', 'EN')
|
||||
|
||||
|
||||
# Combine question and answer for better search
|
||||
if question or answer:
|
||||
# Format as Q&A for better context
|
||||
article_text = f"## {question}\n\n{answer}\n\n"
|
||||
|
||||
|
||||
# Add language tag if not English
|
||||
if language != 'EN':
|
||||
article_text = f"[{language}] {article_text}"
|
||||
|
||||
|
||||
# Add metadata separator
|
||||
article_text += f"---\nArticle ID: {article_id}\nLanguage: {language}\n\n"
|
||||
|
||||
|
||||
processed_articles.append(article_text)
|
||||
|
||||
|
||||
# Handle generic JSONL format
|
||||
else:
|
||||
# Convert the entire JSON object to readable text
|
||||
json_text = json.dumps(data, indent=2, ensure_ascii=False)
|
||||
processed_articles.append(json_text + "\n\n")
|
||||
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Error parsing JSONL line {line_num}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing JSONL line {line_num}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Combine all articles
|
||||
combined_text = '\n'.join(processed_articles)
|
||||
|
||||
|
||||
logger.info(f"Successfully processed {len(processed_articles)} articles from JSONL file {filename}")
|
||||
return combined_text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing JSONL file {filename}: {e}")
|
||||
return ""
|
||||
@@ -1153,7 +1167,7 @@ class RAGModule(BaseModule):
|
||||
chunks = self._chunk_text(content)
|
||||
|
||||
# Generate embeddings for all chunks in batch (more efficient)
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
embeddings = await self._generate_embeddings(chunks, is_document=True)
|
||||
|
||||
# Create document points
|
||||
points = []
|
||||
@@ -1200,10 +1214,28 @@ class RAGModule(BaseModule):
|
||||
"""Index a processed document in the vector database"""
|
||||
if not self.enabled:
|
||||
raise RuntimeError("RAG module not initialized")
|
||||
|
||||
|
||||
collection_name = collection_name or self.default_collection_name
|
||||
|
||||
|
||||
try:
|
||||
# Special handling for JSONL files
|
||||
if processed_doc.file_type == 'jsonl':
|
||||
# Import the optimized JSONL processor
|
||||
from app.services.jsonl_processor import JSONLProcessor
|
||||
jsonl_processor = JSONLProcessor(self)
|
||||
|
||||
# Read the original file content
|
||||
with open(processed_doc.metadata.get('file_path', ''), 'rb') as f:
|
||||
file_content = f.read()
|
||||
|
||||
# Process using the optimized JSONL processor
|
||||
return await jsonl_processor.process_and_index_jsonl(
|
||||
collection_name=collection_name,
|
||||
content=file_content,
|
||||
filename=processed_doc.original_filename,
|
||||
metadata=processed_doc.metadata
|
||||
)
|
||||
|
||||
# Ensure collection exists
|
||||
await self._ensure_collection_exists(collection_name)
|
||||
|
||||
@@ -1216,7 +1248,7 @@ class RAGModule(BaseModule):
|
||||
chunks = self._chunk_text(processed_doc.content)
|
||||
|
||||
# Generate embeddings for all chunks in batch (more efficient)
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
embeddings = await self._generate_embeddings(chunks, is_document=True)
|
||||
|
||||
# Create document points with enhanced metadata
|
||||
points = []
|
||||
@@ -1339,24 +1371,48 @@ class RAGModule(BaseModule):
|
||||
score_threshold=score_threshold / 2 # Lower threshold for initial search
|
||||
)
|
||||
|
||||
# Combine scores
|
||||
# Combine scores with improved normalization
|
||||
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
|
||||
vector_weight = hybrid_weights.get("vector", 0.7)
|
||||
bm25_weight = hybrid_weights.get("bm25", 0.3)
|
||||
|
||||
# Create hybrid results
|
||||
# Get score distributions for better normalization
|
||||
vector_scores = [r.score for r in vector_results]
|
||||
bm25_scores_list = list(bm25_scores.values())
|
||||
|
||||
# Calculate statistics for normalization
|
||||
if vector_scores:
|
||||
v_max = max(vector_scores)
|
||||
v_min = min(vector_scores)
|
||||
v_range = v_max - v_min if v_max != v_min else 1
|
||||
else:
|
||||
v_max, v_min, v_range = 1, 0, 1
|
||||
|
||||
if bm25_scores_list:
|
||||
bm25_max = max(bm25_scores_list)
|
||||
bm25_min = min(bm25_scores_list)
|
||||
bm25_range = bm25_max - bm25_min if bm25_max != bm25_min else 1
|
||||
else:
|
||||
bm25_max, bm25_min, bm25_range = 1, 0, 1
|
||||
|
||||
# Create hybrid results with improved scoring
|
||||
hybrid_results = []
|
||||
for result in vector_results:
|
||||
doc_id = result.payload.get("document_id", "")
|
||||
vector_score = result.score
|
||||
bm25_score = bm25_scores.get(doc_id, 0.0)
|
||||
|
||||
# Normalize scores (simple min-max normalization)
|
||||
vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0
|
||||
bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1
|
||||
# Improved normalization using actual score distributions
|
||||
vector_norm = (vector_score - v_min) / v_range if v_range > 0 else 0.5
|
||||
bm25_norm = (bm25_score - bm25_min) / bm25_range if bm25_range > 0 else 0.5
|
||||
|
||||
# Calculate hybrid score
|
||||
hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm)
|
||||
# Apply reciprocal rank fusion for better combination
|
||||
# This gives more weight to documents that rank highly in both methods
|
||||
rrf_vector = 1.0 / (1.0 + vector_results.index(result) + 1) # +1 to avoid division by zero
|
||||
rrf_bm25 = 1.0 / (1.0 + sorted(bm25_scores_list, reverse=True).index(bm25_score) + 1) if bm25_score in bm25_scores_list else 0
|
||||
|
||||
# Calculate hybrid score using both normalized scores and RRF
|
||||
hybrid_score = (vector_weight * vector_norm + bm25_weight * bm25_norm) * 0.7 + (rrf_vector + rrf_bm25) * 0.3
|
||||
|
||||
# Create new point with hybrid score
|
||||
hybrid_point = ScoredPoint(
|
||||
@@ -1435,7 +1491,7 @@ class RAGModule(BaseModule):
|
||||
# Normalize score to 0-1 range
|
||||
return min(score / 10.0, 1.0) # Simple normalization
|
||||
|
||||
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
|
||||
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
|
||||
"""Search for relevant documents"""
|
||||
if not self.enabled:
|
||||
raise RuntimeError("RAG module not initialized")
|
||||
@@ -1453,8 +1509,10 @@ class RAGModule(BaseModule):
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await self._generate_embedding(query)
|
||||
# Generate query embedding with task-specific prefix for better retrieval
|
||||
# The E5 model works better with "query:" prefix for search queries
|
||||
optimized_query = f"query: {query}"
|
||||
query_embedding = await self._generate_embedding(optimized_query)
|
||||
|
||||
# Build filter
|
||||
search_filter = None
|
||||
@@ -1474,7 +1532,8 @@ class RAGModule(BaseModule):
|
||||
|
||||
# Check if hybrid search is enabled
|
||||
enable_hybrid = self.config.get("enable_hybrid", False)
|
||||
score_threshold = self.config.get("score_threshold", 0.3)
|
||||
# Use provided score_threshold or fall back to config
|
||||
search_score_threshold = score_threshold if score_threshold is not None else self.config.get("score_threshold", 0.3)
|
||||
|
||||
if enable_hybrid and NLTK_AVAILABLE:
|
||||
# Perform hybrid search (vector + BM25)
|
||||
@@ -1484,7 +1543,7 @@ class RAGModule(BaseModule):
|
||||
query_vector=query_embedding,
|
||||
query_filter=search_filter,
|
||||
limit=max_results,
|
||||
score_threshold=score_threshold
|
||||
score_threshold=search_score_threshold
|
||||
)
|
||||
else:
|
||||
# Pure vector search with improved threshold
|
||||
@@ -1493,7 +1552,7 @@ class RAGModule(BaseModule):
|
||||
query_vector=query_embedding,
|
||||
query_filter=search_filter,
|
||||
limit=max_results,
|
||||
score_threshold=score_threshold
|
||||
score_threshold=search_score_threshold
|
||||
)
|
||||
|
||||
logger.info(f"Raw search results count: {len(search_results)}")
|
||||
@@ -1841,9 +1900,9 @@ async def index_processed_document(processed_doc: ProcessedDocument, collection_
|
||||
"""Index a processed document"""
|
||||
return await rag_module.index_processed_document(processed_doc, collection_name)
|
||||
|
||||
async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
|
||||
async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
|
||||
"""Search documents"""
|
||||
return await rag_module.search_documents(query, max_results, filters, collection_name)
|
||||
return await rag_module.search_documents(query, max_results, filters, collection_name, score_threshold)
|
||||
|
||||
async def delete_document(document_id: str, collection_name: str = None) -> bool:
|
||||
"""Delete a document"""
|
||||
|
||||
@@ -7,7 +7,7 @@ export async function POST(request: NextRequest) {
|
||||
|
||||
// Make request to backend auth endpoint without requiring existing auth
|
||||
const baseUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`
|
||||
const url = `${baseUrl}/api/auth/login`
|
||||
const url = `${baseUrl}/api-internal/v1/auth/login`
|
||||
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
|
||||
@@ -85,8 +85,31 @@ function RAGPageContent() {
|
||||
const loadStats = async () => {
|
||||
try {
|
||||
const data = await apiClient.get('/api-internal/v1/rag/stats')
|
||||
setStats(data.stats)
|
||||
console.log('Stats API response:', data)
|
||||
|
||||
// Check if the response has the expected structure
|
||||
if (data && data.stats && data.stats.collections) {
|
||||
console.log('✓ Stats has collections property')
|
||||
setStats(data.stats)
|
||||
} else {
|
||||
console.error('✗ Invalid stats structure:', data)
|
||||
// Set default empty stats to prevent error
|
||||
setStats({
|
||||
collections: { total: 0, active: 0 },
|
||||
documents: { total: 0, processing: 0, processed: 0 },
|
||||
storage: { total_size_bytes: 0, total_size_mb: 0 },
|
||||
vectors: { total: 0 }
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error loading stats:', error)
|
||||
// Set default empty stats on error
|
||||
setStats({
|
||||
collections: { total: 0, active: 0 },
|
||||
documents: { total: 0, processing: 0, processed: 0 },
|
||||
storage: { total_size_bytes: 0, total_size_mb: 0 },
|
||||
vectors: { total: 0 }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import { Badge } from "@/components/ui/badge"
|
||||
import { Separator } from "@/components/ui/separator"
|
||||
import { AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger } from "@/components/ui/alert-dialog"
|
||||
import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog"
|
||||
import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter } from "lucide-react"
|
||||
import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter, RefreshCw } from "lucide-react"
|
||||
import { useToast } from "@/hooks/use-toast"
|
||||
import { apiClient } from "@/lib/api-client"
|
||||
import { config } from "@/lib/config"
|
||||
@@ -56,6 +56,7 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
|
||||
const [filterStatus, setFilterStatus] = useState("all")
|
||||
const [selectedDocument, setSelectedDocument] = useState<Document | null>(null)
|
||||
const [deleting, setDeleting] = useState<string | null>(null)
|
||||
const [reprocessing, setReprocessing] = useState<string | null>(null)
|
||||
const { toast } = useToast()
|
||||
|
||||
useEffect(() => {
|
||||
@@ -157,6 +158,43 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
|
||||
}
|
||||
}
|
||||
|
||||
const handleReprocessDocument = async (documentId: string) => {
|
||||
setReprocessing(documentId)
|
||||
|
||||
try {
|
||||
await apiClient.post(`/api-internal/v1/rag/documents/${documentId}/reprocess`)
|
||||
|
||||
// Update the document status to processing in the UI
|
||||
setDocuments(prev => prev.map(doc =>
|
||||
doc.id === documentId
|
||||
? { ...doc, status: 'processing' as const, processed_at: new Date().toISOString() }
|
||||
: doc
|
||||
))
|
||||
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "Document reprocessing started",
|
||||
})
|
||||
|
||||
// Reload documents after a short delay to see status updates
|
||||
setTimeout(() => {
|
||||
loadDocuments()
|
||||
}, 2000)
|
||||
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : "Failed to reprocess document"
|
||||
toast({
|
||||
title: "Error",
|
||||
description: errorMessage.includes("Cannot reprocess document with status 'processed'")
|
||||
? "Cannot reprocess documents that are already processed"
|
||||
: errorMessage,
|
||||
variant: "destructive",
|
||||
})
|
||||
} finally {
|
||||
setReprocessing(null)
|
||||
}
|
||||
}
|
||||
|
||||
const formatFileSize = (bytes: number) => {
|
||||
if (bytes === 0) return '0 Bytes'
|
||||
const k = 1024
|
||||
@@ -432,6 +470,21 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
|
||||
<Download className="h-4 w-4" />
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-8 w-8 p-0 hover:bg-blue-100"
|
||||
onClick={() => handleReprocessDocument(document.id)}
|
||||
disabled={reprocessing === document.id || document.status === 'processed'}
|
||||
title={document.status === 'processed' ? "Document already processed" : "Reprocess document"}
|
||||
>
|
||||
{reprocessing === document.id ? (
|
||||
<RefreshCw className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<RefreshCw className={`h-4 w-4 ${document.status === 'processed' ? 'text-gray-400' : ''}`} />
|
||||
)}
|
||||
</Button>
|
||||
|
||||
<AlertDialog>
|
||||
<AlertDialogTrigger asChild>
|
||||
<Button
|
||||
|
||||
@@ -67,12 +67,13 @@ const Navigation = () => {
|
||||
// Core navigation items that are always visible
|
||||
const coreNavItems = [
|
||||
{ href: "/dashboard", label: "Dashboard" },
|
||||
{
|
||||
href: "/llm",
|
||||
{
|
||||
href: "/llm",
|
||||
label: "LLM",
|
||||
children: [
|
||||
{ href: "/llm", label: "Models & Config" },
|
||||
{ href: "/playground", label: "Playground" },
|
||||
{ href: "/rag-demo", label: "RAG Demo" },
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -25,6 +25,12 @@ http {
|
||||
listen 80;
|
||||
server_name localhost;
|
||||
|
||||
# Static files - serve directly from nginx
|
||||
location = /login_helper.html {
|
||||
root /usr/share/nginx/html;
|
||||
try_files $uri =404;
|
||||
}
|
||||
|
||||
# Frontend routes
|
||||
location / {
|
||||
proxy_pass http://frontend;
|
||||
@@ -32,7 +38,7 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
|
||||
# WebSocket support for Next.js HMR
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
@@ -65,6 +71,58 @@ http {
|
||||
}
|
||||
}
|
||||
|
||||
# RAG debug API routes - proxy to frontend (for Next.js API routes)
|
||||
location /api/rag/debug/ {
|
||||
proxy_pass http://frontend;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# CORS headers
|
||||
add_header 'Access-Control-Allow-Origin' '*' always;
|
||||
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always;
|
||||
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
|
||||
add_header 'Access-Control-Allow-Expose-Headers' 'Content-Length,Content-Range' always;
|
||||
|
||||
# Handle preflight requests
|
||||
if ($request_method = 'OPTIONS') {
|
||||
add_header 'Access-Control-Allow-Origin' '*';
|
||||
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS';
|
||||
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization';
|
||||
add_header 'Access-Control-Max-Age' 1728000;
|
||||
add_header 'Content-Type' 'text/plain; charset=utf-8';
|
||||
add_header 'Content-Length' 0;
|
||||
return 204;
|
||||
}
|
||||
}
|
||||
|
||||
# Frontend API routes for authentication - proxy to frontend
|
||||
location /api/auth/ {
|
||||
proxy_pass http://frontend;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# CORS headers
|
||||
add_header 'Access-Control-Allow-Origin' '*' always;
|
||||
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always;
|
||||
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
|
||||
add_header 'Access-Control-Expose-Headers' 'Content-Length,Content-Range' always;
|
||||
|
||||
# Handle preflight requests
|
||||
if ($request_method = 'OPTIONS') {
|
||||
add_header 'Access-Control-Allow-Origin' '*';
|
||||
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS';
|
||||
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization';
|
||||
add_header 'Access-Control-Max-Age' 1728000;
|
||||
add_header 'Content-Type' 'text/plain; charset=utf-8';
|
||||
add_header 'Content-Length' 0;
|
||||
return 204;
|
||||
}
|
||||
}
|
||||
|
||||
# Public API routes - proxy to backend (for external clients)
|
||||
location /api/ {
|
||||
proxy_pass http://backend;
|
||||
@@ -72,13 +130,13 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
|
||||
# CORS headers for external clients
|
||||
add_header 'Access-Control-Allow-Origin' '*' always;
|
||||
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always;
|
||||
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
|
||||
add_header 'Access-Control-Expose-Headers' 'Content-Length,Content-Range' always;
|
||||
|
||||
|
||||
# Handle preflight requests
|
||||
if ($request_method = 'OPTIONS') {
|
||||
add_header 'Access-Control-Allow-Origin' '*';
|
||||
|
||||
Reference in New Issue
Block a user