diff --git a/backend/app/api/v1/rag.py b/backend/app/api/v1/rag.py index 7894174..6dad828 100644 --- a/backend/app/api/v1/rag.py +++ b/backend/app/api/v1/rag.py @@ -7,6 +7,7 @@ from typing import List, Optional, Dict, Any from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status from fastapi.responses import StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select from pydantic import BaseModel import io import asyncio @@ -15,6 +16,7 @@ from datetime import datetime from app.db.database import get_db from app.core.security import get_current_user from app.models.user import User +from app.models.rag_collection import RagCollection from app.services.rag_service import RAGService from app.utils.exceptions import APIException @@ -268,7 +270,19 @@ async def get_documents( try: collection_id_int = int(collection_id) except (ValueError, TypeError): - raise HTTPException(status_code=400, detail="Invalid collection_id format") + # Attempt to resolve by Qdrant collection name + collection_row = await db.scalar( + select(RagCollection).where(RagCollection.qdrant_collection_name == collection_id) + ) + if collection_row: + collection_id_int = collection_row.id + else: + # Unknown collection identifier; return empty result instead of erroring out + return { + "success": True, + "documents": [], + "total": 0 + } rag_service = RAGService(db) documents = await rag_service.get_documents( diff --git a/backend/app/core/config.py b/backend/app/core/config.py index b6c716e..a091936 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -129,6 +129,7 @@ class Settings(BaseSettings): RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5")) RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true" RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true" + RAG_EMBEDDING_MODEL: str = os.getenv("RAG_EMBEDDING_MODEL", "BAAI/bge-small-en") RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300")) RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120")) RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120")) @@ -154,4 +155,3 @@ class Settings(BaseSettings): # Global settings instance settings = Settings() - diff --git a/backend/app/modules/rag/main.py b/backend/app/modules/rag/main.py index 1164c1a..b3c5ac3 100644 --- a/backend/app/modules/rag/main.py +++ b/backend/app/modules/rag/main.py @@ -8,6 +8,7 @@ import json import logging import mimetypes import re +import time from typing import Any, Dict, List, Optional, Tuple, Union from datetime import datetime from dataclasses import dataclass, asdict @@ -146,6 +147,11 @@ class RAGModule(BaseModule): # Update with any provided config if config: self.config.update(config) + + # Ensure embedding model configured (defaults to local BGE small) + default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en') + self.config.setdefault("embedding_model", default_embedding_model) + self.default_embedding_model = default_embedding_model # Content processing components self.nlp_model = None @@ -178,6 +184,7 @@ class RAGModule(BaseModule): "supported_types": len(self.supported_types) } self.search_cache = {} + self.collection_vector_sizes: Dict[str, int] = {} def get_required_permissions(self) -> List[Permission]: """Return list of permissions this module requires""" @@ -215,7 +222,7 @@ class RAGModule(BaseModule): self.initialized = True log_module_event("rag", "initialized", { "vector_db": self.config.get("vector_db", "qdrant"), - "embedding_model": self.embedding_model.get("model_name", "intfloat/multilingual-e5-large-instruct"), + "embedding_model": self.embedding_model.get("model_name", self.default_embedding_model), "chunk_size": self.config.get("chunk_size", 400), "max_results": self.config.get("max_results", 10), "supported_file_types": list(self.supported_types.keys()), @@ -427,8 +434,7 @@ class RAGModule(BaseModule): # Prefer enhanced embedding service (rate limiting + retry) from app.services.enhanced_embedding_service import enhanced_embedding_service as embedding_service - # Use intfloat/multilingual-e5-large-instruct for LLM service integration - model_name = self.config.get("embedding_model", "intfloat/multilingual-e5-large-instruct") + model_name = self.config.get("embedding_model", self.default_embedding_model) embedding_service.model_name = model_name # Initialize the embedding service @@ -439,7 +445,7 @@ class RAGModule(BaseModule): logger.info(f"Successfully initialized embedding service with {model_name}") return { "model_name": model_name, - "dimension": embedding_service.dimension or 768 + "dimension": embedding_service.dimension or 384 } else: # Fallback to mock implementation @@ -447,7 +453,7 @@ class RAGModule(BaseModule): self.embedding_service = None return { "model_name": model_name, - "dimension": 1024 # Default dimension for intfloat/multilingual-e5-large-instruct + "dimension": 384 # Default dimension matching local bge-small embeddings } async def _initialize_content_processing(self): @@ -588,16 +594,37 @@ class RAGModule(BaseModule): collection_names = await self._get_collections_safely() if collection_name not in collection_names: - # Create collection + # Create collection with the current embedding dimension + vector_dimension = self.embedding_model.get( + "dimension", + getattr(self.embedding_service, "dimension", 384) or 384 + ) + self.qdrant_client.create_collection( collection_name=collection_name, vectors_config=VectorParams( - size=self.embedding_model.get("dimension", 768), + size=vector_dimension, distance=Distance.COSINE ) ) + self.collection_vector_sizes[collection_name] = vector_dimension log_module_event("rag", "collection_created", {"collection": collection_name}) - + else: + # Cache existing collection vector size for later alignment + try: + info = self.qdrant_client.get_collection(collection_name) + vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None + existing_size = None + if vectors_param is not None and hasattr(vectors_param, "size"): + existing_size = vectors_param.size + elif isinstance(vectors_param, dict): + existing_size = vectors_param.get("size") + + if existing_size: + self.collection_vector_sizes[collection_name] = existing_size + except Exception as inner_error: + logger.debug(f"Unable to cache vector size for collection {collection_name}: {inner_error}") + except Exception as e: logger.error(f"Error ensuring collection exists: {e}") raise @@ -632,12 +659,13 @@ class RAGModule(BaseModule): async def _generate_embedding(self, text: str) -> List[float]: """Generate embedding for text""" if self.embedding_service: - # Use real embedding service - return await self.embedding_service.get_embedding(text) + vector = await self.embedding_service.get_embedding(text) + return vector else: # Fallback to deterministic random embedding for consistency np.random.seed(hash(text) % 2**32) - return np.random.random(self.embedding_model.get("dimension", 768)).tolist() + fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 384) or 384) + return np.random.random(fallback_dim).tolist() async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]: """Generate embeddings for multiple texts (batch processing)""" @@ -651,14 +679,90 @@ class RAGModule(BaseModule): prefixed_texts = texts # Use real embedding service for batch processing - return await self.embedding_service.get_embeddings(prefixed_texts) + backend = getattr(self.embedding_service, "backend", "unknown") + start_time = time.time() + logger.info( + "Embedding batch requested", + extra={ + "backend": backend, + "model": getattr(self.embedding_service, "model_name", "unknown"), + "count": len(prefixed_texts), + "scope": "documents" if is_document else "queries" + }, + ) + embeddings = await self.embedding_service.get_embeddings(prefixed_texts) + duration = time.time() - start_time + logger.info( + "Embedding batch finished", + extra={ + "backend": backend, + "model": getattr(self.embedding_service, "model_name", "unknown"), + "count": len(embeddings), + "scope": "documents" if is_document else "queries", + "duration_sec": round(duration, 4) + }, + ) + return embeddings else: # Fallback to individual processing + logger.warning( + "Embedding service unavailable, falling back to per-item generation", + extra={ + "count": len(texts), + "scope": "documents" if is_document else "queries" + }, + ) embeddings = [] for text in texts: embedding = await self._generate_embedding(text) embeddings.append(embedding) return embeddings + + def _get_collection_vector_size(self, collection_name: Optional[str]) -> int: + """Return the expected vector size for a collection, caching results.""" + default_dim = self.embedding_model.get( + "dimension", + getattr(self.embedding_service, "dimension", 384) or 384 + ) + + if not collection_name: + return default_dim + + if collection_name in self.collection_vector_sizes: + return self.collection_vector_sizes[collection_name] + + try: + info = self.qdrant_client.get_collection(collection_name) + vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None + existing_size = None + if vectors_param is not None and hasattr(vectors_param, "size"): + existing_size = vectors_param.size + elif isinstance(vectors_param, dict): + existing_size = vectors_param.get("size") + + if existing_size: + self.collection_vector_sizes[collection_name] = existing_size + return existing_size + except Exception as e: + logger.debug(f"Unable to determine vector size for {collection_name}: {e}") + + self.collection_vector_sizes[collection_name] = default_dim + return default_dim + + def _align_embedding_dimension(self, vector: List[float], collection_name: Optional[str]) -> List[float]: + """Pad or truncate embeddings to match the target collection dimension.""" + if vector is None: + return vector + + target_dim = self._get_collection_vector_size(collection_name) + current_dim = len(vector) + + if current_dim == target_dim: + return vector + if current_dim > target_dim: + return vector[:target_dim] + padding = [0.0] * (target_dim - current_dim) + return vector + padding def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]: """Split text into overlapping chunks for better context preservation""" @@ -1176,6 +1280,7 @@ class RAGModule(BaseModule): # Create document points points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): + aligned_embedding = self._align_embedding_dimension(embedding, collection_name) chunk_id = str(uuid.uuid4()) chunk_metadata = { @@ -1189,7 +1294,7 @@ class RAGModule(BaseModule): points.append(PointStruct( id=chunk_id, - vector=embedding, + vector=aligned_embedding, payload=chunk_metadata )) @@ -1257,6 +1362,7 @@ class RAGModule(BaseModule): # Create document points with enhanced metadata points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): + aligned_embedding = self._align_embedding_dimension(embedding, collection_name) chunk_id = str(uuid.uuid4()) chunk_metadata = { @@ -1280,7 +1386,7 @@ class RAGModule(BaseModule): points.append(PointStruct( id=chunk_id, - vector=embedding, + vector=aligned_embedding, payload=chunk_metadata )) @@ -1514,9 +1620,9 @@ class RAGModule(BaseModule): start_time = time.time() # Generate query embedding with task-specific prefix for better retrieval - # The E5 model works better with "query:" prefix for search queries optimized_query = f"query: {query}" query_embedding = await self._generate_embedding(optimized_query) + query_embedding = self._align_embedding_dimension(query_embedding, collection_name) # Build filter search_filter = None diff --git a/backend/app/services/embedding_service.py b/backend/app/services/embedding_service.py index df1eba7..78198a9 100644 --- a/backend/app/services/embedding_service.py +++ b/backend/app/services/embedding_service.py @@ -1,148 +1,133 @@ """ Embedding Service -Provides text embedding functionality using LLM service +Provides local sentence-transformer embeddings (default: BAAI/bge-small-en). +Falls back to deterministic random vectors when the local model is unavailable. """ +import asyncio import logging +import time from typing import List, Dict, Any, Optional import numpy as np +from app.core.config import settings + logger = logging.getLogger(__name__) class EmbeddingService: - """Service for generating text embeddings using LLM service""" - - def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"): - self.model_name = model_name - self.dimension = 1024 # Actual dimension for intfloat/multilingual-e5-large-instruct + """Service for generating text embeddings using a local transformer model""" + + def __init__(self, model_name: Optional[str] = None): + self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-small-en") + self.dimension = 384 # bge-small produces 384-d vectors self.initialized = False - + self.local_model = None + self.backend = "uninitialized" + async def initialize(self): """Initialize the embedding service with LLM service""" try: - from app.services.llm.service import llm_service - - # Initialize LLM service if not already done - if not llm_service._initialized: - await llm_service.initialize() - - # Test LLM service health - if not llm_service._initialized: - logger.error("LLM service not initialized") - return False + from sentence_transformers import SentenceTransformer - # Check if PrivateMode provider is available - try: - provider_status = await llm_service.get_provider_status() - privatemode_status = provider_status.get("privatemode") - if not privatemode_status or privatemode_status.status != "healthy": - logger.error(f"PrivateMode provider not available: {privatemode_status}") - return False - except Exception as e: - logger.error(f"Failed to check provider status: {e}") - return False - + loop = asyncio.get_running_loop() + + def load_model(): + # Load model synchronously in a worker thread to avoid blocking event loop + return SentenceTransformer(self.model_name) + + self.local_model = await loop.run_in_executor(None, load_model) + self.dimension = self.local_model.get_sentence_embedding_dimension() self.initialized = True - logger.info(f"Embedding service initialized with LLM service: {self.model_name} (dimension: {self.dimension})") + self.backend = "sentence_transformer" + logger.info( + "Embedding service initialized with local model %s (dimension: %s)", + self.model_name, + self.dimension, + ) return True - - except Exception as e: - logger.error(f"Failed to initialize LLM embedding service: {e}") - logger.warning("Using fallback random embeddings") + + except ImportError as exc: + logger.error("sentence-transformers not installed: %s", exc) + logger.warning("Falling back to random embeddings") + self.local_model = None + self.initialized = False + self.backend = "fallback_random" + return False + + except Exception as exc: + logger.error(f"Failed to load local embedding model {self.model_name}: {exc}") + logger.warning("Falling back to random embeddings") + self.local_model = None + self.initialized = False + self.backend = "fallback_random" return False async def get_embedding(self, text: str) -> List[float]: """Get embedding for a single text""" embeddings = await self.get_embeddings([text]) return embeddings[0] - + async def get_embeddings(self, texts: List[str]) -> List[List[float]]: """Get embeddings for multiple texts using LLM service""" - if not self.initialized: - # Fallback to random embeddings if not initialized - logger.warning("LLM service not available, using random embeddings") - return self._generate_fallback_embeddings(texts) - - try: - embeddings = [] - # Process texts in batches for efficiency - batch_size = 10 - for i in range(0, len(texts), batch_size): - batch = texts[i:i+batch_size] - - # Process each text in the batch - batch_embeddings = [] - for text in batch: - try: - # Truncate text if it's too long for the model's context window - # intfloat/multilingual-e5-large-instruct has a 512 token limit, truncate to ~400 tokens worth of chars - # Rough estimate: 1 token ≈ 4 characters, so 400 tokens ≈ 1600 chars - max_chars = 1600 - if len(text) > max_chars: - truncated_text = text[:max_chars] - logger.debug(f"Truncated text from {len(text)} to {max_chars} chars for embedding") - else: - truncated_text = text - - # Guard: skip empty inputs (validator rejects empty strings) - if not truncated_text.strip(): - logger.debug("Empty input for embedding; using fallback vector") - batch_embeddings.append(self._generate_fallback_embedding(text)) - continue + start_time = time.time() - # Call LLM service embedding endpoint - from app.services.llm.service import llm_service - from app.services.llm.models import EmbeddingRequest - - llm_request = EmbeddingRequest( - model=self.model_name, - input=truncated_text, - user_id="rag_system", - api_key_id=0 # System API key - ) - - response = await llm_service.create_embedding(llm_request) - - # Extract embedding from response - if response.data and len(response.data) > 0: - embedding = response.data[0].embedding - if embedding: - batch_embeddings.append(embedding) - # Update dimension based on actual embedding size - if not hasattr(self, '_dimension_confirmed'): - self.dimension = len(embedding) - self._dimension_confirmed = True - logger.info(f"Confirmed embedding dimension: {self.dimension}") - else: - logger.warning(f"No embedding in response for text: {text[:50]}...") - batch_embeddings.append(self._generate_fallback_embedding(text)) - else: - logger.warning(f"Invalid response structure for text: {text[:50]}...") - batch_embeddings.append(self._generate_fallback_embedding(text)) - except Exception as e: - logger.error(f"Error getting embedding for text: {e}") - batch_embeddings.append(self._generate_fallback_embedding(text)) - - embeddings.extend(batch_embeddings) - - return embeddings - - except Exception as e: - logger.error(f"Error generating embeddings with LLM service: {e}") - # Fallback to random embeddings - return self._generate_fallback_embeddings(texts) + if self.local_model: + if not texts: + return [] + + loop = asyncio.get_running_loop() + + try: + embeddings = await loop.run_in_executor( + None, + lambda: self.local_model.encode( + texts, + convert_to_numpy=True, + normalize_embeddings=True, + ), + ) + duration = time.time() - start_time + logger.info( + "Embedding batch completed", + extra={ + "backend": self.backend, + "model": self.model_name, + "count": len(texts), + "dimension": self.dimension, + "duration_sec": round(duration, 4), + }, + ) + return embeddings.tolist() + except Exception as exc: + logger.error(f"Local embedding generation failed: {exc}") + self.backend = "fallback_random" + return self._generate_fallback_embeddings(texts, duration=time.time() - start_time) + + logger.warning("Local embedding model unavailable; using fallback random embeddings") + self.backend = "fallback_random" + return self._generate_fallback_embeddings(texts, duration=time.time() - start_time) - def _generate_fallback_embeddings(self, texts: List[str]) -> List[List[float]]: + def _generate_fallback_embeddings(self, texts: List[str], duration: float = None) -> List[List[float]]: """Generate fallback random embeddings when model unavailable""" embeddings = [] for text in texts: embeddings.append(self._generate_fallback_embedding(text)) + logger.info( + "Embedding batch completed", + extra={ + "backend": "fallback_random", + "model": self.model_name, + "count": len(texts), + "dimension": self.dimension, + "duration_sec": round(duration or 0.0, 4), + }, + ) return embeddings def _generate_fallback_embedding(self, text: str) -> List[float]: """Generate a single fallback embedding""" - dimension = self.dimension or 1024 # Default dimension for intfloat/multilingual-e5-large-instruct + dimension = self.dimension or 384 # Use hash for reproducible random embeddings np.random.seed(hash(text) % 2**32) return np.random.random(dimension).tolist() @@ -169,22 +154,15 @@ class EmbeddingService: "model_name": self.model_name, "model_loaded": self.initialized, "dimension": self.dimension, - "backend": "LLM Service", + "backend": self.backend, "initialized": self.initialized } - + async def cleanup(self): """Cleanup resources""" - # Cleanup LLM service to prevent memory leaks - try: - from .llm.service import llm_service - if llm_service._initialized: - await llm_service.cleanup() - logger.info("Cleaned up LLM service from embedding service") - except Exception as e: - logger.error(f"Error cleaning up LLM service: {e}") - + self.local_model = None self.initialized = False + self.backend = "uninitialized" # Global embedding service instance diff --git a/backend/app/services/enhanced_embedding_service.py b/backend/app/services/enhanced_embedding_service.py index 87d4be4..846bfd7 100644 --- a/backend/app/services/enhanced_embedding_service.py +++ b/backend/app/services/enhanced_embedding_service.py @@ -1,14 +1,13 @@ # Enhanced Embedding Service with Rate Limiting Handling """ -Enhanced embedding service with robust rate limiting and retry logic +Enhanced embedding service that adds basic retry semantics around the local +embedding generator. Since embeddings are fully local, rate-limiting metadata is +largely informational but retained for compatibility. """ -import asyncio import logging import time -from typing import List, Dict, Any, Optional -import numpy as np -from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional from .embedding_service import EmbeddingService from app.core.config import settings @@ -17,9 +16,9 @@ logger = logging.getLogger(__name__) class EnhancedEmbeddingService(EmbeddingService): - """Enhanced embedding service with rate limiting handling""" + """Enhanced embedding service with lightweight retry bookkeeping""" - def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"): + def __init__(self, model_name: Optional[str] = None): super().__init__(model_name) self.rate_limit_tracker = { 'requests_count': 0, @@ -34,161 +33,16 @@ class EnhancedEmbeddingService(EmbeddingService): async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]: """ - Get embeddings with rate limiting and retry logic + Get embeddings with retry bookkeeping. """ - if max_retries is None: - max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3)) - - batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 3)) - - if not self.initialized: - logger.warning("Embedding service not initialized, using fallback") - return self._generate_fallback_embeddings(texts), False - - embeddings = [] - success = True - - for i in range(0, len(texts), batch_size): - batch = texts[i:i+batch_size] - batch_embeddings, batch_success = await self._get_batch_embeddings_with_retry(batch, max_retries) - embeddings.extend(batch_embeddings) - success = success and batch_success - - # Add delay between batches to avoid rate limiting - if i + batch_size < len(texts): - delay = self.rate_limit_tracker['delay_between_batches'] - await asyncio.sleep(delay) # Configurable delay between batches - - return embeddings, success - - async def _get_batch_embeddings_with_retry(self, texts: List[str], max_retries: int) -> tuple[List[List[float]], bool]: - """Get embeddings for a batch with retry logic""" - last_error = None - - for attempt in range(max_retries + 1): - try: - # Check rate limit before making request - if self._is_rate_limited(): - delay = self._get_rate_limit_delay() - logger.warning(f"Rate limit detected, waiting {delay} seconds") - await asyncio.sleep(delay) - continue - - # Make the request - embeddings = await self._get_embeddings_batch_impl(texts) - - return embeddings, True - - except Exception as e: - last_error = e - error_msg = str(e).lower() - - # Check if it's a rate limit error - if any(indicator in error_msg for indicator in ['429', 'rate limit', 'too many requests', 'quota exceeded']): - logger.warning(f"Rate limit error (attempt {attempt + 1}/{max_retries + 1}): {e}") - self._update_rate_limit_tracker(success=False) - - if attempt < max_retries: - delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)] - logger.info(f"Retrying in {delay} seconds...") - await asyncio.sleep(delay) - continue - else: - logger.error(f"Max retries exceeded for rate limit, using fallback embeddings") - return self._generate_fallback_embeddings(texts), False - else: - # Non-rate-limit error - logger.error(f"Error generating embeddings: {e}") - if attempt < max_retries: - delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)] - await asyncio.sleep(delay) - else: - logger.error("Max retries exceeded, using fallback embeddings") - return self._generate_fallback_embeddings(texts), False - - # If we get here, all retries failed - logger.error(f"All retries failed, last error: {last_error}") - return self._generate_fallback_embeddings(texts), False - - async def _get_embeddings_batch_impl(self, texts: List[str]) -> List[List[float]]: - """Implementation of getting embeddings for a batch""" - from app.services.llm.service import llm_service - from app.services.llm.models import EmbeddingRequest - - embeddings = [] - - for text in texts: - # Respect rate limit before each request - while self._is_rate_limited(): - delay = self._get_rate_limit_delay() - logger.warning(f"Rate limit window exceeded, waiting {delay:.2f}s before next request") - await asyncio.sleep(delay) - - # Truncate text if needed - max_chars = 1600 - truncated_text = text[:max_chars] if len(text) > max_chars else text - - llm_request = EmbeddingRequest( - model=self.model_name, - input=truncated_text, - user_id="rag_system", - api_key_id=0 + embeddings = await super().get_embeddings(texts) + success = self.local_model is not None + if not success: + logger.warning( + "Embedding service operating in fallback mode; consider installing the local model %s", + self.model_name, ) - - response = await llm_service.create_embedding(llm_request) - - if response.data and len(response.data) > 0: - embedding = response.data[0].embedding - if embedding: - embeddings.append(embedding) - if not hasattr(self, '_dimension_confirmed'): - self.dimension = len(embedding) - self._dimension_confirmed = True - else: - raise ValueError("Empty embedding in response") - else: - raise ValueError("Invalid response structure") - - # Count this successful request and optionally delay between requests - self._update_rate_limit_tracker(success=True) - per_req_delay = self.rate_limit_tracker.get('delay_per_request', 0) - if per_req_delay and per_req_delay > 0: - await asyncio.sleep(per_req_delay) - - return embeddings - - def _is_rate_limited(self) -> bool: - """Check if we're currently rate limited""" - now = time.time() - window_start = self.rate_limit_tracker['window_start'] - - # Reset window if it's expired - if now - window_start > self.rate_limit_tracker['window_size']: - self.rate_limit_tracker['requests_count'] = 0 - self.rate_limit_tracker['window_start'] = now - return False - - # Check if we've exceeded the limit - return self.rate_limit_tracker['requests_count'] >= self.rate_limit_tracker['max_requests_per_minute'] - - def _get_rate_limit_delay(self) -> float: - """Get delay to wait for rate limit reset""" - now = time.time() - window_end = self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size'] - return max(0, window_end - now) - - def _update_rate_limit_tracker(self, success: bool): - """Update the rate limit tracker""" - now = time.time() - - # Reset window if it's expired - if now - self.rate_limit_tracker['window_start'] > self.rate_limit_tracker['window_size']: - self.rate_limit_tracker['requests_count'] = 0 - self.rate_limit_tracker['window_start'] = now - - # Increment counter on successful requests - if success: - self.rate_limit_tracker['requests_count'] += 1 + return embeddings, success async def get_embedding_stats(self) -> Dict[str, Any]: """Get embedding service statistics including rate limiting info""" diff --git a/backend/app/services/rag_service.py b/backend/app/services/rag_service.py index 9974500..74443d4 100644 --- a/backend/app/services/rag_service.py +++ b/backend/app/services/rag_service.py @@ -566,10 +566,14 @@ class RAGService: logger.warning(f"Could not check existing collections: {e}") # Create collection with proper vector configuration + from app.services.embedding_service import embedding_service + + vector_dimension = getattr(embedding_service, 'dimension', 384) or 384 + client.create_collection( collection_name=collection_name, vectors_config=VectorParams( - size=1024, # Updated for multilingual-e5-large-instruct model + size=vector_dimension, distance=Distance.COSINE ), optimizers_config=models.OptimizersConfigDiff( diff --git a/backend/modules/rag/main.py b/backend/modules/rag/main.py index eb5f2eb..8bebd9e 100644 --- a/backend/modules/rag/main.py +++ b/backend/modules/rag/main.py @@ -8,6 +8,7 @@ import json import logging import mimetypes import re +import time from typing import Any, Dict, List, Optional, Tuple, Union from datetime import datetime from dataclasses import dataclass, asdict @@ -146,6 +147,11 @@ class RAGModule(BaseModule): # Update with any provided config if config: self.config.update(config) + + # Ensure embedding model configured (defaults to local BGE small) + default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en') + self.config.setdefault("embedding_model", default_embedding_model) + self.default_embedding_model = default_embedding_model # Content processing components self.nlp_model = None @@ -178,6 +184,7 @@ class RAGModule(BaseModule): "supported_types": len(self.supported_types) } self.search_cache = {} + self.collection_vector_sizes: Dict[str, int] = {} def get_required_permissions(self) -> List[Permission]: """Return list of permissions this module requires""" @@ -215,7 +222,7 @@ class RAGModule(BaseModule): self.initialized = True log_module_event("rag", "initialized", { "vector_db": self.config.get("vector_db", "qdrant"), - "embedding_model": self.embedding_model.get("model_name", "intfloat/multilingual-e5-large-instruct"), + "embedding_model": self.embedding_model.get("model_name", self.default_embedding_model), "chunk_size": self.config.get("chunk_size", 400), "max_results": self.config.get("max_results", 10), "supported_file_types": list(self.supported_types.keys()), @@ -426,8 +433,7 @@ class RAGModule(BaseModule): """Initialize embedding model""" from app.services.embedding_service import embedding_service - # Use intfloat/multilingual-e5-large-instruct for LLM service integration - model_name = self.config.get("embedding_model", "intfloat/multilingual-e5-large-instruct") + model_name = self.config.get("embedding_model", self.default_embedding_model) embedding_service.model_name = model_name # Initialize the embedding service @@ -438,7 +444,7 @@ class RAGModule(BaseModule): logger.info(f"Successfully initialized embedding service with {model_name}") return { "model_name": model_name, - "dimension": embedding_service.dimension or 768 + "dimension": embedding_service.dimension or 384 } else: # Fallback to mock implementation @@ -446,7 +452,7 @@ class RAGModule(BaseModule): self.embedding_service = None return { "model_name": model_name, - "dimension": 1024 # Default dimension for intfloat/multilingual-e5-large-instruct + "dimension": 384 # Default dimension matching local bge-small embeddings } async def _initialize_content_processing(self): @@ -585,18 +591,39 @@ class RAGModule(BaseModule): try: # Use safe collection fetching to avoid Pydantic validation errors collection_names = await self._get_collections_safely() - + if collection_name not in collection_names: - # Create collection + # Create collection with the current embedding dimension + vector_dimension = self.embedding_model.get( + "dimension", + getattr(self.embedding_service, "dimension", 384) or 384 + ) + self.qdrant_client.create_collection( collection_name=collection_name, vectors_config=VectorParams( - size=self.embedding_model.get("dimension", 768), + size=vector_dimension, distance=Distance.COSINE ) ) + self.collection_vector_sizes[collection_name] = vector_dimension log_module_event("rag", "collection_created", {"collection": collection_name}) - + else: + # Cache existing collection vector size for later alignment + try: + info = self.qdrant_client.get_collection(collection_name) + vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None + existing_size = None + if vectors_param is not None and hasattr(vectors_param, "size"): + existing_size = vectors_param.size + elif isinstance(vectors_param, dict): + existing_size = vectors_param.get("size") + + if existing_size: + self.collection_vector_sizes[collection_name] = existing_size + except Exception as inner_error: + logger.debug(f"Unable to cache vector size for collection {collection_name}: {inner_error}") + except Exception as e: logger.error(f"Error ensuring collection exists: {e}") raise @@ -632,11 +659,13 @@ class RAGModule(BaseModule): """Generate embedding for text""" if self.embedding_service: # Use real embedding service - return await self.embedding_service.get_embedding(text) + vector = await self.embedding_service.get_embedding(text) + return vector else: # Fallback to deterministic random embedding for consistency np.random.seed(hash(text) % 2**32) - return np.random.random(self.embedding_model.get("dimension", 768)).tolist() + fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 384) or 384) + return np.random.random(fallback_dim).tolist() async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]: """Generate embeddings for multiple texts (batch processing)""" @@ -650,14 +679,91 @@ class RAGModule(BaseModule): prefixed_texts = texts # Use real embedding service for batch processing - return await self.embedding_service.get_embeddings(prefixed_texts) + backend = getattr(self.embedding_service, "backend", "unknown") + start_time = time.time() + logger.info( + "Embedding batch requested", + extra={ + "backend": backend, + "model": getattr(self.embedding_service, "model_name", "unknown"), + "count": len(prefixed_texts), + "scope": "documents" if is_document else "queries" + }, + ) + embeddings = await self.embedding_service.get_embeddings(prefixed_texts) + duration = time.time() - start_time + logger.info( + "Embedding batch finished", + extra={ + "backend": backend, + "model": getattr(self.embedding_service, "model_name", "unknown"), + "count": len(embeddings), + "scope": "documents" if is_document else "queries", + "duration_sec": round(duration, 4) + }, + ) + return embeddings else: # Fallback to individual processing + logger.warning( + "Embedding service unavailable, falling back to per-item generation", + extra={ + "count": len(texts), + "scope": "documents" if is_document else "queries" + }, + ) embeddings = [] for text in texts: embedding = await self._generate_embedding(text) embeddings.append(embedding) return embeddings + + def _get_collection_vector_size(self, collection_name: Optional[str]) -> int: + """Return the expected vector size for a collection, caching results.""" + default_dim = self.embedding_model.get( + "dimension", + getattr(self.embedding_service, "dimension", 384) or 384 + ) + + if not collection_name: + return default_dim + + if collection_name in self.collection_vector_sizes: + return self.collection_vector_sizes[collection_name] + + try: + info = self.qdrant_client.get_collection(collection_name) + vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None + existing_size = None + if vectors_param is not None and hasattr(vectors_param, "size"): + existing_size = vectors_param.size + elif isinstance(vectors_param, dict): + existing_size = vectors_param.get("size") + + if existing_size: + self.collection_vector_sizes[collection_name] = existing_size + return existing_size + except Exception as e: + logger.debug(f"Unable to determine vector size for {collection_name}: {e}") + + self.collection_vector_sizes[collection_name] = default_dim + return default_dim + + def _align_embedding_dimension(self, vector: List[float], collection_name: Optional[str]) -> List[float]: + """Pad or truncate embeddings to match the target collection dimension.""" + if vector is None: + return vector + + target_dim = self._get_collection_vector_size(collection_name) + current_dim = len(vector) + + if current_dim == target_dim: + return vector + if current_dim > target_dim: + return vector[:target_dim] + # Pad with zeros to reach the target dimension + padding = [0.0] * (target_dim - current_dim) + return vector + padding def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]: """Split text into overlapping chunks for better context preservation""" @@ -1180,8 +1286,9 @@ class RAGModule(BaseModule): # Create document points points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): + aligned_embedding = self._align_embedding_dimension(embedding, collection_name) chunk_id = str(uuid.uuid4()) - + chunk_metadata = { **metadata, "document_id": doc_id, @@ -1193,7 +1300,7 @@ class RAGModule(BaseModule): points.append(PointStruct( id=chunk_id, - vector=embedding, + vector=aligned_embedding, payload=chunk_metadata )) @@ -1261,8 +1368,9 @@ class RAGModule(BaseModule): # Create document points with enhanced metadata points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): + aligned_embedding = self._align_embedding_dimension(embedding, collection_name) chunk_id = str(uuid.uuid4()) - + chunk_metadata = { **processed_doc.metadata, "document_id": processed_doc.id, @@ -1284,7 +1392,7 @@ class RAGModule(BaseModule): points.append(PointStruct( id=chunk_id, - vector=embedding, + vector=aligned_embedding, payload=chunk_metadata )) @@ -1548,6 +1656,8 @@ class RAGModule(BaseModule): logger.warning(f"sentence-transformers not available, falling back to default embedding for {collection_name}") optimized_query = f"query: {query}" query_embedding = await self._generate_embedding(optimized_query) + + query_embedding = self._align_embedding_dimension(query_embedding, collection_name) # Build filter search_filter = None diff --git a/frontend/src/components/chatbot/ChatInterface.tsx b/frontend/src/components/chatbot/ChatInterface.tsx index cdfcaf6..94dc7fa 100644 --- a/frontend/src/components/chatbot/ChatInterface.tsx +++ b/frontend/src/components/chatbot/ChatInterface.tsx @@ -87,6 +87,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface const [messages, setMessages] = useState([]) const [input, setInput] = useState("") const [isLoading, setIsLoading] = useState(false) + const [conversationId, setConversationId] = useState(undefined) const scrollAreaRef = useRef(null) const { success: toastSuccess, error: toastError } = useToast() @@ -103,6 +104,12 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface scrollToBottom() }, [messages]) + useEffect(() => { + // Reset conversation when switching chatbots + setMessages([]) + setConversationId(undefined) + }, [chatbotId]) + const sendMessage = useCallback(async () => { if (!input.trim() || isLoading) return @@ -119,12 +126,13 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface setIsLoading(true) // Enhanced logging for debugging + const currentConversationId = conversationId const debugInfo = { chatbotId, messageLength: messageToSend.length, - conversationId, + conversationId: currentConversationId ?? null, timestamp: new Date().toISOString(), - messagesCount: messages.length + messagesCount: messages.length + 1 } console.log('=== CHAT REQUEST DEBUG ===', debugInfo) @@ -132,7 +140,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface let data: any // Use internal API - const conversationHistory = messages.map(msg => ({ + const conversationHistory = [...messages, userMessage].map(msg => ({ role: msg.role, content: msg.content })) @@ -140,7 +148,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface data = await chatbotApi.sendMessage( chatbotId, messageToSend, - undefined, // No conversation ID + currentConversationId, conversationHistory ) @@ -154,6 +162,11 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface setMessages(prev => [...prev, assistantMessage]) + const newConversationId = data?.conversation_id || currentConversationId + if (newConversationId !== conversationId) { + setConversationId(newConversationId) + } + } catch (error) { const appError = error as AppError @@ -168,7 +181,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface } finally { setIsLoading(false) } - }, [input, isLoading, chatbotId, messages, toastError]) + }, [input, isLoading, chatbotId, messages, toastError, conversationId]) const handleKeyPress = useCallback((e: React.KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { @@ -358,4 +371,4 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface ) -} \ No newline at end of file +}