refactor: cleanup

This commit is contained in:
Florian Hönicke
2023-05-11 00:05:33 +02:00
parent bcb1804997
commit 69ebebf4a4
5 changed files with 42 additions and 39 deletions

View File

@@ -2,16 +2,15 @@ import json
import os import os
from copy import deepcopy from copy import deepcopy
from time import sleep from time import sleep
from typing import List, Any from typing import List, Any
import openai import openai
from langchain import PromptTemplate from langchain import PromptTemplate
from langchain.callbacks import CallbackManager from langchain.callbacks import CallbackManager
from langchain.chat_models import ChatOpenAI
from openai.error import RateLimitError
from langchain.schema import HumanMessage, SystemMessage, BaseMessage, AIMessage
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage, BaseMessage, AIMessage
from openai.error import RateLimitError
from requests.exceptions import ConnectionError, ChunkedEncodingError from requests.exceptions import ConnectionError, ChunkedEncodingError
from urllib3.exceptions import InvalidChunkLength from urllib3.exceptions import InvalidChunkLength
@@ -32,6 +31,7 @@ If you have updated it already, please restart your terminal.
exit(1) exit(1)
openai.api_key = os.environ['OPENAI_API_KEY'] openai.api_key = os.environ['OPENAI_API_KEY']
class GPTSession: class GPTSession:
_instance = None _instance = None
_initialized = False _initialized = False
@@ -64,7 +64,6 @@ class GPTSession:
self.model_name, self.cost_callback, messages, print_stream, print_costs self.model_name, self.cost_callback, messages, print_stream, print_costs
) )
@staticmethod @staticmethod
def is_gpt4_available(): def is_gpt4_available():
try: try:
@@ -120,7 +119,6 @@ class _GPTConversation:
self.print_stream = print_stream self.print_stream = print_stream
self.print_costs = print_costs self.print_costs = print_costs
def print_messages(self, messages): def print_messages(self, messages):
for i, message in enumerate(messages): for i, message in enumerate(messages):
if os.environ['VERBOSE'].lower() == 'true': if os.environ['VERBOSE'].lower() == 'true':
@@ -167,8 +165,9 @@ class _GPTConversation:
def ask_gpt(prompt_template, parser, **kwargs): def ask_gpt(prompt_template, parser, **kwargs):
template_parameters = get_template_parameters(prompt_template) template_parameters = get_template_parameters(prompt_template)
if set(template_parameters) != set(kwargs.keys()): if set(template_parameters) != set(kwargs.keys()):
raise ValueError(f'Prompt template parameters {get_template_parameters(prompt_template)} do not match ' raise ValueError(
f'provided parameters {kwargs.keys()}') f'Prompt template parameters {set(template_parameters)} do not match provided parameters {set(kwargs.keys())}'
)
for key, value in kwargs.items(): for key, value in kwargs.items():
if isinstance(value, dict): if isinstance(value, dict):
kwargs[key] = json.dumps(value, indent=4) kwargs[key] = json.dumps(value, indent=4)

View File

@@ -1,29 +1,29 @@
from typing import Dict # from typing import Dict
#
from dev_gpt.apis.gpt import ask_gpt # from dev_gpt.apis.gpt import ask_gpt
from dev_gpt.options.generate.chains.question_answering import answer_yes_no_question # from dev_gpt.options.generate.chains.question_answering import answer_yes_no_question
from dev_gpt.options.generate.parser import identity_parser, boolean_parser # from dev_gpt.options.generate.parser import identity_parser, boolean_parser
#
#
def extract_information(text, info_keys) -> Dict[str, str]: # def extract_information(text, info_keys) -> Dict[str, str]:
extracted_infos = {} # extracted_infos = {}
for info_key in info_keys: # for info_key in info_keys:
is_information_in_text = answer_yes_no_question(text, f'Is a {info_key} mentioned above?') # is_information_in_text = answer_yes_no_question(text, f'Is a {info_key} mentioned above?')
if is_information_in_text: # if is_information_in_text:
extracted_info = ask_gpt( # extracted_info = ask_gpt(
extract_information_prompt, # extract_information_prompt,
identity_parser, # identity_parser,
text=text, # text=text,
info_key=info_key # info_key=info_key
) # )
extracted_infos[info_key] = extracted_info # extracted_infos[info_key] = extracted_info
return extracted_infos # return extracted_infos
#
#
extract_information_prompt = '''\ # extract_information_prompt = '''\
{text} # {text}
#
Your task: # Your task:
Return all {info_key}s from above.' # Return all {info_key}s from above.'
Note: you must only output your answer. # Note: you must only output your answer.
''' # '''

View File

@@ -6,6 +6,7 @@ from dev_gpt.options.generate.chains.translation import translation
from dev_gpt.options.generate.chains.user_confirmation_feedback_loop import user_feedback_loop from dev_gpt.options.generate.chains.user_confirmation_feedback_loop import user_feedback_loop
from dev_gpt.options.generate.chains.get_user_input_if_needed import get_user_input_if_needed from dev_gpt.options.generate.chains.get_user_input_if_needed import get_user_input_if_needed
from dev_gpt.options.generate.parser import identity_parser from dev_gpt.options.generate.parser import identity_parser
from dev_gpt.options.generate.prompt_factory import make_prompt_friendly
from dev_gpt.options.generate.ui import get_random_employee from dev_gpt.options.generate.ui import get_random_employee
@@ -88,6 +89,7 @@ Description of the microservice:
) )
if user_answer: if user_answer:
if post_transformation_fn: if post_transformation_fn:
user_answer = make_prompt_friendly(user_answer)
user_answer = post_transformation_fn(user_answer) user_answer = post_transformation_fn(user_answer)
return f'\n{extension_name}: {user_answer}' return f'\n{extension_name}: {user_answer}'
else: else:

View File

@@ -1,12 +1,14 @@
import json import json
def make_prompt_friendly(text):
return text.replace('{', '{{').replace('}', '}}')
def context_to_string(context): def context_to_string(context):
context_strings = [] context_strings = []
for k, v in context.items(): for k, v in context.items():
if isinstance(v, dict): if isinstance(v, dict):
v = json.dumps(v, indent=4) v = json.dumps(v, indent=4)
v = v.replace('{', '{{').replace('}', '}}') v = make_prompt_friendly(v)
context_strings.append(f'''\ context_strings.append(f'''\
{k}: {k}:
``` ```

View File

@@ -99,7 +99,7 @@ Example input: 'AAPL'
'mock_input_sequence', [ 'mock_input_sequence', [
[ [
'y', 'y',
'https://www.signalogic.com/melp/EngSamples/Orig/ENG_M.wav', 'https://www2.cs.uic.edu/~i101/SoundFiles/taunt.wav',
f'''\ f'''\
import requests import requests
url = "https://transcribe.whisperapi.com" url = "https://transcribe.whisperapi.com"