Added a memory backend argument

This commit is contained in:
Eesa Hamza
2023-04-12 16:58:39 +03:00
parent 16b37fff1d
commit 083ccb6bd3
2 changed files with 18 additions and 2 deletions

View File

@@ -2,7 +2,7 @@ import json
import random import random
import commands as cmd import commands as cmd
import utils import utils
from memory import get_memory from memory import get_memory, get_supported_memory_backends
import data import data
import chat import chat
from colorama import Fore, Style from colorama import Fore, Style
@@ -276,6 +276,7 @@ def parse_arguments():
parser.add_argument('--debug', action='store_true', help='Enable Debug Mode') parser.add_argument('--debug', action='store_true', help='Enable Debug Mode')
parser.add_argument('--gpt3only', action='store_true', help='Enable GPT3.5 Only Mode') parser.add_argument('--gpt3only', action='store_true', help='Enable GPT3.5 Only Mode')
parser.add_argument('--gpt4only', action='store_true', help='Enable GPT4 Only Mode') parser.add_argument('--gpt4only', action='store_true', help='Enable GPT4 Only Mode')
parser.add_argument('--use-memory', '-m', dest="memory_type", help='Defines which Memory backend to use')
args = parser.parse_args() args = parser.parse_args()
if args.continuous: if args.continuous:
@@ -302,6 +303,15 @@ def parse_arguments():
print_to_console("Debug Mode: ", Fore.GREEN, "ENABLED") print_to_console("Debug Mode: ", Fore.GREEN, "ENABLED")
cfg.set_debug_mode(True) cfg.set_debug_mode(True)
if args.memory_type:
supported_memory = get_supported_memory_backends()
chosen = args.memory_type
if not chosen in supported_memory:
print_to_console("ONLY THE FOLLOWING MEMORY BACKENDS ARE SUPPORTED: ", Fore.RED, f'{supported_memory}')
print_to_console(f"Defaulting to: ", Fore.YELLOW, cfg.memory_backend)
else:
cfg.memory_backend = chosen
# TODO: fill in llm values here # TODO: fill in llm values here
check_openai_api_key() check_openai_api_key()

View File

@@ -1,17 +1,21 @@
from memory.local import LocalCache from memory.local import LocalCache
supported_memory = ['local']
try: try:
from memory.redismem import RedisMemory from memory.redismem import RedisMemory
supported_memory.append('redis')
except ImportError: except ImportError:
print("Redis not installed. Skipping import.") print("Redis not installed. Skipping import.")
RedisMemory = None RedisMemory = None
try: try:
from memory.pinecone import PineconeMemory from memory.pinecone import PineconeMemory
supported_memory.append('pinecone')
except ImportError: except ImportError:
print("Pinecone not installed. Skipping import.") print("Pinecone not installed. Skipping import.")
PineconeMemory = None PineconeMemory = None
def get_memory(cfg, init=False): def get_memory(cfg, init=False):
memory = None memory = None
if cfg.memory_backend == "pinecone": if cfg.memory_backend == "pinecone":
@@ -35,6 +39,8 @@ def get_memory(cfg, init=False):
memory.clear() memory.clear()
return memory return memory
def get_supported_memory_backends():
return supported_memory
__all__ = [ __all__ = [
"get_memory", "get_memory",