mirror of
https://github.com/aljazceru/gpt-engineer.git
synced 2025-12-17 12:45:26 +01:00
Add types
This commit is contained in:
@@ -12,16 +12,13 @@ from gpt_engineer.chat_to_files import to_files
|
||||
from gpt_engineer.db import DBs
|
||||
|
||||
|
||||
def setup_sys_prompt(dbs):
|
||||
def setup_sys_prompt(dbs: DBs) -> str:
|
||||
return (
|
||||
dbs.preprompts["generate"] + "\nUseful to know:\n" + dbs.preprompts["philosophy"]
|
||||
)
|
||||
|
||||
|
||||
Step = TypeVar("Step", bound=Callable[[AI, DBs], List[dict]])
|
||||
|
||||
|
||||
def get_prompt(dbs):
|
||||
def get_prompt(dbs: DBs) -> str:
|
||||
"""While we migrate we have this fallback getter"""
|
||||
assert (
|
||||
"prompt" in dbs.input or "main_prompt" in dbs.input
|
||||
@@ -37,14 +34,18 @@ def get_prompt(dbs):
|
||||
return dbs.input["prompt"]
|
||||
|
||||
|
||||
def simple_gen(ai: AI, dbs: DBs):
|
||||
# All steps below have this signature
|
||||
Step = TypeVar("Step", bound=Callable[[AI, DBs], List[dict]])
|
||||
|
||||
|
||||
def simple_gen(ai: AI, dbs: DBs) -> List[dict]:
|
||||
"""Run the AI on the main prompt and save the results"""
|
||||
messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs))
|
||||
to_files(messages[-1]["content"], dbs.workspace)
|
||||
return messages
|
||||
|
||||
|
||||
def clarify(ai: AI, dbs: DBs):
|
||||
def clarify(ai: AI, dbs: DBs) -> List[dict]:
|
||||
"""
|
||||
Ask the user if they want to clarify anything and save the results to the workspace
|
||||
"""
|
||||
@@ -83,7 +84,7 @@ def clarify(ai: AI, dbs: DBs):
|
||||
return messages
|
||||
|
||||
|
||||
def gen_spec(ai: AI, dbs: DBs):
|
||||
def gen_spec(ai: AI, dbs: DBs) -> List[dict]:
|
||||
"""
|
||||
Generate a spec from the main prompt + clarifications and save the results to
|
||||
the workspace
|
||||
@@ -100,7 +101,7 @@ def gen_spec(ai: AI, dbs: DBs):
|
||||
return messages
|
||||
|
||||
|
||||
def respec(ai: AI, dbs: DBs):
|
||||
def respec(ai: AI, dbs: DBs) -> List[dict]:
|
||||
messages = json.loads(dbs.logs[gen_spec.__name__])
|
||||
messages += [ai.fsystem(dbs.preprompts["respec"])]
|
||||
|
||||
@@ -121,7 +122,7 @@ def respec(ai: AI, dbs: DBs):
|
||||
return messages
|
||||
|
||||
|
||||
def gen_unit_tests(ai: AI, dbs: DBs):
|
||||
def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]:
|
||||
"""
|
||||
Generate unit tests based on the specification, that should work.
|
||||
"""
|
||||
@@ -139,8 +140,8 @@ def gen_unit_tests(ai: AI, dbs: DBs):
|
||||
return messages
|
||||
|
||||
|
||||
def gen_clarified_code(ai: AI, dbs: DBs):
|
||||
# get the messages from previous step
|
||||
def gen_clarified_code(ai: AI, dbs: DBs) -> List[dict]:
|
||||
"""Takes clarification and generates code"""
|
||||
|
||||
messages = json.loads(dbs.logs[clarify.__name__])
|
||||
|
||||
@@ -153,7 +154,7 @@ def gen_clarified_code(ai: AI, dbs: DBs):
|
||||
return messages
|
||||
|
||||
|
||||
def gen_code(ai: AI, dbs: DBs):
|
||||
def gen_code(ai: AI, dbs: DBs) -> List[dict]:
|
||||
# get the messages from previous step
|
||||
|
||||
messages = [
|
||||
@@ -167,7 +168,7 @@ def gen_code(ai: AI, dbs: DBs):
|
||||
return messages
|
||||
|
||||
|
||||
def execute_entrypoint(ai, dbs):
|
||||
def execute_entrypoint(ai: AI, dbs: DBs) -> List[dict]:
|
||||
command = dbs.workspace["run.sh"]
|
||||
|
||||
print("Do you want to execute this code?")
|
||||
|
||||
Reference in New Issue
Block a user