mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 07:24:34 +01:00
removing lite llm and going directly for privatemode
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Embedding Service
|
||||
Provides text embedding functionality using LiteLLM proxy
|
||||
Provides text embedding functionality using LLM service
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -11,32 +11,34 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""Service for generating text embeddings using LiteLLM"""
|
||||
"""Service for generating text embeddings using LLM service"""
|
||||
|
||||
def __init__(self, model_name: str = "privatemode-embeddings"):
|
||||
self.model_name = model_name
|
||||
self.litellm_client = None
|
||||
self.dimension = 1024 # Actual dimension for privatemode-embeddings
|
||||
self.initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the embedding service with LiteLLM"""
|
||||
"""Initialize the embedding service with LLM service"""
|
||||
try:
|
||||
from app.services.litellm_client import litellm_client
|
||||
self.litellm_client = litellm_client
|
||||
from app.services.llm.service import llm_service
|
||||
|
||||
# Test connection to LiteLLM
|
||||
health = await self.litellm_client.health_check()
|
||||
if health.get("status") == "unhealthy":
|
||||
logger.error(f"LiteLLM service unhealthy: {health.get('error')}")
|
||||
# Initialize LLM service if not already done
|
||||
if not llm_service._initialized:
|
||||
await llm_service.initialize()
|
||||
|
||||
# Test LLM service health
|
||||
health_summary = llm_service.get_health_summary()
|
||||
if health_summary.get("service_status") != "healthy":
|
||||
logger.error(f"LLM service unhealthy: {health_summary}")
|
||||
return False
|
||||
|
||||
self.initialized = True
|
||||
logger.info(f"Embedding service initialized with LiteLLM: {self.model_name} (dimension: {self.dimension})")
|
||||
logger.info(f"Embedding service initialized with LLM service: {self.model_name} (dimension: {self.dimension})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LiteLLM embedding service: {e}")
|
||||
logger.error(f"Failed to initialize LLM embedding service: {e}")
|
||||
logger.warning("Using fallback random embeddings")
|
||||
return False
|
||||
|
||||
@@ -46,10 +48,10 @@ class EmbeddingService:
|
||||
return embeddings[0]
|
||||
|
||||
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get embeddings for multiple texts using LiteLLM"""
|
||||
if not self.initialized or not self.litellm_client:
|
||||
"""Get embeddings for multiple texts using LLM service"""
|
||||
if not self.initialized:
|
||||
# Fallback to random embeddings if not initialized
|
||||
logger.warning("LiteLLM not available, using random embeddings")
|
||||
logger.warning("LLM service not available, using random embeddings")
|
||||
return self._generate_fallback_embeddings(texts)
|
||||
|
||||
try:
|
||||
@@ -73,17 +75,22 @@ class EmbeddingService:
|
||||
else:
|
||||
truncated_text = text
|
||||
|
||||
# Call LiteLLM embedding endpoint
|
||||
response = await self.litellm_client.create_embedding(
|
||||
# Call LLM service embedding endpoint
|
||||
from app.services.llm.service import llm_service
|
||||
from app.services.llm.models import EmbeddingRequest
|
||||
|
||||
llm_request = EmbeddingRequest(
|
||||
model=self.model_name,
|
||||
input_text=truncated_text,
|
||||
input=truncated_text,
|
||||
user_id="rag_system",
|
||||
api_key_id=0 # System API key
|
||||
)
|
||||
|
||||
response = await llm_service.create_embedding(llm_request)
|
||||
|
||||
# Extract embedding from response
|
||||
if "data" in response and len(response["data"]) > 0:
|
||||
embedding = response["data"][0].get("embedding", [])
|
||||
if response.data and len(response.data) > 0:
|
||||
embedding = response.data[0].embedding
|
||||
if embedding:
|
||||
batch_embeddings.append(embedding)
|
||||
# Update dimension based on actual embedding size
|
||||
@@ -106,7 +113,7 @@ class EmbeddingService:
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings with LiteLLM: {e}")
|
||||
logger.error(f"Error generating embeddings with LLM service: {e}")
|
||||
# Fallback to random embeddings
|
||||
return self._generate_fallback_embeddings(texts)
|
||||
|
||||
@@ -146,14 +153,13 @@ class EmbeddingService:
|
||||
"model_name": self.model_name,
|
||||
"model_loaded": self.initialized,
|
||||
"dimension": self.dimension,
|
||||
"backend": "LiteLLM",
|
||||
"backend": "LLM Service",
|
||||
"initialized": self.initialized
|
||||
}
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
self.initialized = False
|
||||
self.litellm_client = None
|
||||
|
||||
|
||||
# Global embedding service instance
|
||||
|
||||
Reference in New Issue
Block a user