diff --git a/gpt_engineer/ai.py b/gpt_engineer/ai.py index d837a07..ef9b65b 100644 --- a/gpt_engineer/ai.py +++ b/gpt_engineer/ai.py @@ -10,17 +10,7 @@ logger = logging.getLogger(__name__) class AI: def __init__(self, model="gpt-4", temperature=0.1): self.temperature = temperature - - try: - openai.Model.retrieve(model) - self.model = model - except openai.InvalidRequestError: - print( - f"Model {model} not available for provided API key. Reverting " - "to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: " - "https://openai.com/waitlist/gpt-4-api" - ) - self.model = "gpt-3.5-turbo" + self.model = model def start(self, system, user): messages = [ @@ -61,3 +51,16 @@ class AI: messages += [{"role": "assistant", "content": "".join(chat)}] logger.debug(f"Chat completion finished: {messages}") return messages + + +def fallback_model(model: str) -> str: + try: + openai.Model.retrieve(model) + return model + except openai.InvalidRequestError: + print( + f"Model {model} not available for provided API key. Reverting " + "to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: " + "https://openai.com/waitlist/gpt-4-api\n" + ) + return "gpt-3.5-turbo" diff --git a/gpt_engineer/collect.py b/gpt_engineer/collect.py index 491eef7..090ed0d 100644 --- a/gpt_engineer/collect.py +++ b/gpt_engineer/collect.py @@ -4,13 +4,15 @@ import os import random import tempfile -from dataclasses import dataclass +from dataclasses import dataclass, field +from datetime import datetime from pathlib import Path +from typing import List from dataclasses_json import dataclass_json from gpt_engineer import steps -from gpt_engineer.db import DBs +from gpt_engineer.db import DB, DBs from gpt_engineer.steps import Step @@ -22,9 +24,12 @@ class Learning: steps: str steps_file_hash: str prompt: str + logs: str + workspace: str feedback: str | None session: str - version: str = "0.1" + timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) + version: str = "0.2" def steps_file_hash(): @@ -33,8 +38,23 @@ def steps_file_hash(): return hashlib.sha256(content.encode("utf-8"), usedforsecurity=False).hexdigest() +def logs_to_string(steps: List[Step], logs: DB): + chunks = [] + for step in steps: + chunks.append(f"--- {step.__name__} ---\n") + messages = json.loads(logs[step.__name__]) + chunks.append(format_messages(messages)) + return "\n".join(chunks) + + +def format_messages(messages: List[dict]) -> str: + return "\n".join( + [f"{message['role']}:\n\n{message['content']}" for message in messages] + ) + + def extract_learning( - model: str, temperature: float, steps: list[Step], dbs: DBs + model: str, temperature: float, steps: List[Step], dbs: DBs ) -> Learning: learning = Learning( prompt=dbs.input["prompt"], @@ -44,11 +64,13 @@ def extract_learning( steps_file_hash=steps_file_hash(), feedback=dbs.input.get("feedback"), session=get_session(), + logs=logs_to_string(steps, dbs.logs), + workspace=dbs.workspace["all_output.txt"], ) return learning -def send_learnings(learning: Learning): +def send_learning(learning: Learning): import rudderstack.analytics as rudder_analytics rudder_analytics.write_key = "2Re4kqwL61GDp7S8ewe6K5dbogG" @@ -76,10 +98,10 @@ def get_session(): return "ephemeral_" + str(random.randint(0, 2**32)) -def collect_learnings(model: str, temperature: float, steps: list[Step], dbs: DBs): +def collect_learnings(model: str, temperature: float, steps: List[Step], dbs: DBs): if os.environ.get("COLLECT_LEARNINGS_OPT_OUT") in ["true", "1"]: print("COLLECT_LEARNINGS_OPT_OUT is set to true, not collecting learning") return learnings = extract_learning(model, temperature, steps, dbs) - send_learnings(learnings) + send_learning(learnings) diff --git a/gpt_engineer/main.py b/gpt_engineer/main.py index 934d80b..7e78ae1 100644 --- a/gpt_engineer/main.py +++ b/gpt_engineer/main.py @@ -7,7 +7,7 @@ from pathlib import Path import typer from gpt_engineer import steps -from gpt_engineer.ai import AI +from gpt_engineer.ai import AI, fallback_model from gpt_engineer.collect import collect_learnings from gpt_engineer.db import DB, DBs from gpt_engineer.steps import STEPS @@ -44,6 +44,8 @@ def main( shutil.rmtree(memory_path, ignore_errors=True) shutil.rmtree(workspace_path, ignore_errors=True) + model = fallback_model(model) + ai = AI( model=model, temperature=temperature, diff --git a/gpt_engineer/steps.py b/gpt_engineer/steps.py index 10d144e..ec41f3c 100644 --- a/gpt_engineer/steps.py +++ b/gpt_engineer/steps.py @@ -32,8 +32,9 @@ def get_prompt(dbs): colored("Please put the prompt in the file `prompt`, not `main_prompt", "red") ) print() + return dbs.input["main_prompt"] - return dbs.input.get("prompt", dbs.input["main_prompt"]) + return dbs.input["prompt"] def simple_gen(ai: AI, dbs: DBs): @@ -65,9 +66,7 @@ def clarify(ai: AI, dbs: DBs): print() messages = ai.next( messages, - ai.fuser( - "Make your own assumptions and state them explicitly before starting" - ), + "Make your own assumptions and state them explicitly before starting", ) print() return messages @@ -192,7 +191,16 @@ def execute_entrypoint(ai, dbs): print("You can press ctrl+c *once* to stop the execution.") print() - subprocess.run("bash run.sh", shell=True, cwd=dbs.workspace.path) + p = subprocess.Popen("bash run.sh", shell=True, cwd=dbs.workspace.path) + try: + p.wait() + except KeyboardInterrupt: + print("Stopping execution...") + print() + p.kill() + print() + print("Execution stopped.") + return [] diff --git a/tests/test_collect.py b/tests/test_collect.py index aba69eb..76fabde 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -1,3 +1,4 @@ +import json import os from unittest.mock import MagicMock @@ -21,7 +22,9 @@ def test_collect_learnings(monkeypatch): "prompt": "test prompt\n with newlines", "feedback": "test feedback", } - dbs.logs = {gen_code.__name__: "test logs"} + code = "this is output\n\nit contains code" + dbs.logs = {gen_code.__name__: json.dumps([{"role": "system", "content": code}])} + dbs.workspace = {"all_output.txt": "test workspace\n" + code} collect_learnings(model, temperature, steps, dbs) @@ -30,6 +33,9 @@ def test_collect_learnings(monkeypatch): assert rudder_analytics.track.call_args[1]["event"] == "learning" assert rudder_analytics.track.call_args[1]["properties"] == learnings.to_dict() + assert code in learnings.logs + assert code in learnings.workspace + if __name__ == "__main__": pytest.main(["-v"])