rag improvements

This commit is contained in:
2025-09-21 18:44:02 +02:00
parent f58a76ac59
commit a2ee959ec9
3 changed files with 401 additions and 86 deletions

View File

@@ -0,0 +1,153 @@
"""
Token-based rate limiting for LLM service
"""
import time
import redis
from typing import Dict, Optional, Tuple
from datetime import datetime, timedelta
from ..core.config import settings
from ..core.logging import get_logger
logger = get_logger(__name__)
class TokenRateLimiter:
"""Token-based rate limiting implementation"""
def __init__(self):
try:
self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
self.redis_client.ping()
logger.info("Token rate limiter initialized with Redis backend")
except Exception as e:
logger.warning(f"Redis not available for token rate limiting: {e}")
self.redis_client = None
# Fall back to in-memory rate limiting
self.in_memory_store = {}
logger.info("Token rate limiter using in-memory fallback")
async def check_token_limits(
self,
provider: str,
prompt_tokens: int,
completion_tokens: int = 0
) -> Tuple[bool, Dict[str, str]]:
"""
Check if token usage is within limits
Args:
provider: Provider name (e.g., "privatemode")
prompt_tokens: Number of prompt tokens to use
completion_tokens: Number of completion tokens to use
Returns:
Tuple of (is_allowed, headers)
"""
# Get token limits from configuration
from .config import get_config
config = get_config()
token_limits = config.token_limits_per_minute
# Check organization-wide limits
org_key = f"tokens:org:{provider}"
# Get current usage
current_usage = await self._get_token_usage(org_key)
# Calculate new usage
new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens
new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens
# Check limits
prompt_limit = token_limits.get("prompt_tokens", 20000)
completion_limit = token_limits.get("completion_tokens", 10000)
is_allowed = (
new_prompt_tokens <= prompt_limit and
new_completion_tokens <= completion_limit
)
if is_allowed:
# Update usage
await self._update_token_usage(org_key, prompt_tokens, completion_tokens)
logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, "
f"{new_completion_tokens}/{completion_limit} completion")
# Calculate remaining tokens
remaining_prompt = max(0, prompt_limit - new_prompt_tokens)
remaining_completion = max(0, completion_limit - new_completion_tokens)
# Create headers
headers = {
"X-TokenLimit-Prompt-Remaining": str(remaining_prompt),
"X-TokenLimit-Completion-Remaining": str(remaining_completion),
"X-TokenLimit-Prompt-Limit": str(prompt_limit),
"X-TokenLimit-Completion-Limit": str(completion_limit),
"X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute
}
if not is_allowed:
logger.warning(f"Token rate limit exceeded for {provider}. "
f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. "
f"Current: {current_usage}")
return is_allowed, headers
async def _get_token_usage(self, key: str) -> Dict[str, int]:
"""Get current token usage"""
if self.redis_client:
try:
data = self.redis_client.hgetall(key)
if data:
return {
"prompt_tokens": int(data.get("prompt_tokens", 0)),
"completion_tokens": int(data.get("completion_tokens", 0)),
"updated_at": float(data.get("updated_at", time.time()))
}
except Exception as e:
logger.error(f"Error getting token usage from Redis: {e}")
# Fallback to in-memory
return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0})
async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int):
"""Update token usage"""
if self.redis_client:
try:
pipe = self.redis_client.pipeline()
pipe.hincrby(key, "prompt_tokens", prompt_tokens)
pipe.hincrby(key, "completion_tokens", completion_tokens)
pipe.hset(key, "updated_at", time.time())
pipe.expire(key, 60) # Expire after 1 minute
pipe.execute()
except Exception as e:
logger.error(f"Error updating token usage in Redis: {e}")
# Fallback to in-memory
self._update_in_memory(key, prompt_tokens, completion_tokens)
else:
self._update_in_memory(key, prompt_tokens, completion_tokens)
def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int):
"""Update in-memory token usage"""
if key not in self.in_memory_store:
self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0}
self.in_memory_store[key]["prompt_tokens"] += prompt_tokens
self.in_memory_store[key]["completion_tokens"] += completion_tokens
self.in_memory_store[key]["updated_at"] = time.time()
def cleanup_expired(self):
"""Clean up expired entries (for in-memory store)"""
if not self.redis_client:
current_time = time.time()
expired_keys = [
key for key, data in self.in_memory_store.items()
if current_time - data.get("updated_at", 0) > 60
]
for key in expired_keys:
del self.in_memory_store[key]
# Global token rate limiter instance
token_rate_limiter = TokenRateLimiter()

View File

@@ -265,7 +265,6 @@ class ChatbotModule(BaseModule):
async def chat_completion(self, request: ChatRequest, user_id: str, db: Session) -> ChatResponse:
"""Generate chat completion response"""
logger.info("=== CHAT COMPLETION METHOD CALLED ===")
# Get chatbot configuration from database
db_chatbot = db.query(DBChatbotInstance).filter(DBChatbotInstance.id == request.chatbot_id).first()
@@ -364,11 +363,10 @@ class ChatbotModule(BaseModule):
metadata={"error": str(e), "fallback": True}
)
async def _generate_response(self, message: str, db_messages: List[DBMessage],
async def _generate_response(self, message: str, db_messages: List[DBMessage],
config: ChatbotConfig, context: Optional[Dict] = None, db: Session = None) -> tuple[str, Optional[List]]:
"""Generate response using LLM with optional RAG"""
logger.info("=== _generate_response METHOD CALLED ===")
# Lazy load dependencies if not available
await self._ensure_dependencies()
@@ -397,8 +395,8 @@ class ChatbotModule(BaseModule):
for i, result in enumerate(rag_results)]
# Build full RAG context from all results
rag_context = "\\n\\nRelevant information from knowledge base:\\n" + "\\n\\n".join([
f"[Document {i+1}]:\\n{result.document.content}" for i, result in enumerate(rag_results)
rag_context = "\n\nRelevant information from knowledge base:\n" + "\n\n".join([
f"[Document {i+1}]:\n{result.document.content}" for i, result in enumerate(rag_results)
])
# Detailed RAG logging - ALWAYS log for debugging
@@ -407,14 +405,14 @@ class ChatbotModule(BaseModule):
logger.info(f"Collection: {qdrant_collection_name}")
logger.info(f"Number of results: {len(rag_results)}")
for i, result in enumerate(rag_results):
logger.info(f"\\n--- RAG Result {i+1} ---")
logger.info(f"\n--- RAG Result {i+1} ---")
logger.info(f"Score: {getattr(result, 'score', 'N/A')}")
logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}")
logger.info(f"Full Content ({len(result.document.content)} chars):")
logger.info(f"{result.document.content}")
if hasattr(result.document, 'metadata'):
logger.info(f"Metadata: {result.document.metadata}")
logger.info(f"\\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
logger.info(f"\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===")
logger.info(rag_context)
logger.info("=== END RAG SEARCH RESULTS ===")
else:
@@ -428,11 +426,6 @@ class ChatbotModule(BaseModule):
logger.warning(f"RAG search traceback: {traceback.format_exc()}")
# Build conversation context (includes the current message from db_messages)
logger.info(f"=== CRITICAL DEBUG ===")
logger.info(f"rag_context length: {len(rag_context)}")
logger.info(f"rag_context empty: {not rag_context}")
logger.info(f"rag_context preview: {rag_context[:200] if rag_context else 'EMPTY'}")
logger.info(f"=== END CRITICAL DEBUG ===")
messages = self._build_conversation_messages(db_messages, config, rag_context, context)
# Note: Current user message is already included in db_messages from the query
@@ -452,9 +445,9 @@ class ChatbotModule(BaseModule):
if config.use_rag and rag_context:
logger.info(f"RAG context added: {len(rag_context)} characters")
logger.info(f"RAG sources: {len(sources) if sources else 0} documents")
logger.info("\\n=== COMPLETE MESSAGES SENT TO LLM ===")
logger.info("\n=== COMPLETE MESSAGES SENT TO LLM ===")
for i, msg in enumerate(messages):
logger.info(f"\\n--- Message {i+1} ---")
logger.info(f"\n--- Message {i+1} ---")
logger.info(f"Role: {msg['role']}")
logger.info(f"Content ({len(msg['content'])} chars):")
# Truncate long content for logging (full RAG context can be very long)
@@ -518,38 +511,34 @@ class ChatbotModule(BaseModule):
# Return fallback if available
return "I'm currently unable to process your request. Please try again later.", None
def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig,
def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig,
rag_context: str = "", context: Optional[Dict] = None) -> List[Dict]:
"""Build messages array for LLM completion"""
messages = []
logger.info(f"DEBUG: _build_conversation_messages called. rag_context length: {len(rag_context)}")
# System prompt - keep it clean without RAG context
# System prompt
system_prompt = config.system_prompt
if rag_context:
# Add explicit instruction to use RAG context
system_prompt += "\n\nIMPORTANT: Use the following information from the knowledge base to answer the user's question. " \
"This information is directly relevant to their query and should be your primary source:\n" + rag_context
if context and context.get('additional_instructions'):
system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}"
messages.append({"role": "system", "content": system_prompt})
logger.info(f"Building messages from {len(db_messages)} database messages")
# Conversation history (messages are already limited by memory_length in the query)
# Reverse to get chronological order
# Include ALL messages - the current user message is needed for the LLM to respond!
for idx, msg in enumerate(reversed(db_messages)):
logger.info(f"Processing message {idx}: role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...")
if msg.role in ["user", "assistant"]:
# For user messages, prepend RAG context if available
content = msg.content
if msg.role == "user" and rag_context and idx == 0:
# Add RAG context to the current user message (first in reversed order)
content = f"Relevant information from knowledge base:\n{rag_context}\n\nQuestion: {msg.content}"
logger.info("Added RAG context to user message")
messages.append({
"role": msg.role,
"content": content
"content": msg.content
})
logger.info(f"Added message with role {msg.role} to LLM messages")
else:
@@ -690,10 +679,9 @@ class ChatbotModule(BaseModule):
return router
# API Compatibility Methods
async def chat(self, chatbot_config: Dict[str, Any], message: str,
async def chat(self, chatbot_config: Dict[str, Any], message: str,
conversation_history: List = None, user_id: str = "anonymous") -> Dict[str, Any]:
"""Chat method for API compatibility"""
logger.info("=== CHAT METHOD (API COMPATIBILITY) CALLED ===")
logger.info(f"Chat method called with message: {message[:50]}... by user: {user_id}")
# Lazy load dependencies
@@ -723,20 +711,21 @@ class ChatbotModule(BaseModule):
fallback_responses=chatbot_config.get("fallback_responses", [])
)
# For API compatibility, create a temporary DBMessage for the current message
# so RAG context can be properly added
from app.models.chatbot import ChatbotMessage as DBMessage
# Generate response using internal method
# Create a temporary message object for the current user message
temp_messages = [
DBMessage(
id=0,
conversation_id=0,
role="user",
content=message,
timestamp=datetime.utcnow(),
metadata={}
)
]
# Create a temporary user message with the current message
temp_user_message = DBMessage(
conversation_id="temp_conversation",
role=MessageRole.USER.value,
content=message
)
# Generate response using internal method with the current message included
response_content, sources = await self._generate_response(
message, [temp_user_message], config, None, db
message, temp_messages, config, None, db
)
return {

View File

@@ -53,14 +53,13 @@ except ImportError:
PYTHON_DOCX_AVAILABLE = False
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
from qdrant_client.models import Distance, VectorParams, PointStruct, ScoredPoint, Filter, FieldCondition, MatchValue
from qdrant_client.http import models
import tiktoken
from app.core.config import settings
from app.core.logging import log_module_event
from app.services.base_module import BaseModule, Permission
from app.services.enhanced_embedding_service import enhanced_embedding_service
@dataclass
@@ -134,6 +133,19 @@ class RAGModule(BaseModule):
self.embedding_model = None
self.embedding_service = None
self.tokenizer = None
# Set improved default configuration
self.config = {
"chunk_size": 300, # Reduced from 400 for better precision
"chunk_overlap": 50, # Added overlap for context preservation
"max_results": 10,
"score_threshold": 0.3, # Increased from 0.0 to filter low-quality results
"enable_hybrid": True, # Enable hybrid search (vector + BM25)
"hybrid_weights": {"vector": 0.7, "bm25": 0.3} # Weight for hybrid scoring
}
# Update with any provided config
if config:
self.config.update(config)
# Content processing components
self.nlp_model = None
@@ -640,19 +652,33 @@ class RAGModule(BaseModule):
return embeddings
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
"""Split text into chunks"""
chunk_size = chunk_size or self.config.get("chunk_size", 400)
"""Split text into overlapping chunks for better context preservation"""
chunk_size = chunk_size or self.config.get("chunk_size", 300)
chunk_overlap = self.config.get("chunk_overlap", 50)
# Tokenize text
tokens = self.tokenizer.encode(text)
# Split into chunks
# Split into chunks with overlap
chunks = []
for i in range(0, len(tokens), chunk_size):
chunk_tokens = tokens[i:i + chunk_size]
start_idx = 0
while start_idx < len(tokens):
end_idx = min(start_idx + chunk_size, len(tokens))
chunk_tokens = tokens[start_idx:end_idx]
chunk_text = self.tokenizer.decode(chunk_tokens)
chunks.append(chunk_text)
# Only add non-empty chunks
if chunk_text.strip():
chunks.append(chunk_text)
# Move to next chunk with overlap
start_idx = end_idx - chunk_overlap
# Ensure progress (in case overlap >= chunk_size)
if start_idx >= end_idx:
start_idx = end_idx
return chunks
async def _process_text(self, content: bytes, filename: str) -> str:
@@ -1126,17 +1152,9 @@ class RAGModule(BaseModule):
# Chunk the document
chunks = self._chunk_text(content)
# Generate embeddings with enhanced rate limiting handling
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
# Log if fallback embeddings were used
if not success:
logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded")
log_module_event("rag", "fallback_embeddings_used", {
"document_id": doc_id,
"content_preview": content[:100] + "..." if len(content) > 100 else content
})
# Generate embeddings for all chunks in batch (more efficient)
embeddings = await self._generate_embeddings(chunks)
# Create document points
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
@@ -1197,17 +1215,9 @@ class RAGModule(BaseModule):
# Chunk the document
chunks = self._chunk_text(processed_doc.content)
# Generate embeddings with enhanced rate limiting handling
embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks)
# Log if fallback embeddings were used
if not success:
logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded")
log_module_event("rag", "fallback_embeddings_used", {
"document_id": processed_doc.id,
"filename": processed_doc.original_filename
})
# Generate embeddings for all chunks in batch (more efficient)
embeddings = await self._generate_embeddings(chunks)
# Create document points with enhanced metadata
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
@@ -1277,6 +1287,154 @@ class RAGModule(BaseModule):
except Exception:
return False
async def _hybrid_search(self, collection_name: str, query: str, query_vector: List[float],
query_filter: Optional[Filter], limit: int, score_threshold: float) -> List[Any]:
"""Perform hybrid search combining vector similarity and BM25 scoring"""
# Preprocess query for BM25
query_terms = self._preprocess_text_for_bm25(query)
# Get all documents from the collection (for BM25 scoring)
# Note: In production, you'd want to optimize this with a proper BM25 index
scroll_filter = query_filter or Filter()
all_points = []
# Use scroll to get all points
offset = None
batch_size = 100
while True:
search_result = self.qdrant_client.scroll(
collection_name=collection_name,
scroll_filter=scroll_filter,
limit=batch_size,
offset=offset,
with_payload=True,
with_vectors=False
)
points = search_result[0]
all_points.extend(points)
if len(points) < batch_size:
break
offset = points[-1].id
# Calculate BM25 scores for each document
bm25_scores = {}
for point in all_points:
doc_id = point.payload.get("document_id", "")
content = point.payload.get("content", "")
# Calculate BM25 score
bm25_score = self._calculate_bm25_score(query_terms, content)
bm25_scores[doc_id] = bm25_score
# Perform vector search
vector_results = self.qdrant_client.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=query_filter,
limit=limit * 2, # Get more results for re-ranking
score_threshold=score_threshold / 2 # Lower threshold for initial search
)
# Combine scores
hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3})
vector_weight = hybrid_weights.get("vector", 0.7)
bm25_weight = hybrid_weights.get("bm25", 0.3)
# Create hybrid results
hybrid_results = []
for result in vector_results:
doc_id = result.payload.get("document_id", "")
vector_score = result.score
bm25_score = bm25_scores.get(doc_id, 0.0)
# Normalize scores (simple min-max normalization)
vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0
bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1
# Calculate hybrid score
hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm)
# Create new point with hybrid score
hybrid_point = ScoredPoint(
id=result.id,
payload=result.payload,
score=hybrid_score,
vector=result.vector,
shard_key=None,
order_value=None
)
hybrid_results.append(hybrid_point)
# Sort by hybrid score and apply final threshold
hybrid_results.sort(key=lambda x: x.score, reverse=True)
final_results = [r for r in hybrid_results if r.score >= score_threshold][:limit]
logger.info(f"Hybrid search: {len(vector_results)} vector results, {len(final_results)} final results")
return final_results
def _preprocess_text_for_bm25(self, text: str) -> List[str]:
"""Preprocess text for BM25 scoring"""
if not NLTK_AVAILABLE:
return text.lower().split()
try:
# Tokenize
tokens = word_tokenize(text.lower())
# Remove stopwords and non-alphabetic tokens
stop_words = set(stopwords.words('english'))
filtered_tokens = [
token for token in tokens
if token.isalpha() and token not in stop_words and len(token) > 2
]
return filtered_tokens
except:
# Fallback to simple splitting
return text.lower().split()
def _calculate_bm25_score(self, query_terms: List[str], document: str) -> float:
"""Calculate BM25 score for a document against query terms"""
if not query_terms:
return 0.0
# Preprocess document
doc_terms = self._preprocess_text_for_bm25(document)
if not doc_terms:
return 0.0
# Calculate term frequencies
doc_len = len(doc_terms)
avg_doc_len = 300 # Average document length (configurable)
# BM25 parameters
k1 = 1.2 # Controls term frequency saturation
b = 0.75 # Controls document length normalization
score = 0.0
# Calculate IDF for each query term
for term in set(query_terms):
# Term frequency in document
tf = doc_terms.count(term)
# Simple IDF (log(N/n) + 1)
# In production, you'd use the actual document frequency
idf = 2.0 # Simplified IDF
# BM25 formula
numerator = tf * (k1 + 1)
denominator = tf + k1 * (1 - b + b * (doc_len / avg_doc_len))
score += idf * (numerator / denominator)
# Normalize score to 0-1 range
return min(score / 10.0, 1.0) # Simple normalization
async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]:
"""Search for relevant documents"""
if not self.enabled:
@@ -1314,14 +1472,29 @@ class RAGModule(BaseModule):
logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}")
logger.info(f"Embedding service available: {self.embedding_service is not None}")
# Search in Qdrant
search_results = self.qdrant_client.search(
collection_name=collection_name,
query_vector=query_embedding,
query_filter=search_filter,
limit=max_results,
score_threshold=0.0 # Lowered from 0.5 to see all results including low scores
)
# Check if hybrid search is enabled
enable_hybrid = self.config.get("enable_hybrid", False)
score_threshold = self.config.get("score_threshold", 0.3)
if enable_hybrid and NLTK_AVAILABLE:
# Perform hybrid search (vector + BM25)
search_results = await self._hybrid_search(
collection_name=collection_name,
query=query,
query_vector=query_embedding,
query_filter=search_filter,
limit=max_results,
score_threshold=score_threshold
)
else:
# Pure vector search with improved threshold
search_results = self.qdrant_client.search(
collection_name=collection_name,
query_vector=query_embedding,
query_filter=search_filter,
limit=max_results,
score_threshold=score_threshold
)
logger.info(f"Raw search results count: {len(search_results)}")