diff --git a/autogpt/llm_utils.py b/autogpt/llm_utils.py index da6cb979..ba7521a4 100644 --- a/autogpt/llm_utils.py +++ b/autogpt/llm_utils.py @@ -154,6 +154,13 @@ def create_chat_completion( return resp +def get_ada_embedding(text): + text = text.replace("\n", " ") + return api_manager.embedding_create( + text_list=[text], model="text-embedding-ada-002" + ) + + def create_embedding_with_ada(text) -> list: """Create an embedding with text-ada-002 using the OpenAI SDK""" num_retries = 10 diff --git a/autogpt/memory/base.py b/autogpt/memory/base.py index b69f795c..b6252464 100644 --- a/autogpt/memory/base.py +++ b/autogpt/memory/base.py @@ -1,21 +1,11 @@ """Base class for memory providers.""" import abc -import openai - -from autogpt.api_manager import api_manager from autogpt.config import AbstractSingleton, Config cfg = Config() -def get_ada_embedding(text): - text = text.replace("\n", " ") - return api_manager.embedding_create( - text_list=[text], model="text-embedding-ada-002" - ) - - class MemoryProviderSingleton(AbstractSingleton): @abc.abstractmethod def add(self, data): diff --git a/autogpt/memory/milvus.py b/autogpt/memory/milvus.py index 1849a9e6..085f50b4 100644 --- a/autogpt/memory/milvus.py +++ b/autogpt/memory/milvus.py @@ -4,7 +4,8 @@ import re from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections from autogpt.config import Config -from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding +from autogpt.llm_utils import get_ada_embedding +from autogpt.memory.base import MemoryProviderSingleton class MilvusMemory(MemoryProviderSingleton): diff --git a/autogpt/memory/weaviate.py b/autogpt/memory/weaviate.py index 0225ae04..fbebbfd7 100644 --- a/autogpt/memory/weaviate.py +++ b/autogpt/memory/weaviate.py @@ -1,12 +1,10 @@ -import uuid - import weaviate from weaviate import Client from weaviate.embedded import EmbeddedOptions from weaviate.util import generate_uuid5 -from autogpt.config import Config -from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding +from autogpt.llm_utils import get_ada_embedding +from autogpt.memory.base import MemoryProviderSingleton def default_schema(weaviate_index): diff --git a/tests/integration/weaviate_memory_tests.py b/tests/integration/weaviate_memory_tests.py index 015eab05..5448b79e 100644 --- a/tests/integration/weaviate_memory_tests.py +++ b/tests/integration/weaviate_memory_tests.py @@ -1,14 +1,11 @@ -import os -import sys import unittest -from unittest import mock from uuid import uuid4 from weaviate import Client from weaviate.util import get_valid_uuid from autogpt.config import Config -from autogpt.memory.base import get_ada_embedding +from autogpt.llm_utils import get_ada_embedding from autogpt.memory.weaviate import WeaviateMemory