rag improvements

This commit is contained in:
2025-09-21 18:44:02 +02:00
parent f58a76ac59
commit a2ee959ec9
3 changed files with 401 additions and 86 deletions

View File

@@ -53,14 +53,13 @@ except ImportError:
PYTHON_DOCX_AVAILABLE = False
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
from qdrant_client.models import Distance, VectorParams, PointStruct, ScoredPoint, Filter, FieldCondition, MatchValue
from qdrant_client.http import models
import tiktoken
from app.core.config import settings
from app.core.logging import log_module_event
from app.services.base_module import BaseModule, Permission
from app.services.enhanced_embedding_service import enhanced_embedding_service
@dataclass
@@ -134,6 +133,19 @@ class RAGModule(BaseModule):
self.embedding_model = None
self.embedding_service = None
self.tokenizer = None
# Set improved default configuration
self.config = {
"chunk_size": 300, # Reduced from 400 for better precision
"chunk_overlap": 50, # Added overlap for context preservation
"max_results": 10,
"score_threshold": 0.3, # Increased from 0.0 to filter low-quality results
"enable_hybrid": True, # Enable hybrid search (vector + BM25)
"hybrid_weights": {"vector": 0.7, "bm25": 0.3} # Weight for hybrid scoring
}
# Update with any provided config
if config:
self.config.update(config)
# Content processing components
self.nlp_model = None
@@ -640,19 +652,33 @@ class RAGModule(BaseModule):
return embeddings
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
"""Split text into chunks"""
chunk_size = chunk_size or self.config.get("chunk_size", 400)
"""Split text into overlapping chunks for better context preservation"""
chunk_size = chunk_size or self.config.get("chunk_size", 300)
chunk_overlap = self.config.get("chunk_overlap", 50)
# Tokenize text
tokens = self.tokenizer.encode(text)
# Split into chunks
# Split into chunks with overlap
chunks = []
for i in range(0, len(tokens), chunk_size):
chunk_tokens = tokens[i:i + chunk_size]
start_idx = 0
while start_idx < len(tokens):
end_idx = min(start_idx + chunk_size, len(tokens))
chunk_tokens = tokens[start_idx:end_idx]
chunk_text = self.tokenizer.decode(chunk_tokens)
chunks.append(chunk_text)
# Only add non-empty chunks
if chunk_text.strip():
chunks.append(chunk_text)
# Move to next chunk with overlap
start_idx = end_idx - chunk_overlap
# Ensure progress (in case overlap >= chunk_size)
if start_idx >= end_idx:
start_idx = end_idx
return chunks
async def _process_text(self, content: bytes, filename: str) -> str:
@@ -1126,17 +1152,9 @@ class RAGModule(BaseModule):
# Chunk the document
chunks = self._chunk_text(content)
# Generate embeddings with enhanced rate limiting handling
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
# Log if fallback embeddings were used
if not success:
logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded")
log_module_event("rag", "fallback_embeddings_used", {
"document_id": doc_id,
"content_preview": content[:100] + "..." if len(content) > 100 else content
})
# Generate embeddings for all chunks in batch (more efficient)
embeddings = await self._generate_embeddings(chunks)
# Create document points
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
@@ -1197,17 +1215,9 @@ class RAGModule(BaseModule):
# Chunk the document
chunks = self._chunk_text(processed_doc.content)
# Generate embeddings with enhanced rate limiting handling
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
# Log if fallback embeddings were used
if not success:
logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded")
log_module_event("rag", "fallback_embeddings_used", {
"document_id": processed_doc.id,
"filename": processed_doc.original_filename
})
# Generate embeddings for all chunks in batch (more efficient)
embeddings = await self._generate_embeddings(chunks)
# Create document points with enhanced metadata
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
@@ -1277,6 +1287,154 @@ class RAGModule(BaseModule):
except Exception:
return False
async def _hybrid_search(self, collection_name: str, query: str, query_vector: List[float],
query_filter: Optional[Filter], limit: int, score_threshold: float) -> List[Any]:
"""Perform hybrid search combining vector similarity and BM25 scoring"""
# Preprocess query for BM25
query_terms = self._preprocess_text_for_bm25(query)
# Get all documents from the collection (for BM25 scoring)
# Note: In production, you'd want to optimize this with a proper BM25 index
scroll_filter = query_filter or Filter()
all_points = []
# Use scroll to get all points
offset = None
batch_size = 100
while True:
search_result = self.qdrant_client.scroll(
collection_name=collection_name,
scroll_filter=scroll_filter,
limit=batch_size,
offset=offset,
with_payload=True,
with_vectors=False
)
points = search_result[0]
all_points.extend(points)
if len(points) < batch_size:
break
offset = points[-1].id
# Calculate BM25 scores for each document
bm25_scores = {}
for point in all_points:
doc_id = point.payload.get("document_id", "")
content = point.payload.get("content", "")
# Calculate BM25 score
bm25_score = self._calculate_bm25_score(query_terms, content)
bm25_scores[doc_id] = bm25_score
# Perform vector search
vector_results = self.qdrant_client.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=query_filter,
limit=limit * 2, # Get more results for re-ranking
score_threshold=score_threshold / 2 # Lower threshold for initial search
)
# Combine scores
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
vector_weight = hybrid_weights.get("vector", 0.7)
bm25_weight = hybrid_weights.get("bm25", 0.3)
# Create hybrid results
hybrid_results = []
for result in vector_results:
doc_id = result.payload.get("document_id", "")
vector_score = result.score
bm25_score = bm25_scores.get(doc_id, 0.0)
# Normalize scores (simple min-max normalization)
vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0
bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1
# Calculate hybrid score
hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm)
# Create new point with hybrid score
hybrid_point = ScoredPoint(
id=result.id,
payload=result.payload,
score=hybrid_score,
vector=result.vector,
shard_key=None,
order_value=None
)
hybrid_results.append(hybrid_point)
# Sort by hybrid score and apply final threshold
hybrid_results.sort(key=lambda x: x.score, reverse=True)
final_results = [r for r in hybrid_results if r.score >= score_threshold][:limit]
logger.info(f"Hybrid search: {len(vector_results)} vector results, {len(final_results)} final results")
return final_results
def _preprocess_text_for_bm25(self, text: str) -> List[str]:
"""Preprocess text for BM25 scoring"""
if not NLTK_AVAILABLE:
return text.lower().split()
try:
# Tokenize
tokens = word_tokenize(text.lower())
# Remove stopwords and non-alphabetic tokens
stop_words = set(stopwords.words('english'))
filtered_tokens = [
token for token in tokens
if token.isalpha() and token not in stop_words and len(token) > 2
]
return filtered_tokens
except:
# Fallback to simple splitting
return text.lower().split()
def _calculate_bm25_score(self, query_terms: List[str], document: str) -> float:
"""Calculate BM25 score for a document against query terms"""
if not query_terms:
return 0.0
# Preprocess document
doc_terms = self._preprocess_text_for_bm25(document)
if not doc_terms:
return 0.0
# Calculate term frequencies
doc_len = len(doc_terms)
avg_doc_len = 300 # Average document length (configurable)
# BM25 parameters
k1 = 1.2 # Controls term frequency saturation
b = 0.75 # Controls document length normalization
score = 0.0
# Calculate IDF for each query term
for term in set(query_terms):
# Term frequency in document
tf = doc_terms.count(term)
# Simple IDF (log(N/n) + 1)
# In production, you'd use the actual document frequency
idf = 2.0 # Simplified IDF
# BM25 formula
numerator = tf * (k1 + 1)
denominator = tf + k1 * (1 - b + b * (doc_len / avg_doc_len))
score += idf * (numerator / denominator)
# Normalize score to 0-1 range
return min(score / 10.0, 1.0) # Simple normalization
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
"""Search for relevant documents"""
if not self.enabled:
@@ -1314,14 +1472,29 @@ class RAGModule(BaseModule):
logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}")
logger.info(f"Embedding service available: {self.embedding_service is not None}")
# Search in Qdrant
search_results = self.qdrant_client.search(
collection_name=collection_name,
query_vector=query_embedding,
query_filter=search_filter,
limit=max_results,
score_threshold=0.0 # Lowered from 0.5 to see all results including low scores
)
# Check if hybrid search is enabled
enable_hybrid = self.config.get("enable_hybrid", False)
score_threshold = self.config.get("score_threshold", 0.3)
if enable_hybrid and NLTK_AVAILABLE:
# Perform hybrid search (vector + BM25)
search_results = await self._hybrid_search(
collection_name=collection_name,
query=query,
query_vector=query_embedding,
query_filter=search_filter,
limit=max_results,
score_threshold=score_threshold
)
else:
# Pure vector search with improved threshold
search_results = self.qdrant_client.search(
collection_name=collection_name,
query_vector=query_embedding,
query_filter=search_filter,
limit=max_results,
score_threshold=score_threshold
)
logger.info(f"Raw search results count: {len(search_results)}")