mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
swapping to local embeddings
This commit is contained in:
@@ -1,148 +1,133 @@
|
||||
"""
|
||||
Embedding Service
|
||||
Provides text embedding functionality using LLM service
|
||||
Provides local sentence-transformer embeddings (default: BAAI/bge-small-en).
|
||||
Falls back to deterministic random vectors when the local model is unavailable.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""Service for generating text embeddings using LLM service"""
|
||||
|
||||
def __init__(self, model_name: str = "intfloat/multilingual-e5-large-instruct"):
|
||||
self.model_name = model_name
|
||||
self.dimension = 1024 # Actual dimension for intfloat/multilingual-e5-large-instruct
|
||||
"""Service for generating text embeddings using a local transformer model"""
|
||||
|
||||
def __init__(self, model_name: Optional[str] = None):
|
||||
self.model_name = model_name or getattr(settings, "RAG_EMBEDDING_MODEL", "BAAI/bge-small-en")
|
||||
self.dimension = 384 # bge-small produces 384-d vectors
|
||||
self.initialized = False
|
||||
|
||||
self.local_model = None
|
||||
self.backend = "uninitialized"
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the embedding service with LLM service"""
|
||||
try:
|
||||
from app.services.llm.service import llm_service
|
||||
|
||||
# Initialize LLM service if not already done
|
||||
if not llm_service._initialized:
|
||||
await llm_service.initialize()
|
||||
|
||||
# Test LLM service health
|
||||
if not llm_service._initialized:
|
||||
logger.error("LLM service not initialized")
|
||||
return False
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Check if PrivateMode provider is available
|
||||
try:
|
||||
provider_status = await llm_service.get_provider_status()
|
||||
privatemode_status = provider_status.get("privatemode")
|
||||
if not privatemode_status or privatemode_status.status != "healthy":
|
||||
logger.error(f"PrivateMode provider not available: {privatemode_status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check provider status: {e}")
|
||||
return False
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def load_model():
|
||||
# Load model synchronously in a worker thread to avoid blocking event loop
|
||||
return SentenceTransformer(self.model_name)
|
||||
|
||||
self.local_model = await loop.run_in_executor(None, load_model)
|
||||
self.dimension = self.local_model.get_sentence_embedding_dimension()
|
||||
self.initialized = True
|
||||
logger.info(f"Embedding service initialized with LLM service: {self.model_name} (dimension: {self.dimension})")
|
||||
self.backend = "sentence_transformer"
|
||||
logger.info(
|
||||
"Embedding service initialized with local model %s (dimension: %s)",
|
||||
self.model_name,
|
||||
self.dimension,
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM embedding service: {e}")
|
||||
logger.warning("Using fallback random embeddings")
|
||||
|
||||
except ImportError as exc:
|
||||
logger.error("sentence-transformers not installed: %s", exc)
|
||||
logger.warning("Falling back to random embeddings")
|
||||
self.local_model = None
|
||||
self.initialized = False
|
||||
self.backend = "fallback_random"
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to load local embedding model {self.model_name}: {exc}")
|
||||
logger.warning("Falling back to random embeddings")
|
||||
self.local_model = None
|
||||
self.initialized = False
|
||||
self.backend = "fallback_random"
|
||||
return False
|
||||
|
||||
async def get_embedding(self, text: str) -> List[float]:
|
||||
"""Get embedding for a single text"""
|
||||
embeddings = await self.get_embeddings([text])
|
||||
return embeddings[0]
|
||||
|
||||
|
||||
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get embeddings for multiple texts using LLM service"""
|
||||
if not self.initialized:
|
||||
# Fallback to random embeddings if not initialized
|
||||
logger.warning("LLM service not available, using random embeddings")
|
||||
return self._generate_fallback_embeddings(texts)
|
||||
|
||||
try:
|
||||
embeddings = []
|
||||
# Process texts in batches for efficiency
|
||||
batch_size = 10
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
|
||||
# Process each text in the batch
|
||||
batch_embeddings = []
|
||||
for text in batch:
|
||||
try:
|
||||
# Truncate text if it's too long for the model's context window
|
||||
# intfloat/multilingual-e5-large-instruct has a 512 token limit, truncate to ~400 tokens worth of chars
|
||||
# Rough estimate: 1 token ≈ 4 characters, so 400 tokens ≈ 1600 chars
|
||||
max_chars = 1600
|
||||
if len(text) > max_chars:
|
||||
truncated_text = text[:max_chars]
|
||||
logger.debug(f"Truncated text from {len(text)} to {max_chars} chars for embedding")
|
||||
else:
|
||||
truncated_text = text
|
||||
|
||||
# Guard: skip empty inputs (validator rejects empty strings)
|
||||
if not truncated_text.strip():
|
||||
logger.debug("Empty input for embedding; using fallback vector")
|
||||
batch_embeddings.append(self._generate_fallback_embedding(text))
|
||||
continue
|
||||
start_time = time.time()
|
||||
|
||||
# Call LLM service embedding endpoint
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import EmbeddingRequest
|
||||
|
||||
llm_request = EmbeddingRequest(
|
||||
model=self.model_name,
|
||||
input=truncated_text,
|
||||
user_id="rag_system",
|
||||
api_key_id=0 # System API key
|
||||
)
|
||||
|
||||
response = await llm_service.create_embedding(llm_request)
|
||||
|
||||
# Extract embedding from response
|
||||
if response.data and len(response.data) > 0:
|
||||
embedding = response.data[0].embedding
|
||||
if embedding:
|
||||
batch_embeddings.append(embedding)
|
||||
# Update dimension based on actual embedding size
|
||||
if not hasattr(self, '_dimension_confirmed'):
|
||||
self.dimension = len(embedding)
|
||||
self._dimension_confirmed = True
|
||||
logger.info(f"Confirmed embedding dimension: {self.dimension}")
|
||||
else:
|
||||
logger.warning(f"No embedding in response for text: {text[:50]}...")
|
||||
batch_embeddings.append(self._generate_fallback_embedding(text))
|
||||
else:
|
||||
logger.warning(f"Invalid response structure for text: {text[:50]}...")
|
||||
batch_embeddings.append(self._generate_fallback_embedding(text))
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting embedding for text: {e}")
|
||||
batch_embeddings.append(self._generate_fallback_embedding(text))
|
||||
|
||||
embeddings.extend(batch_embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings with LLM service: {e}")
|
||||
# Fallback to random embeddings
|
||||
return self._generate_fallback_embeddings(texts)
|
||||
if self.local_model:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try:
|
||||
embeddings = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.local_model.encode(
|
||||
texts,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
),
|
||||
)
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
"Embedding batch completed",
|
||||
extra={
|
||||
"backend": self.backend,
|
||||
"model": self.model_name,
|
||||
"count": len(texts),
|
||||
"dimension": self.dimension,
|
||||
"duration_sec": round(duration, 4),
|
||||
},
|
||||
)
|
||||
return embeddings.tolist()
|
||||
except Exception as exc:
|
||||
logger.error(f"Local embedding generation failed: {exc}")
|
||||
self.backend = "fallback_random"
|
||||
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
|
||||
|
||||
logger.warning("Local embedding model unavailable; using fallback random embeddings")
|
||||
self.backend = "fallback_random"
|
||||
return self._generate_fallback_embeddings(texts, duration=time.time() - start_time)
|
||||
|
||||
def _generate_fallback_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
def _generate_fallback_embeddings(self, texts: List[str], duration: float = None) -> List[List[float]]:
|
||||
"""Generate fallback random embeddings when model unavailable"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
embeddings.append(self._generate_fallback_embedding(text))
|
||||
logger.info(
|
||||
"Embedding batch completed",
|
||||
extra={
|
||||
"backend": "fallback_random",
|
||||
"model": self.model_name,
|
||||
"count": len(texts),
|
||||
"dimension": self.dimension,
|
||||
"duration_sec": round(duration or 0.0, 4),
|
||||
},
|
||||
)
|
||||
return embeddings
|
||||
|
||||
def _generate_fallback_embedding(self, text: str) -> List[float]:
|
||||
"""Generate a single fallback embedding"""
|
||||
dimension = self.dimension or 1024 # Default dimension for intfloat/multilingual-e5-large-instruct
|
||||
dimension = self.dimension or 384
|
||||
# Use hash for reproducible random embeddings
|
||||
np.random.seed(hash(text) % 2**32)
|
||||
return np.random.random(dimension).tolist()
|
||||
@@ -169,22 +154,15 @@ class EmbeddingService:
|
||||
"model_name": self.model_name,
|
||||
"model_loaded": self.initialized,
|
||||
"dimension": self.dimension,
|
||||
"backend": "LLM Service",
|
||||
"backend": self.backend,
|
||||
"initialized": self.initialized
|
||||
}
|
||||
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
# Cleanup LLM service to prevent memory leaks
|
||||
try:
|
||||
from .llm.service import llm_service
|
||||
if llm_service._initialized:
|
||||
await llm_service.cleanup()
|
||||
logger.info("Cleaned up LLM service from embedding service")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up LLM service: {e}")
|
||||
|
||||
self.local_model = None
|
||||
self.initialized = False
|
||||
self.backend = "uninitialized"
|
||||
|
||||
|
||||
# Global embedding service instance
|
||||
|
||||
Reference in New Issue
Block a user