diff --git a/autogpt/memory/__init__.py b/autogpt/memory/__init__.py index 670efda1..102b93ae 100644 --- a/autogpt/memory/__init__.py +++ b/autogpt/memory/__init__.py @@ -22,7 +22,7 @@ except ImportError: PineconeMemory = None try: - from memory.milvus import MilvusMemory + from autogpt.memory.milvus import MilvusMemory except ImportError: print("pymilvus not installed. Skipping import.") MilvusMemory = None @@ -48,14 +48,14 @@ def get_memory(cfg, init=False): ) else: memory = RedisMemory(cfg) - elif cfg.memory_backend == "no_memory": - memory = NoMemory(cfg) elif cfg.memory_backend == "milvus": if not MilvusMemory: print("Error: Milvus sdk is not installed." "Please install pymilvus to use Milvus as memory backend.") else: memory = MilvusMemory(cfg) + elif cfg.memory_backend == "no_memory": + memory = NoMemory(cfg) if memory is None: memory = LocalCache(cfg) diff --git a/scripts/memory/milvus.py b/autogpt/memory/milvus.py similarity index 78% rename from scripts/memory/milvus.py rename to autogpt/memory/milvus.py index c6d31750..fce46a89 100644 --- a/scripts/memory/milvus.py +++ b/autogpt/memory/milvus.py @@ -6,7 +6,7 @@ from pymilvus import ( Collection, ) -from memory.base import MemoryProviderSingleton, get_ada_embedding +from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding class MilvusMemory(MemoryProviderSingleton): @@ -28,15 +28,16 @@ class MilvusMemory(MemoryProviderSingleton): ] # create collection if not exist and load it. - schema = CollectionSchema(fields, "auto-gpt memory storage") - self.collection = Collection(cfg.milvus_collection, schema) + self.milvus_collection = cfg.milvus_collection + self.schema = CollectionSchema(fields, "auto-gpt memory storage") + self.collection = Collection(self.milvus_collection, self.schema) # create index if not exist. - if not self.collection.has_index(index_name="embeddings"): + if not self.collection.has_index(): self.collection.release() self.collection.create_index("embeddings", { - "index_type": "IVF_FLAT", "metric_type": "IP", - "params": {"nlist": 128}, + "index_type": "HNSW", + "params": {"M": 8, "efConstruction": 64}, }, index_name="embeddings") self.collection.load() @@ -65,6 +66,13 @@ class MilvusMemory(MemoryProviderSingleton): """ Drop the index in memory. """ self.collection.drop() + self.collection = Collection(self.milvus_collection, self.schema) + self.collection.create_index("embeddings", { + "metric_type": "IP", + "index_type": "HNSW", + "params": {"M": 8, "efConstruction": 64}, + }, index_name="embeddings") + self.collection.load() return "Obliviated" def get_relevant(self, data, num_relevant=5): diff --git a/tests/integration/milvus_memory_tests.py b/tests/integration/milvus_memory_tests.py new file mode 100644 index 00000000..96934cd6 --- /dev/null +++ b/tests/integration/milvus_memory_tests.py @@ -0,0 +1,48 @@ +import random +import string +import unittest + +from autogpt.config import Config +from autogpt.memory.milvus import MilvusMemory + + +class TestMilvusMemory(unittest.TestCase): + def random_string(self, length): + return "".join(random.choice(string.ascii_letters) for _ in range(length)) + + def setUp(self): + cfg = Config() + cfg.milvus_addr = "localhost:19530" + self.memory = MilvusMemory(cfg) + self.memory.clear() + + # Add example texts to the cache + self.example_texts = [ + "The quick brown fox jumps over the lazy dog", + "I love machine learning and natural language processing", + "The cake is a lie, but the pie is always true", + "ChatGPT is an advanced AI model for conversation", + ] + + for text in self.example_texts: + self.memory.add(text) + + # Add some random strings to test noise + for _ in range(5): + self.memory.add(self.random_string(10)) + + def test_get_relevant(self): + query = "I'm interested in artificial intelligence and NLP" + k = 3 + relevant_texts = self.memory.get_relevant(query, k) + + print(f"Top {k} relevant texts for the query '{query}':") + for i, text in enumerate(relevant_texts, start=1): + print(f"{i}. {text}") + + self.assertEqual(len(relevant_texts), k) + self.assertIn(self.example_texts[1], relevant_texts) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/milvus_memory_test.py b/tests/milvus_memory_test.py new file mode 100644 index 00000000..c64924f7 --- /dev/null +++ b/tests/milvus_memory_test.py @@ -0,0 +1,64 @@ +import os +import sys +import unittest + +from autogpt.memory.milvus import MilvusMemory + + +def MockConfig(): + return type( + "MockConfig", + (object,), + { + "debug_mode": False, + "continuous_mode": False, + "speak_mode": False, + "milvus_collection": "autogpt", + "milvus_addr": "localhost:19530", + + }, + ) + + +class TestMilvusMemory(unittest.TestCase): + def setUp(self): + self.cfg = MockConfig() + self.memory = MilvusMemory(self.cfg) + + def test_add(self): + text = "Sample text" + self.memory.clear() + self.memory.add(text) + result = self.memory.get(text) + self.assertEqual([text], result) + + def test_clear(self): + self.memory.clear() + self.assertEqual(self.memory.collection.num_entities, 0) + + def test_get(self): + text = "Sample text" + self.memory.clear() + self.memory.add(text) + result = self.memory.get(text) + self.assertEqual(result, [text]) + + def test_get_relevant(self): + text1 = "Sample text 1" + text2 = "Sample text 2" + self.memory.clear() + self.memory.add(text1) + self.memory.add(text2) + result = self.memory.get_relevant(text1, 1) + self.assertEqual(result, [text1]) + + def test_get_stats(self): + text = "Sample text" + self.memory.clear() + self.memory.add(text) + stats = self.memory.get_stats() + self.assertEqual(15, len(stats)) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file