Langchain integration (#512)

* Added LangChain integration

* Fixed issue created by git checkin process

* Added ':' to characters to remove from end of file path

* Tested initial migration to LangChain, removed comments and logging used for debugging

* Tested initial migration to LangChain, removed comments and logging used for debugging

* Converted camelCase to snake_case

* Turns out we need the exception handling

* Testing Hugging Face Integrations via LangChain

* Added LangChain loadable models

* Renames "qa" prompt to "clarify", since it's used in the "clarify" step, asking for clarification

* Fixed loading model yaml files

* Fixed streaming

* Added modeldir cli option

* Fixed typing

* Fixed interaction with token logging

* Fix spelling + dependency issues + typing

* Fix spelling + tests

* Removed unneeded logging which caused test to fail

* Cleaned up code

* Incorporated feedback

- deleted unnecessary functions & logger.info
- used LangChain ChatLLM instead of LLM to naturally communicate with gpt-4
- deleted loading model from yaml file, as LC doesn't offer this for ChatModels

* Update gpt_engineer/steps.py

Co-authored-by: Anton Osika <anton.osika@gmail.com>

* Incorporated feedback

- Fixed failing test
- Removed parsing complexity by using # type: ignore
- Replace every ocurence of ai.last_message_content with its content

* Fixed test

* Update gpt_engineer/steps.py

---------

Co-authored-by: H <holden.robbins@gmail.com>
Co-authored-by: Anton Osika <anton.osika@gmail.com>
This commit is contained in:
UmerHA
2023-07-23 23:30:09 +02:00
committed by GitHub
parent 07ba335ecf
commit 19a4c10b6e
9 changed files with 132 additions and 92 deletions

View File

@@ -1,11 +1,11 @@
import inspect
import json
import re
import subprocess
from enum import Enum
from typing import List
from typing import List, Union
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from termcolor import colored
from gpt_engineer.ai import AI
@@ -13,6 +13,8 @@ from gpt_engineer.chat_to_files import to_files
from gpt_engineer.db import DBs
from gpt_engineer.learning import human_input
Message = Union[AIMessage, HumanMessage, SystemMessage]
def setup_sys_prompt(dbs: DBs) -> str:
return (
@@ -44,26 +46,27 @@ def curr_fn() -> str:
# All steps below have the signature Step
def simple_gen(ai: AI, dbs: DBs) -> List[dict]:
def simple_gen(ai: AI, dbs: DBs) -> List[Message]:
"""Run the AI on the main prompt and save the results"""
messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs), step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace)
to_files(messages[-1].content.strip(), dbs.workspace)
return messages
def clarify(ai: AI, dbs: DBs) -> List[dict]:
def clarify(ai: AI, dbs: DBs) -> List[Message]:
"""
Ask the user if they want to clarify anything and save the results to the workspace
"""
messages = [ai.fsystem(dbs.preprompts["qa"])]
messages: List[Message] = [ai.fsystem(dbs.preprompts["clarify"])]
user_input = get_prompt(dbs)
while True:
messages = ai.next(messages, user_input, step_name=curr_fn())
msg = messages[-1].content.strip()
if messages[-1]["content"].strip() == "Nothing more to clarify.":
if msg == "Nothing more to clarify.":
break
if messages[-1]["content"].strip().lower().startswith("no"):
if msg.lower().startswith("no"):
print("Nothing more to clarify.")
break
@@ -94,7 +97,7 @@ def clarify(ai: AI, dbs: DBs) -> List[dict]:
return messages
def gen_spec(ai: AI, dbs: DBs) -> List[dict]:
def gen_spec(ai: AI, dbs: DBs) -> List[Message]:
"""
Generate a spec from the main prompt + clarifications and save the results to
the workspace
@@ -106,13 +109,13 @@ def gen_spec(ai: AI, dbs: DBs) -> List[dict]:
messages = ai.next(messages, dbs.preprompts["spec"], step_name=curr_fn())
dbs.memory["specification"] = messages[-1]["content"]
dbs.memory["specification"] = messages[-1].content.strip()
return messages
def respec(ai: AI, dbs: DBs) -> List[dict]:
messages = json.loads(dbs.logs[gen_spec.__name__])
def respec(ai: AI, dbs: DBs) -> List[Message]:
messages = AI.deserialize_messages(dbs.logs[gen_spec.__name__])
messages += [ai.fsystem(dbs.preprompts["respec"])]
messages = ai.next(messages, step_name=curr_fn())
@@ -129,7 +132,7 @@ def respec(ai: AI, dbs: DBs) -> List[dict]:
step_name=curr_fn(),
)
dbs.memory["specification"] = messages[-1]["content"]
dbs.memory["specification"] = messages[-1].content.strip()
return messages
@@ -145,7 +148,7 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]:
messages = ai.next(messages, dbs.preprompts["unit_tests"], step_name=curr_fn())
dbs.memory["unit_tests"] = messages[-1]["content"]
dbs.memory["unit_tests"] = messages[-1].content.strip()
to_files(dbs.memory["unit_tests"], dbs.workspace)
return messages
@@ -153,14 +156,14 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]:
def gen_clarified_code(ai: AI, dbs: DBs) -> List[dict]:
"""Takes clarification and generates code"""
messages = json.loads(dbs.logs[clarify.__name__])
messages = AI.deserialize_messages(dbs.logs[clarify.__name__])
messages = [
ai.fsystem(setup_sys_prompt(dbs)),
] + messages[1:]
messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace)
to_files(messages[-1].content.strip(), dbs.workspace)
return messages
@@ -173,7 +176,7 @@ def gen_code(ai: AI, dbs: DBs) -> List[dict]:
ai.fuser(f"Unit tests:\n\n{dbs.memory['unit_tests']}"),
]
messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace)
to_files(messages[-1].content.strip(), dbs.workspace)
return messages
@@ -235,7 +238,7 @@ def gen_entrypoint(ai: AI, dbs: DBs) -> List[dict]:
print()
regex = r"```\S*\n(.+?)```"
matches = re.finditer(regex, messages[-1]["content"], re.DOTALL)
matches = re.finditer(regex, messages[-1].content.strip(), re.DOTALL)
dbs.workspace["run.sh"] = "\n".join(match.group(1) for match in matches)
return messages
@@ -248,12 +251,13 @@ def use_feedback(ai: AI, dbs: DBs):
ai.fsystem(dbs.preprompts["use_feedback"]),
]
messages = ai.next(messages, dbs.input["feedback"], step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace)
to_files(messages[-1].content.strip(), dbs.workspace)
return messages
def fix_code(ai: AI, dbs: DBs):
code_output = json.loads(dbs.logs[gen_code.__name__])[-1]["content"]
messages = AI.deserialize_messages(dbs.logs[gen_code.__name__])
code_output = messages[-1].content.strip()
messages = [
ai.fsystem(setup_sys_prompt(dbs)),
ai.fuser(f"Instructions: {dbs.input['prompt']}"),
@@ -263,7 +267,7 @@ def fix_code(ai: AI, dbs: DBs):
messages = ai.next(
messages, "Please fix any errors in the code above.", step_name=curr_fn()
)
to_files(messages[-1]["content"], dbs.workspace)
to_files(messages[-1].content.strip(), dbs.workspace)
return messages