mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
swapping to local embeddings
This commit is contained in:
@@ -8,6 +8,7 @@ import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, asdict
|
||||
@@ -146,6 +147,11 @@ class RAGModule(BaseModule):
|
||||
# Update with any provided config
|
||||
if config:
|
||||
self.config.update(config)
|
||||
|
||||
# Ensure embedding model configured (defaults to local BGE small)
|
||||
default_embedding_model = getattr(settings, 'RAG_EMBEDDING_MODEL', 'BAAI/bge-small-en')
|
||||
self.config.setdefault("embedding_model", default_embedding_model)
|
||||
self.default_embedding_model = default_embedding_model
|
||||
|
||||
# Content processing components
|
||||
self.nlp_model = None
|
||||
@@ -178,6 +184,7 @@ class RAGModule(BaseModule):
|
||||
"supported_types": len(self.supported_types)
|
||||
}
|
||||
self.search_cache = {}
|
||||
self.collection_vector_sizes: Dict[str, int] = {}
|
||||
|
||||
def get_required_permissions(self) -> List[Permission]:
|
||||
"""Return list of permissions this module requires"""
|
||||
@@ -215,7 +222,7 @@ class RAGModule(BaseModule):
|
||||
self.initialized = True
|
||||
log_module_event("rag", "initialized", {
|
||||
"vector_db": self.config.get("vector_db", "qdrant"),
|
||||
"embedding_model": self.embedding_model.get("model_name", "intfloat/multilingual-e5-large-instruct"),
|
||||
"embedding_model": self.embedding_model.get("model_name", self.default_embedding_model),
|
||||
"chunk_size": self.config.get("chunk_size", 400),
|
||||
"max_results": self.config.get("max_results", 10),
|
||||
"supported_file_types": list(self.supported_types.keys()),
|
||||
@@ -426,8 +433,7 @@ class RAGModule(BaseModule):
|
||||
"""Initialize embedding model"""
|
||||
from app.services.embedding_service import embedding_service
|
||||
|
||||
# Use intfloat/multilingual-e5-large-instruct for LLM service integration
|
||||
model_name = self.config.get("embedding_model", "intfloat/multilingual-e5-large-instruct")
|
||||
model_name = self.config.get("embedding_model", self.default_embedding_model)
|
||||
embedding_service.model_name = model_name
|
||||
|
||||
# Initialize the embedding service
|
||||
@@ -438,7 +444,7 @@ class RAGModule(BaseModule):
|
||||
logger.info(f"Successfully initialized embedding service with {model_name}")
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"dimension": embedding_service.dimension or 768
|
||||
"dimension": embedding_service.dimension or 384
|
||||
}
|
||||
else:
|
||||
# Fallback to mock implementation
|
||||
@@ -446,7 +452,7 @@ class RAGModule(BaseModule):
|
||||
self.embedding_service = None
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"dimension": 1024 # Default dimension for intfloat/multilingual-e5-large-instruct
|
||||
"dimension": 384 # Default dimension matching local bge-small embeddings
|
||||
}
|
||||
|
||||
async def _initialize_content_processing(self):
|
||||
@@ -585,18 +591,39 @@ class RAGModule(BaseModule):
|
||||
try:
|
||||
# Use safe collection fetching to avoid Pydantic validation errors
|
||||
collection_names = await self._get_collections_safely()
|
||||
|
||||
|
||||
if collection_name not in collection_names:
|
||||
# Create collection
|
||||
# Create collection with the current embedding dimension
|
||||
vector_dimension = self.embedding_model.get(
|
||||
"dimension",
|
||||
getattr(self.embedding_service, "dimension", 384) or 384
|
||||
)
|
||||
|
||||
self.qdrant_client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=self.embedding_model.get("dimension", 768),
|
||||
size=vector_dimension,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
self.collection_vector_sizes[collection_name] = vector_dimension
|
||||
log_module_event("rag", "collection_created", {"collection": collection_name})
|
||||
|
||||
else:
|
||||
# Cache existing collection vector size for later alignment
|
||||
try:
|
||||
info = self.qdrant_client.get_collection(collection_name)
|
||||
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
|
||||
existing_size = None
|
||||
if vectors_param is not None and hasattr(vectors_param, "size"):
|
||||
existing_size = vectors_param.size
|
||||
elif isinstance(vectors_param, dict):
|
||||
existing_size = vectors_param.get("size")
|
||||
|
||||
if existing_size:
|
||||
self.collection_vector_sizes[collection_name] = existing_size
|
||||
except Exception as inner_error:
|
||||
logger.debug(f"Unable to cache vector size for collection {collection_name}: {inner_error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring collection exists: {e}")
|
||||
raise
|
||||
@@ -632,11 +659,13 @@ class RAGModule(BaseModule):
|
||||
"""Generate embedding for text"""
|
||||
if self.embedding_service:
|
||||
# Use real embedding service
|
||||
return await self.embedding_service.get_embedding(text)
|
||||
vector = await self.embedding_service.get_embedding(text)
|
||||
return vector
|
||||
else:
|
||||
# Fallback to deterministic random embedding for consistency
|
||||
np.random.seed(hash(text) % 2**32)
|
||||
return np.random.random(self.embedding_model.get("dimension", 768)).tolist()
|
||||
fallback_dim = self.embedding_model.get("dimension", getattr(self.embedding_service, "dimension", 384) or 384)
|
||||
return np.random.random(fallback_dim).tolist()
|
||||
|
||||
async def _generate_embeddings(self, texts: List[str], is_document: bool = True) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts (batch processing)"""
|
||||
@@ -650,14 +679,91 @@ class RAGModule(BaseModule):
|
||||
prefixed_texts = texts
|
||||
|
||||
# Use real embedding service for batch processing
|
||||
return await self.embedding_service.get_embeddings(prefixed_texts)
|
||||
backend = getattr(self.embedding_service, "backend", "unknown")
|
||||
start_time = time.time()
|
||||
logger.info(
|
||||
"Embedding batch requested",
|
||||
extra={
|
||||
"backend": backend,
|
||||
"model": getattr(self.embedding_service, "model_name", "unknown"),
|
||||
"count": len(prefixed_texts),
|
||||
"scope": "documents" if is_document else "queries"
|
||||
},
|
||||
)
|
||||
embeddings = await self.embedding_service.get_embeddings(prefixed_texts)
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
"Embedding batch finished",
|
||||
extra={
|
||||
"backend": backend,
|
||||
"model": getattr(self.embedding_service, "model_name", "unknown"),
|
||||
"count": len(embeddings),
|
||||
"scope": "documents" if is_document else "queries",
|
||||
"duration_sec": round(duration, 4)
|
||||
},
|
||||
)
|
||||
return embeddings
|
||||
else:
|
||||
# Fallback to individual processing
|
||||
logger.warning(
|
||||
"Embedding service unavailable, falling back to per-item generation",
|
||||
extra={
|
||||
"count": len(texts),
|
||||
"scope": "documents" if is_document else "queries"
|
||||
},
|
||||
)
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embedding = await self._generate_embedding(text)
|
||||
embeddings.append(embedding)
|
||||
return embeddings
|
||||
|
||||
def _get_collection_vector_size(self, collection_name: Optional[str]) -> int:
|
||||
"""Return the expected vector size for a collection, caching results."""
|
||||
default_dim = self.embedding_model.get(
|
||||
"dimension",
|
||||
getattr(self.embedding_service, "dimension", 384) or 384
|
||||
)
|
||||
|
||||
if not collection_name:
|
||||
return default_dim
|
||||
|
||||
if collection_name in self.collection_vector_sizes:
|
||||
return self.collection_vector_sizes[collection_name]
|
||||
|
||||
try:
|
||||
info = self.qdrant_client.get_collection(collection_name)
|
||||
vectors_param = getattr(info.config.params, "vectors", None) if hasattr(info, "config") else None
|
||||
existing_size = None
|
||||
if vectors_param is not None and hasattr(vectors_param, "size"):
|
||||
existing_size = vectors_param.size
|
||||
elif isinstance(vectors_param, dict):
|
||||
existing_size = vectors_param.get("size")
|
||||
|
||||
if existing_size:
|
||||
self.collection_vector_sizes[collection_name] = existing_size
|
||||
return existing_size
|
||||
except Exception as e:
|
||||
logger.debug(f"Unable to determine vector size for {collection_name}: {e}")
|
||||
|
||||
self.collection_vector_sizes[collection_name] = default_dim
|
||||
return default_dim
|
||||
|
||||
def _align_embedding_dimension(self, vector: List[float], collection_name: Optional[str]) -> List[float]:
|
||||
"""Pad or truncate embeddings to match the target collection dimension."""
|
||||
if vector is None:
|
||||
return vector
|
||||
|
||||
target_dim = self._get_collection_vector_size(collection_name)
|
||||
current_dim = len(vector)
|
||||
|
||||
if current_dim == target_dim:
|
||||
return vector
|
||||
if current_dim > target_dim:
|
||||
return vector[:target_dim]
|
||||
# Pad with zeros to reach the target dimension
|
||||
padding = [0.0] * (target_dim - current_dim)
|
||||
return vector + padding
|
||||
|
||||
def _chunk_text(self, text: str, chunk_size: int = None) -> List[str]:
|
||||
"""Split text into overlapping chunks for better context preservation"""
|
||||
@@ -1180,8 +1286,9 @@ class RAGModule(BaseModule):
|
||||
# Create document points
|
||||
points = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
|
||||
chunk_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
chunk_metadata = {
|
||||
**metadata,
|
||||
"document_id": doc_id,
|
||||
@@ -1193,7 +1300,7 @@ class RAGModule(BaseModule):
|
||||
|
||||
points.append(PointStruct(
|
||||
id=chunk_id,
|
||||
vector=embedding,
|
||||
vector=aligned_embedding,
|
||||
payload=chunk_metadata
|
||||
))
|
||||
|
||||
@@ -1261,8 +1368,9 @@ class RAGModule(BaseModule):
|
||||
# Create document points with enhanced metadata
|
||||
points = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
aligned_embedding = self._align_embedding_dimension(embedding, collection_name)
|
||||
chunk_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
chunk_metadata = {
|
||||
**processed_doc.metadata,
|
||||
"document_id": processed_doc.id,
|
||||
@@ -1284,7 +1392,7 @@ class RAGModule(BaseModule):
|
||||
|
||||
points.append(PointStruct(
|
||||
id=chunk_id,
|
||||
vector=embedding,
|
||||
vector=aligned_embedding,
|
||||
payload=chunk_metadata
|
||||
))
|
||||
|
||||
@@ -1548,6 +1656,8 @@ class RAGModule(BaseModule):
|
||||
logger.warning(f"sentence-transformers not available, falling back to default embedding for {collection_name}")
|
||||
optimized_query = f"query: {query}"
|
||||
query_embedding = await self._generate_embedding(optimized_query)
|
||||
|
||||
query_embedding = self._align_embedding_dimension(query_embedding, collection_name)
|
||||
|
||||
# Build filter
|
||||
search_filter = None
|
||||
|
||||
Reference in New Issue
Block a user