diff --git a/scripts/factory.py b/scripts/factory.py index 44901631..19f67ba0 100644 --- a/scripts/factory.py +++ b/scripts/factory.py @@ -8,4 +8,6 @@ class MemoryFactory: return PineconeMemory() if mem_type == 'weaviate': - return WeaviateMemory() \ No newline at end of file + return WeaviateMemory() + + raise ValueError('Unknown memory provider') \ No newline at end of file diff --git a/tests/memory_tests.py b/tests/memory_tests.py new file mode 100644 index 00000000..4985c022 --- /dev/null +++ b/tests/memory_tests.py @@ -0,0 +1,55 @@ +import unittest +from unittest import mock +import sys +import os + +sys.path.append(os.path.abspath('./scripts')) + +from factory import MemoryFactory +from providers.weaviate import WeaviateMemory +from providers.pinecone import PineconeMemory + +class TestMemoryFactory(unittest.TestCase): + + def test_invalid_memory_provider(self): + + with self.assertRaises(ValueError): + memory = MemoryFactory.get_memory('Thanos') + + def test_create_pinecone_provider(self): + + # mock the init function of the provider to bypass + # connection to the external pinecone service + def __init__(self): + pass + + with mock.patch.object(PineconeMemory, '__init__', __init__): + memory = MemoryFactory.get_memory('pinecone') + self.assertIsInstance(memory, PineconeMemory) + + def test_create_weaviate_provider(self): + + # mock the init function of the provider to bypass + # connection to the external weaviate service + def __init__(self): + pass + + with mock.patch.object(WeaviateMemory, '__init__', __init__): + memory = MemoryFactory.get_memory('weaviate') + self.assertIsInstance(memory, WeaviateMemory) + + def test_provider_is_singleton(self): + + def __init__(self): + pass + + with mock.patch.object(WeaviateMemory, '__init__', __init__): + instance = MemoryFactory.get_memory('weaviate') + other_instance = MemoryFactory.get_memory('weaviate') + + self.assertIs(instance, other_instance) + + +if __name__ == '__main__': + unittest.main() +