mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-27 01:44:33 +01:00
Merge branch 'master' of https://github.com/BillSchumacher/Auto-GPT into plugin-support
This commit is contained in:
117
tests/integration/weaviate_memory_tests.py
Normal file
117
tests/integration/weaviate_memory_tests.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import unittest
|
||||
from unittest import mock
|
||||
import sys
|
||||
import os
|
||||
|
||||
from weaviate import Client
|
||||
from weaviate.util import get_valid_uuid
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.memory.weaviate import WeaviateMemory
|
||||
from autogpt.memory.base import get_ada_embedding
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {
|
||||
"WEAVIATE_HOST": "127.0.0.1",
|
||||
"WEAVIATE_PROTOCOL": "http",
|
||||
"WEAVIATE_PORT": "8080",
|
||||
"WEAVIATE_USERNAME": "",
|
||||
"WEAVIATE_PASSWORD": "",
|
||||
"MEMORY_INDEX": "AutogptTests"
|
||||
})
|
||||
class TestWeaviateMemory(unittest.TestCase):
|
||||
cfg = None
|
||||
client = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# only create the connection to weaviate once
|
||||
cls.cfg = Config()
|
||||
|
||||
if cls.cfg.use_weaviate_embedded:
|
||||
from weaviate.embedded import EmbeddedOptions
|
||||
|
||||
cls.client = Client(embedded_options=EmbeddedOptions(
|
||||
hostname=cls.cfg.weaviate_host,
|
||||
port=int(cls.cfg.weaviate_port),
|
||||
persistence_data_path=cls.cfg.weaviate_embedded_path
|
||||
))
|
||||
else:
|
||||
cls.client = Client(f"{cls.cfg.weaviate_protocol}://{cls.cfg.weaviate_host}:{self.cfg.weaviate_port}")
|
||||
|
||||
"""
|
||||
In order to run these tests you will need a local instance of
|
||||
Weaviate running. Refer to https://weaviate.io/developers/weaviate/installation/docker-compose
|
||||
for creating local instances using docker.
|
||||
Alternatively in your .env file set the following environmental variables to run Weaviate embedded (see: https://weaviate.io/developers/weaviate/installation/embedded):
|
||||
|
||||
USE_WEAVIATE_EMBEDDED=True
|
||||
WEAVIATE_EMBEDDED_PATH="/home/me/.local/share/weaviate"
|
||||
"""
|
||||
def setUp(self):
|
||||
try:
|
||||
self.client.schema.delete_class(self.cfg.memory_index)
|
||||
except:
|
||||
pass
|
||||
|
||||
self.memory = WeaviateMemory(self.cfg)
|
||||
|
||||
def test_add(self):
|
||||
doc = 'You are a Titan name Thanos and you are looking for the Infinity Stones'
|
||||
self.memory.add(doc)
|
||||
result = self.client.query.get(self.cfg.memory_index, ['raw_text']).do()
|
||||
actual = result['data']['Get'][self.cfg.memory_index]
|
||||
|
||||
self.assertEqual(len(actual), 1)
|
||||
self.assertEqual(actual[0]['raw_text'], doc)
|
||||
|
||||
def test_get(self):
|
||||
doc = 'You are an Avenger and swore to defend the Galaxy from a menace called Thanos'
|
||||
|
||||
with self.client.batch as batch:
|
||||
batch.add_data_object(
|
||||
uuid=get_valid_uuid(uuid4()),
|
||||
data_object={'raw_text': doc},
|
||||
class_name=self.cfg.memory_index,
|
||||
vector=get_ada_embedding(doc)
|
||||
)
|
||||
|
||||
batch.flush()
|
||||
|
||||
actual = self.memory.get(doc)
|
||||
|
||||
self.assertEqual(len(actual), 1)
|
||||
self.assertEqual(actual[0], doc)
|
||||
|
||||
def test_get_stats(self):
|
||||
docs = [
|
||||
'You are now about to count the number of docs in this index',
|
||||
'And then you about to find out if you can count correctly'
|
||||
]
|
||||
|
||||
[self.memory.add(doc) for doc in docs]
|
||||
|
||||
stats = self.memory.get_stats()
|
||||
|
||||
self.assertTrue(stats)
|
||||
self.assertTrue('count' in stats)
|
||||
self.assertEqual(stats['count'], 2)
|
||||
|
||||
def test_clear(self):
|
||||
docs = [
|
||||
'Shame this is the last test for this class',
|
||||
'Testing is fun when someone else is doing it'
|
||||
]
|
||||
|
||||
[self.memory.add(doc) for doc in docs]
|
||||
|
||||
self.assertEqual(self.memory.get_stats()['count'], 2)
|
||||
|
||||
self.memory.clear()
|
||||
|
||||
self.assertEqual(self.memory.get_stats()['count'], 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
86
tests/unit/test_chat.py
Normal file
86
tests/unit/test_chat.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Generated by CodiumAI
|
||||
import unittest
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from autogpt.chat import create_chat_message, generate_context
|
||||
|
||||
|
||||
class TestChat(unittest.TestCase):
|
||||
# Tests that the function returns a dictionary with the correct keys and values when valid strings are provided for role and content.
|
||||
def test_happy_path_role_content(self):
|
||||
result = create_chat_message("system", "Hello, world!")
|
||||
self.assertEqual(result, {"role": "system", "content": "Hello, world!"})
|
||||
|
||||
# Tests that the function returns a dictionary with the correct keys and values when empty strings are provided for role and content.
|
||||
def test_empty_role_content(self):
|
||||
result = create_chat_message("", "")
|
||||
self.assertEqual(result, {"role": "", "content": ""})
|
||||
|
||||
# Tests the behavior of the generate_context function when all input parameters are empty.
|
||||
@patch("time.strftime")
|
||||
def test_generate_context_empty_inputs(self, mock_strftime):
|
||||
# Mock the time.strftime function to return a fixed value
|
||||
mock_strftime.return_value = "Sat Apr 15 00:00:00 2023"
|
||||
# Arrange
|
||||
prompt = ""
|
||||
relevant_memory = ""
|
||||
full_message_history = []
|
||||
model = "gpt-3.5-turbo-0301"
|
||||
|
||||
# Act
|
||||
result = generate_context(prompt, relevant_memory, full_message_history, model)
|
||||
|
||||
# Assert
|
||||
expected_result = (
|
||||
-1,
|
||||
47,
|
||||
3,
|
||||
[
|
||||
{"role": "system", "content": ""},
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"The current time and date is {time.strftime('%c')}",
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"This reminds you of these events from your past:\n\n\n",
|
||||
},
|
||||
],
|
||||
)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
# Tests that the function successfully generates a current_context given valid inputs.
|
||||
def test_generate_context_valid_inputs(self):
|
||||
# Given
|
||||
prompt = "What is your favorite color?"
|
||||
relevant_memory = "You once painted your room blue."
|
||||
full_message_history = [
|
||||
create_chat_message("user", "Hi there!"),
|
||||
create_chat_message("assistant", "Hello! How can I assist you today?"),
|
||||
create_chat_message("user", "Can you tell me a joke?"),
|
||||
create_chat_message(
|
||||
"assistant",
|
||||
"Why did the tomato turn red? Because it saw the salad dressing!",
|
||||
),
|
||||
create_chat_message("user", "Haha, that's funny."),
|
||||
]
|
||||
model = "gpt-3.5-turbo-0301"
|
||||
|
||||
# When
|
||||
result = generate_context(prompt, relevant_memory, full_message_history, model)
|
||||
|
||||
# Then
|
||||
self.assertIsInstance(result[0], int)
|
||||
self.assertIsInstance(result[1], int)
|
||||
self.assertIsInstance(result[2], int)
|
||||
self.assertIsInstance(result[3], list)
|
||||
self.assertGreaterEqual(result[0], 0)
|
||||
self.assertGreaterEqual(result[1], 0)
|
||||
self.assertGreaterEqual(result[2], 0)
|
||||
self.assertGreaterEqual(
|
||||
len(result[3]), 3
|
||||
) # current_context should have at least 3 messages
|
||||
self.assertLessEqual(
|
||||
result[1], 2048
|
||||
) # token limit for GPT-3.5-turbo-0301 is 2048 tokens
|
||||
Reference in New Issue
Block a user