mirror of
https://github.com/aljazceru/gpt-engineer.git
synced 2025-12-17 20:55:09 +01:00
Bugfixes, store output logs
This commit is contained in:
@@ -10,17 +10,7 @@ logger = logging.getLogger(__name__)
|
||||
class AI:
|
||||
def __init__(self, model="gpt-4", temperature=0.1):
|
||||
self.temperature = temperature
|
||||
|
||||
try:
|
||||
openai.Model.retrieve(model)
|
||||
self.model = model
|
||||
except openai.InvalidRequestError:
|
||||
print(
|
||||
f"Model {model} not available for provided API key. Reverting "
|
||||
"to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: "
|
||||
"https://openai.com/waitlist/gpt-4-api"
|
||||
)
|
||||
self.model = "gpt-3.5-turbo"
|
||||
self.model = model
|
||||
|
||||
def start(self, system, user):
|
||||
messages = [
|
||||
@@ -61,3 +51,16 @@ class AI:
|
||||
messages += [{"role": "assistant", "content": "".join(chat)}]
|
||||
logger.debug(f"Chat completion finished: {messages}")
|
||||
return messages
|
||||
|
||||
|
||||
def fallback_model(model: str) -> str:
|
||||
try:
|
||||
openai.Model.retrieve(model)
|
||||
return model
|
||||
except openai.InvalidRequestError:
|
||||
print(
|
||||
f"Model {model} not available for provided API key. Reverting "
|
||||
"to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: "
|
||||
"https://openai.com/waitlist/gpt-4-api\n"
|
||||
)
|
||||
return "gpt-3.5-turbo"
|
||||
|
||||
@@ -4,13 +4,15 @@ import os
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from dataclasses_json import dataclass_json
|
||||
|
||||
from gpt_engineer import steps
|
||||
from gpt_engineer.db import DBs
|
||||
from gpt_engineer.db import DB, DBs
|
||||
from gpt_engineer.steps import Step
|
||||
|
||||
|
||||
@@ -22,9 +24,12 @@ class Learning:
|
||||
steps: str
|
||||
steps_file_hash: str
|
||||
prompt: str
|
||||
logs: str
|
||||
workspace: str
|
||||
feedback: str | None
|
||||
session: str
|
||||
version: str = "0.1"
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
version: str = "0.2"
|
||||
|
||||
|
||||
def steps_file_hash():
|
||||
@@ -33,8 +38,23 @@ def steps_file_hash():
|
||||
return hashlib.sha256(content.encode("utf-8"), usedforsecurity=False).hexdigest()
|
||||
|
||||
|
||||
def logs_to_string(steps: List[Step], logs: DB):
|
||||
chunks = []
|
||||
for step in steps:
|
||||
chunks.append(f"--- {step.__name__} ---\n")
|
||||
messages = json.loads(logs[step.__name__])
|
||||
chunks.append(format_messages(messages))
|
||||
return "\n".join(chunks)
|
||||
|
||||
|
||||
def format_messages(messages: List[dict]) -> str:
|
||||
return "\n".join(
|
||||
[f"{message['role']}:\n\n{message['content']}" for message in messages]
|
||||
)
|
||||
|
||||
|
||||
def extract_learning(
|
||||
model: str, temperature: float, steps: list[Step], dbs: DBs
|
||||
model: str, temperature: float, steps: List[Step], dbs: DBs
|
||||
) -> Learning:
|
||||
learning = Learning(
|
||||
prompt=dbs.input["prompt"],
|
||||
@@ -44,11 +64,13 @@ def extract_learning(
|
||||
steps_file_hash=steps_file_hash(),
|
||||
feedback=dbs.input.get("feedback"),
|
||||
session=get_session(),
|
||||
logs=logs_to_string(steps, dbs.logs),
|
||||
workspace=dbs.workspace["all_output.txt"],
|
||||
)
|
||||
return learning
|
||||
|
||||
|
||||
def send_learnings(learning: Learning):
|
||||
def send_learning(learning: Learning):
|
||||
import rudderstack.analytics as rudder_analytics
|
||||
|
||||
rudder_analytics.write_key = "2Re4kqwL61GDp7S8ewe6K5dbogG"
|
||||
@@ -76,10 +98,10 @@ def get_session():
|
||||
return "ephemeral_" + str(random.randint(0, 2**32))
|
||||
|
||||
|
||||
def collect_learnings(model: str, temperature: float, steps: list[Step], dbs: DBs):
|
||||
def collect_learnings(model: str, temperature: float, steps: List[Step], dbs: DBs):
|
||||
if os.environ.get("COLLECT_LEARNINGS_OPT_OUT") in ["true", "1"]:
|
||||
print("COLLECT_LEARNINGS_OPT_OUT is set to true, not collecting learning")
|
||||
return
|
||||
|
||||
learnings = extract_learning(model, temperature, steps, dbs)
|
||||
send_learnings(learnings)
|
||||
send_learning(learnings)
|
||||
|
||||
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
import typer
|
||||
|
||||
from gpt_engineer import steps
|
||||
from gpt_engineer.ai import AI
|
||||
from gpt_engineer.ai import AI, fallback_model
|
||||
from gpt_engineer.collect import collect_learnings
|
||||
from gpt_engineer.db import DB, DBs
|
||||
from gpt_engineer.steps import STEPS
|
||||
@@ -44,6 +44,8 @@ def main(
|
||||
shutil.rmtree(memory_path, ignore_errors=True)
|
||||
shutil.rmtree(workspace_path, ignore_errors=True)
|
||||
|
||||
model = fallback_model(model)
|
||||
|
||||
ai = AI(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -32,8 +32,9 @@ def get_prompt(dbs):
|
||||
colored("Please put the prompt in the file `prompt`, not `main_prompt", "red")
|
||||
)
|
||||
print()
|
||||
return dbs.input["main_prompt"]
|
||||
|
||||
return dbs.input.get("prompt", dbs.input["main_prompt"])
|
||||
return dbs.input["prompt"]
|
||||
|
||||
|
||||
def simple_gen(ai: AI, dbs: DBs):
|
||||
@@ -65,9 +66,7 @@ def clarify(ai: AI, dbs: DBs):
|
||||
print()
|
||||
messages = ai.next(
|
||||
messages,
|
||||
ai.fuser(
|
||||
"Make your own assumptions and state them explicitly before starting"
|
||||
),
|
||||
"Make your own assumptions and state them explicitly before starting",
|
||||
)
|
||||
print()
|
||||
return messages
|
||||
@@ -192,7 +191,16 @@ def execute_entrypoint(ai, dbs):
|
||||
print("You can press ctrl+c *once* to stop the execution.")
|
||||
print()
|
||||
|
||||
subprocess.run("bash run.sh", shell=True, cwd=dbs.workspace.path)
|
||||
p = subprocess.Popen("bash run.sh", shell=True, cwd=dbs.workspace.path)
|
||||
try:
|
||||
p.wait()
|
||||
except KeyboardInterrupt:
|
||||
print("Stopping execution...")
|
||||
print()
|
||||
p.kill()
|
||||
print()
|
||||
print("Execution stopped.")
|
||||
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
@@ -21,7 +22,9 @@ def test_collect_learnings(monkeypatch):
|
||||
"prompt": "test prompt\n with newlines",
|
||||
"feedback": "test feedback",
|
||||
}
|
||||
dbs.logs = {gen_code.__name__: "test logs"}
|
||||
code = "this is output\n\nit contains code"
|
||||
dbs.logs = {gen_code.__name__: json.dumps([{"role": "system", "content": code}])}
|
||||
dbs.workspace = {"all_output.txt": "test workspace\n" + code}
|
||||
|
||||
collect_learnings(model, temperature, steps, dbs)
|
||||
|
||||
@@ -30,6 +33,9 @@ def test_collect_learnings(monkeypatch):
|
||||
assert rudder_analytics.track.call_args[1]["event"] == "learning"
|
||||
assert rudder_analytics.track.call_args[1]["properties"] == learnings.to_dict()
|
||||
|
||||
assert code in learnings.logs
|
||||
assert code in learnings.workspace
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-v"])
|
||||
|
||||
Reference in New Issue
Block a user