swapping to local embeddings

This commit is contained in:
2025-10-06 17:28:55 +02:00
parent bae86fb5a2
commit 487d7d0cc9
8 changed files with 398 additions and 319 deletions

View File

@@ -8,6 +8,7 @@ import json
import logging
import mimetypes
import re
import time
from typing import Any, Dict, List, Optional, Tuple, Union
from datetime import datetime
from dataclasses import dataclass, asdict
@@ -146,6 +147,11 @@ class RAGModule(BaseModule):
# Update with any provided config
if config:
self.config.update(config)
# Ensure embedding model configured (defaults to local BGE small)
default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en')
self.config.setdefault("embedding_model", default_embedding_model)
self.default_embedding_model = default_embedding_model
# Content processing components
self.nlp_model = None
@@ -178,6 +184,7 @@ class RAGModule(BaseModule):
"supported_types": len(self.supported_types)
}
self.search_cache = {}
self.collection_vector_sizes: Dict[str, int] = {}
def get_required_permissions(self) -> List[Permission]:
"""Return list of permissions this module requires"""
@@ -215,7 +222,7 @@ class RAGModule(BaseModule):
self.initialized = True
log_module_event("rag", "initialized", {
"vector_db": self.config.get("vector_db", "qdrant"),
"embedding_model": self.embedding_model.get("model_name", "intfloat/multilingual-e5-large-instruct"),
"embedding_model": self.embedding_model.get("model_name", self.default_embedding_model),
"chunk_size": self.config.get("chunk_size", 400),
"max_results": self.config.get("max_results", 10),
"supported_file_types": list(self.supported_types.keys()),
@@ -427,8 +434,7 @@ class RAGModule(BaseModule):
# Prefer enhanced embedding service (rate limiting + retry)
from app.services.enhanced_embedding_service import enhanced_embedding_service as embedding_service
# Use intfloat/multilingual-e5-large-instruct for LLM service integration
model_name = self.config.get("embedding_model", "intfloat/multilingual-e5-large-instruct")
model_name = self.config.get("embedding_model", self.default_embedding_model)
embedding_service.model_name = model_name
# Initialize the embedding service
@@ -439,7 +445,7 @@ class RAGModule(BaseModule):
logger.info(f"Successfully initialized embedding service with {model_name}")
return {
"model_name": model_name,
"dimension": embedding_service.dimension or 768
"dimension": embedding_service.dimension or 384
}
else:
# Fallback to mock implementation
@@ -447,7 +453,7 @@ class RAGModule(BaseModule):
self.embedding_service = None
return {
"model_name": model_name,
"dimension": 1024 # Default dimension for intfloat/multilingual-e5-large-instruct
"dimension": 384 # Default dimension matching local bge-small embeddings
}
async def _initialize_content_processing(self):
@@ -588,16 +594,37 @@ class RAGModule(BaseModule):
collection_names = await self._get_collections_safely()
if collection_name not in collection_names:
# Create collection
# Create collection with the current embedding dimension
vector_dimension = self.embedding_model.get(
"dimension",
getattr(self.embedding_service, "dimension", 384) or 384
)
self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=self.embedding_model.get("dimension", 768),
size=vector_dimension,
distance=Distance.COSINE
)
)
self.collection_vector_sizes[collection_name] = vector_dimension
log_module_event("rag", "collection_created", {"collection": collection_name})
else:
# Cache existing collection vector size for later alignment
try:
info = self.qdrant_client.get_collection(collection_name)
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
existing_size = None
if vectors_param is not None and hasattr(vectors_param, "size"):
existing_size = vectors_param.size
elif isinstance(vectors_param, dict):
existing_size = vectors_param.get("size")
if existing_size:
self.collection_vector_sizes[collection_name] = existing_size
except Exception as inner_error:
logger.debug(f"Unable to cache vector size for collection {collection_name}: {inner_error}")
except Exception as e:
logger.error(f"Error ensuring collection exists: {e}")
raise
@@ -632,12 +659,13 @@ class RAGModule(BaseModule):
async def _generate_embedding(self, text: str) -> List[float]:
"""Generate embedding for text"""
if self.embedding_service:
# Use real embedding service
return await self.embedding_service.get_embedding(text)
vector = await self.embedding_service.get_embedding(text)
return vector
else:
# Fallback to deterministic random embedding for consistency
np.random.seed(hash(text) % 2**32)
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 384) or 384)
return np.random.random(fallback_dim).tolist()
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
"""Generate embeddings for multiple texts (batch processing)"""
@@ -651,14 +679,90 @@ class RAGModule(BaseModule):
prefixed_texts = texts
# Use real embedding service for batch processing
return await self.embedding_service.get_embeddings(prefixed_texts)
backend = getattr(self.embedding_service, "backend", "unknown")
start_time = time.time()
logger.info(
"Embedding batch requested",
extra={
"backend": backend,
"model": getattr(self.embedding_service, "model_name", "unknown"),
"count": len(prefixed_texts),
"scope": "documents" if is_document else "queries"
},
)
embeddings = await self.embedding_service.get_embeddings(prefixed_texts)
duration = time.time() - start_time
logger.info(
"Embedding batch finished",
extra={
"backend": backend,
"model": getattr(self.embedding_service, "model_name", "unknown"),
"count": len(embeddings),
"scope": "documents" if is_document else "queries",
"duration_sec": round(duration, 4)
},
)
return embeddings
else:
# Fallback to individual processing
logger.warning(
"Embedding service unavailable, falling back to per-item generation",
extra={
"count": len(texts),
"scope": "documents" if is_document else "queries"
},
)
embeddings = []
for text in texts:
embedding = await self._generate_embedding(text)
embeddings.append(embedding)
return embeddings
def _get_collection_vector_size(self, collection_name: Optional[str]) -> int:
"""Return the expected vector size for a collection, caching results."""
default_dim = self.embedding_model.get(
"dimension",
getattr(self.embedding_service, "dimension", 384) or 384
)
if not collection_name:
return default_dim
if collection_name in self.collection_vector_sizes:
return self.collection_vector_sizes[collection_name]
try:
info = self.qdrant_client.get_collection(collection_name)
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
existing_size = None
if vectors_param is not None and hasattr(vectors_param, "size"):
existing_size = vectors_param.size
elif isinstance(vectors_param, dict):
existing_size = vectors_param.get("size")
if existing_size:
self.collection_vector_sizes[collection_name] = existing_size
return existing_size
except Exception as e:
logger.debug(f"Unable to determine vector size for {collection_name}: {e}")
self.collection_vector_sizes[collection_name] = default_dim
return default_dim
def _align_embedding_dimension(self, vector: List[float], collection_name: Optional[str]) -> List[float]:
"""Pad or truncate embeddings to match the target collection dimension."""
if vector is None:
return vector
target_dim = self._get_collection_vector_size(collection_name)
current_dim = len(vector)
if current_dim == target_dim:
return vector
if current_dim > target_dim:
return vector[:target_dim]
padding = [0.0] * (target_dim - current_dim)
return vector + padding
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
"""Split text into overlapping chunks for better context preservation"""
@@ -1176,6 +1280,7 @@ class RAGModule(BaseModule):
# Create document points
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
chunk_id = str(uuid.uuid4())
chunk_metadata = {
@@ -1189,7 +1294,7 @@ class RAGModule(BaseModule):
points.append(PointStruct(
id=chunk_id,
vector=embedding,
vector=aligned_embedding,
payload=chunk_metadata
))
@@ -1257,6 +1362,7 @@ class RAGModule(BaseModule):
# Create document points with enhanced metadata
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
chunk_id = str(uuid.uuid4())
chunk_metadata = {
@@ -1280,7 +1386,7 @@ class RAGModule(BaseModule):
points.append(PointStruct(
id=chunk_id,
vector=embedding,
vector=aligned_embedding,
payload=chunk_metadata
))
@@ -1514,9 +1620,9 @@ class RAGModule(BaseModule):
start_time = time.time()
# Generate query embedding with task-specific prefix for better retrieval
# The E5 model works better with "query:" prefix for search queries
optimized_query = f"query: {query}"
query_embedding = await self._generate_embedding(optimized_query)
query_embedding = self._align_embedding_dimension(query_embedding, collection_name)
# Build filter
search_filter = None