mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
rag improvements
This commit is contained in:
153
backend/app/services/llm/token_rate_limiter.py
Normal file
153
backend/app/services/llm/token_rate_limiter.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""
|
||||||
|
Token-based rate limiting for LLM service
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import redis
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from ..core.config import settings
|
||||||
|
from ..core.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenRateLimiter:
|
||||||
|
"""Token-based rate limiting implementation"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
try:
|
||||||
|
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||||
|
self.redis_client.ping()
|
||||||
|
logger.info("Token rate limiter initialized with Redis backend")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis not available for token rate limiting: {e}")
|
||||||
|
self.redis_client = None
|
||||||
|
# Fall back to in-memory rate limiting
|
||||||
|
self.in_memory_store = {}
|
||||||
|
logger.info("Token rate limiter using in-memory fallback")
|
||||||
|
|
||||||
|
async def check_token_limits(
|
||||||
|
self,
|
||||||
|
provider: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int = 0
|
||||||
|
) -> Tuple[bool, Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Check if token usage is within limits
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name (e.g., "privatemode")
|
||||||
|
prompt_tokens: Number of prompt tokens to use
|
||||||
|
completion_tokens: Number of completion tokens to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_allowed, headers)
|
||||||
|
"""
|
||||||
|
# Get token limits from configuration
|
||||||
|
from .config import get_config
|
||||||
|
config = get_config()
|
||||||
|
token_limits = config.token_limits_per_minute
|
||||||
|
|
||||||
|
# Check organization-wide limits
|
||||||
|
org_key = f"tokens:org:{provider}"
|
||||||
|
|
||||||
|
# Get current usage
|
||||||
|
current_usage = await self._get_token_usage(org_key)
|
||||||
|
|
||||||
|
# Calculate new usage
|
||||||
|
new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens
|
||||||
|
new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens
|
||||||
|
|
||||||
|
# Check limits
|
||||||
|
prompt_limit = token_limits.get("prompt_tokens", 20000)
|
||||||
|
completion_limit = token_limits.get("completion_tokens", 10000)
|
||||||
|
|
||||||
|
is_allowed = (
|
||||||
|
new_prompt_tokens <= prompt_limit and
|
||||||
|
new_completion_tokens <= completion_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_allowed:
|
||||||
|
# Update usage
|
||||||
|
await self._update_token_usage(org_key, prompt_tokens, completion_tokens)
|
||||||
|
logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, "
|
||||||
|
f"{new_completion_tokens}/{completion_limit} completion")
|
||||||
|
|
||||||
|
# Calculate remaining tokens
|
||||||
|
remaining_prompt = max(0, prompt_limit - new_prompt_tokens)
|
||||||
|
remaining_completion = max(0, completion_limit - new_completion_tokens)
|
||||||
|
|
||||||
|
# Create headers
|
||||||
|
headers = {
|
||||||
|
"X-TokenLimit-Prompt-Remaining": str(remaining_prompt),
|
||||||
|
"X-TokenLimit-Completion-Remaining": str(remaining_completion),
|
||||||
|
"X-TokenLimit-Prompt-Limit": str(prompt_limit),
|
||||||
|
"X-TokenLimit-Completion-Limit": str(completion_limit),
|
||||||
|
"X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute
|
||||||
|
}
|
||||||
|
|
||||||
|
if not is_allowed:
|
||||||
|
logger.warning(f"Token rate limit exceeded for {provider}. "
|
||||||
|
f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. "
|
||||||
|
f"Current: {current_usage}")
|
||||||
|
|
||||||
|
return is_allowed, headers
|
||||||
|
|
||||||
|
async def _get_token_usage(self, key: str) -> Dict[str, int]:
|
||||||
|
"""Get current token usage"""
|
||||||
|
if self.redis_client:
|
||||||
|
try:
|
||||||
|
data = self.redis_client.hgetall(key)
|
||||||
|
if data:
|
||||||
|
return {
|
||||||
|
"prompt_tokens": int(data.get("prompt_tokens", 0)),
|
||||||
|
"completion_tokens": int(data.get("completion_tokens", 0)),
|
||||||
|
"updated_at": float(data.get("updated_at", time.time()))
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting token usage from Redis: {e}")
|
||||||
|
|
||||||
|
# Fallback to in-memory
|
||||||
|
return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0})
|
||||||
|
|
||||||
|
async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int):
|
||||||
|
"""Update token usage"""
|
||||||
|
if self.redis_client:
|
||||||
|
try:
|
||||||
|
pipe = self.redis_client.pipeline()
|
||||||
|
pipe.hincrby(key, "prompt_tokens", prompt_tokens)
|
||||||
|
pipe.hincrby(key, "completion_tokens", completion_tokens)
|
||||||
|
pipe.hset(key, "updated_at", time.time())
|
||||||
|
pipe.expire(key, 60) # Expire after 1 minute
|
||||||
|
pipe.execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating token usage in Redis: {e}")
|
||||||
|
# Fallback to in-memory
|
||||||
|
self._update_in_memory(key, prompt_tokens, completion_tokens)
|
||||||
|
else:
|
||||||
|
self._update_in_memory(key, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int):
|
||||||
|
"""Update in-memory token usage"""
|
||||||
|
if key not in self.in_memory_store:
|
||||||
|
self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||||
|
|
||||||
|
self.in_memory_store[key]["prompt_tokens"] += prompt_tokens
|
||||||
|
self.in_memory_store[key]["completion_tokens"] += completion_tokens
|
||||||
|
self.in_memory_store[key]["updated_at"] = time.time()
|
||||||
|
|
||||||
|
def cleanup_expired(self):
|
||||||
|
"""Clean up expired entries (for in-memory store)"""
|
||||||
|
if not self.redis_client:
|
||||||
|
current_time = time.time()
|
||||||
|
expired_keys = [
|
||||||
|
key for key, data in self.in_memory_store.items()
|
||||||
|
if current_time - data.get("updated_at", 0) > 60
|
||||||
|
]
|
||||||
|
for key in expired_keys:
|
||||||
|
del self.in_memory_store[key]
|
||||||
|
|
||||||
|
|
||||||
|
# Global token rate limiter instance
|
||||||
|
token_rate_limiter = TokenRateLimiter()
|
||||||
@@ -265,7 +265,6 @@ class ChatbotModule(BaseModule):
|
|||||||
|
|
||||||
async def chat_completion(self, request: ChatRequest, user_id: str, db: Session) -> ChatResponse:
|
async def chat_completion(self, request: ChatRequest, user_id: str, db: Session) -> ChatResponse:
|
||||||
"""Generate chat completion response"""
|
"""Generate chat completion response"""
|
||||||
logger.info("=== CHAT COMPLETION METHOD CALLED ===")
|
|
||||||
|
|
||||||
# Get chatbot configuration from database
|
# Get chatbot configuration from database
|
||||||
db_chatbot = db.query(DBChatbotInstance).filter(DBChatbotInstance.id == request.chatbot_id).first()
|
db_chatbot = db.query(DBChatbotInstance).filter(DBChatbotInstance.id == request.chatbot_id).first()
|
||||||
@@ -367,7 +366,6 @@ class ChatbotModule(BaseModule):
|
|||||||
async def _generate_response(self, message: str, db_messages: List[DBMessage],
|
async def _generate_response(self, message: str, db_messages: List[DBMessage],
|
||||||
config: ChatbotConfig, context: Optional[Dict] = None, db: Session = None) -> tuple[str, Optional[List]]:
|
config: ChatbotConfig, context: Optional[Dict] = None, db: Session = None) -> tuple[str, Optional[List]]:
|
||||||
"""Generate response using LLM with optional RAG"""
|
"""Generate response using LLM with optional RAG"""
|
||||||
logger.info("=== _generate_response METHOD CALLED ===")
|
|
||||||
|
|
||||||
# Lazy load dependencies if not available
|
# Lazy load dependencies if not available
|
||||||
await self._ensure_dependencies()
|
await self._ensure_dependencies()
|
||||||
@@ -397,8 +395,8 @@ class ChatbotModule(BaseModule):
|
|||||||
for i, result in enumerate(rag_results)]
|
for i, result in enumerate(rag_results)]
|
||||||
|
|
||||||
# Build full RAG context from all results
|
# Build full RAG context from all results
|
||||||
rag_context = "\\n\\nRelevant information from knowledge base:\\n" + "\\n\\n".join([
|
rag_context = "\n\nRelevant information from knowledge base:\n" + "\n\n".join([
|
||||||
f"[Document {i+1}]:\\n{result.document.content}" for i, result in enumerate(rag_results)
|
f"[Document {i+1}]:\n{result.document.content}" for i, result in enumerate(rag_results)
|
||||||
])
|
])
|
||||||
|
|
||||||
# Detailed RAG logging - ALWAYS log for debugging
|
# Detailed RAG logging - ALWAYS log for debugging
|
||||||
@@ -407,14 +405,14 @@ class ChatbotModule(BaseModule):
|
|||||||
logger.info(f"Collection: {qdrant_collection_name}")
|
logger.info(f"Collection: {qdrant_collection_name}")
|
||||||
logger.info(f"Number of results: {len(rag_results)}")
|
logger.info(f"Number of results: {len(rag_results)}")
|
||||||
for i, result in enumerate(rag_results):
|
for i, result in enumerate(rag_results):
|
||||||
logger.info(f"\\n--- RAG Result {i+1} ---")
|
logger.info(f"\n--- RAG Result {i+1} ---")
|
||||||
logger.info(f"Score: {getattr(result, 'score', 'N/A')}")
|
logger.info(f"Score: {getattr(result, 'score', 'N/A')}")
|
||||||
logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}")
|
logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}")
|
||||||
logger.info(f"Full Content ({len(result.document.content)} chars):")
|
logger.info(f"Full Content ({len(result.document.content)} chars):")
|
||||||
logger.info(f"{result.document.content}")
|
logger.info(f"{result.document.content}")
|
||||||
if hasattr(result.document, 'metadata'):
|
if hasattr(result.document, 'metadata'):
|
||||||
logger.info(f"Metadata: {result.document.metadata}")
|
logger.info(f"Metadata: {result.document.metadata}")
|
||||||
logger.info(f"\\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
|
logger.info(f"\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
|
||||||
logger.info(rag_context)
|
logger.info(rag_context)
|
||||||
logger.info("=== END RAG SEARCH RESULTS ===")
|
logger.info("=== END RAG SEARCH RESULTS ===")
|
||||||
else:
|
else:
|
||||||
@@ -428,11 +426,6 @@ class ChatbotModule(BaseModule):
|
|||||||
logger.warning(f"RAG search traceback: {traceback.format_exc()}")
|
logger.warning(f"RAG search traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
# Build conversation context (includes the current message from db_messages)
|
# Build conversation context (includes the current message from db_messages)
|
||||||
logger.info(f"=== CRITICAL DEBUG ===")
|
|
||||||
logger.info(f"rag_context length: {len(rag_context)}")
|
|
||||||
logger.info(f"rag_context empty: {not rag_context}")
|
|
||||||
logger.info(f"rag_context preview: {rag_context[:200] if rag_context else 'EMPTY'}")
|
|
||||||
logger.info(f"=== END CRITICAL DEBUG ===")
|
|
||||||
messages = self._build_conversation_messages(db_messages, config, rag_context, context)
|
messages = self._build_conversation_messages(db_messages, config, rag_context, context)
|
||||||
|
|
||||||
# Note: Current user message is already included in db_messages from the query
|
# Note: Current user message is already included in db_messages from the query
|
||||||
@@ -452,9 +445,9 @@ class ChatbotModule(BaseModule):
|
|||||||
if config.use_rag and rag_context:
|
if config.use_rag and rag_context:
|
||||||
logger.info(f"RAG context added: {len(rag_context)} characters")
|
logger.info(f"RAG context added: {len(rag_context)} characters")
|
||||||
logger.info(f"RAG sources: {len(sources) if sources else 0} documents")
|
logger.info(f"RAG sources: {len(sources) if sources else 0} documents")
|
||||||
logger.info("\\n=== COMPLETE MESSAGES SENT TO LLM ===")
|
logger.info("\n=== COMPLETE MESSAGES SENT TO LLM ===")
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
logger.info(f"\\n--- Message {i+1} ---")
|
logger.info(f"\n--- Message {i+1} ---")
|
||||||
logger.info(f"Role: {msg['role']}")
|
logger.info(f"Role: {msg['role']}")
|
||||||
logger.info(f"Content ({len(msg['content'])} chars):")
|
logger.info(f"Content ({len(msg['content'])} chars):")
|
||||||
# Truncate long content for logging (full RAG context can be very long)
|
# Truncate long content for logging (full RAG context can be very long)
|
||||||
@@ -523,10 +516,13 @@ class ChatbotModule(BaseModule):
|
|||||||
"""Build messages array for LLM completion"""
|
"""Build messages array for LLM completion"""
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
logger.info(f"DEBUG: _build_conversation_messages called. rag_context length: {len(rag_context)}")
|
|
||||||
|
|
||||||
# System prompt - keep it clean without RAG context
|
# System prompt
|
||||||
system_prompt = config.system_prompt
|
system_prompt = config.system_prompt
|
||||||
|
if rag_context:
|
||||||
|
# Add explicit instruction to use RAG context
|
||||||
|
system_prompt += "\n\nIMPORTANT: Use the following information from the knowledge base to answer the user's question. " \
|
||||||
|
"This information is directly relevant to their query and should be your primary source:\n" + rag_context
|
||||||
if context and context.get('additional_instructions'):
|
if context and context.get('additional_instructions'):
|
||||||
system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}"
|
system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}"
|
||||||
|
|
||||||
@@ -540,16 +536,9 @@ class ChatbotModule(BaseModule):
|
|||||||
for idx, msg in enumerate(reversed(db_messages)):
|
for idx, msg in enumerate(reversed(db_messages)):
|
||||||
logger.info(f"Processing message {idx}: role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...")
|
logger.info(f"Processing message {idx}: role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...")
|
||||||
if msg.role in ["user", "assistant"]:
|
if msg.role in ["user", "assistant"]:
|
||||||
# For user messages, prepend RAG context if available
|
|
||||||
content = msg.content
|
|
||||||
if msg.role == "user" and rag_context and idx == 0:
|
|
||||||
# Add RAG context to the current user message (first in reversed order)
|
|
||||||
content = f"Relevant information from knowledge base:\n{rag_context}\n\nQuestion: {msg.content}"
|
|
||||||
logger.info("Added RAG context to user message")
|
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": msg.role,
|
"role": msg.role,
|
||||||
"content": content
|
"content": msg.content
|
||||||
})
|
})
|
||||||
logger.info(f"Added message with role {msg.role} to LLM messages")
|
logger.info(f"Added message with role {msg.role} to LLM messages")
|
||||||
else:
|
else:
|
||||||
@@ -693,7 +682,6 @@ class ChatbotModule(BaseModule):
|
|||||||
async def chat(self, chatbot_config: Dict[str, Any], message: str,
|
async def chat(self, chatbot_config: Dict[str, Any], message: str,
|
||||||
conversation_history: List = None, user_id: str = "anonymous") -> Dict[str, Any]:
|
conversation_history: List = None, user_id: str = "anonymous") -> Dict[str, Any]:
|
||||||
"""Chat method for API compatibility"""
|
"""Chat method for API compatibility"""
|
||||||
logger.info("=== CHAT METHOD (API COMPATIBILITY) CALLED ===")
|
|
||||||
logger.info(f"Chat method called with message: {message[:50]}... by user: {user_id}")
|
logger.info(f"Chat method called with message: {message[:50]}... by user: {user_id}")
|
||||||
|
|
||||||
# Lazy load dependencies
|
# Lazy load dependencies
|
||||||
@@ -723,20 +711,21 @@ class ChatbotModule(BaseModule):
|
|||||||
fallback_responses=chatbot_config.get("fallback_responses", [])
|
fallback_responses=chatbot_config.get("fallback_responses", [])
|
||||||
)
|
)
|
||||||
|
|
||||||
# For API compatibility, create a temporary DBMessage for the current message
|
# Generate response using internal method
|
||||||
# so RAG context can be properly added
|
# Create a temporary message object for the current user message
|
||||||
from app.models.chatbot import ChatbotMessage as DBMessage
|
temp_messages = [
|
||||||
|
DBMessage(
|
||||||
|
id=0,
|
||||||
|
conversation_id=0,
|
||||||
|
role="user",
|
||||||
|
content=message,
|
||||||
|
timestamp=datetime.utcnow(),
|
||||||
|
metadata={}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
# Create a temporary user message with the current message
|
|
||||||
temp_user_message = DBMessage(
|
|
||||||
conversation_id="temp_conversation",
|
|
||||||
role=MessageRole.USER.value,
|
|
||||||
content=message
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate response using internal method with the current message included
|
|
||||||
response_content, sources = await self._generate_response(
|
response_content, sources = await self._generate_response(
|
||||||
message, [temp_user_message], config, None, db
|
message, temp_messages, config, None, db
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -53,14 +53,13 @@ except ImportError:
|
|||||||
PYTHON_DOCX_AVAILABLE = False
|
PYTHON_DOCX_AVAILABLE = False
|
||||||
|
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
|
from qdrant_client.models import Distance, VectorParams, PointStruct, ScoredPoint, Filter, FieldCondition, MatchValue
|
||||||
from qdrant_client.http import models
|
from qdrant_client.http import models
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging import log_module_event
|
from app.core.logging import log_module_event
|
||||||
from app.services.base_module import BaseModule, Permission
|
from app.services.base_module import BaseModule, Permission
|
||||||
from app.services.enhanced_embedding_service import enhanced_embedding_service
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -135,6 +134,19 @@ class RAGModule(BaseModule):
|
|||||||
self.embedding_service = None
|
self.embedding_service = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|
||||||
|
# Set improved default configuration
|
||||||
|
self.config = {
|
||||||
|
"chunk_size": 300, # Reduced from 400 for better precision
|
||||||
|
"chunk_overlap": 50, # Added overlap for context preservation
|
||||||
|
"max_results": 10,
|
||||||
|
"score_threshold": 0.3, # Increased from 0.0 to filter low-quality results
|
||||||
|
"enable_hybrid": True, # Enable hybrid search (vector + BM25)
|
||||||
|
"hybrid_weights": {"vector": 0.7, "bm25": 0.3} # Weight for hybrid scoring
|
||||||
|
}
|
||||||
|
# Update with any provided config
|
||||||
|
if config:
|
||||||
|
self.config.update(config)
|
||||||
|
|
||||||
# Content processing components
|
# Content processing components
|
||||||
self.nlp_model = None
|
self.nlp_model = None
|
||||||
self.lemmatizer = None
|
self.lemmatizer = None
|
||||||
@@ -640,18 +652,32 @@ class RAGModule(BaseModule):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
|
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
|
||||||
"""Split text into chunks"""
|
"""Split text into overlapping chunks for better context preservation"""
|
||||||
chunk_size = chunk_size or self.config.get("chunk_size", 400)
|
chunk_size = chunk_size or self.config.get("chunk_size", 300)
|
||||||
|
chunk_overlap = self.config.get("chunk_overlap", 50)
|
||||||
|
|
||||||
# Tokenize text
|
# Tokenize text
|
||||||
tokens = self.tokenizer.encode(text)
|
tokens = self.tokenizer.encode(text)
|
||||||
|
|
||||||
# Split into chunks
|
# Split into chunks with overlap
|
||||||
chunks = []
|
chunks = []
|
||||||
for i in range(0, len(tokens), chunk_size):
|
start_idx = 0
|
||||||
chunk_tokens = tokens[i:i + chunk_size]
|
|
||||||
|
while start_idx < len(tokens):
|
||||||
|
end_idx = min(start_idx + chunk_size, len(tokens))
|
||||||
|
chunk_tokens = tokens[start_idx:end_idx]
|
||||||
chunk_text = self.tokenizer.decode(chunk_tokens)
|
chunk_text = self.tokenizer.decode(chunk_tokens)
|
||||||
chunks.append(chunk_text)
|
|
||||||
|
# Only add non-empty chunks
|
||||||
|
if chunk_text.strip():
|
||||||
|
chunks.append(chunk_text)
|
||||||
|
|
||||||
|
# Move to next chunk with overlap
|
||||||
|
start_idx = end_idx - chunk_overlap
|
||||||
|
|
||||||
|
# Ensure progress (in case overlap >= chunk_size)
|
||||||
|
if start_idx >= end_idx:
|
||||||
|
start_idx = end_idx
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
@@ -1126,16 +1152,8 @@ class RAGModule(BaseModule):
|
|||||||
# Chunk the document
|
# Chunk the document
|
||||||
chunks = self._chunk_text(content)
|
chunks = self._chunk_text(content)
|
||||||
|
|
||||||
# Generate embeddings with enhanced rate limiting handling
|
# Generate embeddings for all chunks in batch (more efficient)
|
||||||
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
|
embeddings = await self._generate_embeddings(chunks)
|
||||||
|
|
||||||
# Log if fallback embeddings were used
|
|
||||||
if not success:
|
|
||||||
logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded")
|
|
||||||
log_module_event("rag", "fallback_embeddings_used", {
|
|
||||||
"document_id": doc_id,
|
|
||||||
"content_preview": content[:100] + "..." if len(content) > 100 else content
|
|
||||||
})
|
|
||||||
|
|
||||||
# Create document points
|
# Create document points
|
||||||
points = []
|
points = []
|
||||||
@@ -1197,16 +1215,8 @@ class RAGModule(BaseModule):
|
|||||||
# Chunk the document
|
# Chunk the document
|
||||||
chunks = self._chunk_text(processed_doc.content)
|
chunks = self._chunk_text(processed_doc.content)
|
||||||
|
|
||||||
# Generate embeddings with enhanced rate limiting handling
|
# Generate embeddings for all chunks in batch (more efficient)
|
||||||
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
|
embeddings = await self._generate_embeddings(chunks)
|
||||||
|
|
||||||
# Log if fallback embeddings were used
|
|
||||||
if not success:
|
|
||||||
logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded")
|
|
||||||
log_module_event("rag", "fallback_embeddings_used", {
|
|
||||||
"document_id": processed_doc.id,
|
|
||||||
"filename": processed_doc.original_filename
|
|
||||||
})
|
|
||||||
|
|
||||||
# Create document points with enhanced metadata
|
# Create document points with enhanced metadata
|
||||||
points = []
|
points = []
|
||||||
@@ -1277,6 +1287,154 @@ class RAGModule(BaseModule):
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def _hybrid_search(self, collection_name: str, query: str, query_vector: List[float],
|
||||||
|
query_filter: Optional[Filter], limit: int, score_threshold: float) -> List[Any]:
|
||||||
|
"""Perform hybrid search combining vector similarity and BM25 scoring"""
|
||||||
|
|
||||||
|
# Preprocess query for BM25
|
||||||
|
query_terms = self._preprocess_text_for_bm25(query)
|
||||||
|
|
||||||
|
# Get all documents from the collection (for BM25 scoring)
|
||||||
|
# Note: In production, you'd want to optimize this with a proper BM25 index
|
||||||
|
scroll_filter = query_filter or Filter()
|
||||||
|
all_points = []
|
||||||
|
|
||||||
|
# Use scroll to get all points
|
||||||
|
offset = None
|
||||||
|
batch_size = 100
|
||||||
|
while True:
|
||||||
|
search_result = self.qdrant_client.scroll(
|
||||||
|
collection_name=collection_name,
|
||||||
|
scroll_filter=scroll_filter,
|
||||||
|
limit=batch_size,
|
||||||
|
offset=offset,
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=False
|
||||||
|
)
|
||||||
|
|
||||||
|
points = search_result[0]
|
||||||
|
all_points.extend(points)
|
||||||
|
|
||||||
|
if len(points) < batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
offset = points[-1].id
|
||||||
|
|
||||||
|
# Calculate BM25 scores for each document
|
||||||
|
bm25_scores = {}
|
||||||
|
for point in all_points:
|
||||||
|
doc_id = point.payload.get("document_id", "")
|
||||||
|
content = point.payload.get("content", "")
|
||||||
|
|
||||||
|
# Calculate BM25 score
|
||||||
|
bm25_score = self._calculate_bm25_score(query_terms, content)
|
||||||
|
bm25_scores[doc_id] = bm25_score
|
||||||
|
|
||||||
|
# Perform vector search
|
||||||
|
vector_results = self.qdrant_client.search(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query_vector=query_vector,
|
||||||
|
query_filter=query_filter,
|
||||||
|
limit=limit * 2, # Get more results for re-ranking
|
||||||
|
score_threshold=score_threshold / 2 # Lower threshold for initial search
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine scores
|
||||||
|
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
|
||||||
|
vector_weight = hybrid_weights.get("vector", 0.7)
|
||||||
|
bm25_weight = hybrid_weights.get("bm25", 0.3)
|
||||||
|
|
||||||
|
# Create hybrid results
|
||||||
|
hybrid_results = []
|
||||||
|
for result in vector_results:
|
||||||
|
doc_id = result.payload.get("document_id", "")
|
||||||
|
vector_score = result.score
|
||||||
|
bm25_score = bm25_scores.get(doc_id, 0.0)
|
||||||
|
|
||||||
|
# Normalize scores (simple min-max normalization)
|
||||||
|
vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0
|
||||||
|
bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1
|
||||||
|
|
||||||
|
# Calculate hybrid score
|
||||||
|
hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm)
|
||||||
|
|
||||||
|
# Create new point with hybrid score
|
||||||
|
hybrid_point = ScoredPoint(
|
||||||
|
id=result.id,
|
||||||
|
payload=result.payload,
|
||||||
|
score=hybrid_score,
|
||||||
|
vector=result.vector,
|
||||||
|
shard_key=None,
|
||||||
|
order_value=None
|
||||||
|
)
|
||||||
|
hybrid_results.append(hybrid_point)
|
||||||
|
|
||||||
|
# Sort by hybrid score and apply final threshold
|
||||||
|
hybrid_results.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
final_results = [r for r in hybrid_results if r.score >= score_threshold][:limit]
|
||||||
|
|
||||||
|
logger.info(f"Hybrid search: {len(vector_results)} vector results, {len(final_results)} final results")
|
||||||
|
return final_results
|
||||||
|
|
||||||
|
def _preprocess_text_for_bm25(self, text: str) -> List[str]:
|
||||||
|
"""Preprocess text for BM25 scoring"""
|
||||||
|
if not NLTK_AVAILABLE:
|
||||||
|
return text.lower().split()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Tokenize
|
||||||
|
tokens = word_tokenize(text.lower())
|
||||||
|
|
||||||
|
# Remove stopwords and non-alphabetic tokens
|
||||||
|
stop_words = set(stopwords.words('english'))
|
||||||
|
filtered_tokens = [
|
||||||
|
token for token in tokens
|
||||||
|
if token.isalpha() and token not in stop_words and len(token) > 2
|
||||||
|
]
|
||||||
|
|
||||||
|
return filtered_tokens
|
||||||
|
except:
|
||||||
|
# Fallback to simple splitting
|
||||||
|
return text.lower().split()
|
||||||
|
|
||||||
|
def _calculate_bm25_score(self, query_terms: List[str], document: str) -> float:
|
||||||
|
"""Calculate BM25 score for a document against query terms"""
|
||||||
|
if not query_terms:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Preprocess document
|
||||||
|
doc_terms = self._preprocess_text_for_bm25(document)
|
||||||
|
if not doc_terms:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Calculate term frequencies
|
||||||
|
doc_len = len(doc_terms)
|
||||||
|
avg_doc_len = 300 # Average document length (configurable)
|
||||||
|
|
||||||
|
# BM25 parameters
|
||||||
|
k1 = 1.2 # Controls term frequency saturation
|
||||||
|
b = 0.75 # Controls document length normalization
|
||||||
|
|
||||||
|
score = 0.0
|
||||||
|
|
||||||
|
# Calculate IDF for each query term
|
||||||
|
for term in set(query_terms):
|
||||||
|
# Term frequency in document
|
||||||
|
tf = doc_terms.count(term)
|
||||||
|
|
||||||
|
# Simple IDF (log(N/n) + 1)
|
||||||
|
# In production, you'd use the actual document frequency
|
||||||
|
idf = 2.0 # Simplified IDF
|
||||||
|
|
||||||
|
# BM25 formula
|
||||||
|
numerator = tf * (k1 + 1)
|
||||||
|
denominator = tf + k1 * (1 - b + b * (doc_len / avg_doc_len))
|
||||||
|
|
||||||
|
score += idf * (numerator / denominator)
|
||||||
|
|
||||||
|
# Normalize score to 0-1 range
|
||||||
|
return min(score / 10.0, 1.0) # Simple normalization
|
||||||
|
|
||||||
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
|
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
|
||||||
"""Search for relevant documents"""
|
"""Search for relevant documents"""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
@@ -1314,14 +1472,29 @@ class RAGModule(BaseModule):
|
|||||||
logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}")
|
logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}")
|
||||||
logger.info(f"Embedding service available: {self.embedding_service is not None}")
|
logger.info(f"Embedding service available: {self.embedding_service is not None}")
|
||||||
|
|
||||||
# Search in Qdrant
|
# Check if hybrid search is enabled
|
||||||
search_results = self.qdrant_client.search(
|
enable_hybrid = self.config.get("enable_hybrid", False)
|
||||||
collection_name=collection_name,
|
score_threshold = self.config.get("score_threshold", 0.3)
|
||||||
query_vector=query_embedding,
|
|
||||||
query_filter=search_filter,
|
if enable_hybrid and NLTK_AVAILABLE:
|
||||||
limit=max_results,
|
# Perform hybrid search (vector + BM25)
|
||||||
score_threshold=0.0 # Lowered from 0.5 to see all results including low scores
|
search_results = await self._hybrid_search(
|
||||||
)
|
collection_name=collection_name,
|
||||||
|
query=query,
|
||||||
|
query_vector=query_embedding,
|
||||||
|
query_filter=search_filter,
|
||||||
|
limit=max_results,
|
||||||
|
score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Pure vector search with improved threshold
|
||||||
|
search_results = self.qdrant_client.search(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query_vector=query_embedding,
|
||||||
|
query_filter=search_filter,
|
||||||
|
limit=max_results,
|
||||||
|
score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Raw search results count: {len(search_results)}")
|
logger.info(f"Raw search results count: {len(search_results)}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user