mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-17 22:14:28 +01:00
Blacked
This commit is contained in:
@@ -12,14 +12,17 @@ from autogpt.memory.weaviate import WeaviateMemory
|
||||
from autogpt.memory.base import get_ada_embedding
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {
|
||||
"WEAVIATE_HOST": "127.0.0.1",
|
||||
"WEAVIATE_PROTOCOL": "http",
|
||||
"WEAVIATE_PORT": "8080",
|
||||
"WEAVIATE_USERNAME": "",
|
||||
"WEAVIATE_PASSWORD": "",
|
||||
"MEMORY_INDEX": "AutogptTests"
|
||||
})
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"WEAVIATE_HOST": "127.0.0.1",
|
||||
"WEAVIATE_PROTOCOL": "http",
|
||||
"WEAVIATE_PORT": "8080",
|
||||
"WEAVIATE_USERNAME": "",
|
||||
"WEAVIATE_PASSWORD": "",
|
||||
"MEMORY_INDEX": "AutogptTests",
|
||||
},
|
||||
)
|
||||
class TestWeaviateMemory(unittest.TestCase):
|
||||
cfg = None
|
||||
client = None
|
||||
@@ -32,13 +35,17 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
if cls.cfg.use_weaviate_embedded:
|
||||
from weaviate.embedded import EmbeddedOptions
|
||||
|
||||
cls.client = Client(embedded_options=EmbeddedOptions(
|
||||
hostname=cls.cfg.weaviate_host,
|
||||
port=int(cls.cfg.weaviate_port),
|
||||
persistence_data_path=cls.cfg.weaviate_embedded_path
|
||||
))
|
||||
cls.client = Client(
|
||||
embedded_options=EmbeddedOptions(
|
||||
hostname=cls.cfg.weaviate_host,
|
||||
port=int(cls.cfg.weaviate_port),
|
||||
persistence_data_path=cls.cfg.weaviate_embedded_path,
|
||||
)
|
||||
)
|
||||
else:
|
||||
cls.client = Client(f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}")
|
||||
cls.client = Client(
|
||||
f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}"
|
||||
)
|
||||
|
||||
"""
|
||||
In order to run these tests you will need a local instance of
|
||||
@@ -49,6 +56,7 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
USE_WEAVIATE_EMBEDDED=True
|
||||
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
try:
|
||||
self.client.schema.delete_class(self.cfg.memory_index)
|
||||
@@ -58,23 +66,23 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
self.memory = WeaviateMemory(self.cfg)
|
||||
|
||||
def test_add(self):
|
||||
doc = 'You are a Titan name Thanos and you are looking for the Infinity Stones'
|
||||
doc = "You are a Titan name Thanos and you are looking for the Infinity Stones"
|
||||
self.memory.add(doc)
|
||||
result = self.client.query.get(self.cfg.memory_index, ['raw_text']).do()
|
||||
actual = result['data']['Get'][self.cfg.memory_index]
|
||||
result = self.client.query.get(self.cfg.memory_index, ["raw_text"]).do()
|
||||
actual = result["data"]["Get"][self.cfg.memory_index]
|
||||
|
||||
self.assertEqual(len(actual), 1)
|
||||
self.assertEqual(actual[0]['raw_text'], doc)
|
||||
self.assertEqual(actual[0]["raw_text"], doc)
|
||||
|
||||
def test_get(self):
|
||||
doc = 'You are an Avenger and swore to defend the Galaxy from a menace called Thanos'
|
||||
doc = "You are an Avenger and swore to defend the Galaxy from a menace called Thanos"
|
||||
|
||||
with self.client.batch as batch:
|
||||
batch.add_data_object(
|
||||
uuid=get_valid_uuid(uuid4()),
|
||||
data_object={'raw_text': doc},
|
||||
data_object={"raw_text": doc},
|
||||
class_name=self.cfg.memory_index,
|
||||
vector=get_ada_embedding(doc)
|
||||
vector=get_ada_embedding(doc),
|
||||
)
|
||||
|
||||
batch.flush()
|
||||
@@ -86,8 +94,8 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
|
||||
def test_get_stats(self):
|
||||
docs = [
|
||||
'You are now about to count the number of docs in this index',
|
||||
'And then you about to find out if you can count correctly'
|
||||
"You are now about to count the number of docs in this index",
|
||||
"And then you about to find out if you can count correctly",
|
||||
]
|
||||
|
||||
[self.memory.add(doc) for doc in docs]
|
||||
@@ -95,23 +103,23 @@ class TestWeaviateMemory(unittest.TestCase):
|
||||
stats = self.memory.get_stats()
|
||||
|
||||
self.assertTrue(stats)
|
||||
self.assertTrue('count' in stats)
|
||||
self.assertEqual(stats['count'], 2)
|
||||
self.assertTrue("count" in stats)
|
||||
self.assertEqual(stats["count"], 2)
|
||||
|
||||
def test_clear(self):
|
||||
docs = [
|
||||
'Shame this is the last test for this class',
|
||||
'Testing is fun when someone else is doing it'
|
||||
"Shame this is the last test for this class",
|
||||
"Testing is fun when someone else is doing it",
|
||||
]
|
||||
|
||||
[self.memory.add(doc) for doc in docs]
|
||||
|
||||
self.assertEqual(self.memory.get_stats()['count'], 2)
|
||||
self.assertEqual(self.memory.get_stats()["count"], 2)
|
||||
|
||||
self.memory.clear()
|
||||
|
||||
self.assertEqual(self.memory.get_stats()['count'], 0)
|
||||
self.assertEqual(self.memory.get_stats()["count"], 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user