Langchain integration (#512)

* Added LangChain integration

* Fixed issue created by git checkin process

* Added ':' to characters to remove from end of file path

* Tested initial migration to LangChain, removed comments and logging used for debugging

* Tested initial migration to LangChain, removed comments and logging used for debugging

* Converted camelCase to snake_case

* Turns out we need the exception handling

* Testing Hugging Face Integrations via LangChain

* Added LangChain loadable models

* Renames "qa" prompt to "clarify", since it's used in the "clarify" step, asking for clarification

* Fixed loading model yaml files

* Fixed streaming

* Added modeldir cli option

* Fixed typing

* Fixed interaction with token logging

* Fix spelling + dependency issues + typing

* Fix spelling + tests

* Removed unneeded logging which caused test to fail

* Cleaned up code

* Incorporated feedback

- deleted unnecessary functions & logger.info
- used LangChain ChatLLM instead of LLM to naturally communicate with gpt-4
- deleted loading model from yaml file, as LC doesn't offer this for ChatModels

* Update gpt_engineer/steps.py

Co-authored-by: Anton Osika <anton.osika@gmail.com>

* Incorporated feedback

- Fixed failing test
- Removed parsing complexity by using # type: ignore
- Replace every ocurence of ai.last_message_content with its content

* Fixed test

* Update gpt_engineer/steps.py

---------

Co-authored-by: H <holden.robbins@gmail.com>
Co-authored-by: Anton Osika <anton.osika@gmail.com>
This commit is contained in:
UmerHA
2023-07-23 23:30:09 +02:00
committed by GitHub
parent 07ba335ecf
commit 19a4c10b6e
9 changed files with 132 additions and 92 deletions

View File

@@ -1,13 +1,27 @@
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List from typing import List, Optional, Union
import openai import openai
import tiktoken import tiktoken
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage,
messages_from_dict,
messages_to_dict,
)
Message = Union[AIMessage, HumanMessage, SystemMessage]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -23,9 +37,11 @@ class TokenUsage:
class AI: class AI:
def __init__(self, model="gpt-4", temperature=0.1): def __init__(self, model_name="gpt-4", temperature=0.1):
self.temperature = temperature self.temperature = temperature
self.model = model self.model_name = fallback_model(model_name)
self.llm = create_chat_model(self.model_name, temperature)
self.tokenizer = get_tokenizer(self.model_name)
# initialize token usage log # initialize token usage log
self.cumulative_prompt_tokens = 0 self.cumulative_prompt_tokens = 0
@@ -33,62 +49,57 @@ class AI:
self.cumulative_total_tokens = 0 self.cumulative_total_tokens = 0
self.token_usage_log = [] self.token_usage_log = []
try: def start(self, system: str, user: str, step_name: str) -> List[Message]:
self.tokenizer = tiktoken.encoding_for_model(model) messages: List[Message] = [
except KeyError: SystemMessage(content=system),
logger.debug( HumanMessage(content=user),
f"Tiktoken encoder for model {model} not found. Using "
"cl100k_base encoder instead. The results may therefore be "
"inaccurate and should only be used as estimate."
)
self.tokenizer = tiktoken.get_encoding("cl100k_base")
def start(self, system, user, step_name):
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user},
] ]
return self.next(messages, step_name=step_name) return self.next(messages, step_name=step_name)
def fsystem(self, msg): def fsystem(self, msg: str) -> SystemMessage:
return {"role": "system", "content": msg} return SystemMessage(content=msg)
def fuser(self, msg): def fuser(self, msg: str) -> HumanMessage:
return {"role": "user", "content": msg} return HumanMessage(content=msg)
def fassistant(self, msg): def fassistant(self, msg: str) -> AIMessage:
return {"role": "assistant", "content": msg} return AIMessage(content=msg)
def next(self, messages: List[Dict[str, str]], prompt=None, *, step_name=None): def next(
self,
messages: List[Message],
prompt: Optional[str] = None,
*,
step_name: str,
) -> List[Message]:
if prompt: if prompt:
messages += [{"role": "user", "content": prompt}] messages.append(self.fuser(prompt))
logger.debug(f"Creating a new chat completion: {messages}") logger.debug(f"Creating a new chat completion: {messages}")
response = openai.ChatCompletion.create(
messages=messages,
stream=True,
model=self.model,
temperature=self.temperature,
)
chat = [] callsbacks = [StreamingStdOutCallbackHandler()]
for chunk in response: response = self.llm(messages, callbacks=callsbacks) # type: ignore
delta = chunk["choices"][0]["delta"] # type: ignore messages.append(response)
msg = delta.get("content", "")
print(msg, end="")
chat.append(msg)
print()
messages += [{"role": "assistant", "content": "".join(chat)}]
logger.debug(f"Chat completion finished: {messages}") logger.debug(f"Chat completion finished: {messages}")
self.update_token_usage_log( self.update_token_usage_log(
messages=messages, answer="".join(chat), step_name=step_name messages=messages, answer=response.content, step_name=step_name
) )
return messages return messages
def update_token_usage_log(self, messages, answer, step_name): @staticmethod
def serialize_messages(messages: List[Message]) -> str:
return json.dumps(messages_to_dict(messages))
@staticmethod
def deserialize_messages(jsondictstr: str) -> List[Message]:
return list(messages_from_dict(json.loads(jsondictstr))) # type: ignore
def update_token_usage_log(
self, messages: List[Message], answer: str, step_name: str
) -> None:
prompt_tokens = self.num_tokens_from_messages(messages) prompt_tokens = self.num_tokens_from_messages(messages)
completion_tokens = self.num_tokens(answer) completion_tokens = self.num_tokens(answer)
total_tokens = prompt_tokens + completion_tokens total_tokens = prompt_tokens + completion_tokens
@@ -109,7 +120,7 @@ class AI:
) )
) )
def format_token_usage_log(self): def format_token_usage_log(self) -> str:
result = "step_name," result = "step_name,"
result += "prompt_tokens_in_step,completion_tokens_in_step,total_tokens_in_step" result += "prompt_tokens_in_step,completion_tokens_in_step,total_tokens_in_step"
result += ",total_prompt_tokens,total_completion_tokens,total_tokens\n" result += ",total_prompt_tokens,total_completion_tokens,total_tokens\n"
@@ -123,20 +134,17 @@ class AI:
result += str(log.total_tokens) + "\n" result += str(log.total_tokens) + "\n"
return result return result
def num_tokens(self, txt): def num_tokens(self, txt: str) -> int:
return len(self.tokenizer.encode(txt)) return len(self.tokenizer.encode(txt))
def num_tokens_from_messages(self, messages): def num_tokens_from_messages(self, messages: List[Message]) -> int:
"""Returns the number of tokens used by a list of messages.""" """Returns the number of tokens used by a list of messages."""
n_tokens = 0 n_tokens = 0
for message in messages: for message in messages:
n_tokens += ( n_tokens += (
4 # every message follows <im_start>{role/name}\n{content}<im_end>\n 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
) )
for key, value in message.items(): n_tokens += self.num_tokens(message.content)
n_tokens += self.num_tokens(value)
if key == "name": # if there's a name, the role is omitted
n_tokens += -1 # role is always required and always 1 token
n_tokens += 2 # every reply is primed with <im_start>assistant n_tokens += 2 # every reply is primed with <im_start>assistant
return n_tokens return n_tokens
@@ -151,4 +159,39 @@ def fallback_model(model: str) -> str:
"to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: " "to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: "
"https://openai.com/waitlist/gpt-4-api\n" "https://openai.com/waitlist/gpt-4-api\n"
) )
return "gpt-3.5-turbo-16k" return "gpt-3.5-turbo"
def create_chat_model(model: str, temperature) -> BaseChatModel:
if model == "gpt-4":
return ChatOpenAI(
model="gpt-4",
temperature=temperature,
streaming=True,
client=openai.ChatCompletion,
)
elif model == "gpt-3.5-turbo":
return ChatOpenAI(
model="gpt-3.5-turbo",
temperature=temperature,
streaming=True,
client=openai.ChatCompletion,
)
else:
raise ValueError(f"Model {model} is not supported.")
def get_tokenizer(model: str):
if "gpt-4" in model or "gpt-3.5" in model:
return tiktoken.encoding_for_model(model)
logger.debug(
f"No encoder implemented for model {model}."
"Defaulting to tiktoken cl100k_base encoder."
"Use results only as estimates."
)
return tiktoken.get_encoding("cl100k_base")
def serialize_messages(messages: List[Message]) -> str:
return AI.serialize_messages(messages)

View File

@@ -9,7 +9,7 @@ def parse_chat(chat): # -> List[Tuple[str, str]]:
files = [] files = []
for match in matches: for match in matches:
# Strip the filename of any non-allowed characters and convert / to \ # Strip the filename of any non-allowed characters and convert / to \
path = re.sub(r'[<>"|?*]', "", match.group(1)) path = re.sub(r'[\:<>"|?*]', "", match.group(1))
# Remove leading and trailing brackets # Remove leading and trailing brackets
path = re.sub(r"^\[(.*)\]$", r"\1", path) path = re.sub(r"^\[(.*)\]$", r"\1", path)
@@ -18,7 +18,7 @@ def parse_chat(chat): # -> List[Tuple[str, str]]:
path = re.sub(r"^`(.*)`$", r"\1", path) path = re.sub(r"^`(.*)`$", r"\1", path)
# Remove trailing ] # Remove trailing ]
path = re.sub(r"\]$", "", path) path = re.sub(r"[\]\:]$", "", path)
# Get the code # Get the code
code = match.group(2) code = match.group(2)

View File

@@ -100,11 +100,11 @@ def check_consent():
path = Path(".gpte_consent") path = Path(".gpte_consent")
if path.exists() and path.read_text() == "true": if path.exists() and path.read_text() == "true":
return return
ans = input("Is it ok if we store your prompts to learn? (y/n)") answer = input("Is it ok if we store your prompts to learn? (y/n)")
while ans.lower() not in ("y", "n"): while answer.lower() not in ("y", "n"):
ans = input("Invalid input. Please enter y or n: ") answer = input("Invalid input. Please enter y or n: ")
if ans.lower() == "y": if answer.lower() == "y":
path.write_text("true") path.write_text("true")
print(colored("Thank you", "light_green")) print(colored("Thank you", "light_green"))
print() print()
@@ -153,21 +153,14 @@ def ask_if_can_store() -> bool:
return can_store == "y" return can_store == "y"
def logs_to_string(steps: List[Step], logs: DB): def logs_to_string(steps: List[Step], logs: DB) -> str:
chunks = [] chunks = []
for step in steps: for step in steps:
chunks.append(f"--- {step.__name__} ---\n") chunks.append(f"--- {step.__name__} ---\n")
messages = json.loads(logs[step.__name__]) chunks.append(logs[step.__name__])
chunks.append(format_messages(messages))
return "\n".join(chunks) return "\n".join(chunks)
def format_messages(messages: List[dict]) -> str:
return "\n".join(
[f"{message['role']}:\n\n{message['content']}" for message in messages]
)
def extract_learning( def extract_learning(
model: str, temperature: float, steps: List[Step], dbs: DBs, steps_file_hash model: str, temperature: float, steps: List[Step], dbs: DBs, steps_file_hash
) -> Learning: ) -> Learning:

View File

@@ -1,4 +1,3 @@
import json
import logging import logging
from pathlib import Path from pathlib import Path
@@ -28,7 +27,7 @@ def main(
model = fallback_model(model) model = fallback_model(model)
ai = AI( ai = AI(
model=model, model_name=model,
temperature=temperature, temperature=temperature,
) )
@@ -56,7 +55,7 @@ def main(
steps = STEPS[steps_config] steps = STEPS[steps_config]
for step in steps: for step in steps:
messages = step(ai, dbs) messages = step(ai, dbs)
dbs.logs[step.__name__] = json.dumps(messages) dbs.logs[step.__name__] = AI.serialize_messages(messages)
if collect_consent(): if collect_consent():
collect_learnings(model, temperature, steps, dbs) collect_learnings(model, temperature, steps, dbs)

View File

@@ -1,11 +1,11 @@
import inspect import inspect
import json
import re import re
import subprocess import subprocess
from enum import Enum from enum import Enum
from typing import List from typing import List, Union
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from termcolor import colored from termcolor import colored
from gpt_engineer.ai import AI from gpt_engineer.ai import AI
@@ -13,6 +13,8 @@ from gpt_engineer.chat_to_files import to_files
from gpt_engineer.db import DBs from gpt_engineer.db import DBs
from gpt_engineer.learning import human_input from gpt_engineer.learning import human_input
Message = Union[AIMessage, HumanMessage, SystemMessage]
def setup_sys_prompt(dbs: DBs) -> str: def setup_sys_prompt(dbs: DBs) -> str:
return ( return (
@@ -44,26 +46,27 @@ def curr_fn() -> str:
# All steps below have the signature Step # All steps below have the signature Step
def simple_gen(ai: AI, dbs: DBs) -> List[dict]: def simple_gen(ai: AI, dbs: DBs) -> List[Message]:
"""Run the AI on the main prompt and save the results""" """Run the AI on the main prompt and save the results"""
messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs), step_name=curr_fn()) messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs), step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace) to_files(messages[-1].content.strip(), dbs.workspace)
return messages return messages
def clarify(ai: AI, dbs: DBs) -> List[dict]: def clarify(ai: AI, dbs: DBs) -> List[Message]:
""" """
Ask the user if they want to clarify anything and save the results to the workspace Ask the user if they want to clarify anything and save the results to the workspace
""" """
messages = [ai.fsystem(dbs.preprompts["qa"])] messages: List[Message] = [ai.fsystem(dbs.preprompts["clarify"])]
user_input = get_prompt(dbs) user_input = get_prompt(dbs)
while True: while True:
messages = ai.next(messages, user_input, step_name=curr_fn()) messages = ai.next(messages, user_input, step_name=curr_fn())
msg = messages[-1].content.strip()
if messages[-1]["content"].strip() == "Nothing more to clarify.": if msg == "Nothing more to clarify.":
break break
if messages[-1]["content"].strip().lower().startswith("no"): if msg.lower().startswith("no"):
print("Nothing more to clarify.") print("Nothing more to clarify.")
break break
@@ -94,7 +97,7 @@ def clarify(ai: AI, dbs: DBs) -> List[dict]:
return messages return messages
def gen_spec(ai: AI, dbs: DBs) -> List[dict]: def gen_spec(ai: AI, dbs: DBs) -> List[Message]:
""" """
Generate a spec from the main prompt + clarifications and save the results to Generate a spec from the main prompt + clarifications and save the results to
the workspace the workspace
@@ -106,13 +109,13 @@ def gen_spec(ai: AI, dbs: DBs) -> List[dict]:
messages = ai.next(messages, dbs.preprompts["spec"], step_name=curr_fn()) messages = ai.next(messages, dbs.preprompts["spec"], step_name=curr_fn())
dbs.memory["specification"] = messages[-1]["content"] dbs.memory["specification"] = messages[-1].content.strip()
return messages return messages
def respec(ai: AI, dbs: DBs) -> List[dict]: def respec(ai: AI, dbs: DBs) -> List[Message]:
messages = json.loads(dbs.logs[gen_spec.__name__]) messages = AI.deserialize_messages(dbs.logs[gen_spec.__name__])
messages += [ai.fsystem(dbs.preprompts["respec"])] messages += [ai.fsystem(dbs.preprompts["respec"])]
messages = ai.next(messages, step_name=curr_fn()) messages = ai.next(messages, step_name=curr_fn())
@@ -129,7 +132,7 @@ def respec(ai: AI, dbs: DBs) -> List[dict]:
step_name=curr_fn(), step_name=curr_fn(),
) )
dbs.memory["specification"] = messages[-1]["content"] dbs.memory["specification"] = messages[-1].content.strip()
return messages return messages
@@ -145,7 +148,7 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]:
messages = ai.next(messages, dbs.preprompts["unit_tests"], step_name=curr_fn()) messages = ai.next(messages, dbs.preprompts["unit_tests"], step_name=curr_fn())
dbs.memory["unit_tests"] = messages[-1]["content"] dbs.memory["unit_tests"] = messages[-1].content.strip()
to_files(dbs.memory["unit_tests"], dbs.workspace) to_files(dbs.memory["unit_tests"], dbs.workspace)
return messages return messages
@@ -153,14 +156,14 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]:
def gen_clarified_code(ai: AI, dbs: DBs) -> List[dict]: def gen_clarified_code(ai: AI, dbs: DBs) -> List[dict]:
"""Takes clarification and generates code""" """Takes clarification and generates code"""
messages = json.loads(dbs.logs[clarify.__name__]) messages = AI.deserialize_messages(dbs.logs[clarify.__name__])
messages = [ messages = [
ai.fsystem(setup_sys_prompt(dbs)), ai.fsystem(setup_sys_prompt(dbs)),
] + messages[1:] ] + messages[1:]
messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn()) messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace) to_files(messages[-1].content.strip(), dbs.workspace)
return messages return messages
@@ -173,7 +176,7 @@ def gen_code(ai: AI, dbs: DBs) -> List[dict]:
ai.fuser(f"Unit tests:\n\n{dbs.memory['unit_tests']}"), ai.fuser(f"Unit tests:\n\n{dbs.memory['unit_tests']}"),
] ]
messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn()) messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace) to_files(messages[-1].content.strip(), dbs.workspace)
return messages return messages
@@ -235,7 +238,7 @@ def gen_entrypoint(ai: AI, dbs: DBs) -> List[dict]:
print() print()
regex = r"```\S*\n(.+?)```" regex = r"```\S*\n(.+?)```"
matches = re.finditer(regex, messages[-1]["content"], re.DOTALL) matches = re.finditer(regex, messages[-1].content.strip(), re.DOTALL)
dbs.workspace["run.sh"] = "\n".join(match.group(1) for match in matches) dbs.workspace["run.sh"] = "\n".join(match.group(1) for match in matches)
return messages return messages
@@ -248,12 +251,13 @@ def use_feedback(ai: AI, dbs: DBs):
ai.fsystem(dbs.preprompts["use_feedback"]), ai.fsystem(dbs.preprompts["use_feedback"]),
] ]
messages = ai.next(messages, dbs.input["feedback"], step_name=curr_fn()) messages = ai.next(messages, dbs.input["feedback"], step_name=curr_fn())
to_files(messages[-1]["content"], dbs.workspace) to_files(messages[-1].content.strip(), dbs.workspace)
return messages return messages
def fix_code(ai: AI, dbs: DBs): def fix_code(ai: AI, dbs: DBs):
code_output = json.loads(dbs.logs[gen_code.__name__])[-1]["content"] messages = AI.deserialize_messages(dbs.logs[gen_code.__name__])
code_output = messages[-1].content.strip()
messages = [ messages = [
ai.fsystem(setup_sys_prompt(dbs)), ai.fsystem(setup_sys_prompt(dbs)),
ai.fuser(f"Instructions: {dbs.input['prompt']}"), ai.fuser(f"Instructions: {dbs.input['prompt']}"),
@@ -263,7 +267,7 @@ def fix_code(ai: AI, dbs: DBs):
messages = ai.next( messages = ai.next(
messages, "Please fix any errors in the code above.", step_name=curr_fn() messages, "Please fix any errors in the code above.", step_name=curr_fn()
) )
to_files(messages[-1]["content"], dbs.workspace) to_files(messages[-1].content.strip(), dbs.workspace)
return messages return messages

View File

@@ -21,6 +21,7 @@ dependencies = [
'dataclasses-json == 0.5.7', 'dataclasses-json == 0.5.7',
'tiktoken', 'tiktoken',
'tabulate == 0.9.0', 'tabulate == 0.9.0',
'langchain',
] ]
classifiers = [ classifiers = [

View File

@@ -19,14 +19,14 @@ def main(
temperature: float = 0.1, temperature: float = 0.1,
): ):
ai = AI( ai = AI(
model=model, model_name=model,
temperature=temperature, temperature=temperature,
) )
with open(messages_path) as f: with open(messages_path) as f:
messages = json.load(f) messages = json.load(f)
messages = ai.next(messages) messages = ai.next(messages, step_name="rerun")
if out_path: if out_path:
to_files(messages[-1]["content"], out_path) to_files(messages[-1]["content"], out_path)

View File

@@ -43,7 +43,7 @@ def test_collect_learnings(monkeypatch):
b = {k: v for k, v in learnings.to_dict().items() if k != "timestamp"} b = {k: v for k, v in learnings.to_dict().items() if k != "timestamp"}
assert a == b assert a == b
assert code in learnings.logs assert json.dumps(code) in learnings.logs
assert code in learnings.workspace assert code in learnings.workspace