swapping to local embeddings

This commit is contained in:
2025-10-06 17:28:55 +02:00
parent bae86fb5a2
commit 487d7d0cc9
8 changed files with 398 additions and 319 deletions

View File

@@ -1,148 +1,133 @@
"""
Embedding Service
Provides text embedding functionality using LLM service
Provides local sentence-transformer embeddings (default: BAAI/bge-small-en).
Falls back to deterministic random vectors when the local model is unavailable.
"""
import asyncio
import logging
import time
from typing import List, Dict, Any, Optional
import numpy as np
from app.core.config import settings
logger = logging.getLogger(__name__)
class EmbeddingService:
"""Service for generating text embeddings using LLM service"""
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"):
self.model_name = model_name
self.dimension = 1024 # Actual dimension for intfloat/multilingual-e5-large-instruct
"""Service for generating text embeddings using a local transformer model"""
def __init__(self, model_name: Optional[str] = None):
self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-small-en")
self.dimension = 384 # bge-small produces 384-d vectors
self.initialized = False
self.local_model = None
self.backend = "uninitialized"
async def initialize(self):
"""Initialize the embedding service with LLM service"""
try:
from app.services.llm.service import llm_service
# Initialize LLM service if not already done
if not llm_service._initialized:
await llm_service.initialize()
# Test LLM service health
if not llm_service._initialized:
logger.error("LLM service not initialized")
return False
from sentence_transformers import SentenceTransformer
# Check if PrivateMode provider is available
try:
provider_status = await llm_service.get_provider_status()
privatemode_status = provider_status.get("privatemode")
if not privatemode_status or privatemode_status.status != "healthy":
logger.error(f"PrivateMode provider not available: {privatemode_status}")
return False
except Exception as e:
logger.error(f"Failed to check provider status: {e}")
return False
loop = asyncio.get_running_loop()
def load_model():
# Load model synchronously in a worker thread to avoid blocking event loop
return SentenceTransformer(self.model_name)
self.local_model = await loop.run_in_executor(None, load_model)
self.dimension = self.local_model.get_sentence_embedding_dimension()
self.initialized = True
logger.info(f"Embedding service initialized with LLM service: {self.model_name} (dimension: {self.dimension})")
self.backend = "sentence_transformer"
logger.info(
"Embedding service initialized with local model %s (dimension: %s)",
self.model_name,
self.dimension,
)
return True
except Exception as e:
logger.error(f"Failed to initialize LLM embedding service: {e}")
logger.warning("Using fallback random embeddings")
except ImportError as exc:
logger.error("sentence-transformers not installed: %s", exc)
logger.warning("Falling back to random embeddings")
self.local_model = None
self.initialized = False
self.backend = "fallback_random"
return False
except Exception as exc:
logger.error(f"Failed to load local embedding model {self.model_name}: {exc}")
logger.warning("Falling back to random embeddings")
self.local_model = None
self.initialized = False
self.backend = "fallback_random"
return False
async def get_embedding(self, text: str) -> List[float]:
"""Get embedding for a single text"""
embeddings = await self.get_embeddings([text])
return embeddings[0]
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for multiple texts using LLM service"""
if not self.initialized:
# Fallback to random embeddings if not initialized
logger.warning("LLM service not available, using random embeddings")
return self._generate_fallback_embeddings(texts)
try:
embeddings = []
# Process texts in batches for efficiency
batch_size = 10
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
# Process each text in the batch
batch_embeddings = []
for text in batch:
try:
# Truncate text if it's too long for the model's context window
# intfloat/multilingual-e5-large-instruct has a 512 token limit, truncate to ~400 tokens worth of chars
# Rough estimate: 1 token ≈ 4 characters, so 400 tokens ≈ 1600 chars
max_chars = 1600
if len(text) > max_chars:
truncated_text = text[:max_chars]
logger.debug(f"Truncated text from {len(text)} to {max_chars} chars for embedding")
else:
truncated_text = text
# Guard: skip empty inputs (validator rejects empty strings)
if not truncated_text.strip():
logger.debug("Empty input for embedding; using fallback vector")
batch_embeddings.append(self._generate_fallback_embedding(text))
continue
start_time = time.time()
# Call LLM service embedding endpoint
from app.services.llm.service import llm_service
from app.services.llm.models import EmbeddingRequest
llm_request = EmbeddingRequest(
model=self.model_name,
input=truncated_text,
user_id="rag_system",
api_key_id=0 # System API key
)
response = await llm_service.create_embedding(llm_request)
# Extract embedding from response
if response.data and len(response.data) > 0:
embedding = response.data[0].embedding
if embedding:
batch_embeddings.append(embedding)
# Update dimension based on actual embedding size
if not hasattr(self, '_dimension_confirmed'):
self.dimension = len(embedding)
self._dimension_confirmed = True
logger.info(f"Confirmed embedding dimension: {self.dimension}")
else:
logger.warning(f"No embedding in response for text: {text[:50]}...")
batch_embeddings.append(self._generate_fallback_embedding(text))
else:
logger.warning(f"Invalid response structure for text: {text[:50]}...")
batch_embeddings.append(self._generate_fallback_embedding(text))
except Exception as e:
logger.error(f"Error getting embedding for text: {e}")
batch_embeddings.append(self._generate_fallback_embedding(text))
embeddings.extend(batch_embeddings)
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings with LLM service: {e}")
# Fallback to random embeddings
return self._generate_fallback_embeddings(texts)
if self.local_model:
if not texts:
return []
loop = asyncio.get_running_loop()
try:
embeddings = await loop.run_in_executor(
None,
lambda: self.local_model.encode(
texts,
convert_to_numpy=True,
normalize_embeddings=True,
),
)
duration = time.time() - start_time
logger.info(
"Embedding batch completed",
extra={
"backend": self.backend,
"model": self.model_name,
"count": len(texts),
"dimension": self.dimension,
"duration_sec": round(duration, 4),
},
)
return embeddings.tolist()
except Exception as exc:
logger.error(f"Local embedding generation failed: {exc}")
self.backend = "fallback_random"
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
logger.warning("Local embedding model unavailable; using fallback random embeddings")
self.backend = "fallback_random"
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
def _generate_fallback_embeddings(self, texts: List[str]) -> List[List[float]]:
def _generate_fallback_embeddings(self, texts: List[str], duration: float = None) -> List[List[float]]:
"""Generate fallback random embeddings when model unavailable"""
embeddings = []
for text in texts:
embeddings.append(self._generate_fallback_embedding(text))
logger.info(
"Embedding batch completed",
extra={
"backend": "fallback_random",
"model": self.model_name,
"count": len(texts),
"dimension": self.dimension,
"duration_sec": round(duration or 0.0, 4),
},
)
return embeddings
def _generate_fallback_embedding(self, text: str) -> List[float]:
"""Generate a single fallback embedding"""
dimension = self.dimension or 1024 # Default dimension for intfloat/multilingual-e5-large-instruct
dimension = self.dimension or 384
# Use hash for reproducible random embeddings
np.random.seed(hash(text) % 2**32)
return np.random.random(dimension).tolist()
@@ -169,22 +154,15 @@ class EmbeddingService:
"model_name": self.model_name,
"model_loaded": self.initialized,
"dimension": self.dimension,
"backend": "LLM Service",
"backend": self.backend,
"initialized": self.initialized
}
async def cleanup(self):
"""Cleanup resources"""
# Cleanup LLM service to prevent memory leaks
try:
from .llm.service import llm_service
if llm_service._initialized:
await llm_service.cleanup()
logger.info("Cleaned up LLM service from embedding service")
except Exception as e:
logger.error(f"Error cleaning up LLM service: {e}")
self.local_model = None
self.initialized = False
self.backend = "uninitialized"
# Global embedding service instance