mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
rag improvements
This commit is contained in:
@@ -53,14 +53,13 @@ except ImportError:
|
||||
PYTHON_DOCX_AVAILABLE = False
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
|
||||
from qdrant_client.models import Distance, VectorParams, PointStruct, ScoredPoint, Filter, FieldCondition, MatchValue
|
||||
from qdrant_client.http import models
|
||||
import tiktoken
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import log_module_event
|
||||
from app.services.base_module import BaseModule, Permission
|
||||
from app.services.enhanced_embedding_service import enhanced_embedding_service
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -134,6 +133,19 @@ class RAGModule(BaseModule):
|
||||
self.embedding_model = None
|
||||
self.embedding_service = None
|
||||
self.tokenizer = None
|
||||
|
||||
# Set improved default configuration
|
||||
self.config = {
|
||||
"chunk_size": 300, # Reduced from 400 for better precision
|
||||
"chunk_overlap": 50, # Added overlap for context preservation
|
||||
"max_results": 10,
|
||||
"score_threshold": 0.3, # Increased from 0.0 to filter low-quality results
|
||||
"enable_hybrid": True, # Enable hybrid search (vector + BM25)
|
||||
"hybrid_weights": {"vector": 0.7, "bm25": 0.3} # Weight for hybrid scoring
|
||||
}
|
||||
# Update with any provided config
|
||||
if config:
|
||||
self.config.update(config)
|
||||
|
||||
# Content processing components
|
||||
self.nlp_model = None
|
||||
@@ -640,19 +652,33 @@ class RAGModule(BaseModule):
|
||||
return embeddings
|
||||
|
||||
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
|
||||
"""Split text into chunks"""
|
||||
chunk_size = chunk_size or self.config.get("chunk_size", 400)
|
||||
|
||||
"""Split text into overlapping chunks for better context preservation"""
|
||||
chunk_size = chunk_size or self.config.get("chunk_size", 300)
|
||||
chunk_overlap = self.config.get("chunk_overlap", 50)
|
||||
|
||||
# Tokenize text
|
||||
tokens = self.tokenizer.encode(text)
|
||||
|
||||
# Split into chunks
|
||||
|
||||
# Split into chunks with overlap
|
||||
chunks = []
|
||||
for i in range(0, len(tokens), chunk_size):
|
||||
chunk_tokens = tokens[i:i + chunk_size]
|
||||
start_idx = 0
|
||||
|
||||
while start_idx < len(tokens):
|
||||
end_idx = min(start_idx + chunk_size, len(tokens))
|
||||
chunk_tokens = tokens[start_idx:end_idx]
|
||||
chunk_text = self.tokenizer.decode(chunk_tokens)
|
||||
chunks.append(chunk_text)
|
||||
|
||||
|
||||
# Only add non-empty chunks
|
||||
if chunk_text.strip():
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Move to next chunk with overlap
|
||||
start_idx = end_idx - chunk_overlap
|
||||
|
||||
# Ensure progress (in case overlap >= chunk_size)
|
||||
if start_idx >= end_idx:
|
||||
start_idx = end_idx
|
||||
|
||||
return chunks
|
||||
|
||||
async def _process_text(self, content: bytes, filename: str) -> str:
|
||||
@@ -1126,17 +1152,9 @@ class RAGModule(BaseModule):
|
||||
# Chunk the document
|
||||
chunks = self._chunk_text(content)
|
||||
|
||||
# Generate embeddings with enhanced rate limiting handling
|
||||
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
|
||||
|
||||
# Log if fallback embeddings were used
|
||||
if not success:
|
||||
logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded")
|
||||
log_module_event("rag", "fallback_embeddings_used", {
|
||||
"document_id": doc_id,
|
||||
"content_preview": content[:100] + "..." if len(content) > 100 else content
|
||||
})
|
||||
|
||||
# Generate embeddings for all chunks in batch (more efficient)
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
|
||||
# Create document points
|
||||
points = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
@@ -1197,17 +1215,9 @@ class RAGModule(BaseModule):
|
||||
# Chunk the document
|
||||
chunks = self._chunk_text(processed_doc.content)
|
||||
|
||||
# Generate embeddings with enhanced rate limiting handling
|
||||
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
|
||||
|
||||
# Log if fallback embeddings were used
|
||||
if not success:
|
||||
logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded")
|
||||
log_module_event("rag", "fallback_embeddings_used", {
|
||||
"document_id": processed_doc.id,
|
||||
"filename": processed_doc.original_filename
|
||||
})
|
||||
|
||||
# Generate embeddings for all chunks in batch (more efficient)
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
|
||||
# Create document points with enhanced metadata
|
||||
points = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
@@ -1277,6 +1287,154 @@ class RAGModule(BaseModule):
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _hybrid_search(self, collection_name: str, query: str, query_vector: List[float],
|
||||
query_filter: Optional[Filter], limit: int, score_threshold: float) -> List[Any]:
|
||||
"""Perform hybrid search combining vector similarity and BM25 scoring"""
|
||||
|
||||
# Preprocess query for BM25
|
||||
query_terms = self._preprocess_text_for_bm25(query)
|
||||
|
||||
# Get all documents from the collection (for BM25 scoring)
|
||||
# Note: In production, you'd want to optimize this with a proper BM25 index
|
||||
scroll_filter = query_filter or Filter()
|
||||
all_points = []
|
||||
|
||||
# Use scroll to get all points
|
||||
offset = None
|
||||
batch_size = 100
|
||||
while True:
|
||||
search_result = self.qdrant_client.scroll(
|
||||
collection_name=collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=batch_size,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
points = search_result[0]
|
||||
all_points.extend(points)
|
||||
|
||||
if len(points) < batch_size:
|
||||
break
|
||||
|
||||
offset = points[-1].id
|
||||
|
||||
# Calculate BM25 scores for each document
|
||||
bm25_scores = {}
|
||||
for point in all_points:
|
||||
doc_id = point.payload.get("document_id", "")
|
||||
content = point.payload.get("content", "")
|
||||
|
||||
# Calculate BM25 score
|
||||
bm25_score = self._calculate_bm25_score(query_terms, content)
|
||||
bm25_scores[doc_id] = bm25_score
|
||||
|
||||
# Perform vector search
|
||||
vector_results = self.qdrant_client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=query_filter,
|
||||
limit=limit * 2, # Get more results for re-ranking
|
||||
score_threshold=score_threshold / 2 # Lower threshold for initial search
|
||||
)
|
||||
|
||||
# Combine scores
|
||||
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
|
||||
vector_weight = hybrid_weights.get("vector", 0.7)
|
||||
bm25_weight = hybrid_weights.get("bm25", 0.3)
|
||||
|
||||
# Create hybrid results
|
||||
hybrid_results = []
|
||||
for result in vector_results:
|
||||
doc_id = result.payload.get("document_id", "")
|
||||
vector_score = result.score
|
||||
bm25_score = bm25_scores.get(doc_id, 0.0)
|
||||
|
||||
# Normalize scores (simple min-max normalization)
|
||||
vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0
|
||||
bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1
|
||||
|
||||
# Calculate hybrid score
|
||||
hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm)
|
||||
|
||||
# Create new point with hybrid score
|
||||
hybrid_point = ScoredPoint(
|
||||
id=result.id,
|
||||
payload=result.payload,
|
||||
score=hybrid_score,
|
||||
vector=result.vector,
|
||||
shard_key=None,
|
||||
order_value=None
|
||||
)
|
||||
hybrid_results.append(hybrid_point)
|
||||
|
||||
# Sort by hybrid score and apply final threshold
|
||||
hybrid_results.sort(key=lambda x: x.score, reverse=True)
|
||||
final_results = [r for r in hybrid_results if r.score >= score_threshold][:limit]
|
||||
|
||||
logger.info(f"Hybrid search: {len(vector_results)} vector results, {len(final_results)} final results")
|
||||
return final_results
|
||||
|
||||
def _preprocess_text_for_bm25(self, text: str) -> List[str]:
|
||||
"""Preprocess text for BM25 scoring"""
|
||||
if not NLTK_AVAILABLE:
|
||||
return text.lower().split()
|
||||
|
||||
try:
|
||||
# Tokenize
|
||||
tokens = word_tokenize(text.lower())
|
||||
|
||||
# Remove stopwords and non-alphabetic tokens
|
||||
stop_words = set(stopwords.words('english'))
|
||||
filtered_tokens = [
|
||||
token for token in tokens
|
||||
if token.isalpha() and token not in stop_words and len(token) > 2
|
||||
]
|
||||
|
||||
return filtered_tokens
|
||||
except:
|
||||
# Fallback to simple splitting
|
||||
return text.lower().split()
|
||||
|
||||
def _calculate_bm25_score(self, query_terms: List[str], document: str) -> float:
|
||||
"""Calculate BM25 score for a document against query terms"""
|
||||
if not query_terms:
|
||||
return 0.0
|
||||
|
||||
# Preprocess document
|
||||
doc_terms = self._preprocess_text_for_bm25(document)
|
||||
if not doc_terms:
|
||||
return 0.0
|
||||
|
||||
# Calculate term frequencies
|
||||
doc_len = len(doc_terms)
|
||||
avg_doc_len = 300 # Average document length (configurable)
|
||||
|
||||
# BM25 parameters
|
||||
k1 = 1.2 # Controls term frequency saturation
|
||||
b = 0.75 # Controls document length normalization
|
||||
|
||||
score = 0.0
|
||||
|
||||
# Calculate IDF for each query term
|
||||
for term in set(query_terms):
|
||||
# Term frequency in document
|
||||
tf = doc_terms.count(term)
|
||||
|
||||
# Simple IDF (log(N/n) + 1)
|
||||
# In production, you'd use the actual document frequency
|
||||
idf = 2.0 # Simplified IDF
|
||||
|
||||
# BM25 formula
|
||||
numerator = tf * (k1 + 1)
|
||||
denominator = tf + k1 * (1 - b + b * (doc_len / avg_doc_len))
|
||||
|
||||
score += idf * (numerator / denominator)
|
||||
|
||||
# Normalize score to 0-1 range
|
||||
return min(score / 10.0, 1.0) # Simple normalization
|
||||
|
||||
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
|
||||
"""Search for relevant documents"""
|
||||
if not self.enabled:
|
||||
@@ -1314,14 +1472,29 @@ class RAGModule(BaseModule):
|
||||
logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}")
|
||||
logger.info(f"Embedding service available: {self.embedding_service is not None}")
|
||||
|
||||
# Search in Qdrant
|
||||
search_results = self.qdrant_client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_embedding,
|
||||
query_filter=search_filter,
|
||||
limit=max_results,
|
||||
score_threshold=0.0 # Lowered from 0.5 to see all results including low scores
|
||||
)
|
||||
# Check if hybrid search is enabled
|
||||
enable_hybrid = self.config.get("enable_hybrid", False)
|
||||
score_threshold = self.config.get("score_threshold", 0.3)
|
||||
|
||||
if enable_hybrid and NLTK_AVAILABLE:
|
||||
# Perform hybrid search (vector + BM25)
|
||||
search_results = await self._hybrid_search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
query_vector=query_embedding,
|
||||
query_filter=search_filter,
|
||||
limit=max_results,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
else:
|
||||
# Pure vector search with improved threshold
|
||||
search_results = self.qdrant_client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_embedding,
|
||||
query_filter=search_filter,
|
||||
limit=max_results,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
|
||||
logger.info(f"Raw search results count: {len(search_results)}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user