From 453b428d33aa573608d76d8a64ef448e5070d77a Mon Sep 17 00:00:00 2001 From: cs0lar Date: Wed, 12 Apr 2023 08:21:41 +0100 Subject: [PATCH] added support for weaviate embedded --- .env.template | 5 +++-- requirements.txt | 2 +- scripts/config.py | 3 ++- scripts/memory/weaviate.py | 2 +- tests/test_weaviate_memory.py | 22 ++++++++++++++++++---- 5 files changed, 25 insertions(+), 9 deletions(-) diff --git a/.env.template b/.env.template index 9e4a1a9f..b9869fb9 100644 --- a/.env.template +++ b/.env.template @@ -10,10 +10,11 @@ USE_AZURE=False OPENAI_API_BASE=your-base-url-for-azure OPENAI_API_VERSION=api-version-for-azure OPENAI_DEPLOYMENT_ID=deployment-id-for-azure -WEAVIATE_HOST="http://127.0.0.1" +WEAVIATE_HOST="127.0.0.1" WEAVIATE_PORT=8080 +WEAVIATE_PROTOCOL="http" USE_WEAVIATE_EMBEDDED=False -WEAVIATE_EMBEDDED_PATH="~/.local/share/weaviate" +WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate" WEAVIATE_USERNAME= WEAVIATE_PASSWORD= IMAGE_PROVIDER=dalle diff --git a/requirements.txt b/requirements.txt index 8004319d..d86ebe97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,4 @@ pinecone-client==2.2.1 redis orjson Pillow -weaviate-client==3.15.4 +weaviate-client==3.15.5 diff --git a/scripts/config.py b/scripts/config.py index 7608a504..ba43dca4 100644 --- a/scripts/config.py +++ b/scripts/config.py @@ -66,11 +66,12 @@ class Config(metaclass=Singleton): self.weaviate_host = os.getenv("WEAVIATE_HOST") self.weaviate_port = os.getenv("WEAVIATE_PORT") + self.weaviate_protocol = os.getenv("WEAVIATE_PROTOCOL", "http") self.weaviate_username = os.getenv("WEAVIATE_USERNAME", None) self.weaviate_password = os.getenv("WEAVIATE_PASSWORD", None) self.weaviate_scopes = os.getenv("WEAVIATE_SCOPES", None) self.weaviate_embedded_path = os.getenv('WEAVIATE_EMBEDDED_PATH', '~/.local/share/weaviate') - self.use_weaviate_embedded = os.getenv("USE_WEAVIATE_EMBEDDED", False) + self.use_weaviate_embedded = os.getenv("USE_WEAVIATE_EMBEDDED", "False") == "True" self.image_provider = os.getenv("IMAGE_PROVIDER") self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN") diff --git a/scripts/memory/weaviate.py b/scripts/memory/weaviate.py index 10a76640..2eac5839 100644 --- a/scripts/memory/weaviate.py +++ b/scripts/memory/weaviate.py @@ -22,7 +22,7 @@ class WeaviateMemory(MemoryProviderSingleton): def __init__(self, cfg): auth_credentials = self._build_auth_credentials(cfg) - url = f'{cfg.weaviate_host}:{cfg.weaviate_port}' + url = f'{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}' if cfg.use_weaviate_embedded: self.client = Client(embedded_options=EmbeddedOptions( diff --git a/tests/test_weaviate_memory.py b/tests/test_weaviate_memory.py index 41709bc0..6f39a203 100644 --- a/tests/test_weaviate_memory.py +++ b/tests/test_weaviate_memory.py @@ -13,10 +13,11 @@ from memory.weaviate import WeaviateMemory from memory.base import get_ada_embedding @mock.patch.dict(os.environ, { - "WEAVIATE_HOST": "http://127.0.0.1", + "WEAVIATE_HOST": "127.0.0.1", + "WEAVIATE_PROTOCOL": "http", "WEAVIATE_PORT": "8080", - "WEAVIATE_USERNAME": '', - "WEAVIATE_PASSWORD": '', + "WEAVIATE_USERNAME": "", + "WEAVIATE_PASSWORD": "", "MEMORY_INDEX": "AutogptTests" }) class TestWeaviateMemory(unittest.TestCase): @@ -24,11 +25,24 @@ class TestWeaviateMemory(unittest.TestCase): In order to run these tests you will need a local instance of Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose for creating local instances using docker. + Alternatively in your .env file set the following environmental variables to run Weaviate embedded (see: https://weaviate.io/developers/weaviate/installation/embedded): + + USE_WEAVIATE_EMBEDDED=True + WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate" """ def setUp(self): self.cfg = Config() - self.client = Client('http://127.0.0.1:8080') + if self.cfg.use_weaviate_embedded: + from weaviate.embedded import EmbeddedOptions + + self.client = Client(embedded_options=EmbeddedOptions( + hostname=self.cfg.weaviate_host, + port=int(self.cfg.weaviate_port), + persistence_data_path=self.cfg.weaviate_embedded_path + )) + else: + self.client = Client(f"{self.cfg.weaviate_protocol}://{self.cfg.weaviate_host}:{self.cfg.weaviate_port}") try: self.client.schema.delete_class(self.cfg.memory_index)