added support for weaviate embedded

This commit is contained in:
cs0lar
2023-04-12 05:40:24 +01:00
parent 3c7767fab0
commit 96c5e929be
3 changed files with 17 additions and 3 deletions

View File

@@ -11,7 +11,9 @@ OPENAI_API_BASE=your-base-url-for-azure
OPENAI_API_VERSION=api-version-for-azure
OPENAI_DEPLOYMENT_ID=deployment-id-for-azure
WEAVIATE_HOST="http://127.0.0.1"
WEAVIATE_PORT="8080"
WEAVIATE_PORT=8080
USE_WEAVIATE_EMBEDDED=False
WEAVIATE_EMBEDDED_PATH="~/.local/share/weaviate"
WEAVIATE_USERNAME=
WEAVIATE_PASSWORD=
IMAGE_PROVIDER=dalle

View File

@@ -69,6 +69,8 @@ class Config(metaclass=Singleton):
self.weaviate_username = os.getenv("WEAVIATE_USERNAME", None)
self.weaviate_password = os.getenv("WEAVIATE_PASSWORD", None)
self.weaviate_scopes = os.getenv("WEAVIATE_SCOPES", None)
self.weaviate_embedded_path = os.getenv('WEAVIATE_EMBEDDED_PATH', '~/.local/share/weaviate')
self.use_weaviate_embedded = os.getenv("USE_WEAVIATE_EMBEDDED", False)
self.image_provider = os.getenv("IMAGE_PROVIDER")
self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN")

View File

@@ -3,6 +3,7 @@ from memory.base import MemoryProviderSingleton, get_ada_embedding
import uuid
import weaviate
from weaviate import Client
from weaviate.embedded import EmbeddedOptions
from weaviate.util import generate_uuid5
def default_schema(weaviate_index):
@@ -23,7 +24,17 @@ class WeaviateMemory(MemoryProviderSingleton):
url = f'{cfg.weaviate_host}:{cfg.weaviate_port}'
self.client = Client(url, auth_client_secret=auth_credentials)
if cfg.use_weaviate_embedded:
self.client = Client(embedded_options=EmbeddedOptions(
hostname=cfg.weaviate_host,
port=int(cfg.weaviate_port),
persistence_data_path=cfg.weaviate_embedded_path
))
print(f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}")
else:
self.client = Client(url, auth_client_secret=auth_credentials)
self.index = cfg.memory_index
self._create_schema()
@@ -59,7 +70,6 @@ class WeaviateMemory(MemoryProviderSingleton):
return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}"
def get(self, data):
return self.get_relevant(data, 1)