mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-18 07:54:29 +01:00
rag improvements
This commit is contained in:
@@ -16,7 +16,6 @@ from .analytics import router as analytics_router
|
||||
from .rag import router as rag_router
|
||||
from .chatbot import router as chatbot_router
|
||||
from .prompt_templates import router as prompt_templates_router
|
||||
from .security import router as security_router
|
||||
from .plugin_registry import router as plugin_registry_router
|
||||
|
||||
# Create main API router
|
||||
@@ -61,8 +60,6 @@ api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
|
||||
# Include prompt template routes
|
||||
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
|
||||
|
||||
# Include security routes
|
||||
api_router.include_router(security_router, prefix="/security", tags=["security"])
|
||||
|
||||
|
||||
# Include plugin registry routes
|
||||
|
||||
@@ -745,8 +745,7 @@ async def get_llm_metrics(
|
||||
"total_requests": metrics.total_requests,
|
||||
"successful_requests": metrics.successful_requests,
|
||||
"failed_requests": metrics.failed_requests,
|
||||
"security_blocked_requests": metrics.security_blocked_requests,
|
||||
"average_latency_ms": metrics.average_latency_ms,
|
||||
"average_latency_ms": metrics.average_latency_ms,
|
||||
"average_risk_score": metrics.average_risk_score,
|
||||
"provider_metrics": metrics.provider_metrics,
|
||||
"last_updated": metrics.last_updated.isoformat()
|
||||
|
||||
@@ -3,12 +3,14 @@ RAG API Endpoints
|
||||
Provides REST API for RAG (Retrieval Augmented Generation) operations
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from pydantic import BaseModel
|
||||
import io
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
@@ -16,6 +18,9 @@ from app.models.user import User
|
||||
from app.services.rag_service import RAGService
|
||||
from app.utils.exceptions import APIException
|
||||
|
||||
# Import RAG module from module manager
|
||||
from app.services.module_manager import module_manager
|
||||
|
||||
|
||||
router = APIRouter(tags=["RAG"])
|
||||
|
||||
@@ -78,14 +83,25 @@ async def get_collections(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get all RAG collections from Qdrant (source of truth) with PostgreSQL metadata"""
|
||||
"""Get all RAG collections - live data directly from Qdrant (source of truth)"""
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
collections_data = await rag_service.get_all_collections(skip=skip, limit=limit)
|
||||
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||
|
||||
# Get live stats from Qdrant
|
||||
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||
collections = stats_data.get("collections", [])
|
||||
|
||||
# Apply pagination
|
||||
start_idx = skip
|
||||
end_idx = skip + limit
|
||||
paginated_collections = collections[start_idx:end_idx]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"collections": collections_data,
|
||||
"total": len(collections_data)
|
||||
"collections": paginated_collections,
|
||||
"total": len(collections),
|
||||
"total_documents": stats_data.get("total_documents", 0),
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0)
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -116,6 +132,62 @@ async def create_collection(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stats", response_model=dict)
|
||||
async def get_rag_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get overall RAG statistics - live data directly from Qdrant"""
|
||||
try:
|
||||
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||
|
||||
# Get live stats from Qdrant
|
||||
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||
|
||||
# Calculate active collections (collections with documents)
|
||||
active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
|
||||
|
||||
# Calculate processing documents from database
|
||||
processing_docs = 0
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
from app.models.rag_document import RagDocument, ProcessingStatus
|
||||
|
||||
result = await db.execute(
|
||||
select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
|
||||
)
|
||||
processing_docs = len(result.scalars().all())
|
||||
except Exception:
|
||||
pass # If database query fails, default to 0
|
||||
|
||||
response_data = {
|
||||
"success": True,
|
||||
"stats": {
|
||||
"collections": {
|
||||
"total": stats_data.get("total_collections", 0),
|
||||
"active": active_collections
|
||||
},
|
||||
"documents": {
|
||||
"total": stats_data.get("total_documents", 0),
|
||||
"processing": processing_docs,
|
||||
"processed": stats_data.get("total_documents", 0) # Indexed documents
|
||||
},
|
||||
"storage": {
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0),
|
||||
"total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
|
||||
},
|
||||
"vectors": {
|
||||
"total": stats_data.get("total_documents", 0) # Same as documents for RAG
|
||||
},
|
||||
"last_updated": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
return response_data
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/collections/{collection_id}", response_model=dict)
|
||||
async def get_collection(
|
||||
collection_id: int,
|
||||
@@ -225,21 +297,65 @@ async def upload_document(
|
||||
try:
|
||||
# Read file content
|
||||
file_content = await file.read()
|
||||
|
||||
|
||||
if len(file_content) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file uploaded")
|
||||
|
||||
|
||||
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
|
||||
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
|
||||
|
||||
|
||||
# Validate file can be read before processing
|
||||
filename = file.filename or "unknown"
|
||||
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
|
||||
|
||||
try:
|
||||
# Test file readability based on type
|
||||
if file_extension == 'jsonl':
|
||||
# Validate JSONL format - try to parse first few lines
|
||||
try:
|
||||
content_str = file_content.decode('utf-8')
|
||||
lines = content_str.strip().split('\n')[:5] # Check first 5 lines
|
||||
import json
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip(): # Skip empty lines
|
||||
json.loads(line) # Will raise JSONDecodeError if invalid
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
|
||||
|
||||
elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
|
||||
# Validate text files can be decoded
|
||||
try:
|
||||
file_content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
|
||||
elif file_extension in ['pdf']:
|
||||
# For PDF files, just check if it starts with PDF signature
|
||||
if not file_content.startswith(b'%PDF'):
|
||||
raise HTTPException(status_code=400, detail="Invalid PDF file format")
|
||||
|
||||
elif file_extension in ['docx', 'xlsx', 'pptx']:
|
||||
# For Office documents, check ZIP signature
|
||||
if not file_content.startswith(b'PK'):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
|
||||
|
||||
# For other file types, we'll rely on the document processor
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
|
||||
|
||||
rag_service = RAGService(db)
|
||||
document = await rag_service.upload_document(
|
||||
collection_id=collection_id,
|
||||
file_content=file_content,
|
||||
filename=file.filename or "unknown",
|
||||
filename=filename,
|
||||
content_type=file.content_type
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"document": document.to_dict(),
|
||||
@@ -362,21 +478,167 @@ async def download_document(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Stats Endpoint
|
||||
|
||||
@router.get("/stats", response_model=dict)
|
||||
async def get_rag_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
# Debug Endpoints
|
||||
|
||||
@router.post("/debug/search")
|
||||
async def search_with_debug(
|
||||
query: str,
|
||||
max_results: int = 10,
|
||||
score_threshold: float = 0.3,
|
||||
collection_name: str = None,
|
||||
config: Dict[str, Any] = None,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get RAG system statistics"""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Enhanced search with comprehensive debug information
|
||||
"""
|
||||
# Get RAG module from module manager
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||
|
||||
debug_info = {}
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
stats = await rag_service.get_stats()
|
||||
|
||||
# Apply configuration if provided
|
||||
if config:
|
||||
# Update RAG config temporarily
|
||||
original_config = rag_module.config.copy()
|
||||
rag_module.config.update(config)
|
||||
|
||||
# Generate query embedding (with or without prefix)
|
||||
if config and config.get("use_query_prefix"):
|
||||
optimized_query = f"query: {query}"
|
||||
else:
|
||||
optimized_query = query
|
||||
|
||||
query_embedding = await rag_module._generate_embedding(optimized_query)
|
||||
|
||||
# Store embedding info for debug
|
||||
if config and config.get("debug", {}).get("show_embeddings"):
|
||||
debug_info["query_embedding"] = query_embedding[:10] # First 10 dimensions
|
||||
debug_info["embedding_dimension"] = len(query_embedding)
|
||||
debug_info["optimized_query"] = optimized_query
|
||||
|
||||
# Perform search
|
||||
search_start = asyncio.get_event_loop().time()
|
||||
results = await rag_module.search_documents(
|
||||
query,
|
||||
max_results=max_results,
|
||||
score_threshold=score_threshold,
|
||||
collection_name=collection_name
|
||||
)
|
||||
search_time = (asyncio.get_event_loop().time() - search_start) * 1000
|
||||
|
||||
# Calculate score statistics
|
||||
scores = [r.score for r in results if r.score is not None]
|
||||
if scores:
|
||||
import statistics
|
||||
debug_info["score_stats"] = {
|
||||
"min": min(scores),
|
||||
"max": max(scores),
|
||||
"avg": statistics.mean(scores),
|
||||
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0
|
||||
}
|
||||
|
||||
# Get collection statistics
|
||||
try:
|
||||
from qdrant_client.http.models import Filter
|
||||
collection_name = collection_name or rag_module.default_collection_name
|
||||
|
||||
# Count total documents
|
||||
count_result = rag_module.qdrant_client.count(
|
||||
collection_name=collection_name,
|
||||
count_filter=Filter(must=[])
|
||||
)
|
||||
total_points = count_result.count
|
||||
|
||||
# Get unique documents and languages
|
||||
scroll_result = rag_module.qdrant_client.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=1000, # Sample for stats
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
unique_docs = set()
|
||||
languages = set()
|
||||
|
||||
for point in scroll_result[0]:
|
||||
payload = point.payload or {}
|
||||
doc_id = payload.get("document_id")
|
||||
if doc_id:
|
||||
unique_docs.add(doc_id)
|
||||
|
||||
language = payload.get("language")
|
||||
if language:
|
||||
languages.add(language)
|
||||
|
||||
debug_info["collection_stats"] = {
|
||||
"total_documents": len(unique_docs),
|
||||
"total_chunks": total_points,
|
||||
"languages": sorted(list(languages))
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
debug_info["collection_stats_error"] = str(e)
|
||||
|
||||
# Enhance results with debug info
|
||||
enhanced_results = []
|
||||
for result in results:
|
||||
enhanced_result = {
|
||||
"document": {
|
||||
"id": result.document.id,
|
||||
"content": result.document.content,
|
||||
"metadata": result.document.metadata
|
||||
},
|
||||
"score": result.score,
|
||||
"debug_info": {}
|
||||
}
|
||||
|
||||
# Add hybrid search debug info if available
|
||||
metadata = result.document.metadata or {}
|
||||
if "_vector_score" in metadata:
|
||||
enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
|
||||
if "_bm25_score" in metadata:
|
||||
enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
|
||||
|
||||
enhanced_results.append(enhanced_result)
|
||||
|
||||
# Note: Analytics logging disabled (module not available)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"stats": stats
|
||||
"results": enhanced_results,
|
||||
"debug_info": debug_info,
|
||||
"search_time_ms": search_time,
|
||||
"timestamp": start_time.isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
# Note: Analytics logging disabled (module not available)
|
||||
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
||||
|
||||
finally:
|
||||
# Restore original config if modified
|
||||
if config and 'original_config' in locals():
|
||||
rag_module.config = original_config
|
||||
|
||||
|
||||
@router.get("/debug/config")
|
||||
async def get_current_config(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current RAG configuration"""
|
||||
# Get RAG module from module manager
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||
|
||||
return {
|
||||
"config": rag_module.config,
|
||||
"embedding_model": rag_module.embedding_model,
|
||||
"enabled": rag_module.enabled,
|
||||
"collections": await rag_module._get_collections_safely()
|
||||
}
|
||||
|
||||
@@ -1,251 +0,0 @@
|
||||
"""
|
||||
Security API endpoints for monitoring and configuration
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.security import get_current_active_user, RequiresRole
|
||||
from app.middleware.security import get_security_stats, get_request_auth_level, get_request_risk_score
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(tags=["security"])
|
||||
|
||||
|
||||
# Pydantic models for API responses
|
||||
class SecurityStatsResponse(BaseModel):
|
||||
"""Security statistics response model"""
|
||||
total_requests_analyzed: int
|
||||
threats_detected: int
|
||||
threats_blocked: int
|
||||
anomalies_detected: int
|
||||
rate_limits_exceeded: int
|
||||
avg_analysis_time: float
|
||||
threat_types: Dict[str, int]
|
||||
threat_levels: Dict[str, int]
|
||||
top_attacking_ips: List[tuple]
|
||||
security_enabled: bool
|
||||
threat_detection_enabled: bool
|
||||
rate_limiting_enabled: bool
|
||||
|
||||
|
||||
class SecurityConfigResponse(BaseModel):
|
||||
"""Security configuration response model"""
|
||||
security_enabled: bool = Field(description="Overall security system enabled")
|
||||
threat_detection_enabled: bool = Field(description="Threat detection analysis enabled")
|
||||
rate_limiting_enabled: bool = Field(description="Rate limiting enabled")
|
||||
ip_reputation_enabled: bool = Field(description="IP reputation checking enabled")
|
||||
anomaly_detection_enabled: bool = Field(description="Anomaly detection enabled")
|
||||
security_headers_enabled: bool = Field(description="Security headers enabled")
|
||||
|
||||
# Rate limiting settings
|
||||
unauthenticated_per_minute: int = Field(description="Rate limit for unauthenticated requests per minute")
|
||||
authenticated_per_minute: int = Field(description="Rate limit for authenticated users per minute")
|
||||
api_key_per_minute: int = Field(description="Rate limit for API key users per minute")
|
||||
premium_per_minute: int = Field(description="Rate limit for premium users per minute")
|
||||
|
||||
# Security thresholds
|
||||
risk_threshold: float = Field(description="Risk score threshold for blocking requests")
|
||||
warning_threshold: float = Field(description="Risk score threshold for warnings")
|
||||
anomaly_threshold: float = Field(description="Anomaly severity threshold")
|
||||
|
||||
# IP settings
|
||||
blocked_ips: List[str] = Field(description="List of blocked IP addresses")
|
||||
allowed_ips: List[str] = Field(description="List of allowed IP addresses (empty = allow all)")
|
||||
|
||||
|
||||
class RateLimitInfoResponse(BaseModel):
|
||||
"""Rate limit information for current request"""
|
||||
auth_level: str = Field(description="Authentication level (unauthenticated, authenticated, api_key, premium)")
|
||||
current_limits: Dict[str, int] = Field(description="Current rate limits for this auth level")
|
||||
remaining_requests: Optional[Dict[str, int]] = Field(description="Estimated remaining requests (if available)")
|
||||
|
||||
|
||||
@router.get("/stats", response_model=SecurityStatsResponse)
|
||||
async def get_security_statistics(
|
||||
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
|
||||
):
|
||||
"""
|
||||
Get security system statistics
|
||||
|
||||
Requires admin role. Returns comprehensive statistics about:
|
||||
- Request analysis counts
|
||||
- Threat detection results
|
||||
- Rate limiting enforcement
|
||||
- Top attacking IPs
|
||||
- Performance metrics
|
||||
"""
|
||||
try:
|
||||
stats = get_security_stats()
|
||||
return SecurityStatsResponse(**stats)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting security stats: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve security statistics"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config", response_model=SecurityConfigResponse)
|
||||
async def get_security_config(
|
||||
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
|
||||
):
|
||||
"""
|
||||
Get current security configuration
|
||||
|
||||
Requires admin role. Returns current security settings including:
|
||||
- Feature enablement flags
|
||||
- Rate limiting thresholds
|
||||
- Security thresholds
|
||||
- IP allowlists/blocklists
|
||||
"""
|
||||
return SecurityConfigResponse(
|
||||
security_enabled=settings.API_SECURITY_ENABLED,
|
||||
threat_detection_enabled=settings.API_THREAT_DETECTION_ENABLED,
|
||||
rate_limiting_enabled=settings.API_RATE_LIMITING_ENABLED,
|
||||
ip_reputation_enabled=settings.API_IP_REPUTATION_ENABLED,
|
||||
anomaly_detection_enabled=settings.API_ANOMALY_DETECTION_ENABLED,
|
||||
security_headers_enabled=settings.API_SECURITY_HEADERS_ENABLED,
|
||||
|
||||
unauthenticated_per_minute=settings.API_RATE_LIMIT_UNAUTHENTICATED_PER_MINUTE,
|
||||
authenticated_per_minute=settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE,
|
||||
api_key_per_minute=settings.API_RATE_LIMIT_API_KEY_PER_MINUTE,
|
||||
premium_per_minute=settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE,
|
||||
|
||||
risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
|
||||
warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
|
||||
anomaly_threshold=settings.API_SECURITY_ANOMALY_THRESHOLD,
|
||||
|
||||
blocked_ips=settings.API_BLOCKED_IPS,
|
||||
allowed_ips=settings.API_ALLOWED_IPS
|
||||
)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_security_status(
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(get_current_active_user)
|
||||
):
|
||||
"""
|
||||
Get security status for current request
|
||||
|
||||
Returns information about the security analysis of the current request:
|
||||
- Authentication level
|
||||
- Risk score (if available)
|
||||
- Rate limiting status
|
||||
"""
|
||||
auth_level = get_request_auth_level(request)
|
||||
risk_score = get_request_risk_score(request)
|
||||
|
||||
# Get rate limits for current auth level
|
||||
from app.core.threat_detection import AuthLevel
|
||||
try:
|
||||
auth_enum = AuthLevel(auth_level)
|
||||
from app.core.threat_detection import threat_detection_service
|
||||
minute_limit, hour_limit = threat_detection_service.get_rate_limits(auth_enum)
|
||||
|
||||
rate_limit_info = RateLimitInfoResponse(
|
||||
auth_level=auth_level,
|
||||
current_limits={
|
||||
"per_minute": minute_limit,
|
||||
"per_hour": hour_limit
|
||||
},
|
||||
remaining_requests=None # We don't track remaining requests in current implementation
|
||||
)
|
||||
except ValueError:
|
||||
rate_limit_info = RateLimitInfoResponse(
|
||||
auth_level=auth_level,
|
||||
current_limits={},
|
||||
remaining_requests=None
|
||||
)
|
||||
|
||||
return {
|
||||
"security_enabled": settings.API_SECURITY_ENABLED,
|
||||
"auth_level": auth_level,
|
||||
"risk_score": round(risk_score, 3) if risk_score > 0 else None,
|
||||
"rate_limit_info": rate_limit_info.dict(),
|
||||
"security_headers_enabled": settings.API_SECURITY_HEADERS_ENABLED
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_security_analysis(
|
||||
request: Request,
|
||||
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
|
||||
):
|
||||
"""
|
||||
Test security analysis on current request
|
||||
|
||||
Requires admin role. Manually triggers security analysis on the current request
|
||||
and returns detailed results. Useful for testing security rules and thresholds.
|
||||
"""
|
||||
try:
|
||||
from app.middleware.security import analyze_request_security
|
||||
|
||||
analysis = await analyze_request_security(request, current_user)
|
||||
|
||||
return {
|
||||
"analysis_complete": True,
|
||||
"is_threat": analysis.is_threat,
|
||||
"risk_score": round(analysis.risk_score, 3),
|
||||
"auth_level": analysis.auth_level.value,
|
||||
"should_block": analysis.should_block,
|
||||
"rate_limit_exceeded": analysis.rate_limit_exceeded,
|
||||
"threat_count": len(analysis.threats),
|
||||
"threats": [
|
||||
{
|
||||
"type": threat.threat_type,
|
||||
"level": threat.level.value,
|
||||
"confidence": round(threat.confidence, 3),
|
||||
"description": threat.description,
|
||||
"mitigation": threat.mitigation
|
||||
}
|
||||
for threat in analysis.threats
|
||||
],
|
||||
"recommendations": analysis.recommendations
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in security analysis test: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to perform security analysis test"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def security_health_check():
|
||||
"""
|
||||
Security system health check
|
||||
|
||||
Public endpoint that returns the health status of the security system.
|
||||
Does not require authentication.
|
||||
"""
|
||||
try:
|
||||
stats = get_security_stats()
|
||||
|
||||
# Basic health checks
|
||||
is_healthy = (
|
||||
settings.API_SECURITY_ENABLED and
|
||||
stats.get("total_requests_analyzed", 0) >= 0 and
|
||||
stats.get("avg_analysis_time", 0) < 1.0 # Analysis should be under 1 second
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "healthy" if is_healthy else "degraded",
|
||||
"security_enabled": settings.API_SECURITY_ENABLED,
|
||||
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
|
||||
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED,
|
||||
"avg_analysis_time_ms": round(stats.get("avg_analysis_time", 0) * 1000, 2),
|
||||
"total_requests_analyzed": stats.get("total_requests_analyzed", 0)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Security health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": "Security system error",
|
||||
"security_enabled": settings.API_SECURITY_ENABLED
|
||||
}
|
||||
@@ -97,7 +97,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
|
||||
"api": {
|
||||
# Security Settings
|
||||
"security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"},
|
||||
"threat_detection_enabled": {"value": True, "type": "boolean", "description": "Enable threat detection analysis"},
|
||||
"rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"},
|
||||
"ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"},
|
||||
"anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"},
|
||||
@@ -112,7 +111,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
|
||||
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"},
|
||||
|
||||
# Security Thresholds
|
||||
"security_risk_threshold": {"value": 0.8, "type": "float", "description": "Risk score threshold for blocking requests (0.0-1.0)"},
|
||||
"security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"},
|
||||
"anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"},
|
||||
|
||||
@@ -601,7 +599,6 @@ async def reset_to_defaults(
|
||||
"api": {
|
||||
# Security Settings
|
||||
"security_enabled": {"value": True, "type": "boolean"},
|
||||
"threat_detection_enabled": {"value": True, "type": "boolean"},
|
||||
"rate_limiting_enabled": {"value": True, "type": "boolean"},
|
||||
"ip_reputation_enabled": {"value": True, "type": "boolean"},
|
||||
"anomaly_detection_enabled": {"value": True, "type": "boolean"},
|
||||
@@ -616,7 +613,6 @@ async def reset_to_defaults(
|
||||
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer"},
|
||||
|
||||
# Security Thresholds
|
||||
"security_risk_threshold": {"value": 0.8, "type": "float"},
|
||||
"security_warning_threshold": {"value": 0.6, "type": "float"},
|
||||
"anomaly_threshold": {"value": 0.7, "type": "float"},
|
||||
|
||||
|
||||
Reference in New Issue
Block a user