Make the json_parser more robust

For some reason the bot keeps prefacing its JSON. This fixes it for now.
This commit is contained in:
Taylor Brown
2023-04-02 18:50:51 -05:00
parent a47da497b5
commit 3e587bc7fb
5 changed files with 160 additions and 58 deletions

View File

@@ -3,46 +3,23 @@ import json
import openai
import dirtyjson
from config import Config
from call_ai_function import call_ai_function
from json_parser import fix_and_parse_json
cfg = Config()
# This is a magic function that can do anything with no-code. See
# https://github.com/Torantulino/AI-Functions for more info.
def call_ai_function(function, args, description, model=cfg.smart_llm_model):
# For each arg, if any are None, convert to "None":
args = [str(arg) if arg is not None else "None" for arg in args]
# parse args to comma seperated string
args = ", ".join(args)
messages = [
{
"role": "system",
"content": f"You are now the following python function: ```# {description}\n{function}```\n\nOnly respond with your `return` value.",
},
{"role": "user", "content": args},
]
response = openai.ChatCompletion.create(
model=model, messages=messages, temperature=0
)
return response.choices[0].message["content"]
# Evaluating code
def evaluate_code(code: str) -> List[str]:
function_string = "def analyze_code(code: str) -> List[str]:"
args = [code]
description_string = """Analyzes the given code and returns a list of suggestions for improvements."""
result_string = call_ai_function(function_string, args, description_string)
return json.loads(result_string)
return fix_and_parse_json.loads(result_string)
# Improving code
def improve_code(suggestions: List[str], code: str) -> str:
function_string = (
"def generate_improved_code(suggestions: List[str], code: str) -> str:"
@@ -68,28 +45,3 @@ def write_tests(code: str, focus: List[str]) -> str:
return result_string
# TODO: Make debug a global config var
def fix_json(json_str: str, schema:str = None, debug=True) -> str:
# Try to fix the JSON using gpt:
function_string = "def fix_json(json_str: str, schema:str=None) -> str:"
args = [json_str, schema]
description_string = """Fixes the provided JSON string to make it parseable. If the schema is provided, the JSON will be made to look like the schema, otherwise it will be made to look like a valid JSON object."""
result_string = call_ai_function(
function_string, args, description_string, model=cfg.fast_llm_model
)
if debug:
print("------------ JSON FIX ATTEMPT ---------------")
print(f"Original JSON: {json_str}")
print(f"Fixed JSON: {result_string}")
print("----------- END OF FIX ATTEMPT ----------------")
try:
return dirtyjson.loads(result_string)
except:
# Log the exception:
print("Failed to fix JSON")
# Get the call stack:
import traceback
call_stack = traceback.format_exc()
print(call_stack)
return {}

View File

@@ -0,0 +1,27 @@
from typing import List, Optional
import json
import openai
import dirtyjson
from config import Config
cfg = Config()
# This is a magic function that can do anything with no-code. See
# https://github.com/Torantulino/AI-Functions for more info.
def call_ai_function(function, args, description, model=cfg.smart_llm_model):
# For each arg, if any are None, convert to "None":
args = [str(arg) if arg is not None else "None" for arg in args]
# parse args to comma seperated string
args = ", ".join(args)
messages = [
{
"role": "system",
"content": f"You are now the following python function: ```# {description}\n{function}```\n\nOnly respond with your `return` value.",
},
{"role": "user", "content": args},
]
response = openai.ChatCompletion.create(
model=model, messages=messages, temperature=0
)
return response.choices[0].message["content"]

View File

@@ -8,12 +8,13 @@ from config import Config
import ai_functions as ai
from file_operations import read_file, write_to_file, append_to_file, delete_file
from execute_code import execute_python_file
from json_parser import fix_and_parse_json
cfg = Config()
def get_command(response):
try:
response_json = json.loads(response)
response_json = fix_and_parse_json(response)
command = response_json["command"]
command_name = command["name"]
arguments = command["args"]

View File

@@ -1,12 +1,53 @@
import dirtyjson
from ai_functions import fix_json
from call_ai_function import call_ai_function
from config import Config
cfg = Config()
def fix_and_parse_json(json_str: str, try_to_fix_with_gpt: bool = True):
try:
return dirtyjson.loads(json_str)
except Exception as e:
if try_to_fix_with_gpt:
# Now try to fix this up using the ai_functions
return fix_json(json_str, None, True)
else:
raise e
# Let's do something manually - sometimes GPT responds with something BEFORE the braces:
# "I'm sorry, I don't understand. Please try again."{"text": "I'm sorry, I don't understand. Please try again.", "confidence": 0.0}
# So let's try to find the first brace and then parse the rest of the string
try:
brace_index = json_str.index("{")
json_str = json_str[brace_index:]
last_brace_index = json_str.rindex("}")
json_str = json_str[:last_brace_index+1]
return dirtyjson.loads(json_str)
except Exception as e:
if try_to_fix_with_gpt:
# Now try to fix this up using the ai_functions
return fix_json(json_str, None, True)
else:
raise e
# TODO: Make debug a global config var
def fix_json(json_str: str, schema:str = None, debug=True) -> str:
# Try to fix the JSON using gpt:
function_string = "def fix_json(json_str: str, schema:str=None) -> str:"
args = [json_str, schema]
description_string = """Fixes the provided JSON string to make it parseable. If the schema is provided, the JSON will be made to look like the schema, otherwise it will be made to look like a valid JSON object."""
# If it doesn't already start with a "`", add one:
if not json_str.startswith("`"):
json_str = "```json\n" + json_str + "\n```"
result_string = call_ai_function(
function_string, args, description_string, model=cfg.fast_llm_model
)
if debug:
print("------------ JSON FIX ATTEMPT ---------------")
print(f"Original JSON: {json_str}")
print(f"Fixed JSON: {result_string}")
print("----------- END OF FIX ATTEMPT ----------------")
try:
return dirtyjson.loads(result_string)
except:
# Log the exception:
print("Failed to fix JSON")
# Get the call stack:
import traceback
call_stack = traceback.format_exc()
print(call_stack)
return {}