mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
rag improvements
This commit is contained in:
@@ -3,12 +3,14 @@ RAG API Endpoints
|
||||
Provides REST API for RAG (Retrieval Augmented Generation) operations
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from pydantic import BaseModel
|
||||
import io
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from app.db.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
@@ -16,6 +18,9 @@ from app.models.user import User
|
||||
from app.services.rag_service import RAGService
|
||||
from app.utils.exceptions import APIException
|
||||
|
||||
# Import RAG module from module manager
|
||||
from app.services.module_manager import module_manager
|
||||
|
||||
|
||||
router = APIRouter(tags=["RAG"])
|
||||
|
||||
@@ -78,14 +83,25 @@ async def get_collections(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get all RAG collections from Qdrant (source of truth) with PostgreSQL metadata"""
|
||||
"""Get all RAG collections - live data directly from Qdrant (source of truth)"""
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
collections_data = await rag_service.get_all_collections(skip=skip, limit=limit)
|
||||
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||
|
||||
# Get live stats from Qdrant
|
||||
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||
collections = stats_data.get("collections", [])
|
||||
|
||||
# Apply pagination
|
||||
start_idx = skip
|
||||
end_idx = skip + limit
|
||||
paginated_collections = collections[start_idx:end_idx]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"collections": collections_data,
|
||||
"total": len(collections_data)
|
||||
"collections": paginated_collections,
|
||||
"total": len(collections),
|
||||
"total_documents": stats_data.get("total_documents", 0),
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0)
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -116,6 +132,62 @@ async def create_collection(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stats", response_model=dict)
|
||||
async def get_rag_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get overall RAG statistics - live data directly from Qdrant"""
|
||||
try:
|
||||
from app.services.qdrant_stats_service import qdrant_stats_service
|
||||
|
||||
# Get live stats from Qdrant
|
||||
stats_data = await qdrant_stats_service.get_collections_stats()
|
||||
|
||||
# Calculate active collections (collections with documents)
|
||||
active_collections = sum(1 for col in stats_data.get("collections", []) if col.get("document_count", 0) > 0)
|
||||
|
||||
# Calculate processing documents from database
|
||||
processing_docs = 0
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
from app.models.rag_document import RagDocument, ProcessingStatus
|
||||
|
||||
result = await db.execute(
|
||||
select(RagDocument).where(RagDocument.status == ProcessingStatus.PROCESSING)
|
||||
)
|
||||
processing_docs = len(result.scalars().all())
|
||||
except Exception:
|
||||
pass # If database query fails, default to 0
|
||||
|
||||
response_data = {
|
||||
"success": True,
|
||||
"stats": {
|
||||
"collections": {
|
||||
"total": stats_data.get("total_collections", 0),
|
||||
"active": active_collections
|
||||
},
|
||||
"documents": {
|
||||
"total": stats_data.get("total_documents", 0),
|
||||
"processing": processing_docs,
|
||||
"processed": stats_data.get("total_documents", 0) # Indexed documents
|
||||
},
|
||||
"storage": {
|
||||
"total_size_bytes": stats_data.get("total_size_bytes", 0),
|
||||
"total_size_mb": round(stats_data.get("total_size_bytes", 0) / (1024 * 1024), 2)
|
||||
},
|
||||
"vectors": {
|
||||
"total": stats_data.get("total_documents", 0) # Same as documents for RAG
|
||||
},
|
||||
"last_updated": datetime.utcnow().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
return response_data
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/collections/{collection_id}", response_model=dict)
|
||||
async def get_collection(
|
||||
collection_id: int,
|
||||
@@ -225,21 +297,65 @@ async def upload_document(
|
||||
try:
|
||||
# Read file content
|
||||
file_content = await file.read()
|
||||
|
||||
|
||||
if len(file_content) == 0:
|
||||
raise HTTPException(status_code=400, detail="Empty file uploaded")
|
||||
|
||||
|
||||
if len(file_content) > 50 * 1024 * 1024: # 50MB limit
|
||||
raise HTTPException(status_code=400, detail="File too large (max 50MB)")
|
||||
|
||||
|
||||
# Validate file can be read before processing
|
||||
filename = file.filename or "unknown"
|
||||
file_extension = filename.split('.')[-1].lower() if '.' in filename else ''
|
||||
|
||||
try:
|
||||
# Test file readability based on type
|
||||
if file_extension == 'jsonl':
|
||||
# Validate JSONL format - try to parse first few lines
|
||||
try:
|
||||
content_str = file_content.decode('utf-8')
|
||||
lines = content_str.strip().split('\n')[:5] # Check first 5 lines
|
||||
import json
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip(): # Skip empty lines
|
||||
json.loads(line) # Will raise JSONDecodeError if invalid
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSONL format: {str(e)}")
|
||||
|
||||
elif file_extension in ['txt', 'md', 'py', 'js', 'html', 'css', 'json']:
|
||||
# Validate text files can be decoded
|
||||
try:
|
||||
file_content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="File is not valid UTF-8 text")
|
||||
|
||||
elif file_extension in ['pdf']:
|
||||
# For PDF files, just check if it starts with PDF signature
|
||||
if not file_content.startswith(b'%PDF'):
|
||||
raise HTTPException(status_code=400, detail="Invalid PDF file format")
|
||||
|
||||
elif file_extension in ['docx', 'xlsx', 'pptx']:
|
||||
# For Office documents, check ZIP signature
|
||||
if not file_content.startswith(b'PK'):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid {file_extension.upper()} file format")
|
||||
|
||||
# For other file types, we'll rely on the document processor
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"File validation failed: {str(e)}")
|
||||
|
||||
rag_service = RAGService(db)
|
||||
document = await rag_service.upload_document(
|
||||
collection_id=collection_id,
|
||||
file_content=file_content,
|
||||
filename=file.filename or "unknown",
|
||||
filename=filename,
|
||||
content_type=file.content_type
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"document": document.to_dict(),
|
||||
@@ -362,21 +478,167 @@ async def download_document(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Stats Endpoint
|
||||
|
||||
@router.get("/stats", response_model=dict)
|
||||
async def get_rag_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
# Debug Endpoints
|
||||
|
||||
@router.post("/debug/search")
|
||||
async def search_with_debug(
|
||||
query: str,
|
||||
max_results: int = 10,
|
||||
score_threshold: float = 0.3,
|
||||
collection_name: str = None,
|
||||
config: Dict[str, Any] = None,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get RAG system statistics"""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Enhanced search with comprehensive debug information
|
||||
"""
|
||||
# Get RAG module from module manager
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||
|
||||
debug_info = {}
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
rag_service = RAGService(db)
|
||||
stats = await rag_service.get_stats()
|
||||
|
||||
# Apply configuration if provided
|
||||
if config:
|
||||
# Update RAG config temporarily
|
||||
original_config = rag_module.config.copy()
|
||||
rag_module.config.update(config)
|
||||
|
||||
# Generate query embedding (with or without prefix)
|
||||
if config and config.get("use_query_prefix"):
|
||||
optimized_query = f"query: {query}"
|
||||
else:
|
||||
optimized_query = query
|
||||
|
||||
query_embedding = await rag_module._generate_embedding(optimized_query)
|
||||
|
||||
# Store embedding info for debug
|
||||
if config and config.get("debug", {}).get("show_embeddings"):
|
||||
debug_info["query_embedding"] = query_embedding[:10] # First 10 dimensions
|
||||
debug_info["embedding_dimension"] = len(query_embedding)
|
||||
debug_info["optimized_query"] = optimized_query
|
||||
|
||||
# Perform search
|
||||
search_start = asyncio.get_event_loop().time()
|
||||
results = await rag_module.search_documents(
|
||||
query,
|
||||
max_results=max_results,
|
||||
score_threshold=score_threshold,
|
||||
collection_name=collection_name
|
||||
)
|
||||
search_time = (asyncio.get_event_loop().time() - search_start) * 1000
|
||||
|
||||
# Calculate score statistics
|
||||
scores = [r.score for r in results if r.score is not None]
|
||||
if scores:
|
||||
import statistics
|
||||
debug_info["score_stats"] = {
|
||||
"min": min(scores),
|
||||
"max": max(scores),
|
||||
"avg": statistics.mean(scores),
|
||||
"stddev": statistics.stdev(scores) if len(scores) > 1 else 0
|
||||
}
|
||||
|
||||
# Get collection statistics
|
||||
try:
|
||||
from qdrant_client.http.models import Filter
|
||||
collection_name = collection_name or rag_module.default_collection_name
|
||||
|
||||
# Count total documents
|
||||
count_result = rag_module.qdrant_client.count(
|
||||
collection_name=collection_name,
|
||||
count_filter=Filter(must=[])
|
||||
)
|
||||
total_points = count_result.count
|
||||
|
||||
# Get unique documents and languages
|
||||
scroll_result = rag_module.qdrant_client.scroll(
|
||||
collection_name=collection_name,
|
||||
limit=1000, # Sample for stats
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
unique_docs = set()
|
||||
languages = set()
|
||||
|
||||
for point in scroll_result[0]:
|
||||
payload = point.payload or {}
|
||||
doc_id = payload.get("document_id")
|
||||
if doc_id:
|
||||
unique_docs.add(doc_id)
|
||||
|
||||
language = payload.get("language")
|
||||
if language:
|
||||
languages.add(language)
|
||||
|
||||
debug_info["collection_stats"] = {
|
||||
"total_documents": len(unique_docs),
|
||||
"total_chunks": total_points,
|
||||
"languages": sorted(list(languages))
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
debug_info["collection_stats_error"] = str(e)
|
||||
|
||||
# Enhance results with debug info
|
||||
enhanced_results = []
|
||||
for result in results:
|
||||
enhanced_result = {
|
||||
"document": {
|
||||
"id": result.document.id,
|
||||
"content": result.document.content,
|
||||
"metadata": result.document.metadata
|
||||
},
|
||||
"score": result.score,
|
||||
"debug_info": {}
|
||||
}
|
||||
|
||||
# Add hybrid search debug info if available
|
||||
metadata = result.document.metadata or {}
|
||||
if "_vector_score" in metadata:
|
||||
enhanced_result["debug_info"]["vector_score"] = metadata["_vector_score"]
|
||||
if "_bm25_score" in metadata:
|
||||
enhanced_result["debug_info"]["bm25_score"] = metadata["_bm25_score"]
|
||||
|
||||
enhanced_results.append(enhanced_result)
|
||||
|
||||
# Note: Analytics logging disabled (module not available)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"stats": stats
|
||||
"results": enhanced_results,
|
||||
"debug_info": debug_info,
|
||||
"search_time_ms": search_time,
|
||||
"timestamp": start_time.isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
# Note: Analytics logging disabled (module not available)
|
||||
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
||||
|
||||
finally:
|
||||
# Restore original config if modified
|
||||
if config and 'original_config' in locals():
|
||||
rag_module.config = original_config
|
||||
|
||||
|
||||
@router.get("/debug/config")
|
||||
async def get_current_config(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current RAG configuration"""
|
||||
# Get RAG module from module manager
|
||||
rag_module = module_manager.modules.get('rag')
|
||||
if not rag_module or not rag_module.enabled:
|
||||
raise HTTPException(status_code=503, detail="RAG module not initialized")
|
||||
|
||||
return {
|
||||
"config": rag_module.config,
|
||||
"embedding_model": rag_module.embedding_model,
|
||||
"enabled": rag_module.enabled,
|
||||
"collections": await rag_module._get_collections_safely()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user