From a2ee959ec951b0ecedd5112b7e467dc237b7a504 Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Sun, 21 Sep 2025 18:44:02 +0200 Subject: [PATCH] rag improvements --- .../app/services/llm/token_rate_limiter.py | 153 +++++++++++ backend/modules/chatbot/main.py | 79 +++--- backend/modules/rag/main.py | 255 +++++++++++++++--- 3 files changed, 401 insertions(+), 86 deletions(-) create mode 100644 backend/app/services/llm/token_rate_limiter.py diff --git a/backend/app/services/llm/token_rate_limiter.py b/backend/app/services/llm/token_rate_limiter.py new file mode 100644 index 0000000..2338a03 --- /dev/null +++ b/backend/app/services/llm/token_rate_limiter.py @@ -0,0 +1,153 @@ +""" +Token-based rate limiting for LLM service +""" + +import time +import redis +from typing import Dict, Optional, Tuple +from datetime import datetime, timedelta +from ..core.config import settings +from ..core.logging import get_logger + +logger = get_logger(__name__) + + +class TokenRateLimiter: + """Token-based rate limiting implementation""" + + def __init__(self): + try: + self.redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) + self.redis_client.ping() + logger.info("Token rate limiter initialized with Redis backend") + except Exception as e: + logger.warning(f"Redis not available for token rate limiting: {e}") + self.redis_client = None + # Fall back to in-memory rate limiting + self.in_memory_store = {} + logger.info("Token rate limiter using in-memory fallback") + + async def check_token_limits( + self, + provider: str, + prompt_tokens: int, + completion_tokens: int = 0 + ) -> Tuple[bool, Dict[str, str]]: + """ + Check if token usage is within limits + + Args: + provider: Provider name (e.g., "privatemode") + prompt_tokens: Number of prompt tokens to use + completion_tokens: Number of completion tokens to use + + Returns: + Tuple of (is_allowed, headers) + """ + # Get token limits from configuration + from .config import get_config + config = get_config() + token_limits = config.token_limits_per_minute + + # Check organization-wide limits + org_key = f"tokens:org:{provider}" + + # Get current usage + current_usage = await self._get_token_usage(org_key) + + # Calculate new usage + new_prompt_tokens = current_usage.get("prompt_tokens", 0) + prompt_tokens + new_completion_tokens = current_usage.get("completion_tokens", 0) + completion_tokens + + # Check limits + prompt_limit = token_limits.get("prompt_tokens", 20000) + completion_limit = token_limits.get("completion_tokens", 10000) + + is_allowed = ( + new_prompt_tokens <= prompt_limit and + new_completion_tokens <= completion_limit + ) + + if is_allowed: + # Update usage + await self._update_token_usage(org_key, prompt_tokens, completion_tokens) + logger.debug(f"Token usage updated: {new_prompt_tokens}/{prompt_limit} prompt, " + f"{new_completion_tokens}/{completion_limit} completion") + + # Calculate remaining tokens + remaining_prompt = max(0, prompt_limit - new_prompt_tokens) + remaining_completion = max(0, completion_limit - new_completion_tokens) + + # Create headers + headers = { + "X-TokenLimit-Prompt-Remaining": str(remaining_prompt), + "X-TokenLimit-Completion-Remaining": str(remaining_completion), + "X-TokenLimit-Prompt-Limit": str(prompt_limit), + "X-TokenLimit-Completion-Limit": str(completion_limit), + "X-TokenLimit-Reset": str(int(time.time() + 60)) # Reset in 1 minute + } + + if not is_allowed: + logger.warning(f"Token rate limit exceeded for {provider}. " + f"Requested: {prompt_tokens} prompt, {completion_tokens} completion. " + f"Current: {current_usage}") + + return is_allowed, headers + + async def _get_token_usage(self, key: str) -> Dict[str, int]: + """Get current token usage""" + if self.redis_client: + try: + data = self.redis_client.hgetall(key) + if data: + return { + "prompt_tokens": int(data.get("prompt_tokens", 0)), + "completion_tokens": int(data.get("completion_tokens", 0)), + "updated_at": float(data.get("updated_at", time.time())) + } + except Exception as e: + logger.error(f"Error getting token usage from Redis: {e}") + + # Fallback to in-memory + return self.in_memory_store.get(key, {"prompt_tokens": 0, "completion_tokens": 0}) + + async def _update_token_usage(self, key: str, prompt_tokens: int, completion_tokens: int): + """Update token usage""" + if self.redis_client: + try: + pipe = self.redis_client.pipeline() + pipe.hincrby(key, "prompt_tokens", prompt_tokens) + pipe.hincrby(key, "completion_tokens", completion_tokens) + pipe.hset(key, "updated_at", time.time()) + pipe.expire(key, 60) # Expire after 1 minute + pipe.execute() + except Exception as e: + logger.error(f"Error updating token usage in Redis: {e}") + # Fallback to in-memory + self._update_in_memory(key, prompt_tokens, completion_tokens) + else: + self._update_in_memory(key, prompt_tokens, completion_tokens) + + def _update_in_memory(self, key: str, prompt_tokens: int, completion_tokens: int): + """Update in-memory token usage""" + if key not in self.in_memory_store: + self.in_memory_store[key] = {"prompt_tokens": 0, "completion_tokens": 0} + + self.in_memory_store[key]["prompt_tokens"] += prompt_tokens + self.in_memory_store[key]["completion_tokens"] += completion_tokens + self.in_memory_store[key]["updated_at"] = time.time() + + def cleanup_expired(self): + """Clean up expired entries (for in-memory store)""" + if not self.redis_client: + current_time = time.time() + expired_keys = [ + key for key, data in self.in_memory_store.items() + if current_time - data.get("updated_at", 0) > 60 + ] + for key in expired_keys: + del self.in_memory_store[key] + + +# Global token rate limiter instance +token_rate_limiter = TokenRateLimiter() \ No newline at end of file diff --git a/backend/modules/chatbot/main.py b/backend/modules/chatbot/main.py index 6f42f09..5ab62c7 100644 --- a/backend/modules/chatbot/main.py +++ b/backend/modules/chatbot/main.py @@ -265,7 +265,6 @@ class ChatbotModule(BaseModule): async def chat_completion(self, request: ChatRequest, user_id: str, db: Session) -> ChatResponse: """Generate chat completion response""" - logger.info("=== CHAT COMPLETION METHOD CALLED ===") # Get chatbot configuration from database db_chatbot = db.query(DBChatbotInstance).filter(DBChatbotInstance.id == request.chatbot_id).first() @@ -364,11 +363,10 @@ class ChatbotModule(BaseModule): metadata={"error": str(e), "fallback": True} ) - async def _generate_response(self, message: str, db_messages: List[DBMessage], + async def _generate_response(self, message: str, db_messages: List[DBMessage], config: ChatbotConfig, context: Optional[Dict] = None, db: Session = None) -> tuple[str, Optional[List]]: """Generate response using LLM with optional RAG""" - logger.info("=== _generate_response METHOD CALLED ===") - + # Lazy load dependencies if not available await self._ensure_dependencies() @@ -397,8 +395,8 @@ class ChatbotModule(BaseModule): for i, result in enumerate(rag_results)] # Build full RAG context from all results - rag_context = "\\n\\nRelevant information from knowledge base:\\n" + "\\n\\n".join([ - f"[Document {i+1}]:\\n{result.document.content}" for i, result in enumerate(rag_results) + rag_context = "\n\nRelevant information from knowledge base:\n" + "\n\n".join([ + f"[Document {i+1}]:\n{result.document.content}" for i, result in enumerate(rag_results) ]) # Detailed RAG logging - ALWAYS log for debugging @@ -407,14 +405,14 @@ class ChatbotModule(BaseModule): logger.info(f"Collection: {qdrant_collection_name}") logger.info(f"Number of results: {len(rag_results)}") for i, result in enumerate(rag_results): - logger.info(f"\\n--- RAG Result {i+1} ---") + logger.info(f"\n--- RAG Result {i+1} ---") logger.info(f"Score: {getattr(result, 'score', 'N/A')}") logger.info(f"Document ID: {getattr(result.document, 'id', 'N/A')}") logger.info(f"Full Content ({len(result.document.content)} chars):") logger.info(f"{result.document.content}") if hasattr(result.document, 'metadata'): logger.info(f"Metadata: {result.document.metadata}") - logger.info(f"\\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===") + logger.info(f"\n=== RAG CONTEXT BEING ADDED TO PROMPT ({len(rag_context)} chars) ===") logger.info(rag_context) logger.info("=== END RAG SEARCH RESULTS ===") else: @@ -428,11 +426,6 @@ class ChatbotModule(BaseModule): logger.warning(f"RAG search traceback: {traceback.format_exc()}") # Build conversation context (includes the current message from db_messages) - logger.info(f"=== CRITICAL DEBUG ===") - logger.info(f"rag_context length: {len(rag_context)}") - logger.info(f"rag_context empty: {not rag_context}") - logger.info(f"rag_context preview: {rag_context[:200] if rag_context else 'EMPTY'}") - logger.info(f"=== END CRITICAL DEBUG ===") messages = self._build_conversation_messages(db_messages, config, rag_context, context) # Note: Current user message is already included in db_messages from the query @@ -452,9 +445,9 @@ class ChatbotModule(BaseModule): if config.use_rag and rag_context: logger.info(f"RAG context added: {len(rag_context)} characters") logger.info(f"RAG sources: {len(sources) if sources else 0} documents") - logger.info("\\n=== COMPLETE MESSAGES SENT TO LLM ===") + logger.info("\n=== COMPLETE MESSAGES SENT TO LLM ===") for i, msg in enumerate(messages): - logger.info(f"\\n--- Message {i+1} ---") + logger.info(f"\n--- Message {i+1} ---") logger.info(f"Role: {msg['role']}") logger.info(f"Content ({len(msg['content'])} chars):") # Truncate long content for logging (full RAG context can be very long) @@ -518,38 +511,34 @@ class ChatbotModule(BaseModule): # Return fallback if available return "I'm currently unable to process your request. Please try again later.", None - def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig, + def _build_conversation_messages(self, db_messages: List[DBMessage], config: ChatbotConfig, rag_context: str = "", context: Optional[Dict] = None) -> List[Dict]: """Build messages array for LLM completion""" - + messages = [] - logger.info(f"DEBUG: _build_conversation_messages called. rag_context length: {len(rag_context)}") - - # System prompt - keep it clean without RAG context + + # System prompt system_prompt = config.system_prompt + if rag_context: + # Add explicit instruction to use RAG context + system_prompt += "\n\nIMPORTANT: Use the following information from the knowledge base to answer the user's question. " \ + "This information is directly relevant to their query and should be your primary source:\n" + rag_context if context and context.get('additional_instructions'): system_prompt += f"\n\nAdditional instructions: {context['additional_instructions']}" - + messages.append({"role": "system", "content": system_prompt}) - + logger.info(f"Building messages from {len(db_messages)} database messages") - + # Conversation history (messages are already limited by memory_length in the query) # Reverse to get chronological order # Include ALL messages - the current user message is needed for the LLM to respond! for idx, msg in enumerate(reversed(db_messages)): logger.info(f"Processing message {idx}: role={msg.role}, content_preview={msg.content[:50] if msg.content else 'None'}...") if msg.role in ["user", "assistant"]: - # For user messages, prepend RAG context if available - content = msg.content - if msg.role == "user" and rag_context and idx == 0: - # Add RAG context to the current user message (first in reversed order) - content = f"Relevant information from knowledge base:\n{rag_context}\n\nQuestion: {msg.content}" - logger.info("Added RAG context to user message") - messages.append({ "role": msg.role, - "content": content + "content": msg.content }) logger.info(f"Added message with role {msg.role} to LLM messages") else: @@ -690,10 +679,9 @@ class ChatbotModule(BaseModule): return router # API Compatibility Methods - async def chat(self, chatbot_config: Dict[str, Any], message: str, + async def chat(self, chatbot_config: Dict[str, Any], message: str, conversation_history: List = None, user_id: str = "anonymous") -> Dict[str, Any]: """Chat method for API compatibility""" - logger.info("=== CHAT METHOD (API COMPATIBILITY) CALLED ===") logger.info(f"Chat method called with message: {message[:50]}... by user: {user_id}") # Lazy load dependencies @@ -723,20 +711,21 @@ class ChatbotModule(BaseModule): fallback_responses=chatbot_config.get("fallback_responses", []) ) - # For API compatibility, create a temporary DBMessage for the current message - # so RAG context can be properly added - from app.models.chatbot import ChatbotMessage as DBMessage + # Generate response using internal method + # Create a temporary message object for the current user message + temp_messages = [ + DBMessage( + id=0, + conversation_id=0, + role="user", + content=message, + timestamp=datetime.utcnow(), + metadata={} + ) + ] - # Create a temporary user message with the current message - temp_user_message = DBMessage( - conversation_id="temp_conversation", - role=MessageRole.USER.value, - content=message - ) - - # Generate response using internal method with the current message included response_content, sources = await self._generate_response( - message, [temp_user_message], config, None, db + message, temp_messages, config, None, db ) return { diff --git a/backend/modules/rag/main.py b/backend/modules/rag/main.py index b6c90b7..7d75fbd 100644 --- a/backend/modules/rag/main.py +++ b/backend/modules/rag/main.py @@ -53,14 +53,13 @@ except ImportError: PYTHON_DOCX_AVAILABLE = False from qdrant_client import QdrantClient -from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue +from qdrant_client.models import Distance, VectorParams, PointStruct, ScoredPoint, Filter, FieldCondition, MatchValue from qdrant_client.http import models import tiktoken from app.core.config import settings from app.core.logging import log_module_event from app.services.base_module import BaseModule, Permission -from app.services.enhanced_embedding_service import enhanced_embedding_service @dataclass @@ -134,6 +133,19 @@ class RAGModule(BaseModule): self.embedding_model = None self.embedding_service = None self.tokenizer = None + + # Set improved default configuration + self.config = { + "chunk_size": 300, # Reduced from 400 for better precision + "chunk_overlap": 50, # Added overlap for context preservation + "max_results": 10, + "score_threshold": 0.3, # Increased from 0.0 to filter low-quality results + "enable_hybrid": True, # Enable hybrid search (vector + BM25) + "hybrid_weights": {"vector": 0.7, "bm25": 0.3} # Weight for hybrid scoring + } + # Update with any provided config + if config: + self.config.update(config) # Content processing components self.nlp_model = None @@ -640,19 +652,33 @@ class RAGModule(BaseModule): return embeddings def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]: - """Split text into chunks""" - chunk_size = chunk_size or self.config.get("chunk_size", 400) - + """Split text into overlapping chunks for better context preservation""" + chunk_size = chunk_size or self.config.get("chunk_size", 300) + chunk_overlap = self.config.get("chunk_overlap", 50) + # Tokenize text tokens = self.tokenizer.encode(text) - - # Split into chunks + + # Split into chunks with overlap chunks = [] - for i in range(0, len(tokens), chunk_size): - chunk_tokens = tokens[i:i + chunk_size] + start_idx = 0 + + while start_idx < len(tokens): + end_idx = min(start_idx + chunk_size, len(tokens)) + chunk_tokens = tokens[start_idx:end_idx] chunk_text = self.tokenizer.decode(chunk_tokens) - chunks.append(chunk_text) - + + # Only add non-empty chunks + if chunk_text.strip(): + chunks.append(chunk_text) + + # Move to next chunk with overlap + start_idx = end_idx - chunk_overlap + + # Ensure progress (in case overlap >= chunk_size) + if start_idx >= end_idx: + start_idx = end_idx + return chunks async def _process_text(self, content: bytes, filename: str) -> str: @@ -1126,17 +1152,9 @@ class RAGModule(BaseModule): # Chunk the document chunks = self._chunk_text(content) - # Generate embeddings with enhanced rate limiting handling - embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks) - - # Log if fallback embeddings were used - if not success: - logger.warning(f"Used fallback embeddings for document {doc_id} - search quality may be degraded") - log_module_event("rag", "fallback_embeddings_used", { - "document_id": doc_id, - "content_preview": content[:100] + "..." if len(content) > 100 else content - }) - + # Generate embeddings for all chunks in batch (more efficient) + embeddings = await self._generate_embeddings(chunks) + # Create document points points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): @@ -1197,17 +1215,9 @@ class RAGModule(BaseModule): # Chunk the document chunks = self._chunk_text(processed_doc.content) - # Generate embeddings with enhanced rate limiting handling - embeddings, success = await enhanced_embedding_service.get_embeddings_with_retry(chunks) - - # Log if fallback embeddings were used - if not success: - logger.warning(f"Used fallback embeddings for document {processed_doc.id} - search quality may be degraded") - log_module_event("rag", "fallback_embeddings_used", { - "document_id": processed_doc.id, - "filename": processed_doc.original_filename - }) - + # Generate embeddings for all chunks in batch (more efficient) + embeddings = await self._generate_embeddings(chunks) + # Create document points with enhanced metadata points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): @@ -1277,6 +1287,154 @@ class RAGModule(BaseModule): except Exception: return False + async def _hybrid_search(self, collection_name: str, query: str, query_vector: List[float], + query_filter: Optional[Filter], limit: int, score_threshold: float) -> List[Any]: + """Perform hybrid search combining vector similarity and BM25 scoring""" + + # Preprocess query for BM25 + query_terms = self._preprocess_text_for_bm25(query) + + # Get all documents from the collection (for BM25 scoring) + # Note: In production, you'd want to optimize this with a proper BM25 index + scroll_filter = query_filter or Filter() + all_points = [] + + # Use scroll to get all points + offset = None + batch_size = 100 + while True: + search_result = self.qdrant_client.scroll( + collection_name=collection_name, + scroll_filter=scroll_filter, + limit=batch_size, + offset=offset, + with_payload=True, + with_vectors=False + ) + + points = search_result[0] + all_points.extend(points) + + if len(points) < batch_size: + break + + offset = points[-1].id + + # Calculate BM25 scores for each document + bm25_scores = {} + for point in all_points: + doc_id = point.payload.get("document_id", "") + content = point.payload.get("content", "") + + # Calculate BM25 score + bm25_score = self._calculate_bm25_score(query_terms, content) + bm25_scores[doc_id] = bm25_score + + # Perform vector search + vector_results = self.qdrant_client.search( + collection_name=collection_name, + query_vector=query_vector, + query_filter=query_filter, + limit=limit * 2, # Get more results for re-ranking + score_threshold=score_threshold / 2 # Lower threshold for initial search + ) + + # Combine scores + hybrid_weights = self.config.get("hybrid_weights", {"vector": 0.7, "bm25": 0.3}) + vector_weight = hybrid_weights.get("vector", 0.7) + bm25_weight = hybrid_weights.get("bm25", 0.3) + + # Create hybrid results + hybrid_results = [] + for result in vector_results: + doc_id = result.payload.get("document_id", "") + vector_score = result.score + bm25_score = bm25_scores.get(doc_id, 0.0) + + # Normalize scores (simple min-max normalization) + vector_norm = (vector_score - score_threshold) / (1.0 - score_threshold) if vector_score > score_threshold else 0 + bm25_norm = min(bm25_score, 1.0) # BM25 scores are typically 0-1 + + # Calculate hybrid score + hybrid_score = (vector_weight * vector_norm) + (bm25_weight * bm25_norm) + + # Create new point with hybrid score + hybrid_point = ScoredPoint( + id=result.id, + payload=result.payload, + score=hybrid_score, + vector=result.vector, + shard_key=None, + order_value=None + ) + hybrid_results.append(hybrid_point) + + # Sort by hybrid score and apply final threshold + hybrid_results.sort(key=lambda x: x.score, reverse=True) + final_results = [r for r in hybrid_results if r.score >= score_threshold][:limit] + + logger.info(f"Hybrid search: {len(vector_results)} vector results, {len(final_results)} final results") + return final_results + + def _preprocess_text_for_bm25(self, text: str) -> List[str]: + """Preprocess text for BM25 scoring""" + if not NLTK_AVAILABLE: + return text.lower().split() + + try: + # Tokenize + tokens = word_tokenize(text.lower()) + + # Remove stopwords and non-alphabetic tokens + stop_words = set(stopwords.words('english')) + filtered_tokens = [ + token for token in tokens + if token.isalpha() and token not in stop_words and len(token) > 2 + ] + + return filtered_tokens + except: + # Fallback to simple splitting + return text.lower().split() + + def _calculate_bm25_score(self, query_terms: List[str], document: str) -> float: + """Calculate BM25 score for a document against query terms""" + if not query_terms: + return 0.0 + + # Preprocess document + doc_terms = self._preprocess_text_for_bm25(document) + if not doc_terms: + return 0.0 + + # Calculate term frequencies + doc_len = len(doc_terms) + avg_doc_len = 300 # Average document length (configurable) + + # BM25 parameters + k1 = 1.2 # Controls term frequency saturation + b = 0.75 # Controls document length normalization + + score = 0.0 + + # Calculate IDF for each query term + for term in set(query_terms): + # Term frequency in document + tf = doc_terms.count(term) + + # Simple IDF (log(N/n) + 1) + # In production, you'd use the actual document frequency + idf = 2.0 # Simplified IDF + + # BM25 formula + numerator = tf * (k1 + 1) + denominator = tf + k1 * (1 - b + b * (doc_len / avg_doc_len)) + + score += idf * (numerator / denominator) + + # Normalize score to 0-1 range + return min(score / 10.0, 1.0) # Simple normalization + async def search_documents(self, query: str, max_results: int = None, filters: Dict[str, Any] = None, collection_name: str = None) -> List[SearchResult]: """Search for relevant documents""" if not self.enabled: @@ -1314,14 +1472,29 @@ class RAGModule(BaseModule): logger.info(f"Query embedding (first 10 values): {query_embedding[:10] if query_embedding else 'None'}") logger.info(f"Embedding service available: {self.embedding_service is not None}") - # Search in Qdrant - search_results = self.qdrant_client.search( - collection_name=collection_name, - query_vector=query_embedding, - query_filter=search_filter, - limit=max_results, - score_threshold=0.0 # Lowered from 0.5 to see all results including low scores - ) + # Check if hybrid search is enabled + enable_hybrid = self.config.get("enable_hybrid", False) + score_threshold = self.config.get("score_threshold", 0.3) + + if enable_hybrid and NLTK_AVAILABLE: + # Perform hybrid search (vector + BM25) + search_results = await self._hybrid_search( + collection_name=collection_name, + query=query, + query_vector=query_embedding, + query_filter=search_filter, + limit=max_results, + score_threshold=score_threshold + ) + else: + # Pure vector search with improved threshold + search_results = self.qdrant_client.search( + collection_name=collection_name, + query_vector=query_embedding, + query_filter=search_filter, + limit=max_results, + score_threshold=score_threshold + ) logger.info(f"Raw search results count: {len(search_results)}")