diff --git a/scripts/memory/base.py b/scripts/memory/base.py index a0b3f25f..c3fec628 100644 --- a/scripts/memory/base.py +++ b/scripts/memory/base.py @@ -15,6 +15,12 @@ except ImportError: cfg = Config() +# Dimension of embeddings encoded by models +EMBED_DIM = { + "ada": 1536, + "sbert": 768 +}.get(cfg.memory_embeder, default=1536) + def get_embedding(text): text = text.replace("\n", " ") diff --git a/scripts/memory/local.py b/scripts/memory/local.py index 728723cb..e699a891 100644 --- a/scripts/memory/local.py +++ b/scripts/memory/local.py @@ -3,14 +3,9 @@ import orjson from typing import Any, List, Optional import numpy as np import os -from memory.base import MemoryProviderSingleton, get_embedding -from config import Config +from memory.base import MemoryProviderSingleton, get_embedding, EMBED_DIM -# TODO: get the embeddings dimension without importing config -cfg = Config() -# set the embedding dimension based on the embeder -EMBED_DIM = 1536 if cfg.memory_embeder == "ada" else 768 SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS diff --git a/scripts/memory/pinecone.py b/scripts/memory/pinecone.py index e8a71316..1f95bb23 100644 --- a/scripts/memory/pinecone.py +++ b/scripts/memory/pinecone.py @@ -1,7 +1,7 @@ import pinecone -from memory.base import MemoryProviderSingleton, get_embedding +from memory.base import MemoryProviderSingleton, get_embedding, EMBED_DIM from logger import logger from colorama import Fore, Style @@ -10,8 +10,6 @@ class PineconeMemory(MemoryProviderSingleton): pinecone_api_key = cfg.pinecone_api_key pinecone_region = cfg.pinecone_region pinecone.init(api_key=pinecone_api_key, environment=pinecone_region) - # set the embedding dimension based on the embeder - dimension = 1536 if cfg.memory_embeder == "ada" else 768 metric = "cosine" pod_type = "p1" table_name = "auto-gpt" @@ -29,7 +27,7 @@ class PineconeMemory(MemoryProviderSingleton): exit(1) if table_name not in pinecone.list_indexes(): - pinecone.create_index(table_name, dimension=dimension, metric=metric, pod_type=pod_type) + pinecone.create_index(table_name, dimension=EMBED_DIM, metric=metric, pod_type=pod_type) self.index = pinecone.Index(table_name) def add(self, data): diff --git a/scripts/memory/redismem.py b/scripts/memory/redismem.py index 3da52827..43e8a33b 100644 --- a/scripts/memory/redismem.py +++ b/scripts/memory/redismem.py @@ -6,16 +6,9 @@ from redis.commands.search.query import Query from redis.commands.search.indexDefinition import IndexDefinition, IndexType import numpy as np -from memory.base import MemoryProviderSingleton, get_embedding +from memory.base import MemoryProviderSingleton, get_embedding, EMBED_DIM from logger import logger from colorama import Fore, Style -from config import Config - -# TODO: get the embeddings dimension without importing config -cfg = Config() - -# set the embedding dimension based on the embeder -EMBED_DIM = 1536 if cfg.memory_embeder == "ada" else 768 SCHEMA = [ TextField("data"),