diff --git a/tests/embedder_test.py b/tests/embedder_test.py index d64dae11..fddcf225 100644 --- a/tests/embedder_test.py +++ b/tests/embedder_test.py @@ -3,32 +3,27 @@ import sys # Probably a better way: sys.path.append(os.path.abspath('../scripts')) from memory.base import get_embedding +from config import Config +import unittest + + +# Required, because the get_embedding function uses it +cfg = Config() -def MockConfig(): - return type('MockConfig', (object,), { - 'debug_mode': False, - 'continuous_mode': False, - 'speak_mode': False, - 'memory_embedder': 'sbert' - }) class TestMemoryEmbedder(unittest.TestCase): - def setUp(self): - self.cfg = MockConfig() - def test_ada(self): - self.cfg.memory_embedder = "ada" + cfg.memory_embedder = "ada" text = "Sample text" result = get_embedding(text) - self.assertEqual(result.shape, (1536,)) + self.assertEqual(len(result), 1536) def test_sbert(self): - self.cfg.memory_embedder = "sbert" + cfg.memory_embedder = "sbert" text = "Sample text" result = get_embedding(text) - self.assertEqual(result.shape, (768,)) + self.assertEqual(len(result), 768) if __name__ == '__main__': unittest.main() - \ No newline at end of file