Implement local memory.

This commit is contained in:
BillSchumacher
2023-04-07 18:13:18 -05:00
parent ea6b970509
commit cb14c8d999
6 changed files with 125 additions and 7 deletions

View File

@@ -12,4 +12,5 @@ docker
duckduckgo-search duckduckgo-search
google-api-python-client #(https://developers.google.com/custom-search/v1/overview) google-api-python-client #(https://developers.google.com/custom-search/v1/overview)
pinecone-client==2.2.1 pinecone-client==2.2.1
redis redis
orjson

View File

@@ -1,5 +1,6 @@
import browse import browse
import json import json
from memory.local import LocalCache
from memory.pinecone import PineconeMemory from memory.pinecone import PineconeMemory
from memory.redismem import RedisMemory from memory.redismem import RedisMemory
import datetime import datetime
@@ -55,11 +56,14 @@ def get_command(response):
def execute_command(command_name, arguments): def execute_command(command_name, arguments):
if cfg.memory_backend == "pinecone": if cfg.memory_backend == "pinecone":
memory = PineconeMemory(cfg=cfg) memory = PineconeMemory(cfg=cfg)
else: elif cfg.memory_backend == "redis":
memory = RedisMemory(cfg=cfg) memory = RedisMemory(cfg=cfg)
else:
memory = LocalCache(cfg=cfg)
try: try:
if command_name == "google": if command_name == "google":
# Check if the Google API key is set and use the official search method # Check if the Google API key is set and use the official search method
# If the API key is not set or has only whitespaces, use the unofficial search method # If the API key is not set or has only whitespaces, use the unofficial search method
if cfg.google_api_key and (cfg.google_api_key.strip() if cfg.google_api_key else None): if cfg.google_api_key and (cfg.google_api_key.strip() if cfg.google_api_key else None):

View File

@@ -65,10 +65,10 @@ class Config(metaclass=Singleton):
self.redis_port = os.getenv("REDIS_PORT") self.redis_port = os.getenv("REDIS_PORT")
self.redis_password = os.getenv("REDIS_PASSWORD") self.redis_password = os.getenv("REDIS_PASSWORD")
self.wipe_redis_on_start = os.getenv("WIPE_REDIS_ON_START", "True") == 'True' self.wipe_redis_on_start = os.getenv("WIPE_REDIS_ON_START", "True") == 'True'
self.memory_index = os.getenv("MEMORY_INDEX", 'gpt') self.memory_index = os.getenv("MEMORY_INDEX", 'auto-gpt')
# Note that indexes must be created on db 0 in redis, this is not configureable. # Note that indexes must be created on db 0 in redis, this is not configureable.
self.memory_backend = os.getenv("MEMORY_BACKEND", 'pinecone') self.memory_backend = os.getenv("MEMORY_BACKEND", 'local')
# Initialize the OpenAI API client # Initialize the OpenAI API client
openai.api_key = self.openai_api_key openai.api_key = self.openai_api_key

View File

@@ -1,6 +1,7 @@
import json import json
import random import random
import commands as cmd import commands as cmd
from memory.local import LocalCache
from memory.pinecone import PineconeMemory from memory.pinecone import PineconeMemory
from memory.redismem import RedisMemory from memory.redismem import RedisMemory
import data import data
@@ -287,8 +288,10 @@ user_input = "Determine which next command to use, and respond using the format
if cfg.memory_backend == "pinecone": if cfg.memory_backend == "pinecone":
memory = PineconeMemory(cfg) memory = PineconeMemory(cfg)
memory.clear() memory.clear()
else: elif cfg.memory_backend == "redis":
memory = RedisMemory(cfg) memory = RedisMemory(cfg)
else:
memory = LocalCache(cfg)
print('Using memory of type: ' + memory.__class__.__name__) print('Using memory of type: ' + memory.__class__.__name__)

111
scripts/memory/local.py Normal file
View File

@@ -0,0 +1,111 @@
import dataclasses
import orjson
from typing import Any, List, Optional
import numpy as np
import os
from memory.base import MemoryProviderSingleton, get_ada_embedding
EMBED_DIM = 1536
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS
def create_default_embeddings():
return np.zeros((0, EMBED_DIM)).astype(np.float32)
@dataclasses.dataclass
class CacheContent:
texts: List[str] = dataclasses.field(default_factory=list)
embeddings: np.ndarray = dataclasses.field(
default_factory=create_default_embeddings
)
class LocalCache(MemoryProviderSingleton):
# on load, load our database
def __init__(self, cfg) -> None:
self.filename = f"{cfg.memory_index}.json"
if os.path.exists(self.filename):
with open(self.filename, 'rb') as f:
loaded = orjson.loads(f.read())
self.data = CacheContent(**loaded)
else:
self.data = CacheContent()
def add(self, text: str):
"""
Add text to our list of texts, add embedding as row to our
embeddings-matrix
Args:
text: str
Returns: None
"""
self.data.texts.append(text)
embedding = get_ada_embedding(text)
vector = np.array(embedding).astype(np.float32)
vector = vector[np.newaxis, :]
self.data.embeddings = np.concatenate(
[
vector,
self.data.embeddings,
],
axis=0,
)
with open(self.filename, 'wb') as f:
out = orjson.dumps(
self.data,
option=SAVE_OPTIONS
)
f.write(out)
def clear(self) -> str:
"""
Clears the redis server.
Returns: A message indicating that the memory has been cleared.
"""
self.data = CacheContent()
return "Obliviated"
def get(self, data: str) -> Optional[List[Any]]:
"""
Gets the data from the memory that is most relevant to the given data.
Args:
data: The data to compare to.
Returns: The most relevant data.
"""
return self.get_relevant(data, 1)
def get_relevant(self, text: str, k: int) -> List[Any]:
""""
matrix-vector mult to find score-for-each-row-of-matrix
get indices for top-k winning scores
return texts for those indices
Args:
text: str
k: int
Returns: List[str]
"""
embedding = get_ada_embedding(text)
scores = np.dot(self.data.embeddings, embedding)
top_k_indices = np.argsort(scores)[-k:][::-1]
return [self.data.texts[i] for i in top_k_indices]
def get_stats(self):
"""
Returns: The stats of the local cache.
"""
return len(self.data.texts), self.data.embeddings.shape

View File

@@ -4,7 +4,6 @@ import redis
from redis.commands.search.field import VectorField, TextField from redis.commands.search.field import VectorField, TextField
from redis.commands.search.query import Query from redis.commands.search.query import Query
from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.indexDefinition import IndexDefinition, IndexType
import traceback
import numpy as np import numpy as np
from memory.base import MemoryProviderSingleton, get_ada_embedding from memory.base import MemoryProviderSingleton, get_ada_embedding