rag improvements

This commit is contained in:
2025-09-23 15:26:54 +02:00
parent 354b43494d
commit f8d127ff42
30 changed files with 817 additions and 2428 deletions

View File

@@ -12,8 +12,8 @@ from ..v1.audit import router as audit_router
from ..v1.settings import router as settings_router
from ..v1.analytics import router as analytics_router
from ..v1.rag import router as rag_router
from ..rag_debug import router as rag_debug_router
from ..v1.prompt_templates import router as prompt_templates_router
from ..v1.security import router as security_router
from ..v1.plugin_registry import router as plugin_registry_router
from ..v1.platform import router as platform_router
from ..v1.llm_internal import router as llm_internal_router
@@ -52,11 +52,12 @@ internal_api_router.include_router(analytics_router, prefix="/analytics", tags=[
# Include RAG routes (frontend RAG document management)
internal_api_router.include_router(rag_router, prefix="/rag", tags=["internal-rag"])
# Include RAG debug routes (for demo and debugging)
internal_api_router.include_router(rag_debug_router, prefix="/rag/debug", tags=["internal-rag-debug"])
# Include prompt template routes (frontend prompt template management)
internal_api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["internal-prompt-templates"])
# Include security routes (frontend security settings)
internal_api_router.include_router(security_router, prefix="/security", tags=["internal-security"])
# Include plugin registry routes (frontend plugin management)
internal_api_router.include_router(plugin_registry_router, prefix="/plugins", tags=["internal-plugins"])

View File

@@ -16,7 +16,6 @@ from .analytics import router as analytics_router
from .rag import router as rag_router
from .chatbot import router as chatbot_router
from .prompt_templates import router as prompt_templates_router
from .security import router as security_router
from .plugin_registry import router as plugin_registry_router
# Create main API router
@@ -61,8 +60,6 @@ api_router.include_router(chatbot_router, prefix="/chatbot", tags=["chatbot"])
# Include prompt template routes
api_router.include_router(prompt_templates_router, prefix="/prompt-templates", tags=["prompt-templates"])
# Include security routes
api_router.include_router(security_router, prefix="/security", tags=["security"])
# Include plugin registry routes

View File

@@ -745,8 +745,7 @@ async def get_llm_metrics(
"total_requests": metrics.total_requests,
"successful_requests": metrics.successful_requests,
"failed_requests": metrics.failed_requests,
"security_blocked_requests": metrics.security_blocked_requests,
"average_latency_ms": metrics.average_latency_ms,
"average_latency_ms": metrics.average_latency_ms,
"average_risk_score": metrics.average_risk_score,
"provider_metrics": metrics.provider_metrics,
"last_updated": metrics.last_updated.isoformat()

View File

@@ -3,12 +3,14 @@ RAG API Endpoints
Provides REST API for RAG (Retrieval Augmented Generation) operations
"""
from typing import List, Optional
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
import io
import asyncio
from datetime import datetime
from app.db.database import get_db
from app.core.security import get_current_user
@@ -16,6 +18,9 @@ from app.models.user import User
from app.services.rag_service import RAGService
from app.utils.exceptions import APIException
# Import RAG module from module manager
from app.services.module_manager import module_manager
router = APIRouter(tags=["RAG"])
@@ -78,14 +83,25 @@ async def get_collections(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get all RAG collections from Qdrant (source of truth) with PostgreSQL metadata"""
"""Get all RAG collections - live data directly from Qdrant (source of truth)"""
try:
rag_service = RAGService(db)
collections_data = await rag_service.get_all_collections(skip=skip, limit=limit)
from app.services.qdrant_stats_service import qdrant_stats_service
# Get live stats from Qdrant
stats_data = await qdrant_stats_service.get_collections_stats()
collections = stats_data.get("collections", [])
# Apply pagination
start_idx = skip
end_idx = skip + limit
paginated_collections = collections[start_idx:end_idx]
return {
"success": True,
"collections": collections_data,
"total": len(collections_data)
"collections": paginated_collections,
"total": len(collections),
"total_documents": stats_data.get("total_documents", 0),
"total_size_bytes": stats_data.get("total_size_bytes", 0)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -116,6 +132,62 @@ async def create_collection(
raise HTTPException(status_code=500, detail=str(e))
@router.get("/stats", response_model=dict)
async def get_rag_stats(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get overall RAG statistics - live data directly from Qdrant"""
try:
from app.services.qdrant_stats_service import qdrant_stats_service
# Get live stats from Qdrant
stats_data = await qdrant_stats_service.get_collections_stats()
# Calculate active collections (collections with documents)
active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
# Calculate processing documents from database
processing_docs = 0
try:
from sqlalchemy import select
from app.models.rag_document import RagDocument, ProcessingStatus
result = await db.execute(
select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
)
processing_docs = len(result.scalars().all())
except Exception:
pass # If database query fails, default to 0
response_data = {
"success": True,
"stats": {
"collections": {
"total": stats_data.get("total_collections", 0),
"active": active_collections
},
"documents": {
"total": stats_data.get("total_documents", 0),
"processing": processing_docs,
"processed": stats_data.get("total_documents", 0) # Indexed documents
},
"storage": {
"total_size_bytes": stats_data.get("total_size_bytes", 0),
"total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
},
"vectors": {
"total": stats_data.get("total_documents", 0) # Same as documents for RAG
},
"last_updated": datetime.utcnow().isoformat()
}
}
return response_data
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/collections/{collection_id}", response_model=dict)
async def get_collection(
collection_id: int,
@@ -225,21 +297,65 @@ async def upload_document(
try:
# Read file content
file_content = await file.read()
if len(file_content) == 0:
raise HTTPException(status_code=400, detail="Empty file uploaded")
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
# Validate file can be read before processing
filename = file.filename or "unknown"
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
try:
# Test file readability based on type
if file_extension == 'jsonl':
# Validate JSONL format - try to parse first few lines
try:
content_str = file_content.decode('utf-8')
lines = content_str.strip().split('\n')[:5] # Check first 5 lines
import json
for i, line in enumerate(lines):
if line.strip(): # Skip empty lines
json.loads(line) # Will raise JSONDecodeError if invalid
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
# Validate text files can be decoded
try:
file_content.decode('utf-8')
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
elif file_extension in ['pdf']:
# For PDF files, just check if it starts with PDF signature
if not file_content.startswith(b'%PDF'):
raise HTTPException(status_code=400, detail="Invalid PDF file format")
elif file_extension in ['docx', 'xlsx', 'pptx']:
# For Office documents, check ZIP signature
if not file_content.startswith(b'PK'):
raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
# For other file types, we'll rely on the document processor
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
rag_service = RAGService(db)
document = await rag_service.upload_document(
collection_id=collection_id,
file_content=file_content,
filename=file.filename or "unknown",
filename=filename,
content_type=file.content_type
)
return {
"success": True,
"document": document.to_dict(),
@@ -362,21 +478,167 @@ async def download_document(
raise HTTPException(status_code=500, detail=str(e))
# Stats Endpoint
@router.get("/stats", response_model=dict)
async def get_rag_stats(
db: AsyncSession = Depends(get_db),
# Debug Endpoints
@router.post("/debug/search")
async def search_with_debug(
query: str,
max_results: int = 10,
score_threshold: float = 0.3,
collection_name: str = None,
config: Dict[str, Any] = None,
current_user: User = Depends(get_current_user)
):
"""Get RAG system statistics"""
) -> Dict[str, Any]:
"""
Enhanced search with comprehensive debug information
"""
# Get RAG module from module manager
rag_module = module_manager.modules.get('rag')
if not rag_module or not rag_module.enabled:
raise HTTPException(status_code=503, detail="RAG module not initialized")
debug_info = {}
start_time = datetime.utcnow()
try:
rag_service = RAGService(db)
stats = await rag_service.get_stats()
# Apply configuration if provided
if config:
# Update RAG config temporarily
original_config = rag_module.config.copy()
rag_module.config.update(config)
# Generate query embedding (with or without prefix)
if config and config.get("use_query_prefix"):
optimized_query = f"query: {query}"
else:
optimized_query = query
query_embedding = await rag_module._generate_embedding(optimized_query)
# Store embedding info for debug
if config and config.get("debug", {}).get("show_embeddings"):
debug_info["query_embedding"] = query_embedding[:10] # First 10 dimensions
debug_info["embedding_dimension"] = len(query_embedding)
debug_info["optimized_query"] = optimized_query
# Perform search
search_start = asyncio.get_event_loop().time()
results = await rag_module.search_documents(
query,
max_results=max_results,
score_threshold=score_threshold,
collection_name=collection_name
)
search_time = (asyncio.get_event_loop().time() - search_start) * 1000
# Calculate score statistics
scores = [r.score for r in results if r.score is not None]
if scores:
import statistics
debug_info["score_stats"] = {
"min": min(scores),
"max": max(scores),
"avg": statistics.mean(scores),
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0
}
# Get collection statistics
try:
from qdrant_client.http.models import Filter
collection_name = collection_name or rag_module.default_collection_name
# Count total documents
count_result = rag_module.qdrant_client.count(
collection_name=collection_name,
count_filter=Filter(must=[])
)
total_points = count_result.count
# Get unique documents and languages
scroll_result = rag_module.qdrant_client.scroll(
collection_name=collection_name,
limit=1000, # Sample for stats
with_payload=True,
with_vectors=False
)
unique_docs = set()
languages = set()
for point in scroll_result[0]:
payload = point.payload or {}
doc_id = payload.get("document_id")
if doc_id:
unique_docs.add(doc_id)
language = payload.get("language")
if language:
languages.add(language)
debug_info["collection_stats"] = {
"total_documents": len(unique_docs),
"total_chunks": total_points,
"languages": sorted(list(languages))
}
except Exception as e:
debug_info["collection_stats_error"] = str(e)
# Enhance results with debug info
enhanced_results = []
for result in results:
enhanced_result = {
"document": {
"id": result.document.id,
"content": result.document.content,
"metadata": result.document.metadata
},
"score": result.score,
"debug_info": {}
}
# Add hybrid search debug info if available
metadata = result.document.metadata or {}
if "_vector_score" in metadata:
enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
if "_bm25_score" in metadata:
enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
enhanced_results.append(enhanced_result)
# Note: Analytics logging disabled (module not available)
return {
"success": True,
"stats": stats
"results": enhanced_results,
"debug_info": debug_info,
"search_time_ms": search_time,
"timestamp": start_time.isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Note: Analytics logging disabled (module not available)
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
finally:
# Restore original config if modified
if config and 'original_config' in locals():
rag_module.config = original_config
@router.get("/debug/config")
async def get_current_config(
current_user: User = Depends(get_current_user)
) -> Dict[str, Any]:
"""Get current RAG configuration"""
# Get RAG module from module manager
rag_module = module_manager.modules.get('rag')
if not rag_module or not rag_module.enabled:
raise HTTPException(status_code=503, detail="RAG module not initialized")
return {
"config": rag_module.config,
"embedding_model": rag_module.embedding_model,
"enabled": rag_module.enabled,
"collections": await rag_module._get_collections_safely()
}

View File

@@ -1,251 +0,0 @@
"""
Security API endpoints for monitoring and configuration
"""
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel, Field
from app.core.security import get_current_active_user, RequiresRole
from app.middleware.security import get_security_stats, get_request_auth_level, get_request_risk_score
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
router = APIRouter(tags=["security"])
# Pydantic models for API responses
class SecurityStatsResponse(BaseModel):
"""Security statistics response model"""
total_requests_analyzed: int
threats_detected: int
threats_blocked: int
anomalies_detected: int
rate_limits_exceeded: int
avg_analysis_time: float
threat_types: Dict[str, int]
threat_levels: Dict[str, int]
top_attacking_ips: List[tuple]
security_enabled: bool
threat_detection_enabled: bool
rate_limiting_enabled: bool
class SecurityConfigResponse(BaseModel):
"""Security configuration response model"""
security_enabled: bool = Field(description="Overall security system enabled")
threat_detection_enabled: bool = Field(description="Threat detection analysis enabled")
rate_limiting_enabled: bool = Field(description="Rate limiting enabled")
ip_reputation_enabled: bool = Field(description="IP reputation checking enabled")
anomaly_detection_enabled: bool = Field(description="Anomaly detection enabled")
security_headers_enabled: bool = Field(description="Security headers enabled")
# Rate limiting settings
unauthenticated_per_minute: int = Field(description="Rate limit for unauthenticated requests per minute")
authenticated_per_minute: int = Field(description="Rate limit for authenticated users per minute")
api_key_per_minute: int = Field(description="Rate limit for API key users per minute")
premium_per_minute: int = Field(description="Rate limit for premium users per minute")
# Security thresholds
risk_threshold: float = Field(description="Risk score threshold for blocking requests")
warning_threshold: float = Field(description="Risk score threshold for warnings")
anomaly_threshold: float = Field(description="Anomaly severity threshold")
# IP settings
blocked_ips: List[str] = Field(description="List of blocked IP addresses")
allowed_ips: List[str] = Field(description="List of allowed IP addresses (empty = allow all)")
class RateLimitInfoResponse(BaseModel):
"""Rate limit information for current request"""
auth_level: str = Field(description="Authentication level (unauthenticated, authenticated, api_key, premium)")
current_limits: Dict[str, int] = Field(description="Current rate limits for this auth level")
remaining_requests: Optional[Dict[str, int]] = Field(description="Estimated remaining requests (if available)")
@router.get("/stats", response_model=SecurityStatsResponse)
async def get_security_statistics(
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
):
"""
Get security system statistics
Requires admin role. Returns comprehensive statistics about:
- Request analysis counts
- Threat detection results
- Rate limiting enforcement
- Top attacking IPs
- Performance metrics
"""
try:
stats = get_security_stats()
return SecurityStatsResponse(**stats)
except Exception as e:
logger.error(f"Error getting security stats: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve security statistics"
)
@router.get("/config", response_model=SecurityConfigResponse)
async def get_security_config(
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
):
"""
Get current security configuration
Requires admin role. Returns current security settings including:
- Feature enablement flags
- Rate limiting thresholds
- Security thresholds
- IP allowlists/blocklists
"""
return SecurityConfigResponse(
security_enabled=settings.API_SECURITY_ENABLED,
threat_detection_enabled=settings.API_THREAT_DETECTION_ENABLED,
rate_limiting_enabled=settings.API_RATE_LIMITING_ENABLED,
ip_reputation_enabled=settings.API_IP_REPUTATION_ENABLED,
anomaly_detection_enabled=settings.API_ANOMALY_DETECTION_ENABLED,
security_headers_enabled=settings.API_SECURITY_HEADERS_ENABLED,
unauthenticated_per_minute=settings.API_RATE_LIMIT_UNAUTHENTICATED_PER_MINUTE,
authenticated_per_minute=settings.API_RATE_LIMIT_AUTHENTICATED_PER_MINUTE,
api_key_per_minute=settings.API_RATE_LIMIT_API_KEY_PER_MINUTE,
premium_per_minute=settings.API_RATE_LIMIT_PREMIUM_PER_MINUTE,
risk_threshold=settings.API_SECURITY_RISK_THRESHOLD,
warning_threshold=settings.API_SECURITY_WARNING_THRESHOLD,
anomaly_threshold=settings.API_SECURITY_ANOMALY_THRESHOLD,
blocked_ips=settings.API_BLOCKED_IPS,
allowed_ips=settings.API_ALLOWED_IPS
)
@router.get("/status")
async def get_security_status(
request: Request,
current_user: Dict[str, Any] = Depends(get_current_active_user)
):
"""
Get security status for current request
Returns information about the security analysis of the current request:
- Authentication level
- Risk score (if available)
- Rate limiting status
"""
auth_level = get_request_auth_level(request)
risk_score = get_request_risk_score(request)
# Get rate limits for current auth level
from app.core.threat_detection import AuthLevel
try:
auth_enum = AuthLevel(auth_level)
from app.core.threat_detection import threat_detection_service
minute_limit, hour_limit = threat_detection_service.get_rate_limits(auth_enum)
rate_limit_info = RateLimitInfoResponse(
auth_level=auth_level,
current_limits={
"per_minute": minute_limit,
"per_hour": hour_limit
},
remaining_requests=None # We don't track remaining requests in current implementation
)
except ValueError:
rate_limit_info = RateLimitInfoResponse(
auth_level=auth_level,
current_limits={},
remaining_requests=None
)
return {
"security_enabled": settings.API_SECURITY_ENABLED,
"auth_level": auth_level,
"risk_score": round(risk_score, 3) if risk_score > 0 else None,
"rate_limit_info": rate_limit_info.dict(),
"security_headers_enabled": settings.API_SECURITY_HEADERS_ENABLED
}
@router.post("/test")
async def test_security_analysis(
request: Request,
current_user: Dict[str, Any] = Depends(RequiresRole("admin"))
):
"""
Test security analysis on current request
Requires admin role. Manually triggers security analysis on the current request
and returns detailed results. Useful for testing security rules and thresholds.
"""
try:
from app.middleware.security import analyze_request_security
analysis = await analyze_request_security(request, current_user)
return {
"analysis_complete": True,
"is_threat": analysis.is_threat,
"risk_score": round(analysis.risk_score, 3),
"auth_level": analysis.auth_level.value,
"should_block": analysis.should_block,
"rate_limit_exceeded": analysis.rate_limit_exceeded,
"threat_count": len(analysis.threats),
"threats": [
{
"type": threat.threat_type,
"level": threat.level.value,
"confidence": round(threat.confidence, 3),
"description": threat.description,
"mitigation": threat.mitigation
}
for threat in analysis.threats
],
"recommendations": analysis.recommendations
}
except Exception as e:
logger.error(f"Error in security analysis test: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to perform security analysis test"
)
@router.get("/health")
async def security_health_check():
"""
Security system health check
Public endpoint that returns the health status of the security system.
Does not require authentication.
"""
try:
stats = get_security_stats()
# Basic health checks
is_healthy = (
settings.API_SECURITY_ENABLED and
stats.get("total_requests_analyzed", 0) >= 0 and
stats.get("avg_analysis_time", 0) < 1.0 # Analysis should be under 1 second
)
return {
"status": "healthy" if is_healthy else "degraded",
"security_enabled": settings.API_SECURITY_ENABLED,
"threat_detection_enabled": settings.API_THREAT_DETECTION_ENABLED,
"rate_limiting_enabled": settings.API_RATE_LIMITING_ENABLED,
"avg_analysis_time_ms": round(stats.get("avg_analysis_time", 0) * 1000, 2),
"total_requests_analyzed": stats.get("total_requests_analyzed", 0)
}
except Exception as e:
logger.error(f"Security health check failed: {e}")
return {
"status": "unhealthy",
"error": "Security system error",
"security_enabled": settings.API_SECURITY_ENABLED
}

View File

@@ -97,7 +97,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
"api": {
# Security Settings
"security_enabled": {"value": True, "type": "boolean", "description": "Enable API security system"},
"threat_detection_enabled": {"value": True, "type": "boolean", "description": "Enable threat detection analysis"},
"rate_limiting_enabled": {"value": True, "type": "boolean", "description": "Enable rate limiting"},
"ip_reputation_enabled": {"value": True, "type": "boolean", "description": "Enable IP reputation checking"},
"anomaly_detection_enabled": {"value": True, "type": "boolean", "description": "Enable anomaly detection"},
@@ -112,7 +111,6 @@ SETTINGS_STORE: Dict[str, Dict[str, Any]] = {
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer", "description": "Rate limit for premium users per hour"},
# Security Thresholds
"security_risk_threshold": {"value": 0.8, "type": "float", "description": "Risk score threshold for blocking requests (0.0-1.0)"},
"security_warning_threshold": {"value": 0.6, "type": "float", "description": "Risk score threshold for warnings (0.0-1.0)"},
"anomaly_threshold": {"value": 0.7, "type": "float", "description": "Anomaly severity threshold (0.0-1.0)"},
@@ -601,7 +599,6 @@ async def reset_to_defaults(
"api": {
# Security Settings
"security_enabled": {"value": True, "type": "boolean"},
"threat_detection_enabled": {"value": True, "type": "boolean"},
"rate_limiting_enabled": {"value": True, "type": "boolean"},
"ip_reputation_enabled": {"value": True, "type": "boolean"},
"anomaly_detection_enabled": {"value": True, "type": "boolean"},
@@ -616,7 +613,6 @@ async def reset_to_defaults(
"rate_limit_premium_per_hour": {"value": 100000, "type": "integer"},
# Security Thresholds
"security_risk_threshold": {"value": 0.8, "type": "float"},
"security_warning_threshold": {"value": 0.6, "type": "float"},
"anomaly_threshold": {"value": 0.7, "type": "float"},