swapping to local embeddings

This commit is contained in:
2025-10-06 17:28:55 +02:00
parent bae86fb5a2
commit 487d7d0cc9
8 changed files with 398 additions and 319 deletions

View File

@@ -7,6 +7,7 @@ from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, status
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from pydantic import BaseModel from pydantic import BaseModel
import io import io
import asyncio import asyncio
@@ -15,6 +16,7 @@ from datetime import datetime
from app.db.database import get_db from app.db.database import get_db
from app.core.security import get_current_user from app.core.security import get_current_user
from app.models.user import User from app.models.user import User
from app.models.rag_collection import RagCollection
from app.services.rag_service import RAGService from app.services.rag_service import RAGService
from app.utils.exceptions import APIException from app.utils.exceptions import APIException
@@ -268,7 +270,19 @@ async def get_documents(
try: try:
collection_id_int = int(collection_id) collection_id_int = int(collection_id)
except (ValueError, TypeError): except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="Invalid collection_id format") # Attempt to resolve by Qdrant collection name
collection_row = await db.scalar(
select(RagCollection).where(RagCollection.qdrant_collection_name == collection_id)
)
if collection_row:
collection_id_int = collection_row.id
else:
# Unknown collection identifier; return empty result instead of erroring out
return {
"success": True,
"documents": [],
"total": 0
}
rag_service = RAGService(db) rag_service = RAGService(db)
documents = await rag_service.get_documents( documents = await rag_service.get_documents(

View File

@@ -129,6 +129,7 @@ class Settings(BaseSettings):
RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5")) RAG_EMBEDDING_DELAY_PER_REQUEST: float = float(os.getenv("RAG_EMBEDDING_DELAY_PER_REQUEST", "0.5"))
RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true" RAG_ALLOW_FALLBACK_EMBEDDINGS: bool = os.getenv("RAG_ALLOW_FALLBACK_EMBEDDINGS", "True").lower() == "true"
RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true" RAG_WARN_ON_FALLBACK: bool = os.getenv("RAG_WARN_ON_FALLBACK", "True").lower() == "true"
RAG_EMBEDDING_MODEL: str = os.getenv("RAG_EMBEDDING_MODEL", "BAAI/bge-small-en")
RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300")) RAG_DOCUMENT_PROCESSING_TIMEOUT: int = int(os.getenv("RAG_DOCUMENT_PROCESSING_TIMEOUT", "300"))
RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120")) RAG_EMBEDDING_GENERATION_TIMEOUT: int = int(os.getenv("RAG_EMBEDDING_GENERATION_TIMEOUT", "120"))
RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120")) RAG_INDEXING_TIMEOUT: int = int(os.getenv("RAG_INDEXING_TIMEOUT", "120"))
@@ -154,4 +155,3 @@ class Settings(BaseSettings):
# Global settings instance # Global settings instance
settings = Settings() settings = Settings()

View File

@@ -8,6 +8,7 @@ import json
import logging import logging
import mimetypes import mimetypes
import re import re
import time
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from datetime import datetime from datetime import datetime
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
@@ -147,6 +148,11 @@ class RAGModule(BaseModule):
if config: if config:
self.config.update(config) self.config.update(config)
# Ensure embedding model configured (defaults to local BGE small)
default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en')
self.config.setdefault("embedding_model", default_embedding_model)
self.default_embedding_model = default_embedding_model
# Content processing components # Content processing components
self.nlp_model = None self.nlp_model = None
self.lemmatizer = None self.lemmatizer = None
@@ -178,6 +184,7 @@ class RAGModule(BaseModule):
"supported_types": len(self.supported_types) "supported_types": len(self.supported_types)
} }
self.search_cache = {} self.search_cache = {}
self.collection_vector_sizes: Dict[str, int] = {}
def get_required_permissions(self) -> List[Permission]: def get_required_permissions(self) -> List[Permission]:
"""Return list of permissions this module requires""" """Return list of permissions this module requires"""
@@ -215,7 +222,7 @@ class RAGModule(BaseModule):
self.initialized = True self.initialized = True
log_module_event("rag", "initialized", { log_module_event("rag", "initialized", {
"vector_db": self.config.get("vector_db", "qdrant"), "vector_db": self.config.get("vector_db", "qdrant"),
"embedding_model": self.embedding_model.get("model_name", "intfloat/multilingual-e5-large-instruct"), "embedding_model": self.embedding_model.get("model_name", self.default_embedding_model),
"chunk_size": self.config.get("chunk_size", 400), "chunk_size": self.config.get("chunk_size", 400),
"max_results": self.config.get("max_results", 10), "max_results": self.config.get("max_results", 10),
"supported_file_types": list(self.supported_types.keys()), "supported_file_types": list(self.supported_types.keys()),
@@ -427,8 +434,7 @@ class RAGModule(BaseModule):
# Prefer enhanced embedding service (rate limiting + retry) # Prefer enhanced embedding service (rate limiting + retry)
from app.services.enhanced_embedding_service import enhanced_embedding_service as embedding_service from app.services.enhanced_embedding_service import enhanced_embedding_service as embedding_service
# Use intfloat/multilingual-e5-large-instruct for LLM service integration model_name = self.config.get("embedding_model", self.default_embedding_model)
model_name = self.config.get("embedding_model", "intfloat/multilingual-e5-large-instruct")
embedding_service.model_name = model_name embedding_service.model_name = model_name
# Initialize the embedding service # Initialize the embedding service
@@ -439,7 +445,7 @@ class RAGModule(BaseModule):
logger.info(f"Successfully initialized embedding service with {model_name}") logger.info(f"Successfully initialized embedding service with {model_name}")
return { return {
"model_name": model_name, "model_name": model_name,
"dimension": embedding_service.dimension or 768 "dimension": embedding_service.dimension or 384
} }
else: else:
# Fallback to mock implementation # Fallback to mock implementation
@@ -447,7 +453,7 @@ class RAGModule(BaseModule):
self.embedding_service = None self.embedding_service = None
return { return {
"model_name": model_name, "model_name": model_name,
"dimension": 1024 # Default dimension for intfloat/multilingual-e5-large-instruct "dimension": 384 # Default dimension matching local bge-small embeddings
} }
async def _initialize_content_processing(self): async def _initialize_content_processing(self):
@@ -588,15 +594,36 @@ class RAGModule(BaseModule):
collection_names = await self._get_collections_safely() collection_names = await self._get_collections_safely()
if collection_name not in collection_names: if collection_name not in collection_names:
# Create collection # Create collection with the current embedding dimension
vector_dimension = self.embedding_model.get(
"dimension",
getattr(self.embedding_service, "dimension", 384) or 384
)
self.qdrant_client.create_collection( self.qdrant_client.create_collection(
collection_name=collection_name, collection_name=collection_name,
vectors_config=VectorParams( vectors_config=VectorParams(
size=self.embedding_model.get("dimension", 768), size=vector_dimension,
distance=Distance.COSINE distance=Distance.COSINE
) )
) )
self.collection_vector_sizes[collection_name] = vector_dimension
log_module_event("rag", "collection_created", {"collection": collection_name}) log_module_event("rag", "collection_created", {"collection": collection_name})
else:
# Cache existing collection vector size for later alignment
try:
info = self.qdrant_client.get_collection(collection_name)
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
existing_size = None
if vectors_param is not None and hasattr(vectors_param, "size"):
existing_size = vectors_param.size
elif isinstance(vectors_param, dict):
existing_size = vectors_param.get("size")
if existing_size:
self.collection_vector_sizes[collection_name] = existing_size
except Exception as inner_error:
logger.debug(f"Unable to cache vector size for collection {collection_name}: {inner_error}")
except Exception as e: except Exception as e:
logger.error(f"Error ensuring collection exists: {e}") logger.error(f"Error ensuring collection exists: {e}")
@@ -632,12 +659,13 @@ class RAGModule(BaseModule):
async def _generate_embedding(self, text: str) -> List[float]: async def _generate_embedding(self, text: str) -> List[float]:
"""Generate embedding for text""" """Generate embedding for text"""
if self.embedding_service: if self.embedding_service:
# Use real embedding service vector = await self.embedding_service.get_embedding(text)
return await self.embedding_service.get_embedding(text) return vector
else: else:
# Fallback to deterministic random embedding for consistency # Fallback to deterministic random embedding for consistency
np.random.seed(hash(text) % 2**32) np.random.seed(hash(text) % 2**32)
return np.random.random(self.embedding_model.get("dimension", 768)).tolist() fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 384) or 384)
return np.random.random(fallback_dim).tolist()
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]: async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
"""Generate embeddings for multiple texts (batch processing)""" """Generate embeddings for multiple texts (batch processing)"""
@@ -651,15 +679,91 @@ class RAGModule(BaseModule):
prefixed_texts = texts prefixed_texts = texts
# Use real embedding service for batch processing # Use real embedding service for batch processing
return await self.embedding_service.get_embeddings(prefixed_texts) backend = getattr(self.embedding_service, "backend", "unknown")
start_time = time.time()
logger.info(
"Embedding batch requested",
extra={
"backend": backend,
"model": getattr(self.embedding_service, "model_name", "unknown"),
"count": len(prefixed_texts),
"scope": "documents" if is_document else "queries"
},
)
embeddings = await self.embedding_service.get_embeddings(prefixed_texts)
duration = time.time() - start_time
logger.info(
"Embedding batch finished",
extra={
"backend": backend,
"model": getattr(self.embedding_service, "model_name", "unknown"),
"count": len(embeddings),
"scope": "documents" if is_document else "queries",
"duration_sec": round(duration, 4)
},
)
return embeddings
else: else:
# Fallback to individual processing # Fallback to individual processing
logger.warning(
"Embedding service unavailable, falling back to per-item generation",
extra={
"count": len(texts),
"scope": "documents" if is_document else "queries"
},
)
embeddings = [] embeddings = []
for text in texts: for text in texts:
embedding = await self._generate_embedding(text) embedding = await self._generate_embedding(text)
embeddings.append(embedding) embeddings.append(embedding)
return embeddings return embeddings
def _get_collection_vector_size(self, collection_name: Optional[str]) -> int:
"""Return the expected vector size for a collection, caching results."""
default_dim = self.embedding_model.get(
"dimension",
getattr(self.embedding_service, "dimension", 384) or 384
)
if not collection_name:
return default_dim
if collection_name in self.collection_vector_sizes:
return self.collection_vector_sizes[collection_name]
try:
info = self.qdrant_client.get_collection(collection_name)
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
existing_size = None
if vectors_param is not None and hasattr(vectors_param, "size"):
existing_size = vectors_param.size
elif isinstance(vectors_param, dict):
existing_size = vectors_param.get("size")
if existing_size:
self.collection_vector_sizes[collection_name] = existing_size
return existing_size
except Exception as e:
logger.debug(f"Unable to determine vector size for {collection_name}: {e}")
self.collection_vector_sizes[collection_name] = default_dim
return default_dim
def _align_embedding_dimension(self, vector: List[float], collection_name: Optional[str]) -> List[float]:
"""Pad or truncate embeddings to match the target collection dimension."""
if vector is None:
return vector
target_dim = self._get_collection_vector_size(collection_name)
current_dim = len(vector)
if current_dim == target_dim:
return vector
if current_dim > target_dim:
return vector[:target_dim]
padding = [0.0] * (target_dim - current_dim)
return vector + padding
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]: def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
"""Split text into overlapping chunks for better context preservation""" """Split text into overlapping chunks for better context preservation"""
chunk_size = chunk_size or self.config.get("chunk_size", 300) chunk_size = chunk_size or self.config.get("chunk_size", 300)
@@ -1176,6 +1280,7 @@ class RAGModule(BaseModule):
# Create document points # Create document points
points = [] points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
chunk_id = str(uuid.uuid4()) chunk_id = str(uuid.uuid4())
chunk_metadata = { chunk_metadata = {
@@ -1189,7 +1294,7 @@ class RAGModule(BaseModule):
points.append(PointStruct( points.append(PointStruct(
id=chunk_id, id=chunk_id,
vector=embedding, vector=aligned_embedding,
payload=chunk_metadata payload=chunk_metadata
)) ))
@@ -1257,6 +1362,7 @@ class RAGModule(BaseModule):
# Create document points with enhanced metadata # Create document points with enhanced metadata
points = [] points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
chunk_id = str(uuid.uuid4()) chunk_id = str(uuid.uuid4())
chunk_metadata = { chunk_metadata = {
@@ -1280,7 +1386,7 @@ class RAGModule(BaseModule):
points.append(PointStruct( points.append(PointStruct(
id=chunk_id, id=chunk_id,
vector=embedding, vector=aligned_embedding,
payload=chunk_metadata payload=chunk_metadata
)) ))
@@ -1514,9 +1620,9 @@ class RAGModule(BaseModule):
start_time = time.time() start_time = time.time()
# Generate query embedding with task-specific prefix for better retrieval # Generate query embedding with task-specific prefix for better retrieval
# The E5 model works better with "query:" prefix for search queries
optimized_query = f"query: {query}" optimized_query = f"query: {query}"
query_embedding = await self._generate_embedding(optimized_query) query_embedding = await self._generate_embedding(optimized_query)
query_embedding = self._align_embedding_dimension(query_embedding, collection_name)
# Build filter # Build filter
search_filter = None search_filter = None

View File

@@ -1,55 +1,66 @@
""" """
Embedding Service Embedding Service
Provides text embedding functionality using LLM service Provides local sentence-transformer embeddings (default: BAAI/bge-small-en).
Falls back to deterministic random vectors when the local model is unavailable.
""" """
import asyncio
import logging import logging
import time
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
import numpy as np import numpy as np
from app.core.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EmbeddingService: class EmbeddingService:
"""Service for generating text embeddings using LLM service""" """Service for generating text embeddings using a local transformer model"""
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"): def __init__(self, model_name: Optional[str] = None):
self.model_name = model_name self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-small-en")
self.dimension = 1024 # Actual dimension for intfloat/multilingual-e5-large-instruct self.dimension = 384 # bge-small produces 384-d vectors
self.initialized = False self.initialized = False
self.local_model = None
self.backend = "uninitialized"
async def initialize(self): async def initialize(self):
"""Initialize the embedding service with LLM service""" """Initialize the embedding service with LLM service"""
try: try:
from app.services.llm.service import llm_service from sentence_transformers import SentenceTransformer
# Initialize LLM service if not already done loop = asyncio.get_running_loop()
if not llm_service._initialized:
await llm_service.initialize()
# Test LLM service health def load_model():
if not llm_service._initialized: # Load model synchronously in a worker thread to avoid blocking event loop
logger.error("LLM service not initialized") return SentenceTransformer(self.model_name)
return False
# Check if PrivateMode provider is available
try:
provider_status = await llm_service.get_provider_status()
privatemode_status = provider_status.get("privatemode")
if not privatemode_status or privatemode_status.status != "healthy":
logger.error(f"PrivateMode provider not available: {privatemode_status}")
return False
except Exception as e:
logger.error(f"Failed to check provider status: {e}")
return False
self.local_model = await loop.run_in_executor(None, load_model)
self.dimension = self.local_model.get_sentence_embedding_dimension()
self.initialized = True self.initialized = True
logger.info(f"Embedding service initialized with LLM service: {self.model_name} (dimension: {self.dimension})") self.backend = "sentence_transformer"
logger.info(
"Embedding service initialized with local model %s (dimension: %s)",
self.model_name,
self.dimension,
)
return True return True
except Exception as e: except ImportError as exc:
logger.error(f"Failed to initialize LLM embedding service: {e}") logger.error("sentence-transformers not installed: %s", exc)
logger.warning("Using fallback random embeddings") logger.warning("Falling back to random embeddings")
self.local_model = None
self.initialized = False
self.backend = "fallback_random"
return False
except Exception as exc:
logger.error(f"Failed to load local embedding model {self.model_name}: {exc}")
logger.warning("Falling back to random embeddings")
self.local_model = None
self.initialized = False
self.backend = "fallback_random"
return False return False
async def get_embedding(self, text: str) -> List[float]: async def get_embedding(self, text: str) -> List[float]:
@@ -59,90 +70,64 @@ class EmbeddingService:
async def get_embeddings(self, texts: List[str]) -> List[List[float]]: async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for multiple texts using LLM service""" """Get embeddings for multiple texts using LLM service"""
if not self.initialized: start_time = time.time()
# Fallback to random embeddings if not initialized
logger.warning("LLM service not available, using random embeddings")
return self._generate_fallback_embeddings(texts)
try: if self.local_model:
embeddings = [] if not texts:
# Process texts in batches for efficiency return []
batch_size = 10
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
# Process each text in the batch loop = asyncio.get_running_loop()
batch_embeddings = []
for text in batch:
try:
# Truncate text if it's too long for the model's context window
# intfloat/multilingual-e5-large-instruct has a 512 token limit, truncate to ~400 tokens worth of chars
# Rough estimate: 1 token ≈ 4 characters, so 400 tokens ≈ 1600 chars
max_chars = 1600
if len(text) > max_chars:
truncated_text = text[:max_chars]
logger.debug(f"Truncated text from {len(text)} to {max_chars} chars for embedding")
else:
truncated_text = text
# Guard: skip empty inputs (validator rejects empty strings) try:
if not truncated_text.strip(): embeddings = await loop.run_in_executor(
logger.debug("Empty input for embedding; using fallback vector") None,
batch_embeddings.append(self._generate_fallback_embedding(text)) lambda: self.local_model.encode(
continue texts,
convert_to_numpy=True,
normalize_embeddings=True,
),
)
duration = time.time() - start_time
logger.info(
"Embedding batch completed",
extra={
"backend": self.backend,
"model": self.model_name,
"count": len(texts),
"dimension": self.dimension,
"duration_sec": round(duration, 4),
},
)
return embeddings.tolist()
except Exception as exc:
logger.error(f"Local embedding generation failed: {exc}")
self.backend = "fallback_random"
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
# Call LLM service embedding endpoint logger.warning("Local embedding model unavailable; using fallback random embeddings")
from app.services.llm.service import llm_service self.backend = "fallback_random"
from app.services.llm.models import EmbeddingRequest return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
llm_request = EmbeddingRequest( def _generate_fallback_embeddings(self, texts: List[str], duration: float = None) -> List[List[float]]:
model=self.model_name,
input=truncated_text,
user_id="rag_system",
api_key_id=0 # System API key
)
response = await llm_service.create_embedding(llm_request)
# Extract embedding from response
if response.data and len(response.data) > 0:
embedding = response.data[0].embedding
if embedding:
batch_embeddings.append(embedding)
# Update dimension based on actual embedding size
if not hasattr(self, '_dimension_confirmed'):
self.dimension = len(embedding)
self._dimension_confirmed = True
logger.info(f"Confirmed embedding dimension: {self.dimension}")
else:
logger.warning(f"No embedding in response for text: {text[:50]}...")
batch_embeddings.append(self._generate_fallback_embedding(text))
else:
logger.warning(f"Invalid response structure for text: {text[:50]}...")
batch_embeddings.append(self._generate_fallback_embedding(text))
except Exception as e:
logger.error(f"Error getting embedding for text: {e}")
batch_embeddings.append(self._generate_fallback_embedding(text))
embeddings.extend(batch_embeddings)
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings with LLM service: {e}")
# Fallback to random embeddings
return self._generate_fallback_embeddings(texts)
def _generate_fallback_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Generate fallback random embeddings when model unavailable""" """Generate fallback random embeddings when model unavailable"""
embeddings = [] embeddings = []
for text in texts: for text in texts:
embeddings.append(self._generate_fallback_embedding(text)) embeddings.append(self._generate_fallback_embedding(text))
logger.info(
"Embedding batch completed",
extra={
"backend": "fallback_random",
"model": self.model_name,
"count": len(texts),
"dimension": self.dimension,
"duration_sec": round(duration or 0.0, 4),
},
)
return embeddings return embeddings
def _generate_fallback_embedding(self, text: str) -> List[float]: def _generate_fallback_embedding(self, text: str) -> List[float]:
"""Generate a single fallback embedding""" """Generate a single fallback embedding"""
dimension = self.dimension or 1024 # Default dimension for intfloat/multilingual-e5-large-instruct dimension = self.dimension or 384
# Use hash for reproducible random embeddings # Use hash for reproducible random embeddings
np.random.seed(hash(text) % 2**32) np.random.seed(hash(text) % 2**32)
return np.random.random(dimension).tolist() return np.random.random(dimension).tolist()
@@ -169,22 +154,15 @@ class EmbeddingService:
"model_name": self.model_name, "model_name": self.model_name,
"model_loaded": self.initialized, "model_loaded": self.initialized,
"dimension": self.dimension, "dimension": self.dimension,
"backend": "LLM Service", "backend": self.backend,
"initialized": self.initialized "initialized": self.initialized
} }
async def cleanup(self): async def cleanup(self):
"""Cleanup resources""" """Cleanup resources"""
# Cleanup LLM service to prevent memory leaks self.local_model = None
try:
from .llm.service import llm_service
if llm_service._initialized:
await llm_service.cleanup()
logger.info("Cleaned up LLM service from embedding service")
except Exception as e:
logger.error(f"Error cleaning up LLM service: {e}")
self.initialized = False self.initialized = False
self.backend = "uninitialized"
# Global embedding service instance # Global embedding service instance

View File

@@ -1,14 +1,13 @@
# Enhanced Embedding Service with Rate Limiting Handling # Enhanced Embedding Service with Rate Limiting Handling
""" """
Enhanced embedding service with robust rate limiting and retry logic Enhanced embedding service that adds basic retry semantics around the local
embedding generator. Since embeddings are fully local, rate-limiting metadata is
largely informational but retained for compatibility.
""" """
import asyncio
import logging import logging
import time import time
from typing import List, Dict, Any, Optional from typing import Any, Dict, List, Optional
import numpy as np
from datetime import datetime, timedelta
from .embedding_service import EmbeddingService from .embedding_service import EmbeddingService
from app.core.config import settings from app.core.config import settings
@@ -17,9 +16,9 @@ logger = logging.getLogger(__name__)
class EnhancedEmbeddingService(EmbeddingService): class EnhancedEmbeddingService(EmbeddingService):
"""Enhanced embedding service with rate limiting handling""" """Enhanced embedding service with lightweight retry bookkeeping"""
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"): def __init__(self, model_name: Optional[str] = None):
super().__init__(model_name) super().__init__(model_name)
self.rate_limit_tracker = { self.rate_limit_tracker = {
'requests_count': 0, 'requests_count': 0,
@@ -34,161 +33,16 @@ class EnhancedEmbeddingService(EmbeddingService):
async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]: async def get_embeddings_with_retry(self, texts: List[str], max_retries: int = None) -> tuple[List[List[float]], bool]:
""" """
Get embeddings with rate limiting and retry logic Get embeddings with retry bookkeeping.
""" """
if max_retries is None: embeddings = await super().get_embeddings(texts)
max_retries = int(getattr(settings, 'RAG_EMBEDDING_RETRY_COUNT', 3)) success = self.local_model is not None
if not success:
batch_size = int(getattr(settings, 'RAG_EMBEDDING_BATCH_SIZE', 3)) logger.warning(
"Embedding service operating in fallback mode; consider installing the local model %s",
if not self.initialized: self.model_name,
logger.warning("Embedding service not initialized, using fallback")
return self._generate_fallback_embeddings(texts), False
embeddings = []
success = True
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
batch_embeddings, batch_success = await self._get_batch_embeddings_with_retry(batch, max_retries)
embeddings.extend(batch_embeddings)
success = success and batch_success
# Add delay between batches to avoid rate limiting
if i + batch_size < len(texts):
delay = self.rate_limit_tracker['delay_between_batches']
await asyncio.sleep(delay) # Configurable delay between batches
return embeddings, success
async def _get_batch_embeddings_with_retry(self, texts: List[str], max_retries: int) -> tuple[List[List[float]], bool]:
"""Get embeddings for a batch with retry logic"""
last_error = None
for attempt in range(max_retries + 1):
try:
# Check rate limit before making request
if self._is_rate_limited():
delay = self._get_rate_limit_delay()
logger.warning(f"Rate limit detected, waiting {delay} seconds")
await asyncio.sleep(delay)
continue
# Make the request
embeddings = await self._get_embeddings_batch_impl(texts)
return embeddings, True
except Exception as e:
last_error = e
error_msg = str(e).lower()
# Check if it's a rate limit error
if any(indicator in error_msg for indicator in ['429', 'rate limit', 'too many requests', 'quota exceeded']):
logger.warning(f"Rate limit error (attempt {attempt + 1}/{max_retries + 1}): {e}")
self._update_rate_limit_tracker(success=False)
if attempt < max_retries:
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
logger.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)
continue
else:
logger.error(f"Max retries exceeded for rate limit, using fallback embeddings")
return self._generate_fallback_embeddings(texts), False
else:
# Non-rate-limit error
logger.error(f"Error generating embeddings: {e}")
if attempt < max_retries:
delay = self.rate_limit_tracker['retry_delays'][min(attempt, len(self.rate_limit_tracker['retry_delays']) - 1)]
await asyncio.sleep(delay)
else:
logger.error("Max retries exceeded, using fallback embeddings")
return self._generate_fallback_embeddings(texts), False
# If we get here, all retries failed
logger.error(f"All retries failed, last error: {last_error}")
return self._generate_fallback_embeddings(texts), False
async def _get_embeddings_batch_impl(self, texts: List[str]) -> List[List[float]]:
"""Implementation of getting embeddings for a batch"""
from app.services.llm.service import llm_service
from app.services.llm.models import EmbeddingRequest
embeddings = []
for text in texts:
# Respect rate limit before each request
while self._is_rate_limited():
delay = self._get_rate_limit_delay()
logger.warning(f"Rate limit window exceeded, waiting {delay:.2f}s before next request")
await asyncio.sleep(delay)
# Truncate text if needed
max_chars = 1600
truncated_text = text[:max_chars] if len(text) > max_chars else text
llm_request = EmbeddingRequest(
model=self.model_name,
input=truncated_text,
user_id="rag_system",
api_key_id=0
) )
return embeddings, success
response = await llm_service.create_embedding(llm_request)
if response.data and len(response.data) > 0:
embedding = response.data[0].embedding
if embedding:
embeddings.append(embedding)
if not hasattr(self, '_dimension_confirmed'):
self.dimension = len(embedding)
self._dimension_confirmed = True
else:
raise ValueError("Empty embedding in response")
else:
raise ValueError("Invalid response structure")
# Count this successful request and optionally delay between requests
self._update_rate_limit_tracker(success=True)
per_req_delay = self.rate_limit_tracker.get('delay_per_request', 0)
if per_req_delay and per_req_delay > 0:
await asyncio.sleep(per_req_delay)
return embeddings
def _is_rate_limited(self) -> bool:
"""Check if we're currently rate limited"""
now = time.time()
window_start = self.rate_limit_tracker['window_start']
# Reset window if it's expired
if now - window_start > self.rate_limit_tracker['window_size']:
self.rate_limit_tracker['requests_count'] = 0
self.rate_limit_tracker['window_start'] = now
return False
# Check if we've exceeded the limit
return self.rate_limit_tracker['requests_count'] >= self.rate_limit_tracker['max_requests_per_minute']
def _get_rate_limit_delay(self) -> float:
"""Get delay to wait for rate limit reset"""
now = time.time()
window_end = self.rate_limit_tracker['window_start'] + self.rate_limit_tracker['window_size']
return max(0, window_end - now)
def _update_rate_limit_tracker(self, success: bool):
"""Update the rate limit tracker"""
now = time.time()
# Reset window if it's expired
if now - self.rate_limit_tracker['window_start'] > self.rate_limit_tracker['window_size']:
self.rate_limit_tracker['requests_count'] = 0
self.rate_limit_tracker['window_start'] = now
# Increment counter on successful requests
if success:
self.rate_limit_tracker['requests_count'] += 1
async def get_embedding_stats(self) -> Dict[str, Any]: async def get_embedding_stats(self) -> Dict[str, Any]:
"""Get embedding service statistics including rate limiting info""" """Get embedding service statistics including rate limiting info"""

View File

@@ -566,10 +566,14 @@ class RAGService:
logger.warning(f"Could not check existing collections: {e}") logger.warning(f"Could not check existing collections: {e}")
# Create collection with proper vector configuration # Create collection with proper vector configuration
from app.services.embedding_service import embedding_service
vector_dimension = getattr(embedding_service, 'dimension', 384) or 384
client.create_collection( client.create_collection(
collection_name=collection_name, collection_name=collection_name,
vectors_config=VectorParams( vectors_config=VectorParams(
size=1024, # Updated for multilingual-e5-large-instruct model size=vector_dimension,
distance=Distance.COSINE distance=Distance.COSINE
), ),
optimizers_config=models.OptimizersConfigDiff( optimizers_config=models.OptimizersConfigDiff(

View File

@@ -8,6 +8,7 @@ import json
import logging import logging
import mimetypes import mimetypes
import re import re
import time
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from datetime import datetime from datetime import datetime
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
@@ -147,6 +148,11 @@ class RAGModule(BaseModule):
if config: if config:
self.config.update(config) self.config.update(config)
# Ensure embedding model configured (defaults to local BGE small)
default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en')
self.config.setdefault("embedding_model", default_embedding_model)
self.default_embedding_model = default_embedding_model
# Content processing components # Content processing components
self.nlp_model = None self.nlp_model = None
self.lemmatizer = None self.lemmatizer = None
@@ -178,6 +184,7 @@ class RAGModule(BaseModule):
"supported_types": len(self.supported_types) "supported_types": len(self.supported_types)
} }
self.search_cache = {} self.search_cache = {}
self.collection_vector_sizes: Dict[str, int] = {}
def get_required_permissions(self) -> List[Permission]: def get_required_permissions(self) -> List[Permission]:
"""Return list of permissions this module requires""" """Return list of permissions this module requires"""
@@ -215,7 +222,7 @@ class RAGModule(BaseModule):
self.initialized = True self.initialized = True
log_module_event("rag", "initialized", { log_module_event("rag", "initialized", {
"vector_db": self.config.get("vector_db", "qdrant"), "vector_db": self.config.get("vector_db", "qdrant"),
"embedding_model": self.embedding_model.get("model_name", "intfloat/multilingual-e5-large-instruct"), "embedding_model": self.embedding_model.get("model_name", self.default_embedding_model),
"chunk_size": self.config.get("chunk_size", 400), "chunk_size": self.config.get("chunk_size", 400),
"max_results": self.config.get("max_results", 10), "max_results": self.config.get("max_results", 10),
"supported_file_types": list(self.supported_types.keys()), "supported_file_types": list(self.supported_types.keys()),
@@ -426,8 +433,7 @@ class RAGModule(BaseModule):
"""Initialize embedding model""" """Initialize embedding model"""
from app.services.embedding_service import embedding_service from app.services.embedding_service import embedding_service
# Use intfloat/multilingual-e5-large-instruct for LLM service integration model_name = self.config.get("embedding_model", self.default_embedding_model)
model_name = self.config.get("embedding_model", "intfloat/multilingual-e5-large-instruct")
embedding_service.model_name = model_name embedding_service.model_name = model_name
# Initialize the embedding service # Initialize the embedding service
@@ -438,7 +444,7 @@ class RAGModule(BaseModule):
logger.info(f"Successfully initialized embedding service with {model_name}") logger.info(f"Successfully initialized embedding service with {model_name}")
return { return {
"model_name": model_name, "model_name": model_name,
"dimension": embedding_service.dimension or 768 "dimension": embedding_service.dimension or 384
} }
else: else:
# Fallback to mock implementation # Fallback to mock implementation
@@ -446,7 +452,7 @@ class RAGModule(BaseModule):
self.embedding_service = None self.embedding_service = None
return { return {
"model_name": model_name, "model_name": model_name,
"dimension": 1024 # Default dimension for intfloat/multilingual-e5-large-instruct "dimension": 384 # Default dimension matching local bge-small embeddings
} }
async def _initialize_content_processing(self): async def _initialize_content_processing(self):
@@ -587,15 +593,36 @@ class RAGModule(BaseModule):
collection_names = await self._get_collections_safely() collection_names = await self._get_collections_safely()
if collection_name not in collection_names: if collection_name not in collection_names:
# Create collection # Create collection with the current embedding dimension
vector_dimension = self.embedding_model.get(
"dimension",
getattr(self.embedding_service, "dimension", 384) or 384
)
self.qdrant_client.create_collection( self.qdrant_client.create_collection(
collection_name=collection_name, collection_name=collection_name,
vectors_config=VectorParams( vectors_config=VectorParams(
size=self.embedding_model.get("dimension", 768), size=vector_dimension,
distance=Distance.COSINE distance=Distance.COSINE
) )
) )
self.collection_vector_sizes[collection_name] = vector_dimension
log_module_event("rag", "collection_created", {"collection": collection_name}) log_module_event("rag", "collection_created", {"collection": collection_name})
else:
# Cache existing collection vector size for later alignment
try:
info = self.qdrant_client.get_collection(collection_name)
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
existing_size = None
if vectors_param is not None and hasattr(vectors_param, "size"):
existing_size = vectors_param.size
elif isinstance(vectors_param, dict):
existing_size = vectors_param.get("size")
if existing_size:
self.collection_vector_sizes[collection_name] = existing_size
except Exception as inner_error:
logger.debug(f"Unable to cache vector size for collection {collection_name}: {inner_error}")
except Exception as e: except Exception as e:
logger.error(f"Error ensuring collection exists: {e}") logger.error(f"Error ensuring collection exists: {e}")
@@ -632,11 +659,13 @@ class RAGModule(BaseModule):
"""Generate embedding for text""" """Generate embedding for text"""
if self.embedding_service: if self.embedding_service:
# Use real embedding service # Use real embedding service
return await self.embedding_service.get_embedding(text) vector = await self.embedding_service.get_embedding(text)
return vector
else: else:
# Fallback to deterministic random embedding for consistency # Fallback to deterministic random embedding for consistency
np.random.seed(hash(text) % 2**32) np.random.seed(hash(text) % 2**32)
return np.random.random(self.embedding_model.get("dimension", 768)).tolist() fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 384) or 384)
return np.random.random(fallback_dim).tolist()
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]: async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
"""Generate embeddings for multiple texts (batch processing)""" """Generate embeddings for multiple texts (batch processing)"""
@@ -650,15 +679,92 @@ class RAGModule(BaseModule):
prefixed_texts = texts prefixed_texts = texts
# Use real embedding service for batch processing # Use real embedding service for batch processing
return await self.embedding_service.get_embeddings(prefixed_texts) backend = getattr(self.embedding_service, "backend", "unknown")
start_time = time.time()
logger.info(
"Embedding batch requested",
extra={
"backend": backend,
"model": getattr(self.embedding_service, "model_name", "unknown"),
"count": len(prefixed_texts),
"scope": "documents" if is_document else "queries"
},
)
embeddings = await self.embedding_service.get_embeddings(prefixed_texts)
duration = time.time() - start_time
logger.info(
"Embedding batch finished",
extra={
"backend": backend,
"model": getattr(self.embedding_service, "model_name", "unknown"),
"count": len(embeddings),
"scope": "documents" if is_document else "queries",
"duration_sec": round(duration, 4)
},
)
return embeddings
else: else:
# Fallback to individual processing # Fallback to individual processing
logger.warning(
"Embedding service unavailable, falling back to per-item generation",
extra={
"count": len(texts),
"scope": "documents" if is_document else "queries"
},
)
embeddings = [] embeddings = []
for text in texts: for text in texts:
embedding = await self._generate_embedding(text) embedding = await self._generate_embedding(text)
embeddings.append(embedding) embeddings.append(embedding)
return embeddings return embeddings
def _get_collection_vector_size(self, collection_name: Optional[str]) -> int:
"""Return the expected vector size for a collection, caching results."""
default_dim = self.embedding_model.get(
"dimension",
getattr(self.embedding_service, "dimension", 384) or 384
)
if not collection_name:
return default_dim
if collection_name in self.collection_vector_sizes:
return self.collection_vector_sizes[collection_name]
try:
info = self.qdrant_client.get_collection(collection_name)
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
existing_size = None
if vectors_param is not None and hasattr(vectors_param, "size"):
existing_size = vectors_param.size
elif isinstance(vectors_param, dict):
existing_size = vectors_param.get("size")
if existing_size:
self.collection_vector_sizes[collection_name] = existing_size
return existing_size
except Exception as e:
logger.debug(f"Unable to determine vector size for {collection_name}: {e}")
self.collection_vector_sizes[collection_name] = default_dim
return default_dim
def _align_embedding_dimension(self, vector: List[float], collection_name: Optional[str]) -> List[float]:
"""Pad or truncate embeddings to match the target collection dimension."""
if vector is None:
return vector
target_dim = self._get_collection_vector_size(collection_name)
current_dim = len(vector)
if current_dim == target_dim:
return vector
if current_dim > target_dim:
return vector[:target_dim]
# Pad with zeros to reach the target dimension
padding = [0.0] * (target_dim - current_dim)
return vector + padding
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]: def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
"""Split text into overlapping chunks for better context preservation""" """Split text into overlapping chunks for better context preservation"""
chunk_size = chunk_size or self.config.get("chunk_size", 300) chunk_size = chunk_size or self.config.get("chunk_size", 300)
@@ -1180,6 +1286,7 @@ class RAGModule(BaseModule):
# Create document points # Create document points
points = [] points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
chunk_id = str(uuid.uuid4()) chunk_id = str(uuid.uuid4())
chunk_metadata = { chunk_metadata = {
@@ -1193,7 +1300,7 @@ class RAGModule(BaseModule):
points.append(PointStruct( points.append(PointStruct(
id=chunk_id, id=chunk_id,
vector=embedding, vector=aligned_embedding,
payload=chunk_metadata payload=chunk_metadata
)) ))
@@ -1261,6 +1368,7 @@ class RAGModule(BaseModule):
# Create document points with enhanced metadata # Create document points with enhanced metadata
points = [] points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
chunk_id = str(uuid.uuid4()) chunk_id = str(uuid.uuid4())
chunk_metadata = { chunk_metadata = {
@@ -1284,7 +1392,7 @@ class RAGModule(BaseModule):
points.append(PointStruct( points.append(PointStruct(
id=chunk_id, id=chunk_id,
vector=embedding, vector=aligned_embedding,
payload=chunk_metadata payload=chunk_metadata
)) ))
@@ -1549,6 +1657,8 @@ class RAGModule(BaseModule):
optimized_query = f"query: {query}" optimized_query = f"query: {query}"
query_embedding = await self._generate_embedding(optimized_query) query_embedding = await self._generate_embedding(optimized_query)
query_embedding = self._align_embedding_dimension(query_embedding, collection_name)
# Build filter # Build filter
search_filter = None search_filter = None
if filters: if filters:

View File

@@ -87,6 +87,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
const [messages, setMessages] = useState<ChatMessage[]>([]) const [messages, setMessages] = useState<ChatMessage[]>([])
const [input, setInput] = useState("") const [input, setInput] = useState("")
const [isLoading, setIsLoading] = useState(false) const [isLoading, setIsLoading] = useState(false)
const [conversationId, setConversationId] = useState<string | undefined>(undefined)
const scrollAreaRef = useRef<HTMLDivElement>(null) const scrollAreaRef = useRef<HTMLDivElement>(null)
const { success: toastSuccess, error: toastError } = useToast() const { success: toastSuccess, error: toastError } = useToast()
@@ -103,6 +104,12 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
scrollToBottom() scrollToBottom()
}, [messages]) }, [messages])
useEffect(() => {
// Reset conversation when switching chatbots
setMessages([])
setConversationId(undefined)
}, [chatbotId])
const sendMessage = useCallback(async () => { const sendMessage = useCallback(async () => {
if (!input.trim() || isLoading) return if (!input.trim() || isLoading) return
@@ -119,12 +126,13 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
setIsLoading(true) setIsLoading(true)
// Enhanced logging for debugging // Enhanced logging for debugging
const currentConversationId = conversationId
const debugInfo = { const debugInfo = {
chatbotId, chatbotId,
messageLength: messageToSend.length, messageLength: messageToSend.length,
conversationId, conversationId: currentConversationId ?? null,
timestamp: new Date().toISOString(), timestamp: new Date().toISOString(),
messagesCount: messages.length messagesCount: messages.length + 1
} }
console.log('=== CHAT REQUEST DEBUG ===', debugInfo) console.log('=== CHAT REQUEST DEBUG ===', debugInfo)
@@ -132,7 +140,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
let data: any let data: any
// Use internal API // Use internal API
const conversationHistory = messages.map(msg => ({ const conversationHistory = [...messages, userMessage].map(msg => ({
role: msg.role, role: msg.role,
content: msg.content content: msg.content
})) }))
@@ -140,7 +148,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
data = await chatbotApi.sendMessage( data = await chatbotApi.sendMessage(
chatbotId, chatbotId,
messageToSend, messageToSend,
undefined, // No conversation ID currentConversationId,
conversationHistory conversationHistory
) )
@@ -154,6 +162,11 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
setMessages(prev => [...prev, assistantMessage]) setMessages(prev => [...prev, assistantMessage])
const newConversationId = data?.conversation_id || currentConversationId
if (newConversationId !== conversationId) {
setConversationId(newConversationId)
}
} catch (error) { } catch (error) {
const appError = error as AppError const appError = error as AppError
@@ -168,7 +181,7 @@ export function ChatInterface({ chatbotId, chatbotName, onClose }: ChatInterface
} finally { } finally {
setIsLoading(false) setIsLoading(false)
} }
}, [input, isLoading, chatbotId, messages, toastError]) }, [input, isLoading, chatbotId, messages, toastError, conversationId])
const handleKeyPress = useCallback((e: React.KeyboardEvent) => { const handleKeyPress = useCallback((e: React.KeyboardEvent) => {
if (e.key === 'Enter' && !e.shiftKey) { if (e.key === 'Enter' && !e.shiftKey) {