mirror of
https://github.com/aljazceru/gpt-engineer.git
synced 2025-12-17 12:45:26 +01:00
Langchain integration (#512)
* Added LangChain integration * Fixed issue created by git checkin process * Added ':' to characters to remove from end of file path * Tested initial migration to LangChain, removed comments and logging used for debugging * Tested initial migration to LangChain, removed comments and logging used for debugging * Converted camelCase to snake_case * Turns out we need the exception handling * Testing Hugging Face Integrations via LangChain * Added LangChain loadable models * Renames "qa" prompt to "clarify", since it's used in the "clarify" step, asking for clarification * Fixed loading model yaml files * Fixed streaming * Added modeldir cli option * Fixed typing * Fixed interaction with token logging * Fix spelling + dependency issues + typing * Fix spelling + tests * Removed unneeded logging which caused test to fail * Cleaned up code * Incorporated feedback - deleted unnecessary functions & logger.info - used LangChain ChatLLM instead of LLM to naturally communicate with gpt-4 - deleted loading model from yaml file, as LC doesn't offer this for ChatModels * Update gpt_engineer/steps.py Co-authored-by: Anton Osika <anton.osika@gmail.com> * Incorporated feedback - Fixed failing test - Removed parsing complexity by using # type: ignore - Replace every ocurence of ai.last_message_content with its content * Fixed test * Update gpt_engineer/steps.py --------- Co-authored-by: H <holden.robbins@gmail.com> Co-authored-by: Anton Osika <anton.osika@gmail.com>
This commit is contained in:
@@ -1,13 +1,27 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema import (
|
||||||
|
AIMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
messages_from_dict,
|
||||||
|
messages_to_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
Message = Union[AIMessage, HumanMessage, SystemMessage]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -23,9 +37,11 @@ class TokenUsage:
|
|||||||
|
|
||||||
|
|
||||||
class AI:
|
class AI:
|
||||||
def __init__(self, model="gpt-4", temperature=0.1):
|
def __init__(self, model_name="gpt-4", temperature=0.1):
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.model = model
|
self.model_name = fallback_model(model_name)
|
||||||
|
self.llm = create_chat_model(self.model_name, temperature)
|
||||||
|
self.tokenizer = get_tokenizer(self.model_name)
|
||||||
|
|
||||||
# initialize token usage log
|
# initialize token usage log
|
||||||
self.cumulative_prompt_tokens = 0
|
self.cumulative_prompt_tokens = 0
|
||||||
@@ -33,62 +49,57 @@ class AI:
|
|||||||
self.cumulative_total_tokens = 0
|
self.cumulative_total_tokens = 0
|
||||||
self.token_usage_log = []
|
self.token_usage_log = []
|
||||||
|
|
||||||
try:
|
def start(self, system: str, user: str, step_name: str) -> List[Message]:
|
||||||
self.tokenizer = tiktoken.encoding_for_model(model)
|
messages: List[Message] = [
|
||||||
except KeyError:
|
SystemMessage(content=system),
|
||||||
logger.debug(
|
HumanMessage(content=user),
|
||||||
f"Tiktoken encoder for model {model} not found. Using "
|
|
||||||
"cl100k_base encoder instead. The results may therefore be "
|
|
||||||
"inaccurate and should only be used as estimate."
|
|
||||||
)
|
|
||||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
||||||
|
|
||||||
def start(self, system, user, step_name):
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": system},
|
|
||||||
{"role": "user", "content": user},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return self.next(messages, step_name=step_name)
|
return self.next(messages, step_name=step_name)
|
||||||
|
|
||||||
def fsystem(self, msg):
|
def fsystem(self, msg: str) -> SystemMessage:
|
||||||
return {"role": "system", "content": msg}
|
return SystemMessage(content=msg)
|
||||||
|
|
||||||
def fuser(self, msg):
|
def fuser(self, msg: str) -> HumanMessage:
|
||||||
return {"role": "user", "content": msg}
|
return HumanMessage(content=msg)
|
||||||
|
|
||||||
def fassistant(self, msg):
|
def fassistant(self, msg: str) -> AIMessage:
|
||||||
return {"role": "assistant", "content": msg}
|
return AIMessage(content=msg)
|
||||||
|
|
||||||
def next(self, messages: List[Dict[str, str]], prompt=None, *, step_name=None):
|
def next(
|
||||||
|
self,
|
||||||
|
messages: List[Message],
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
step_name: str,
|
||||||
|
) -> List[Message]:
|
||||||
if prompt:
|
if prompt:
|
||||||
messages += [{"role": "user", "content": prompt}]
|
messages.append(self.fuser(prompt))
|
||||||
|
|
||||||
logger.debug(f"Creating a new chat completion: {messages}")
|
logger.debug(f"Creating a new chat completion: {messages}")
|
||||||
response = openai.ChatCompletion.create(
|
|
||||||
messages=messages,
|
|
||||||
stream=True,
|
|
||||||
model=self.model,
|
|
||||||
temperature=self.temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
chat = []
|
callsbacks = [StreamingStdOutCallbackHandler()]
|
||||||
for chunk in response:
|
response = self.llm(messages, callbacks=callsbacks) # type: ignore
|
||||||
delta = chunk["choices"][0]["delta"] # type: ignore
|
messages.append(response)
|
||||||
msg = delta.get("content", "")
|
|
||||||
print(msg, end="")
|
|
||||||
chat.append(msg)
|
|
||||||
print()
|
|
||||||
messages += [{"role": "assistant", "content": "".join(chat)}]
|
|
||||||
logger.debug(f"Chat completion finished: {messages}")
|
logger.debug(f"Chat completion finished: {messages}")
|
||||||
|
|
||||||
self.update_token_usage_log(
|
self.update_token_usage_log(
|
||||||
messages=messages, answer="".join(chat), step_name=step_name
|
messages=messages, answer=response.content, step_name=step_name
|
||||||
)
|
)
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def update_token_usage_log(self, messages, answer, step_name):
|
@staticmethod
|
||||||
|
def serialize_messages(messages: List[Message]) -> str:
|
||||||
|
return json.dumps(messages_to_dict(messages))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize_messages(jsondictstr: str) -> List[Message]:
|
||||||
|
return list(messages_from_dict(json.loads(jsondictstr))) # type: ignore
|
||||||
|
|
||||||
|
def update_token_usage_log(
|
||||||
|
self, messages: List[Message], answer: str, step_name: str
|
||||||
|
) -> None:
|
||||||
prompt_tokens = self.num_tokens_from_messages(messages)
|
prompt_tokens = self.num_tokens_from_messages(messages)
|
||||||
completion_tokens = self.num_tokens(answer)
|
completion_tokens = self.num_tokens(answer)
|
||||||
total_tokens = prompt_tokens + completion_tokens
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
@@ -109,7 +120,7 @@ class AI:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def format_token_usage_log(self):
|
def format_token_usage_log(self) -> str:
|
||||||
result = "step_name,"
|
result = "step_name,"
|
||||||
result += "prompt_tokens_in_step,completion_tokens_in_step,total_tokens_in_step"
|
result += "prompt_tokens_in_step,completion_tokens_in_step,total_tokens_in_step"
|
||||||
result += ",total_prompt_tokens,total_completion_tokens,total_tokens\n"
|
result += ",total_prompt_tokens,total_completion_tokens,total_tokens\n"
|
||||||
@@ -123,20 +134,17 @@ class AI:
|
|||||||
result += str(log.total_tokens) + "\n"
|
result += str(log.total_tokens) + "\n"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def num_tokens(self, txt):
|
def num_tokens(self, txt: str) -> int:
|
||||||
return len(self.tokenizer.encode(txt))
|
return len(self.tokenizer.encode(txt))
|
||||||
|
|
||||||
def num_tokens_from_messages(self, messages):
|
def num_tokens_from_messages(self, messages: List[Message]) -> int:
|
||||||
"""Returns the number of tokens used by a list of messages."""
|
"""Returns the number of tokens used by a list of messages."""
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
for message in messages:
|
for message in messages:
|
||||||
n_tokens += (
|
n_tokens += (
|
||||||
4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||||
)
|
)
|
||||||
for key, value in message.items():
|
n_tokens += self.num_tokens(message.content)
|
||||||
n_tokens += self.num_tokens(value)
|
|
||||||
if key == "name": # if there's a name, the role is omitted
|
|
||||||
n_tokens += -1 # role is always required and always 1 token
|
|
||||||
n_tokens += 2 # every reply is primed with <im_start>assistant
|
n_tokens += 2 # every reply is primed with <im_start>assistant
|
||||||
return n_tokens
|
return n_tokens
|
||||||
|
|
||||||
@@ -151,4 +159,39 @@ def fallback_model(model: str) -> str:
|
|||||||
"to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: "
|
"to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: "
|
||||||
"https://openai.com/waitlist/gpt-4-api\n"
|
"https://openai.com/waitlist/gpt-4-api\n"
|
||||||
)
|
)
|
||||||
return "gpt-3.5-turbo-16k"
|
return "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
|
||||||
|
def create_chat_model(model: str, temperature) -> BaseChatModel:
|
||||||
|
if model == "gpt-4":
|
||||||
|
return ChatOpenAI(
|
||||||
|
model="gpt-4",
|
||||||
|
temperature=temperature,
|
||||||
|
streaming=True,
|
||||||
|
client=openai.ChatCompletion,
|
||||||
|
)
|
||||||
|
elif model == "gpt-3.5-turbo":
|
||||||
|
return ChatOpenAI(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
temperature=temperature,
|
||||||
|
streaming=True,
|
||||||
|
client=openai.ChatCompletion,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Model {model} is not supported.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(model: str):
|
||||||
|
if "gpt-4" in model or "gpt-3.5" in model:
|
||||||
|
return tiktoken.encoding_for_model(model)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"No encoder implemented for model {model}."
|
||||||
|
"Defaulting to tiktoken cl100k_base encoder."
|
||||||
|
"Use results only as estimates."
|
||||||
|
)
|
||||||
|
return tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_messages(messages: List[Message]) -> str:
|
||||||
|
return AI.serialize_messages(messages)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ def parse_chat(chat): # -> List[Tuple[str, str]]:
|
|||||||
files = []
|
files = []
|
||||||
for match in matches:
|
for match in matches:
|
||||||
# Strip the filename of any non-allowed characters and convert / to \
|
# Strip the filename of any non-allowed characters and convert / to \
|
||||||
path = re.sub(r'[<>"|?*]', "", match.group(1))
|
path = re.sub(r'[\:<>"|?*]', "", match.group(1))
|
||||||
|
|
||||||
# Remove leading and trailing brackets
|
# Remove leading and trailing brackets
|
||||||
path = re.sub(r"^\[(.*)\]$", r"\1", path)
|
path = re.sub(r"^\[(.*)\]$", r"\1", path)
|
||||||
@@ -18,7 +18,7 @@ def parse_chat(chat): # -> List[Tuple[str, str]]:
|
|||||||
path = re.sub(r"^`(.*)`$", r"\1", path)
|
path = re.sub(r"^`(.*)`$", r"\1", path)
|
||||||
|
|
||||||
# Remove trailing ]
|
# Remove trailing ]
|
||||||
path = re.sub(r"\]$", "", path)
|
path = re.sub(r"[\]\:]$", "", path)
|
||||||
|
|
||||||
# Get the code
|
# Get the code
|
||||||
code = match.group(2)
|
code = match.group(2)
|
||||||
|
|||||||
@@ -100,11 +100,11 @@ def check_consent():
|
|||||||
path = Path(".gpte_consent")
|
path = Path(".gpte_consent")
|
||||||
if path.exists() and path.read_text() == "true":
|
if path.exists() and path.read_text() == "true":
|
||||||
return
|
return
|
||||||
ans = input("Is it ok if we store your prompts to learn? (y/n)")
|
answer = input("Is it ok if we store your prompts to learn? (y/n)")
|
||||||
while ans.lower() not in ("y", "n"):
|
while answer.lower() not in ("y", "n"):
|
||||||
ans = input("Invalid input. Please enter y or n: ")
|
answer = input("Invalid input. Please enter y or n: ")
|
||||||
|
|
||||||
if ans.lower() == "y":
|
if answer.lower() == "y":
|
||||||
path.write_text("true")
|
path.write_text("true")
|
||||||
print(colored("Thank you️", "light_green"))
|
print(colored("Thank you️", "light_green"))
|
||||||
print()
|
print()
|
||||||
@@ -153,21 +153,14 @@ def ask_if_can_store() -> bool:
|
|||||||
return can_store == "y"
|
return can_store == "y"
|
||||||
|
|
||||||
|
|
||||||
def logs_to_string(steps: List[Step], logs: DB):
|
def logs_to_string(steps: List[Step], logs: DB) -> str:
|
||||||
chunks = []
|
chunks = []
|
||||||
for step in steps:
|
for step in steps:
|
||||||
chunks.append(f"--- {step.__name__} ---\n")
|
chunks.append(f"--- {step.__name__} ---\n")
|
||||||
messages = json.loads(logs[step.__name__])
|
chunks.append(logs[step.__name__])
|
||||||
chunks.append(format_messages(messages))
|
|
||||||
return "\n".join(chunks)
|
return "\n".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
def format_messages(messages: List[dict]) -> str:
|
|
||||||
return "\n".join(
|
|
||||||
[f"{message['role']}:\n\n{message['content']}" for message in messages]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_learning(
|
def extract_learning(
|
||||||
model: str, temperature: float, steps: List[Step], dbs: DBs, steps_file_hash
|
model: str, temperature: float, steps: List[Step], dbs: DBs, steps_file_hash
|
||||||
) -> Learning:
|
) -> Learning:
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -28,7 +27,7 @@ def main(
|
|||||||
|
|
||||||
model = fallback_model(model)
|
model = fallback_model(model)
|
||||||
ai = AI(
|
ai = AI(
|
||||||
model=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -56,7 +55,7 @@ def main(
|
|||||||
steps = STEPS[steps_config]
|
steps = STEPS[steps_config]
|
||||||
for step in steps:
|
for step in steps:
|
||||||
messages = step(ai, dbs)
|
messages = step(ai, dbs)
|
||||||
dbs.logs[step.__name__] = json.dumps(messages)
|
dbs.logs[step.__name__] = AI.serialize_messages(messages)
|
||||||
|
|
||||||
if collect_consent():
|
if collect_consent():
|
||||||
collect_learnings(model, temperature, steps, dbs)
|
collect_learnings(model, temperature, steps, dbs)
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
|
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from gpt_engineer.ai import AI
|
from gpt_engineer.ai import AI
|
||||||
@@ -13,6 +13,8 @@ from gpt_engineer.chat_to_files import to_files
|
|||||||
from gpt_engineer.db import DBs
|
from gpt_engineer.db import DBs
|
||||||
from gpt_engineer.learning import human_input
|
from gpt_engineer.learning import human_input
|
||||||
|
|
||||||
|
Message = Union[AIMessage, HumanMessage, SystemMessage]
|
||||||
|
|
||||||
|
|
||||||
def setup_sys_prompt(dbs: DBs) -> str:
|
def setup_sys_prompt(dbs: DBs) -> str:
|
||||||
return (
|
return (
|
||||||
@@ -44,26 +46,27 @@ def curr_fn() -> str:
|
|||||||
# All steps below have the signature Step
|
# All steps below have the signature Step
|
||||||
|
|
||||||
|
|
||||||
def simple_gen(ai: AI, dbs: DBs) -> List[dict]:
|
def simple_gen(ai: AI, dbs: DBs) -> List[Message]:
|
||||||
"""Run the AI on the main prompt and save the results"""
|
"""Run the AI on the main prompt and save the results"""
|
||||||
messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs), step_name=curr_fn())
|
messages = ai.start(setup_sys_prompt(dbs), get_prompt(dbs), step_name=curr_fn())
|
||||||
to_files(messages[-1]["content"], dbs.workspace)
|
to_files(messages[-1].content.strip(), dbs.workspace)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def clarify(ai: AI, dbs: DBs) -> List[dict]:
|
def clarify(ai: AI, dbs: DBs) -> List[Message]:
|
||||||
"""
|
"""
|
||||||
Ask the user if they want to clarify anything and save the results to the workspace
|
Ask the user if they want to clarify anything and save the results to the workspace
|
||||||
"""
|
"""
|
||||||
messages = [ai.fsystem(dbs.preprompts["qa"])]
|
messages: List[Message] = [ai.fsystem(dbs.preprompts["clarify"])]
|
||||||
user_input = get_prompt(dbs)
|
user_input = get_prompt(dbs)
|
||||||
while True:
|
while True:
|
||||||
messages = ai.next(messages, user_input, step_name=curr_fn())
|
messages = ai.next(messages, user_input, step_name=curr_fn())
|
||||||
|
msg = messages[-1].content.strip()
|
||||||
|
|
||||||
if messages[-1]["content"].strip() == "Nothing more to clarify.":
|
if msg == "Nothing more to clarify.":
|
||||||
break
|
break
|
||||||
|
|
||||||
if messages[-1]["content"].strip().lower().startswith("no"):
|
if msg.lower().startswith("no"):
|
||||||
print("Nothing more to clarify.")
|
print("Nothing more to clarify.")
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -94,7 +97,7 @@ def clarify(ai: AI, dbs: DBs) -> List[dict]:
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def gen_spec(ai: AI, dbs: DBs) -> List[dict]:
|
def gen_spec(ai: AI, dbs: DBs) -> List[Message]:
|
||||||
"""
|
"""
|
||||||
Generate a spec from the main prompt + clarifications and save the results to
|
Generate a spec from the main prompt + clarifications and save the results to
|
||||||
the workspace
|
the workspace
|
||||||
@@ -106,13 +109,13 @@ def gen_spec(ai: AI, dbs: DBs) -> List[dict]:
|
|||||||
|
|
||||||
messages = ai.next(messages, dbs.preprompts["spec"], step_name=curr_fn())
|
messages = ai.next(messages, dbs.preprompts["spec"], step_name=curr_fn())
|
||||||
|
|
||||||
dbs.memory["specification"] = messages[-1]["content"]
|
dbs.memory["specification"] = messages[-1].content.strip()
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def respec(ai: AI, dbs: DBs) -> List[dict]:
|
def respec(ai: AI, dbs: DBs) -> List[Message]:
|
||||||
messages = json.loads(dbs.logs[gen_spec.__name__])
|
messages = AI.deserialize_messages(dbs.logs[gen_spec.__name__])
|
||||||
messages += [ai.fsystem(dbs.preprompts["respec"])]
|
messages += [ai.fsystem(dbs.preprompts["respec"])]
|
||||||
|
|
||||||
messages = ai.next(messages, step_name=curr_fn())
|
messages = ai.next(messages, step_name=curr_fn())
|
||||||
@@ -129,7 +132,7 @@ def respec(ai: AI, dbs: DBs) -> List[dict]:
|
|||||||
step_name=curr_fn(),
|
step_name=curr_fn(),
|
||||||
)
|
)
|
||||||
|
|
||||||
dbs.memory["specification"] = messages[-1]["content"]
|
dbs.memory["specification"] = messages[-1].content.strip()
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@@ -145,7 +148,7 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]:
|
|||||||
|
|
||||||
messages = ai.next(messages, dbs.preprompts["unit_tests"], step_name=curr_fn())
|
messages = ai.next(messages, dbs.preprompts["unit_tests"], step_name=curr_fn())
|
||||||
|
|
||||||
dbs.memory["unit_tests"] = messages[-1]["content"]
|
dbs.memory["unit_tests"] = messages[-1].content.strip()
|
||||||
to_files(dbs.memory["unit_tests"], dbs.workspace)
|
to_files(dbs.memory["unit_tests"], dbs.workspace)
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
@@ -153,14 +156,14 @@ def gen_unit_tests(ai: AI, dbs: DBs) -> List[dict]:
|
|||||||
|
|
||||||
def gen_clarified_code(ai: AI, dbs: DBs) -> List[dict]:
|
def gen_clarified_code(ai: AI, dbs: DBs) -> List[dict]:
|
||||||
"""Takes clarification and generates code"""
|
"""Takes clarification and generates code"""
|
||||||
messages = json.loads(dbs.logs[clarify.__name__])
|
messages = AI.deserialize_messages(dbs.logs[clarify.__name__])
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
ai.fsystem(setup_sys_prompt(dbs)),
|
ai.fsystem(setup_sys_prompt(dbs)),
|
||||||
] + messages[1:]
|
] + messages[1:]
|
||||||
messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn())
|
messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn())
|
||||||
|
|
||||||
to_files(messages[-1]["content"], dbs.workspace)
|
to_files(messages[-1].content.strip(), dbs.workspace)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@@ -173,7 +176,7 @@ def gen_code(ai: AI, dbs: DBs) -> List[dict]:
|
|||||||
ai.fuser(f"Unit tests:\n\n{dbs.memory['unit_tests']}"),
|
ai.fuser(f"Unit tests:\n\n{dbs.memory['unit_tests']}"),
|
||||||
]
|
]
|
||||||
messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn())
|
messages = ai.next(messages, dbs.preprompts["use_qa"], step_name=curr_fn())
|
||||||
to_files(messages[-1]["content"], dbs.workspace)
|
to_files(messages[-1].content.strip(), dbs.workspace)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@@ -235,7 +238,7 @@ def gen_entrypoint(ai: AI, dbs: DBs) -> List[dict]:
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
regex = r"```\S*\n(.+?)```"
|
regex = r"```\S*\n(.+?)```"
|
||||||
matches = re.finditer(regex, messages[-1]["content"], re.DOTALL)
|
matches = re.finditer(regex, messages[-1].content.strip(), re.DOTALL)
|
||||||
dbs.workspace["run.sh"] = "\n".join(match.group(1) for match in matches)
|
dbs.workspace["run.sh"] = "\n".join(match.group(1) for match in matches)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@@ -248,12 +251,13 @@ def use_feedback(ai: AI, dbs: DBs):
|
|||||||
ai.fsystem(dbs.preprompts["use_feedback"]),
|
ai.fsystem(dbs.preprompts["use_feedback"]),
|
||||||
]
|
]
|
||||||
messages = ai.next(messages, dbs.input["feedback"], step_name=curr_fn())
|
messages = ai.next(messages, dbs.input["feedback"], step_name=curr_fn())
|
||||||
to_files(messages[-1]["content"], dbs.workspace)
|
to_files(messages[-1].content.strip(), dbs.workspace)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def fix_code(ai: AI, dbs: DBs):
|
def fix_code(ai: AI, dbs: DBs):
|
||||||
code_output = json.loads(dbs.logs[gen_code.__name__])[-1]["content"]
|
messages = AI.deserialize_messages(dbs.logs[gen_code.__name__])
|
||||||
|
code_output = messages[-1].content.strip()
|
||||||
messages = [
|
messages = [
|
||||||
ai.fsystem(setup_sys_prompt(dbs)),
|
ai.fsystem(setup_sys_prompt(dbs)),
|
||||||
ai.fuser(f"Instructions: {dbs.input['prompt']}"),
|
ai.fuser(f"Instructions: {dbs.input['prompt']}"),
|
||||||
@@ -263,7 +267,7 @@ def fix_code(ai: AI, dbs: DBs):
|
|||||||
messages = ai.next(
|
messages = ai.next(
|
||||||
messages, "Please fix any errors in the code above.", step_name=curr_fn()
|
messages, "Please fix any errors in the code above.", step_name=curr_fn()
|
||||||
)
|
)
|
||||||
to_files(messages[-1]["content"], dbs.workspace)
|
to_files(messages[-1].content.strip(), dbs.workspace)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ dependencies = [
|
|||||||
'dataclasses-json == 0.5.7',
|
'dataclasses-json == 0.5.7',
|
||||||
'tiktoken',
|
'tiktoken',
|
||||||
'tabulate == 0.9.0',
|
'tabulate == 0.9.0',
|
||||||
|
'langchain',
|
||||||
]
|
]
|
||||||
|
|
||||||
classifiers = [
|
classifiers = [
|
||||||
|
|||||||
@@ -19,14 +19,14 @@ def main(
|
|||||||
temperature: float = 0.1,
|
temperature: float = 0.1,
|
||||||
):
|
):
|
||||||
ai = AI(
|
ai = AI(
|
||||||
model=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(messages_path) as f:
|
with open(messages_path) as f:
|
||||||
messages = json.load(f)
|
messages = json.load(f)
|
||||||
|
|
||||||
messages = ai.next(messages)
|
messages = ai.next(messages, step_name="rerun")
|
||||||
|
|
||||||
if out_path:
|
if out_path:
|
||||||
to_files(messages[-1]["content"], out_path)
|
to_files(messages[-1]["content"], out_path)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def test_collect_learnings(monkeypatch):
|
|||||||
b = {k: v for k, v in learnings.to_dict().items() if k != "timestamp"}
|
b = {k: v for k, v in learnings.to_dict().items() if k != "timestamp"}
|
||||||
assert a == b
|
assert a == b
|
||||||
|
|
||||||
assert code in learnings.logs
|
assert json.dumps(code) in learnings.logs
|
||||||
assert code in learnings.workspace
|
assert code in learnings.workspace
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user