mirror of
https://github.com/aljazceru/enclava.git
synced 2025-12-17 23:44:24 +01:00
mega changes
This commit is contained in:
@@ -15,7 +15,9 @@ logger = logging.getLogger(__name__)
|
||||
class OllamaEmbeddingService:
|
||||
"""Service for generating text embeddings using Ollama"""
|
||||
|
||||
def __init__(self, model_name: str = "bge-m3", base_url: str = "http://172.17.0.1:11434"):
|
||||
def __init__(
|
||||
self, model_name: str = "bge-m3", base_url: str = "http://172.17.0.1:11434"
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url
|
||||
self.dimension = 1024 # bge-m3 dimension
|
||||
@@ -37,20 +39,28 @@ class OllamaEmbeddingService:
|
||||
return False
|
||||
|
||||
data = await resp.json()
|
||||
models = [model['name'].split(':')[0] for model in data.get('models', [])]
|
||||
models = [
|
||||
model["name"].split(":")[0] for model in data.get("models", [])
|
||||
]
|
||||
|
||||
if self.model_name not in models:
|
||||
logger.error(f"Model {self.model_name} not found in Ollama. Available: {models}")
|
||||
logger.error(
|
||||
f"Model {self.model_name} not found in Ollama. Available: {models}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Test embedding generation
|
||||
test_embedding = await self.get_embedding("test")
|
||||
if not test_embedding or len(test_embedding) != self.dimension:
|
||||
logger.error(f"Failed to generate test embedding with {self.model_name}")
|
||||
logger.error(
|
||||
f"Failed to generate test embedding with {self.model_name}"
|
||||
)
|
||||
return False
|
||||
|
||||
self.initialized = True
|
||||
logger.info(f"Ollama embedding service initialized with model: {self.model_name} (dimension: {self.dimension})")
|
||||
logger.info(
|
||||
f"Ollama embedding service initialized with model: {self.model_name} (dimension: {self.dimension})"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -85,27 +95,32 @@ class OllamaEmbeddingService:
|
||||
# Call Ollama embedding API
|
||||
async with self._session.post(
|
||||
f"{self.base_url}/api/embeddings",
|
||||
json={
|
||||
"model": self.model_name,
|
||||
"prompt": text
|
||||
}
|
||||
json={"model": self.model_name, "prompt": text},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"Ollama embedding request failed: {resp.status}")
|
||||
logger.error(
|
||||
f"Ollama embedding request failed: {resp.status}"
|
||||
)
|
||||
embeddings.append(self._generate_fallback_embedding(text))
|
||||
continue
|
||||
|
||||
result = await resp.json()
|
||||
|
||||
if 'embedding' in result:
|
||||
embedding = result['embedding']
|
||||
if "embedding" in result:
|
||||
embedding = result["embedding"]
|
||||
if len(embedding) == self.dimension:
|
||||
embeddings.append(embedding)
|
||||
else:
|
||||
logger.warning(f"Embedding dimension mismatch: expected {self.dimension}, got {len(embedding)}")
|
||||
embeddings.append(self._generate_fallback_embedding(text))
|
||||
logger.warning(
|
||||
f"Embedding dimension mismatch: expected {self.dimension}, got {len(embedding)}"
|
||||
)
|
||||
embeddings.append(
|
||||
self._generate_fallback_embedding(text)
|
||||
)
|
||||
else:
|
||||
logger.error(f"No embedding in Ollama response for text: {text[:50]}...")
|
||||
logger.error(
|
||||
f"No embedding in Ollama response for text: {text[:50]}..."
|
||||
)
|
||||
embeddings.append(self._generate_fallback_embedding(text))
|
||||
|
||||
except Exception as e:
|
||||
@@ -156,7 +171,7 @@ class OllamaEmbeddingService:
|
||||
"dimension": self.dimension,
|
||||
"backend": "Ollama",
|
||||
"base_url": self.base_url,
|
||||
"initialized": self.initialized
|
||||
"initialized": self.initialized,
|
||||
}
|
||||
|
||||
async def cleanup(self):
|
||||
@@ -167,4 +182,4 @@ class OllamaEmbeddingService:
|
||||
|
||||
|
||||
# Global Ollama embedding service instance
|
||||
ollama_embedding_service = OllamaEmbeddingService()
|
||||
ollama_embedding_service = OllamaEmbeddingService()
|
||||
|
||||
Reference in New Issue
Block a user