Files
Auto-GPT/scripts/memory/base.py

57 lines
1.5 KiB
Python

"""Base class for memory providers."""
import abc
from config import AbstractSingleton, Config
import openai
# try to import sentence transformers, if it fails, default to ada
try:
from sentence_transformers import SentenceTransformer
except ImportError:
SentenceTransformer = None
if cfg.memory_embeder == "sbert":
print("Error: Sentence Transformers is not installed. Please install sentence_transformers"
" to use BERT as an embeder. Defaulting to Ada.")
cfg.memory_embeder = "ada"
cfg = Config()
# Dimension of embeddings encoded by models
EMBED_DIM = {
"ada": 1536,
"sbert": 768
}.get(cfg.memory_embeder, default=1536)
def get_embedding(text):
text = text.replace("\n", " ")
# use the embeder specified in the config
if cfg.memory_embeder == "sbert":
embedding = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device="cpu").encode(text, show_progress_bar=False)
else:
embedding = openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
return embedding
class MemoryProviderSingleton(AbstractSingleton):
@abc.abstractmethod
def add(self, data):
pass
@abc.abstractmethod
def get(self, data):
pass
@abc.abstractmethod
def clear(self):
pass
@abc.abstractmethod
def get_relevant(self, data, num_relevant=5):
pass
@abc.abstractmethod
def get_stats(self):
pass