From 967c9270ce40da947270d57b6343ccbcba794d5f Mon Sep 17 00:00:00 2001 From: Tymec Date: Fri, 14 Apr 2023 14:45:44 +0200 Subject: [PATCH] feat: ability to use local embeddings model (sBERT) --- requirements.txt | 2 +- scripts/config.py | 1 + scripts/memory/base.py | 24 ++++++++++++++++++------ scripts/memory/local.py | 10 ++++++---- scripts/memory/pinecone.py | 9 ++++----- scripts/memory/redismem.py | 13 ++++++++----- 6 files changed, 38 insertions(+), 21 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3f7fd228..88278435 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,4 @@ orjson Pillow coverage flake8 -numpy +sentence_transformers diff --git a/scripts/config.py b/scripts/config.py index 37be1b21..38cbd142 100644 --- a/scripts/config.py +++ b/scripts/config.py @@ -82,6 +82,7 @@ class Config(metaclass=Singleton): # Note that indexes must be created on db 0 in redis, this is not configurable. self.memory_backend = os.getenv("MEMORY_BACKEND", 'local') + self.memory_embeder = os.getenv("MEMORY_EMBEDER", 'ada') # Initialize the OpenAI API client openai.api_key = self.openai_api_key diff --git a/scripts/memory/base.py b/scripts/memory/base.py index 4dbf6791..e3924d7e 100644 --- a/scripts/memory/base.py +++ b/scripts/memory/base.py @@ -3,16 +3,28 @@ import abc from config import AbstractSingleton, Config import openai +try: + from sentence_transformers import SentenceTransformer +except ImportError: + SentenceTransformer = None + if cfg.memory_embeder == "sbert": + print("Error: Sentence Transformers is not installed. Please install sentence_transformers" + " to use BERT as an embeder. Defaulting to Ada.") + cfg.memory_embeder = "ada" + + cfg = Config() - -def get_ada_embedding(text): +def get_embedding(text): text = text.replace("\n", " ") - if cfg.use_azure: - return openai.Embedding.create(input=[text], engine=cfg.get_azure_deployment_id_for_model("text-embedding-ada-002"))["data"][0]["embedding"] - else: - return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"] + if cfg.memory_embeder == "sbert": + embedding = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device="cpu").encode(text, show_progress_bar=False) + else: + embedding = openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"] + + return embedding + class MemoryProviderSingleton(AbstractSingleton): @abc.abstractmethod diff --git a/scripts/memory/local.py b/scripts/memory/local.py index b0afacf6..40a08f66 100644 --- a/scripts/memory/local.py +++ b/scripts/memory/local.py @@ -3,10 +3,12 @@ import orjson from typing import Any, List, Optional import numpy as np import os -from memory.base import MemoryProviderSingleton, get_ada_embedding +from memory.base import MemoryProviderSingleton, get_embedding +from config import Config +cfg = Config() -EMBED_DIM = 1536 +EMBED_DIM = 1536 if cfg.memory_embeder == "ada" else 768 SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS @@ -58,7 +60,7 @@ class LocalCache(MemoryProviderSingleton): return "" self.data.texts.append(text) - embedding = get_ada_embedding(text) + embedding = get_embedding(text) vector = np.array(embedding).astype(np.float32) vector = vector[np.newaxis, :] @@ -109,7 +111,7 @@ class LocalCache(MemoryProviderSingleton): Returns: List[str] """ - embedding = get_ada_embedding(text) + embedding = get_embedding(text) scores = np.dot(self.data.embeddings, embedding) diff --git a/scripts/memory/pinecone.py b/scripts/memory/pinecone.py index 20a905b3..b3aab33a 100644 --- a/scripts/memory/pinecone.py +++ b/scripts/memory/pinecone.py @@ -1,17 +1,16 @@ import pinecone -from memory.base import MemoryProviderSingleton, get_ada_embedding +from memory.base import MemoryProviderSingleton, get_embedding from logger import logger from colorama import Fore, Style - class PineconeMemory(MemoryProviderSingleton): def __init__(self, cfg): pinecone_api_key = cfg.pinecone_api_key pinecone_region = cfg.pinecone_region pinecone.init(api_key=pinecone_api_key, environment=pinecone_region) - dimension = 1536 + dimension = 1536 if cfg.memory_embeder == "ada" else 768 metric = "cosine" pod_type = "p1" table_name = "auto-gpt" @@ -33,7 +32,7 @@ class PineconeMemory(MemoryProviderSingleton): self.index = pinecone.Index(table_name) def add(self, data): - vector = get_ada_embedding(data) + vector = get_embedding(data) # no metadata here. We may wish to change that long term. resp = self.index.upsert([(str(self.vec_num), vector, {"raw_text": data})]) _text = f"Inserting data into memory at index: {self.vec_num}:\n data: {data}" @@ -53,7 +52,7 @@ class PineconeMemory(MemoryProviderSingleton): :param data: The data to compare to. :param num_relevant: The number of relevant data to return. Defaults to 5 """ - query_embedding = get_ada_embedding(data) + query_embedding = get_embedding(data) results = self.index.query(query_embedding, top_k=num_relevant, include_metadata=True) sorted_results = sorted(results.matches, key=lambda x: x.score) return [str(item['metadata']["raw_text"]) for item in sorted_results] diff --git a/scripts/memory/redismem.py b/scripts/memory/redismem.py index 49045dd8..8f325835 100644 --- a/scripts/memory/redismem.py +++ b/scripts/memory/redismem.py @@ -6,10 +6,14 @@ from redis.commands.search.query import Query from redis.commands.search.indexDefinition import IndexDefinition, IndexType import numpy as np -from memory.base import MemoryProviderSingleton, get_ada_embedding +from memory.base import MemoryProviderSingleton, get_embedding from logger import logger from colorama import Fore, Style +from config import Config +cfg = Config() + +EMBED_DIM = 1536 if cfg.memory_embeder == "ada" else 768 SCHEMA = [ TextField("data"), @@ -18,7 +22,7 @@ SCHEMA = [ "HNSW", { "TYPE": "FLOAT32", - "DIM": 1536, + "DIM": EMBED_DIM, "DISTANCE_METRIC": "COSINE" } ), @@ -38,7 +42,6 @@ class RedisMemory(MemoryProviderSingleton): redis_host = cfg.redis_host redis_port = cfg.redis_port redis_password = cfg.redis_password - self.dimension = 1536 self.redis = redis.Redis( host=redis_host, port=redis_port, @@ -83,7 +86,7 @@ class RedisMemory(MemoryProviderSingleton): """ if 'Command Error:' in data: return "" - vector = get_ada_embedding(data) + vector = get_embedding(data) vector = np.array(vector).astype(np.float32).tobytes() data_dict = { b"data": data, @@ -131,7 +134,7 @@ class RedisMemory(MemoryProviderSingleton): Returns: A list of the most relevant data. """ - query_embedding = get_ada_embedding(data) + query_embedding = get_embedding(data) base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]" query = Query(base_query).return_fields( "data",