rag improvements

This commit is contained in:
2025-09-21 18:44:02 +02:00
parent f58a76ac59
commit a2ee959ec9
3 changed files with 401 additions and 86 deletions

View File

@@ -0,0 +1,153 @@
"""
Token-based rate limiting for LLM service
"""
import time
import redis
from typing import Dict, Optional, Tuple
from datetime import datetime, timedelta
from ..core.config import settings
from ..core.logging import get_logger
logger = get_logger(__name__)
class TokenRateLimiter:
"""Token-based rate limiting implementation"""
def __init__(self):
try:
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
self.redis_client.ping()
logger.info("Token rate limiter initialized with Redis backend")
except Exception as e:
logger.warning(f"Redis not available for token rate limiting: {e}")
self.redis_client = None
# Fall back to in-memory rate limiting
self.in_memory_store = {}
logger.info("Token rate limiter using in-memory fallback")
async def check_token_limits(
self,
provider: str,
prompt_tokens: int,
completion_tokens: int = 0
) -> Tuple[bool, Dict[str, str]]:
"""
Check if token usage is within limits
Args:
provider: Provider name (e.g., "privatemode")
prompt_tokens: Number of prompt tokens to use
completion_tokens: Number of completion tokens to use
Returns:
Tuple of (is_allowed, headers)
"""
# Get token limits from configuration
from .config import get_config
config = get_config()
token_limits = config.token_limits_per_minute
# Check organization-wide limits
org_key = f"tokens:org:{provider}"
# Get current usage
current_usage = await self._get_token_usage(org_key)
# Calculate new usage
new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens
new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens
# Check limits
prompt_limit = token_limits.get("prompt_tokens", 20000)
completion_limit = token_limits.get("completion_tokens", 10000)
is_allowed = (
new_prompt_tokens <= prompt_limit and
new_completion_tokens <= completion_limit
)
if is_allowed:
# Update usage
await self._update_token_usage(org_key, prompt_tokens, completion_tokens)
logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, "
f"{new_completion_tokens}/{completion_limit} completion")
# Calculate remaining tokens
remaining_prompt = max(0, prompt_limit - new_prompt_tokens)
remaining_completion = max(0, completion_limit - new_completion_tokens)
# Create headers
headers = {
"X-TokenLimit-Prompt-Remaining": str(remaining_prompt),
"X-TokenLimit-Completion-Remaining": str(remaining_completion),
"X-TokenLimit-Prompt-Limit": str(prompt_limit),
"X-TokenLimit-Completion-Limit": str(completion_limit),
"X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute
}
if not is_allowed:
logger.warning(f"Token rate limit exceeded for {provider}. "
f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. "
f"Current: {current_usage}")
return is_allowed, headers
async def _get_token_usage(self, key: str) -> Dict[str, int]:
"""Get current token usage"""
if self.redis_client:
try:
data = self.redis_client.hgetall(key)
if data:
return {
"prompt_tokens": int(data.get("prompt_tokens", 0)),
"completion_tokens": int(data.get("completion_tokens", 0)),
"updated_at": float(data.get("updated_at", time.time()))
}
except Exception as e:
logger.error(f"Error getting token usage from Redis: {e}")
# Fallback to in-memory
return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0})
async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int):
"""Update token usage"""
if self.redis_client:
try:
pipe = self.redis_client.pipeline()
pipe.hincrby(key, "prompt_tokens", prompt_tokens)
pipe.hincrby(key, "completion_tokens", completion_tokens)
pipe.hset(key, "updated_at", time.time())
pipe.expire(key, 60) # Expire after 1 minute
pipe.execute()
except Exception as e:
logger.error(f"Error updating token usage in Redis: {e}")
# Fallback to in-memory
self._update_in_memory(key, prompt_tokens, completion_tokens)
else:
self._update_in_memory(key, prompt_tokens, completion_tokens)
def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int):
"""Update in-memory token usage"""
if key not in self.in_memory_store:
self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0}
self.in_memory_store[key]["prompt_tokens"] += prompt_tokens
self.in_memory_store[key]["completion_tokens"] += completion_tokens
self.in_memory_store[key]["updated_at"] = time.time()
def cleanup_expired(self):
"""Clean up expired entries (for in-memory store)"""
if not self.redis_client:
current_time = time.time()
expired_keys = [
key for key, data in self.in_memory_store.items()
if current_time - data.get("updated_at", 0) > 60
]
for key in expired_keys:
del self.in_memory_store[key]
# Global token rate limiter instance
token_rate_limiter = TokenRateLimiter()

View File

@@ -265,7 +265,6 @@ class ChatbotModule(BaseModule):
async def chat_completion(self, request: ChatRequest, user_id: str, db: Session) -> ChatResponse: async def chat_completion(self, request: ChatRequest, user_id: str, db: Session) -> ChatResponse:
"""Generate chat completion response""" """Generate chat completion response"""
logger.info("=== CHAT COMPLETION METHOD CALLED ===")
# Get chatbot configuration from database # Get chatbot configuration from database
db_chatbot = db.query(DBChatbotInstance).filter(DBChatbotInstance.id == request.chatbot_id).first() db_chatbot = db.query(DBChatbotInstance).filter(DBChatbotInstance.id == request.chatbot_id).first()
@@ -364,11 +363,10 @@ class ChatbotModule(BaseModule):
metadata={"error": str(e), "fallback": True} metadata={"error": str(e), "fallback": True}
) )
async def _generate_response(self, message: str, db_messages: List[DBMessage], async def _generate_response(self, message: str, db_messages: List[DBMessage],
config: ChatbotConfig, context: Optional[Dict] = None, db: Session = None) -> tuple[str, Optional[List]]: config: ChatbotConfig, context: Optional[Dict] = None, db: Session = None) -> tuple[str, Optional[List]]:
"""Generate response using LLM with optional RAG""" """Generate response using LLM with optional RAG"""
logger.info("=== _generate_response METHOD CALLED ===")
# Lazy load dependencies if not available # Lazy load dependencies if not available
await self._ensure_dependencies() await self._ensure_dependencies()
@@ -397,8 +395,8 @@ class ChatbotModule(BaseModule):
for i, result in enumerate(rag_results)] for i, result in enumerate(rag_results)]
# Build full RAG context from all results # Build full RAG context from all results
rag_context = "\\n\\nRelevant information from knowledge base:\\n" + "\\n\\n".join([ rag_context = "\n\nRelevant information from knowledge base:\n" + "\n\n".join([
f"[Document {i+1}]:\\n{result.document.content}" for i, result in enumerate(rag_results) f"[Document {i+1}]:\n{result.document.content}" for i, result in enumerate(rag_results)
]) ])
# Detailed RAG logging - ALWAYS log for debugging # Detailed RAG logging - ALWAYS log for debugging
@@ -407,14 +405,14 @@ class ChatbotModule(BaseModule):
logger.info(f"Collection: {qdrant_collection_name}") logger.info(f"Collection: {qdrant_collection_name}")
logger.info(f"Number of results: {len(rag_results)}") logger.info(f"Number of results: {len(rag_results)}")
for i, result in enumerate(rag_results): for i, result in enumerate(rag_results):
logger.info(f"\\n--- RAG Result {i+1} ---") logger.info(f"\n--- RAG Result {i+1} ---")
logger.info(f"Score: {getattr(result, 'score', 'N/A')}") logger.info(f"Score: {getattr(result, 'score', 'N/A')}")
logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}") logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}")
logger.info(f"Full Content ({len(result.document.content)} chars):") logger.info(f"Full Content ({len(result.document.content)} chars):")
logger.info(f"{result.document.content}") logger.info(f"{result.document.content}")
if hasattr(result.document, 'metadata'): if hasattr(result.document, 'metadata'):
logger.info(f"Metadata: {result.document.metadata}") logger.info(f"Metadata: {result.document.metadata}")
logger.info(f"\\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===") logger.info(f"\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
logger.info(rag_context) logger.info(rag_context)
logger.info("=== END RAG SEARCH RESULTS ===") logger.info("=== END RAG SEARCH RESULTS ===")
else: else:
@@ -428,11 +426,6 @@ class ChatbotModule(BaseModule):
logger.warning(f"RAG search traceback: {traceback.format_exc()}") logger.warning(f"RAG search traceback: {traceback.format_exc()}")
# Build conversation context (includes the current message from db_messages) # Build conversation context (includes the current message from db_messages)
logger.info(f"=== CRITICAL DEBUG ===")
logger.info(f"rag_context length: {len(rag_context)}")
logger.info(f"rag_context empty: {not rag_context}")
logger.info(f"rag_context preview: {rag_context[:200] if rag_context else 'EMPTY'}")
logger.info(f"=== END CRITICAL DEBUG ===")
messages = self._build_conversation_messages(db_messages, config, rag_context, context) messages = self._build_conversation_messages(db_messages, config, rag_context, context)
# Note: Current user message is already included in db_messages from the query # Note: Current user message is already included in db_messages from the query
@@ -452,9 +445,9 @@ class ChatbotModule(BaseModule):
if config.use_rag and rag_context: if config.use_rag and rag_context:
logger.info(f"RAG context added: {len(rag_context)} characters") logger.info(f"RAG context added: {len(rag_context)} characters")
logger.info(f"RAG sources: {len(sources) if sources else 0} documents") logger.info(f"RAG sources: {len(sources) if sources else 0} documents")
logger.info("\\n=== COMPLETE MESSAGES SENT TO LLM ===") logger.info("\n=== COMPLETE MESSAGES SENT TO LLM ===")
for i, msg in enumerate(messages): for i, msg in enumerate(messages):
logger.info(f"\\n--- Message {i+1} ---") logger.info(f"\n--- Message {i+1} ---")
logger.info(f"Role: {msg['role']}") logger.info(f"Role: {msg['role']}")
logger.info(f"Content ({len(msg['content'])} chars):") logger.info(f"Content ({len(msg['content'])} chars):")
# Truncate long content for logging (full RAG context can be very long) # Truncate long content for logging (full RAG context can be very long)
@@ -518,38 +511,34 @@ class ChatbotModule(BaseModule):
# Return fallback if available # Return fallback if available
return "I'm currently unable to process your request. Please try again later.", None return "I'm currently unable to process your request. Please try again later.", None
def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig, def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig,
rag_context: str = "", context: Optional[Dict] = None) -> List[Dict]: rag_context: str = "", context: Optional[Dict] = None) -> List[Dict]:
"""Build messages array for LLM completion""" """Build messages array for LLM completion"""
messages = [] messages = []
logger.info(f"DEBUG: _build_conversation_messages called. rag_context length: {len(rag_context)}")
# System prompt
# System prompt - keep it clean without RAG context
system_prompt = config.system_prompt system_prompt = config.system_prompt
if rag_context:
# Add explicit instruction to use RAG context
system_prompt += "\n\nIMPORTANT: Use the following information from the knowledge base to answer the user's question. " \
"This information is directly relevant to their query and should be your primary source:\n" + rag_context
if context and context.get('additional_instructions'): if context and context.get('additional_instructions'):
system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}" system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}"
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
logger.info(f"Building messages from {len(db_messages)} database messages") logger.info(f"Building messages from {len(db_messages)} database messages")
# Conversation history (messages are already limited by memory_length in the query) # Conversation history (messages are already limited by memory_length in the query)
# Reverse to get chronological order # Reverse to get chronological order
# Include ALL messages - the current user message is needed for the LLM to respond! # Include ALL messages - the current user message is needed for the LLM to respond!
for idx, msg in enumerate(reversed(db_messages)): for idx, msg in enumerate(reversed(db_messages)):
logger.info(f"Processing message {idx}: role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...") logger.info(f"Processing message {idx}: role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...")
if msg.role in ["user", "assistant"]: if msg.role in ["user", "assistant"]:
# For user messages, prepend RAG context if available
content = msg.content
if msg.role == "user" and rag_context and idx == 0:
# Add RAG context to the current user message (first in reversed order)
content = f"Relevant information from knowledge base:\n{rag_context}\n\nQuestion: {msg.content}"
logger.info("Added RAG context to user message")
messages.append({ messages.append({
"role": msg.role, "role": msg.role,
"content": content "content": msg.content
}) })
logger.info(f"Added message with role {msg.role} to LLM messages") logger.info(f"Added message with role {msg.role} to LLM messages")
else: else:
@@ -690,10 +679,9 @@ class ChatbotModule(BaseModule):
return router return router
# API Compatibility Methods # API Compatibility Methods
async def chat(self, chatbot_config: Dict[str, Any], message: str, async def chat(self, chatbot_config: Dict[str, Any], message: str,
conversation_history: List = None, user_id: str = "anonymous") -> Dict[str, Any]: conversation_history: List = None, user_id: str = "anonymous") -> Dict[str, Any]:
"""Chat method for API compatibility""" """Chat method for API compatibility"""
logger.info("=== CHAT METHOD (API COMPATIBILITY) CALLED ===")
logger.info(f"Chat method called with message: {message[:50]}... by user: {user_id}") logger.info(f"Chat method called with message: {message[:50]}... by user: {user_id}")
# Lazy load dependencies # Lazy load dependencies
@@ -723,20 +711,21 @@ class ChatbotModule(BaseModule):
fallback_responses=chatbot_config.get("fallback_responses", []) fallback_responses=chatbot_config.get("fallback_responses", [])
) )
# For API compatibility, create a temporary DBMessage for the current message # Generate response using internal method
# so RAG context can be properly added # Create a temporary message object for the current user message
from app.models.chatbot import ChatbotMessage as DBMessage temp_messages = [
DBMessage(
id=0,
conversation_id=0,
role="user",
content=message,
timestamp=datetime.utcnow(),
metadata={}
)
]
# Create a temporary user message with the current message
temp_user_message = DBMessage(
conversation_id="temp_conversation",
role=MessageRole.USER.value,
content=message
)
# Generate response using internal method with the current message included
response_content, sources = await self._generate_response( response_content, sources = await self._generate_response(
message, [temp_user_message], config, None, db message, temp_messages, config, None, db
) )
return { return {

View File

@@ -53,14 +53,13 @@ except ImportError:
PYTHON_DOCX_AVAILABLE = False PYTHON_DOCX_AVAILABLE = False
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue from qdrant_client.models import Distance, VectorParams, PointStruct, ScoredPoint, Filter, FieldCondition, MatchValue
from qdrant_client.http import models from qdrant_client.http import models
import tiktoken import tiktoken
from app.core.config import settings from app.core.config import settings
from app.core.logging import log_module_event from app.core.logging import log_module_event
from app.services.base_module import BaseModule, Permission from app.services.base_module import BaseModule, Permission
from app.services.enhanced_embedding_service import enhanced_embedding_service
@dataclass @dataclass
@@ -134,6 +133,19 @@ class RAGModule(BaseModule):
self.embedding_model = None self.embedding_model = None
self.embedding_service = None self.embedding_service = None
self.tokenizer = None self.tokenizer = None
# Set improved default configuration
self.config = {
"chunk_size": 300, # Reduced from 400 for better precision
"chunk_overlap": 50, # Added overlap for context preservation
"max_results": 10,
"score_threshold": 0.3, # Increased from 0.0 to filter low-quality results
"enable_hybrid": True, # Enable hybrid search (vector + BM25)
"hybrid_weights": {"vector": 0.7, "bm25": 0.3} # Weight for hybrid scoring
}
# Update with any provided config
if config:
self.config.update(config)
# Content processing components # Content processing components
self.nlp_model = None self.nlp_model = None
@@ -640,19 +652,33 @@ class RAGModule(BaseModule):
return embeddings return embeddings
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]: def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
"""Split text into chunks""" """Split text into overlapping chunks for better context preservation"""
chunk_size = chunk_size or self.config.get("chunk_size", 400) chunk_size = chunk_size or self.config.get("chunk_size", 300)
chunk_overlap = self.config.get("chunk_overlap", 50)
# Tokenize text # Tokenize text
tokens = self.tokenizer.encode(text) tokens = self.tokenizer.encode(text)
# Split into chunks # Split into chunks with overlap
chunks = [] chunks = []
for i in range(0, len(tokens), chunk_size): start_idx = 0
chunk_tokens = tokens[i:i + chunk_size]
while start_idx < len(tokens):
end_idx = min(start_idx + chunk_size, len(tokens))
chunk_tokens = tokens[start_idx:end_idx]
chunk_text = self.tokenizer.decode(chunk_tokens) chunk_text = self.tokenizer.decode(chunk_tokens)
chunks.append(chunk_text)
# Only add non-empty chunks
if chunk_text.strip():
chunks.append(chunk_text)
# Move to next chunk with overlap
start_idx = end_idx - chunk_overlap
# Ensure progress (in case overlap >= chunk_size)
if start_idx >= end_idx:
start_idx = end_idx
return chunks return chunks
async def _process_text(self, content: bytes, filename: str) -> str: async def _process_text(self, content: bytes, filename: str) -> str:
@@ -1126,17 +1152,9 @@ class RAGModule(BaseModule):
# Chunk the document # Chunk the document
chunks = self._chunk_text(content) chunks = self._chunk_text(content)
# Generate embeddings with enhanced rate limiting handling # Generate embeddings for all chunks in batch (more efficient)
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks) embeddings = await self._generate_embeddings(chunks)
# Log if fallback embeddings were used
if not success:
logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded")
log_module_event("rag", "fallback_embeddings_used", {
"document_id": doc_id,
"content_preview": content[:100] + "..." if len(content) > 100 else content
})
# Create document points # Create document points
points = [] points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
@@ -1197,17 +1215,9 @@ class RAGModule(BaseModule):
# Chunk the document # Chunk the document
chunks = self._chunk_text(processed_doc.content) chunks = self._chunk_text(processed_doc.content)
# Generate embeddings with enhanced rate limiting handling # Generate embeddings for all chunks in batch (more efficient)
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks) embeddings = await self._generate_embeddings(chunks)
# Log if fallback embeddings were used
if not success:
logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded")
log_module_event("rag", "fallback_embeddings_used", {
"document_id": processed_doc.id,
"filename": processed_doc.original_filename
})
# Create document points with enhanced metadata # Create document points with enhanced metadata
points = [] points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
@@ -1277,6 +1287,154 @@ class RAGModule(BaseModule):
except Exception: except Exception:
return False return False
async def _hybrid_search(self, collection_name: str, query: str, query_vector: List[float],
query_filter: Optional[Filter], limit: int, score_threshold: float) -> List[Any]:
"""Perform hybrid search combining vector similarity and BM25 scoring"""
# Preprocess query for BM25
query_terms = self._preprocess_text_for_bm25(query)
# Get all documents from the collection (for BM25 scoring)
# Note: In production, you'd want to optimize this with a proper BM25 index
scroll_filter = query_filter or Filter()
all_points = []
# Use scroll to get all points
offset = None
batch_size = 100
while True:
search_result = self.qdrant_client.scroll(
collection_name=collection_name,
scroll_filter=scroll_filter,
limit=batch_size,
offset=offset,
with_payload=True,
with_vectors=False
)
points = search_result[0]
all_points.extend(points)
if len(points) < batch_size:
break
offset = points[-1].id
# Calculate BM25 scores for each document
bm25_scores = {}
for point in all_points:
doc_id = point.payload.get("document_id", "")
content = point.payload.get("content", "")
# Calculate BM25 score
bm25_score = self._calculate_bm25_score(query_terms, content)
bm25_scores[doc_id] = bm25_score
# Perform vector search
vector_results = self.qdrant_client.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=query_filter,
limit=limit * 2, # Get more results for re-ranking
score_threshold=score_threshold / 2 # Lower threshold for initial search
)
# Combine scores
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
vector_weight = hybrid_weights.get("vector", 0.7)
bm25_weight = hybrid_weights.get("bm25", 0.3)
# Create hybrid results
hybrid_results = []
for result in vector_results:
doc_id = result.payload.get("document_id", "")
vector_score = result.score
bm25_score = bm25_scores.get(doc_id, 0.0)
# Normalize scores (simple min-max normalization)
vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0
bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1
# Calculate hybrid score
hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm)
# Create new point with hybrid score
hybrid_point = ScoredPoint(
id=result.id,
payload=result.payload,
score=hybrid_score,
vector=result.vector,
shard_key=None,
order_value=None
)
hybrid_results.append(hybrid_point)
# Sort by hybrid score and apply final threshold
hybrid_results.sort(key=lambda x: x.score, reverse=True)
final_results = [r for r in hybrid_results if r.score >= score_threshold][:limit]
logger.info(f"Hybrid search: {len(vector_results)} vector results, {len(final_results)} final results")
return final_results
def _preprocess_text_for_bm25(self, text: str) -> List[str]:
"""Preprocess text for BM25 scoring"""
if not NLTK_AVAILABLE:
return text.lower().split()
try:
# Tokenize
tokens = word_tokenize(text.lower())
# Remove stopwords and non-alphabetic tokens
stop_words = set(stopwords.words('english'))
filtered_tokens = [
token for token in tokens
if token.isalpha() and token not in stop_words and len(token) > 2
]
return filtered_tokens
except:
# Fallback to simple splitting
return text.lower().split()
def _calculate_bm25_score(self, query_terms: List[str], document: str) -> float:
"""Calculate BM25 score for a document against query terms"""
if not query_terms:
return 0.0
# Preprocess document
doc_terms = self._preprocess_text_for_bm25(document)
if not doc_terms:
return 0.0
# Calculate term frequencies
doc_len = len(doc_terms)
avg_doc_len = 300 # Average document length (configurable)
# BM25 parameters
k1 = 1.2 # Controls term frequency saturation
b = 0.75 # Controls document length normalization
score = 0.0
# Calculate IDF for each query term
for term in set(query_terms):
# Term frequency in document
tf = doc_terms.count(term)
# Simple IDF (log(N/n) + 1)
# In production, you'd use the actual document frequency
idf = 2.0 # Simplified IDF
# BM25 formula
numerator = tf * (k1 + 1)
denominator = tf + k1 * (1 - b + b * (doc_len / avg_doc_len))
score += idf * (numerator / denominator)
# Normalize score to 0-1 range
return min(score / 10.0, 1.0) # Simple normalization
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]: async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
"""Search for relevant documents""" """Search for relevant documents"""
if not self.enabled: if not self.enabled:
@@ -1314,14 +1472,29 @@ class RAGModule(BaseModule):
logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}") logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}")
logger.info(f"Embedding service available: {self.embedding_service is not None}") logger.info(f"Embedding service available: {self.embedding_service is not None}")
# Search in Qdrant # Check if hybrid search is enabled
search_results = self.qdrant_client.search( enable_hybrid = self.config.get("enable_hybrid", False)
collection_name=collection_name, score_threshold = self.config.get("score_threshold", 0.3)
query_vector=query_embedding,
query_filter=search_filter, if enable_hybrid and NLTK_AVAILABLE:
limit=max_results, # Perform hybrid search (vector + BM25)
score_threshold=0.0 # Lowered from 0.5 to see all results including low scores search_results = await self._hybrid_search(
) collection_name=collection_name,
query=query,
query_vector=query_embedding,
query_filter=search_filter,
limit=max_results,
score_threshold=score_threshold
)
else:
# Pure vector search with improved threshold
search_results = self.qdrant_client.search(
collection_name=collection_name,
query_vector=query_embedding,
query_filter=search_filter,
limit=max_results,
score_threshold=score_threshold
)
logger.info(f"Raw search results count: {len(search_results)}") logger.info(f"Raw search results count: {len(search_results)}")