From 19a4c10b6e61bf431e5ca7f6cd68a892b3fee122 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Sun, 23 Jul 2023 23:30:09 +0200 Subject: [PATCH] Langchain integration (#512) * Added LangChain integration * Fixed issue created by git checkin process * Added ':' to characters to remove from end of file path * Tested initial migration to LangChain, removed comments and logging used for debugging * Tested initial migration to LangChain, removed comments and logging used for debugging * Converted camelCase to snake_case * Turns out we need the exception handling * Testing Hugging Face Integrations via LangChain * Added LangChain loadable models * Renames "qa" prompt to "clarify", since it's used in the "clarify" step, asking for clarification * Fixed loading model yaml files * Fixed streaming * Added modeldir cli option * Fixed typing * Fixed interaction with token logging * Fix spelling + dependency issues + typing * Fix spelling + tests * Removed unneeded logging which caused test to fail * Cleaned up code * Incorporated feedback - deleted unnecessary functions & logger.info - used LangChain ChatLLM instead of LLM to naturally communicate with gpt-4 - deleted loading model from yaml file, as LC doesn't offer this for ChatModels * Update gpt_engineer/steps.py Co-authored-by: Anton Osika * Incorporated feedback - Fixed failing test - Removed parsing complexity by using # type: ignore - Replace every ocurence of ai.last_message_content with its content * Fixed test * Update gpt_engineer/steps.py --------- Co-authored-by: H Co-authored-by: Anton Osika --- gpt_engineer/ai.py | 143 +++++++++++++++--------- gpt_engineer/chat_to_files.py | 4 +- gpt_engineer/learning.py | 19 +--- gpt_engineer/main.py | 5 +- gpt_engineer/preprompts/{qa => clarify} | 0 gpt_engineer/steps.py | 46 ++++---- pyproject.toml | 1 + scripts/rerun_edited_message_logs.py | 4 +- tests/test_collect.py | 2 +- 9 files changed, 132 insertions(+), 92 deletions(-) rename gpt_engineer/preprompts/{qa => clarify} (100%) diff --git a/gpt_engineer/ai.py b/gpt_engineer/ai.py index 37c85a1..9c96453 100644 --- a/gpt_engineer/ai.py +++ b/gpt_engineer/ai.py @@ -1,13 +1,27 @@ from __future__ import annotations +import json import logging from dataclasses import dataclass -from typing import Dict, List +from typing import List, Optional, Union import openai import tiktoken +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain.chat_models import ChatOpenAI +from langchain.chat_models.base import BaseChatModel +from langchain.schema import ( + AIMessage, + HumanMessage, + SystemMessage, + messages_from_dict, + messages_to_dict, +) + +Message = Union[AIMessage, HumanMessage, SystemMessage] + logger = logging.getLogger(__name__) @@ -23,9 +37,11 @@ class TokenUsage: class AI: - def __init__(self, model="gpt-4", temperature=0.1): + def __init__(self, model_name="gpt-4", temperature=0.1): self.temperature = temperature - self.model = model + self.model_name = fallback_model(model_name) + self.llm = create_chat_model(self.model_name, temperature) + self.tokenizer = get_tokenizer(self.model_name) # initialize token usage log self.cumulative_prompt_tokens = 0 @@ -33,62 +49,57 @@ class AI: self.cumulative_total_tokens = 0 self.token_usage_log = [] - try: - self.tokenizer = tiktoken.encoding_for_model(model) - except KeyError: - logger.debug( - f"Tiktoken encoder for model {model} not found. Using " - "cl100k_base encoder instead. The results may therefore be " - "inaccurate and should only be used as estimate." - ) - self.tokenizer = tiktoken.get_encoding("cl100k_base") - - def start(self, system, user, step_name): - messages = [ - {"role": "system", "content": system}, - {"role": "user", "content": user}, + def start(self, system: str, user: str, step_name: str) -> List[Message]: + messages: List[Message] = [ + SystemMessage(content=system), + HumanMessage(content=user), ] - return self.next(messages, step_name=step_name) - def fsystem(self, msg): - return {"role": "system", "content": msg} + def fsystem(self, msg: str) -> SystemMessage: + return SystemMessage(content=msg) - def fuser(self, msg): - return {"role": "user", "content": msg} + def fuser(self, msg: str) -> HumanMessage: + return HumanMessage(content=msg) - def fassistant(self, msg): - return {"role": "assistant", "content": msg} + def fassistant(self, msg: str) -> AIMessage: + return AIMessage(content=msg) - def next(self, messages: List[Dict[str, str]], prompt=None, *, step_name=None): + def next( + self, + messages: List[Message], + prompt: Optional[str] = None, + *, + step_name: str, + ) -> List[Message]: if prompt: - messages += [{"role": "user", "content": prompt}] + messages.append(self.fuser(prompt)) logger.debug(f"Creating a new chat completion: {messages}") - response = openai.ChatCompletion.create( - messages=messages, - stream=True, - model=self.model, - temperature=self.temperature, - ) - chat = [] - for chunk in response: - delta = chunk["choices"][0]["delta"] # type: ignore - msg = delta.get("content", "") - print(msg, end="") - chat.append(msg) - print() - messages += [{"role": "assistant", "content": "".join(chat)}] + callsbacks = [StreamingStdOutCallbackHandler()] + response = self.llm(messages, callbacks=callsbacks) # type: ignore + messages.append(response) + logger.debug(f"Chat completion finished: {messages}") self.update_token_usage_log( - messages=messages, answer="".join(chat), step_name=step_name + messages=messages, answer=response.content, step_name=step_name ) return messages - def update_token_usage_log(self, messages, answer, step_name): + @staticmethod + def serialize_messages(messages: List[Message]) -> str: + return json.dumps(messages_to_dict(messages)) + + @staticmethod + def deserialize_messages(jsondictstr: str) -> List[Message]: + return list(messages_from_dict(json.loads(jsondictstr))) # type: ignore + + def update_token_usage_log( + self, messages: List[Message], answer: str, step_name: str + ) -> None: prompt_tokens = self.num_tokens_from_messages(messages) completion_tokens = self.num_tokens(answer) total_tokens = prompt_tokens + completion_tokens @@ -109,7 +120,7 @@ class AI: ) ) - def format_token_usage_log(self): + def format_token_usage_log(self) -> str: result = "step_name," result += "prompt_tokens_in_step,completion_tokens_in_step,total_tokens_in_step" result += ",total_prompt_tokens,total_completion_tokens,total_tokens\n" @@ -123,20 +134,17 @@ class AI: result += str(log.total_tokens) + "\n" return result - def num_tokens(self, txt): + def num_tokens(self, txt: str) -> int: return len(self.tokenizer.encode(txt)) - def num_tokens_from_messages(self, messages): + def num_tokens_from_messages(self, messages: List[Message]) -> int: """Returns the number of tokens used by a list of messages.""" n_tokens = 0 for message in messages: n_tokens += ( 4 # every message follows {role/name}\n{content}\n ) - for key, value in message.items(): - n_tokens += self.num_tokens(value) - if key == "name": # if there's a name, the role is omitted - n_tokens += -1 # role is always required and always 1 token + n_tokens += self.num_tokens(message.content) n_tokens += 2 # every reply is primed with assistant return n_tokens @@ -151,4 +159,39 @@ def fallback_model(model: str) -> str: "to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: " "https://openai.com/waitlist/gpt-4-api\n" ) - return "gpt-3.5-turbo-16k" + return "gpt-3.5-turbo" + + +def create_chat_model(model: str, temperature) -> BaseChatModel: + if model == "gpt-4": + return ChatOpenAI( + model="gpt-4", + temperature=temperature, + streaming=True, + client=openai.ChatCompletion, + ) + elif model == "gpt-3.5-turbo": + return ChatOpenAI( + model="gpt-3.5-turbo", + temperature=temperature, + streaming=True, + client=openai.ChatCompletion, + ) + else: + raise ValueError(f"Model {model} is not supported.") + + +def get_tokenizer(model: str): + if "gpt-4" in model or "gpt-3.5" in model: + return tiktoken.encoding_for_model(model) + + logger.debug( + f"No encoder implemented for model {model}." + "Defaulting to tiktoken cl100k_base encoder." + "Use results only as estimates." + ) + return tiktoken.get_encoding("cl100k_base") + + +def serialize_messages(messages: List[Message]) -> str: + return AI.serialize_messages(messages) diff --git a/gpt_engineer/chat_to_files.py b/gpt_engineer/chat_to_files.py index 8e251cb..6a3c781 100644 --- a/gpt_engineer/chat_to_files.py +++ b/gpt_engineer/chat_to_files.py @@ -9,7 +9,7 @@ def parse_chat(chat): # -> List[Tuple[str, str]]: files = [] for match in matches: # Strip the filename of any non-allowed characters and convert / to \ - path = re.sub(r'[<>"|?*]', "", match.group(1)) + path = re.sub(r'[\:<>"|?*]', "", match.group(1)) # Remove leading and trailing brackets path = re.sub(r"^\[(.*)\]$", r"\1", path) @@ -18,7 +18,7 @@ def parse_chat(chat): # -> List[Tuple[str, str]]: path = re.sub(r"^`(.*)`$", r"\1", path) # Remove trailing ] - path = re.sub(r"\]$", "", path) + path = re.sub(r"[\]\:]$", "", path) # Get the code code = match.group(2) diff --git a/gpt_engineer/learning.py b/gpt_engineer/learning.py index 4864473..add72cd 100644 --- a/gpt_engineer/learning.py +++ b/gpt_engineer/learning.py @@ -100,11 +100,11 @@ def check_consent(): path = Path(".gpte_consent") if path.exists() and path.read_text() == "true": return - ans = input("Is it ok if we store your prompts to learn? (y/n)") - while ans.lower() not in ("y", "n"): - ans = input("Invalid input. Please enter y or n: ") + answer = input("Is it ok if we store your prompts to learn? (y/n)") + while answer.lower() not in ("y", "n"): + answer = input("Invalid input. Please enter y or n: ") - if ans.lower() == "y": + if answer.lower() == "y": path.write_text("true") print(colored("Thank you️", "light_green")) print() @@ -153,21 +153,14 @@ def ask_if_can_store() -> bool: return can_store == "y" -def logs_to_string(steps: List[Step], logs: DB): +def logs_to_string(steps: List[Step], logs: DB) -> str: chunks = [] for step in steps: chunks.append(f"--- {step.__name__} ---\n") - messages = json.loads(logs[step.__name__]) - chunks.append(format_messages(messages)) + chunks.append(logs[step.__name__]) return "\n".join(chunks) -def format_messages(messages: List[dict]) -> str: - return "\n".join( - [f"{message['role']}:\n\n{message['content']}" for message in messages] - ) - - def extract_learning( model: str, temperature: float, steps: List[Step], dbs: DBs, steps_file_hash ) -> Learning: diff --git a/gpt_engineer/main.py b/gpt_engineer/main.py index e3cd0a5..72db1c5 100644 --- a/gpt_engineer/main.py +++ b/gpt_engineer/main.py @@ -1,4 +1,3 @@ -import json import logging from pathlib import Path @@ -28,7 +27,7 @@ def main( model = fallback_model(model) ai = AI( - model=model, + model_name=model, temperature=temperature, ) @@ -56,7 +55,7 @@ def main( steps = STEPS[steps_config] for step in steps: messages = step(ai, dbs) - dbs.logs[step.__name__] = json.dumps(messages) + dbs.logs[step.__name__] = AI.serialize_messages(messages) if collect_consent(): collect_learnings(model, temperature, steps, dbs) diff --git a/gpt_engineer/preprompts/qa b/gpt_engineer/preprompts/clarify similarity index 100% rename from gpt_engineer/preprompts/qa rename to gpt_engineer/preprompts/clarify diff --git a/gpt_engineer/steps.py b/gpt_engineer/steps.py index 304d646..1401e27 100644 --- a/gpt_engineer/steps.py +++ b/gpt_engineer/steps.py @@ -1,11 +1,11 @@ import inspect -import json import re import subprocess from enum import Enum -from typing import List +from typing import List, Union +from langchain.schema import AIMessage, HumanMessage, SystemMessage from termcolor import colored from gpt_engineer.ai import AI @@ -13,6 +13,8 @@ from gpt_engineer.chat_to_files import to_files from gpt_engineer.db import DBs from gpt_engineer.learning import human_input +Message = Union[AIMessage, HumanMessage, SystemMessage] + def setup_sys_prompt(dbs: DBs) -> str: return ( @@ -44,26 +46,27 @@ def curr_fn() -> str: # All steps below have the signature Step -def simple_gen(ai: AI, dbs: DBs) -> List[dict]: +def simple_gen(ai: AI, dbs: DBs) -> List[Message]: """Run the AI on the main prompt and save the results""" messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs), step_name=curr_fn()) - to_files(messages[-1]["content"], dbs.workspace) + to_files(messages[-1].content.strip(), dbs.workspace) return messages -def clarify(ai: AI, dbs: DBs) -> List[dict]: +def clarify(ai: AI, dbs: DBs) -> List[Message]: """ Ask the user if they want to clarify anything and save the results to the workspace """ - messages = [ai.fsystem(dbs.preprompts["qa"])] + messages: List[Message] = [ai.fsystem(dbs.preprompts["clarify"])] user_input = get_prompt(dbs) while True: messages = ai.next(messages, user_input, step_name=curr_fn()) + msg = messages[-1].content.strip() - if messages[-1]["content"].strip() == "Nothing more to clarify.": + if msg == "Nothing more to clarify.": break - if messages[-1]["content"].strip().lower().startswith("no"): + if msg.lower().startswith("no"): print("Nothing more to clarify.") break @@ -94,7 +97,7 @@ def clarify(ai: AI, dbs: DBs) -> List[dict]: return messages -def gen_spec(ai: AI, dbs: DBs) -> List[dict]: +def gen_spec(ai: AI, dbs: DBs) -> List[Message]: """ Generate a spec from the main prompt + clarifications and save the results to the workspace @@ -106,13 +109,13 @@ def gen_spec(ai: AI, dbs: DBs) -> List[dict]: messages = ai.next(messages, dbs.preprompts["spec"], step_name=curr_fn()) - dbs.memory["specification"] = messages[-1]["content"] + dbs.memory["specification"] = messages[-1].content.strip() return messages -def respec(ai: AI, dbs: DBs) -> List[dict]: - messages = json.loads(dbs.logs[gen_spec.__name__]) +def respec(ai: AI, dbs: DBs) -> List[Message]: + messages = AI.deserialize_messages(dbs.logs[gen_spec.__name__]) messages += [ai.fsystem(dbs.preprompts["respec"])] messages = ai.next(messages, step_name=curr_fn()) @@ -129,7 +132,7 @@ def respec(ai: AI, dbs: DBs) -> List[dict]: step_name=curr_fn(), ) - dbs.memory["specification"] = messages[-1]["content"] + dbs.memory["specification"] = messages[-1].content.strip() return messages @@ -145,7 +148,7 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]: messages = ai.next(messages, dbs.preprompts["unit_tests"], step_name=curr_fn()) - dbs.memory["unit_tests"] = messages[-1]["content"] + dbs.memory["unit_tests"] = messages[-1].content.strip() to_files(dbs.memory["unit_tests"], dbs.workspace) return messages @@ -153,14 +156,14 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]: def gen_clarified_code(ai: AI, dbs: DBs) -> List[dict]: """Takes clarification and generates code""" - messages = json.loads(dbs.logs[clarify.__name__]) + messages = AI.deserialize_messages(dbs.logs[clarify.__name__]) messages = [ ai.fsystem(setup_sys_prompt(dbs)), ] + messages[1:] messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn()) - to_files(messages[-1]["content"], dbs.workspace) + to_files(messages[-1].content.strip(), dbs.workspace) return messages @@ -173,7 +176,7 @@ def gen_code(ai: AI, dbs: DBs) -> List[dict]: ai.fuser(f"Unit tests:\n\n{dbs.memory['unit_tests']}"), ] messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn()) - to_files(messages[-1]["content"], dbs.workspace) + to_files(messages[-1].content.strip(), dbs.workspace) return messages @@ -235,7 +238,7 @@ def gen_entrypoint(ai: AI, dbs: DBs) -> List[dict]: print() regex = r"```\S*\n(.+?)```" - matches = re.finditer(regex, messages[-1]["content"], re.DOTALL) + matches = re.finditer(regex, messages[-1].content.strip(), re.DOTALL) dbs.workspace["run.sh"] = "\n".join(match.group(1) for match in matches) return messages @@ -248,12 +251,13 @@ def use_feedback(ai: AI, dbs: DBs): ai.fsystem(dbs.preprompts["use_feedback"]), ] messages = ai.next(messages, dbs.input["feedback"], step_name=curr_fn()) - to_files(messages[-1]["content"], dbs.workspace) + to_files(messages[-1].content.strip(), dbs.workspace) return messages def fix_code(ai: AI, dbs: DBs): - code_output = json.loads(dbs.logs[gen_code.__name__])[-1]["content"] + messages = AI.deserialize_messages(dbs.logs[gen_code.__name__]) + code_output = messages[-1].content.strip() messages = [ ai.fsystem(setup_sys_prompt(dbs)), ai.fuser(f"Instructions: {dbs.input['prompt']}"), @@ -263,7 +267,7 @@ def fix_code(ai: AI, dbs: DBs): messages = ai.next( messages, "Please fix any errors in the code above.", step_name=curr_fn() ) - to_files(messages[-1]["content"], dbs.workspace) + to_files(messages[-1].content.strip(), dbs.workspace) return messages diff --git a/pyproject.toml b/pyproject.toml index c61ff08..087836a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ 'dataclasses-json == 0.5.7', 'tiktoken', 'tabulate == 0.9.0', + 'langchain', ] classifiers = [ diff --git a/scripts/rerun_edited_message_logs.py b/scripts/rerun_edited_message_logs.py index 8d9b258..c49d9a2 100644 --- a/scripts/rerun_edited_message_logs.py +++ b/scripts/rerun_edited_message_logs.py @@ -19,14 +19,14 @@ def main( temperature: float = 0.1, ): ai = AI( - model=model, + model_name=model, temperature=temperature, ) with open(messages_path) as f: messages = json.load(f) - messages = ai.next(messages) + messages = ai.next(messages, step_name="rerun") if out_path: to_files(messages[-1]["content"], out_path) diff --git a/tests/test_collect.py b/tests/test_collect.py index c02ed8b..df4f5a3 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -43,7 +43,7 @@ def test_collect_learnings(monkeypatch): b = {k: v for k, v in learnings.to_dict().items() if k != "timestamp"} assert a == b - assert code in learnings.logs + assert json.dumps(code) in learnings.logs assert code in learnings.workspace