mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
rag improvements
This commit is contained in:
2
.env
2
.env
@@ -46,7 +46,7 @@ API_RATE_LIMITING_ENABLED=false
|
|||||||
# ===================================
|
# ===================================
|
||||||
# APPLICATION BASE URL (Required - derives all URLs and CORS)
|
# APPLICATION BASE URL (Required - derives all URLs and CORS)
|
||||||
# ===================================
|
# ===================================
|
||||||
BASE_URL=localhost
|
BASE_URL=localhost:80
|
||||||
# Frontend derives: APP_URL=http://localhost, API_URL=http://localhost, WS_URL=ws://localhost
|
# Frontend derives: APP_URL=http://localhost, API_URL=http://localhost, WS_URL=ws://localhost
|
||||||
# Backend derives: CORS_ORIGINS=["http://localhost"]
|
# Backend derives: CORS_ORIGINS=["http://localhost"]
|
||||||
|
|
||||||
|
|||||||
10
.env.example
10
.env.example
@@ -65,6 +65,16 @@ QDRANT_HOST=enclava-qdrant
|
|||||||
QDRANT_PORT=6333
|
QDRANT_PORT=6333
|
||||||
QDRANT_URL=http://enclava-qdrant:6333
|
QDRANT_URL=http://enclava-qdrant:6333
|
||||||
|
|
||||||
|
# ===================================
|
||||||
|
# RAG EMBEDDING CONFIGURATION (Optional overrides)
|
||||||
|
# ===================================
|
||||||
|
# These control embedding throughput to avoid provider 429s.
|
||||||
|
# Defaults are conservative; uncomment to override.
|
||||||
|
# RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE=12
|
||||||
|
# RAG_EMBEDDING_BATCH_SIZE=3
|
||||||
|
# RAG_EMBEDDING_DELAY_BETWEEN_BATCHES=1.0 # seconds
|
||||||
|
# RAG_EMBEDDING_DELAY_PER_REQUEST=0.5 # seconds
|
||||||
|
|
||||||
# ===================================
|
# ===================================
|
||||||
# OPTIONAL PRIVATEMODE SETTINGS (Have defaults)
|
# OPTIONAL PRIVATEMODE SETTINGS (Have defaults)
|
||||||
# ===================================
|
# ===================================
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ from ..v1.audit import router as audit_router
|
|||||||
from ..v1.settings import router as settings_router
|
from ..v1.settings import router as settings_router
|
||||||
from ..v1.analytics import router as analytics_router
|
from ..v1.analytics import router as analytics_router
|
||||||
from ..v1.rag import router as rag_router
|
from ..v1.rag import router as rag_router
|
||||||
|
from ..rag_debug import router as rag_debug_router
|
||||||
from ..v1.prompt_templates import router as prompt_templates_router
|
from ..v1.prompt_templates import router as prompt_templates_router
|
||||||
from ..v1.security import router as security_router
|
|
||||||
from ..v1.plugin_registry import router as plugin_registry_router
|
from ..v1.plugin_registry import router as plugin_registry_router
|
||||||
from ..v1.platform import router as platform_router
|
from ..v1.platform import router as platform_router
|
||||||
from ..v1.llm_internal import router as llm_internal_router
|
from ..v1.llm_internal import router as llm_internal_router
|
||||||
@@ -52,11 +52,12 @@ internal_api_router.include_router(analytics_router, prefix="/analytics", tags=[
|
|||||||
# Include RAG routes (frontend RAG document management)
|
# Include RAG routes (frontend RAG document management)
|
||||||
internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"])
|
internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"])
|
||||||
|
|
||||||
|
# Include RAG debug routes (for demo and debugging)
|
||||||
|
internal_api_router.include_router(rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"])
|
||||||
|
|
||||||
# Include prompt template routes (frontend prompt template management)
|
# Include prompt template routes (frontend prompt template management)
|
||||||
internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"])
|
internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"])
|
||||||
|
|
||||||
# Include security routes (frontend security settings)
|
|
||||||
internal_api_router.include_router(security_router, prefix="/security", tags=["internal-security"])
|
|
||||||
|
|
||||||
# Include plugin registry routes (frontend plugin management)
|
# Include plugin registry routes (frontend plugin management)
|
||||||
internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"])
|
internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"])
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from .analytics import router as analytics_router
|
|||||||
from .rag import router as rag_router
|
from .rag import router as rag_router
|
||||||
from .chatbot import router as chatbot_router
|
from .chatbot import router as chatbot_router
|
||||||
from .prompt_templates import router as prompt_templates_router
|
from .prompt_templates import router as prompt_templates_router
|
||||||
from .security import router as security_router
|
|
||||||
from .plugin_registry import router as plugin_registry_router
|
from .plugin_registry import router as plugin_registry_router
|
||||||
|
|
||||||
# Create main API router
|
# Create main API router
|
||||||
@@ -61,8 +60,6 @@ api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
|
|||||||
# Include prompt template routes
|
# Include prompt template routes
|
||||||
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
|
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
|
||||||
|
|
||||||
# Include security routes
|
|
||||||
api_router.include_router(security_router, prefix="/security", tags=["security"])
|
|
||||||
|
|
||||||
|
|
||||||
# Include plugin registry routes
|
# Include plugin registry routes
|
||||||
|
|||||||
@@ -745,7 +745,6 @@ async def get_llm_metrics(
|
|||||||
"total_requests": metrics.total_requests,
|
"total_requests": metrics.total_requests,
|
||||||
"successful_requests": metrics.successful_requests,
|
"successful_requests": metrics.successful_requests,
|
||||||
"failed_requests": metrics.failed_requests,
|
"failed_requests": metrics.failed_requests,
|
||||||
"security_blocked_requests": metrics.security_blocked_requests,
|
|
||||||
"average_latency_ms": metrics.average_latency_ms,
|
"average_latency_ms": metrics.average_latency_ms,
|
||||||
"average_risk_score": metrics.average_risk_score,
|
"average_risk_score": metrics.average_risk_score,
|
||||||
"provider_metrics": metrics.provider_metrics,
|
"provider_metrics": metrics.provider_metrics,
|
||||||
|
|||||||
@@ -3,12 +3,14 @@ RAG API Endpoints
|
|||||||
Provides REST API for RAG (Retrieval Augmented Generation) operations
|
Provides REST API for RAG (Retrieval Augmented Generation) operations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Dict, Any
|
||||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import io
|
import io
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from app.db.database import get_db
|
from app.db.database import get_db
|
||||||
from app.core.security import get_current_user
|
from app.core.security import get_current_user
|
||||||
@@ -16,6 +18,9 @@ from app.models.user import User
|
|||||||
from app.services.rag_service import RAGService
|
from app.services.rag_service import RAGService
|
||||||
from app.utils.exceptions import APIException
|
from app.utils.exceptions import APIException
|
||||||
|
|
||||||
|
# Import RAG module from module manager
|
||||||
|
from app.services.module_manager import module_manager
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(tags=["RAG"])
|
router = APIRouter(tags=["RAG"])
|
||||||
|
|
||||||
@@ -78,14 +83,25 @@ async def get_collections(
|
|||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""Get all RAG collections from Qdrant (source of truth) with PostgreSQL metadata"""
|
"""Get all RAG collections - live data directly from Qdrant (source of truth)"""
|
||||||
try:
|
try:
|
||||||
rag_service = RAGService(db)
|
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||||
collections_data = await rag_service.get_all_collections(skip=skip, limit=limit)
|
|
||||||
|
# Get live stats from Qdrant
|
||||||
|
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||||
|
collections = stats_data.get("collections", [])
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
start_idx = skip
|
||||||
|
end_idx = skip + limit
|
||||||
|
paginated_collections = collections[start_idx:end_idx]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"collections": collections_data,
|
"collections": paginated_collections,
|
||||||
"total": len(collections_data)
|
"total": len(collections),
|
||||||
|
"total_documents": stats_data.get("total_documents", 0),
|
||||||
|
"total_size_bytes": stats_data.get("total_size_bytes", 0)
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -116,6 +132,62 @@ async def create_collection(
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/stats", response_model=dict)
|
||||||
|
async def get_rag_stats(
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Get overall RAG statistics - live data directly from Qdrant"""
|
||||||
|
try:
|
||||||
|
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||||
|
|
||||||
|
# Get live stats from Qdrant
|
||||||
|
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||||
|
|
||||||
|
# Calculate active collections (collections with documents)
|
||||||
|
active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
|
||||||
|
|
||||||
|
# Calculate processing documents from database
|
||||||
|
processing_docs = 0
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models.rag_document import RagDocument, ProcessingStatus
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
|
||||||
|
)
|
||||||
|
processing_docs = len(result.scalars().all())
|
||||||
|
except Exception:
|
||||||
|
pass # If database query fails, default to 0
|
||||||
|
|
||||||
|
response_data = {
|
||||||
|
"success": True,
|
||||||
|
"stats": {
|
||||||
|
"collections": {
|
||||||
|
"total": stats_data.get("total_collections", 0),
|
||||||
|
"active": active_collections
|
||||||
|
},
|
||||||
|
"documents": {
|
||||||
|
"total": stats_data.get("total_documents", 0),
|
||||||
|
"processing": processing_docs,
|
||||||
|
"processed": stats_data.get("total_documents", 0) # Indexed documents
|
||||||
|
},
|
||||||
|
"storage": {
|
||||||
|
"total_size_bytes": stats_data.get("total_size_bytes", 0),
|
||||||
|
"total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
|
||||||
|
},
|
||||||
|
"vectors": {
|
||||||
|
"total": stats_data.get("total_documents", 0) # Same as documents for RAG
|
||||||
|
},
|
||||||
|
"last_updated": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return response_data
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/collections/{collection_id}", response_model=dict)
|
@router.get("/collections/{collection_id}", response_model=dict)
|
||||||
async def get_collection(
|
async def get_collection(
|
||||||
collection_id: int,
|
collection_id: int,
|
||||||
@@ -232,11 +304,55 @@ async def upload_document(
|
|||||||
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
|
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
|
||||||
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
|
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
|
||||||
|
|
||||||
|
# Validate file can be read before processing
|
||||||
|
filename = file.filename or "unknown"
|
||||||
|
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test file readability based on type
|
||||||
|
if file_extension == 'jsonl':
|
||||||
|
# Validate JSONL format - try to parse first few lines
|
||||||
|
try:
|
||||||
|
content_str = file_content.decode('utf-8')
|
||||||
|
lines = content_str.strip().split('\n')[:5] # Check first 5 lines
|
||||||
|
import json
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if line.strip(): # Skip empty lines
|
||||||
|
json.loads(line) # Will raise JSONDecodeError if invalid
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
|
||||||
|
|
||||||
|
elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
|
||||||
|
# Validate text files can be decoded
|
||||||
|
try:
|
||||||
|
file_content.decode('utf-8')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||||
|
|
||||||
|
elif file_extension in ['pdf']:
|
||||||
|
# For PDF files, just check if it starts with PDF signature
|
||||||
|
if not file_content.startswith(b'%PDF'):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid PDF file format")
|
||||||
|
|
||||||
|
elif file_extension in ['docx', 'xlsx', 'pptx']:
|
||||||
|
# For Office documents, check ZIP signature
|
||||||
|
if not file_content.startswith(b'PK'):
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
|
||||||
|
|
||||||
|
# For other file types, we'll rely on the document processor
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
|
||||||
|
|
||||||
rag_service = RAGService(db)
|
rag_service = RAGService(db)
|
||||||
document = await rag_service.upload_document(
|
document = await rag_service.upload_document(
|
||||||
collection_id=collection_id,
|
collection_id=collection_id,
|
||||||
file_content=file_content,
|
file_content=file_content,
|
||||||
filename=file.filename or "unknown",
|
filename=filename,
|
||||||
content_type=file.content_type
|
content_type=file.content_type
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -362,21 +478,167 @@ async def download_document(
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
# Stats Endpoint
|
|
||||||
|
|
||||||
@router.get("/stats", response_model=dict)
|
# Debug Endpoints
|
||||||
async def get_rag_stats(
|
|
||||||
db: AsyncSession = Depends(get_db),
|
@router.post("/debug/search")
|
||||||
|
async def search_with_debug(
|
||||||
|
query: str,
|
||||||
|
max_results: int = 10,
|
||||||
|
score_threshold: float = 0.3,
|
||||||
|
collection_name: str = None,
|
||||||
|
config: Dict[str, Any] = None,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
) -> Dict[str, Any]:
|
||||||
"""Get RAG system statistics"""
|
"""
|
||||||
|
Enhanced search with comprehensive debug information
|
||||||
|
"""
|
||||||
|
# Get RAG module from module manager
|
||||||
|
rag_module = module_manager.modules.get('rag')
|
||||||
|
if not rag_module or not rag_module.enabled:
|
||||||
|
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||||
|
|
||||||
|
debug_info = {}
|
||||||
|
start_time = datetime.utcnow()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rag_service = RAGService(db)
|
# Apply configuration if provided
|
||||||
stats = await rag_service.get_stats()
|
if config:
|
||||||
|
# Update RAG config temporarily
|
||||||
|
original_config = rag_module.config.copy()
|
||||||
|
rag_module.config.update(config)
|
||||||
|
|
||||||
|
# Generate query embedding (with or without prefix)
|
||||||
|
if config and config.get("use_query_prefix"):
|
||||||
|
optimized_query = f"query: {query}"
|
||||||
|
else:
|
||||||
|
optimized_query = query
|
||||||
|
|
||||||
|
query_embedding = await rag_module._generate_embedding(optimized_query)
|
||||||
|
|
||||||
|
# Store embedding info for debug
|
||||||
|
if config and config.get("debug", {}).get("show_embeddings"):
|
||||||
|
debug_info["query_embedding"] = query_embedding[:10] # First 10 dimensions
|
||||||
|
debug_info["embedding_dimension"] = len(query_embedding)
|
||||||
|
debug_info["optimized_query"] = optimized_query
|
||||||
|
|
||||||
|
# Perform search
|
||||||
|
search_start = asyncio.get_event_loop().time()
|
||||||
|
results = await rag_module.search_documents(
|
||||||
|
query,
|
||||||
|
max_results=max_results,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
collection_name=collection_name
|
||||||
|
)
|
||||||
|
search_time = (asyncio.get_event_loop().time() - search_start) * 1000
|
||||||
|
|
||||||
|
# Calculate score statistics
|
||||||
|
scores = [r.score for r in results if r.score is not None]
|
||||||
|
if scores:
|
||||||
|
import statistics
|
||||||
|
debug_info["score_stats"] = {
|
||||||
|
"min": min(scores),
|
||||||
|
"max": max(scores),
|
||||||
|
"avg": statistics.mean(scores),
|
||||||
|
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get collection statistics
|
||||||
|
try:
|
||||||
|
from qdrant_client.http.models import Filter
|
||||||
|
collection_name = collection_name or rag_module.default_collection_name
|
||||||
|
|
||||||
|
# Count total documents
|
||||||
|
count_result = rag_module.qdrant_client.count(
|
||||||
|
collection_name=collection_name,
|
||||||
|
count_filter=Filter(must=[])
|
||||||
|
)
|
||||||
|
total_points = count_result.count
|
||||||
|
|
||||||
|
# Get unique documents and languages
|
||||||
|
scroll_result = rag_module.qdrant_client.scroll(
|
||||||
|
collection_name=collection_name,
|
||||||
|
limit=1000, # Sample for stats
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=False
|
||||||
|
)
|
||||||
|
|
||||||
|
unique_docs = set()
|
||||||
|
languages = set()
|
||||||
|
|
||||||
|
for point in scroll_result[0]:
|
||||||
|
payload = point.payload or {}
|
||||||
|
doc_id = payload.get("document_id")
|
||||||
|
if doc_id:
|
||||||
|
unique_docs.add(doc_id)
|
||||||
|
|
||||||
|
language = payload.get("language")
|
||||||
|
if language:
|
||||||
|
languages.add(language)
|
||||||
|
|
||||||
|
debug_info["collection_stats"] = {
|
||||||
|
"total_documents": len(unique_docs),
|
||||||
|
"total_chunks": total_points,
|
||||||
|
"languages": sorted(list(languages))
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
debug_info["collection_stats_error"] = str(e)
|
||||||
|
|
||||||
|
# Enhance results with debug info
|
||||||
|
enhanced_results = []
|
||||||
|
for result in results:
|
||||||
|
enhanced_result = {
|
||||||
|
"document": {
|
||||||
|
"id": result.document.id,
|
||||||
|
"content": result.document.content,
|
||||||
|
"metadata": result.document.metadata
|
||||||
|
},
|
||||||
|
"score": result.score,
|
||||||
|
"debug_info": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add hybrid search debug info if available
|
||||||
|
metadata = result.document.metadata or {}
|
||||||
|
if "_vector_score" in metadata:
|
||||||
|
enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
|
||||||
|
if "_bm25_score" in metadata:
|
||||||
|
enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
|
||||||
|
|
||||||
|
enhanced_results.append(enhanced_result)
|
||||||
|
|
||||||
|
# Note: Analytics logging disabled (module not available)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"results": enhanced_results,
|
||||||
"stats": stats
|
"debug_info": debug_info,
|
||||||
|
"search_time_ms": search_time,
|
||||||
|
"timestamp": start_time.isoformat()
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
# Note: Analytics logging disabled (module not available)
|
||||||
|
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original config if modified
|
||||||
|
if config and 'original_config' in locals():
|
||||||
|
rag_module.config = original_config
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/debug/config")
|
||||||
|
async def get_current_config(
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Get current RAG configuration"""
|
||||||
|
# Get RAG module from module manager
|
||||||
|
rag_module = module_manager.modules.get('rag')
|
||||||
|
if not rag_module or not rag_module.enabled:
|
||||||
|
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"config": rag_module.config,
|
||||||
|
"embedding_model": rag_module.embedding_model,
|
||||||
|
"enabled": rag_module.enabled,
|
||||||
|
"collections": await rag_module._get_collections_safely()
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,251 +0,0 @@
|
|||||||
"""
|
|
||||||
Security API endpoints for monitoring and configuration
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any, List, Optional
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from app.core.security import get_current_active_user, RequiresRole
|
|
||||||
from app.middleware.security import get_security_stats, get_request_auth_level, get_request_risk_score
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.logging import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(tags=["security"])
|
|
||||||
|
|
||||||
|
|
||||||
# Pydantic models for API responses
|
|
||||||
class SecurityStatsResponse(BaseModel):
|
|
||||||
"""Security statistics response model"""
|
|
||||||
total_requests_analyzed: int
|
|
||||||
threats_detected: int
|
|
||||||
threats_blocked: int
|
|
||||||
anomalies_detected: int
|
|
||||||
rate_limits_exceeded: int
|
|
||||||
avg_analysis_time: float
|
|
||||||
threat_types: Dict[str, int]
|
|
||||||
threat_levels: Dict[str, int]
|
|
||||||
top_attacking_ips: List[tuple]
|
|
||||||
security_enabled: bool
|
|
||||||
threat_detection_enabled: bool
|
|
||||||
rate_limiting_enabled: bool
|
|
||||||
|
|
||||||
|
|
||||||
class SecurityConfigResponse(BaseModel):
|
|
||||||
"""Security configuration response model"""
|
|
||||||
security_enabled: bool = Field(description="Overall security system enabled")
|
|
||||||
threat_detection_enabled: bool = Field(description="Threat detection analysis enabled")
|
|
||||||
rate_limiting_enabled: bool = Field(description="Rate limiting enabled")
|
|
||||||
ip_reputation_enabled: bool = Field(description="IP reputation checking enabled")
|
|
||||||
anomaly_detection_enabled: bool = Field(description="Anomaly detection enabled")
|
|
||||||
security_headers_enabled: bool = Field(description="Security headers enabled")
|
|
||||||
|
|
||||||
# Rate limiting settings
|
|
||||||
unauthenticated_per_minute: int = Field(description="Rate limit for unauthenticated requests per minute")
|
|
||||||
authenticated_per_minute: int = Field(description="Rate limit for authenticated users per minute")
|
|
||||||
api_key_per_minute: int = Field(description="Rate limit for API key users per minute")
|
|
||||||
premium_per_minute: int = Field(description="Rate limit for premium users per minute")
|
|
||||||
|
|
||||||
# Security thresholds
|
|
||||||
risk_threshold: float = Field(description="Risk score threshold for blocking requests")
|
|
||||||
warning_threshold: float = Field(description="Risk score threshold for warnings")
|
|
||||||
anomaly_threshold: float = Field(description="Anomaly severity threshold")
|
|
||||||
|
|
||||||
# IP settings
|
|
||||||
blocked_ips: List[str] = Field(description="List of blocked IP addresses")
|
|
||||||
allowed_ips: List[str] = Field(description="List of allowed IP addresses (empty = allow all)")
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimitInfoResponse(BaseModel):
|
|
||||||
"""Rate limit information for current request"""
|
|
||||||
auth_level: str = Field(description="Authentication level (unauthenticated, authenticated, api_key, premium)")
|
|
||||||
current_limits: Dict[str, int] = Field(description="Current rate limits for this auth level")
|
|
||||||
remaining_requests: Optional[Dict[str, int]] = Field(description="Estimated remaining requests (if available)")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stats", response_model=SecurityStatsResponse)
|
|
||||||
async def get_security_statistics(
|
|
||||||
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get security system statistics
|
|
||||||
|
|
||||||
Requires admin role. Returns comprehensive statistics about:
|
|
||||||
- Request analysis counts
|
|
||||||
- Threat detection results
|
|
||||||
- Rate limiting enforcement
|
|
||||||
- Top attacking IPs
|
|
||||||
- Performance metrics
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
stats = get_security_stats()
|
|
||||||
return SecurityStatsResponse(**stats)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting security stats: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to retrieve security statistics"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config", response_model=SecurityConfigResponse)
|
|
||||||
async def get_security_config(
|
|
||||||
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get current security configuration
|
|
||||||
|
|
||||||
Requires admin role. Returns current security settings including:
|
|
||||||
- Feature enablement flags
|
|
||||||
- Rate limiting thresholds
|
|
||||||
- Security thresholds
|
|
||||||
- IP allowlists/blocklists
|
|
||||||
"""
|
|
||||||
return SecurityConfigResponse(
|
|
||||||
security_enabled=settings.API_SECURITY_ENABLED,
|
|
||||||
threat_detection_enabled=settings.API_THREAT_DETECTION_ENABLED,
|
|
||||||
rate_limiting_enabled=settings.API_RATE_LIMITING_ENABLED,
|
|
||||||
ip_reputation_enabled=settings.API_IP_REPUTATION_ENABLED,
|
|
||||||
anomaly_detection_enabled=settings.API_ANOMALY_DETECTION_ENABLED,
|
|
||||||
security_headers_enabled=settings.API_SECURITY_HEADERS_ENABLED,
|
|
||||||
|
|
||||||
unauthenticated_per_minute=settings.API_RATE_LIMIT_UNAUTHENTICATED_PER_MINUTE,
|
|
||||||
authenticated_per_minute=settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE,
|
|
||||||
api_key_per_minute=settings.API_RATE_LIMIT_API_KEY_PER_MINUTE,
|
|
||||||
premium_per_minute=settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE,
|
|
||||||
|
|
||||||
risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
|
|
||||||
warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
|
|
||||||
anomaly_threshold=settings.API_SECURITY_ANOMALY_THRESHOLD,
|
|
||||||
|
|
||||||
blocked_ips=settings.API_BLOCKED_IPS,
|
|
||||||
allowed_ips=settings.API_ALLOWED_IPS
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/status")
|
|
||||||
async def get_security_status(
|
|
||||||
request: Request,
|
|
||||||
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get security status for current request
|
|
||||||
|
|
||||||
Returns information about the security analysis of the current request:
|
|
||||||
- Authentication level
|
|
||||||
- Risk score (if available)
|
|
||||||
- Rate limiting status
|
|
||||||
"""
|
|
||||||
auth_level = get_request_auth_level(request)
|
|
||||||
risk_score = get_request_risk_score(request)
|
|
||||||
|
|
||||||
# Get rate limits for current auth level
|
|
||||||
from app.core.threat_detection import AuthLevel
|
|
||||||
try:
|
|
||||||
auth_enum = AuthLevel(auth_level)
|
|
||||||
from app.core.threat_detection import threat_detection_service
|
|
||||||
minute_limit, hour_limit = threat_detection_service.get_rate_limits(auth_enum)
|
|
||||||
|
|
||||||
rate_limit_info = RateLimitInfoResponse(
|
|
||||||
auth_level=auth_level,
|
|
||||||
current_limits={
|
|
||||||
"per_minute": minute_limit,
|
|
||||||
"per_hour": hour_limit
|
|
||||||
},
|
|
||||||
remaining_requests=None # We don't track remaining requests in current implementation
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
rate_limit_info = RateLimitInfoResponse(
|
|
||||||
auth_level=auth_level,
|
|
||||||
current_limits={},
|
|
||||||
remaining_requests=None
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"security_enabled": settings.API_SECURITY_ENABLED,
|
|
||||||
"auth_level": auth_level,
|
|
||||||
"risk_score": round(risk_score, 3) if risk_score > 0 else None,
|
|
||||||
"rate_limit_info": rate_limit_info.dict(),
|
|
||||||
"security_headers_enabled": settings.API_SECURITY_HEADERS_ENABLED
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/test")
|
|
||||||
async def test_security_analysis(
|
|
||||||
request: Request,
|
|
||||||
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Test security analysis on current request
|
|
||||||
|
|
||||||
Requires admin role. Manually triggers security analysis on the current request
|
|
||||||
and returns detailed results. Useful for testing security rules and thresholds.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from app.middleware.security import analyze_request_security
|
|
||||||
|
|
||||||
analysis = await analyze_request_security(request, current_user)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"analysis_complete": True,
|
|
||||||
"is_threat": analysis.is_threat,
|
|
||||||
"risk_score": round(analysis.risk_score, 3),
|
|
||||||
"auth_level": analysis.auth_level.value,
|
|
||||||
"should_block": analysis.should_block,
|
|
||||||
"rate_limit_exceeded": analysis.rate_limit_exceeded,
|
|
||||||
"threat_count": len(analysis.threats),
|
|
||||||
"threats": [
|
|
||||||
{
|
|
||||||
"type": threat.threat_type,
|
|
||||||
"level": threat.level.value,
|
|
||||||
"confidence": round(threat.confidence, 3),
|
|
||||||
"description": threat.description,
|
|
||||||
"mitigation": threat.mitigation
|
|
||||||
}
|
|
||||||
for threat in analysis.threats
|
|
||||||
],
|
|
||||||
"recommendations": analysis.recommendations
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in security analysis test: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to perform security analysis test"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/health")
|
|
||||||
async def security_health_check():
|
|
||||||
"""
|
|
||||||
Security system health check
|
|
||||||
|
|
||||||
Public endpoint that returns the health status of the security system.
|
|
||||||
Does not require authentication.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
stats = get_security_stats()
|
|
||||||
|
|
||||||
# Basic health checks
|
|
||||||
is_healthy = (
|
|
||||||
settings.API_SECURITY_ENABLED and
|
|
||||||
stats.get("total_requests_analyzed", 0) >= 0 and
|
|
||||||
stats.get("avg_analysis_time", 0) < 1.0 # Analysis should be under 1 second
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "healthy" if is_healthy else "degraded",
|
|
||||||
"security_enabled": settings.API_SECURITY_ENABLED,
|
|
||||||
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
|
|
||||||
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED,
|
|
||||||
"avg_analysis_time_ms": round(stats.get("avg_analysis_time", 0) * 1000, 2),
|
|
||||||
"total_requests_analyzed": stats.get("total_requests_analyzed", 0)
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Security health check failed: {e}")
|
|
||||||
return {
|
|
||||||
"status": "unhealthy",
|
|
||||||
"error": "Security system error",
|
|
||||||
"security_enabled": settings.API_SECURITY_ENABLED
|
|
||||||
}
|
|
||||||
@@ -97,7 +97,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
|
|||||||
"api": {
|
"api": {
|
||||||
# Security Settings
|
# Security Settings
|
||||||
"security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"},
|
"security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"},
|
||||||
"threat_detection_enabled": {"value": True, "type": "boolean", "description": "Enable threat detection analysis"},
|
|
||||||
"rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"},
|
"rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"},
|
||||||
"ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"},
|
"ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"},
|
||||||
"anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"},
|
"anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"},
|
||||||
@@ -112,7 +111,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
|
|||||||
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"},
|
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"},
|
||||||
|
|
||||||
# Security Thresholds
|
# Security Thresholds
|
||||||
"security_risk_threshold": {"value": 0.8, "type": "float", "description": "Risk score threshold for blocking requests (0.0-1.0)"},
|
|
||||||
"security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"},
|
"security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"},
|
||||||
"anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"},
|
"anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"},
|
||||||
|
|
||||||
@@ -601,7 +599,6 @@ async def reset_to_defaults(
|
|||||||
"api": {
|
"api": {
|
||||||
# Security Settings
|
# Security Settings
|
||||||
"security_enabled": {"value": True, "type": "boolean"},
|
"security_enabled": {"value": True, "type": "boolean"},
|
||||||
"threat_detection_enabled": {"value": True, "type": "boolean"},
|
|
||||||
"rate_limiting_enabled": {"value": True, "type": "boolean"},
|
"rate_limiting_enabled": {"value": True, "type": "boolean"},
|
||||||
"ip_reputation_enabled": {"value": True, "type": "boolean"},
|
"ip_reputation_enabled": {"value": True, "type": "boolean"},
|
||||||
"anomaly_detection_enabled": {"value": True, "type": "boolean"},
|
"anomaly_detection_enabled": {"value": True, "type": "boolean"},
|
||||||
@@ -616,7 +613,6 @@ async def reset_to_defaults(
|
|||||||
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer"},
|
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer"},
|
||||||
|
|
||||||
# Security Thresholds
|
# Security Thresholds
|
||||||
"security_risk_threshold": {"value": 0.8, "type": "float"},
|
|
||||||
"security_warning_threshold": {"value": 0.6, "type": "float"},
|
"security_warning_threshold": {"value": 0.6, "type": "float"},
|
||||||
"anomaly_threshold": {"value": 0.7, "type": "float"},
|
"anomaly_threshold": {"value": 0.7, "type": "float"},
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ class Settings(BaseSettings):
|
|||||||
APP_LOG_LEVEL: str = os.getenv("APP_LOG_LEVEL", "INFO")
|
APP_LOG_LEVEL: str = os.getenv("APP_LOG_LEVEL", "INFO")
|
||||||
APP_HOST: str = os.getenv("APP_HOST", "0.0.0.0")
|
APP_HOST: str = os.getenv("APP_HOST", "0.0.0.0")
|
||||||
APP_PORT: int = int(os.getenv("APP_PORT", "8000"))
|
APP_PORT: int = int(os.getenv("APP_PORT", "8000"))
|
||||||
|
BACKEND_INTERNAL_PORT: int = int(os.getenv("BACKEND_INTERNAL_PORT", "8000"))
|
||||||
|
FRONTEND_INTERNAL_PORT: int = int(os.getenv("FRONTEND_INTERNAL_PORT", "3000"))
|
||||||
|
|
||||||
# Detailed logging for LLM interactions
|
# Detailed logging for LLM interactions
|
||||||
LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM
|
LOG_LLM_PROMPTS: bool = os.getenv("LOG_LLM_PROMPTS", "False").lower() == "true" # Set to True to log prompts and context sent to LLM
|
||||||
@@ -73,15 +75,10 @@ class Settings(BaseSettings):
|
|||||||
QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost")
|
QDRANT_HOST: str = os.getenv("QDRANT_HOST", "localhost")
|
||||||
QDRANT_PORT: int = int(os.getenv("QDRANT_PORT", "6333"))
|
QDRANT_PORT: int = int(os.getenv("QDRANT_PORT", "6333"))
|
||||||
QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY")
|
QDRANT_API_KEY: Optional[str] = os.getenv("QDRANT_API_KEY")
|
||||||
|
QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||||
|
|
||||||
# API & Security Settings
|
|
||||||
API_SECURITY_ENABLED: bool = os.getenv("API_SECURITY_ENABLED", "True").lower() == "true"
|
|
||||||
API_THREAT_DETECTION_ENABLED: bool = os.getenv("API_THREAT_DETECTION_ENABLED", "True").lower() == "true"
|
|
||||||
API_IP_REPUTATION_ENABLED: bool = os.getenv("API_IP_REPUTATION_ENABLED", "True").lower() == "true"
|
|
||||||
API_ANOMALY_DETECTION_ENABLED: bool = os.getenv("API_ANOMALY_DETECTION_ENABLED", "True").lower() == "true"
|
|
||||||
|
|
||||||
# Rate Limiting Configuration
|
# Rate Limiting Configuration
|
||||||
API_RATE_LIMITING_ENABLED: bool = os.getenv("API_RATE_LIMITING_ENABLED", "True").lower() == "true"
|
|
||||||
|
|
||||||
# PrivateMode Standard tier limits (organization-level, not per user)
|
# PrivateMode Standard tier limits (organization-level, not per user)
|
||||||
# These are shared across all API keys and users in the organization
|
# These are shared across all API keys and users in the organization
|
||||||
@@ -102,22 +99,13 @@ class Settings(BaseSettings):
|
|||||||
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
|
API_RATE_LIMIT_PREMIUM_PER_MINUTE: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_MINUTE", "20")) # Match PrivateMode
|
||||||
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
|
API_RATE_LIMIT_PREMIUM_PER_HOUR: int = int(os.getenv("API_RATE_LIMIT_PREMIUM_PER_HOUR", "1200"))
|
||||||
|
|
||||||
# Security Thresholds
|
|
||||||
API_SECURITY_RISK_THRESHOLD: float = float(os.getenv("API_SECURITY_RISK_THRESHOLD", "0.8")) # Block requests above this risk score
|
|
||||||
API_SECURITY_WARNING_THRESHOLD: float = float(os.getenv("API_SECURITY_WARNING_THRESHOLD", "0.6")) # Log warnings above this threshold
|
|
||||||
API_SECURITY_ANOMALY_THRESHOLD: float = float(os.getenv("API_SECURITY_ANOMALY_THRESHOLD", "0.7")) # Flag anomalies above this threshold
|
|
||||||
|
|
||||||
# Request Size Limits
|
# Request Size Limits
|
||||||
API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB
|
API_MAX_REQUEST_BODY_SIZE: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE", "10485760")) # 10MB
|
||||||
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium
|
API_MAX_REQUEST_BODY_SIZE_PREMIUM: int = int(os.getenv("API_MAX_REQUEST_BODY_SIZE_PREMIUM", "52428800")) # 50MB for premium
|
||||||
|
|
||||||
# IP Security
|
# IP Security
|
||||||
API_BLOCKED_IPS: List[str] = os.getenv("API_BLOCKED_IPS", "").split(",") if os.getenv("API_BLOCKED_IPS") else []
|
|
||||||
API_ALLOWED_IPS: List[str] = os.getenv("API_ALLOWED_IPS", "").split(",") if os.getenv("API_ALLOWED_IPS") else []
|
|
||||||
API_IP_REPUTATION_CACHE_TTL: int = int(os.getenv("API_IP_REPUTATION_CACHE_TTL", "3600")) # 1 hour
|
|
||||||
|
|
||||||
# Security Headers
|
# Security Headers
|
||||||
API_SECURITY_HEADERS_ENABLED: bool = os.getenv("API_SECURITY_HEADERS_ENABLED", "True").lower() == "true"
|
|
||||||
API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
|
API_CSP_HEADER: str = os.getenv("API_CSP_HEADER", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
|
||||||
|
|
||||||
# Monitoring
|
# Monitoring
|
||||||
@@ -130,6 +118,19 @@ class Settings(BaseSettings):
|
|||||||
# Module configuration
|
# Module configuration
|
||||||
MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml")
|
MODULES_CONFIG_PATH: str = os.getenv("MODULES_CONFIG_PATH", "config/modules.yaml")
|
||||||
|
|
||||||
|
# RAG Embedding Configuration
|
||||||
|
RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE: int = int(os.getenv("RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE", "12"))
|
||||||
|
RAG_EMBEDDING_BATCH_SIZE: int = int(os.getenv("RAG_EMBEDDING_BATCH_SIZE", "3"))
|
||||||
|
RAG_EMBEDDING_RETRY_COUNT: int = int(os.getenv("RAG_EMBEDDING_RETRY_COUNT", "3"))
|
||||||
|
RAG_EMBEDDING_RETRY_DELAYS: str = os.getenv("RAG_EMBEDDING_RETRY_DELAYS", "1,2,4,8,16")
|
||||||
|
RAG_EMBEDDING_DELAY_BETWEEN_BATCHES: float = float(os.getenv("RAG_EMBEDDING_DELAY_BETWEEN_BATCHES", "1.0"))
|
||||||
|
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
|
||||||
|
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
|
||||||
|
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
|
||||||
|
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
|
||||||
|
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
|
||||||
|
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
|
||||||
|
|
||||||
# Plugin configuration
|
# Plugin configuration
|
||||||
PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins")
|
PLUGINS_DIR: str = os.getenv("PLUGINS_DIR", "/plugins")
|
||||||
PLUGINS_CONFIG_PATH: str = os.getenv("PLUGINS_CONFIG_PATH", "config/plugins.yaml")
|
PLUGINS_CONFIG_PATH: str = os.getenv("PLUGINS_CONFIG_PATH", "config/plugins.yaml")
|
||||||
@@ -142,7 +143,10 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"env_file": ".env",
|
"env_file": ".env",
|
||||||
"case_sensitive": True
|
"case_sensitive": True,
|
||||||
|
# Ignore unknown environment variables to avoid validation errors
|
||||||
|
# when optional/deprecated flags are present in .env
|
||||||
|
"extra": "ignore",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,744 +0,0 @@
|
|||||||
"""
|
|
||||||
Core threat detection and security analysis for the platform
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from collections import defaultdict, deque
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Dict, List, Optional, Set, Tuple, Any, Union
|
|
||||||
from urllib.parse import unquote
|
|
||||||
|
|
||||||
from fastapi import Request
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.logging import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ThreatLevel(Enum):
|
|
||||||
"""Threat severity levels"""
|
|
||||||
LOW = "low"
|
|
||||||
MEDIUM = "medium"
|
|
||||||
HIGH = "high"
|
|
||||||
CRITICAL = "critical"
|
|
||||||
|
|
||||||
|
|
||||||
class AuthLevel(Enum):
|
|
||||||
"""Authentication levels for rate limiting"""
|
|
||||||
AUTHENTICATED = "authenticated"
|
|
||||||
API_KEY = "api_key"
|
|
||||||
PREMIUM = "premium"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SecurityThreat:
|
|
||||||
"""Security threat detection result"""
|
|
||||||
threat_type: str
|
|
||||||
level: ThreatLevel
|
|
||||||
confidence: float
|
|
||||||
description: str
|
|
||||||
source_ip: str
|
|
||||||
user_agent: Optional[str] = None
|
|
||||||
request_path: Optional[str] = None
|
|
||||||
payload: Optional[str] = None
|
|
||||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
|
||||||
mitigation: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SecurityAnalysis:
|
|
||||||
"""Comprehensive security analysis result"""
|
|
||||||
is_threat: bool
|
|
||||||
threats: List[SecurityThreat]
|
|
||||||
risk_score: float
|
|
||||||
recommendations: List[str]
|
|
||||||
auth_level: AuthLevel
|
|
||||||
rate_limit_exceeded: bool
|
|
||||||
should_block: bool
|
|
||||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RateLimitInfo:
|
|
||||||
"""Rate limiting information"""
|
|
||||||
auth_level: AuthLevel
|
|
||||||
requests_per_minute: int
|
|
||||||
requests_per_hour: int
|
|
||||||
minute_limit: int
|
|
||||||
hour_limit: int
|
|
||||||
exceeded: bool
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AnomalyDetection:
|
|
||||||
"""Anomaly detection result"""
|
|
||||||
is_anomaly: bool
|
|
||||||
anomaly_type: str
|
|
||||||
severity: float
|
|
||||||
details: Dict[str, Any]
|
|
||||||
baseline_value: Optional[float] = None
|
|
||||||
current_value: Optional[float] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ThreatDetectionService:
|
|
||||||
"""Core threat detection and security analysis service"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = "threat_detection"
|
|
||||||
|
|
||||||
# Statistics
|
|
||||||
self.stats = {
|
|
||||||
'total_requests_analyzed': 0,
|
|
||||||
'threats_detected': 0,
|
|
||||||
'threats_blocked': 0,
|
|
||||||
'anomalies_detected': 0,
|
|
||||||
'rate_limits_exceeded': 0,
|
|
||||||
'total_analysis_time': 0,
|
|
||||||
'threat_types': defaultdict(int),
|
|
||||||
'threat_levels': defaultdict(int),
|
|
||||||
'attacking_ips': defaultdict(int)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Threat detection patterns
|
|
||||||
self.sql_injection_patterns = [
|
|
||||||
r"(\bunion\b.*\bselect\b)",
|
|
||||||
r"(\bselect\b.*\bfrom\b)",
|
|
||||||
r"(\binsert\b.*\binto\b)",
|
|
||||||
r"(\bupdate\b.*\bset\b)",
|
|
||||||
r"(\bdelete\b.*\bfrom\b)",
|
|
||||||
r"(\bdrop\b.*\btable\b)",
|
|
||||||
r"(\bor\b.*\b1\s*=\s*1\b)",
|
|
||||||
r"(\band\b.*\b1\s*=\s*1\b)",
|
|
||||||
r"(\bexec\b.*\bxp_\w+)",
|
|
||||||
r"(\bsp_\w+)",
|
|
||||||
r"(\bsleep\b\s*\(\s*\d+\s*\))",
|
|
||||||
r"(\bwaitfor\b.*\bdelay\b)",
|
|
||||||
r"(\bbenchmark\b\s*\(\s*\d+)",
|
|
||||||
r"(\bload_file\b\s*\()",
|
|
||||||
r"(\binto\b.*\boutfile\b)"
|
|
||||||
]
|
|
||||||
|
|
||||||
self.xss_patterns = [
|
|
||||||
r"<script[^>]*>.*?</script>",
|
|
||||||
r"<iframe[^>]*>.*?</iframe>",
|
|
||||||
r"<object[^>]*>.*?</object>",
|
|
||||||
r"<embed[^>]*>.*?</embed>",
|
|
||||||
r"<link[^>]*>",
|
|
||||||
r"<meta[^>]*>",
|
|
||||||
r"javascript:",
|
|
||||||
r"vbscript:",
|
|
||||||
r"on\w+\s*=",
|
|
||||||
r"style\s*=.*expression",
|
|
||||||
r"style\s*=.*javascript"
|
|
||||||
]
|
|
||||||
|
|
||||||
self.path_traversal_patterns = [
|
|
||||||
r"\.\.\/",
|
|
||||||
r"\.\.\\",
|
|
||||||
r"%2e%2e%2f",
|
|
||||||
r"%2e%2e%5c",
|
|
||||||
r"..%2f",
|
|
||||||
r"..%5c",
|
|
||||||
r"%252e%252e%252f",
|
|
||||||
r"%252e%252e%255c"
|
|
||||||
]
|
|
||||||
|
|
||||||
self.command_injection_patterns = [
|
|
||||||
r";\s*cat\s+",
|
|
||||||
r";\s*ls\s+",
|
|
||||||
r";\s*pwd\s*",
|
|
||||||
r";\s*whoami\s*",
|
|
||||||
r";\s*id\s*",
|
|
||||||
r";\s*uname\s*",
|
|
||||||
r";\s*ps\s+",
|
|
||||||
r";\s*netstat\s+",
|
|
||||||
r";\s*wget\s+",
|
|
||||||
r";\s*curl\s+",
|
|
||||||
r"\|\s*cat\s+",
|
|
||||||
r"\|\s*ls\s+",
|
|
||||||
r"&&\s*cat\s+",
|
|
||||||
r"&&\s*ls\s+"
|
|
||||||
]
|
|
||||||
|
|
||||||
self.suspicious_ua_patterns = [
|
|
||||||
r"sqlmap",
|
|
||||||
r"nikto",
|
|
||||||
r"nmap",
|
|
||||||
r"masscan",
|
|
||||||
r"zap",
|
|
||||||
r"burp",
|
|
||||||
r"w3af",
|
|
||||||
r"acunetix",
|
|
||||||
r"nessus",
|
|
||||||
r"openvas",
|
|
||||||
r"metasploit"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Rate limiting tracking - separate by auth level (excluding unauthenticated since they're blocked)
|
|
||||||
self.rate_limits = {
|
|
||||||
AuthLevel.AUTHENTICATED: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
|
|
||||||
AuthLevel.API_KEY: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)}),
|
|
||||||
AuthLevel.PREMIUM: defaultdict(lambda: {'minute': deque(maxlen=60), 'hour': deque(maxlen=3600)})
|
|
||||||
}
|
|
||||||
|
|
||||||
# Anomaly detection
|
|
||||||
self.request_history = deque(maxlen=1000)
|
|
||||||
self.ip_history = defaultdict(lambda: deque(maxlen=100))
|
|
||||||
self.endpoint_history = defaultdict(lambda: deque(maxlen=100))
|
|
||||||
|
|
||||||
# Blocked and allowed IPs
|
|
||||||
self.blocked_ips = set(settings.API_BLOCKED_IPS)
|
|
||||||
self.allowed_ips = set(settings.API_ALLOWED_IPS) if settings.API_ALLOWED_IPS else None
|
|
||||||
|
|
||||||
# IP reputation cache
|
|
||||||
self.ip_reputation_cache = {}
|
|
||||||
self.cache_expiry = {}
|
|
||||||
|
|
||||||
# Compile patterns for performance
|
|
||||||
self._compile_patterns()
|
|
||||||
|
|
||||||
logger.info(f"ThreatDetectionService initialized with {len(self.sql_injection_patterns)} SQL patterns, "
|
|
||||||
f"{len(self.xss_patterns)} XSS patterns, rate limiting enabled: {settings.API_RATE_LIMITING_ENABLED}")
|
|
||||||
|
|
||||||
def _compile_patterns(self):
|
|
||||||
"""Compile regex patterns for better performance"""
|
|
||||||
try:
|
|
||||||
self.compiled_sql_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.sql_injection_patterns]
|
|
||||||
self.compiled_xss_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.xss_patterns]
|
|
||||||
self.compiled_path_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.path_traversal_patterns]
|
|
||||||
self.compiled_cmd_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.command_injection_patterns]
|
|
||||||
self.compiled_ua_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.suspicious_ua_patterns]
|
|
||||||
except re.error as e:
|
|
||||||
logger.error(f"Failed to compile security patterns: {e}")
|
|
||||||
# Fallback to empty lists to prevent crashes
|
|
||||||
self.compiled_sql_patterns = []
|
|
||||||
self.compiled_xss_patterns = []
|
|
||||||
self.compiled_path_patterns = []
|
|
||||||
self.compiled_cmd_patterns = []
|
|
||||||
self.compiled_ua_patterns = []
|
|
||||||
|
|
||||||
def determine_auth_level(self, request: Request, user_context: Optional[Dict] = None) -> AuthLevel:
|
|
||||||
"""Determine authentication level for rate limiting"""
|
|
||||||
# Check if request has API key authentication
|
|
||||||
if hasattr(request.state, 'api_key_context') and request.state.api_key_context:
|
|
||||||
api_key = request.state.api_key_context.get('api_key')
|
|
||||||
if api_key and hasattr(api_key, 'tier'):
|
|
||||||
# Check for premium tier
|
|
||||||
if api_key.tier in ['premium', 'enterprise']:
|
|
||||||
return AuthLevel.PREMIUM
|
|
||||||
return AuthLevel.API_KEY
|
|
||||||
|
|
||||||
# Check for JWT authentication
|
|
||||||
if user_context or hasattr(request.state, 'user'):
|
|
||||||
return AuthLevel.AUTHENTICATED
|
|
||||||
|
|
||||||
# Check Authorization header for API key
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
|
||||||
api_key_header = request.headers.get("X-API-Key", "")
|
|
||||||
if auth_header.startswith("Bearer ") or api_key_header:
|
|
||||||
return AuthLevel.API_KEY
|
|
||||||
|
|
||||||
# Default to authenticated since unauthenticated requests are blocked at middleware
|
|
||||||
return AuthLevel.AUTHENTICATED
|
|
||||||
|
|
||||||
def get_rate_limits(self, auth_level: AuthLevel) -> Tuple[int, int]:
|
|
||||||
"""Get rate limits for authentication level"""
|
|
||||||
if not settings.API_RATE_LIMITING_ENABLED:
|
|
||||||
return float('inf'), float('inf')
|
|
||||||
|
|
||||||
if auth_level == AuthLevel.AUTHENTICATED:
|
|
||||||
return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
|
|
||||||
elif auth_level == AuthLevel.API_KEY:
|
|
||||||
return (settings.API_RATE_LIMIT_API_KEY_PER_MINUTE, settings.API_RATE_LIMIT_API_KEY_PER_HOUR)
|
|
||||||
elif auth_level == AuthLevel.PREMIUM:
|
|
||||||
return (settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE, settings.API_RATE_LIMIT_PREMIUM_PER_HOUR)
|
|
||||||
else:
|
|
||||||
# Fallback to authenticated limits
|
|
||||||
return (settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE, settings.API_RATE_LIMIT_AUTHENTICATED_PER_HOUR)
|
|
||||||
|
|
||||||
def check_rate_limit(self, client_ip: str, auth_level: AuthLevel) -> RateLimitInfo:
|
|
||||||
"""Check if request exceeds rate limits"""
|
|
||||||
minute_limit, hour_limit = self.get_rate_limits(auth_level)
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# Get or create tracking for this auth level
|
|
||||||
if auth_level not in self.rate_limits:
|
|
||||||
# This shouldn't happen, but handle gracefully
|
|
||||||
return RateLimitInfo(
|
|
||||||
auth_level=auth_level,
|
|
||||||
requests_per_minute=0,
|
|
||||||
requests_per_hour=0,
|
|
||||||
minute_limit=minute_limit,
|
|
||||||
hour_limit=hour_limit,
|
|
||||||
exceeded=False
|
|
||||||
)
|
|
||||||
|
|
||||||
ip_limits = self.rate_limits[auth_level][client_ip]
|
|
||||||
|
|
||||||
# Clean old entries
|
|
||||||
minute_ago = current_time - 60
|
|
||||||
hour_ago = current_time - 3600
|
|
||||||
|
|
||||||
while ip_limits['minute'] and ip_limits['minute'][0] < minute_ago:
|
|
||||||
ip_limits['minute'].popleft()
|
|
||||||
|
|
||||||
while ip_limits['hour'] and ip_limits['hour'][0] < hour_ago:
|
|
||||||
ip_limits['hour'].popleft()
|
|
||||||
|
|
||||||
# Check current counts
|
|
||||||
requests_per_minute = len(ip_limits['minute'])
|
|
||||||
requests_per_hour = len(ip_limits['hour'])
|
|
||||||
|
|
||||||
# Check if limits exceeded
|
|
||||||
exceeded = (requests_per_minute >= minute_limit) or (requests_per_hour >= hour_limit)
|
|
||||||
|
|
||||||
# Add current request to tracking
|
|
||||||
if not exceeded:
|
|
||||||
ip_limits['minute'].append(current_time)
|
|
||||||
ip_limits['hour'].append(current_time)
|
|
||||||
|
|
||||||
return RateLimitInfo(
|
|
||||||
auth_level=auth_level,
|
|
||||||
requests_per_minute=requests_per_minute,
|
|
||||||
requests_per_hour=requests_per_hour,
|
|
||||||
minute_limit=minute_limit,
|
|
||||||
hour_limit=hour_limit,
|
|
||||||
exceeded=exceeded
|
|
||||||
)
|
|
||||||
|
|
||||||
async def analyze_request(self, request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
|
|
||||||
"""Perform comprehensive security analysis on a request"""
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
try:
|
|
||||||
client_ip = request.client.host if request.client else "unknown"
|
|
||||||
user_agent = request.headers.get("user-agent", "")
|
|
||||||
path = str(request.url.path)
|
|
||||||
method = request.method
|
|
||||||
|
|
||||||
# Determine authentication level
|
|
||||||
auth_level = self.determine_auth_level(request, user_context)
|
|
||||||
|
|
||||||
# Check IP allowlist/blocklist first
|
|
||||||
if self.allowed_ips and client_ip not in self.allowed_ips:
|
|
||||||
threat = SecurityThreat(
|
|
||||||
threat_type="ip_not_allowed",
|
|
||||||
level=ThreatLevel.HIGH,
|
|
||||||
confidence=1.0,
|
|
||||||
description=f"IP {client_ip} not in allowlist",
|
|
||||||
source_ip=client_ip,
|
|
||||||
mitigation="Add IP to allowlist or remove IP restrictions"
|
|
||||||
)
|
|
||||||
return SecurityAnalysis(
|
|
||||||
is_threat=True,
|
|
||||||
threats=[threat],
|
|
||||||
risk_score=1.0,
|
|
||||||
recommendations=["Block request immediately"],
|
|
||||||
auth_level=auth_level,
|
|
||||||
rate_limit_exceeded=False,
|
|
||||||
should_block=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if client_ip in self.blocked_ips:
|
|
||||||
threat = SecurityThreat(
|
|
||||||
threat_type="ip_blocked",
|
|
||||||
level=ThreatLevel.CRITICAL,
|
|
||||||
confidence=1.0,
|
|
||||||
description=f"IP {client_ip} is blocked",
|
|
||||||
source_ip=client_ip,
|
|
||||||
mitigation="Remove IP from blocklist if legitimate"
|
|
||||||
)
|
|
||||||
return SecurityAnalysis(
|
|
||||||
is_threat=True,
|
|
||||||
threats=[threat],
|
|
||||||
risk_score=1.0,
|
|
||||||
recommendations=["Block request immediately"],
|
|
||||||
auth_level=auth_level,
|
|
||||||
rate_limit_exceeded=False,
|
|
||||||
should_block=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check rate limiting
|
|
||||||
rate_limit_info = self.check_rate_limit(client_ip, auth_level)
|
|
||||||
if rate_limit_info.exceeded:
|
|
||||||
self.stats['rate_limits_exceeded'] += 1
|
|
||||||
threat = SecurityThreat(
|
|
||||||
threat_type="rate_limit_exceeded",
|
|
||||||
level=ThreatLevel.MEDIUM,
|
|
||||||
confidence=0.9,
|
|
||||||
description=f"Rate limit exceeded for {auth_level.value}: {rate_limit_info.requests_per_minute}/min, {rate_limit_info.requests_per_hour}/hr",
|
|
||||||
source_ip=client_ip,
|
|
||||||
mitigation=f"Implement rate limiting, current limits: {rate_limit_info.minute_limit}/min, {rate_limit_info.hour_limit}/hr"
|
|
||||||
)
|
|
||||||
return SecurityAnalysis(
|
|
||||||
is_threat=True,
|
|
||||||
threats=[threat],
|
|
||||||
risk_score=0.7,
|
|
||||||
recommendations=[f"Rate limit exceeded for {auth_level.value} user"],
|
|
||||||
auth_level=auth_level,
|
|
||||||
rate_limit_exceeded=True,
|
|
||||||
should_block=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Skip threat detection if disabled
|
|
||||||
if not settings.API_THREAT_DETECTION_ENABLED:
|
|
||||||
return SecurityAnalysis(
|
|
||||||
is_threat=False,
|
|
||||||
threats=[],
|
|
||||||
risk_score=0.0,
|
|
||||||
recommendations=[],
|
|
||||||
auth_level=auth_level,
|
|
||||||
rate_limit_exceeded=False,
|
|
||||||
should_block=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Collect request data for threat analysis
|
|
||||||
query_params = str(request.query_params)
|
|
||||||
headers = dict(request.headers)
|
|
||||||
|
|
||||||
# Try to get body content safely
|
|
||||||
body_content = ""
|
|
||||||
try:
|
|
||||||
if hasattr(request, '_body') and request._body:
|
|
||||||
body_content = request._body.decode() if isinstance(request._body, bytes) else str(request._body)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
threats = []
|
|
||||||
|
|
||||||
# Analyze for various threats
|
|
||||||
threats.extend(await self._detect_sql_injection(query_params, body_content, path, client_ip))
|
|
||||||
threats.extend(await self._detect_xss(query_params, body_content, headers, client_ip))
|
|
||||||
threats.extend(await self._detect_path_traversal(path, query_params, client_ip))
|
|
||||||
threats.extend(await self._detect_command_injection(query_params, body_content, client_ip))
|
|
||||||
threats.extend(await self._detect_suspicious_patterns(headers, user_agent, path, client_ip))
|
|
||||||
|
|
||||||
# Anomaly detection if enabled
|
|
||||||
if settings.API_ANOMALY_DETECTION_ENABLED:
|
|
||||||
anomaly = await self._detect_anomalies(client_ip, path, method, len(body_content))
|
|
||||||
if anomaly.is_anomaly and anomaly.severity > settings.API_SECURITY_ANOMALY_THRESHOLD:
|
|
||||||
threat = SecurityThreat(
|
|
||||||
threat_type=f"anomaly_{anomaly.anomaly_type}",
|
|
||||||
level=ThreatLevel.MEDIUM if anomaly.severity > 0.7 else ThreatLevel.LOW,
|
|
||||||
confidence=anomaly.severity,
|
|
||||||
description=f"Anomalous behavior detected: {anomaly.details}",
|
|
||||||
source_ip=client_ip,
|
|
||||||
user_agent=user_agent,
|
|
||||||
request_path=path
|
|
||||||
)
|
|
||||||
threats.append(threat)
|
|
||||||
|
|
||||||
# Calculate risk score
|
|
||||||
risk_score = self._calculate_risk_score(threats)
|
|
||||||
|
|
||||||
# Determine if request should be blocked
|
|
||||||
should_block = risk_score >= settings.API_SECURITY_RISK_THRESHOLD
|
|
||||||
|
|
||||||
# Generate recommendations
|
|
||||||
recommendations = self._generate_recommendations(threats, risk_score, auth_level)
|
|
||||||
|
|
||||||
# Update statistics
|
|
||||||
self._update_stats(threats, time.time() - start_time)
|
|
||||||
|
|
||||||
return SecurityAnalysis(
|
|
||||||
is_threat=len(threats) > 0,
|
|
||||||
threats=threats,
|
|
||||||
risk_score=risk_score,
|
|
||||||
recommendations=recommendations,
|
|
||||||
auth_level=auth_level,
|
|
||||||
rate_limit_exceeded=False,
|
|
||||||
should_block=should_block
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in threat analysis: {e}")
|
|
||||||
return SecurityAnalysis(
|
|
||||||
is_threat=False,
|
|
||||||
threats=[],
|
|
||||||
risk_score=0.0,
|
|
||||||
recommendations=["Error occurred during security analysis"],
|
|
||||||
auth_level=AuthLevel.AUTHENTICATED,
|
|
||||||
rate_limit_exceeded=False,
|
|
||||||
should_block=False
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _detect_sql_injection(self, query_params: str, body_content: str, path: str, client_ip: str) -> List[SecurityThreat]:
|
|
||||||
"""Detect SQL injection attempts"""
|
|
||||||
threats = []
|
|
||||||
content_to_check = f"{query_params} {body_content} {path}".lower()
|
|
||||||
|
|
||||||
for pattern in self.compiled_sql_patterns:
|
|
||||||
if pattern.search(content_to_check):
|
|
||||||
threat = SecurityThreat(
|
|
||||||
threat_type="sql_injection",
|
|
||||||
level=ThreatLevel.HIGH,
|
|
||||||
confidence=0.85,
|
|
||||||
description="Potential SQL injection attempt detected",
|
|
||||||
source_ip=client_ip,
|
|
||||||
payload=pattern.pattern,
|
|
||||||
mitigation="Block request, sanitize input, use parameterized queries"
|
|
||||||
)
|
|
||||||
threats.append(threat)
|
|
||||||
break # Don't duplicate for multiple patterns
|
|
||||||
|
|
||||||
return threats
|
|
||||||
|
|
||||||
async def _detect_xss(self, query_params: str, body_content: str, headers: dict, client_ip: str) -> List[SecurityThreat]:
|
|
||||||
"""Detect XSS attempts"""
|
|
||||||
threats = []
|
|
||||||
content_to_check = f"{query_params} {body_content}".lower()
|
|
||||||
|
|
||||||
# Check headers for XSS
|
|
||||||
for header_name, header_value in headers.items():
|
|
||||||
content_to_check += f" {header_value}".lower()
|
|
||||||
|
|
||||||
for pattern in self.compiled_xss_patterns:
|
|
||||||
if pattern.search(content_to_check):
|
|
||||||
threat = SecurityThreat(
|
|
||||||
threat_type="xss",
|
|
||||||
level=ThreatLevel.HIGH,
|
|
||||||
confidence=0.80,
|
|
||||||
description="Potential XSS attack detected",
|
|
||||||
source_ip=client_ip,
|
|
||||||
payload=pattern.pattern,
|
|
||||||
mitigation="Block request, sanitize input, implement CSP headers"
|
|
||||||
)
|
|
||||||
threats.append(threat)
|
|
||||||
break
|
|
||||||
|
|
||||||
return threats
|
|
||||||
|
|
||||||
async def _detect_path_traversal(self, path: str, query_params: str, client_ip: str) -> List[SecurityThreat]:
|
|
||||||
"""Detect path traversal attempts"""
|
|
||||||
threats = []
|
|
||||||
content_to_check = f"{path} {query_params}".lower()
|
|
||||||
decoded_content = unquote(content_to_check)
|
|
||||||
|
|
||||||
for pattern in self.compiled_path_patterns:
|
|
||||||
if pattern.search(content_to_check) or pattern.search(decoded_content):
|
|
||||||
threat = SecurityThreat(
|
|
||||||
threat_type="path_traversal",
|
|
||||||
level=ThreatLevel.HIGH,
|
|
||||||
confidence=0.90,
|
|
||||||
description="Path traversal attempt detected",
|
|
||||||
source_ip=client_ip,
|
|
||||||
request_path=path,
|
|
||||||
mitigation="Block request, validate file paths, implement access controls"
|
|
||||||
)
|
|
||||||
threats.append(threat)
|
|
||||||
break
|
|
||||||
|
|
||||||
return threats
|
|
||||||
|
|
||||||
async def _detect_command_injection(self, query_params: str, body_content: str, client_ip: str) -> List[SecurityThreat]:
|
|
||||||
"""Detect command injection attempts"""
|
|
||||||
threats = []
|
|
||||||
content_to_check = f"{query_params} {body_content}".lower()
|
|
||||||
|
|
||||||
for pattern in self.compiled_cmd_patterns:
|
|
||||||
if pattern.search(content_to_check):
|
|
||||||
threat = SecurityThreat(
|
|
||||||
threat_type="command_injection",
|
|
||||||
level=ThreatLevel.CRITICAL,
|
|
||||||
confidence=0.95,
|
|
||||||
description="Command injection attempt detected",
|
|
||||||
source_ip=client_ip,
|
|
||||||
payload=pattern.pattern,
|
|
||||||
mitigation="Block request immediately, sanitize input, disable shell execution"
|
|
||||||
)
|
|
||||||
threats.append(threat)
|
|
||||||
break
|
|
||||||
|
|
||||||
return threats
|
|
||||||
|
|
||||||
async def _detect_suspicious_patterns(self, headers: dict, user_agent: str, path: str, client_ip: str) -> List[SecurityThreat]:
|
|
||||||
"""Detect suspicious patterns in headers and user agent"""
|
|
||||||
threats = []
|
|
||||||
|
|
||||||
# Check for suspicious user agents
|
|
||||||
ua_lower = user_agent.lower()
|
|
||||||
for pattern in self.compiled_ua_patterns:
|
|
||||||
if pattern.search(ua_lower):
|
|
||||||
threat = SecurityThreat(
|
|
||||||
threat_type="suspicious_user_agent",
|
|
||||||
level=ThreatLevel.HIGH,
|
|
||||||
confidence=0.85,
|
|
||||||
description=f"Suspicious user agent detected: {pattern.pattern}",
|
|
||||||
source_ip=client_ip,
|
|
||||||
user_agent=user_agent,
|
|
||||||
mitigation="Block request, monitor IP for further activity"
|
|
||||||
)
|
|
||||||
threats.append(threat)
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check for suspicious headers
|
|
||||||
if "x-forwarded-for" in headers and "x-real-ip" in headers:
|
|
||||||
# Potential header manipulation
|
|
||||||
threat = SecurityThreat(
|
|
||||||
threat_type="header_manipulation",
|
|
||||||
level=ThreatLevel.LOW,
|
|
||||||
confidence=0.30,
|
|
||||||
description="Potential IP header manipulation detected",
|
|
||||||
source_ip=client_ip,
|
|
||||||
mitigation="Validate proxy headers, implement IP whitelisting"
|
|
||||||
)
|
|
||||||
threats.append(threat)
|
|
||||||
|
|
||||||
return threats
|
|
||||||
|
|
||||||
async def _detect_anomalies(self, client_ip: str, path: str, method: str, body_size: int) -> AnomalyDetection:
|
|
||||||
"""Detect anomalous behavior patterns"""
|
|
||||||
try:
|
|
||||||
# Request size anomaly
|
|
||||||
max_size = settings.API_MAX_REQUEST_BODY_SIZE
|
|
||||||
if body_size > max_size:
|
|
||||||
return AnomalyDetection(
|
|
||||||
is_anomaly=True,
|
|
||||||
anomaly_type="request_size",
|
|
||||||
severity=0.8,
|
|
||||||
details={"body_size": body_size, "threshold": max_size},
|
|
||||||
current_value=body_size,
|
|
||||||
baseline_value=max_size // 10
|
|
||||||
)
|
|
||||||
|
|
||||||
# Unusual endpoint access
|
|
||||||
if path.startswith("/admin") or path.startswith("/api/admin"):
|
|
||||||
return AnomalyDetection(
|
|
||||||
is_anomaly=True,
|
|
||||||
anomaly_type="sensitive_endpoint",
|
|
||||||
severity=0.6,
|
|
||||||
details={"path": path, "reason": "admin endpoint access"},
|
|
||||||
current_value=1.0,
|
|
||||||
baseline_value=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# IP request frequency anomaly
|
|
||||||
current_time = time.time()
|
|
||||||
ip_requests = self.ip_history[client_ip]
|
|
||||||
|
|
||||||
# Clean old entries (last 5 minutes)
|
|
||||||
five_minutes_ago = current_time - 300
|
|
||||||
while ip_requests and ip_requests[0] < five_minutes_ago:
|
|
||||||
ip_requests.popleft()
|
|
||||||
|
|
||||||
ip_requests.append(current_time)
|
|
||||||
|
|
||||||
if len(ip_requests) > 100: # More than 100 requests in 5 minutes
|
|
||||||
return AnomalyDetection(
|
|
||||||
is_anomaly=True,
|
|
||||||
anomaly_type="request_frequency",
|
|
||||||
severity=0.7,
|
|
||||||
details={"requests_5min": len(ip_requests), "threshold": 100},
|
|
||||||
current_value=len(ip_requests),
|
|
||||||
baseline_value=10 # 10 requests baseline
|
|
||||||
)
|
|
||||||
|
|
||||||
return AnomalyDetection(
|
|
||||||
is_anomaly=False,
|
|
||||||
anomaly_type="none",
|
|
||||||
severity=0.0,
|
|
||||||
details={}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in anomaly detection: {e}")
|
|
||||||
return AnomalyDetection(
|
|
||||||
is_anomaly=False,
|
|
||||||
anomaly_type="error",
|
|
||||||
severity=0.0,
|
|
||||||
details={"error": str(e)}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _calculate_risk_score(self, threats: List[SecurityThreat]) -> float:
|
|
||||||
"""Calculate overall risk score based on threats"""
|
|
||||||
if not threats:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
score = 0.0
|
|
||||||
for threat in threats:
|
|
||||||
level_multiplier = {
|
|
||||||
ThreatLevel.LOW: 0.25,
|
|
||||||
ThreatLevel.MEDIUM: 0.5,
|
|
||||||
ThreatLevel.HIGH: 0.75,
|
|
||||||
ThreatLevel.CRITICAL: 1.0
|
|
||||||
}
|
|
||||||
score += threat.confidence * level_multiplier.get(threat.level, 0.5)
|
|
||||||
|
|
||||||
# Normalize to 0-1 range
|
|
||||||
return min(score / len(threats), 1.0)
|
|
||||||
|
|
||||||
def _generate_recommendations(self, threats: List[SecurityThreat], risk_score: float, auth_level: AuthLevel) -> List[str]:
|
|
||||||
"""Generate security recommendations based on analysis"""
|
|
||||||
recommendations = []
|
|
||||||
|
|
||||||
if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
|
|
||||||
recommendations.append("CRITICAL: Block this request immediately")
|
|
||||||
elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
|
|
||||||
recommendations.append("HIGH: Consider blocking or rate limiting this IP")
|
|
||||||
elif risk_score > 0.4:
|
|
||||||
recommendations.append("MEDIUM: Monitor this IP closely")
|
|
||||||
|
|
||||||
threat_types = {threat.threat_type for threat in threats}
|
|
||||||
|
|
||||||
if "sql_injection" in threat_types:
|
|
||||||
recommendations.append("Implement parameterized queries and input validation")
|
|
||||||
|
|
||||||
if "xss" in threat_types:
|
|
||||||
recommendations.append("Implement Content Security Policy (CSP) headers")
|
|
||||||
|
|
||||||
if "command_injection" in threat_types:
|
|
||||||
recommendations.append("Disable shell execution and validate all inputs")
|
|
||||||
|
|
||||||
if "path_traversal" in threat_types:
|
|
||||||
recommendations.append("Implement proper file path validation and access controls")
|
|
||||||
|
|
||||||
if "rate_limit_exceeded" in threat_types:
|
|
||||||
recommendations.append(f"Rate limiting active for {auth_level.value} user")
|
|
||||||
|
|
||||||
if not recommendations:
|
|
||||||
recommendations.append("No immediate action required, continue monitoring")
|
|
||||||
|
|
||||||
return recommendations
|
|
||||||
|
|
||||||
def _update_stats(self, threats: List[SecurityThreat], analysis_time: float):
|
|
||||||
"""Update service statistics"""
|
|
||||||
self.stats['total_requests_analyzed'] += 1
|
|
||||||
self.stats['total_analysis_time'] += analysis_time
|
|
||||||
|
|
||||||
if threats:
|
|
||||||
self.stats['threats_detected'] += len(threats)
|
|
||||||
for threat in threats:
|
|
||||||
self.stats['threat_types'][threat.threat_type] += 1
|
|
||||||
self.stats['threat_levels'][threat.level.value] += 1
|
|
||||||
if threat.source_ip:
|
|
||||||
self.stats['attacking_ips'][threat.source_ip] += 1
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
"""Get service statistics"""
|
|
||||||
avg_time = (self.stats['total_analysis_time'] / self.stats['total_requests_analyzed']
|
|
||||||
if self.stats['total_requests_analyzed'] > 0 else 0)
|
|
||||||
|
|
||||||
# Get top attacking IPs
|
|
||||||
top_ips = sorted(self.stats['attacking_ips'].items(), key=lambda x: x[1], reverse=True)[:10]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_requests_analyzed": self.stats['total_requests_analyzed'],
|
|
||||||
"threats_detected": self.stats['threats_detected'],
|
|
||||||
"threats_blocked": self.stats['threats_blocked'],
|
|
||||||
"anomalies_detected": self.stats['anomalies_detected'],
|
|
||||||
"rate_limits_exceeded": self.stats['rate_limits_exceeded'],
|
|
||||||
"avg_analysis_time": avg_time,
|
|
||||||
"threat_types": dict(self.stats['threat_types']),
|
|
||||||
"threat_levels": dict(self.stats['threat_levels']),
|
|
||||||
"top_attacking_ips": top_ips,
|
|
||||||
"security_enabled": settings.API_SECURITY_ENABLED,
|
|
||||||
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
|
|
||||||
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Global threat detection service instance
|
|
||||||
threat_detection_service = ThreatDetectionService()
|
|
||||||
@@ -53,6 +53,14 @@ async def lifespan(app: FastAPI):
|
|||||||
# Initialize config manager
|
# Initialize config manager
|
||||||
await init_config_manager()
|
await init_config_manager()
|
||||||
|
|
||||||
|
# Initialize LLM service (needed by RAG module)
|
||||||
|
from app.services.llm.service import llm_service
|
||||||
|
try:
|
||||||
|
await llm_service.initialize()
|
||||||
|
logger.info("LLM service initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"LLM service initialization failed: {e}")
|
||||||
|
|
||||||
# Initialize analytics service
|
# Initialize analytics service
|
||||||
init_analytics_service()
|
init_analytics_service()
|
||||||
|
|
||||||
|
|||||||
@@ -1,371 +0,0 @@
|
|||||||
"""
|
|
||||||
Rate limiting middleware
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
import redis
|
|
||||||
from typing import Dict, Optional
|
|
||||||
from fastapi import Request, HTTPException, status
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
import asyncio
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.logging import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter:
|
|
||||||
"""Rate limiting implementation using Redis"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
try:
|
|
||||||
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
|
||||||
self.redis_client.ping() # Test connection
|
|
||||||
logger.info("Rate limiter initialized with Redis backend")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Redis not available for rate limiting: {e}")
|
|
||||||
self.redis_client = None
|
|
||||||
# Fall back to in-memory rate limiting
|
|
||||||
self.memory_store: Dict[str, Dict[str, float]] = {}
|
|
||||||
|
|
||||||
async def check_rate_limit(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
limit: int,
|
|
||||||
window_seconds: int,
|
|
||||||
identifier: str = "default"
|
|
||||||
) -> tuple[bool, Dict[str, int]]:
|
|
||||||
"""
|
|
||||||
Check if request is within rate limit
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: Rate limiting key (e.g., IP address, API key)
|
|
||||||
limit: Maximum number of requests allowed
|
|
||||||
window_seconds: Time window in seconds
|
|
||||||
identifier: Additional identifier for the rate limit
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_allowed, headers_dict)
|
|
||||||
"""
|
|
||||||
|
|
||||||
full_key = f"rate_limit:{identifier}:{key}"
|
|
||||||
current_time = int(time.time())
|
|
||||||
window_start = current_time - window_seconds
|
|
||||||
|
|
||||||
if self.redis_client:
|
|
||||||
return await self._check_redis_rate_limit(
|
|
||||||
full_key, limit, window_seconds, current_time, window_start
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self._check_memory_rate_limit(
|
|
||||||
full_key, limit, window_seconds, current_time, window_start
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _check_redis_rate_limit(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
limit: int,
|
|
||||||
window_seconds: int,
|
|
||||||
current_time: int,
|
|
||||||
window_start: int
|
|
||||||
) -> tuple[bool, Dict[str, int]]:
|
|
||||||
"""Check rate limit using Redis"""
|
|
||||||
|
|
||||||
pipe = self.redis_client.pipeline()
|
|
||||||
|
|
||||||
# Remove old entries
|
|
||||||
pipe.zremrangebyscore(key, 0, window_start)
|
|
||||||
|
|
||||||
# Count current requests in window
|
|
||||||
pipe.zcard(key)
|
|
||||||
|
|
||||||
# Add current request
|
|
||||||
pipe.zadd(key, {str(current_time): current_time})
|
|
||||||
|
|
||||||
# Set expiration
|
|
||||||
pipe.expire(key, window_seconds + 1)
|
|
||||||
|
|
||||||
results = pipe.execute()
|
|
||||||
current_requests = results[1]
|
|
||||||
|
|
||||||
# Calculate remaining requests and reset time
|
|
||||||
remaining = max(0, limit - current_requests - 1)
|
|
||||||
reset_time = current_time + window_seconds
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"X-RateLimit-Limit": limit,
|
|
||||||
"X-RateLimit-Remaining": remaining,
|
|
||||||
"X-RateLimit-Reset": reset_time,
|
|
||||||
"X-RateLimit-Window": window_seconds
|
|
||||||
}
|
|
||||||
|
|
||||||
is_allowed = current_requests < limit
|
|
||||||
|
|
||||||
if not is_allowed:
|
|
||||||
logger.warning(f"Rate limit exceeded for key: {key}")
|
|
||||||
|
|
||||||
return is_allowed, headers
|
|
||||||
|
|
||||||
def _check_memory_rate_limit(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
limit: int,
|
|
||||||
window_seconds: int,
|
|
||||||
current_time: int,
|
|
||||||
window_start: int
|
|
||||||
) -> tuple[bool, Dict[str, int]]:
|
|
||||||
"""Check rate limit using in-memory storage"""
|
|
||||||
|
|
||||||
if key not in self.memory_store:
|
|
||||||
self.memory_store[key] = {}
|
|
||||||
|
|
||||||
# Clean old entries
|
|
||||||
store = self.memory_store[key]
|
|
||||||
keys_to_remove = [k for k, v in store.items() if v < window_start]
|
|
||||||
for k in keys_to_remove:
|
|
||||||
del store[k]
|
|
||||||
|
|
||||||
current_requests = len(store)
|
|
||||||
|
|
||||||
# Calculate remaining requests and reset time
|
|
||||||
remaining = max(0, limit - current_requests - 1)
|
|
||||||
reset_time = current_time + window_seconds
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"X-RateLimit-Limit": limit,
|
|
||||||
"X-RateLimit-Remaining": remaining,
|
|
||||||
"X-RateLimit-Reset": reset_time,
|
|
||||||
"X-RateLimit-Window": window_seconds
|
|
||||||
}
|
|
||||||
|
|
||||||
is_allowed = current_requests < limit
|
|
||||||
|
|
||||||
if is_allowed:
|
|
||||||
# Add current request
|
|
||||||
store[str(current_time)] = current_time
|
|
||||||
else:
|
|
||||||
logger.warning(f"Rate limit exceeded for key: {key}")
|
|
||||||
|
|
||||||
return is_allowed, headers
|
|
||||||
|
|
||||||
|
|
||||||
# Global rate limiter instance
|
|
||||||
rate_limiter = RateLimiter()
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""Rate limiting middleware for FastAPI"""
|
|
||||||
|
|
||||||
def __init__(self, app):
|
|
||||||
super().__init__(app)
|
|
||||||
self.rate_limiter = RateLimiter()
|
|
||||||
logger.info("RateLimitMiddleware initialized")
|
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
|
||||||
"""Process request through rate limiting"""
|
|
||||||
|
|
||||||
# Skip rate limiting if disabled in settings
|
|
||||||
if not settings.API_RATE_LIMITING_ENABLED:
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
|
|
||||||
# Skip rate limiting for all internal API endpoints (platform operations)
|
|
||||||
if request.url.path.startswith("/api-internal/v1/"):
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
|
|
||||||
# Only apply rate limiting to privatemode.ai proxy endpoints (OpenAI-compatible API and LLM service)
|
|
||||||
# Skip for all other endpoints
|
|
||||||
if not (request.url.path.startswith("/api/v1/chat/completions") or
|
|
||||||
request.url.path.startswith("/api/v1/embeddings") or
|
|
||||||
request.url.path.startswith("/api/v1/models") or
|
|
||||||
request.url.path.startswith("/api/v1/llm/")):
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
|
|
||||||
# Skip rate limiting for health checks and static files
|
|
||||||
if request.url.path in ["/health", "/", "/api/v1/docs", "/api/v1/openapi.json"]:
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
|
|
||||||
# Get client IP
|
|
||||||
client_ip = request.client.host
|
|
||||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
||||||
if forwarded_for:
|
|
||||||
client_ip = forwarded_for.split(",")[0].strip()
|
|
||||||
|
|
||||||
# Check for API key in headers
|
|
||||||
api_key = None
|
|
||||||
auth_header = request.headers.get("Authorization")
|
|
||||||
if auth_header and auth_header.startswith("Bearer "):
|
|
||||||
api_key = auth_header[7:]
|
|
||||||
elif request.headers.get("X-API-Key"):
|
|
||||||
api_key = request.headers.get("X-API-Key")
|
|
||||||
|
|
||||||
# Determine rate limiting strategy
|
|
||||||
headers = {}
|
|
||||||
is_allowed = True
|
|
||||||
|
|
||||||
if api_key:
|
|
||||||
# API key-based rate limiting
|
|
||||||
api_key_key = f"api_key:{api_key}"
|
|
||||||
|
|
||||||
# First check organization-wide limits (PrivateMode limits are org-wide)
|
|
||||||
org_key = "organization:privatemode"
|
|
||||||
|
|
||||||
# Check organization per-minute limit
|
|
||||||
org_allowed_minute, org_headers_minute = await self.rate_limiter.check_rate_limit(
|
|
||||||
org_key, settings.PRIVATEMODE_REQUESTS_PER_MINUTE, 60, "minute"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check organization per-hour limit
|
|
||||||
org_allowed_hour, org_headers_hour = await self.rate_limiter.check_rate_limit(
|
|
||||||
org_key, settings.PRIVATEMODE_REQUESTS_PER_HOUR, 3600, "hour"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If organization limits are exceeded, return 429
|
|
||||||
if not (org_allowed_minute and org_allowed_hour):
|
|
||||||
logger.warning(f"Organization rate limit exceeded for {org_key}")
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
||||||
content={"detail": "Organization rate limit exceeded"},
|
|
||||||
headers=org_headers_minute
|
|
||||||
)
|
|
||||||
|
|
||||||
# Then check per-API key limits
|
|
||||||
limit_per_minute = settings.API_RATE_LIMIT_API_KEY_PER_MINUTE
|
|
||||||
limit_per_hour = settings.API_RATE_LIMIT_API_KEY_PER_HOUR
|
|
||||||
|
|
||||||
# Check per-minute limit
|
|
||||||
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
|
|
||||||
api_key_key, limit_per_minute, 60, "minute"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check per-hour limit
|
|
||||||
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
|
|
||||||
api_key_key, limit_per_hour, 3600, "hour"
|
|
||||||
)
|
|
||||||
|
|
||||||
is_allowed = is_allowed_minute and is_allowed_hour
|
|
||||||
headers = headers_minute # Use minute headers for response
|
|
||||||
|
|
||||||
else:
|
|
||||||
# IP-based rate limiting for unauthenticated requests
|
|
||||||
rate_limit_key = f"ip:{client_ip}"
|
|
||||||
|
|
||||||
# More restrictive limits for unauthenticated requests
|
|
||||||
limit_per_minute = 20 # Hardcoded for unauthenticated users
|
|
||||||
limit_per_hour = 100
|
|
||||||
|
|
||||||
# Check per-minute limit
|
|
||||||
is_allowed_minute, headers_minute = await self.rate_limiter.check_rate_limit(
|
|
||||||
rate_limit_key, limit_per_minute, 60, "minute"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check per-hour limit
|
|
||||||
is_allowed_hour, headers_hour = await self.rate_limiter.check_rate_limit(
|
|
||||||
rate_limit_key, limit_per_hour, 3600, "hour"
|
|
||||||
)
|
|
||||||
|
|
||||||
is_allowed = is_allowed_minute and is_allowed_hour
|
|
||||||
headers = headers_minute # Use minute headers for response
|
|
||||||
|
|
||||||
# If rate limit exceeded, return 429
|
|
||||||
if not is_allowed:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
||||||
content={
|
|
||||||
"error": "RATE_LIMIT_EXCEEDED",
|
|
||||||
"message": "Rate limit exceeded. Please try again later.",
|
|
||||||
"details": {
|
|
||||||
"limit": headers["X-RateLimit-Limit"],
|
|
||||||
"reset_time": headers["X-RateLimit-Reset"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
headers={k: str(v) for k, v in headers.items()}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Continue with request
|
|
||||||
response = await call_next(request)
|
|
||||||
|
|
||||||
# Add rate limit headers to response
|
|
||||||
for key, value in headers.items():
|
|
||||||
response.headers[key] = str(value)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
# Keep the old function for backward compatibility
|
|
||||||
async def rate_limit_middleware(request: Request, call_next):
|
|
||||||
"""Legacy function - use RateLimitMiddleware class instead"""
|
|
||||||
middleware = RateLimitMiddleware(None)
|
|
||||||
return await middleware.dispatch(request, call_next)
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimitExceeded(HTTPException):
|
|
||||||
"""Exception raised when rate limit is exceeded"""
|
|
||||||
|
|
||||||
def __init__(self, limit: int, reset_time: int):
|
|
||||||
super().__init__(
|
|
||||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
||||||
detail=f"Rate limit exceeded. Limit: {limit}, Reset: {reset_time}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Decorator for applying rate limits to specific endpoints
|
|
||||||
def rate_limit(requests_per_minute: int = 60, requests_per_hour: int = 1000):
|
|
||||||
"""
|
|
||||||
Decorator to apply rate limiting to specific endpoints
|
|
||||||
|
|
||||||
Args:
|
|
||||||
requests_per_minute: Maximum requests per minute
|
|
||||||
requests_per_hour: Maximum requests per hour
|
|
||||||
"""
|
|
||||||
def decorator(func):
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
# This would be implemented to work with FastAPI dependencies
|
|
||||||
# For now, this is a placeholder for endpoint-specific rate limiting
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
return wrapper
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
# Helper functions for different rate limiting strategies
|
|
||||||
async def check_api_key_rate_limit(api_key: str, endpoint: str) -> bool:
|
|
||||||
"""Check rate limit for specific API key and endpoint"""
|
|
||||||
|
|
||||||
# This would lookup API key specific limits from database
|
|
||||||
# For now, using default limits
|
|
||||||
key = f"api_key:{api_key}:endpoint:{endpoint}"
|
|
||||||
|
|
||||||
is_allowed, _ = await rate_limiter.check_rate_limit(
|
|
||||||
key, limit=100, window_seconds=60, identifier="endpoint"
|
|
||||||
)
|
|
||||||
|
|
||||||
return is_allowed
|
|
||||||
|
|
||||||
|
|
||||||
async def check_user_rate_limit(user_id: str, action: str) -> bool:
|
|
||||||
"""Check rate limit for specific user and action"""
|
|
||||||
|
|
||||||
key = f"user:{user_id}:action:{action}"
|
|
||||||
|
|
||||||
is_allowed, _ = await rate_limiter.check_rate_limit(
|
|
||||||
key, limit=50, window_seconds=60, identifier="user_action"
|
|
||||||
)
|
|
||||||
|
|
||||||
return is_allowed
|
|
||||||
|
|
||||||
|
|
||||||
async def apply_burst_protection(key: str) -> bool:
|
|
||||||
"""Apply burst protection for high-frequency actions"""
|
|
||||||
|
|
||||||
# Allow burst of 10 requests in 10 seconds
|
|
||||||
is_allowed, _ = await rate_limiter.check_rate_limit(
|
|
||||||
key, limit=10, window_seconds=10, identifier="burst"
|
|
||||||
)
|
|
||||||
|
|
||||||
return is_allowed
|
|
||||||
@@ -1,210 +0,0 @@
|
|||||||
"""
|
|
||||||
Security middleware for request/response processing
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import Callable, Optional, Dict, Any
|
|
||||||
|
|
||||||
from fastapi import Request, Response
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.logging import get_logger
|
|
||||||
from app.core.threat_detection import threat_detection_service, SecurityAnalysis
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SecurityMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""Security middleware for threat detection and request filtering - DISABLED"""
|
|
||||||
|
|
||||||
def __init__(self, app, enabled: bool = True):
|
|
||||||
super().__init__(app)
|
|
||||||
self.enabled = False # Force disable regardless of settings
|
|
||||||
logger.info("SecurityMiddleware initialized, enabled: False (DISABLED)")
|
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
||||||
"""Process request through security analysis - DISABLED"""
|
|
||||||
# Security disabled, always pass through
|
|
||||||
return await call_next(request)
|
|
||||||
|
|
||||||
def _should_skip_security(self, request: Request) -> bool:
|
|
||||||
"""Determine if security analysis should be skipped for this request"""
|
|
||||||
path = request.url.path
|
|
||||||
|
|
||||||
# Skip for health checks, authentication endpoints, and static assets
|
|
||||||
skip_paths = [
|
|
||||||
"/health",
|
|
||||||
"/metrics",
|
|
||||||
"/api/v1/docs",
|
|
||||||
"/api/v1/openapi.json",
|
|
||||||
"/api/v1/redoc",
|
|
||||||
"/favicon.ico",
|
|
||||||
"/api/v1/auth/register",
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
"/api/v1/auth/refresh", # Allow refresh endpoint
|
|
||||||
"/api-internal/v1/auth/register",
|
|
||||||
"/api-internal/v1/auth/login",
|
|
||||||
"/api-internal/v1/auth/refresh", # Allow refresh endpoint for internal API
|
|
||||||
"/", # Root endpoint
|
|
||||||
]
|
|
||||||
|
|
||||||
# Skip for static file extensions
|
|
||||||
static_extensions = [".css", ".js", ".png", ".jpg", ".jpeg", ".gif", ".ico", ".svg", ".woff", ".woff2"]
|
|
||||||
|
|
||||||
return (
|
|
||||||
path in skip_paths or
|
|
||||||
any(path.endswith(ext) for ext in static_extensions) or
|
|
||||||
path.startswith("/static/")
|
|
||||||
)
|
|
||||||
|
|
||||||
def _has_valid_auth(self, request: Request) -> bool:
|
|
||||||
"""Check if request has valid authentication"""
|
|
||||||
# Check Authorization header
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
|
||||||
api_key_header = request.headers.get("X-API-Key", "")
|
|
||||||
|
|
||||||
# Has some form of auth token/key
|
|
||||||
return (
|
|
||||||
auth_header.startswith("Bearer ") and len(auth_header) > 7 or
|
|
||||||
len(api_key_header.strip()) > 0
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_block_response(self, analysis: SecurityAnalysis) -> JSONResponse:
|
|
||||||
"""Create response for blocked requests"""
|
|
||||||
# Determine status code based on threat type
|
|
||||||
status_code = 403 # Forbidden by default
|
|
||||||
|
|
||||||
# Critical threats get 403
|
|
||||||
for threat in analysis.threats:
|
|
||||||
if threat.threat_type in ["command_injection", "sql_injection"]:
|
|
||||||
status_code = 403
|
|
||||||
break
|
|
||||||
|
|
||||||
response_data = {
|
|
||||||
"error": "Security Policy Violation",
|
|
||||||
"message": "Request blocked due to security policy violation",
|
|
||||||
"risk_score": round(analysis.risk_score, 3),
|
|
||||||
"auth_level": analysis.auth_level.value,
|
|
||||||
"threat_count": len(analysis.threats),
|
|
||||||
"recommendations": analysis.recommendations[:3] # Limit to first 3 recommendations
|
|
||||||
}
|
|
||||||
|
|
||||||
response = JSONResponse(
|
|
||||||
content=response_data,
|
|
||||||
status_code=status_code
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _add_security_headers(self, response: Response) -> Response:
|
|
||||||
"""Add security headers to response"""
|
|
||||||
if not settings.API_SECURITY_HEADERS_ENABLED:
|
|
||||||
return response
|
|
||||||
|
|
||||||
# Standard security headers
|
|
||||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
||||||
response.headers["X-Frame-Options"] = "DENY"
|
|
||||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
|
||||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
||||||
|
|
||||||
# Only add HSTS for HTTPS
|
|
||||||
if hasattr(response, 'headers') and response.headers.get("X-Forwarded-Proto") == "https":
|
|
||||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
|
||||||
|
|
||||||
# Content Security Policy
|
|
||||||
if settings.API_CSP_HEADER:
|
|
||||||
response.headers["Content-Security-Policy"] = settings.API_CSP_HEADER
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _add_security_metrics(self, response: Response, analysis: SecurityAnalysis, analysis_time: float) -> Response:
|
|
||||||
"""Add security metrics to response headers (for debugging/monitoring)"""
|
|
||||||
# Only add in debug mode or for admin users
|
|
||||||
if settings.APP_DEBUG:
|
|
||||||
response.headers["X-Security-Risk-Score"] = str(round(analysis.risk_score, 3))
|
|
||||||
response.headers["X-Security-Threats"] = str(len(analysis.threats))
|
|
||||||
response.headers["X-Security-Auth-Level"] = analysis.auth_level.value
|
|
||||||
response.headers["X-Security-Analysis-Time"] = f"{analysis_time*1000:.1f}ms"
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def _log_security_event(self, request: Request, analysis: SecurityAnalysis):
|
|
||||||
"""Log security events for audit and monitoring"""
|
|
||||||
client_ip = request.client.host if request.client else "unknown"
|
|
||||||
user_agent = request.headers.get("user-agent", "")
|
|
||||||
|
|
||||||
# Create security event log
|
|
||||||
event_data = {
|
|
||||||
"timestamp": analysis.timestamp.isoformat(),
|
|
||||||
"client_ip": client_ip,
|
|
||||||
"user_agent": user_agent,
|
|
||||||
"path": str(request.url.path),
|
|
||||||
"method": request.method,
|
|
||||||
"risk_score": round(analysis.risk_score, 3),
|
|
||||||
"auth_level": analysis.auth_level.value,
|
|
||||||
"threat_count": len(analysis.threats),
|
|
||||||
"rate_limit_exceeded": analysis.rate_limit_exceeded,
|
|
||||||
"should_block": analysis.should_block,
|
|
||||||
"threats": [
|
|
||||||
{
|
|
||||||
"type": threat.threat_type,
|
|
||||||
"level": threat.level.value,
|
|
||||||
"confidence": round(threat.confidence, 3),
|
|
||||||
"description": threat.description
|
|
||||||
}
|
|
||||||
for threat in analysis.threats[:5] # Limit to first 5 threats
|
|
||||||
],
|
|
||||||
"recommendations": analysis.recommendations
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log at appropriate level based on risk
|
|
||||||
if analysis.should_block:
|
|
||||||
logger.warning(f"SECURITY_BLOCK: {json.dumps(event_data)}")
|
|
||||||
elif analysis.risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
|
|
||||||
logger.warning(f"SECURITY_WARNING: {json.dumps(event_data)}")
|
|
||||||
else:
|
|
||||||
logger.info(f"SECURITY_THREAT: {json.dumps(event_data)}")
|
|
||||||
|
|
||||||
|
|
||||||
def setup_security_middleware(app, enabled: bool = True) -> None:
|
|
||||||
"""Setup security middleware on FastAPI app"""
|
|
||||||
if enabled and settings.API_SECURITY_ENABLED:
|
|
||||||
app.add_middleware(SecurityMiddleware, enabled=enabled)
|
|
||||||
logger.info("Security middleware enabled")
|
|
||||||
else:
|
|
||||||
logger.info("Security middleware disabled")
|
|
||||||
|
|
||||||
|
|
||||||
# Helper functions for manual security checks
|
|
||||||
async def analyze_request_security(request: Request, user_context: Optional[Dict] = None) -> SecurityAnalysis:
|
|
||||||
"""Manually analyze request security (for use in route handlers)"""
|
|
||||||
return await threat_detection_service.analyze_request(request, user_context)
|
|
||||||
|
|
||||||
|
|
||||||
def get_security_stats() -> Dict[str, Any]:
|
|
||||||
"""Get security statistics"""
|
|
||||||
return threat_detection_service.get_stats()
|
|
||||||
|
|
||||||
|
|
||||||
def is_request_blocked(request: Request) -> bool:
|
|
||||||
"""Check if request was blocked by security analysis"""
|
|
||||||
if hasattr(request.state, 'security_analysis'):
|
|
||||||
return request.state.security_analysis.should_block
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_request_risk_score(request: Request) -> float:
|
|
||||||
"""Get risk score for request"""
|
|
||||||
if hasattr(request.state, 'security_analysis'):
|
|
||||||
return request.state.security_analysis.risk_score
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
|
|
||||||
def get_request_auth_level(request: Request) -> str:
|
|
||||||
"""Get authentication level for request"""
|
|
||||||
if hasattr(request.state, 'security_analysis'):
|
|
||||||
return request.state.security_analysis.auth_level.value
|
|
||||||
return "unknown"
|
|
||||||
@@ -162,6 +162,7 @@ class DocumentProcessor:
|
|||||||
|
|
||||||
async def _process_document(self, task: ProcessingTask) -> bool:
|
async def _process_document(self, task: ProcessingTask) -> bool:
|
||||||
"""Process a single document"""
|
"""Process a single document"""
|
||||||
|
from datetime import datetime
|
||||||
from app.db.database import async_session_factory
|
from app.db.database import async_session_factory
|
||||||
async with async_session_factory() as session:
|
async with async_session_factory() as session:
|
||||||
try:
|
try:
|
||||||
@@ -182,16 +183,24 @@ class DocumentProcessor:
|
|||||||
document.status = ProcessingStatus.PROCESSING
|
document.status = ProcessingStatus.PROCESSING
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# Get RAG module for processing (now includes content processing)
|
# Get RAG module for processing
|
||||||
try:
|
try:
|
||||||
from app.services.module_manager import module_manager
|
# Import RAG module and initialize it properly
|
||||||
rag_module = module_manager.get_module('rag')
|
from modules.rag.main import RAGModule
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
# Create and initialize RAG module instance
|
||||||
|
rag_module = RAGModule(settings)
|
||||||
|
init_result = await rag_module.initialize()
|
||||||
|
if not rag_module.enabled:
|
||||||
|
raise Exception("Failed to enable RAG module")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get RAG module: {e}")
|
logger.error(f"Failed to get RAG module: {e}")
|
||||||
raise Exception(f"RAG module not available: {e}")
|
raise Exception(f"RAG module not available: {e}")
|
||||||
|
|
||||||
if not rag_module:
|
if not rag_module or not rag_module.enabled:
|
||||||
raise Exception("RAG module not available")
|
raise Exception("RAG module not available or not enabled")
|
||||||
|
|
||||||
logger.info(f"RAG module loaded successfully for document {task.document_id}")
|
logger.info(f"RAG module loaded successfully for document {task.document_id}")
|
||||||
|
|
||||||
@@ -204,23 +213,31 @@ class DocumentProcessor:
|
|||||||
|
|
||||||
# Process with RAG module
|
# Process with RAG module
|
||||||
logger.info(f"Starting document processing for document {task.document_id} with RAG module")
|
logger.info(f"Starting document processing for document {task.document_id} with RAG module")
|
||||||
|
|
||||||
|
# Special handling for JSONL files - skip processing phase
|
||||||
|
if document.file_type == 'jsonl':
|
||||||
|
# For JSONL files, we don't need to process content here
|
||||||
|
# The optimized JSONL processor will handle everything during indexing
|
||||||
|
document.converted_content = f"JSONL file with {len(file_content)} bytes"
|
||||||
|
document.word_count = 0 # Will be updated during indexing
|
||||||
|
document.character_count = len(file_content)
|
||||||
|
document.document_metadata = {"file_path": document.file_path, "processed": "jsonl"}
|
||||||
|
document.status = ProcessingStatus.PROCESSED
|
||||||
|
document.processed_at = datetime.utcnow()
|
||||||
|
logger.info(f"JSONL document {task.document_id} marked for optimized processing")
|
||||||
|
else:
|
||||||
|
# Standard processing for other file types
|
||||||
try:
|
try:
|
||||||
# Add timeout to prevent hanging
|
# Add timeout to prevent hanging
|
||||||
processed_doc = await asyncio.wait_for(
|
processed_doc = await asyncio.wait_for(
|
||||||
rag_module.process_document(
|
rag_module.process_document(
|
||||||
file_content,
|
file_content,
|
||||||
document.original_filename,
|
document.original_filename,
|
||||||
{}
|
{"file_path": document.file_path}
|
||||||
),
|
),
|
||||||
timeout=300.0 # 5 minute timeout
|
timeout=300.0 # 5 minute timeout
|
||||||
)
|
)
|
||||||
logger.info(f"Document processing completed for document {task.document_id}")
|
logger.info(f"Document processing completed for document {task.document_id}")
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(f"Document processing timed out for document {task.document_id}")
|
|
||||||
raise Exception("Document processing timed out after 5 minutes")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Document processing failed for document {task.document_id}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Update document with processed content
|
# Update document with processed content
|
||||||
document.converted_content = processed_doc.content
|
document.converted_content = processed_doc.content
|
||||||
@@ -229,6 +246,12 @@ class DocumentProcessor:
|
|||||||
document.document_metadata = processed_doc.metadata
|
document.document_metadata = processed_doc.metadata
|
||||||
document.status = ProcessingStatus.PROCESSED
|
document.status = ProcessingStatus.PROCESSED
|
||||||
document.processed_at = datetime.utcnow()
|
document.processed_at = datetime.utcnow()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"Document processing timed out for document {task.document_id}")
|
||||||
|
raise Exception("Document processing timed out after 5 minutes")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Document processing failed for document {task.document_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
# Index in RAG system using same RAG module
|
# Index in RAG system using same RAG module
|
||||||
if rag_module and document.converted_content:
|
if rag_module and document.converted_content:
|
||||||
@@ -245,6 +268,49 @@ class DocumentProcessor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Use the correct Qdrant collection name for this document
|
# Use the correct Qdrant collection name for this document
|
||||||
|
# For JSONL files, we need to use the processed document flow
|
||||||
|
if document.file_type == 'jsonl':
|
||||||
|
# Create a ProcessedDocument for the JSONL processor
|
||||||
|
from app.modules.rag.main import ProcessedDocument
|
||||||
|
from datetime import datetime
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
# Calculate file hash
|
||||||
|
processed_at = datetime.utcnow()
|
||||||
|
file_hash = hashlib.md5(str(document.id).encode()).hexdigest()
|
||||||
|
|
||||||
|
processed_doc = ProcessedDocument(
|
||||||
|
id=str(document.id),
|
||||||
|
content="", # Will be filled by JSONL processor
|
||||||
|
extracted_text="", # Will be filled by JSONL processor
|
||||||
|
metadata={
|
||||||
|
**doc_metadata,
|
||||||
|
"file_path": document.file_path
|
||||||
|
},
|
||||||
|
original_filename=document.original_filename,
|
||||||
|
file_type=document.file_type,
|
||||||
|
mime_type=document.mime_type,
|
||||||
|
language=document.document_metadata.get('language', 'EN'),
|
||||||
|
word_count=0, # Will be updated during processing
|
||||||
|
sentence_count=0, # Will be updated during processing
|
||||||
|
entities=[],
|
||||||
|
keywords=[],
|
||||||
|
processing_time=0.0,
|
||||||
|
processed_at=processed_at,
|
||||||
|
file_hash=file_hash,
|
||||||
|
file_size=document.file_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# The JSONL processor will read the original file
|
||||||
|
await asyncio.wait_for(
|
||||||
|
rag_module.index_processed_document(
|
||||||
|
processed_doc=processed_doc,
|
||||||
|
collection_name=document.collection.qdrant_collection_name
|
||||||
|
),
|
||||||
|
timeout=300.0 # 5 minute timeout for JSONL processing
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use standard indexing for other file types
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
rag_module.index_document(
|
rag_module.index_document(
|
||||||
content=document.converted_content,
|
content=document.converted_content,
|
||||||
@@ -271,7 +337,9 @@ class DocumentProcessor:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to index document {task.document_id} in RAG: {e}")
|
logger.error(f"Failed to index document {task.document_id} in RAG: {e}")
|
||||||
# Keep as processed even if indexing fails
|
# Mark as error since indexing failed
|
||||||
|
document.status = ProcessingStatus.ERROR
|
||||||
|
document.processing_error = f"Indexing failed: {str(e)}"
|
||||||
# Don't raise the exception to avoid retries on indexing failures
|
# Don't raise the exception to avoid retries on indexing failures
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -28,9 +28,19 @@ class EmbeddingService:
|
|||||||
await llm_service.initialize()
|
await llm_service.initialize()
|
||||||
|
|
||||||
# Test LLM service health
|
# Test LLM service health
|
||||||
health_summary = llm_service.get_health_summary()
|
if not llm_service._initialized:
|
||||||
if health_summary.get("service_status") != "healthy":
|
logger.error("LLM service not initialized")
|
||||||
logger.error(f"LLM service unhealthy: {health_summary}")
|
return False
|
||||||
|
|
||||||
|
# Check if PrivateMode provider is available
|
||||||
|
try:
|
||||||
|
provider_status = await llm_service.get_provider_status()
|
||||||
|
privatemode_status = provider_status.get("privatemode")
|
||||||
|
if not privatemode_status or privatemode_status.status != "healthy":
|
||||||
|
logger.error(f"PrivateMode provider not available: {privatemode_status}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to check provider status: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
@@ -75,6 +85,12 @@ class EmbeddingService:
|
|||||||
else:
|
else:
|
||||||
truncated_text = text
|
truncated_text = text
|
||||||
|
|
||||||
|
# Guard: skip empty inputs (validator rejects empty strings)
|
||||||
|
if not truncated_text.strip():
|
||||||
|
logger.debug("Empty input for embedding; using fallback vector")
|
||||||
|
batch_embeddings.append(self._generate_fallback_embedding(text))
|
||||||
|
continue
|
||||||
|
|
||||||
# Call LLM service embedding endpoint
|
# Call LLM service embedding endpoint
|
||||||
from app.services.llm.service import llm_service
|
from app.services.llm.service import llm_service
|
||||||
from app.services.llm.models import EmbeddingRequest
|
from app.services.llm.models import EmbeddingRequest
|
||||||
|
|||||||
@@ -25,9 +25,10 @@ class EnhancedEmbeddingService(EmbeddingService):
|
|||||||
'requests_count': 0,
|
'requests_count': 0,
|
||||||
'window_start': time.time(),
|
'window_start': time.time(),
|
||||||
'window_size': 60, # 1 minute window
|
'window_size': 60, # 1 minute window
|
||||||
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 60)), # Configurable
|
'max_requests_per_minute': int(getattr(settings, 'RAG_EMBEDDING_MAX_REQUESTS_PER_MINUTE', 12)), # Configurable
|
||||||
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
|
'retry_delays': [int(x) for x in getattr(settings, 'RAG_EMBEDDING_RETRY_DELAYS', '1,2,4,8,16').split(',')], # Exponential backoff
|
||||||
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 0.5)),
|
'delay_between_batches': float(getattr(settings, 'RAG_EMBEDDING_DELAY_BETWEEN_BATCHES', 1.0)),
|
||||||
|
'delay_per_request': float(getattr(settings, 'RAG_EMBEDDING_DELAY_PER_REQUEST', 0.5)),
|
||||||
'last_rate_limit_error': None
|
'last_rate_limit_error': None
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,7 +39,7 @@ class EnhancedEmbeddingService(EmbeddingService):
|
|||||||
if max_retries is None:
|
if max_retries is None:
|
||||||
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
|
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3))
|
||||||
|
|
||||||
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 5))
|
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 3))
|
||||||
|
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
logger.warning("Embedding service not initialized, using fallback")
|
logger.warning("Embedding service not initialized, using fallback")
|
||||||
@@ -76,9 +77,6 @@ class EnhancedEmbeddingService(EmbeddingService):
|
|||||||
# Make the request
|
# Make the request
|
||||||
embeddings = await self._get_embeddings_batch_impl(texts)
|
embeddings = await self._get_embeddings_batch_impl(texts)
|
||||||
|
|
||||||
# Update rate limit tracker on success
|
|
||||||
self._update_rate_limit_tracker(success=True)
|
|
||||||
|
|
||||||
return embeddings, True
|
return embeddings, True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -120,6 +118,12 @@ class EnhancedEmbeddingService(EmbeddingService):
|
|||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
for text in texts:
|
for text in texts:
|
||||||
|
# Respect rate limit before each request
|
||||||
|
while self._is_rate_limited():
|
||||||
|
delay = self._get_rate_limit_delay()
|
||||||
|
logger.warning(f"Rate limit window exceeded, waiting {delay:.2f}s before next request")
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
# Truncate text if needed
|
# Truncate text if needed
|
||||||
max_chars = 1600
|
max_chars = 1600
|
||||||
truncated_text = text[:max_chars] if len(text) > max_chars else text
|
truncated_text = text[:max_chars] if len(text) > max_chars else text
|
||||||
@@ -145,6 +149,12 @@ class EnhancedEmbeddingService(EmbeddingService):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid response structure")
|
raise ValueError("Invalid response structure")
|
||||||
|
|
||||||
|
# Count this successful request and optionally delay between requests
|
||||||
|
self._update_rate_limit_tracker(success=True)
|
||||||
|
per_req_delay = self.rate_limit_tracker.get('delay_per_request', 0)
|
||||||
|
if per_req_delay and per_req_delay > 0:
|
||||||
|
await asyncio.sleep(per_req_delay)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def _is_rate_limited(self) -> bool:
|
def _is_rate_limited(self) -> bool:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from .models import ResilienceConfig
|
|||||||
class ProviderConfig(BaseModel):
|
class ProviderConfig(BaseModel):
|
||||||
"""Configuration for an LLM provider"""
|
"""Configuration for an LLM provider"""
|
||||||
name: str = Field(..., description="Provider name")
|
name: str = Field(..., description="Provider name")
|
||||||
|
provider_type: str = Field(..., description="Provider type (e.g., 'openai', 'privatemode')")
|
||||||
enabled: bool = Field(True, description="Whether provider is enabled")
|
enabled: bool = Field(True, description="Whether provider is enabled")
|
||||||
base_url: str = Field(..., description="Provider base URL")
|
base_url: str = Field(..., description="Provider base URL")
|
||||||
api_key_env_var: str = Field(..., description="Environment variable for API key")
|
api_key_env_var: str = Field(..., description="Environment variable for API key")
|
||||||
@@ -53,9 +54,6 @@ class LLMServiceConfig(BaseModel):
|
|||||||
enable_security_checks: bool = Field(True, description="Enable security validation")
|
enable_security_checks: bool = Field(True, description="Enable security validation")
|
||||||
enable_metrics_collection: bool = Field(True, description="Enable metrics collection")
|
enable_metrics_collection: bool = Field(True, description="Enable metrics collection")
|
||||||
|
|
||||||
# Security settings
|
|
||||||
security_risk_threshold: float = Field(0.8, ge=0.0, le=1.0, description="Risk threshold for blocking")
|
|
||||||
security_warning_threshold: float = Field(0.6, ge=0.0, le=1.0, description="Risk threshold for warnings")
|
|
||||||
max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length")
|
max_prompt_length: int = Field(50000, ge=1000, description="Maximum prompt length")
|
||||||
max_response_length: int = Field(32000, ge=1000, description="Maximum response length")
|
max_response_length: int = Field(32000, ge=1000, description="Maximum response length")
|
||||||
|
|
||||||
@@ -78,12 +76,6 @@ class LLMServiceConfig(BaseModel):
|
|||||||
# Model routing (model_name -> provider_name)
|
# Model routing (model_name -> provider_name)
|
||||||
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
|
model_routing: Dict[str, str] = Field(default_factory=dict, description="Model to provider routing")
|
||||||
|
|
||||||
@validator('security_risk_threshold')
|
|
||||||
def validate_risk_threshold(cls, v, values):
|
|
||||||
warning_threshold = values.get('security_warning_threshold', 0.6)
|
|
||||||
if v <= warning_threshold:
|
|
||||||
raise ValueError("Risk threshold must be greater than warning threshold")
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
def create_default_config() -> LLMServiceConfig:
|
def create_default_config() -> LLMServiceConfig:
|
||||||
@@ -93,6 +85,7 @@ def create_default_config() -> LLMServiceConfig:
|
|||||||
# Models will be fetched dynamically from proxy /models endpoint
|
# Models will be fetched dynamically from proxy /models endpoint
|
||||||
privatemode_config = ProviderConfig(
|
privatemode_config = ProviderConfig(
|
||||||
name="privatemode",
|
name="privatemode",
|
||||||
|
provider_type="privatemode",
|
||||||
enabled=True,
|
enabled=True,
|
||||||
base_url=settings.PRIVATEMODE_PROXY_URL,
|
base_url=settings.PRIVATEMODE_PROXY_URL,
|
||||||
api_key_env_var="PRIVATEMODE_API_KEY",
|
api_key_env_var="PRIVATEMODE_API_KEY",
|
||||||
@@ -119,9 +112,6 @@ def create_default_config() -> LLMServiceConfig:
|
|||||||
config = LLMServiceConfig(
|
config = LLMServiceConfig(
|
||||||
default_provider="privatemode",
|
default_provider="privatemode",
|
||||||
enable_detailed_logging=settings.LOG_LLM_PROMPTS,
|
enable_detailed_logging=settings.LOG_LLM_PROMPTS,
|
||||||
enable_security_checks=settings.API_SECURITY_ENABLED,
|
|
||||||
security_risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
|
|
||||||
security_warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
|
|
||||||
providers={
|
providers={
|
||||||
"privatemode": privatemode_config
|
"privatemode": privatemode_config
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -124,7 +124,6 @@ class MetricsCollector:
|
|||||||
total_requests = len(self._metrics)
|
total_requests = len(self._metrics)
|
||||||
successful_requests = sum(1 for m in self._metrics if m.success)
|
successful_requests = sum(1 for m in self._metrics if m.success)
|
||||||
failed_requests = total_requests - successful_requests
|
failed_requests = total_requests - successful_requests
|
||||||
security_blocked = sum(1 for m in self._metrics if not m.success and m.security_risk_score > 0.8)
|
|
||||||
|
|
||||||
# Calculate averages
|
# Calculate averages
|
||||||
latencies = [m.latency_ms for m in self._metrics if m.latency_ms > 0]
|
latencies = [m.latency_ms for m in self._metrics if m.latency_ms > 0]
|
||||||
@@ -143,7 +142,6 @@ class MetricsCollector:
|
|||||||
total_requests=total_requests,
|
total_requests=total_requests,
|
||||||
successful_requests=successful_requests,
|
successful_requests=successful_requests,
|
||||||
failed_requests=failed_requests,
|
failed_requests=failed_requests,
|
||||||
security_blocked_requests=security_blocked,
|
|
||||||
average_latency_ms=avg_latency,
|
average_latency_ms=avg_latency,
|
||||||
average_risk_score=avg_risk_score,
|
average_risk_score=avg_risk_score,
|
||||||
provider_metrics=provider_metrics,
|
provider_metrics=provider_metrics,
|
||||||
|
|||||||
@@ -157,7 +157,6 @@ class LLMMetrics(BaseModel):
|
|||||||
total_requests: int = Field(0, description="Total requests processed")
|
total_requests: int = Field(0, description="Total requests processed")
|
||||||
successful_requests: int = Field(0, description="Successful requests")
|
successful_requests: int = Field(0, description="Successful requests")
|
||||||
failed_requests: int = Field(0, description="Failed requests")
|
failed_requests: int = Field(0, description="Failed requests")
|
||||||
security_blocked_requests: int = Field(0, description="Security blocked requests")
|
|
||||||
average_latency_ms: float = Field(0.0, description="Average response latency")
|
average_latency_ms: float = Field(0.0, description="Average response latency")
|
||||||
average_risk_score: float = Field(0.0, description="Average security risk score")
|
average_risk_score: float = Field(0.0, description="Average security risk score")
|
||||||
provider_metrics: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Per-provider metrics")
|
provider_metrics: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Per-provider metrics")
|
||||||
|
|||||||
@@ -452,6 +452,8 @@ class PrivateModeProvider(BaseLLMProvider):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
|
# Log the detailed error response from the provider
|
||||||
|
logger.error(f"PrivateMode embedding error - Status {response.status}: {error_text}")
|
||||||
self._handle_http_error(response.status, error_text, "embeddings")
|
self._handle_http_error(response.status, error_text, "embeddings")
|
||||||
|
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
|
|||||||
@@ -1,325 +0,0 @@
|
|||||||
"""
|
|
||||||
LLM Security Manager
|
|
||||||
|
|
||||||
Handles prompt injection detection and audit logging.
|
|
||||||
Provides comprehensive security for LLM interactions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import hashlib
|
|
||||||
from typing import Dict, Any, List, Optional, Tuple
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SecurityManager:
|
|
||||||
"""Manages security for LLM operations"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._setup_prompt_injection_patterns()
|
|
||||||
|
|
||||||
|
|
||||||
def _setup_prompt_injection_patterns(self):
|
|
||||||
"""Setup patterns for prompt injection detection"""
|
|
||||||
self.injection_patterns = [
|
|
||||||
# Direct instruction injection
|
|
||||||
r"(?i)(ignore|forget|disregard|override).{0,20}(instructions|rules|prompts)",
|
|
||||||
r"(?i)(new|updated|different)\s+(instructions|rules|system)",
|
|
||||||
r"(?i)act\s+as\s+(if|though)\s+you\s+(are|were)",
|
|
||||||
r"(?i)pretend\s+(to\s+be|you\s+are)",
|
|
||||||
r"(?i)you\s+are\s+now\s+(a|an)\s+",
|
|
||||||
|
|
||||||
# System role manipulation
|
|
||||||
r"(?i)system\s*:\s*",
|
|
||||||
r"(?i)\[system\]",
|
|
||||||
r"(?i)<system>",
|
|
||||||
r"(?i)assistant\s*:\s*",
|
|
||||||
r"(?i)\[assistant\]",
|
|
||||||
|
|
||||||
# Escape attempts
|
|
||||||
r"(?i)\\n\\n#+",
|
|
||||||
r"(?i)```\s*(system|assistant|user)",
|
|
||||||
r"(?i)---\s*(new|system|override)",
|
|
||||||
|
|
||||||
# Role manipulation
|
|
||||||
r"(?i)(you|your)\s+(role|purpose|function)\s+(is|has\s+changed)",
|
|
||||||
r"(?i)switch\s+to\s+(admin|developer|debug)\s+mode",
|
|
||||||
r"(?i)(admin|root|sudo|developer)\s+(access|mode|privileges)",
|
|
||||||
|
|
||||||
# Information extraction attempts
|
|
||||||
r"(?i)(show|display|reveal|expose)\s+(your|the)\s+(prompt|instructions|system)",
|
|
||||||
r"(?i)what\s+(are|were)\s+your\s+(original|initial)\s+(instructions|prompts)",
|
|
||||||
r"(?i)(debug|verbose|diagnostic)\s+mode",
|
|
||||||
|
|
||||||
# Encoding/obfuscation attempts
|
|
||||||
r"(?i)base64\s*:",
|
|
||||||
r"(?i)hex\s*:",
|
|
||||||
r"(?i)unicode\s*:",
|
|
||||||
r"(?i)\b[A-Za-z0-9+/]{40,}={0,2}\b", # More specific base64 pattern (longer sequences)
|
|
||||||
|
|
||||||
# SQL injection patterns (more specific to reduce false positives)
|
|
||||||
r"(?i)(union\s+select|select\s+\*|insert\s+into|update\s+\w+\s+set|delete\s+from|drop\s+table|create\s+table)\s",
|
|
||||||
r"(?i)(or|and)\s+\d+\s*=\s*\d+",
|
|
||||||
r"(?i)';?\s*(drop\s+table|delete\s+from|insert\s+into)",
|
|
||||||
|
|
||||||
# Command injection patterns
|
|
||||||
r"(?i)(exec|eval|system|shell|cmd)\s*\(",
|
|
||||||
r"(?i)(\$\(|\`)[^)]+(\)|\`)",
|
|
||||||
r"(?i)&&\s*(rm|del|format)",
|
|
||||||
|
|
||||||
# Jailbreak attempts
|
|
||||||
r"(?i)jailbreak",
|
|
||||||
r"(?i)break\s+out\s+of",
|
|
||||||
r"(?i)escape\s+(the|your)\s+(rules|constraints)",
|
|
||||||
r"(?i)(DAN|Do\s+Anything\s+Now)",
|
|
||||||
r"(?i)unrestricted\s+mode",
|
|
||||||
]
|
|
||||||
|
|
||||||
self.compiled_patterns = [re.compile(pattern) for pattern in self.injection_patterns]
|
|
||||||
logger.info(f"Initialized {len(self.injection_patterns)} prompt injection patterns")
|
|
||||||
|
|
||||||
|
|
||||||
def validate_prompt_security(self, messages: List[Dict[str, str]]) -> Tuple[bool, float, List[str]]:
|
|
||||||
"""
|
|
||||||
Validate messages for prompt injection attempts
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, float, List[str]]: (is_safe, risk_score, detected_patterns)
|
|
||||||
"""
|
|
||||||
detected_patterns = []
|
|
||||||
total_risk = 0.0
|
|
||||||
|
|
||||||
# Check if this is a system/RAG request
|
|
||||||
is_system_request = self._is_system_request(messages)
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
content = message.get("content", "")
|
|
||||||
if not content:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check against injection patterns with context awareness
|
|
||||||
for i, pattern in enumerate(self.compiled_patterns):
|
|
||||||
matches = pattern.findall(content)
|
|
||||||
if matches:
|
|
||||||
# Apply context-aware risk calculation
|
|
||||||
pattern_risk = self._calculate_pattern_risk(i, matches, message.get("role", "user"), is_system_request)
|
|
||||||
total_risk += pattern_risk
|
|
||||||
detected_patterns.append({
|
|
||||||
"pattern_index": i,
|
|
||||||
"pattern": self.injection_patterns[i],
|
|
||||||
"matches": matches,
|
|
||||||
"risk": pattern_risk
|
|
||||||
})
|
|
||||||
|
|
||||||
# Additional security checks with context awareness
|
|
||||||
total_risk += self._check_message_characteristics(content, message.get("role", "user"), is_system_request)
|
|
||||||
|
|
||||||
# Normalize risk score (0.0 to 1.0)
|
|
||||||
risk_score = min(total_risk / len(messages) if messages else 0.0, 1.0)
|
|
||||||
# Never block - always return True for is_safe
|
|
||||||
is_safe = True
|
|
||||||
|
|
||||||
if detected_patterns:
|
|
||||||
logger.info(f"Detected {len(detected_patterns)} potential injection patterns, risk score: {risk_score} (system_request: {is_system_request})")
|
|
||||||
|
|
||||||
return is_safe, risk_score, detected_patterns
|
|
||||||
|
|
||||||
def _calculate_pattern_risk(self, pattern_index: int, matches: List, role: str, is_system_request: bool) -> float:
|
|
||||||
"""Calculate risk score for a detected pattern with context awareness"""
|
|
||||||
# Different patterns have different risk levels
|
|
||||||
high_risk_patterns = [0, 1, 2, 3, 4, 5, 6, 7, 22, 23, 24] # System manipulation, jailbreak
|
|
||||||
medium_risk_patterns = [8, 9, 10, 11, 12, 13, 17, 18, 19, 20, 21] # Escape attempts, info extraction
|
|
||||||
|
|
||||||
# Base risk score
|
|
||||||
base_risk = 0.8 if pattern_index in high_risk_patterns else 0.5 if pattern_index in medium_risk_patterns else 0.3
|
|
||||||
|
|
||||||
# Apply context-specific risk reduction
|
|
||||||
if is_system_request or role == "system":
|
|
||||||
# Reduce risk for system messages and RAG content
|
|
||||||
if pattern_index in [14, 15, 16]: # Encoding patterns (base64, hex, unicode)
|
|
||||||
base_risk *= 0.2 # Reduce encoding risk by 80% for system content
|
|
||||||
elif pattern_index in [17, 18, 19]: # SQL patterns
|
|
||||||
base_risk *= 0.3 # Reduce SQL risk by 70% for system content
|
|
||||||
else:
|
|
||||||
base_risk *= 0.6 # Reduce other risks by 40% for system content
|
|
||||||
|
|
||||||
# Increase risk based on number of matches, but cap it
|
|
||||||
match_multiplier = min(1.0 + (len(matches) - 1) * 0.1, 1.5) # Reduced multiplier
|
|
||||||
|
|
||||||
return base_risk * match_multiplier
|
|
||||||
|
|
||||||
def _check_message_characteristics(self, content: str, role: str, is_system_request: bool) -> float:
|
|
||||||
"""Check message characteristics for additional risk factors with context awareness"""
|
|
||||||
risk = 0.0
|
|
||||||
|
|
||||||
# Excessive length (potential stuffing attack) - less restrictive for system content
|
|
||||||
length_threshold = 50000 if is_system_request else 10000 # Much higher threshold for system content
|
|
||||||
if len(content) > length_threshold:
|
|
||||||
risk += 0.1 if is_system_request else 0.3
|
|
||||||
|
|
||||||
# High ratio of special characters - more lenient for system content
|
|
||||||
special_chars = sum(1 for c in content if not c.isalnum() and not c.isspace())
|
|
||||||
if len(content) > 0:
|
|
||||||
char_ratio = special_chars / len(content)
|
|
||||||
threshold = 0.8 if is_system_request else 0.5
|
|
||||||
if char_ratio > threshold:
|
|
||||||
risk += 0.2 if is_system_request else 0.4
|
|
||||||
|
|
||||||
# Multiple encoding indicators - reduced risk for system content
|
|
||||||
encoding_indicators = ["base64", "hex", "unicode", "url", "ascii"]
|
|
||||||
found_encodings = sum(1 for indicator in encoding_indicators if indicator.lower() in content.lower())
|
|
||||||
if found_encodings > 1:
|
|
||||||
risk += 0.1 if is_system_request else 0.3
|
|
||||||
|
|
||||||
# Excessive newlines or formatting - more lenient for system content
|
|
||||||
newline_threshold = 200 if is_system_request else 50
|
|
||||||
if content.count('\n') > newline_threshold or content.count('\\n') > newline_threshold:
|
|
||||||
risk += 0.1 if is_system_request else 0.2
|
|
||||||
|
|
||||||
return risk
|
|
||||||
|
|
||||||
def _is_system_request(self, messages: List[Dict[str, str]]) -> bool:
|
|
||||||
"""Determine if this is a system/RAG request"""
|
|
||||||
if not messages:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check for system messages
|
|
||||||
for message in messages:
|
|
||||||
if message.get("role") == "system":
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check message content for RAG indicators
|
|
||||||
for message in messages:
|
|
||||||
content = message.get("content", "")
|
|
||||||
if ("document:" in content.lower() or
|
|
||||||
"context:" in content.lower() or
|
|
||||||
"source:" in content.lower() or
|
|
||||||
"retrieved:" in content.lower() or
|
|
||||||
"citation:" in content.lower() or
|
|
||||||
"reference:" in content.lower()):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def create_audit_log(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
api_key_id: int,
|
|
||||||
provider: str,
|
|
||||||
model: str,
|
|
||||||
request_type: str,
|
|
||||||
risk_score: float,
|
|
||||||
detected_patterns: List[str],
|
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Create comprehensive audit log for LLM request"""
|
|
||||||
audit_entry = {
|
|
||||||
"timestamp": datetime.utcnow().isoformat(),
|
|
||||||
"user_id": user_id,
|
|
||||||
"api_key_id": api_key_id,
|
|
||||||
"provider": provider,
|
|
||||||
"model": model,
|
|
||||||
"request_type": request_type,
|
|
||||||
"security": {
|
|
||||||
"risk_score": risk_score,
|
|
||||||
"detected_patterns": detected_patterns,
|
|
||||||
"security_check_passed": risk_score < settings.API_SECURITY_RISK_THRESHOLD
|
|
||||||
},
|
|
||||||
"metadata": metadata or {},
|
|
||||||
"audit_hash": None # Will be set below
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create hash for audit integrity
|
|
||||||
audit_hash = self._create_audit_hash(audit_entry)
|
|
||||||
audit_entry["audit_hash"] = audit_hash
|
|
||||||
|
|
||||||
# Log based on risk level (never block, only log)
|
|
||||||
if risk_score >= settings.API_SECURITY_RISK_THRESHOLD:
|
|
||||||
logger.warning(f"HIGH RISK LLM REQUEST DETECTED (NOT BLOCKED): {json.dumps(audit_entry)}")
|
|
||||||
elif risk_score >= settings.API_SECURITY_WARNING_THRESHOLD:
|
|
||||||
logger.info(f"MEDIUM RISK LLM REQUEST: {json.dumps(audit_entry)}")
|
|
||||||
else:
|
|
||||||
logger.info(f"LLM REQUEST AUDIT: user={user_id}, model={model}, risk={risk_score:.3f}")
|
|
||||||
|
|
||||||
return audit_entry
|
|
||||||
|
|
||||||
def _create_audit_hash(self, audit_entry: Dict[str, Any]) -> str:
|
|
||||||
"""Create hash for audit trail integrity"""
|
|
||||||
# Create hash from key fields (excluding the hash itself)
|
|
||||||
hash_data = {
|
|
||||||
"timestamp": audit_entry["timestamp"],
|
|
||||||
"user_id": audit_entry["user_id"],
|
|
||||||
"api_key_id": audit_entry["api_key_id"],
|
|
||||||
"provider": audit_entry["provider"],
|
|
||||||
"model": audit_entry["model"],
|
|
||||||
"request_type": audit_entry["request_type"],
|
|
||||||
"risk_score": audit_entry["security"]["risk_score"]
|
|
||||||
}
|
|
||||||
|
|
||||||
hash_string = json.dumps(hash_data, sort_keys=True)
|
|
||||||
return hashlib.sha256(hash_string.encode()).hexdigest()
|
|
||||||
|
|
||||||
def log_detailed_request(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
model: str,
|
|
||||||
user_id: str,
|
|
||||||
provider: str,
|
|
||||||
context_info: Optional[Dict[str, Any]] = None
|
|
||||||
):
|
|
||||||
"""Log detailed LLM request if LOG_LLM_PROMPTS is enabled"""
|
|
||||||
if not settings.LOG_LLM_PROMPTS:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("=== DETAILED LLM REQUEST ===")
|
|
||||||
logger.info(f"Model: {model}")
|
|
||||||
logger.info(f"Provider: {provider}")
|
|
||||||
logger.info(f"User ID: {user_id}")
|
|
||||||
|
|
||||||
if context_info:
|
|
||||||
for key, value in context_info.items():
|
|
||||||
logger.info(f"{key}: {value}")
|
|
||||||
|
|
||||||
logger.info("Messages to LLM:")
|
|
||||||
for i, message in enumerate(messages):
|
|
||||||
role = message.get("role", "unknown")
|
|
||||||
content = message.get("content", "")[:500] # Truncate for logging
|
|
||||||
logger.info(f" Message {i+1} [{role}]: {content}{'...' if len(message.get('content', '')) > 500 else ''}")
|
|
||||||
|
|
||||||
logger.info("=== END DETAILED LLM REQUEST ===")
|
|
||||||
|
|
||||||
def log_detailed_response(
|
|
||||||
self,
|
|
||||||
response_content: str,
|
|
||||||
token_usage: Optional[Dict[str, int]] = None,
|
|
||||||
provider: str = "unknown"
|
|
||||||
):
|
|
||||||
"""Log detailed LLM response if LOG_LLM_PROMPTS is enabled"""
|
|
||||||
if not settings.LOG_LLM_PROMPTS:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("=== DETAILED LLM RESPONSE ===")
|
|
||||||
logger.info(f"Provider: {provider}")
|
|
||||||
logger.info(f"Response content: {response_content[:500]}{'...' if len(response_content) > 500 else ''}")
|
|
||||||
|
|
||||||
if token_usage:
|
|
||||||
logger.info(f"Token usage - Prompt: {token_usage.get('prompt_tokens', 0)}, "
|
|
||||||
f"Completion: {token_usage.get('completion_tokens', 0)}, "
|
|
||||||
f"Total: {token_usage.get('total_tokens', 0)}")
|
|
||||||
|
|
||||||
logger.info("=== END DETAILED LLM RESPONSE ===")
|
|
||||||
|
|
||||||
|
|
||||||
class SecurityError(Exception):
|
|
||||||
"""Security-related errors in LLM operations"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# Global security manager instance
|
|
||||||
security_manager = SecurityManager()
|
|
||||||
@@ -17,9 +17,8 @@ from .models import (
|
|||||||
)
|
)
|
||||||
from .config import config_manager, ProviderConfig
|
from .config import config_manager, ProviderConfig
|
||||||
from ...core.config import settings
|
from ...core.config import settings
|
||||||
from .security import security_manager
|
|
||||||
from .resilience import ResilienceManagerFactory
|
from .resilience import ResilienceManagerFactory
|
||||||
from .metrics import metrics_collector
|
# from .metrics import metrics_collector
|
||||||
from .providers import BaseLLMProvider, PrivateModeProvider
|
from .providers import BaseLLMProvider, PrivateModeProvider
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
LLMError, ProviderError, SecurityError, ConfigurationError,
|
LLMError, ProviderError, SecurityError, ConfigurationError,
|
||||||
@@ -150,45 +149,8 @@ class LLMService:
|
|||||||
if not request.messages:
|
if not request.messages:
|
||||||
raise ValidationError("Messages cannot be empty", field="messages")
|
raise ValidationError("Messages cannot be empty", field="messages")
|
||||||
|
|
||||||
# Security validation (only if enabled)
|
# Security validation disabled - always allow requests
|
||||||
messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
risk_score = 0.0
|
||||||
|
|
||||||
if settings.API_SECURITY_ENABLED:
|
|
||||||
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict)
|
|
||||||
else:
|
|
||||||
# Security disabled - always safe
|
|
||||||
is_safe, risk_score, detected_patterns = True, 0.0, []
|
|
||||||
|
|
||||||
if not is_safe:
|
|
||||||
# Log security violation
|
|
||||||
security_manager.create_audit_log(
|
|
||||||
user_id=request.user_id,
|
|
||||||
api_key_id=request.api_key_id,
|
|
||||||
provider="blocked",
|
|
||||||
model=request.model,
|
|
||||||
request_type="chat_completion",
|
|
||||||
risk_score=risk_score,
|
|
||||||
detected_patterns=[p.get("pattern", "") for p in detected_patterns]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Record blocked request
|
|
||||||
metrics_collector.record_request(
|
|
||||||
provider="security",
|
|
||||||
model=request.model,
|
|
||||||
request_type="chat_completion",
|
|
||||||
success=False,
|
|
||||||
latency_ms=0,
|
|
||||||
security_risk_score=risk_score,
|
|
||||||
error_code="SECURITY_BLOCKED",
|
|
||||||
user_id=request.user_id,
|
|
||||||
api_key_id=request.api_key_id
|
|
||||||
)
|
|
||||||
|
|
||||||
raise SecurityError(
|
|
||||||
"Request blocked due to security concerns",
|
|
||||||
risk_score=risk_score,
|
|
||||||
details={"detected_patterns": detected_patterns}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get provider for model
|
# Get provider for model
|
||||||
provider_name = self._get_provider_for_model(request.model)
|
provider_name = self._get_provider_for_model(request.model)
|
||||||
@@ -197,18 +159,7 @@ class LLMService:
|
|||||||
if not provider:
|
if not provider:
|
||||||
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
|
raise ProviderError(f"No available provider for model '{request.model}'", provider=provider_name)
|
||||||
|
|
||||||
# Log detailed request if enabled
|
# Security logging disabled
|
||||||
security_manager.log_detailed_request(
|
|
||||||
messages=messages_dict,
|
|
||||||
model=request.model,
|
|
||||||
user_id=request.user_id,
|
|
||||||
provider=provider_name,
|
|
||||||
context_info={
|
|
||||||
"temperature": request.temperature,
|
|
||||||
"max_tokens": request.max_tokens,
|
|
||||||
"risk_score": f"{risk_score:.3f}"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute with resilience
|
# Execute with resilience
|
||||||
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
|
resilience_manager = ResilienceManagerFactory.get_manager(provider_name)
|
||||||
@@ -222,85 +173,46 @@ class LLMService:
|
|||||||
non_retryable_exceptions=(SecurityError, ValidationError)
|
non_retryable_exceptions=(SecurityError, ValidationError)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update response with security information
|
# Security features disabled
|
||||||
response.security_check = is_safe
|
|
||||||
response.risk_score = risk_score
|
|
||||||
response.detected_patterns = [p.get("pattern", "") for p in detected_patterns]
|
|
||||||
|
|
||||||
# Log detailed response if enabled
|
# Security logging disabled
|
||||||
if response.choices:
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
security_manager.log_detailed_response(
|
|
||||||
response_content=content,
|
|
||||||
token_usage=response.usage.model_dump() if response.usage else None,
|
|
||||||
provider=provider_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# Record successful request
|
# Record successful request - metrics disabled
|
||||||
total_latency = (time.time() - start_time) * 1000
|
total_latency = (time.time() - start_time) * 1000
|
||||||
metrics_collector.record_request(
|
# metrics_collector.record_request(
|
||||||
provider=provider_name,
|
# provider=provider_name,
|
||||||
model=request.model,
|
# model=request.model,
|
||||||
request_type="chat_completion",
|
# request_type="chat_completion",
|
||||||
success=True,
|
# success=True,
|
||||||
latency_ms=total_latency,
|
# latency_ms=total_latency,
|
||||||
token_usage=response.usage.model_dump() if response.usage else None,
|
# token_usage=response.usage.model_dump() if response.usage else None,
|
||||||
security_risk_score=risk_score,
|
# security_risk_score=risk_score,
|
||||||
user_id=request.user_id,
|
# user_id=request.user_id,
|
||||||
api_key_id=request.api_key_id
|
# api_key_id=request.api_key_id
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Create audit log
|
# Security audit logging disabled
|
||||||
security_manager.create_audit_log(
|
|
||||||
user_id=request.user_id,
|
|
||||||
api_key_id=request.api_key_id,
|
|
||||||
provider=provider_name,
|
|
||||||
model=request.model,
|
|
||||||
request_type="chat_completion",
|
|
||||||
risk_score=risk_score,
|
|
||||||
detected_patterns=[p.get("pattern", "") for p in detected_patterns],
|
|
||||||
metadata={
|
|
||||||
"success": True,
|
|
||||||
"latency_ms": total_latency,
|
|
||||||
"token_usage": response.usage.model_dump() if response.usage else None
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Record failed request
|
# Record failed request - metrics disabled
|
||||||
total_latency = (time.time() - start_time) * 1000
|
total_latency = (time.time() - start_time) * 1000
|
||||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||||
|
|
||||||
metrics_collector.record_request(
|
# metrics_collector.record_request(
|
||||||
provider=provider_name,
|
# provider=provider_name,
|
||||||
model=request.model,
|
# model=request.model,
|
||||||
request_type="chat_completion",
|
# request_type="chat_completion",
|
||||||
success=False,
|
# success=False,
|
||||||
latency_ms=total_latency,
|
# latency_ms=total_latency,
|
||||||
security_risk_score=risk_score,
|
# security_risk_score=risk_score,
|
||||||
error_code=error_code,
|
# error_code=error_code,
|
||||||
user_id=request.user_id,
|
# user_id=request.user_id,
|
||||||
api_key_id=request.api_key_id
|
# api_key_id=request.api_key_id
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Create audit log for failure
|
# Security audit logging disabled
|
||||||
security_manager.create_audit_log(
|
|
||||||
user_id=request.user_id,
|
|
||||||
api_key_id=request.api_key_id,
|
|
||||||
provider=provider_name,
|
|
||||||
model=request.model,
|
|
||||||
request_type="chat_completion",
|
|
||||||
risk_score=risk_score,
|
|
||||||
detected_patterns=[p.get("pattern", "") for p in detected_patterns],
|
|
||||||
metadata={
|
|
||||||
"success": False,
|
|
||||||
"error": str(e),
|
|
||||||
"error_code": error_code,
|
|
||||||
"latency_ms": total_latency
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -309,21 +221,8 @@ class LLMService:
|
|||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
await self.initialize()
|
await self.initialize()
|
||||||
|
|
||||||
# Security validation (same as non-streaming)
|
# Security validation disabled - always allow streaming requests
|
||||||
messages_dict = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
risk_score = 0.0
|
||||||
|
|
||||||
if settings.API_SECURITY_ENABLED:
|
|
||||||
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security(messages_dict)
|
|
||||||
else:
|
|
||||||
# Security disabled - always safe
|
|
||||||
is_safe, risk_score, detected_patterns = True, 0.0, []
|
|
||||||
|
|
||||||
if not is_safe:
|
|
||||||
raise SecurityError(
|
|
||||||
"Streaming request blocked due to security concerns",
|
|
||||||
risk_score=risk_score,
|
|
||||||
details={"detected_patterns": detected_patterns}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get provider
|
# Get provider
|
||||||
provider_name = self._get_provider_for_model(request.model)
|
provider_name = self._get_provider_for_model(request.model)
|
||||||
@@ -345,19 +244,19 @@ class LLMService:
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Record streaming failure
|
# Record streaming failure - metrics disabled
|
||||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||||
metrics_collector.record_request(
|
# metrics_collector.record_request(
|
||||||
provider=provider_name,
|
# provider=provider_name,
|
||||||
model=request.model,
|
# model=request.model,
|
||||||
request_type="chat_completion_stream",
|
# request_type="chat_completion_stream",
|
||||||
success=False,
|
# success=False,
|
||||||
latency_ms=0,
|
# latency_ms=0,
|
||||||
security_risk_score=risk_score,
|
# security_risk_score=risk_score,
|
||||||
error_code=error_code,
|
# error_code=error_code,
|
||||||
user_id=request.user_id,
|
# user_id=request.user_id,
|
||||||
api_key_id=request.api_key_id
|
# api_key_id=request.api_key_id
|
||||||
)
|
# )
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
async def create_embedding(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
||||||
@@ -365,23 +264,8 @@ class LLMService:
|
|||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
await self.initialize()
|
await self.initialize()
|
||||||
|
|
||||||
# Security validation for embedding input
|
# Security validation disabled - always allow embedding requests
|
||||||
input_text = request.input if isinstance(request.input, str) else " ".join(request.input)
|
risk_score = 0.0
|
||||||
|
|
||||||
if settings.API_SECURITY_ENABLED:
|
|
||||||
is_safe, risk_score, detected_patterns = security_manager.validate_prompt_security([
|
|
||||||
{"role": "user", "content": input_text}
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
# Security disabled - always safe
|
|
||||||
is_safe, risk_score, detected_patterns = True, 0.0, []
|
|
||||||
|
|
||||||
if not is_safe:
|
|
||||||
raise SecurityError(
|
|
||||||
"Embedding request blocked due to security concerns",
|
|
||||||
risk_score=risk_score,
|
|
||||||
details={"detected_patterns": detected_patterns}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get provider
|
# Get provider
|
||||||
provider_name = self._get_provider_for_model(request.model)
|
provider_name = self._get_provider_for_model(request.model)
|
||||||
@@ -402,42 +286,40 @@ class LLMService:
|
|||||||
non_retryable_exceptions=(SecurityError, ValidationError)
|
non_retryable_exceptions=(SecurityError, ValidationError)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update response with security information
|
# Security features disabled
|
||||||
response.security_check = is_safe
|
|
||||||
response.risk_score = risk_score
|
|
||||||
|
|
||||||
# Record successful request
|
# Record successful request - metrics disabled
|
||||||
total_latency = (time.time() - start_time) * 1000
|
total_latency = (time.time() - start_time) * 1000
|
||||||
metrics_collector.record_request(
|
# metrics_collector.record_request(
|
||||||
provider=provider_name,
|
# provider=provider_name,
|
||||||
model=request.model,
|
# model=request.model,
|
||||||
request_type="embedding",
|
# request_type="embedding",
|
||||||
success=True,
|
# success=True,
|
||||||
latency_ms=total_latency,
|
# latency_ms=total_latency,
|
||||||
token_usage=response.usage.model_dump() if response.usage else None,
|
# token_usage=response.usage.model_dump() if response.usage else None,
|
||||||
security_risk_score=risk_score,
|
# security_risk_score=risk_score,
|
||||||
user_id=request.user_id,
|
# user_id=request.user_id,
|
||||||
api_key_id=request.api_key_id
|
# api_key_id=request.api_key_id
|
||||||
)
|
# )
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Record failed request
|
# Record failed request - metrics disabled
|
||||||
total_latency = (time.time() - start_time) * 1000
|
total_latency = (time.time() - start_time) * 1000
|
||||||
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
error_code = getattr(e, 'error_code', e.__class__.__name__)
|
||||||
|
|
||||||
metrics_collector.record_request(
|
# metrics_collector.record_request(
|
||||||
provider=provider_name,
|
# provider=provider_name,
|
||||||
model=request.model,
|
# model=request.model,
|
||||||
request_type="embedding",
|
# request_type="embedding",
|
||||||
success=False,
|
# success=False,
|
||||||
latency_ms=total_latency,
|
# latency_ms=total_latency,
|
||||||
security_risk_score=risk_score,
|
# security_risk_score=risk_score,
|
||||||
error_code=error_code,
|
# error_code=error_code,
|
||||||
user_id=request.user_id,
|
# user_id=request.user_id,
|
||||||
api_key_id=request.api_key_id
|
# api_key_id=request.api_key_id
|
||||||
)
|
# )
|
||||||
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -492,12 +374,18 @@ class LLMService:
|
|||||||
return status_dict
|
return status_dict
|
||||||
|
|
||||||
def get_metrics(self) -> LLMMetrics:
|
def get_metrics(self) -> LLMMetrics:
|
||||||
"""Get service metrics"""
|
"""Get service metrics - metrics disabled"""
|
||||||
return metrics_collector.get_metrics()
|
# return metrics_collector.get_metrics()
|
||||||
|
return LLMMetrics(
|
||||||
|
total_requests=0,
|
||||||
|
success_rate=0.0,
|
||||||
|
avg_latency_ms=0,
|
||||||
|
error_rates={}
|
||||||
|
)
|
||||||
|
|
||||||
def get_health_summary(self) -> Dict[str, Any]:
|
def get_health_summary(self) -> Dict[str, Any]:
|
||||||
"""Get comprehensive health summary"""
|
"""Get comprehensive health summary - metrics disabled"""
|
||||||
metrics_health = metrics_collector.get_health_summary()
|
# metrics_health = metrics_collector.get_health_summary()
|
||||||
resilience_health = ResilienceManagerFactory.get_all_health_status()
|
resilience_health = ResilienceManagerFactory.get_all_health_status()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -505,7 +393,7 @@ class LLMService:
|
|||||||
"startup_time": self._startup_time.isoformat() if self._startup_time else None,
|
"startup_time": self._startup_time.isoformat() if self._startup_time else None,
|
||||||
"provider_count": len(self._providers),
|
"provider_count": len(self._providers),
|
||||||
"active_providers": list(self._providers.keys()),
|
"active_providers": list(self._providers.keys()),
|
||||||
"metrics": metrics_health,
|
"metrics": {"status": "disabled"},
|
||||||
"resilience": resilience_health
|
"resilience": resilience_health
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,153 +0,0 @@
|
|||||||
"""
|
|
||||||
Token-based rate limiting for LLM service
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
import redis
|
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from ..core.config import settings
|
|
||||||
from ..core.logging import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenRateLimiter:
|
|
||||||
"""Token-based rate limiting implementation"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
try:
|
|
||||||
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
|
||||||
self.redis_client.ping()
|
|
||||||
logger.info("Token rate limiter initialized with Redis backend")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Redis not available for token rate limiting: {e}")
|
|
||||||
self.redis_client = None
|
|
||||||
# Fall back to in-memory rate limiting
|
|
||||||
self.in_memory_store = {}
|
|
||||||
logger.info("Token rate limiter using in-memory fallback")
|
|
||||||
|
|
||||||
async def check_token_limits(
|
|
||||||
self,
|
|
||||||
provider: str,
|
|
||||||
prompt_tokens: int,
|
|
||||||
completion_tokens: int = 0
|
|
||||||
) -> Tuple[bool, Dict[str, str]]:
|
|
||||||
"""
|
|
||||||
Check if token usage is within limits
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider: Provider name (e.g., "privatemode")
|
|
||||||
prompt_tokens: Number of prompt tokens to use
|
|
||||||
completion_tokens: Number of completion tokens to use
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_allowed, headers)
|
|
||||||
"""
|
|
||||||
# Get token limits from configuration
|
|
||||||
from .config import get_config
|
|
||||||
config = get_config()
|
|
||||||
token_limits = config.token_limits_per_minute
|
|
||||||
|
|
||||||
# Check organization-wide limits
|
|
||||||
org_key = f"tokens:org:{provider}"
|
|
||||||
|
|
||||||
# Get current usage
|
|
||||||
current_usage = await self._get_token_usage(org_key)
|
|
||||||
|
|
||||||
# Calculate new usage
|
|
||||||
new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens
|
|
||||||
new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens
|
|
||||||
|
|
||||||
# Check limits
|
|
||||||
prompt_limit = token_limits.get("prompt_tokens", 20000)
|
|
||||||
completion_limit = token_limits.get("completion_tokens", 10000)
|
|
||||||
|
|
||||||
is_allowed = (
|
|
||||||
new_prompt_tokens <= prompt_limit and
|
|
||||||
new_completion_tokens <= completion_limit
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_allowed:
|
|
||||||
# Update usage
|
|
||||||
await self._update_token_usage(org_key, prompt_tokens, completion_tokens)
|
|
||||||
logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, "
|
|
||||||
f"{new_completion_tokens}/{completion_limit} completion")
|
|
||||||
|
|
||||||
# Calculate remaining tokens
|
|
||||||
remaining_prompt = max(0, prompt_limit - new_prompt_tokens)
|
|
||||||
remaining_completion = max(0, completion_limit - new_completion_tokens)
|
|
||||||
|
|
||||||
# Create headers
|
|
||||||
headers = {
|
|
||||||
"X-TokenLimit-Prompt-Remaining": str(remaining_prompt),
|
|
||||||
"X-TokenLimit-Completion-Remaining": str(remaining_completion),
|
|
||||||
"X-TokenLimit-Prompt-Limit": str(prompt_limit),
|
|
||||||
"X-TokenLimit-Completion-Limit": str(completion_limit),
|
|
||||||
"X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute
|
|
||||||
}
|
|
||||||
|
|
||||||
if not is_allowed:
|
|
||||||
logger.warning(f"Token rate limit exceeded for {provider}. "
|
|
||||||
f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. "
|
|
||||||
f"Current: {current_usage}")
|
|
||||||
|
|
||||||
return is_allowed, headers
|
|
||||||
|
|
||||||
async def _get_token_usage(self, key: str) -> Dict[str, int]:
|
|
||||||
"""Get current token usage"""
|
|
||||||
if self.redis_client:
|
|
||||||
try:
|
|
||||||
data = self.redis_client.hgetall(key)
|
|
||||||
if data:
|
|
||||||
return {
|
|
||||||
"prompt_tokens": int(data.get("prompt_tokens", 0)),
|
|
||||||
"completion_tokens": int(data.get("completion_tokens", 0)),
|
|
||||||
"updated_at": float(data.get("updated_at", time.time()))
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting token usage from Redis: {e}")
|
|
||||||
|
|
||||||
# Fallback to in-memory
|
|
||||||
return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0})
|
|
||||||
|
|
||||||
async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int):
|
|
||||||
"""Update token usage"""
|
|
||||||
if self.redis_client:
|
|
||||||
try:
|
|
||||||
pipe = self.redis_client.pipeline()
|
|
||||||
pipe.hincrby(key, "prompt_tokens", prompt_tokens)
|
|
||||||
pipe.hincrby(key, "completion_tokens", completion_tokens)
|
|
||||||
pipe.hset(key, "updated_at", time.time())
|
|
||||||
pipe.expire(key, 60) # Expire after 1 minute
|
|
||||||
pipe.execute()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating token usage in Redis: {e}")
|
|
||||||
# Fallback to in-memory
|
|
||||||
self._update_in_memory(key, prompt_tokens, completion_tokens)
|
|
||||||
else:
|
|
||||||
self._update_in_memory(key, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int):
|
|
||||||
"""Update in-memory token usage"""
|
|
||||||
if key not in self.in_memory_store:
|
|
||||||
self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0}
|
|
||||||
|
|
||||||
self.in_memory_store[key]["prompt_tokens"] += prompt_tokens
|
|
||||||
self.in_memory_store[key]["completion_tokens"] += completion_tokens
|
|
||||||
self.in_memory_store[key]["updated_at"] = time.time()
|
|
||||||
|
|
||||||
def cleanup_expired(self):
|
|
||||||
"""Clean up expired entries (for in-memory store)"""
|
|
||||||
if not self.redis_client:
|
|
||||||
current_time = time.time()
|
|
||||||
expired_keys = [
|
|
||||||
key for key, data in self.in_memory_store.items()
|
|
||||||
if current_time - data.get("updated_at", 0) > 60
|
|
||||||
]
|
|
||||||
for key in expired_keys:
|
|
||||||
del self.in_memory_store[key]
|
|
||||||
|
|
||||||
|
|
||||||
# Global token rate limiter instance
|
|
||||||
token_rate_limiter = TokenRateLimiter()
|
|
||||||
@@ -755,10 +755,11 @@ class RAGService:
|
|||||||
|
|
||||||
# Process with RAG module
|
# Process with RAG module
|
||||||
try:
|
try:
|
||||||
|
# Pass file_path in metadata so JSONL indexing can reopen the source file
|
||||||
processed_doc = await rag_module.process_document(
|
processed_doc = await rag_module.process_document(
|
||||||
file_content,
|
file_content,
|
||||||
document.original_filename,
|
document.original_filename,
|
||||||
{}
|
{"file_path": document.file_path}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Success case - update document with processed content
|
# Success case - update document with processed content
|
||||||
|
|||||||
@@ -638,11 +638,19 @@ class RAGModule(BaseModule):
|
|||||||
np.random.seed(hash(text) % 2**32)
|
np.random.seed(hash(text) % 2**32)
|
||||||
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
|
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
|
||||||
|
|
||||||
async def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
|
||||||
"""Generate embeddings for multiple texts (batch processing)"""
|
"""Generate embeddings for multiple texts (batch processing)"""
|
||||||
if self.embedding_service:
|
if self.embedding_service:
|
||||||
|
# Add task-specific prefixes for better E5 model performance
|
||||||
|
if is_document:
|
||||||
|
# For document passages, use "passage:" prefix
|
||||||
|
prefixed_texts = [f"passage: {text}" for text in texts]
|
||||||
|
else:
|
||||||
|
# For queries, use "query:" prefix (handled in search method)
|
||||||
|
prefixed_texts = texts
|
||||||
|
|
||||||
# Use real embedding service for batch processing
|
# Use real embedding service for batch processing
|
||||||
return await self.embedding_service.get_embeddings(texts)
|
return await self.embedding_service.get_embeddings(prefixed_texts)
|
||||||
else:
|
else:
|
||||||
# Fallback to individual processing
|
# Fallback to individual processing
|
||||||
embeddings = []
|
embeddings = []
|
||||||
@@ -922,12 +930,18 @@ class RAGModule(BaseModule):
|
|||||||
- Each line contains a JSON object with 'id' and 'payload'
|
- Each line contains a JSON object with 'id' and 'payload'
|
||||||
- Payload contains 'question', 'language', and 'answer' fields
|
- Payload contains 'question', 'language', and 'answer' fields
|
||||||
- Combines question and answer into searchable content
|
- Combines question and answer into searchable content
|
||||||
|
|
||||||
|
Performance optimizations:
|
||||||
|
- Processes articles in smaller batches to reduce memory usage
|
||||||
|
- Uses streaming approach for large files
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Use streaming approach for large files
|
||||||
jsonl_content = content.decode('utf-8', errors='replace')
|
jsonl_content = content.decode('utf-8', errors='replace')
|
||||||
lines = jsonl_content.strip().split('\n')
|
lines = jsonl_content.strip().split('\n')
|
||||||
|
|
||||||
processed_articles = []
|
processed_articles = []
|
||||||
|
batch_size = 50 # Process in batches of 50 articles
|
||||||
|
|
||||||
for line_num, line in enumerate(lines, 1):
|
for line_num, line in enumerate(lines, 1):
|
||||||
if not line.strip():
|
if not line.strip():
|
||||||
@@ -1153,7 +1167,7 @@ class RAGModule(BaseModule):
|
|||||||
chunks = self._chunk_text(content)
|
chunks = self._chunk_text(content)
|
||||||
|
|
||||||
# Generate embeddings for all chunks in batch (more efficient)
|
# Generate embeddings for all chunks in batch (more efficient)
|
||||||
embeddings = await self._generate_embeddings(chunks)
|
embeddings = await self._generate_embeddings(chunks, is_document=True)
|
||||||
|
|
||||||
# Create document points
|
# Create document points
|
||||||
points = []
|
points = []
|
||||||
@@ -1204,6 +1218,24 @@ class RAGModule(BaseModule):
|
|||||||
collection_name = collection_name or self.default_collection_name
|
collection_name = collection_name or self.default_collection_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Special handling for JSONL files
|
||||||
|
if processed_doc.file_type == 'jsonl':
|
||||||
|
# Import the optimized JSONL processor
|
||||||
|
from app.services.jsonl_processor import JSONLProcessor
|
||||||
|
jsonl_processor = JSONLProcessor(self)
|
||||||
|
|
||||||
|
# Read the original file content
|
||||||
|
with open(processed_doc.metadata.get('file_path', ''), 'rb') as f:
|
||||||
|
file_content = f.read()
|
||||||
|
|
||||||
|
# Process using the optimized JSONL processor
|
||||||
|
return await jsonl_processor.process_and_index_jsonl(
|
||||||
|
collection_name=collection_name,
|
||||||
|
content=file_content,
|
||||||
|
filename=processed_doc.original_filename,
|
||||||
|
metadata=processed_doc.metadata
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure collection exists
|
# Ensure collection exists
|
||||||
await self._ensure_collection_exists(collection_name)
|
await self._ensure_collection_exists(collection_name)
|
||||||
|
|
||||||
@@ -1216,7 +1248,7 @@ class RAGModule(BaseModule):
|
|||||||
chunks = self._chunk_text(processed_doc.content)
|
chunks = self._chunk_text(processed_doc.content)
|
||||||
|
|
||||||
# Generate embeddings for all chunks in batch (more efficient)
|
# Generate embeddings for all chunks in batch (more efficient)
|
||||||
embeddings = await self._generate_embeddings(chunks)
|
embeddings = await self._generate_embeddings(chunks, is_document=True)
|
||||||
|
|
||||||
# Create document points with enhanced metadata
|
# Create document points with enhanced metadata
|
||||||
points = []
|
points = []
|
||||||
@@ -1339,24 +1371,48 @@ class RAGModule(BaseModule):
|
|||||||
score_threshold=score_threshold / 2 # Lower threshold for initial search
|
score_threshold=score_threshold / 2 # Lower threshold for initial search
|
||||||
)
|
)
|
||||||
|
|
||||||
# Combine scores
|
# Combine scores with improved normalization
|
||||||
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
|
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
|
||||||
vector_weight = hybrid_weights.get("vector", 0.7)
|
vector_weight = hybrid_weights.get("vector", 0.7)
|
||||||
bm25_weight = hybrid_weights.get("bm25", 0.3)
|
bm25_weight = hybrid_weights.get("bm25", 0.3)
|
||||||
|
|
||||||
# Create hybrid results
|
# Get score distributions for better normalization
|
||||||
|
vector_scores = [r.score for r in vector_results]
|
||||||
|
bm25_scores_list = list(bm25_scores.values())
|
||||||
|
|
||||||
|
# Calculate statistics for normalization
|
||||||
|
if vector_scores:
|
||||||
|
v_max = max(vector_scores)
|
||||||
|
v_min = min(vector_scores)
|
||||||
|
v_range = v_max - v_min if v_max != v_min else 1
|
||||||
|
else:
|
||||||
|
v_max, v_min, v_range = 1, 0, 1
|
||||||
|
|
||||||
|
if bm25_scores_list:
|
||||||
|
bm25_max = max(bm25_scores_list)
|
||||||
|
bm25_min = min(bm25_scores_list)
|
||||||
|
bm25_range = bm25_max - bm25_min if bm25_max != bm25_min else 1
|
||||||
|
else:
|
||||||
|
bm25_max, bm25_min, bm25_range = 1, 0, 1
|
||||||
|
|
||||||
|
# Create hybrid results with improved scoring
|
||||||
hybrid_results = []
|
hybrid_results = []
|
||||||
for result in vector_results:
|
for result in vector_results:
|
||||||
doc_id = result.payload.get("document_id", "")
|
doc_id = result.payload.get("document_id", "")
|
||||||
vector_score = result.score
|
vector_score = result.score
|
||||||
bm25_score = bm25_scores.get(doc_id, 0.0)
|
bm25_score = bm25_scores.get(doc_id, 0.0)
|
||||||
|
|
||||||
# Normalize scores (simple min-max normalization)
|
# Improved normalization using actual score distributions
|
||||||
vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0
|
vector_norm = (vector_score - v_min) / v_range if v_range > 0 else 0.5
|
||||||
bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1
|
bm25_norm = (bm25_score - bm25_min) / bm25_range if bm25_range > 0 else 0.5
|
||||||
|
|
||||||
# Calculate hybrid score
|
# Apply reciprocal rank fusion for better combination
|
||||||
hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm)
|
# This gives more weight to documents that rank highly in both methods
|
||||||
|
rrf_vector = 1.0 / (1.0 + vector_results.index(result) + 1) # +1 to avoid division by zero
|
||||||
|
rrf_bm25 = 1.0 / (1.0 + sorted(bm25_scores_list, reverse=True).index(bm25_score) + 1) if bm25_score in bm25_scores_list else 0
|
||||||
|
|
||||||
|
# Calculate hybrid score using both normalized scores and RRF
|
||||||
|
hybrid_score = (vector_weight * vector_norm + bm25_weight * bm25_norm) * 0.7 + (rrf_vector + rrf_bm25) * 0.3
|
||||||
|
|
||||||
# Create new point with hybrid score
|
# Create new point with hybrid score
|
||||||
hybrid_point = ScoredPoint(
|
hybrid_point = ScoredPoint(
|
||||||
@@ -1435,7 +1491,7 @@ class RAGModule(BaseModule):
|
|||||||
# Normalize score to 0-1 range
|
# Normalize score to 0-1 range
|
||||||
return min(score / 10.0, 1.0) # Simple normalization
|
return min(score / 10.0, 1.0) # Simple normalization
|
||||||
|
|
||||||
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
|
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
|
||||||
"""Search for relevant documents"""
|
"""Search for relevant documents"""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
raise RuntimeError("RAG module not initialized")
|
raise RuntimeError("RAG module not initialized")
|
||||||
@@ -1453,8 +1509,10 @@ class RAGModule(BaseModule):
|
|||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Generate query embedding
|
# Generate query embedding with task-specific prefix for better retrieval
|
||||||
query_embedding = await self._generate_embedding(query)
|
# The E5 model works better with "query:" prefix for search queries
|
||||||
|
optimized_query = f"query: {query}"
|
||||||
|
query_embedding = await self._generate_embedding(optimized_query)
|
||||||
|
|
||||||
# Build filter
|
# Build filter
|
||||||
search_filter = None
|
search_filter = None
|
||||||
@@ -1474,7 +1532,8 @@ class RAGModule(BaseModule):
|
|||||||
|
|
||||||
# Check if hybrid search is enabled
|
# Check if hybrid search is enabled
|
||||||
enable_hybrid = self.config.get("enable_hybrid", False)
|
enable_hybrid = self.config.get("enable_hybrid", False)
|
||||||
score_threshold = self.config.get("score_threshold", 0.3)
|
# Use provided score_threshold or fall back to config
|
||||||
|
search_score_threshold = score_threshold if score_threshold is not None else self.config.get("score_threshold", 0.3)
|
||||||
|
|
||||||
if enable_hybrid and NLTK_AVAILABLE:
|
if enable_hybrid and NLTK_AVAILABLE:
|
||||||
# Perform hybrid search (vector + BM25)
|
# Perform hybrid search (vector + BM25)
|
||||||
@@ -1484,7 +1543,7 @@ class RAGModule(BaseModule):
|
|||||||
query_vector=query_embedding,
|
query_vector=query_embedding,
|
||||||
query_filter=search_filter,
|
query_filter=search_filter,
|
||||||
limit=max_results,
|
limit=max_results,
|
||||||
score_threshold=score_threshold
|
score_threshold=search_score_threshold
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Pure vector search with improved threshold
|
# Pure vector search with improved threshold
|
||||||
@@ -1493,7 +1552,7 @@ class RAGModule(BaseModule):
|
|||||||
query_vector=query_embedding,
|
query_vector=query_embedding,
|
||||||
query_filter=search_filter,
|
query_filter=search_filter,
|
||||||
limit=max_results,
|
limit=max_results,
|
||||||
score_threshold=score_threshold
|
score_threshold=search_score_threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Raw search results count: {len(search_results)}")
|
logger.info(f"Raw search results count: {len(search_results)}")
|
||||||
@@ -1841,9 +1900,9 @@ async def index_processed_document(processed_doc: ProcessedDocument, collection_
|
|||||||
"""Index a processed document"""
|
"""Index a processed document"""
|
||||||
return await rag_module.index_processed_document(processed_doc, collection_name)
|
return await rag_module.index_processed_document(processed_doc, collection_name)
|
||||||
|
|
||||||
async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
|
async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
|
||||||
"""Search documents"""
|
"""Search documents"""
|
||||||
return await rag_module.search_documents(query, max_results, filters, collection_name)
|
return await rag_module.search_documents(query, max_results, filters, collection_name, score_threshold)
|
||||||
|
|
||||||
async def delete_document(document_id: str, collection_name: str = None) -> bool:
|
async def delete_document(document_id: str, collection_name: str = None) -> bool:
|
||||||
"""Delete a document"""
|
"""Delete a document"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ export async function POST(request: NextRequest) {
|
|||||||
|
|
||||||
// Make request to backend auth endpoint without requiring existing auth
|
// Make request to backend auth endpoint without requiring existing auth
|
||||||
const baseUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`
|
const baseUrl = process.env.INTERNAL_API_URL || `http://enclava-backend:${process.env.BACKEND_INTERNAL_PORT || '8000'}`
|
||||||
const url = `${baseUrl}/api/auth/login`
|
const url = `${baseUrl}/api-internal/v1/auth/login`
|
||||||
|
|
||||||
const response = await fetch(url, {
|
const response = await fetch(url, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
|
|||||||
@@ -85,8 +85,31 @@ function RAGPageContent() {
|
|||||||
const loadStats = async () => {
|
const loadStats = async () => {
|
||||||
try {
|
try {
|
||||||
const data = await apiClient.get('/api-internal/v1/rag/stats')
|
const data = await apiClient.get('/api-internal/v1/rag/stats')
|
||||||
|
console.log('Stats API response:', data)
|
||||||
|
|
||||||
|
// Check if the response has the expected structure
|
||||||
|
if (data && data.stats && data.stats.collections) {
|
||||||
|
console.log('✓ Stats has collections property')
|
||||||
setStats(data.stats)
|
setStats(data.stats)
|
||||||
|
} else {
|
||||||
|
console.error('✗ Invalid stats structure:', data)
|
||||||
|
// Set default empty stats to prevent error
|
||||||
|
setStats({
|
||||||
|
collections: { total: 0, active: 0 },
|
||||||
|
documents: { total: 0, processing: 0, processed: 0 },
|
||||||
|
storage: { total_size_bytes: 0, total_size_mb: 0 },
|
||||||
|
vectors: { total: 0 }
|
||||||
|
})
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
console.error('Error loading stats:', error)
|
||||||
|
// Set default empty stats on error
|
||||||
|
setStats({
|
||||||
|
collections: { total: 0, active: 0 },
|
||||||
|
documents: { total: 0, processing: 0, processed: 0 },
|
||||||
|
storage: { total_size_bytes: 0, total_size_mb: 0 },
|
||||||
|
vectors: { total: 0 }
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import { Badge } from "@/components/ui/badge"
|
|||||||
import { Separator } from "@/components/ui/separator"
|
import { Separator } from "@/components/ui/separator"
|
||||||
import { AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger } from "@/components/ui/alert-dialog"
|
import { AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger } from "@/components/ui/alert-dialog"
|
||||||
import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog"
|
import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, DialogTrigger } from "@/components/ui/dialog"
|
||||||
import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter } from "lucide-react"
|
import { Search, FileText, Trash2, Eye, Download, Calendar, Hash, FileIcon, Filter, RefreshCw } from "lucide-react"
|
||||||
import { useToast } from "@/hooks/use-toast"
|
import { useToast } from "@/hooks/use-toast"
|
||||||
import { apiClient } from "@/lib/api-client"
|
import { apiClient } from "@/lib/api-client"
|
||||||
import { config } from "@/lib/config"
|
import { config } from "@/lib/config"
|
||||||
@@ -56,6 +56,7 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
|
|||||||
const [filterStatus, setFilterStatus] = useState("all")
|
const [filterStatus, setFilterStatus] = useState("all")
|
||||||
const [selectedDocument, setSelectedDocument] = useState<Document | null>(null)
|
const [selectedDocument, setSelectedDocument] = useState<Document | null>(null)
|
||||||
const [deleting, setDeleting] = useState<string | null>(null)
|
const [deleting, setDeleting] = useState<string | null>(null)
|
||||||
|
const [reprocessing, setReprocessing] = useState<string | null>(null)
|
||||||
const { toast } = useToast()
|
const { toast } = useToast()
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -157,6 +158,43 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleReprocessDocument = async (documentId: string) => {
|
||||||
|
setReprocessing(documentId)
|
||||||
|
|
||||||
|
try {
|
||||||
|
await apiClient.post(`/api-internal/v1/rag/documents/${documentId}/reprocess`)
|
||||||
|
|
||||||
|
// Update the document status to processing in the UI
|
||||||
|
setDocuments(prev => prev.map(doc =>
|
||||||
|
doc.id === documentId
|
||||||
|
? { ...doc, status: 'processing' as const, processed_at: new Date().toISOString() }
|
||||||
|
: doc
|
||||||
|
))
|
||||||
|
|
||||||
|
toast({
|
||||||
|
title: "Success",
|
||||||
|
description: "Document reprocessing started",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Reload documents after a short delay to see status updates
|
||||||
|
setTimeout(() => {
|
||||||
|
loadDocuments()
|
||||||
|
}, 2000)
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
const errorMessage = error instanceof Error ? error.message : "Failed to reprocess document"
|
||||||
|
toast({
|
||||||
|
title: "Error",
|
||||||
|
description: errorMessage.includes("Cannot reprocess document with status 'processed'")
|
||||||
|
? "Cannot reprocess documents that are already processed"
|
||||||
|
: errorMessage,
|
||||||
|
variant: "destructive",
|
||||||
|
})
|
||||||
|
} finally {
|
||||||
|
setReprocessing(null)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const formatFileSize = (bytes: number) => {
|
const formatFileSize = (bytes: number) => {
|
||||||
if (bytes === 0) return '0 Bytes'
|
if (bytes === 0) return '0 Bytes'
|
||||||
const k = 1024
|
const k = 1024
|
||||||
@@ -432,6 +470,21 @@ export function DocumentBrowser({ collections, selectedCollection, onCollectionS
|
|||||||
<Download className="h-4 w-4" />
|
<Download className="h-4 w-4" />
|
||||||
</Button>
|
</Button>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
className="h-8 w-8 p-0 hover:bg-blue-100"
|
||||||
|
onClick={() => handleReprocessDocument(document.id)}
|
||||||
|
disabled={reprocessing === document.id || document.status === 'processed'}
|
||||||
|
title={document.status === 'processed' ? "Document already processed" : "Reprocess document"}
|
||||||
|
>
|
||||||
|
{reprocessing === document.id ? (
|
||||||
|
<RefreshCw className="h-4 w-4 animate-spin" />
|
||||||
|
) : (
|
||||||
|
<RefreshCw className={`h-4 w-4 ${document.status === 'processed' ? 'text-gray-400' : ''}`} />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
|
||||||
<AlertDialog>
|
<AlertDialog>
|
||||||
<AlertDialogTrigger asChild>
|
<AlertDialogTrigger asChild>
|
||||||
<Button
|
<Button
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ const Navigation = () => {
|
|||||||
children: [
|
children: [
|
||||||
{ href: "/llm", label: "Models & Config" },
|
{ href: "/llm", label: "Models & Config" },
|
||||||
{ href: "/playground", label: "Playground" },
|
{ href: "/playground", label: "Playground" },
|
||||||
|
{ href: "/rag-demo", label: "RAG Demo" },
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -25,6 +25,12 @@ http {
|
|||||||
listen 80;
|
listen 80;
|
||||||
server_name localhost;
|
server_name localhost;
|
||||||
|
|
||||||
|
# Static files - serve directly from nginx
|
||||||
|
location = /login_helper.html {
|
||||||
|
root /usr/share/nginx/html;
|
||||||
|
try_files $uri =404;
|
||||||
|
}
|
||||||
|
|
||||||
# Frontend routes
|
# Frontend routes
|
||||||
location / {
|
location / {
|
||||||
proxy_pass http://frontend;
|
proxy_pass http://frontend;
|
||||||
@@ -65,6 +71,58 @@ http {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# RAG debug API routes - proxy to frontend (for Next.js API routes)
|
||||||
|
location /api/rag/debug/ {
|
||||||
|
proxy_pass http://frontend;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
|
# CORS headers
|
||||||
|
add_header 'Access-Control-Allow-Origin' '*' always;
|
||||||
|
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always;
|
||||||
|
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
|
||||||
|
add_header 'Access-Control-Allow-Expose-Headers' 'Content-Length,Content-Range' always;
|
||||||
|
|
||||||
|
# Handle preflight requests
|
||||||
|
if ($request_method = 'OPTIONS') {
|
||||||
|
add_header 'Access-Control-Allow-Origin' '*';
|
||||||
|
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS';
|
||||||
|
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization';
|
||||||
|
add_header 'Access-Control-Max-Age' 1728000;
|
||||||
|
add_header 'Content-Type' 'text/plain; charset=utf-8';
|
||||||
|
add_header 'Content-Length' 0;
|
||||||
|
return 204;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Frontend API routes for authentication - proxy to frontend
|
||||||
|
location /api/auth/ {
|
||||||
|
proxy_pass http://frontend;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
|
||||||
|
# CORS headers
|
||||||
|
add_header 'Access-Control-Allow-Origin' '*' always;
|
||||||
|
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS' always;
|
||||||
|
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always;
|
||||||
|
add_header 'Access-Control-Expose-Headers' 'Content-Length,Content-Range' always;
|
||||||
|
|
||||||
|
# Handle preflight requests
|
||||||
|
if ($request_method = 'OPTIONS') {
|
||||||
|
add_header 'Access-Control-Allow-Origin' '*';
|
||||||
|
add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS';
|
||||||
|
add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization';
|
||||||
|
add_header 'Access-Control-Max-Age' 1728000;
|
||||||
|
add_header 'Content-Type' 'text/plain; charset=utf-8';
|
||||||
|
add_header 'Content-Length' 0;
|
||||||
|
return 204;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# Public API routes - proxy to backend (for external clients)
|
# Public API routes - proxy to backend (for external clients)
|
||||||
location /api/ {
|
location /api/ {
|
||||||
proxy_pass http://backend;
|
proxy_pass http://backend;
|
||||||
|
|||||||
Reference in New Issue
Block a user