fix: more modular approach for embedding dimension

This commit is contained in:
Tymec
2023-04-14 17:17:10 +02:00
parent 653904a359
commit 121f4e606c
4 changed files with 10 additions and 18 deletions

View File

@@ -15,6 +15,12 @@ except ImportError:
cfg = Config()
# Dimension of embeddings encoded by models
EMBED_DIM = {
"ada": 1536,
"sbert": 768
}.get(cfg.memory_embeder, default=1536)
def get_embedding(text):
text = text.replace("\n", " ")

View File

@@ -3,14 +3,9 @@ import orjson
from typing import Any, List, Optional
import numpy as np
import os
from memory.base import MemoryProviderSingleton, get_embedding
from config import Config
from memory.base import MemoryProviderSingleton, get_embedding, EMBED_DIM
# TODO: get the embeddings dimension without importing config
cfg = Config()
# set the embedding dimension based on the embeder
EMBED_DIM = 1536 if cfg.memory_embeder == "ada" else 768
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS

View File

@@ -1,7 +1,7 @@
import pinecone
from memory.base import MemoryProviderSingleton, get_embedding
from memory.base import MemoryProviderSingleton, get_embedding, EMBED_DIM
from logger import logger
from colorama import Fore, Style
@@ -10,8 +10,6 @@ class PineconeMemory(MemoryProviderSingleton):
pinecone_api_key = cfg.pinecone_api_key
pinecone_region = cfg.pinecone_region
pinecone.init(api_key=pinecone_api_key, environment=pinecone_region)
# set the embedding dimension based on the embeder
dimension = 1536 if cfg.memory_embeder == "ada" else 768
metric = "cosine"
pod_type = "p1"
table_name = "auto-gpt"
@@ -29,7 +27,7 @@ class PineconeMemory(MemoryProviderSingleton):
exit(1)
if table_name not in pinecone.list_indexes():
pinecone.create_index(table_name, dimension=dimension, metric=metric, pod_type=pod_type)
pinecone.create_index(table_name, dimension=EMBED_DIM, metric=metric, pod_type=pod_type)
self.index = pinecone.Index(table_name)
def add(self, data):

View File

@@ -6,16 +6,9 @@ from redis.commands.search.query import Query
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
import numpy as np
from memory.base import MemoryProviderSingleton, get_embedding
from memory.base import MemoryProviderSingleton, get_embedding, EMBED_DIM
from logger import logger
from colorama import Fore, Style
from config import Config
# TODO: get the embeddings dimension without importing config
cfg = Config()
# set the embedding dimension based on the embeder
EMBED_DIM = 1536 if cfg.memory_embeder == "ada" else 768
SCHEMA = [
TextField("data"),