mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
rag improvements
This commit is contained in:
153
backend/app/services/llm/token_rate_limiter.py
Normal file
153
backend/app/services/llm/token_rate_limiter.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Token-based rate limiting for LLM service
|
||||
"""
|
||||
|
||||
import time
|
||||
import redis
|
||||
from typing import Dict, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from ..core.config import settings
|
||||
from ..core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenRateLimiter:
|
||||
"""Token-based rate limiting implementation"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
self.redis_client.ping()
|
||||
logger.info("Token rate limiter initialized with Redis backend")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis not available for token rate limiting: {e}")
|
||||
self.redis_client = None
|
||||
# Fall back to in-memory rate limiting
|
||||
self.in_memory_store = {}
|
||||
logger.info("Token rate limiter using in-memory fallback")
|
||||
|
||||
async def check_token_limits(
|
||||
self,
|
||||
provider: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int = 0
|
||||
) -> Tuple[bool, Dict[str, str]]:
|
||||
"""
|
||||
Check if token usage is within limits
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g., "privatemode")
|
||||
prompt_tokens: Number of prompt tokens to use
|
||||
completion_tokens: Number of completion tokens to use
|
||||
|
||||
Returns:
|
||||
Tuple of (is_allowed, headers)
|
||||
"""
|
||||
# Get token limits from configuration
|
||||
from .config import get_config
|
||||
config = get_config()
|
||||
token_limits = config.token_limits_per_minute
|
||||
|
||||
# Check organization-wide limits
|
||||
org_key = f"tokens:org:{provider}"
|
||||
|
||||
# Get current usage
|
||||
current_usage = await self._get_token_usage(org_key)
|
||||
|
||||
# Calculate new usage
|
||||
new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens
|
||||
new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens
|
||||
|
||||
# Check limits
|
||||
prompt_limit = token_limits.get("prompt_tokens", 20000)
|
||||
completion_limit = token_limits.get("completion_tokens", 10000)
|
||||
|
||||
is_allowed = (
|
||||
new_prompt_tokens <= prompt_limit and
|
||||
new_completion_tokens <= completion_limit
|
||||
)
|
||||
|
||||
if is_allowed:
|
||||
# Update usage
|
||||
await self._update_token_usage(org_key, prompt_tokens, completion_tokens)
|
||||
logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, "
|
||||
f"{new_completion_tokens}/{completion_limit} completion")
|
||||
|
||||
# Calculate remaining tokens
|
||||
remaining_prompt = max(0, prompt_limit - new_prompt_tokens)
|
||||
remaining_completion = max(0, completion_limit - new_completion_tokens)
|
||||
|
||||
# Create headers
|
||||
headers = {
|
||||
"X-TokenLimit-Prompt-Remaining": str(remaining_prompt),
|
||||
"X-TokenLimit-Completion-Remaining": str(remaining_completion),
|
||||
"X-TokenLimit-Prompt-Limit": str(prompt_limit),
|
||||
"X-TokenLimit-Completion-Limit": str(completion_limit),
|
||||
"X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute
|
||||
}
|
||||
|
||||
if not is_allowed:
|
||||
logger.warning(f"Token rate limit exceeded for {provider}. "
|
||||
f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. "
|
||||
f"Current: {current_usage}")
|
||||
|
||||
return is_allowed, headers
|
||||
|
||||
async def _get_token_usage(self, key: str) -> Dict[str, int]:
|
||||
"""Get current token usage"""
|
||||
if self.redis_client:
|
||||
try:
|
||||
data = self.redis_client.hgetall(key)
|
||||
if data:
|
||||
return {
|
||||
"prompt_tokens": int(data.get("prompt_tokens", 0)),
|
||||
"completion_tokens": int(data.get("completion_tokens", 0)),
|
||||
"updated_at": float(data.get("updated_at", time.time()))
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting token usage from Redis: {e}")
|
||||
|
||||
# Fallback to in-memory
|
||||
return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0})
|
||||
|
||||
async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int):
|
||||
"""Update token usage"""
|
||||
if self.redis_client:
|
||||
try:
|
||||
pipe = self.redis_client.pipeline()
|
||||
pipe.hincrby(key, "prompt_tokens", prompt_tokens)
|
||||
pipe.hincrby(key, "completion_tokens", completion_tokens)
|
||||
pipe.hset(key, "updated_at", time.time())
|
||||
pipe.expire(key, 60) # Expire after 1 minute
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating token usage in Redis: {e}")
|
||||
# Fallback to in-memory
|
||||
self._update_in_memory(key, prompt_tokens, completion_tokens)
|
||||
else:
|
||||
self._update_in_memory(key, prompt_tokens, completion_tokens)
|
||||
|
||||
def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int):
|
||||
"""Update in-memory token usage"""
|
||||
if key not in self.in_memory_store:
|
||||
self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
|
||||
self.in_memory_store[key]["prompt_tokens"] += prompt_tokens
|
||||
self.in_memory_store[key]["completion_tokens"] += completion_tokens
|
||||
self.in_memory_store[key]["updated_at"] = time.time()
|
||||
|
||||
def cleanup_expired(self):
|
||||
"""Clean up expired entries (for in-memory store)"""
|
||||
if not self.redis_client:
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, data in self.in_memory_store.items()
|
||||
if current_time - data.get("updated_at", 0) > 60
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self.in_memory_store[key]
|
||||
|
||||
|
||||
# Global token rate limiter instance
|
||||
token_rate_limiter = TokenRateLimiter()
|
||||
@@ -265,7 +265,6 @@ class ChatbotModule(BaseModule):
|
||||
|
||||
async def chat_completion(self, request: ChatRequest, user_id: str, db: Session) -> ChatResponse:
|
||||
"""Generate chat completion response"""
|
||||
logger.info("=== CHAT COMPLETION METHOD CALLED ===")
|
||||
|
||||
# Get chatbot configuration from database
|
||||
db_chatbot = db.query(DBChatbotInstance).filter(DBChatbotInstance.id == request.chatbot_id).first()
|
||||
@@ -367,7 +366,6 @@ class ChatbotModule(BaseModule):
|
||||
async def _generate_response(self, message: str, db_messages: List[DBMessage],
|
||||
config: ChatbotConfig, context: Optional[Dict] = None, db: Session = None) -> tuple[str, Optional[List]]:
|
||||
"""Generate response using LLM with optional RAG"""
|
||||
logger.info("=== _generate_response METHOD CALLED ===")
|
||||
|
||||
# Lazy load dependencies if not available
|
||||
await self._ensure_dependencies()
|
||||
@@ -397,8 +395,8 @@ class ChatbotModule(BaseModule):
|
||||
for i, result in enumerate(rag_results)]
|
||||
|
||||
# Build full RAG context from all results
|
||||
rag_context = "\\n\\nRelevant information from knowledge base:\\n" + "\\n\\n".join([
|
||||
f"[Document {i+1}]:\\n{result.document.content}" for i, result in enumerate(rag_results)
|
||||
rag_context = "\n\nRelevant information from knowledge base:\n" + "\n\n".join([
|
||||
f"[Document {i+1}]:\n{result.document.content}" for i, result in enumerate(rag_results)
|
||||
])
|
||||
|
||||
# Detailed RAG logging - ALWAYS log for debugging
|
||||
@@ -407,14 +405,14 @@ class ChatbotModule(BaseModule):
|
||||
logger.info(f"Collection: {qdrant_collection_name}")
|
||||
logger.info(f"Number of results: {len(rag_results)}")
|
||||
for i, result in enumerate(rag_results):
|
||||
logger.info(f"\\n--- RAG Result {i+1} ---")
|
||||
logger.info(f"\n--- RAG Result {i+1} ---")
|
||||
logger.info(f"Score: {getattr(result, 'score', 'N/A')}")
|
||||
logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}")
|
||||
logger.info(f"Full Content ({len(result.document.content)} chars):")
|
||||
logger.info(f"{result.document.content}")
|
||||
if hasattr(result.document, 'metadata'):
|
||||
logger.info(f"Metadata: {result.document.metadata}")
|
||||
logger.info(f"\\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
|
||||
logger.info(f"\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
|
||||
logger.info(rag_context)
|
||||
logger.info("=== END RAG SEARCH RESULTS ===")
|
||||
else:
|
||||
@@ -428,11 +426,6 @@ class ChatbotModule(BaseModule):
|
||||
logger.warning(f"RAG search traceback: {traceback.format_exc()}")
|
||||
|
||||
# Build conversation context (includes the current message from db_messages)
|
||||
logger.info(f"=== CRITICAL DEBUG ===")
|
||||
logger.info(f"rag_context length: {len(rag_context)}")
|
||||
logger.info(f"rag_context empty: {not rag_context}")
|
||||
logger.info(f"rag_context preview: {rag_context[:200] if rag_context else 'EMPTY'}")
|
||||
logger.info(f"=== END CRITICAL DEBUG ===")
|
||||
messages = self._build_conversation_messages(db_messages, config, rag_context, context)
|
||||
|
||||
# Note: Current user message is already included in db_messages from the query
|
||||
@@ -452,9 +445,9 @@ class ChatbotModule(BaseModule):
|
||||
if config.use_rag and rag_context:
|
||||
logger.info(f"RAG context added: {len(rag_context)} characters")
|
||||
logger.info(f"RAG sources: {len(sources) if sources else 0} documents")
|
||||
logger.info("\\n=== COMPLETE MESSAGES SENT TO LLM ===")
|
||||
logger.info("\n=== COMPLETE MESSAGES SENT TO LLM ===")
|
||||
for i, msg in enumerate(messages):
|
||||
logger.info(f"\\n--- Message {i+1} ---")
|
||||
logger.info(f"\n--- Message {i+1} ---")
|
||||
logger.info(f"Role: {msg['role']}")
|
||||
logger.info(f"Content ({len(msg['content'])} chars):")
|
||||
# Truncate long content for logging (full RAG context can be very long)
|
||||
@@ -523,10 +516,13 @@ class ChatbotModule(BaseModule):
|
||||
"""Build messages array for LLM completion"""
|
||||
|
||||
messages = []
|
||||
logger.info(f"DEBUG: _build_conversation_messages called. rag_context length: {len(rag_context)}")
|
||||
|
||||
# System prompt - keep it clean without RAG context
|
||||
# System prompt
|
||||
system_prompt = config.system_prompt
|
||||
if rag_context:
|
||||
# Add explicit instruction to use RAG context
|
||||
system_prompt += "\n\nIMPORTANT: Use the following information from the knowledge base to answer the user's question. " \
|
||||
"This information is directly relevant to their query and should be your primary source:\n" + rag_context
|
||||
if context and context.get('additional_instructions'):
|
||||
system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}"
|
||||
|
||||
@@ -540,16 +536,9 @@ class ChatbotModule(BaseModule):
|
||||
for idx, msg in enumerate(reversed(db_messages)):
|
||||
logger.info(f"Processing message {idx}: role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...")
|
||||
if msg.role in ["user", "assistant"]:
|
||||
# For user messages, prepend RAG context if available
|
||||
content = msg.content
|
||||
if msg.role == "user" and rag_context and idx == 0:
|
||||
# Add RAG context to the current user message (first in reversed order)
|
||||
content = f"Relevant information from knowledge base:\n{rag_context}\n\nQuestion: {msg.content}"
|
||||
logger.info("Added RAG context to user message")
|
||||
|
||||
messages.append({
|
||||
"role": msg.role,
|
||||
"content": content
|
||||
"content": msg.content
|
||||
})
|
||||
logger.info(f"Added message with role {msg.role} to LLM messages")
|
||||
else:
|
||||
@@ -693,7 +682,6 @@ class ChatbotModule(BaseModule):
|
||||
async def chat(self, chatbot_config: Dict[str, Any], message: str,
|
||||
conversation_history: List = None, user_id: str = "anonymous") -> Dict[str, Any]:
|
||||
"""Chat method for API compatibility"""
|
||||
logger.info("=== CHAT METHOD (API COMPATIBILITY) CALLED ===")
|
||||
logger.info(f"Chat method called with message: {message[:50]}... by user: {user_id}")
|
||||
|
||||
# Lazy load dependencies
|
||||
@@ -723,20 +711,21 @@ class ChatbotModule(BaseModule):
|
||||
fallback_responses=chatbot_config.get("fallback_responses", [])
|
||||
)
|
||||
|
||||
# For API compatibility, create a temporary DBMessage for the current message
|
||||
# so RAG context can be properly added
|
||||
from app.models.chatbot import ChatbotMessage as DBMessage
|
||||
# Generate response using internal method
|
||||
# Create a temporary message object for the current user message
|
||||
temp_messages = [
|
||||
DBMessage(
|
||||
id=0,
|
||||
conversation_id=0,
|
||||
role="user",
|
||||
content=message,
|
||||
timestamp=datetime.utcnow(),
|
||||
metadata={}
|
||||
)
|
||||
]
|
||||
|
||||
# Create a temporary user message with the current message
|
||||
temp_user_message = DBMessage(
|
||||
conversation_id="temp_conversation",
|
||||
role=MessageRole.USER.value,
|
||||
content=message
|
||||
)
|
||||
|
||||
# Generate response using internal method with the current message included
|
||||
response_content, sources = await self._generate_response(
|
||||
message, [temp_user_message], config, None, db
|
||||
message, temp_messages, config, None, db
|
||||
)
|
||||
|
||||
return {
|
||||
|
||||
@@ -53,14 +53,13 @@ except ImportError:
|
||||
PYTHON_DOCX_AVAILABLE = False
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
|
||||
from qdrant_client.models import Distance, VectorParams, PointStruct, ScoredPoint, Filter, FieldCondition, MatchValue
|
||||
from qdrant_client.http import models
|
||||
import tiktoken
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import log_module_event
|
||||
from app.services.base_module import BaseModule, Permission
|
||||
from app.services.enhanced_embedding_service import enhanced_embedding_service
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -135,6 +134,19 @@ class RAGModule(BaseModule):
|
||||
self.embedding_service = None
|
||||
self.tokenizer = None
|
||||
|
||||
# Set improved default configuration
|
||||
self.config = {
|
||||
"chunk_size": 300, # Reduced from 400 for better precision
|
||||
"chunk_overlap": 50, # Added overlap for context preservation
|
||||
"max_results": 10,
|
||||
"score_threshold": 0.3, # Increased from 0.0 to filter low-quality results
|
||||
"enable_hybrid": True, # Enable hybrid search (vector + BM25)
|
||||
"hybrid_weights": {"vector": 0.7, "bm25": 0.3} # Weight for hybrid scoring
|
||||
}
|
||||
# Update with any provided config
|
||||
if config:
|
||||
self.config.update(config)
|
||||
|
||||
# Content processing components
|
||||
self.nlp_model = None
|
||||
self.lemmatizer = None
|
||||
@@ -640,18 +652,32 @@ class RAGModule(BaseModule):
|
||||
return embeddings
|
||||
|
||||
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
|
||||
"""Split text into chunks"""
|
||||
chunk_size = chunk_size or self.config.get("chunk_size", 400)
|
||||
"""Split text into overlapping chunks for better context preservation"""
|
||||
chunk_size = chunk_size or self.config.get("chunk_size", 300)
|
||||
chunk_overlap = self.config.get("chunk_overlap", 50)
|
||||
|
||||
# Tokenize text
|
||||
tokens = self.tokenizer.encode(text)
|
||||
|
||||
# Split into chunks
|
||||
# Split into chunks with overlap
|
||||
chunks = []
|
||||
for i in range(0, len(tokens), chunk_size):
|
||||
chunk_tokens = tokens[i:i + chunk_size]
|
||||
start_idx = 0
|
||||
|
||||
while start_idx < len(tokens):
|
||||
end_idx = min(start_idx + chunk_size, len(tokens))
|
||||
chunk_tokens = tokens[start_idx:end_idx]
|
||||
chunk_text = self.tokenizer.decode(chunk_tokens)
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Only add non-empty chunks
|
||||
if chunk_text.strip():
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Move to next chunk with overlap
|
||||
start_idx = end_idx - chunk_overlap
|
||||
|
||||
# Ensure progress (in case overlap >= chunk_size)
|
||||
if start_idx >= end_idx:
|
||||
start_idx = end_idx
|
||||
|
||||
return chunks
|
||||
|
||||
@@ -1126,16 +1152,8 @@ class RAGModule(BaseModule):
|
||||
# Chunk the document
|
||||
chunks = self._chunk_text(content)
|
||||
|
||||
# Generate embeddings with enhanced rate limiting handling
|
||||
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
|
||||
|
||||
# Log if fallback embeddings were used
|
||||
if not success:
|
||||
logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded")
|
||||
log_module_event("rag", "fallback_embeddings_used", {
|
||||
"document_id": doc_id,
|
||||
"content_preview": content[:100] + "..." if len(content) > 100 else content
|
||||
})
|
||||
# Generate embeddings for all chunks in batch (more efficient)
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
|
||||
# Create document points
|
||||
points = []
|
||||
@@ -1197,16 +1215,8 @@ class RAGModule(BaseModule):
|
||||
# Chunk the document
|
||||
chunks = self._chunk_text(processed_doc.content)
|
||||
|
||||
# Generate embeddings with enhanced rate limiting handling
|
||||
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
|
||||
|
||||
# Log if fallback embeddings were used
|
||||
if not success:
|
||||
logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded")
|
||||
log_module_event("rag", "fallback_embeddings_used", {
|
||||
"document_id": processed_doc.id,
|
||||
"filename": processed_doc.original_filename
|
||||
})
|
||||
# Generate embeddings for all chunks in batch (more efficient)
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
|
||||
# Create document points with enhanced metadata
|
||||
points = []
|
||||
@@ -1277,6 +1287,154 @@ class RAGModule(BaseModule):
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _hybrid_search(self, collection_name: str, query: str, query_vector: List[float],
|
||||
query_filter: Optional[Filter], limit: int, score_threshold: float) -> List[Any]:
|
||||
"""Perform hybrid search combining vector similarity and BM25 scoring"""
|
||||
|
||||
# Preprocess query for BM25
|
||||
query_terms = self._preprocess_text_for_bm25(query)
|
||||
|
||||
# Get all documents from the collection (for BM25 scoring)
|
||||
# Note: In production, you'd want to optimize this with a proper BM25 index
|
||||
scroll_filter = query_filter or Filter()
|
||||
all_points = []
|
||||
|
||||
# Use scroll to get all points
|
||||
offset = None
|
||||
batch_size = 100
|
||||
while True:
|
||||
search_result = self.qdrant_client.scroll(
|
||||
collection_name=collection_name,
|
||||
scroll_filter=scroll_filter,
|
||||
limit=batch_size,
|
||||
offset=offset,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
points = search_result[0]
|
||||
all_points.extend(points)
|
||||
|
||||
if len(points) < batch_size:
|
||||
break
|
||||
|
||||
offset = points[-1].id
|
||||
|
||||
# Calculate BM25 scores for each document
|
||||
bm25_scores = {}
|
||||
for point in all_points:
|
||||
doc_id = point.payload.get("document_id", "")
|
||||
content = point.payload.get("content", "")
|
||||
|
||||
# Calculate BM25 score
|
||||
bm25_score = self._calculate_bm25_score(query_terms, content)
|
||||
bm25_scores[doc_id] = bm25_score
|
||||
|
||||
# Perform vector search
|
||||
vector_results = self.qdrant_client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=query_filter,
|
||||
limit=limit * 2, # Get more results for re-ranking
|
||||
score_threshold=score_threshold / 2 # Lower threshold for initial search
|
||||
)
|
||||
|
||||
# Combine scores
|
||||
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
|
||||
vector_weight = hybrid_weights.get("vector", 0.7)
|
||||
bm25_weight = hybrid_weights.get("bm25", 0.3)
|
||||
|
||||
# Create hybrid results
|
||||
hybrid_results = []
|
||||
for result in vector_results:
|
||||
doc_id = result.payload.get("document_id", "")
|
||||
vector_score = result.score
|
||||
bm25_score = bm25_scores.get(doc_id, 0.0)
|
||||
|
||||
# Normalize scores (simple min-max normalization)
|
||||
vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0
|
||||
bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1
|
||||
|
||||
# Calculate hybrid score
|
||||
hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm)
|
||||
|
||||
# Create new point with hybrid score
|
||||
hybrid_point = ScoredPoint(
|
||||
id=result.id,
|
||||
payload=result.payload,
|
||||
score=hybrid_score,
|
||||
vector=result.vector,
|
||||
shard_key=None,
|
||||
order_value=None
|
||||
)
|
||||
hybrid_results.append(hybrid_point)
|
||||
|
||||
# Sort by hybrid score and apply final threshold
|
||||
hybrid_results.sort(key=lambda x: x.score, reverse=True)
|
||||
final_results = [r for r in hybrid_results if r.score >= score_threshold][:limit]
|
||||
|
||||
logger.info(f"Hybrid search: {len(vector_results)} vector results, {len(final_results)} final results")
|
||||
return final_results
|
||||
|
||||
def _preprocess_text_for_bm25(self, text: str) -> List[str]:
|
||||
"""Preprocess text for BM25 scoring"""
|
||||
if not NLTK_AVAILABLE:
|
||||
return text.lower().split()
|
||||
|
||||
try:
|
||||
# Tokenize
|
||||
tokens = word_tokenize(text.lower())
|
||||
|
||||
# Remove stopwords and non-alphabetic tokens
|
||||
stop_words = set(stopwords.words('english'))
|
||||
filtered_tokens = [
|
||||
token for token in tokens
|
||||
if token.isalpha() and token not in stop_words and len(token) > 2
|
||||
]
|
||||
|
||||
return filtered_tokens
|
||||
except:
|
||||
# Fallback to simple splitting
|
||||
return text.lower().split()
|
||||
|
||||
def _calculate_bm25_score(self, query_terms: List[str], document: str) -> float:
|
||||
"""Calculate BM25 score for a document against query terms"""
|
||||
if not query_terms:
|
||||
return 0.0
|
||||
|
||||
# Preprocess document
|
||||
doc_terms = self._preprocess_text_for_bm25(document)
|
||||
if not doc_terms:
|
||||
return 0.0
|
||||
|
||||
# Calculate term frequencies
|
||||
doc_len = len(doc_terms)
|
||||
avg_doc_len = 300 # Average document length (configurable)
|
||||
|
||||
# BM25 parameters
|
||||
k1 = 1.2 # Controls term frequency saturation
|
||||
b = 0.75 # Controls document length normalization
|
||||
|
||||
score = 0.0
|
||||
|
||||
# Calculate IDF for each query term
|
||||
for term in set(query_terms):
|
||||
# Term frequency in document
|
||||
tf = doc_terms.count(term)
|
||||
|
||||
# Simple IDF (log(N/n) + 1)
|
||||
# In production, you'd use the actual document frequency
|
||||
idf = 2.0 # Simplified IDF
|
||||
|
||||
# BM25 formula
|
||||
numerator = tf * (k1 + 1)
|
||||
denominator = tf + k1 * (1 - b + b * (doc_len / avg_doc_len))
|
||||
|
||||
score += idf * (numerator / denominator)
|
||||
|
||||
# Normalize score to 0-1 range
|
||||
return min(score / 10.0, 1.0) # Simple normalization
|
||||
|
||||
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
|
||||
"""Search for relevant documents"""
|
||||
if not self.enabled:
|
||||
@@ -1314,14 +1472,29 @@ class RAGModule(BaseModule):
|
||||
logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}")
|
||||
logger.info(f"Embedding service available: {self.embedding_service is not None}")
|
||||
|
||||
# Search in Qdrant
|
||||
search_results = self.qdrant_client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_embedding,
|
||||
query_filter=search_filter,
|
||||
limit=max_results,
|
||||
score_threshold=0.0 # Lowered from 0.5 to see all results including low scores
|
||||
)
|
||||
# Check if hybrid search is enabled
|
||||
enable_hybrid = self.config.get("enable_hybrid", False)
|
||||
score_threshold = self.config.get("score_threshold", 0.3)
|
||||
|
||||
if enable_hybrid and NLTK_AVAILABLE:
|
||||
# Perform hybrid search (vector + BM25)
|
||||
search_results = await self._hybrid_search(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
query_vector=query_embedding,
|
||||
query_filter=search_filter,
|
||||
limit=max_results,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
else:
|
||||
# Pure vector search with improved threshold
|
||||
search_results = self.qdrant_client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_embedding,
|
||||
query_filter=search_filter,
|
||||
limit=max_results,
|
||||
score_threshold=score_threshold
|
||||
)
|
||||
|
||||
logger.info(f"Raw search results count: {len(search_results)}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user