This commit is contained in:
BillSchumacher
2023-04-16 14:15:38 -05:00
parent c544cebbe6
commit 3fadf2c90b
10 changed files with 92 additions and 73 deletions

View File

@@ -12,14 +12,17 @@ from autogpt.memory.weaviate import WeaviateMemory
from autogpt.memory.base import get_ada_embedding
@mock.patch.dict(os.environ, {
"WEAVIATE_HOST": "127.0.0.1",
"WEAVIATE_PROTOCOL": "http",
"WEAVIATE_PORT": "8080",
"WEAVIATE_USERNAME": "",
"WEAVIATE_PASSWORD": "",
"MEMORY_INDEX": "AutogptTests"
})
@mock.patch.dict(
os.environ,
{
"WEAVIATE_HOST": "127.0.0.1",
"WEAVIATE_PROTOCOL": "http",
"WEAVIATE_PORT": "8080",
"WEAVIATE_USERNAME": "",
"WEAVIATE_PASSWORD": "",
"MEMORY_INDEX": "AutogptTests",
},
)
class TestWeaviateMemory(unittest.TestCase):
cfg = None
client = None
@@ -32,13 +35,17 @@ class TestWeaviateMemory(unittest.TestCase):
if cls.cfg.use_weaviate_embedded:
from weaviate.embedded import EmbeddedOptions
cls.client = Client(embedded_options=EmbeddedOptions(
hostname=cls.cfg.weaviate_host,
port=int(cls.cfg.weaviate_port),
persistence_data_path=cls.cfg.weaviate_embedded_path
))
cls.client = Client(
embedded_options=EmbeddedOptions(
hostname=cls.cfg.weaviate_host,
port=int(cls.cfg.weaviate_port),
persistence_data_path=cls.cfg.weaviate_embedded_path,
)
)
else:
cls.client = Client(f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}")
cls.client = Client(
f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}"
)
"""
In order to run these tests you will need a local instance of
@@ -49,6 +56,7 @@ class TestWeaviateMemory(unittest.TestCase):
USE_WEAVIATE_EMBEDDED=True
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
"""
def setUp(self):
try:
self.client.schema.delete_class(self.cfg.memory_index)
@@ -58,23 +66,23 @@ class TestWeaviateMemory(unittest.TestCase):
self.memory = WeaviateMemory(self.cfg)
def test_add(self):
doc = 'You are a Titan name Thanos and you are looking for the Infinity Stones'
doc = "You are a Titan name Thanos and you are looking for the Infinity Stones"
self.memory.add(doc)
result = self.client.query.get(self.cfg.memory_index, ['raw_text']).do()
actual = result['data']['Get'][self.cfg.memory_index]
result = self.client.query.get(self.cfg.memory_index, ["raw_text"]).do()
actual = result["data"]["Get"][self.cfg.memory_index]
self.assertEqual(len(actual), 1)
self.assertEqual(actual[0]['raw_text'], doc)
self.assertEqual(actual[0]["raw_text"], doc)
def test_get(self):
doc = 'You are an Avenger and swore to defend the Galaxy from a menace called Thanos'
doc = "You are an Avenger and swore to defend the Galaxy from a menace called Thanos"
with self.client.batch as batch:
batch.add_data_object(
uuid=get_valid_uuid(uuid4()),
data_object={'raw_text': doc},
data_object={"raw_text": doc},
class_name=self.cfg.memory_index,
vector=get_ada_embedding(doc)
vector=get_ada_embedding(doc),
)
batch.flush()
@@ -86,8 +94,8 @@ class TestWeaviateMemory(unittest.TestCase):
def test_get_stats(self):
docs = [
'You are now about to count the number of docs in this index',
'And then you about to find out if you can count correctly'
"You are now about to count the number of docs in this index",
"And then you about to find out if you can count correctly",
]
[self.memory.add(doc) for doc in docs]
@@ -95,23 +103,23 @@ class TestWeaviateMemory(unittest.TestCase):
stats = self.memory.get_stats()
self.assertTrue(stats)
self.assertTrue('count' in stats)
self.assertEqual(stats['count'], 2)
self.assertTrue("count" in stats)
self.assertEqual(stats["count"], 2)
def test_clear(self):
docs = [
'Shame this is the last test for this class',
'Testing is fun when someone else is doing it'
"Shame this is the last test for this class",
"Testing is fun when someone else is doing it",
]
[self.memory.add(doc) for doc in docs]
self.assertEqual(self.memory.get_stats()['count'], 2)
self.assertEqual(self.memory.get_stats()["count"], 2)
self.memory.clear()
self.assertEqual(self.memory.get_stats()['count'], 0)
self.assertEqual(self.memory.get_stats()["count"], 0)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()