mirror of
https://github.com/aljazceru/gpt-engineer.git
synced 2025-12-17 20:55:09 +01:00
64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
import openai
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AI:
|
|
def __init__(self, model="gpt-4", temperature=0.1):
|
|
self.temperature = temperature
|
|
|
|
try:
|
|
openai.Model.retrieve(model)
|
|
self.model = model
|
|
except openai.InvalidRequestError:
|
|
print(
|
|
f"Model {model} not available for provided API key. Reverting "
|
|
"to gpt-3.5-turbo. Sign up for the GPT-4 wait list here: "
|
|
"https://openai.com/waitlist/gpt-4-api"
|
|
)
|
|
self.model = "gpt-3.5-turbo"
|
|
|
|
def start(self, system, user):
|
|
messages = [
|
|
{"role": "system", "content": system},
|
|
{"role": "user", "content": user},
|
|
]
|
|
|
|
return self.next(messages)
|
|
|
|
def fsystem(self, msg):
|
|
return {"role": "system", "content": msg}
|
|
|
|
def fuser(self, msg):
|
|
return {"role": "user", "content": msg}
|
|
|
|
def fassistant(self, msg):
|
|
return {"role": "assistant", "content": msg}
|
|
|
|
def next(self, messages: list[dict[str, str]], prompt=None):
|
|
if prompt:
|
|
messages += [{"role": "user", "content": prompt}]
|
|
|
|
logger.debug(f"Creating a new chat completion: {messages}")
|
|
response = openai.ChatCompletion.create(
|
|
messages=messages,
|
|
stream=True,
|
|
model=self.model,
|
|
temperature=self.temperature,
|
|
)
|
|
|
|
chat = []
|
|
for chunk in response:
|
|
delta = chunk["choices"][0]["delta"]
|
|
msg = delta.get("content", "")
|
|
print(msg, end="")
|
|
chat.append(msg)
|
|
print()
|
|
messages += [{"role": "assistant", "content": "".join(chat)}]
|
|
logger.debug(f"Chat completion finished: {messages}")
|
|
return messages
|