diff --git a/.env.template b/.env.template index 99cb5688..352fd9c7 100644 --- a/.env.template +++ b/.env.template @@ -52,6 +52,8 @@ SMART_TOKEN_LIMIT=8000 # MEMORY_BACKEND - Memory backend type (Default: local) MEMORY_BACKEND=local +# MEMORY_EMBEDDER - Embeddings model to use (Default: ada) +MEMORY_EMBEDDER=ada ### PINECONE # PINECONE_API_KEY - Pinecone API Key (Example: my-pinecone-api-key) diff --git a/autogpt/config/config.py b/autogpt/config/config.py index 60977b8d..e3800b29 100644 --- a/autogpt/config/config.py +++ b/autogpt/config/config.py @@ -85,7 +85,9 @@ class Config(metaclass=Singleton): self.memory_index = os.getenv("MEMORY_INDEX", "auto-gpt") # Note that indexes must be created on db 0 in redis, this is not configurable. - self.memory_backend = os.getenv("MEMORY_BACKEND", "local") + self.memory_backend = os.getenv("MEMORY_BACKEND", 'local') + self.memory_embedder = os.getenv("MEMORY_EMBEDDER", 'ada') + # Initialize the OpenAI API client openai.api_key = self.openai_api_key diff --git a/autogpt/memory/base.py b/autogpt/memory/base.py index 691e2299..bf703660 100644 --- a/autogpt/memory/base.py +++ b/autogpt/memory/base.py @@ -1,24 +1,50 @@ """Base class for memory providers.""" import abc - import openai - from autogpt.config import AbstractSingleton, Config + +# try to import sentence transformers, if it fails, default to ada +try: + from sentence_transformers import SentenceTransformer +except ImportError: + SentenceTransformer = None + if cfg.memory_embedder == "sbert": + print("Error: Sentence Transformers is not installed. Please install sentence_transformers" + " to use sBERT as an embedder. Defaulting to Ada.") + cfg.memory_embedder = "ada" + + cfg = Config() +# Dimension of embeddings encoded by embedders +EMBED_DIM = { + "ada": 1536, + "sbert": 768 +}.get(cfg.memory_embedder, 1536) -def get_ada_embedding(text): +def get_embedding(text): text = text.replace("\n", " ") - if cfg.use_azure: - return openai.Embedding.create( - input=[text], - engine=cfg.get_azure_deployment_id_for_model("text-embedding-ada-002"), - )["data"][0]["embedding"] + + # Use the embedder specified in the config + if cfg.memory_embedder == "sbert": + # sBERT model + embedding = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device="cpu").encode(text, show_progress_bar=False) else: - return openai.Embedding.create(input=[text], model="text-embedding-ada-002")[ - "data" - ][0]["embedding"] + # Ada model + model = "text-embedding-ada-002" + engine = None + if cfg.use_azure: + engine = cfg.get_azure_deployment_id_for_model(model) + model = None + + embedding = openai.Embedding.create( + input=[text], + model=model, + engine=engine, + )["data"][0]["embedding"] + + return embedding class MemoryProviderSingleton(AbstractSingleton): diff --git a/autogpt/memory/local.py b/autogpt/memory/local.py index a5f6076e..28ddfeb8 100644 --- a/autogpt/memory/local.py +++ b/autogpt/memory/local.py @@ -5,9 +5,9 @@ from typing import Any, List, Optional, Tuple import numpy as np import orjson -from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding +from autogpt.memory.base import MemoryProviderSingleton, get_embedding, EMBED_DIM + -EMBED_DIM = 1536 SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS @@ -70,7 +70,7 @@ class LocalCache(MemoryProviderSingleton): return "" self.data.texts.append(text) - embedding = get_ada_embedding(text) + embedding = get_embedding(text) vector = np.array(embedding).astype(np.float32) vector = vector[np.newaxis, :] @@ -118,7 +118,7 @@ class LocalCache(MemoryProviderSingleton): Returns: List[str] """ - embedding = get_ada_embedding(text) + embedding = get_embedding(text) scores = np.dot(self.data.embeddings, embedding) diff --git a/autogpt/memory/pinecone.py b/autogpt/memory/pinecone.py index a7dbfa82..e95fe4d4 100644 --- a/autogpt/memory/pinecone.py +++ b/autogpt/memory/pinecone.py @@ -2,7 +2,7 @@ import pinecone from colorama import Fore, Style from autogpt.logs import logger -from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding +from autogpt.memory.base import MemoryProviderSingleton, get_embedding, EMBED_DIM class PineconeMemory(MemoryProviderSingleton): @@ -10,7 +10,6 @@ class PineconeMemory(MemoryProviderSingleton): pinecone_api_key = cfg.pinecone_api_key pinecone_region = cfg.pinecone_region pinecone.init(api_key=pinecone_api_key, environment=pinecone_region) - dimension = 1536 metric = "cosine" pod_type = "p1" table_name = "auto-gpt" @@ -37,13 +36,11 @@ class PineconeMemory(MemoryProviderSingleton): exit(1) if table_name not in pinecone.list_indexes(): - pinecone.create_index( - table_name, dimension=dimension, metric=metric, pod_type=pod_type - ) + pinecone.create_index(table_name, dimension=EMBED_DIM, metric=metric, pod_type=pod_type) self.index = pinecone.Index(table_name) def add(self, data): - vector = get_ada_embedding(data) + vector = get_embedding(data) # no metadata here. We may wish to change that long term. self.index.upsert([(str(self.vec_num), vector, {"raw_text": data})]) _text = f"Inserting data into memory at index: {self.vec_num}:\n data: {data}" @@ -63,10 +60,8 @@ class PineconeMemory(MemoryProviderSingleton): :param data: The data to compare to. :param num_relevant: The number of relevant data to return. Defaults to 5 """ - query_embedding = get_ada_embedding(data) - results = self.index.query( - query_embedding, top_k=num_relevant, include_metadata=True - ) + query_embedding = get_embedding(data) + results = self.index.query(query_embedding, top_k=num_relevant, include_metadata=True) sorted_results = sorted(results.matches, key=lambda x: x.score) return [str(item["metadata"]["raw_text"]) for item in sorted_results] diff --git a/autogpt/memory/redismem.py b/autogpt/memory/redismem.py index df6d8fc0..541923b8 100644 --- a/autogpt/memory/redismem.py +++ b/autogpt/memory/redismem.py @@ -9,14 +9,18 @@ from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import Query from autogpt.logs import logger -from autogpt.memory.base import MemoryProviderSingleton, get_ada_embedding +from autogpt.memory.base import MemoryProviderSingleton, get_embedding, EMBED_DIM SCHEMA = [ TextField("data"), VectorField( "embedding", "HNSW", - {"TYPE": "FLOAT32", "DIM": 1536, "DISTANCE_METRIC": "COSINE"}, + { + "TYPE": "FLOAT32", + "DIM": EMBED_DIM, + "DISTANCE_METRIC": "COSINE" + } ), ] @@ -34,7 +38,6 @@ class RedisMemory(MemoryProviderSingleton): redis_host = cfg.redis_host redis_port = cfg.redis_port redis_password = cfg.redis_password - self.dimension = 1536 self.redis = redis.Redis( host=redis_host, port=redis_port, @@ -85,7 +88,7 @@ class RedisMemory(MemoryProviderSingleton): """ if "Command Error:" in data: return "" - vector = get_ada_embedding(data) + vector = get_embedding(data) vector = np.array(vector).astype(np.float32).tobytes() data_dict = {b"data": data, "embedding": vector} pipe = self.redis.pipeline() @@ -127,7 +130,7 @@ class RedisMemory(MemoryProviderSingleton): Returns: A list of the most relevant data. """ - query_embedding = get_ada_embedding(data) + query_embedding = get_embedding(data) base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]" query = ( Query(base_query) diff --git a/requirements.txt b/requirements.txt index 2fec1b16..5b862bd5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ selenium webdriver-manager coverage flake8 +sentence_transformers numpy pre-commit black diff --git a/tests/embedder_test.py b/tests/embedder_test.py new file mode 100644 index 00000000..5b9ca798 --- /dev/null +++ b/tests/embedder_test.py @@ -0,0 +1,26 @@ +import os +import sys + +from autogpt.config import Config +from autogpt.memory.base import get_embedding + +# Required, because the get_embedding function uses it +cfg = Config() + + +class TestMemoryEmbedder(unittest.TestCase): + def test_ada(self): + cfg.memory_embedder = "ada" + text = "Sample text" + result = get_embedding(text) + self.assertEqual(len(result), 1536) + + def test_sbert(self): + cfg.memory_embedder = "sbert" + text = "Sample text" + result = get_embedding(text) + self.assertEqual(len(result), 768) + + +if __name__ == '__main__': + unittest.main()