mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-24 09:24:27 +01:00
57 lines
1.5 KiB
Python
57 lines
1.5 KiB
Python
"""Base class for memory providers."""
|
|
import abc
|
|
from config import AbstractSingleton, Config
|
|
import openai
|
|
|
|
# try to import sentence transformers, if it fails, default to ada
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
except ImportError:
|
|
SentenceTransformer = None
|
|
if cfg.memory_embeder == "sbert":
|
|
print("Error: Sentence Transformers is not installed. Please install sentence_transformers"
|
|
" to use BERT as an embeder. Defaulting to Ada.")
|
|
cfg.memory_embeder = "ada"
|
|
|
|
|
|
cfg = Config()
|
|
# Dimension of embeddings encoded by models
|
|
EMBED_DIM = {
|
|
"ada": 1536,
|
|
"sbert": 768
|
|
}.get(cfg.memory_embeder, default=1536)
|
|
|
|
|
|
def get_embedding(text):
|
|
text = text.replace("\n", " ")
|
|
|
|
# use the embeder specified in the config
|
|
if cfg.memory_embeder == "sbert":
|
|
embedding = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device="cpu").encode(text, show_progress_bar=False)
|
|
else:
|
|
embedding = openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
|
|
|
|
return embedding
|
|
|
|
|
|
class MemoryProviderSingleton(AbstractSingleton):
|
|
@abc.abstractmethod
|
|
def add(self, data):
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def get(self, data):
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def clear(self):
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def get_relevant(self, data, num_relevant=5):
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def get_stats(self):
|
|
pass
|