diff --git a/.env.template b/.env.template index bc3bea20..b9869fb9 100644 --- a/.env.template +++ b/.env.template @@ -10,8 +10,11 @@ USE_AZURE=False OPENAI_API_BASE=your-base-url-for-azure OPENAI_API_VERSION=api-version-for-azure OPENAI_DEPLOYMENT_ID=deployment-id-for-azure -WEAVIATE_HOST="http://127.0.0.1" -WEAVIATE_PORT="8080" +WEAVIATE_HOST="127.0.0.1" +WEAVIATE_PORT=8080 +WEAVIATE_PROTOCOL="http" +USE_WEAVIATE_EMBEDDED=False +WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate" WEAVIATE_USERNAME= WEAVIATE_PASSWORD= IMAGE_PROVIDER=dalle diff --git a/requirements.txt b/requirements.txt index 8004319d..d86ebe97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,4 @@ pinecone-client==2.2.1 redis orjson Pillow -weaviate-client==3.15.4 +weaviate-client==3.15.5 diff --git a/scripts/config.py b/scripts/config.py index 2a0c0a94..ba43dca4 100644 --- a/scripts/config.py +++ b/scripts/config.py @@ -66,9 +66,12 @@ class Config(metaclass=Singleton): self.weaviate_host = os.getenv("WEAVIATE_HOST") self.weaviate_port = os.getenv("WEAVIATE_PORT") + self.weaviate_protocol = os.getenv("WEAVIATE_PROTOCOL", "http") self.weaviate_username = os.getenv("WEAVIATE_USERNAME", None) self.weaviate_password = os.getenv("WEAVIATE_PASSWORD", None) self.weaviate_scopes = os.getenv("WEAVIATE_SCOPES", None) + self.weaviate_embedded_path = os.getenv('WEAVIATE_EMBEDDED_PATH', '~/.local/share/weaviate') + self.use_weaviate_embedded = os.getenv("USE_WEAVIATE_EMBEDDED", "False") == "True" self.image_provider = os.getenv("IMAGE_PROVIDER") self.huggingface_api_token = os.getenv("HUGGINGFACE_API_TOKEN") diff --git a/scripts/memory/weaviate.py b/scripts/memory/weaviate.py index d29192ea..2eac5839 100644 --- a/scripts/memory/weaviate.py +++ b/scripts/memory/weaviate.py @@ -3,6 +3,7 @@ from memory.base import MemoryProviderSingleton, get_ada_embedding import uuid import weaviate from weaviate import Client +from weaviate.embedded import EmbeddedOptions from weaviate.util import generate_uuid5 def default_schema(weaviate_index): @@ -21,9 +22,19 @@ class WeaviateMemory(MemoryProviderSingleton): def __init__(self, cfg): auth_credentials = self._build_auth_credentials(cfg) - url = f'{cfg.weaviate_host}:{cfg.weaviate_port}' + url = f'{cfg.weaviate_protocol}://{cfg.weaviate_host}:{cfg.weaviate_port}' + + if cfg.use_weaviate_embedded: + self.client = Client(embedded_options=EmbeddedOptions( + hostname=cfg.weaviate_host, + port=int(cfg.weaviate_port), + persistence_data_path=cfg.weaviate_embedded_path + )) + + print(f"Weaviate Embedded running on: {url} with persistence path: {cfg.weaviate_embedded_path}") + else: + self.client = Client(url, auth_client_secret=auth_credentials) - self.client = Client(url, auth_client_secret=auth_credentials) self.index = cfg.memory_index self._create_schema() @@ -59,7 +70,6 @@ class WeaviateMemory(MemoryProviderSingleton): return f"Inserting data into memory at uuid: {doc_uuid}:\n data: {data}" - def get(self, data): return self.get_relevant(data, 1) diff --git a/tests/test_weaviate_memory.py b/tests/test_weaviate_memory.py index 41709bc0..6f39a203 100644 --- a/tests/test_weaviate_memory.py +++ b/tests/test_weaviate_memory.py @@ -13,10 +13,11 @@ from memory.weaviate import WeaviateMemory from memory.base import get_ada_embedding @mock.patch.dict(os.environ, { - "WEAVIATE_HOST": "http://127.0.0.1", + "WEAVIATE_HOST": "127.0.0.1", + "WEAVIATE_PROTOCOL": "http", "WEAVIATE_PORT": "8080", - "WEAVIATE_USERNAME": '', - "WEAVIATE_PASSWORD": '', + "WEAVIATE_USERNAME": "", + "WEAVIATE_PASSWORD": "", "MEMORY_INDEX": "AutogptTests" }) class TestWeaviateMemory(unittest.TestCase): @@ -24,11 +25,24 @@ class TestWeaviateMemory(unittest.TestCase): In order to run these tests you will need a local instance of Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose for creating local instances using docker. + Alternatively in your .env file set the following environmental variables to run Weaviate embedded (see: https://weaviate.io/developers/weaviate/installation/embedded): + + USE_WEAVIATE_EMBEDDED=True + WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate" """ def setUp(self): self.cfg = Config() - self.client = Client('http://127.0.0.1:8080') + if self.cfg.use_weaviate_embedded: + from weaviate.embedded import EmbeddedOptions + + self.client = Client(embedded_options=EmbeddedOptions( + hostname=self.cfg.weaviate_host, + port=int(self.cfg.weaviate_port), + persistence_data_path=self.cfg.weaviate_embedded_path + )) + else: + self.client = Client(f"{self.cfg.weaviate_protocol}://{self.cfg.weaviate_host}:{self.cfg.weaviate_port}") try: self.client.schema.delete_class(self.cfg.memory_index)