Implemented logging token usage (solves #322) (#438)

* Implemented logging token usage

Token usage is now tracked and logged into memory/logs/token_usage

* Step names are now inferred from function name

* Incorporated Anton's feedback

- Made LogUsage a dataclass
- For token logging, step name is now inferred via inspect module

* Formatted (black/ruff)

* Update gpt_engineer/ai.py

Co-authored-by: Anton Osika <anton.osika@gmail.com>

* formatting

---------

Co-authored-by: Anton Osika <anton.osika@gmail.com>
This commit is contained in:
UmerHA
2023-07-03 21:28:34 +02:00
committed by GitHub
parent 2b8e056d5d
commit 8fd315d264
4 changed files with 112 additions and 14 deletions

View File

@@ -1,3 +1,4 @@
import inspect
import json
import re
import subprocess
@@ -35,12 +36,17 @@ def get_prompt(dbs: DBs) -> str:
return dbs.input["prompt"]
def curr_fn() -> str:
"""Get the name of the current function"""
return inspect.stack()[1].function
# All steps below have the signature Step
def simple_gen(ai: AI, dbs: DBs) -> List[dict]:
"""Run the AI on the main prompt and save the results"""
messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs))
messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs), step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace)
return messages
@@ -52,7 +58,7 @@ def clarify(ai: AI, dbs: DBs) -> List[dict]:
messages = [ai.fsystem(dbs.preprompts["qa"])]
user_input = get_prompt(dbs)
while True:
messages = ai.next(messages, user_input)
messages = ai.next(messages, user_input, step_name=curr_fn())
if messages[-1]["content"].strip() == "Nothing more to clarify.":
break
@@ -71,6 +77,7 @@ def clarify(ai: AI, dbs: DBs) -> List[dict]:
messages = ai.next(
messages,
"Make your own assumptions and state them explicitly before starting",
step_name=curr_fn(),
)
print()
return messages
@@ -97,7 +104,7 @@ def gen_spec(ai: AI, dbs: DBs) -> List[dict]:
ai.fsystem(f"Instructions: {dbs.input['prompt']}"),
]
messages = ai.next(messages, dbs.preprompts["spec"])
messages = ai.next(messages, dbs.preprompts["spec"], step_name=curr_fn())
dbs.memory["specification"] = messages[-1]["content"]
@@ -108,7 +115,7 @@ def respec(ai: AI, dbs: DBs) -> List[dict]:
messages = json.loads(dbs.logs[gen_spec.__name__])
messages += [ai.fsystem(dbs.preprompts["respec"])]
messages = ai.next(messages)
messages = ai.next(messages, step_name=curr_fn())
messages = ai.next(
messages,
(
@@ -119,6 +126,7 @@ def respec(ai: AI, dbs: DBs) -> List[dict]:
"If you are satisfied with the specification, just write out the "
"specification word by word again."
),
step_name=curr_fn(),
)
dbs.memory["specification"] = messages[-1]["content"]
@@ -135,7 +143,7 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]:
ai.fuser(f"Specification:\n\n{dbs.memory['specification']}"),
]
messages = ai.next(messages, dbs.preprompts["unit_tests"])
messages = ai.next(messages, dbs.preprompts["unit_tests"], step_name=curr_fn())
dbs.memory["unit_tests"] = messages[-1]["content"]
to_files(dbs.memory["unit_tests"], dbs.workspace)
@@ -145,13 +153,12 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]:
def gen_clarified_code(ai: AI, dbs: DBs) -> List[dict]:
"""Takes clarification and generates code"""
messages = json.loads(dbs.logs[clarify.__name__])
messages = [
ai.fsystem(setup_sys_prompt(dbs)),
] + messages[1:]
messages = ai.next(messages, dbs.preprompts["use_qa"])
messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace)
return messages
@@ -159,14 +166,13 @@ def gen_clarified_code(ai: AI, dbs: DBs) -> List[dict]:
def gen_code(ai: AI, dbs: DBs) -> List[dict]:
# get the messages from previous step
messages = [
ai.fsystem(setup_sys_prompt(dbs)),
ai.fuser(f"Instructions: {dbs.input['prompt']}"),
ai.fuser(f"Specification:\n\n{dbs.memory['specification']}"),
ai.fuser(f"Unit tests:\n\n{dbs.memory['unit_tests']}"),
]
messages = ai.next(messages, dbs.preprompts["use_qa"])
messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace)
return messages
@@ -224,6 +230,7 @@ def gen_entrypoint(ai: AI, dbs: DBs) -> List[dict]:
"if necessary.\n"
),
user="Information about the codebase:\n\n" + dbs.workspace["all_output.txt"],
step_name=curr_fn(),
)
print()
@@ -240,7 +247,7 @@ def use_feedback(ai: AI, dbs: DBs):
ai.fassistant(dbs.workspace["all_output.txt"]),
ai.fsystem(dbs.preprompts["use_feedback"]),
]
messages = ai.next(messages, dbs.input["feedback"])
messages = ai.next(messages, dbs.input["feedback"], step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace)
return messages
@@ -253,7 +260,9 @@ def fix_code(ai: AI, dbs: DBs):
ai.fuser(code_output),
ai.fsystem(dbs.preprompts["fix_code"]),
]
messages = ai.next(messages, "Please fix any errors in the code above.")
messages = ai.next(
messages, "Please fix any errors in the code above.", step_name=curr_fn()
)
to_files(messages[-1]["content"], dbs.workspace)
return messages