From fb6684450c2dedba12c598642def45e4df006b62 Mon Sep 17 00:00:00 2001 From: Tymec Date: Fri, 14 Apr 2023 14:56:58 +0200 Subject: [PATCH] test: added tests for memory embeder --- tests/embeder_test.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/embeder_test.py diff --git a/tests/embeder_test.py b/tests/embeder_test.py new file mode 100644 index 00000000..214b9110 --- /dev/null +++ b/tests/embeder_test.py @@ -0,0 +1,33 @@ +import os +import sys +# Probably a better way: +sys.path.append(os.path.abspath('../scripts')) +from memory.base import get_embedding + +def MockConfig(): + return type('MockConfig', (object,), { + 'debug_mode': False, + 'continuous_mode': False, + 'speak_mode': False, + 'memory_embeder': 'sbert' + }) + +class TestMemoryEmbeder(unittest.TestCase): + def setUp(self): + self.cfg = MockConfig() + + def test_ada(self): + self.cfg.memory_embeder = "ada" + text = "Sample text" + result = get_embedding(text) + self.assertEqual(result.shape, (1536,)) + + def test_sbert(self): + self.cfg.memory_embeder = "sbert" + text = "Sample text" + result = get_embedding(text) + self.assertEqual(result.shape, (768,)) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file