mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-25 17:04:24 +01:00
feat: ability to use local embeddings model (sBERT)
This commit is contained in:
@@ -17,4 +17,4 @@ orjson
|
||||
Pillow
|
||||
coverage
|
||||
flake8
|
||||
numpy
|
||||
sentence_transformers
|
||||
|
||||
@@ -82,6 +82,7 @@ class Config(metaclass=Singleton):
|
||||
# Note that indexes must be created on db 0 in redis, this is not configurable.
|
||||
|
||||
self.memory_backend = os.getenv("MEMORY_BACKEND", 'local')
|
||||
self.memory_embeder = os.getenv("MEMORY_EMBEDER", 'ada')
|
||||
# Initialize the OpenAI API client
|
||||
openai.api_key = self.openai_api_key
|
||||
|
||||
|
||||
@@ -3,16 +3,28 @@ import abc
|
||||
from config import AbstractSingleton, Config
|
||||
import openai
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError:
|
||||
SentenceTransformer = None
|
||||
if cfg.memory_embeder == "sbert":
|
||||
print("Error: Sentence Transformers is not installed. Please install sentence_transformers"
|
||||
" to use BERT as an embeder. Defaulting to Ada.")
|
||||
cfg.memory_embeder = "ada"
|
||||
|
||||
|
||||
cfg = Config()
|
||||
|
||||
|
||||
def get_ada_embedding(text):
|
||||
def get_embedding(text):
|
||||
text = text.replace("\n", " ")
|
||||
if cfg.use_azure:
|
||||
return openai.Embedding.create(input=[text], engine=cfg.get_azure_deployment_id_for_model("text-embedding-ada-002"))["data"][0]["embedding"]
|
||||
else:
|
||||
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
|
||||
|
||||
if cfg.memory_embeder == "sbert":
|
||||
embedding = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device="cpu").encode(text, show_progress_bar=False)
|
||||
else:
|
||||
embedding = openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
class MemoryProviderSingleton(AbstractSingleton):
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -3,10 +3,12 @@ import orjson
|
||||
from typing import Any, List, Optional
|
||||
import numpy as np
|
||||
import os
|
||||
from memory.base import MemoryProviderSingleton, get_ada_embedding
|
||||
from memory.base import MemoryProviderSingleton, get_embedding
|
||||
from config import Config
|
||||
|
||||
cfg = Config()
|
||||
|
||||
EMBED_DIM = 1536
|
||||
EMBED_DIM = 1536 if cfg.memory_embeder == "ada" else 768
|
||||
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS
|
||||
|
||||
|
||||
@@ -58,7 +60,7 @@ class LocalCache(MemoryProviderSingleton):
|
||||
return ""
|
||||
self.data.texts.append(text)
|
||||
|
||||
embedding = get_ada_embedding(text)
|
||||
embedding = get_embedding(text)
|
||||
|
||||
vector = np.array(embedding).astype(np.float32)
|
||||
vector = vector[np.newaxis, :]
|
||||
@@ -109,7 +111,7 @@ class LocalCache(MemoryProviderSingleton):
|
||||
|
||||
Returns: List[str]
|
||||
"""
|
||||
embedding = get_ada_embedding(text)
|
||||
embedding = get_embedding(text)
|
||||
|
||||
scores = np.dot(self.data.embeddings, embedding)
|
||||
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
|
||||
import pinecone
|
||||
|
||||
from memory.base import MemoryProviderSingleton, get_ada_embedding
|
||||
from memory.base import MemoryProviderSingleton, get_embedding
|
||||
from logger import logger
|
||||
from colorama import Fore, Style
|
||||
|
||||
|
||||
class PineconeMemory(MemoryProviderSingleton):
|
||||
def __init__(self, cfg):
|
||||
pinecone_api_key = cfg.pinecone_api_key
|
||||
pinecone_region = cfg.pinecone_region
|
||||
pinecone.init(api_key=pinecone_api_key, environment=pinecone_region)
|
||||
dimension = 1536
|
||||
dimension = 1536 if cfg.memory_embeder == "ada" else 768
|
||||
metric = "cosine"
|
||||
pod_type = "p1"
|
||||
table_name = "auto-gpt"
|
||||
@@ -33,7 +32,7 @@ class PineconeMemory(MemoryProviderSingleton):
|
||||
self.index = pinecone.Index(table_name)
|
||||
|
||||
def add(self, data):
|
||||
vector = get_ada_embedding(data)
|
||||
vector = get_embedding(data)
|
||||
# no metadata here. We may wish to change that long term.
|
||||
resp = self.index.upsert([(str(self.vec_num), vector, {"raw_text": data})])
|
||||
_text = f"Inserting data into memory at index: {self.vec_num}:\n data: {data}"
|
||||
@@ -53,7 +52,7 @@ class PineconeMemory(MemoryProviderSingleton):
|
||||
:param data: The data to compare to.
|
||||
:param num_relevant: The number of relevant data to return. Defaults to 5
|
||||
"""
|
||||
query_embedding = get_ada_embedding(data)
|
||||
query_embedding = get_embedding(data)
|
||||
results = self.index.query(query_embedding, top_k=num_relevant, include_metadata=True)
|
||||
sorted_results = sorted(results.matches, key=lambda x: x.score)
|
||||
return [str(item['metadata']["raw_text"]) for item in sorted_results]
|
||||
|
||||
@@ -6,10 +6,14 @@ from redis.commands.search.query import Query
|
||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||
import numpy as np
|
||||
|
||||
from memory.base import MemoryProviderSingleton, get_ada_embedding
|
||||
from memory.base import MemoryProviderSingleton, get_embedding
|
||||
from logger import logger
|
||||
from colorama import Fore, Style
|
||||
from config import Config
|
||||
|
||||
cfg = Config()
|
||||
|
||||
EMBED_DIM = 1536 if cfg.memory_embeder == "ada" else 768
|
||||
|
||||
SCHEMA = [
|
||||
TextField("data"),
|
||||
@@ -18,7 +22,7 @@ SCHEMA = [
|
||||
"HNSW",
|
||||
{
|
||||
"TYPE": "FLOAT32",
|
||||
"DIM": 1536,
|
||||
"DIM": EMBED_DIM,
|
||||
"DISTANCE_METRIC": "COSINE"
|
||||
}
|
||||
),
|
||||
@@ -38,7 +42,6 @@ class RedisMemory(MemoryProviderSingleton):
|
||||
redis_host = cfg.redis_host
|
||||
redis_port = cfg.redis_port
|
||||
redis_password = cfg.redis_password
|
||||
self.dimension = 1536
|
||||
self.redis = redis.Redis(
|
||||
host=redis_host,
|
||||
port=redis_port,
|
||||
@@ -83,7 +86,7 @@ class RedisMemory(MemoryProviderSingleton):
|
||||
"""
|
||||
if 'Command Error:' in data:
|
||||
return ""
|
||||
vector = get_ada_embedding(data)
|
||||
vector = get_embedding(data)
|
||||
vector = np.array(vector).astype(np.float32).tobytes()
|
||||
data_dict = {
|
||||
b"data": data,
|
||||
@@ -131,7 +134,7 @@ class RedisMemory(MemoryProviderSingleton):
|
||||
|
||||
Returns: A list of the most relevant data.
|
||||
"""
|
||||
query_embedding = get_ada_embedding(data)
|
||||
query_embedding = get_embedding(data)
|
||||
base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
|
||||
query = Query(base_query).return_fields(
|
||||
"data",
|
||||
|
||||
Reference in New Issue
Block a user