This commit is contained in:
BillSchumacher
2023-04-15 16:40:12 -05:00
parent f86ca43b2f
commit 4a19124cb7
3 changed files with 22 additions and 16 deletions

View File

@@ -21,12 +21,14 @@ def fix_json(json_string: str, schema: str) -> str:
# Try to fix the JSON using GPT: # Try to fix the JSON using GPT:
function_string = "def fix_json(json_string: str, schema:str=None) -> str:" function_string = "def fix_json(json_string: str, schema:str=None) -> str:"
args = [f"'''{json_string}'''", f"'''{schema}'''"] args = [f"'''{json_string}'''", f"'''{schema}'''"]
description_string = "This function takes a JSON string and ensures that it"\ description_string = (
" is parseable and fully compliant with the provided schema. If an object"\ "This function takes a JSON string and ensures that it"
" or field specified in the schema isn't contained within the correct JSON,"\ " is parseable and fully compliant with the provided schema. If an object"
" it is omitted. The function also escapes any double quotes within JSON"\ " or field specified in the schema isn't contained within the correct JSON,"
" string values to ensure that they are valid. If the JSON string contains"\ " it is omitted. The function also escapes any double quotes within JSON"
" string values to ensure that they are valid. If the JSON string contains"
" any None or NaN values, they are replaced with null before being parsed." " any None or NaN values, they are replaced with null before being parsed."
)
# If it doesn't already start with a "`", add one: # If it doesn't already start with a "`", add one:
if not json_string.startswith("`"): if not json_string.startswith("`"):

View File

@@ -126,13 +126,16 @@ def create_embedding_with_ada(text) -> list:
backoff = 2 ** (attempt + 2) backoff = 2 ** (attempt + 2)
try: try:
if CFG.use_azure: if CFG.use_azure:
return openai.Embedding.create(input=[text], return openai.Embedding.create(
engine=CFG.get_azure_deployment_id_for_model("text-embedding-ada-002"), input=[text],
engine=CFG.get_azure_deployment_id_for_model(
"text-embedding-ada-002"
),
)["data"][0]["embedding"] )["data"][0]["embedding"]
else: else:
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")[ return openai.Embedding.create(
"data" input=[text], model="text-embedding-ada-002"
][0]["embedding"] )["data"][0]["embedding"]
except RateLimitError: except RateLimitError:
pass pass
except APIError as e: except APIError as e:
@@ -148,4 +151,3 @@ def create_embedding_with_ada(text) -> list:
f"API Bad gateway. Waiting {backoff} seconds..." + Fore.RESET, f"API Bad gateway. Waiting {backoff} seconds..." + Fore.RESET,
) )
time.sleep(backoff) time.sleep(backoff)

View File

@@ -18,11 +18,13 @@ class MemoryDB:
# As last resort, open in dynamic memory. Won't be persistent. # As last resort, open in dynamic memory. Won't be persistent.
self.db_file = ":memory:" self.db_file = ":memory:"
self.cnx = sqlite3.connect(self.db_file) self.cnx = sqlite3.connect(self.db_file)
self.cnx.execute("CREATE VIRTUAL TABLE \ self.cnx.execute(
"CREATE VIRTUAL TABLE \
IF NOT EXISTS text USING FTS5 \ IF NOT EXISTS text USING FTS5 \
(session, \ (session, \
key, \ key, \
block);") block);"
)
self.session_id = int(self.get_max_session_id()) + 1 self.session_id = int(self.get_max_session_id()) + 1
self.cnx.commit() self.cnx.commit()
@@ -66,7 +68,7 @@ class MemoryDB:
cnx = self.get_cnx() cnx = self.get_cnx()
cnx.execute(cmd_str, (session_id, key, text)) cnx.execute(cmd_str, (session_id, key, text))
cnx.commit() cnx.commit()
# Overwrite text at key. # Overwrite text at key.
def overwrite(self, key, text): def overwrite(self, key, text):
self.delete_memory(key) self.delete_memory(key)
@@ -76,8 +78,8 @@ class MemoryDB:
cnx = self.get_cnx() cnx = self.get_cnx()
cnx.execute(cmd_str, (session_id, key, text)) cnx.execute(cmd_str, (session_id, key, text))
cnx.commit() cnx.commit()
def delete_memory(self, key, session_id = None): def delete_memory(self, key, session_id=None):
session = session_id session = session_id
if session is None: if session is None:
session = self.session_id session = self.session_id