mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
1923 lines
80 KiB
Python
1923 lines
80 KiB
Python
"""
|
|
RAG module implementation with vector database and document processing
|
|
Includes comprehensive document processing, content extraction, and NLP analysis
|
|
"""
|
|
import asyncio
|
|
import io
|
|
import json
|
|
import logging
|
|
import mimetypes
|
|
import re
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
from datetime import datetime
|
|
from dataclasses import dataclass, asdict
|
|
from pathlib import Path
|
|
import hashlib
|
|
import base64
|
|
import numpy as np
|
|
import uuid
|
|
|
|
# Initialize logger early
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Document processing libraries (with graceful fallbacks)
|
|
try:
|
|
import nltk
|
|
from nltk.tokenize import sent_tokenize, word_tokenize
|
|
from nltk.corpus import stopwords
|
|
from nltk.stem import WordNetLemmatizer
|
|
NLTK_AVAILABLE = True
|
|
except ImportError:
|
|
logger.warning("NLTK not available - NLP features will be limited")
|
|
NLTK_AVAILABLE = False
|
|
|
|
try:
|
|
import spacy
|
|
SPACY_AVAILABLE = True
|
|
except ImportError:
|
|
logger.warning("spaCy not available - entity extraction will be disabled")
|
|
SPACY_AVAILABLE = False
|
|
|
|
try:
|
|
from markitdown import MarkItDown
|
|
MARKITDOWN_AVAILABLE = True
|
|
except ImportError:
|
|
logger.warning("MarkItDown not available - document conversion will be limited")
|
|
MARKITDOWN_AVAILABLE = False
|
|
|
|
try:
|
|
from docx import Document as DocxDocument
|
|
PYTHON_DOCX_AVAILABLE = True
|
|
except ImportError:
|
|
logger.warning("python-docx not available - DOCX processing will be limited")
|
|
PYTHON_DOCX_AVAILABLE = False
|
|
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import Distance, VectorParams, PointStruct, ScoredPoint, Filter, FieldCondition, MatchValue
|
|
from qdrant_client.http import models
|
|
import tiktoken
|
|
|
|
from app.core.config import settings
|
|
from app.core.logging import log_module_event
|
|
from app.services.base_module import BaseModule, Permission
|
|
|
|
|
|
@dataclass
|
|
class ProcessedDocument:
|
|
"""Processed document data structure"""
|
|
id: str
|
|
original_filename: str
|
|
file_type: str
|
|
mime_type: str
|
|
content: str
|
|
extracted_text: str
|
|
metadata: Dict[str, Any]
|
|
word_count: int
|
|
sentence_count: int
|
|
language: str
|
|
entities: List[Dict[str, Any]]
|
|
keywords: List[str]
|
|
processing_time: float
|
|
processed_at: datetime
|
|
file_hash: str
|
|
file_size: int
|
|
embedding: Optional[List[float]] = None
|
|
created_at: datetime = None
|
|
|
|
def __post_init__(self):
|
|
if self.created_at is None:
|
|
self.created_at = datetime.utcnow()
|
|
|
|
|
|
@dataclass
|
|
class ContentValidationResult:
|
|
"""Content validation result"""
|
|
is_valid: bool
|
|
issues: List[str]
|
|
security_score: float
|
|
content_type: str
|
|
language_confidence: float
|
|
|
|
|
|
# Keep Document class for backward compatibility
|
|
@dataclass
|
|
class Document:
|
|
"""Simple document data structure for backward compatibility"""
|
|
id: str
|
|
content: str
|
|
metadata: Dict[str, Any]
|
|
embedding: Optional[List[float]] = None
|
|
created_at: datetime = None
|
|
|
|
def __post_init__(self):
|
|
if self.created_at is None:
|
|
self.created_at = datetime.utcnow()
|
|
|
|
|
|
@dataclass
|
|
class SearchResult:
|
|
"""Search result data structure"""
|
|
document: Document
|
|
score: float
|
|
relevance_score: float
|
|
|
|
|
|
class RAGModule(BaseModule):
|
|
"""RAG module for document storage, retrieval, and augmented generation with integrated content processing"""
|
|
|
|
def __init__(self, config: Dict[str, Any] = None):
|
|
super().__init__(module_id="rag", config=config)
|
|
self.enabled = False
|
|
self.qdrant_client: Optional[QdrantClient] = None
|
|
self.default_collection_name = "documents" # Keep for backward compatibility
|
|
self.embedding_model = None
|
|
self.embedding_service = None
|
|
self.tokenizer = None
|
|
|
|
# Set improved default configuration
|
|
self.config = {
|
|
"chunk_size": 300, # Reduced from 400 for better precision
|
|
"chunk_overlap": 50, # Added overlap for context preservation
|
|
"max_results": 10,
|
|
"score_threshold": 0.3, # Increased from 0.0 to filter low-quality results
|
|
"enable_hybrid": True, # Enable hybrid search (vector + BM25)
|
|
"hybrid_weights": {"vector": 0.7, "bm25": 0.3} # Weight for hybrid scoring
|
|
}
|
|
# Update with any provided config
|
|
if config:
|
|
self.config.update(config)
|
|
|
|
# Content processing components
|
|
self.nlp_model = None
|
|
self.lemmatizer = None
|
|
self.stop_words = set()
|
|
self.markitdown = None
|
|
self.supported_types = {
|
|
'text/plain': self._process_text,
|
|
'application/pdf': self._process_with_markitdown,
|
|
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': self._process_docx,
|
|
'application/msword': self._process_docx,
|
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': self._process_with_markitdown,
|
|
'application/vnd.ms-excel': self._process_with_markitdown,
|
|
'text/html': self._process_html,
|
|
'application/json': self._process_json,
|
|
'application/x-ndjson': self._process_jsonl, # JSONL support
|
|
'text/markdown': self._process_markdown,
|
|
'text/csv': self._process_csv
|
|
}
|
|
|
|
self.stats = {
|
|
"documents_indexed": 0,
|
|
"documents_processed": 0,
|
|
"total_processing_time": 0,
|
|
"average_processing_time": 0,
|
|
"searches_performed": 0,
|
|
"average_search_time": 0.0,
|
|
"cache_hits": 0,
|
|
"errors": 0,
|
|
"supported_types": len(self.supported_types)
|
|
}
|
|
self.search_cache = {}
|
|
|
|
def get_required_permissions(self) -> List[Permission]:
|
|
"""Return list of permissions this module requires"""
|
|
return [
|
|
Permission("documents", "index", "Index new documents"),
|
|
Permission("documents", "search", "Search documents"),
|
|
Permission("documents", "delete", "Delete documents"),
|
|
Permission("collections", "manage", "Manage collections"),
|
|
Permission("settings", "configure", "Configure RAG settings")
|
|
]
|
|
|
|
async def initialize(self):
|
|
"""Initialize the RAG module with content processing capabilities"""
|
|
|
|
try:
|
|
# Initialize Qdrant client
|
|
qdrant_host = getattr(settings, 'QDRANT_HOST', 'localhost')
|
|
qdrant_port = getattr(settings, 'QDRANT_PORT', 6333)
|
|
qdrant_url = f"http://{qdrant_host}:{qdrant_port}"
|
|
self.qdrant_client = QdrantClient(url=qdrant_url)
|
|
|
|
# Initialize tokenizer
|
|
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
|
|
# Initialize embedding model
|
|
self.embedding_model = await self._initialize_embedding_model()
|
|
|
|
# Initialize content processing components
|
|
await self._initialize_content_processing()
|
|
|
|
# Create default collection if it doesn't exist
|
|
await self._ensure_collection_exists(self.default_collection_name)
|
|
|
|
self.enabled = True
|
|
self.initialized = True
|
|
log_module_event("rag", "initialized", {
|
|
"vector_db": self.config.get("vector_db", "qdrant"),
|
|
"embedding_model": self.embedding_model.get("model_name", "intfloat/multilingual-e5-large-instruct"),
|
|
"chunk_size": self.config.get("chunk_size", 400),
|
|
"max_results": self.config.get("max_results", 10),
|
|
"supported_file_types": list(self.supported_types.keys()),
|
|
"nltk_ready": True,
|
|
"spacy_ready": self.nlp_model is not None,
|
|
"markitdown_ready": self.markitdown is not None
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize RAG module: {e}")
|
|
log_module_event("rag", "initialization_failed", {"error": str(e)})
|
|
self.enabled = False
|
|
raise
|
|
|
|
def _generate_file_hash(self, content: bytes) -> str:
|
|
"""Generate SHA-256 hash of file content"""
|
|
return hashlib.sha256(content).hexdigest()
|
|
|
|
def _detect_mime_type(self, filename: str, content: bytes) -> str:
|
|
"""Detect MIME type of file"""
|
|
# Try to detect from filename
|
|
mime_type, _ = mimetypes.guess_type(filename)
|
|
if mime_type:
|
|
return mime_type
|
|
|
|
# Check for JSONL file extension
|
|
if filename.lower().endswith('.jsonl'):
|
|
return 'application/x-ndjson'
|
|
|
|
# Try to detect from content
|
|
if content.startswith(b'%PDF'):
|
|
return 'application/pdf'
|
|
elif content.startswith(b'PK'):
|
|
# This could be DOCX, XLSX, or other Office formats
|
|
if filename.lower().endswith(('.docx', '.docm')):
|
|
return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
|
elif filename.lower().endswith(('.xlsx', '.xlsm')):
|
|
return 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
|
|
else:
|
|
return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
|
elif content.startswith(b'\xd0\xcf\x11\xe0'):
|
|
# Old Office format (DOC, XLS)
|
|
if filename.lower().endswith('.xls'):
|
|
return 'application/vnd.ms-excel'
|
|
else:
|
|
return 'application/msword'
|
|
elif content.startswith(b'<html') or content.startswith(b'<!DOCTYPE'):
|
|
return 'text/html'
|
|
elif content.startswith(b'{') or content.startswith(b'['):
|
|
# Check if it's JSONL by looking for newline-delimited JSON
|
|
try:
|
|
lines = content.decode('utf-8', errors='ignore').split('\n')
|
|
if len(lines) > 1 and all(line.strip().startswith('{') for line in lines[:3] if line.strip()):
|
|
return 'application/x-ndjson'
|
|
except:
|
|
pass
|
|
return 'application/json'
|
|
else:
|
|
return 'text/plain'
|
|
|
|
def _detect_language(self, text: str) -> Tuple[str, float]:
|
|
"""Detect language of text (simplified implementation)"""
|
|
if len(text) < 50:
|
|
return 'unknown', 0.0
|
|
|
|
# Simple heuristic based on common English words
|
|
english_words = {'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may', 'might', 'must', 'can', 'shall'}
|
|
|
|
if NLTK_AVAILABLE:
|
|
words = word_tokenize(text.lower())
|
|
else:
|
|
# Fallback to simple whitespace tokenization
|
|
words = text.lower().split()
|
|
|
|
english_count = sum(1 for word in words if word in english_words)
|
|
confidence = min(english_count / len(words), 1.0) if words else 0.0
|
|
|
|
return 'en' if confidence > 0.1 else 'unknown', confidence
|
|
|
|
def _extract_entities(self, text: str) -> List[Dict[str, Any]]:
|
|
"""Extract named entities from text"""
|
|
if not self.nlp_model:
|
|
return []
|
|
|
|
try:
|
|
doc = self.nlp_model(text[:10000]) # Limit text length for performance
|
|
entities = []
|
|
|
|
for ent in doc.ents:
|
|
entities.append({
|
|
"text": ent.text,
|
|
"label": ent.label_,
|
|
"start": ent.start_char,
|
|
"end": ent.end_char,
|
|
"confidence": float(ent._.get("score", 0.0)) if hasattr(ent._, "score") else 0.0
|
|
})
|
|
|
|
return entities
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error extracting entities: {e}")
|
|
return []
|
|
|
|
def _extract_keywords(self, text: str, max_keywords: int = 20) -> List[str]:
|
|
"""Extract keywords from text"""
|
|
try:
|
|
if NLTK_AVAILABLE:
|
|
words = word_tokenize(text.lower())
|
|
else:
|
|
# Fallback to simple whitespace tokenization
|
|
words = text.lower().split()
|
|
|
|
words = [word for word in words if word.isalpha() and word not in self.stop_words]
|
|
|
|
if self.lemmatizer and NLTK_AVAILABLE:
|
|
words = [self.lemmatizer.lemmatize(word) for word in words]
|
|
|
|
# Simple frequency-based keyword extraction
|
|
word_freq = {}
|
|
for word in words:
|
|
word_freq[word] = word_freq.get(word, 0) + 1
|
|
|
|
# Sort by frequency and return top keywords
|
|
keywords = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
|
|
return [word for word, freq in keywords[:max_keywords] if freq > 1]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error extracting keywords: {e}")
|
|
return []
|
|
|
|
def _clean_text(self, text: str) -> str:
|
|
"""Clean and normalize text"""
|
|
if not text:
|
|
return ""
|
|
|
|
# Remove excessive whitespace
|
|
text = re.sub(r'\s+', ' ', text)
|
|
|
|
# Remove control characters except newlines and tabs
|
|
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text)
|
|
|
|
# Normalize quotes
|
|
text = re.sub(r'[""''`]', '"', text)
|
|
|
|
# Remove excessive punctuation
|
|
text = re.sub(r'[.]{3,}', '...', text)
|
|
text = re.sub(r'[!]{2,}', '!', text)
|
|
text = re.sub(r'[?]{2,}', '?', text)
|
|
|
|
return text.strip()
|
|
|
|
def _validate_content(self, content: str, file_type: str) -> ContentValidationResult:
|
|
"""Validate and score content for security and quality"""
|
|
issues = []
|
|
security_score = 100.0
|
|
|
|
# Check for potentially malicious content
|
|
if '<script' in content.lower() or 'javascript:' in content.lower():
|
|
issues.append("Potentially malicious JavaScript content detected")
|
|
security_score -= 30
|
|
|
|
if re.search(r'<iframe|<object|<embed', content, re.IGNORECASE):
|
|
issues.append("Embedded content detected")
|
|
security_score -= 20
|
|
|
|
# Check for suspicious URLs
|
|
if re.search(r'https?://[^\s]+\.(exe|bat|cmd|scr|vbs|js)', content, re.IGNORECASE):
|
|
issues.append("Suspicious executable URLs detected")
|
|
security_score -= 40
|
|
|
|
# Check content length
|
|
if len(content) > 1000000: # 1MB limit
|
|
issues.append("Content exceeds maximum size limit")
|
|
security_score -= 10
|
|
|
|
# Detect language
|
|
language, lang_confidence = self._detect_language(content)
|
|
|
|
return ContentValidationResult(
|
|
is_valid=len(issues) == 0,
|
|
issues=issues,
|
|
security_score=max(0, security_score),
|
|
content_type=file_type,
|
|
language_confidence=lang_confidence
|
|
)
|
|
|
|
async def cleanup(self):
|
|
"""Cleanup RAG resources"""
|
|
if self.qdrant_client:
|
|
self.qdrant_client.close()
|
|
self.qdrant_client = None
|
|
|
|
if self.embedding_service:
|
|
await self.embedding_service.cleanup()
|
|
self.embedding_service = None
|
|
|
|
# Cleanup content processing resources
|
|
self.nlp_model = None
|
|
self.lemmatizer = None
|
|
self.markitdown = None
|
|
self.stop_words.clear()
|
|
|
|
self.enabled = False
|
|
self.search_cache.clear()
|
|
log_module_event("rag", "cleanup", {"success": True})
|
|
|
|
async def _initialize_embedding_model(self):
|
|
"""Initialize embedding model"""
|
|
# Prefer enhanced embedding service (rate limiting + retry)
|
|
from app.services.enhanced_embedding_service import enhanced_embedding_service as embedding_service
|
|
|
|
# Use intfloat/multilingual-e5-large-instruct for LLM service integration
|
|
model_name = self.config.get("embedding_model", "intfloat/multilingual-e5-large-instruct")
|
|
embedding_service.model_name = model_name
|
|
|
|
# Initialize the embedding service
|
|
success = await embedding_service.initialize()
|
|
|
|
if success:
|
|
self.embedding_service = embedding_service
|
|
logger.info(f"Successfully initialized embedding service with {model_name}")
|
|
return {
|
|
"model_name": model_name,
|
|
"dimension": embedding_service.dimension or 768
|
|
}
|
|
else:
|
|
# Fallback to mock implementation
|
|
logger.warning("Failed to initialize embedding model, using fallback")
|
|
self.embedding_service = None
|
|
return {
|
|
"model_name": model_name,
|
|
"dimension": 1024 # Default dimension for intfloat/multilingual-e5-large-instruct
|
|
}
|
|
|
|
async def _initialize_content_processing(self):
|
|
"""Initialize content processing components"""
|
|
try:
|
|
# Download required NLTK data
|
|
await self._download_nltk_data()
|
|
|
|
# Initialize NLP components
|
|
if NLTK_AVAILABLE:
|
|
self.lemmatizer = WordNetLemmatizer()
|
|
self.stop_words = set(stopwords.words('english'))
|
|
else:
|
|
self.lemmatizer = None
|
|
self.stop_words = set()
|
|
|
|
# Initialize spaCy model
|
|
await self._initialize_spacy_model()
|
|
|
|
# Initialize MarkItDown
|
|
if MARKITDOWN_AVAILABLE:
|
|
self.markitdown = MarkItDown()
|
|
else:
|
|
self.markitdown = None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize some content processing components: {e}")
|
|
|
|
async def _download_nltk_data(self):
|
|
"""Download required NLTK data"""
|
|
if not NLTK_AVAILABLE:
|
|
return
|
|
|
|
try:
|
|
nltk.download('punkt', quiet=True)
|
|
nltk.download('stopwords', quiet=True)
|
|
nltk.download('wordnet', quiet=True)
|
|
nltk.download('averaged_perceptron_tagger', quiet=True)
|
|
nltk.download('omw-1.4', quiet=True)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to download NLTK data: {e}")
|
|
|
|
async def _initialize_spacy_model(self):
|
|
"""Initialize spaCy model for NLP tasks"""
|
|
if not SPACY_AVAILABLE:
|
|
self.nlp_model = None
|
|
return
|
|
|
|
try:
|
|
self.nlp_model = spacy.load("en_core_web_sm")
|
|
except OSError:
|
|
logger.warning("spaCy model 'en_core_web_sm' not found. NLP features will be limited.")
|
|
self.nlp_model = None
|
|
|
|
async def _get_collections_safely(self) -> List[str]:
|
|
"""Get list of collections using raw HTTP to avoid Pydantic validation issues"""
|
|
try:
|
|
import httpx
|
|
qdrant_host = getattr(settings, 'QDRANT_HOST', 'localhost')
|
|
qdrant_port = getattr(settings, 'QDRANT_PORT', 6333)
|
|
qdrant_url = f"http://{qdrant_host}:{qdrant_port}"
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(f"{qdrant_url}/collections")
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
result = data.get("result", {})
|
|
collections = result.get("collections", [])
|
|
return [col.get("name", "") for col in collections if col.get("name")]
|
|
else:
|
|
logger.warning(f"Failed to get collections via HTTP: {response.status_code}")
|
|
return []
|
|
except Exception as e:
|
|
logger.error(f"Error getting collections safely: {e}")
|
|
# Fallback to direct client call with error handling
|
|
try:
|
|
collections = self.qdrant_client.get_collections()
|
|
return [col.name for col in collections.collections]
|
|
except Exception as fallback_error:
|
|
logger.error(f"Fallback collection fetch also failed: {fallback_error}")
|
|
return []
|
|
|
|
async def _get_collection_info_safely(self, collection_name: str) -> Dict[str, Any]:
|
|
"""Get collection information using raw HTTP to avoid Pydantic validation issues"""
|
|
try:
|
|
import httpx
|
|
qdrant_host = getattr(settings, 'QDRANT_HOST', 'localhost')
|
|
qdrant_port = getattr(settings, 'QDRANT_PORT', 6333)
|
|
qdrant_url = f"http://{qdrant_host}:{qdrant_port}"
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(f"{qdrant_url}/collections/{collection_name}")
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
result = data.get("result", {})
|
|
|
|
# Extract relevant information safely
|
|
collection_info = {
|
|
"points_count": result.get("points_count", 0),
|
|
"status": result.get("status", "unknown"),
|
|
"vector_size": 384 # Default fallback
|
|
}
|
|
|
|
# Try to get vector dimension from config
|
|
try:
|
|
config = result.get("config", {})
|
|
params = config.get("params", {})
|
|
vectors = params.get("vectors", {})
|
|
|
|
if isinstance(vectors, dict) and "size" in vectors:
|
|
collection_info["vector_size"] = vectors["size"]
|
|
elif isinstance(vectors, dict):
|
|
# Handle named vectors or default vector
|
|
if 'default' in vectors:
|
|
collection_info["vector_size"] = vectors['default'].get('size', 384)
|
|
else:
|
|
# Take first vector config if no default
|
|
first_vector = next(iter(vectors.values()), {})
|
|
collection_info["vector_size"] = first_vector.get('size', 384)
|
|
except Exception:
|
|
# Keep default fallback
|
|
pass
|
|
|
|
return collection_info
|
|
else:
|
|
logger.warning(f"Failed to get collection info via HTTP: {response.status_code}")
|
|
return {"points_count": 0, "status": "error", "vector_size": 384}
|
|
except Exception as e:
|
|
logger.error(f"Error getting collection info safely: {e}")
|
|
return {"points_count": 0, "status": "error", "vector_size": 384}
|
|
|
|
async def _ensure_collection_exists(self, collection_name: str = None):
|
|
"""Ensure the specified collection exists"""
|
|
collection_name = collection_name or self.default_collection_name
|
|
|
|
try:
|
|
# Use safe collection fetching to avoid Pydantic validation errors
|
|
collection_names = await self._get_collections_safely()
|
|
|
|
if collection_name not in collection_names:
|
|
# Create collection
|
|
self.qdrant_client.create_collection(
|
|
collection_name=collection_name,
|
|
vectors_config=VectorParams(
|
|
size=self.embedding_model.get("dimension", 768),
|
|
distance=Distance.COSINE
|
|
)
|
|
)
|
|
log_module_event("rag", "collection_created", {"collection": collection_name})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error ensuring collection exists: {e}")
|
|
raise
|
|
|
|
async def create_collection(self, collection_name: str) -> bool:
|
|
"""Create a new Qdrant collection"""
|
|
try:
|
|
await self._ensure_collection_exists(collection_name)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error creating collection {collection_name}: {e}")
|
|
return False
|
|
|
|
async def delete_collection(self, collection_name: str) -> bool:
|
|
"""Delete a Qdrant collection"""
|
|
try:
|
|
# Use safe collection fetching to avoid Pydantic validation errors
|
|
collection_names = await self._get_collections_safely()
|
|
|
|
if collection_name in collection_names:
|
|
self.qdrant_client.delete_collection(collection_name)
|
|
log_module_event("rag", "collection_deleted", {"collection": collection_name})
|
|
return True
|
|
else:
|
|
logger.warning(f"Collection {collection_name} does not exist")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting collection {collection_name}: {e}")
|
|
return False
|
|
|
|
async def _generate_embedding(self, text: str) -> List[float]:
|
|
"""Generate embedding for text"""
|
|
if self.embedding_service:
|
|
# Use real embedding service
|
|
return await self.embedding_service.get_embedding(text)
|
|
else:
|
|
# Fallback to deterministic random embedding for consistency
|
|
np.random.seed(hash(text) % 2**32)
|
|
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
|
|
|
|
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
|
|
"""Generate embeddings for multiple texts (batch processing)"""
|
|
if self.embedding_service:
|
|
# Add task-specific prefixes for better E5 model performance
|
|
if is_document:
|
|
# For document passages, use "passage:" prefix
|
|
prefixed_texts = [f"passage: {text}" for text in texts]
|
|
else:
|
|
# For queries, use "query:" prefix (handled in search method)
|
|
prefixed_texts = texts
|
|
|
|
# Use real embedding service for batch processing
|
|
return await self.embedding_service.get_embeddings(prefixed_texts)
|
|
else:
|
|
# Fallback to individual processing
|
|
embeddings = []
|
|
for text in texts:
|
|
embedding = await self._generate_embedding(text)
|
|
embeddings.append(embedding)
|
|
return embeddings
|
|
|
|
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
|
|
"""Split text into overlapping chunks for better context preservation"""
|
|
chunk_size = chunk_size or self.config.get("chunk_size", 300)
|
|
chunk_overlap = self.config.get("chunk_overlap", 50)
|
|
|
|
# Tokenize text
|
|
tokens = self.tokenizer.encode(text)
|
|
|
|
# Split into chunks with overlap
|
|
chunks = []
|
|
start_idx = 0
|
|
|
|
while start_idx < len(tokens):
|
|
end_idx = min(start_idx + chunk_size, len(tokens))
|
|
chunk_tokens = tokens[start_idx:end_idx]
|
|
chunk_text = self.tokenizer.decode(chunk_tokens)
|
|
|
|
# Only add non-empty chunks
|
|
if chunk_text.strip():
|
|
chunks.append(chunk_text)
|
|
|
|
# Move to next chunk with overlap
|
|
start_idx = end_idx - chunk_overlap
|
|
|
|
# Ensure progress (in case overlap >= chunk_size)
|
|
if start_idx >= end_idx:
|
|
start_idx = end_idx
|
|
|
|
return chunks
|
|
|
|
async def _process_text(self, content: bytes, filename: str) -> str:
|
|
"""Process plain text files"""
|
|
try:
|
|
# Try different encodings
|
|
for encoding in ['utf-8', 'latin-1', 'cp1252']:
|
|
try:
|
|
return content.decode(encoding)
|
|
except UnicodeDecodeError:
|
|
continue
|
|
|
|
# Fallback to utf-8 with error handling
|
|
return content.decode('utf-8', errors='replace')
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing text file: {e}")
|
|
return ""
|
|
|
|
async def _process_with_markitdown(self, content: bytes, filename: str) -> str:
|
|
"""Process documents using MarkItDown (PDF, DOCX, DOC, XLSX, XLS)"""
|
|
try:
|
|
if not self.markitdown:
|
|
raise RuntimeError("MarkItDown not initialized")
|
|
|
|
# Create a temporary file path for the content
|
|
import tempfile
|
|
import os
|
|
|
|
# Get file extension from filename
|
|
file_ext = Path(filename).suffix.lower()
|
|
if not file_ext:
|
|
# Try to determine extension from mime type
|
|
mime_type = self._detect_mime_type(filename, content)
|
|
if mime_type == 'application/pdf':
|
|
file_ext = '.pdf'
|
|
elif mime_type in ['application/vnd.openxmlformats-officedocument.wordprocessingml.document']:
|
|
file_ext = '.docx'
|
|
elif mime_type == 'application/msword':
|
|
file_ext = '.doc'
|
|
elif mime_type == 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet':
|
|
file_ext = '.xlsx'
|
|
elif mime_type == 'application/vnd.ms-excel':
|
|
file_ext = '.xls'
|
|
else:
|
|
file_ext = '.bin'
|
|
|
|
# Write content to temporary file
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
|
|
temp_file.write(content)
|
|
temp_path = temp_file.name
|
|
|
|
try:
|
|
# Convert document to markdown using MarkItDown in a thread pool to avoid blocking
|
|
import concurrent.futures
|
|
import asyncio
|
|
|
|
logger.info(f"Starting MarkItDown conversion for {filename}")
|
|
|
|
def convert_sync():
|
|
"""Synchronous conversion function to run in thread pool"""
|
|
return self.markitdown.convert(temp_path)
|
|
|
|
# Run the synchronous conversion in a thread pool with timeout
|
|
loop = asyncio.get_event_loop()
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
try:
|
|
result = await asyncio.wait_for(
|
|
loop.run_in_executor(executor, convert_sync),
|
|
timeout=120.0 # 2 minute timeout for MarkItDown conversion
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logger.error(f"MarkItDown conversion timed out for {filename}")
|
|
raise RuntimeError(f"Document conversion timed out after 2 minutes for {filename}")
|
|
|
|
if result and hasattr(result, 'text_content'):
|
|
converted_text = result.text_content
|
|
elif result and isinstance(result, str):
|
|
converted_text = result
|
|
else:
|
|
# Fallback if result format is unexpected
|
|
converted_text = str(result) if result else ""
|
|
|
|
logger.info(f"Successfully converted {filename} using MarkItDown ({len(converted_text)} characters)")
|
|
return converted_text
|
|
|
|
finally:
|
|
# Clean up temporary file
|
|
try:
|
|
os.unlink(temp_path)
|
|
except OSError:
|
|
pass
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing {filename} with MarkItDown: {e}")
|
|
# Fallback to basic text extraction attempt
|
|
try:
|
|
return content.decode('utf-8', errors='replace')
|
|
except:
|
|
return f"Error processing {filename}: {str(e)}"
|
|
|
|
async def _process_docx(self, content: bytes, filename: str) -> str:
|
|
"""Process DOCX files using python-docx (more reliable than MarkItDown)"""
|
|
try:
|
|
if not PYTHON_DOCX_AVAILABLE:
|
|
logger.warning(f"python-docx not available, falling back to MarkItDown for {filename}")
|
|
return await self._process_with_markitdown(content, filename)
|
|
|
|
# Create a temporary file for python-docx processing
|
|
import tempfile
|
|
import os
|
|
|
|
logger.info(f"Starting DOCX processing for {filename} using python-docx")
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as temp_file:
|
|
temp_file.write(content)
|
|
temp_path = temp_file.name
|
|
|
|
try:
|
|
# Process in a thread pool to avoid blocking
|
|
import concurrent.futures
|
|
import asyncio
|
|
|
|
def extract_docx_text():
|
|
"""Extract text from DOCX file synchronously"""
|
|
doc = DocxDocument(temp_path)
|
|
text_parts = []
|
|
|
|
# Extract paragraphs
|
|
for paragraph in doc.paragraphs:
|
|
if paragraph.text.strip():
|
|
text_parts.append(paragraph.text.strip())
|
|
|
|
# Extract text from tables
|
|
for table in doc.tables:
|
|
for row in table.rows:
|
|
row_text = []
|
|
for cell in row.cells:
|
|
if cell.text.strip():
|
|
row_text.append(cell.text.strip())
|
|
if row_text:
|
|
text_parts.append(" | ".join(row_text))
|
|
|
|
return "\n\n".join(text_parts)
|
|
|
|
# Run extraction in thread pool with timeout
|
|
loop = asyncio.get_event_loop()
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
try:
|
|
extracted_text = await asyncio.wait_for(
|
|
loop.run_in_executor(executor, extract_docx_text),
|
|
timeout=30.0 # 30 second timeout for DOCX processing
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logger.error(f"DOCX processing timed out for {filename}")
|
|
raise RuntimeError(f"DOCX processing timed out after 30 seconds for {filename}")
|
|
|
|
logger.info(f"Successfully processed {filename} using python-docx ({len(extracted_text)} characters)")
|
|
return extracted_text
|
|
|
|
finally:
|
|
# Clean up temporary file
|
|
try:
|
|
os.unlink(temp_path)
|
|
except OSError:
|
|
pass
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing DOCX file {filename}: {e}")
|
|
# Fallback to MarkItDown if python-docx fails
|
|
try:
|
|
logger.info(f"Falling back to MarkItDown for {filename}")
|
|
return await self._process_with_markitdown(content, filename)
|
|
except Exception as fallback_error:
|
|
logger.error(f"Both python-docx and MarkItDown failed for {filename}: {fallback_error}")
|
|
return f"Error processing DOCX {filename}: {str(e)}"
|
|
|
|
async def _process_html(self, content: bytes, filename: str) -> str:
|
|
"""Process HTML files"""
|
|
try:
|
|
html_content = content.decode('utf-8', errors='replace')
|
|
# Simple HTML tag removal
|
|
text = re.sub(r'<[^>]+>', '', html_content)
|
|
# Decode HTML entities
|
|
text = text.replace('&', '&').replace('<', '<').replace('>', '>').replace('"', '"').replace(''', "'")
|
|
return text
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing HTML file: {e}")
|
|
return ""
|
|
|
|
async def _process_json(self, content: bytes, filename: str) -> str:
|
|
"""Process JSON files"""
|
|
try:
|
|
json_data = json.loads(content.decode('utf-8'))
|
|
# Convert JSON to readable text
|
|
return json.dumps(json_data, indent=2)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing JSON file: {e}")
|
|
return ""
|
|
|
|
async def _process_markdown(self, content: bytes, filename: str) -> str:
|
|
"""Process Markdown files"""
|
|
try:
|
|
md_content = content.decode('utf-8', errors='replace')
|
|
# Simple markdown processing - remove formatting
|
|
text = re.sub(r'#+\s*', '', md_content) # Remove headers
|
|
text = re.sub(r'\*\*(.+?)\*\*', r'\1', text) # Bold
|
|
text = re.sub(r'\*(.+?)\*', r'\1', text) # Italic
|
|
text = re.sub(r'`(.+?)`', r'\1', text) # Code
|
|
text = re.sub(r'\[(.+?)\]\(.+?\)', r'\1', text) # Links
|
|
return text
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing Markdown file: {e}")
|
|
return ""
|
|
|
|
async def _process_csv(self, content: bytes, filename: str) -> str:
|
|
"""Process CSV files"""
|
|
try:
|
|
csv_content = content.decode('utf-8', errors='replace')
|
|
# Convert CSV to readable text
|
|
lines = csv_content.split('\n')
|
|
processed_lines = []
|
|
|
|
for line in lines[:100]: # Limit to first 100 lines
|
|
if line.strip():
|
|
processed_lines.append(line.replace(',', ' | '))
|
|
|
|
return '\n'.join(processed_lines)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing CSV file: {e}")
|
|
return ""
|
|
|
|
async def _process_jsonl(self, content: bytes, filename: str) -> str:
|
|
"""Process JSONL files (newline-delimited JSON)
|
|
|
|
Specifically optimized for helpjuice-export.jsonl format:
|
|
- Each line contains a JSON object with 'id' and 'payload'
|
|
- Payload contains 'question', 'language', and 'answer' fields
|
|
- Combines question and answer into searchable content
|
|
|
|
Performance optimizations:
|
|
- Processes articles in smaller batches to reduce memory usage
|
|
- Uses streaming approach for large files
|
|
"""
|
|
try:
|
|
# Use streaming approach for large files
|
|
jsonl_content = content.decode('utf-8', errors='replace')
|
|
lines = jsonl_content.strip().split('\n')
|
|
|
|
processed_articles = []
|
|
batch_size = 50 # Process in batches of 50 articles
|
|
|
|
for line_num, line in enumerate(lines, 1):
|
|
if not line.strip():
|
|
continue
|
|
|
|
try:
|
|
# Parse each JSON line
|
|
data = json.loads(line)
|
|
|
|
# Handle helpjuice export format
|
|
if 'payload' in data:
|
|
payload = data['payload']
|
|
article_id = data.get('id', f'article_{line_num}')
|
|
|
|
# Extract fields
|
|
question = payload.get('question', '')
|
|
answer = payload.get('answer', '')
|
|
language = payload.get('language', 'EN')
|
|
|
|
# Combine question and answer for better search
|
|
if question or answer:
|
|
# Format as Q&A for better context
|
|
article_text = f"## {question}\n\n{answer}\n\n"
|
|
|
|
# Add language tag if not English
|
|
if language != 'EN':
|
|
article_text = f"[{language}] {article_text}"
|
|
|
|
# Add metadata separator
|
|
article_text += f"---\nArticle ID: {article_id}\nLanguage: {language}\n\n"
|
|
|
|
processed_articles.append(article_text)
|
|
|
|
# Handle generic JSONL format
|
|
else:
|
|
# Convert the entire JSON object to readable text
|
|
json_text = json.dumps(data, indent=2, ensure_ascii=False)
|
|
processed_articles.append(json_text + "\n\n")
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(f"Error parsing JSONL line {line_num}: {e}")
|
|
continue
|
|
except Exception as e:
|
|
logger.warning(f"Error processing JSONL line {line_num}: {e}")
|
|
continue
|
|
|
|
# Combine all articles
|
|
combined_text = '\n'.join(processed_articles)
|
|
|
|
logger.info(f"Successfully processed {len(processed_articles)} articles from JSONL file {filename}")
|
|
return combined_text
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing JSONL file {filename}: {e}")
|
|
return ""
|
|
|
|
def _generate_document_id(self, content: str, metadata: Dict[str, Any]) -> str:
|
|
"""Generate unique document ID"""
|
|
content_hash = hashlib.sha256(content.encode()).hexdigest()[:16]
|
|
metadata_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode()).hexdigest()[:8]
|
|
return f"{content_hash}_{metadata_hash}"
|
|
|
|
async def process_document(self, file_data: bytes, filename: str, metadata: Dict[str, Any] = None) -> ProcessedDocument:
|
|
"""Process a document and extract content"""
|
|
if not self.enabled:
|
|
raise RuntimeError("RAG module not initialized")
|
|
|
|
import time
|
|
start_time = time.time()
|
|
|
|
try:
|
|
logger.info(f"Starting document processing pipeline for {filename}")
|
|
|
|
# Generate file hash and ID
|
|
file_hash = self._generate_file_hash(file_data)
|
|
doc_id = f"{file_hash}_{int(time.time())}"
|
|
logger.info(f"Generated document ID: {doc_id}")
|
|
|
|
# Detect MIME type
|
|
mime_type = self._detect_mime_type(filename, file_data)
|
|
file_type = mime_type.split('/')[0]
|
|
logger.info(f"Detected MIME type: {mime_type}, file type: {file_type}")
|
|
|
|
# Check if file type is supported
|
|
if mime_type not in self.supported_types:
|
|
raise ValueError(f"Unsupported file type: {mime_type}")
|
|
|
|
# Extract content using appropriate processor
|
|
processor = self.supported_types[mime_type]
|
|
logger.info(f"Using processor: {processor.__name__} for {filename}")
|
|
extracted_text = await processor(file_data, filename)
|
|
logger.info(f"Content extraction completed for {filename}, extracted {len(extracted_text)} characters")
|
|
|
|
# Clean the extracted text
|
|
logger.info(f"Starting text cleaning for {filename}")
|
|
cleaned_text = self._clean_text(extracted_text)
|
|
logger.info(f"Text cleaning completed for {filename}, final text length: {len(cleaned_text)}")
|
|
|
|
# Validate content
|
|
logger.info(f"Starting content validation for {filename}")
|
|
validation_result = self._validate_content(cleaned_text, file_type)
|
|
logger.info(f"Content validation completed for {filename}")
|
|
|
|
if not validation_result.is_valid:
|
|
logger.warning(f"Content validation issues: {validation_result.issues}")
|
|
|
|
# Extract linguistic features
|
|
logger.info(f"Starting linguistic analysis for {filename}")
|
|
if NLTK_AVAILABLE and cleaned_text:
|
|
logger.info(f"Using NLTK for tokenization of {filename}")
|
|
sentences = sent_tokenize(cleaned_text)
|
|
words = word_tokenize(cleaned_text)
|
|
elif cleaned_text:
|
|
logger.info(f"Using fallback tokenization for {filename}")
|
|
# Fallback to simple tokenization
|
|
sentences = cleaned_text.split('.')
|
|
words = cleaned_text.split()
|
|
else:
|
|
logger.warning(f"No text content for linguistic analysis in {filename}")
|
|
sentences = []
|
|
words = []
|
|
|
|
logger.info(f"Tokenization completed for {filename}: {len(sentences)} sentences, {len(words)} words")
|
|
|
|
# Detect language
|
|
logger.info(f"Starting language detection for {filename}")
|
|
language, lang_confidence = self._detect_language(cleaned_text)
|
|
logger.info(f"Language detection completed for {filename}: {language} (confidence: {lang_confidence:.2f})")
|
|
|
|
# Extract entities and keywords
|
|
logger.info(f"Starting entity extraction for {filename}")
|
|
entities = self._extract_entities(cleaned_text)
|
|
logger.info(f"Entity extraction completed for {filename}: found {len(entities)} entities")
|
|
|
|
logger.info(f"Starting keyword extraction for {filename}")
|
|
keywords = self._extract_keywords(cleaned_text)
|
|
logger.info(f"Keyword extraction completed for {filename}: found {len(keywords)} keywords")
|
|
|
|
# Calculate processing time
|
|
processing_time = time.time() - start_time
|
|
|
|
# Create processed document
|
|
logger.info(f"Creating ProcessedDocument object for {filename}")
|
|
processed_doc = ProcessedDocument(
|
|
id=doc_id,
|
|
original_filename=filename,
|
|
file_type=file_type,
|
|
mime_type=mime_type,
|
|
content=cleaned_text,
|
|
extracted_text=extracted_text,
|
|
metadata={
|
|
**(metadata or {}),
|
|
"validation": asdict(validation_result),
|
|
"file_size": len(file_data),
|
|
"processing_stats": {
|
|
"processing_time": processing_time,
|
|
"processor_used": processor.__name__
|
|
}
|
|
},
|
|
word_count=len(words),
|
|
sentence_count=len(sentences),
|
|
language=language,
|
|
entities=entities,
|
|
keywords=keywords,
|
|
processing_time=processing_time,
|
|
processed_at=datetime.utcnow(),
|
|
file_hash=file_hash,
|
|
file_size=len(file_data)
|
|
)
|
|
logger.info(f"ProcessedDocument created for {filename}")
|
|
|
|
# Update stats
|
|
self.stats["documents_processed"] += 1
|
|
self.stats["total_processing_time"] += processing_time
|
|
self.stats["average_processing_time"] = (
|
|
self.stats["total_processing_time"] / self.stats["documents_processed"]
|
|
)
|
|
|
|
log_module_event("rag", "document_processed", {
|
|
"document_id": doc_id,
|
|
"filename": filename,
|
|
"file_type": file_type,
|
|
"word_count": len(words),
|
|
"processing_time": processing_time,
|
|
"language": language,
|
|
"entities_count": len(entities),
|
|
"keywords_count": len(keywords)
|
|
})
|
|
|
|
logger.info(f"Document processing completed successfully for {filename} in {processing_time:.2f} seconds")
|
|
return processed_doc
|
|
|
|
except Exception as e:
|
|
self.stats["errors"] += 1
|
|
logger.error(f"Error processing document {filename}: {e}")
|
|
log_module_event("rag", "processing_failed", {
|
|
"filename": filename,
|
|
"error": str(e)
|
|
})
|
|
raise
|
|
|
|
async def index_document(self, content: str, metadata: Dict[str, Any] = None, collection_name: str = None) -> str:
|
|
"""Index a document in the vector database (backward compatibility method)"""
|
|
if not self.enabled:
|
|
raise RuntimeError("RAG module not initialized")
|
|
|
|
collection_name = collection_name or self.default_collection_name
|
|
metadata = metadata or {}
|
|
|
|
try:
|
|
# Ensure collection exists
|
|
await self._ensure_collection_exists(collection_name)
|
|
|
|
# Generate document ID
|
|
doc_id = self._generate_document_id(content, metadata)
|
|
|
|
# Check if document already exists
|
|
if await self._document_exists(doc_id, collection_name):
|
|
log_module_event("rag", "document_exists", {"document_id": doc_id, "collection": collection_name})
|
|
return doc_id
|
|
|
|
# Chunk the document
|
|
chunks = self._chunk_text(content)
|
|
|
|
# Generate embeddings for all chunks in batch (more efficient)
|
|
embeddings = await self._generate_embeddings(chunks, is_document=True)
|
|
|
|
# Create document points
|
|
points = []
|
|
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
|
chunk_id = str(uuid.uuid4())
|
|
|
|
chunk_metadata = {
|
|
**metadata,
|
|
"document_id": doc_id,
|
|
"chunk_index": i,
|
|
"chunk_count": len(chunks),
|
|
"content": chunk,
|
|
"indexed_at": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
points.append(PointStruct(
|
|
id=chunk_id,
|
|
vector=embedding,
|
|
payload=chunk_metadata
|
|
))
|
|
|
|
# Insert points into Qdrant
|
|
self.qdrant_client.upsert(
|
|
collection_name=collection_name,
|
|
points=points
|
|
)
|
|
|
|
self.stats["documents_indexed"] += 1
|
|
log_module_event("rag", "document_indexed", {
|
|
"document_id": doc_id,
|
|
"collection": collection_name,
|
|
"chunks": len(chunks),
|
|
"metadata": metadata
|
|
})
|
|
|
|
return doc_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error indexing document: {e}")
|
|
log_module_event("rag", "indexing_failed", {"error": str(e)})
|
|
raise
|
|
|
|
async def index_processed_document(self, processed_doc: ProcessedDocument, collection_name: str = None) -> str:
|
|
"""Index a processed document in the vector database"""
|
|
if not self.enabled:
|
|
raise RuntimeError("RAG module not initialized")
|
|
|
|
collection_name = collection_name or self.default_collection_name
|
|
|
|
try:
|
|
# Special handling for JSONL files
|
|
if processed_doc.file_type == 'jsonl':
|
|
# Import the optimized JSONL processor
|
|
from app.services.jsonl_processor import JSONLProcessor
|
|
jsonl_processor = JSONLProcessor(self)
|
|
|
|
# Read the original file content
|
|
with open(processed_doc.metadata.get('file_path', ''), 'rb') as f:
|
|
file_content = f.read()
|
|
|
|
# Process using the optimized JSONL processor
|
|
return await jsonl_processor.process_and_index_jsonl(
|
|
collection_name=collection_name,
|
|
content=file_content,
|
|
filename=processed_doc.original_filename,
|
|
metadata=processed_doc.metadata
|
|
)
|
|
|
|
# Ensure collection exists
|
|
await self._ensure_collection_exists(collection_name)
|
|
|
|
# Check if document already exists
|
|
if await self._document_exists(processed_doc.id, collection_name):
|
|
log_module_event("rag", "document_exists", {"document_id": processed_doc.id, "collection": collection_name})
|
|
return processed_doc.id
|
|
|
|
# Chunk the document
|
|
chunks = self._chunk_text(processed_doc.content)
|
|
|
|
# Generate embeddings for all chunks in batch (more efficient)
|
|
embeddings = await self._generate_embeddings(chunks, is_document=True)
|
|
|
|
# Create document points with enhanced metadata
|
|
points = []
|
|
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
|
chunk_id = str(uuid.uuid4())
|
|
|
|
chunk_metadata = {
|
|
**processed_doc.metadata,
|
|
"document_id": processed_doc.id,
|
|
"original_filename": processed_doc.original_filename,
|
|
"file_type": processed_doc.file_type,
|
|
"mime_type": processed_doc.mime_type,
|
|
"language": processed_doc.language,
|
|
"entities": processed_doc.entities,
|
|
"keywords": processed_doc.keywords,
|
|
"word_count": processed_doc.word_count,
|
|
"sentence_count": processed_doc.sentence_count,
|
|
"file_hash": processed_doc.file_hash,
|
|
"processed_at": processed_doc.processed_at.isoformat(),
|
|
"chunk_index": i,
|
|
"chunk_count": len(chunks),
|
|
"content": chunk,
|
|
"indexed_at": datetime.utcnow().isoformat()
|
|
}
|
|
|
|
points.append(PointStruct(
|
|
id=chunk_id,
|
|
vector=embedding,
|
|
payload=chunk_metadata
|
|
))
|
|
|
|
# Insert points into Qdrant
|
|
self.qdrant_client.upsert(
|
|
collection_name=collection_name,
|
|
points=points
|
|
)
|
|
|
|
self.stats["documents_indexed"] += 1
|
|
log_module_event("rag", "processed_document_indexed", {
|
|
"document_id": processed_doc.id,
|
|
"filename": processed_doc.original_filename,
|
|
"collection": collection_name,
|
|
"chunks": len(chunks),
|
|
"file_type": processed_doc.file_type,
|
|
"language": processed_doc.language
|
|
})
|
|
|
|
return processed_doc.id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error indexing processed document: {e}")
|
|
log_module_event("rag", "indexing_failed", {"error": str(e)})
|
|
raise
|
|
|
|
async def _document_exists(self, document_id: str, collection_name: str = None) -> bool:
|
|
"""Check if document exists in the collection"""
|
|
collection_name = collection_name or self.default_collection_name
|
|
|
|
try:
|
|
result = self.qdrant_client.search(
|
|
collection_name=collection_name,
|
|
query_filter=Filter(
|
|
must=[FieldCondition(key="document_id", match=MatchValue(value=document_id))]
|
|
),
|
|
limit=1
|
|
)
|
|
return len(result) > 0
|
|
except Exception:
|
|
return False
|
|
|
|
async def _hybrid_search(self, collection_name: str, query: str, query_vector: List[float],
|
|
query_filter: Optional[Filter], limit: int, score_threshold: float) -> List[Any]:
|
|
"""Perform hybrid search combining vector similarity and BM25 scoring"""
|
|
|
|
# Preprocess query for BM25
|
|
query_terms = self._preprocess_text_for_bm25(query)
|
|
|
|
# Get all documents from the collection (for BM25 scoring)
|
|
# Note: In production, you'd want to optimize this with a proper BM25 index
|
|
scroll_filter = query_filter or Filter()
|
|
all_points = []
|
|
|
|
# Use scroll to get all points
|
|
offset = None
|
|
batch_size = 100
|
|
while True:
|
|
search_result = self.qdrant_client.scroll(
|
|
collection_name=collection_name,
|
|
scroll_filter=scroll_filter,
|
|
limit=batch_size,
|
|
offset=offset,
|
|
with_payload=True,
|
|
with_vectors=False
|
|
)
|
|
|
|
points = search_result[0]
|
|
all_points.extend(points)
|
|
|
|
if len(points) < batch_size:
|
|
break
|
|
|
|
offset = points[-1].id
|
|
|
|
# Calculate BM25 scores for each document
|
|
bm25_scores = {}
|
|
for point in all_points:
|
|
doc_id = point.payload.get("document_id", "")
|
|
content = point.payload.get("content", "")
|
|
|
|
# Calculate BM25 score
|
|
bm25_score = self._calculate_bm25_score(query_terms, content)
|
|
bm25_scores[doc_id] = bm25_score
|
|
|
|
# Perform vector search
|
|
vector_results = self.qdrant_client.search(
|
|
collection_name=collection_name,
|
|
query_vector=query_vector,
|
|
query_filter=query_filter,
|
|
limit=limit * 2, # Get more results for re-ranking
|
|
score_threshold=score_threshold / 2 # Lower threshold for initial search
|
|
)
|
|
|
|
# Combine scores with improved normalization
|
|
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
|
|
vector_weight = hybrid_weights.get("vector", 0.7)
|
|
bm25_weight = hybrid_weights.get("bm25", 0.3)
|
|
|
|
# Get score distributions for better normalization
|
|
vector_scores = [r.score for r in vector_results]
|
|
bm25_scores_list = list(bm25_scores.values())
|
|
|
|
# Calculate statistics for normalization
|
|
if vector_scores:
|
|
v_max = max(vector_scores)
|
|
v_min = min(vector_scores)
|
|
v_range = v_max - v_min if v_max != v_min else 1
|
|
else:
|
|
v_max, v_min, v_range = 1, 0, 1
|
|
|
|
if bm25_scores_list:
|
|
bm25_max = max(bm25_scores_list)
|
|
bm25_min = min(bm25_scores_list)
|
|
bm25_range = bm25_max - bm25_min if bm25_max != bm25_min else 1
|
|
else:
|
|
bm25_max, bm25_min, bm25_range = 1, 0, 1
|
|
|
|
# Create hybrid results with improved scoring
|
|
hybrid_results = []
|
|
for result in vector_results:
|
|
doc_id = result.payload.get("document_id", "")
|
|
vector_score = result.score
|
|
bm25_score = bm25_scores.get(doc_id, 0.0)
|
|
|
|
# Improved normalization using actual score distributions
|
|
vector_norm = (vector_score - v_min) / v_range if v_range > 0 else 0.5
|
|
bm25_norm = (bm25_score - bm25_min) / bm25_range if bm25_range > 0 else 0.5
|
|
|
|
# Apply reciprocal rank fusion for better combination
|
|
# This gives more weight to documents that rank highly in both methods
|
|
rrf_vector = 1.0 / (1.0 + vector_results.index(result) + 1) # +1 to avoid division by zero
|
|
rrf_bm25 = 1.0 / (1.0 + sorted(bm25_scores_list, reverse=True).index(bm25_score) + 1) if bm25_score in bm25_scores_list else 0
|
|
|
|
# Calculate hybrid score using both normalized scores and RRF
|
|
hybrid_score = (vector_weight * vector_norm + bm25_weight * bm25_norm) * 0.7 + (rrf_vector + rrf_bm25) * 0.3
|
|
|
|
# Create new point with hybrid score
|
|
hybrid_point = ScoredPoint(
|
|
id=result.id,
|
|
payload=result.payload,
|
|
score=hybrid_score,
|
|
vector=result.vector,
|
|
shard_key=None,
|
|
order_value=None
|
|
)
|
|
hybrid_results.append(hybrid_point)
|
|
|
|
# Sort by hybrid score and apply final threshold
|
|
hybrid_results.sort(key=lambda x: x.score, reverse=True)
|
|
final_results = [r for r in hybrid_results if r.score >= score_threshold][:limit]
|
|
|
|
logger.info(f"Hybrid search: {len(vector_results)} vector results, {len(final_results)} final results")
|
|
return final_results
|
|
|
|
def _preprocess_text_for_bm25(self, text: str) -> List[str]:
|
|
"""Preprocess text for BM25 scoring"""
|
|
if not NLTK_AVAILABLE:
|
|
return text.lower().split()
|
|
|
|
try:
|
|
# Tokenize
|
|
tokens = word_tokenize(text.lower())
|
|
|
|
# Remove stopwords and non-alphabetic tokens
|
|
stop_words = set(stopwords.words('english'))
|
|
filtered_tokens = [
|
|
token for token in tokens
|
|
if token.isalpha() and token not in stop_words and len(token) > 2
|
|
]
|
|
|
|
return filtered_tokens
|
|
except:
|
|
# Fallback to simple splitting
|
|
return text.lower().split()
|
|
|
|
def _calculate_bm25_score(self, query_terms: List[str], document: str) -> float:
|
|
"""Calculate BM25 score for a document against query terms"""
|
|
if not query_terms:
|
|
return 0.0
|
|
|
|
# Preprocess document
|
|
doc_terms = self._preprocess_text_for_bm25(document)
|
|
if not doc_terms:
|
|
return 0.0
|
|
|
|
# Calculate term frequencies
|
|
doc_len = len(doc_terms)
|
|
avg_doc_len = 300 # Average document length (configurable)
|
|
|
|
# BM25 parameters
|
|
k1 = 1.2 # Controls term frequency saturation
|
|
b = 0.75 # Controls document length normalization
|
|
|
|
score = 0.0
|
|
|
|
# Calculate IDF for each query term
|
|
for term in set(query_terms):
|
|
# Term frequency in document
|
|
tf = doc_terms.count(term)
|
|
|
|
# Simple IDF (log(N/n) + 1)
|
|
# In production, you'd use the actual document frequency
|
|
idf = 2.0 # Simplified IDF
|
|
|
|
# BM25 formula
|
|
numerator = tf * (k1 + 1)
|
|
denominator = tf + k1 * (1 - b + b * (doc_len / avg_doc_len))
|
|
|
|
score += idf * (numerator / denominator)
|
|
|
|
# Normalize score to 0-1 range
|
|
return min(score / 10.0, 1.0) # Simple normalization
|
|
|
|
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
|
|
"""Search for relevant documents"""
|
|
if not self.enabled:
|
|
raise RuntimeError("RAG module not initialized")
|
|
|
|
collection_name = collection_name or self.default_collection_name
|
|
max_results = max_results or self.config.get("max_results", 10)
|
|
|
|
# Check cache (include collection name in cache key)
|
|
cache_key = f"{collection_name}_{query}_{max_results}_{hash(str(filters))}"
|
|
if cache_key in self.search_cache:
|
|
self.stats["cache_hits"] += 1
|
|
return self.search_cache[cache_key]
|
|
|
|
try:
|
|
import time
|
|
start_time = time.time()
|
|
|
|
# Generate query embedding with task-specific prefix for better retrieval
|
|
# The E5 model works better with "query:" prefix for search queries
|
|
optimized_query = f"query: {query}"
|
|
query_embedding = await self._generate_embedding(optimized_query)
|
|
|
|
# Build filter
|
|
search_filter = None
|
|
if filters:
|
|
conditions = []
|
|
for key, value in filters.items():
|
|
conditions.append(FieldCondition(key=key, match=MatchValue(value=value)))
|
|
search_filter = Filter(must=conditions)
|
|
|
|
# Enhanced debugging for search
|
|
logger.info("=== ENHANCED RAG SEARCH DEBUGGING ===")
|
|
logger.info(f"Collection: {collection_name}")
|
|
logger.info(f"Query: '{query}'")
|
|
logger.info(f"Max results requested: {max_results}")
|
|
logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}")
|
|
logger.info(f"Embedding service available: {self.embedding_service is not None}")
|
|
|
|
# Check if hybrid search is enabled
|
|
enable_hybrid = self.config.get("enable_hybrid", False)
|
|
# Use provided score_threshold or fall back to config
|
|
search_score_threshold = score_threshold if score_threshold is not None else self.config.get("score_threshold", 0.3)
|
|
|
|
if enable_hybrid and NLTK_AVAILABLE:
|
|
# Perform hybrid search (vector + BM25)
|
|
search_results = await self._hybrid_search(
|
|
collection_name=collection_name,
|
|
query=query,
|
|
query_vector=query_embedding,
|
|
query_filter=search_filter,
|
|
limit=max_results,
|
|
score_threshold=search_score_threshold
|
|
)
|
|
else:
|
|
# Pure vector search with improved threshold
|
|
search_results = self.qdrant_client.search(
|
|
collection_name=collection_name,
|
|
query_vector=query_embedding,
|
|
query_filter=search_filter,
|
|
limit=max_results,
|
|
score_threshold=search_score_threshold
|
|
)
|
|
|
|
logger.info(f"Raw search results count: {len(search_results)}")
|
|
|
|
# Process results
|
|
results = []
|
|
document_scores = {}
|
|
|
|
for i, result in enumerate(search_results):
|
|
doc_id = result.payload.get("document_id")
|
|
content = result.payload.get("content", "")
|
|
score = result.score
|
|
|
|
# Log each raw result for debugging
|
|
logger.info(f"\n--- Raw Result {i+1} ---")
|
|
logger.info(f"Score: {score}")
|
|
logger.info(f"Document ID: {doc_id}")
|
|
logger.info(f"Content preview (first 200 chars): {content[:200]}")
|
|
logger.info(f"Metadata keys: {list(result.payload.keys())}")
|
|
|
|
# Aggregate scores by document
|
|
if doc_id in document_scores:
|
|
document_scores[doc_id]["score"] = max(document_scores[doc_id]["score"], score)
|
|
document_scores[doc_id]["content"] += "\n" + content
|
|
else:
|
|
document_scores[doc_id] = {
|
|
"score": score,
|
|
"content": content,
|
|
"metadata": {k: v for k, v in result.payload.items() if k not in ["content", "document_id"]}
|
|
}
|
|
|
|
logger.info(f"\nAggregated documents count: {len(document_scores)}")
|
|
logger.info("=== END ENHANCED RAG SEARCH DEBUGGING ===")
|
|
|
|
# Create SearchResult objects
|
|
for doc_id, data in document_scores.items():
|
|
document = Document(
|
|
id=doc_id,
|
|
content=data["content"],
|
|
metadata=data["metadata"]
|
|
)
|
|
|
|
search_result = SearchResult(
|
|
document=document,
|
|
score=data["score"],
|
|
relevance_score=min(data["score"] * 100, 100)
|
|
)
|
|
|
|
results.append(search_result)
|
|
|
|
# Sort by score
|
|
results.sort(key=lambda x: x.score, reverse=True)
|
|
|
|
# Update stats
|
|
search_time = time.time() - start_time
|
|
self.stats["searches_performed"] += 1
|
|
self.stats["average_search_time"] = (
|
|
(self.stats["average_search_time"] * (self.stats["searches_performed"] - 1) + search_time) /
|
|
self.stats["searches_performed"]
|
|
)
|
|
|
|
# Cache results
|
|
self.search_cache[cache_key] = results
|
|
|
|
log_module_event("rag", "search_completed", {
|
|
"query": query,
|
|
"collection": collection_name,
|
|
"results_count": len(results),
|
|
"search_time": search_time
|
|
})
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error searching documents in collection {collection_name}: {e}")
|
|
log_module_event("rag", "search_failed", {"error": str(e), "collection": collection_name})
|
|
raise
|
|
|
|
async def delete_document(self, document_id: str, collection_name: str = None) -> bool:
|
|
"""Delete a document from the vector database"""
|
|
if not self.enabled:
|
|
raise RuntimeError("RAG module not initialized")
|
|
|
|
collection_name = collection_name or self.default_collection_name
|
|
|
|
try:
|
|
# Delete all chunks for this document
|
|
self.qdrant_client.delete(
|
|
collection_name=collection_name,
|
|
points_selector=models.FilterSelector(
|
|
filter=Filter(
|
|
must=[FieldCondition(key="document_id", match=MatchValue(value=document_id))]
|
|
)
|
|
)
|
|
)
|
|
|
|
log_module_event("rag", "document_deleted", {"document_id": document_id, "collection": collection_name})
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting document from collection {collection_name}: {e}")
|
|
log_module_event("rag", "deletion_failed", {"error": str(e), "collection": collection_name})
|
|
return False
|
|
|
|
async def get_stats(self) -> Dict[str, Any]:
|
|
"""Get RAG module statistics"""
|
|
stats = self.stats.copy()
|
|
|
|
if self.enabled:
|
|
try:
|
|
# Use raw HTTP call to avoid Pydantic validation issues
|
|
import httpx
|
|
|
|
# Direct HTTP call to Qdrant API instead of using client to avoid Pydantic issues
|
|
qdrant_url = f"http://{settings.QDRANT_HOST}:{settings.QDRANT_PORT}"
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(f"{qdrant_url}/collections/{self.default_collection_name}")
|
|
|
|
if response.status_code == 200:
|
|
collection_data = response.json()
|
|
|
|
# Safely extract stats from raw JSON
|
|
result = collection_data.get("result", {})
|
|
|
|
basic_stats = {
|
|
"total_points": result.get("points_count", 0),
|
|
"collection_status": result.get("status", "unknown"),
|
|
}
|
|
|
|
# Try to get vector dimension from config
|
|
try:
|
|
config = result.get("config", {})
|
|
params = config.get("params", {})
|
|
vectors = params.get("vectors", {})
|
|
|
|
if isinstance(vectors, dict) and "size" in vectors:
|
|
basic_stats["vector_dimension"] = vectors["size"]
|
|
else:
|
|
basic_stats["vector_dimension"] = "unknown"
|
|
except Exception as config_error:
|
|
logger.debug(f"Could not get vector dimension: {config_error}")
|
|
basic_stats["vector_dimension"] = "unknown"
|
|
|
|
stats.update(basic_stats)
|
|
else:
|
|
# Collection doesn't exist or error
|
|
stats.update({
|
|
"total_points": 0,
|
|
"collection_status": "not_found",
|
|
"vector_dimension": "unknown"
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Could not get Qdrant stats (using fallback): {e}")
|
|
# Add basic fallback stats without logging as error since this is not critical
|
|
stats.update({
|
|
"total_points": 0,
|
|
"collection_status": "unavailable",
|
|
"vector_dimension": "unknown"
|
|
})
|
|
else:
|
|
stats.update({
|
|
"total_points": 0,
|
|
"collection_status": "disabled",
|
|
"vector_dimension": "unknown"
|
|
})
|
|
|
|
return stats
|
|
|
|
async def process_request(self, request: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Process a module request through the interceptor pattern"""
|
|
if not self.enabled:
|
|
raise RuntimeError("RAG module not initialized")
|
|
|
|
action = request.get("action", "search")
|
|
|
|
if action == "search":
|
|
query = request.get("query")
|
|
if not query:
|
|
raise ValueError("Query is required for search action")
|
|
|
|
max_results = request.get("max_results", self.config.get("max_results", 10))
|
|
filters = request.get("filters", {})
|
|
|
|
results = await self.search_documents(query, max_results, filters)
|
|
|
|
return {
|
|
"action": "search",
|
|
"query": query,
|
|
"results": [
|
|
{
|
|
"document_id": result.document.id,
|
|
"content": result.document.content,
|
|
"metadata": result.document.metadata,
|
|
"score": result.score,
|
|
"relevance_score": result.relevance_score
|
|
}
|
|
for result in results
|
|
],
|
|
"total_results": len(results),
|
|
"cache_hit": False # Would be determined by search logic
|
|
}
|
|
|
|
elif action == "index":
|
|
content = request.get("content")
|
|
if not content:
|
|
raise ValueError("Content is required for index action")
|
|
|
|
metadata = request.get("metadata", {})
|
|
document_id = await self.index_document(content, metadata)
|
|
|
|
return {
|
|
"action": "index",
|
|
"document_id": document_id,
|
|
"status": "success",
|
|
"message": "Document indexed successfully"
|
|
}
|
|
|
|
elif action == "process":
|
|
file_data = request.get("file_data")
|
|
filename = request.get("filename")
|
|
if not file_data or not filename:
|
|
raise ValueError("File data and filename are required for process action")
|
|
|
|
# Decode base64 file data if provided as string
|
|
if isinstance(file_data, str):
|
|
import base64
|
|
file_data = base64.b64decode(file_data)
|
|
|
|
metadata = request.get("metadata", {})
|
|
processed_doc = await self.process_document(file_data, filename, metadata)
|
|
|
|
return {
|
|
"action": "process",
|
|
"document_id": processed_doc.id,
|
|
"filename": processed_doc.original_filename,
|
|
"file_type": processed_doc.file_type,
|
|
"mime_type": processed_doc.mime_type,
|
|
"word_count": processed_doc.word_count,
|
|
"sentence_count": processed_doc.sentence_count,
|
|
"language": processed_doc.language,
|
|
"entities_count": len(processed_doc.entities),
|
|
"keywords_count": len(processed_doc.keywords),
|
|
"processing_time": processed_doc.processing_time,
|
|
"status": "success",
|
|
"message": "Document processed successfully"
|
|
}
|
|
|
|
elif action == "delete":
|
|
document_id = request.get("document_id")
|
|
if not document_id:
|
|
raise ValueError("Document ID is required for delete action")
|
|
|
|
success = await self.delete_document(document_id)
|
|
|
|
return {
|
|
"action": "delete",
|
|
"document_id": document_id,
|
|
"status": "success" if success else "failed",
|
|
"message": "Document deleted successfully" if success else "Failed to delete document"
|
|
}
|
|
|
|
elif action == "stats":
|
|
stats = await self.get_stats()
|
|
|
|
return {
|
|
"action": "stats",
|
|
"statistics": stats
|
|
}
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported action: {action}")
|
|
|
|
async def pre_request_interceptor(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Pre-request interceptor for RAG enhancement"""
|
|
if not self.enabled:
|
|
return context
|
|
|
|
request = context.get("request")
|
|
if not request:
|
|
return context
|
|
|
|
# Check if this is a request that could benefit from RAG
|
|
if request.url.path.startswith("/api/v1/chat") or request.url.path.startswith("/api/v1/completions"):
|
|
# Extract query/prompt from request
|
|
request_body = await request.body() if hasattr(request, 'body') else b""
|
|
|
|
if request_body:
|
|
try:
|
|
data = json.loads(request_body.decode())
|
|
query = data.get("message", data.get("prompt", ""))
|
|
|
|
if query:
|
|
# Search for relevant documents
|
|
search_results = await self.search_documents(query, max_results=3)
|
|
|
|
if search_results:
|
|
# Add context to request
|
|
context["rag_context"] = [
|
|
{
|
|
"content": result.document.content,
|
|
"metadata": result.document.metadata,
|
|
"relevance_score": result.relevance_score
|
|
}
|
|
for result in search_results
|
|
]
|
|
|
|
log_module_event("rag", "context_added", {
|
|
"query": query[:100],
|
|
"results_count": len(search_results)
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing RAG request: {e}")
|
|
|
|
return context
|
|
|
|
# Global RAG instance
|
|
rag_module = RAGModule()
|
|
|
|
# Module interface functions
|
|
async def initialize(config: Dict[str, Any]):
|
|
"""Initialize RAG module"""
|
|
await rag_module.initialize(config)
|
|
|
|
async def cleanup():
|
|
"""Cleanup RAG module"""
|
|
await rag_module.cleanup()
|
|
|
|
async def pre_request_interceptor(context: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Pre-request interceptor"""
|
|
return await rag_module.pre_request_interceptor(context)
|
|
|
|
# Additional exported functions
|
|
async def process_document(file_data: bytes, filename: str, metadata: Dict[str, Any] = None) -> ProcessedDocument:
|
|
"""Process a document with full content analysis"""
|
|
return await rag_module.process_document(file_data, filename, metadata)
|
|
|
|
async def index_document(content: str, metadata: Dict[str, Any] = None, collection_name: str = None) -> str:
|
|
"""Index a document (backward compatibility)"""
|
|
return await rag_module.index_document(content, metadata, collection_name)
|
|
|
|
async def index_processed_document(processed_doc: ProcessedDocument, collection_name: str = None) -> str:
|
|
"""Index a processed document"""
|
|
return await rag_module.index_processed_document(processed_doc, collection_name)
|
|
|
|
async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
|
|
"""Search documents"""
|
|
return await rag_module.search_documents(query, max_results, filters, collection_name, score_threshold)
|
|
|
|
async def delete_document(document_id: str, collection_name: str = None) -> bool:
|
|
"""Delete a document"""
|
|
return await rag_module.delete_document(document_id, collection_name)
|
|
|
|
async def create_collection(collection_name: str) -> bool:
|
|
"""Create a new Qdrant collection"""
|
|
return await rag_module.create_collection(collection_name)
|
|
|
|
async def delete_collection(collection_name: str) -> bool:
|
|
"""Delete a Qdrant collection"""
|
|
return await rag_module.delete_collection(collection_name)
|
|
|
|
async def get_supported_types() -> List[str]:
|
|
"""Get list of supported file types"""
|
|
return list(rag_module.supported_types.keys())
|