Files
Auto-GPT/tests/integration/weaviate_memory_tests.py
BillSchumacher 3fadf2c90b Blacked
2023-04-16 14:15:38 -05:00

126 lines
3.7 KiB
Python

import unittest
from unittest import mock
import sys
import os
from weaviate import Client
from weaviate.util import get_valid_uuid
from uuid import uuid4
from autogpt.config import Config
from autogpt.memory.weaviate import WeaviateMemory
from autogpt.memory.base import get_ada_embedding
@mock.patch.dict(
os.environ,
{
"WEAVIATE_HOST": "127.0.0.1",
"WEAVIATE_PROTOCOL": "http",
"WEAVIATE_PORT": "8080",
"WEAVIATE_USERNAME": "",
"WEAVIATE_PASSWORD": "",
"MEMORY_INDEX": "AutogptTests",
},
)
class TestWeaviateMemory(unittest.TestCase):
cfg = None
client = None
@classmethod
def setUpClass(cls):
# only create the connection to weaviate once
cls.cfg = Config()
if cls.cfg.use_weaviate_embedded:
from weaviate.embedded import EmbeddedOptions
cls.client = Client(
embedded_options=EmbeddedOptions(
hostname=cls.cfg.weaviate_host,
port=int(cls.cfg.weaviate_port),
persistence_data_path=cls.cfg.weaviate_embedded_path,
)
)
else:
cls.client = Client(
f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}"
)
"""
In order to run these tests you will need a local instance of
Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose
for creating local instances using docker.
Alternatively in your .env file set the following environmental variables to run Weaviate embedded (see: https://weaviate.io/developers/weaviate/installation/embedded):
USE_WEAVIATE_EMBEDDED=True
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
"""
def setUp(self):
try:
self.client.schema.delete_class(self.cfg.memory_index)
except:
pass
self.memory = WeaviateMemory(self.cfg)
def test_add(self):
doc = "You are a Titan name Thanos and you are looking for the Infinity Stones"
self.memory.add(doc)
result = self.client.query.get(self.cfg.memory_index, ["raw_text"]).do()
actual = result["data"]["Get"][self.cfg.memory_index]
self.assertEqual(len(actual), 1)
self.assertEqual(actual[0]["raw_text"], doc)
def test_get(self):
doc = "You are an Avenger and swore to defend the Galaxy from a menace called Thanos"
with self.client.batch as batch:
batch.add_data_object(
uuid=get_valid_uuid(uuid4()),
data_object={"raw_text": doc},
class_name=self.cfg.memory_index,
vector=get_ada_embedding(doc),
)
batch.flush()
actual = self.memory.get(doc)
self.assertEqual(len(actual), 1)
self.assertEqual(actual[0], doc)
def test_get_stats(self):
docs = [
"You are now about to count the number of docs in this index",
"And then you about to find out if you can count correctly",
]
[self.memory.add(doc) for doc in docs]
stats = self.memory.get_stats()
self.assertTrue(stats)
self.assertTrue("count" in stats)
self.assertEqual(stats["count"], 2)
def test_clear(self):
docs = [
"Shame this is the last test for this class",
"Testing is fun when someone else is doing it",
]
[self.memory.add(doc) for doc in docs]
self.assertEqual(self.memory.get_stats()["count"], 2)
self.memory.clear()
self.assertEqual(self.memory.get_stats()["count"], 0)
if __name__ == "__main__":
unittest.main()