diff --git a/autogpt/memory/__init__.py b/autogpt/memory/__init__.py index d56b2de2..7b545ea3 100644 --- a/autogpt/memory/__init__.py +++ b/autogpt/memory/__init__.py @@ -27,6 +27,7 @@ except ImportError: print("Weaviate not installed. Skipping import.") WeaviateMemory = None + def get_memory(cfg, init=False): memory = None if cfg.memory_backend == "pinecone": @@ -53,7 +54,7 @@ def get_memory(cfg, init=False): " use Weaviate as a memory backend.") else: memory = WeaviateMemory(cfg) - + elif cfg.memory_backend == "no_memory": memory = NoMemory(cfg) diff --git a/autogpt/memory/base.py b/autogpt/memory/base.py index 784483fa..691e2299 100644 --- a/autogpt/memory/base.py +++ b/autogpt/memory/base.py @@ -7,6 +7,7 @@ from autogpt.config import AbstractSingleton, Config cfg = Config() + def get_ada_embedding(text): text = text.replace("\n", " ") if cfg.use_azure: diff --git a/autogpt/memory/weaviate.py b/autogpt/memory/weaviate.py index fdac8e85..48816973 100644 --- a/autogpt/memory/weaviate.py +++ b/autogpt/memory/weaviate.py @@ -6,6 +6,7 @@ from weaviate import Client from weaviate.embedded import EmbeddedOptions from weaviate.util import generate_uuid5 + def default_schema(weaviate_index): return { "class": weaviate_index, @@ -18,6 +19,7 @@ def default_schema(weaviate_index): ], } + class WeaviateMemory(MemoryProviderSingleton): def __init__(self, cfg): auth_credentials = self._build_auth_credentials(cfg) @@ -72,12 +74,11 @@ class WeaviateMemory(MemoryProviderSingleton): def get(self, data): return self.get_relevant(data, 1) - def clear(self): self.client.schema.delete_all() # weaviate does not yet have a neat way to just remove the items in an index - # without removing the entire schema, therefore we need to re-create it + # without removing the entire schema, therefore we need to re-create it # after a call to delete_all self._create_schema() diff --git a/tests/integration/weaviate_memory_tests.py b/tests/integration/weaviate_memory_tests.py index fa456c8a..503fe9d2 100644 --- a/tests/integration/weaviate_memory_tests.py +++ b/tests/integration/weaviate_memory_tests.py @@ -11,6 +11,7 @@ from autogpt.config import Config from autogpt.memory.weaviate import WeaviateMemory from autogpt.memory.base import get_ada_embedding + @mock.patch.dict(os.environ, { "WEAVIATE_HOST": "127.0.0.1", "WEAVIATE_PROTOCOL": "http", @@ -38,13 +39,13 @@ class TestWeaviateMemory(unittest.TestCase): )) else: cls.client = Client(f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}") - + """ In order to run these tests you will need a local instance of Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose for creating local instances using docker. Alternatively in your .env file set the following environmental variables to run Weaviate embedded (see: https://weaviate.io/developers/weaviate/installation/embedded): - + USE_WEAVIATE_EMBEDDED=True WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate" """ @@ -53,7 +54,7 @@ class TestWeaviateMemory(unittest.TestCase): self.client.schema.delete_class(self.cfg.memory_index) except: pass - + self.memory = WeaviateMemory(self.cfg) def test_add(self): @@ -67,7 +68,7 @@ class TestWeaviateMemory(unittest.TestCase): def test_get(self): doc = 'You are an Avenger and swore to defend the Galaxy from a menace called Thanos' - + with self.client.batch as batch: batch.add_data_object( uuid=get_valid_uuid(uuid4()), @@ -83,7 +84,6 @@ class TestWeaviateMemory(unittest.TestCase): self.assertEqual(len(actual), 1) self.assertEqual(actual[0], doc) - def test_get_stats(self): docs = [ 'You are now about to count the number of docs in this index', @@ -98,7 +98,6 @@ class TestWeaviateMemory(unittest.TestCase): self.assertTrue('count' in stats) self.assertEqual(stats['count'], 2) - def test_clear(self): docs = [ 'Shame this is the last test for this class',