mirror of
https://github.com/aljazceru/dev-gpt.git
synced 2025-12-20 23:24:20 +01:00
refactor: cleanup
This commit is contained in:
@@ -2,16 +2,15 @@ import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from time import sleep
|
||||
|
||||
from typing import List, Any
|
||||
|
||||
import openai
|
||||
from langchain import PromptTemplate
|
||||
from langchain.callbacks import CallbackManager
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from openai.error import RateLimitError
|
||||
from langchain.schema import HumanMessage, SystemMessage, BaseMessage, AIMessage
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import HumanMessage, SystemMessage, BaseMessage, AIMessage
|
||||
from openai.error import RateLimitError
|
||||
from requests.exceptions import ConnectionError, ChunkedEncodingError
|
||||
from urllib3.exceptions import InvalidChunkLength
|
||||
|
||||
@@ -32,6 +31,7 @@ If you have updated it already, please restart your terminal.
|
||||
exit(1)
|
||||
openai.api_key = os.environ['OPENAI_API_KEY']
|
||||
|
||||
|
||||
class GPTSession:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
@@ -64,7 +64,6 @@ class GPTSession:
|
||||
self.model_name, self.cost_callback, messages, print_stream, print_costs
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def is_gpt4_available():
|
||||
try:
|
||||
@@ -120,7 +119,6 @@ class _GPTConversation:
|
||||
self.print_stream = print_stream
|
||||
self.print_costs = print_costs
|
||||
|
||||
|
||||
def print_messages(self, messages):
|
||||
for i, message in enumerate(messages):
|
||||
if os.environ['VERBOSE'].lower() == 'true':
|
||||
@@ -167,8 +165,9 @@ class _GPTConversation:
|
||||
def ask_gpt(prompt_template, parser, **kwargs):
|
||||
template_parameters = get_template_parameters(prompt_template)
|
||||
if set(template_parameters) != set(kwargs.keys()):
|
||||
raise ValueError(f'Prompt template parameters {get_template_parameters(prompt_template)} do not match '
|
||||
f'provided parameters {kwargs.keys()}')
|
||||
raise ValueError(
|
||||
f'Prompt template parameters {set(template_parameters)} do not match provided parameters {set(kwargs.keys())}'
|
||||
)
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, dict):
|
||||
kwargs[key] = json.dumps(value, indent=4)
|
||||
|
||||
Reference in New Issue
Block a user