mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
rag improvements
This commit is contained in:
@@ -638,11 +638,19 @@ class RAGModule(BaseModule):
|
||||
np.random.seed(hash(text) % 2**32)
|
||||
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
|
||||
|
||||
async def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts (batch processing)"""
|
||||
if self.embedding_service:
|
||||
# Add task-specific prefixes for better E5 model performance
|
||||
if is_document:
|
||||
# For document passages, use "passage:" prefix
|
||||
prefixed_texts = [f"passage: {text}" for text in texts]
|
||||
else:
|
||||
# For queries, use "query:" prefix (handled in search method)
|
||||
prefixed_texts = texts
|
||||
|
||||
# Use real embedding service for batch processing
|
||||
return await self.embedding_service.get_embeddings(texts)
|
||||
return await self.embedding_service.get_embeddings(prefixed_texts)
|
||||
else:
|
||||
# Fallback to individual processing
|
||||
embeddings = []
|
||||
@@ -917,69 +925,75 @@ class RAGModule(BaseModule):
|
||||
|
||||
async def _process_jsonl(self, content: bytes, filename: str) -> str:
|
||||
"""Process JSONL files (newline-delimited JSON)
|
||||
|
||||
|
||||
Specifically optimized for helpjuice-export.jsonl format:
|
||||
- Each line contains a JSON object with 'id' and 'payload'
|
||||
- Payload contains 'question', 'language', and 'answer' fields
|
||||
- Combines question and answer into searchable content
|
||||
|
||||
Performance optimizations:
|
||||
- Processes articles in smaller batches to reduce memory usage
|
||||
- Uses streaming approach for large files
|
||||
"""
|
||||
try:
|
||||
# Use streaming approach for large files
|
||||
jsonl_content = content.decode('utf-8', errors='replace')
|
||||
lines = jsonl_content.strip().split('\n')
|
||||
|
||||
|
||||
processed_articles = []
|
||||
|
||||
batch_size = 50 # Process in batches of 50 articles
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# Parse each JSON line
|
||||
data = json.loads(line)
|
||||
|
||||
|
||||
# Handle helpjuice export format
|
||||
if 'payload' in data:
|
||||
payload = data['payload']
|
||||
article_id = data.get('id', f'article_{line_num}')
|
||||
|
||||
|
||||
# Extract fields
|
||||
question = payload.get('question', '')
|
||||
answer = payload.get('answer', '')
|
||||
language = payload.get('language', 'EN')
|
||||
|
||||
|
||||
# Combine question and answer for better search
|
||||
if question or answer:
|
||||
# Format as Q&A for better context
|
||||
article_text = f"## {question}\n\n{answer}\n\n"
|
||||
|
||||
|
||||
# Add language tag if not English
|
||||
if language != 'EN':
|
||||
article_text = f"[{language}] {article_text}"
|
||||
|
||||
|
||||
# Add metadata separator
|
||||
article_text += f"---\nArticle ID: {article_id}\nLanguage: {language}\n\n"
|
||||
|
||||
|
||||
processed_articles.append(article_text)
|
||||
|
||||
|
||||
# Handle generic JSONL format
|
||||
else:
|
||||
# Convert the entire JSON object to readable text
|
||||
json_text = json.dumps(data, indent=2, ensure_ascii=False)
|
||||
processed_articles.append(json_text + "\n\n")
|
||||
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Error parsing JSONL line {line_num}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing JSONL line {line_num}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Combine all articles
|
||||
combined_text = '\n'.join(processed_articles)
|
||||
|
||||
|
||||
logger.info(f"Successfully processed {len(processed_articles)} articles from JSONL file {filename}")
|
||||
return combined_text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing JSONL file {filename}: {e}")
|
||||
return ""
|
||||
@@ -1153,7 +1167,7 @@ class RAGModule(BaseModule):
|
||||
chunks = self._chunk_text(content)
|
||||
|
||||
# Generate embeddings for all chunks in batch (more efficient)
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
embeddings = await self._generate_embeddings(chunks, is_document=True)
|
||||
|
||||
# Create document points
|
||||
points = []
|
||||
@@ -1200,10 +1214,28 @@ class RAGModule(BaseModule):
|
||||
"""Index a processed document in the vector database"""
|
||||
if not self.enabled:
|
||||
raise RuntimeError("RAG module not initialized")
|
||||
|
||||
|
||||
collection_name = collection_name or self.default_collection_name
|
||||
|
||||
|
||||
try:
|
||||
# Special handling for JSONL files
|
||||
if processed_doc.file_type == 'jsonl':
|
||||
# Import the optimized JSONL processor
|
||||
from app.services.jsonl_processor import JSONLProcessor
|
||||
jsonl_processor = JSONLProcessor(self)
|
||||
|
||||
# Read the original file content
|
||||
with open(processed_doc.metadata.get('file_path', ''), 'rb') as f:
|
||||
file_content = f.read()
|
||||
|
||||
# Process using the optimized JSONL processor
|
||||
return await jsonl_processor.process_and_index_jsonl(
|
||||
collection_name=collection_name,
|
||||
content=file_content,
|
||||
filename=processed_doc.original_filename,
|
||||
metadata=processed_doc.metadata
|
||||
)
|
||||
|
||||
# Ensure collection exists
|
||||
await self._ensure_collection_exists(collection_name)
|
||||
|
||||
@@ -1216,7 +1248,7 @@ class RAGModule(BaseModule):
|
||||
chunks = self._chunk_text(processed_doc.content)
|
||||
|
||||
# Generate embeddings for all chunks in batch (more efficient)
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
embeddings = await self._generate_embeddings(chunks, is_document=True)
|
||||
|
||||
# Create document points with enhanced metadata
|
||||
points = []
|
||||
@@ -1339,24 +1371,48 @@ class RAGModule(BaseModule):
|
||||
score_threshold=score_threshold / 2 # Lower threshold for initial search
|
||||
)
|
||||
|
||||
# Combine scores
|
||||
# Combine scores with improved normalization
|
||||
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
|
||||
vector_weight = hybrid_weights.get("vector", 0.7)
|
||||
bm25_weight = hybrid_weights.get("bm25", 0.3)
|
||||
|
||||
# Create hybrid results
|
||||
# Get score distributions for better normalization
|
||||
vector_scores = [r.score for r in vector_results]
|
||||
bm25_scores_list = list(bm25_scores.values())
|
||||
|
||||
# Calculate statistics for normalization
|
||||
if vector_scores:
|
||||
v_max = max(vector_scores)
|
||||
v_min = min(vector_scores)
|
||||
v_range = v_max - v_min if v_max != v_min else 1
|
||||
else:
|
||||
v_max, v_min, v_range = 1, 0, 1
|
||||
|
||||
if bm25_scores_list:
|
||||
bm25_max = max(bm25_scores_list)
|
||||
bm25_min = min(bm25_scores_list)
|
||||
bm25_range = bm25_max - bm25_min if bm25_max != bm25_min else 1
|
||||
else:
|
||||
bm25_max, bm25_min, bm25_range = 1, 0, 1
|
||||
|
||||
# Create hybrid results with improved scoring
|
||||
hybrid_results = []
|
||||
for result in vector_results:
|
||||
doc_id = result.payload.get("document_id", "")
|
||||
vector_score = result.score
|
||||
bm25_score = bm25_scores.get(doc_id, 0.0)
|
||||
|
||||
# Normalize scores (simple min-max normalization)
|
||||
vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0
|
||||
bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1
|
||||
# Improved normalization using actual score distributions
|
||||
vector_norm = (vector_score - v_min) / v_range if v_range > 0 else 0.5
|
||||
bm25_norm = (bm25_score - bm25_min) / bm25_range if bm25_range > 0 else 0.5
|
||||
|
||||
# Calculate hybrid score
|
||||
hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm)
|
||||
# Apply reciprocal rank fusion for better combination
|
||||
# This gives more weight to documents that rank highly in both methods
|
||||
rrf_vector = 1.0 / (1.0 + vector_results.index(result) + 1) # +1 to avoid division by zero
|
||||
rrf_bm25 = 1.0 / (1.0 + sorted(bm25_scores_list, reverse=True).index(bm25_score) + 1) if bm25_score in bm25_scores_list else 0
|
||||
|
||||
# Calculate hybrid score using both normalized scores and RRF
|
||||
hybrid_score = (vector_weight * vector_norm + bm25_weight * bm25_norm) * 0.7 + (rrf_vector + rrf_bm25) * 0.3
|
||||
|
||||
# Create new point with hybrid score
|
||||
hybrid_point = ScoredPoint(
|
||||
@@ -1435,7 +1491,7 @@ class RAGModule(BaseModule):
|
||||
# Normalize score to 0-1 range
|
||||
return min(score / 10.0, 1.0) # Simple normalization
|
||||
|
||||
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
|
||||
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
|
||||
"""Search for relevant documents"""
|
||||
if not self.enabled:
|
||||
raise RuntimeError("RAG module not initialized")
|
||||
@@ -1453,8 +1509,10 @@ class RAGModule(BaseModule):
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await self._generate_embedding(query)
|
||||
# Generate query embedding with task-specific prefix for better retrieval
|
||||
# The E5 model works better with "query:" prefix for search queries
|
||||
optimized_query = f"query: {query}"
|
||||
query_embedding = await self._generate_embedding(optimized_query)
|
||||
|
||||
# Build filter
|
||||
search_filter = None
|
||||
@@ -1474,7 +1532,8 @@ class RAGModule(BaseModule):
|
||||
|
||||
# Check if hybrid search is enabled
|
||||
enable_hybrid = self.config.get("enable_hybrid", False)
|
||||
score_threshold = self.config.get("score_threshold", 0.3)
|
||||
# Use provided score_threshold or fall back to config
|
||||
search_score_threshold = score_threshold if score_threshold is not None else self.config.get("score_threshold", 0.3)
|
||||
|
||||
if enable_hybrid and NLTK_AVAILABLE:
|
||||
# Perform hybrid search (vector + BM25)
|
||||
@@ -1484,7 +1543,7 @@ class RAGModule(BaseModule):
|
||||
query_vector=query_embedding,
|
||||
query_filter=search_filter,
|
||||
limit=max_results,
|
||||
score_threshold=score_threshold
|
||||
score_threshold=search_score_threshold
|
||||
)
|
||||
else:
|
||||
# Pure vector search with improved threshold
|
||||
@@ -1493,7 +1552,7 @@ class RAGModule(BaseModule):
|
||||
query_vector=query_embedding,
|
||||
query_filter=search_filter,
|
||||
limit=max_results,
|
||||
score_threshold=score_threshold
|
||||
score_threshold=search_score_threshold
|
||||
)
|
||||
|
||||
logger.info(f"Raw search results count: {len(search_results)}")
|
||||
@@ -1841,9 +1900,9 @@ async def index_processed_document(processed_doc: ProcessedDocument, collection_
|
||||
"""Index a processed document"""
|
||||
return await rag_module.index_processed_document(processed_doc, collection_name)
|
||||
|
||||
async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
|
||||
async def search_documents(query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None, score_threshold: float = None) -> List[SearchResult]:
|
||||
"""Search documents"""
|
||||
return await rag_module.search_documents(query, max_results, filters, collection_name)
|
||||
return await rag_module.search_documents(query, max_results, filters, collection_name, score_threshold)
|
||||
|
||||
async def delete_document(document_id: str, collection_name: str = None) -> bool:
|
||||
"""Delete a document"""
|
||||
|
||||
Reference in New Issue
Block a user