diff --git a/dev_gpt/apis/gpt.py b/dev_gpt/apis/gpt.py index 55b74ab..2f07a02 100644 --- a/dev_gpt/apis/gpt.py +++ b/dev_gpt/apis/gpt.py @@ -2,16 +2,15 @@ import json import os from copy import deepcopy from time import sleep - from typing import List, Any import openai from langchain import PromptTemplate from langchain.callbacks import CallbackManager -from langchain.chat_models import ChatOpenAI -from openai.error import RateLimitError -from langchain.schema import HumanMessage, SystemMessage, BaseMessage, AIMessage from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain.chat_models import ChatOpenAI +from langchain.schema import HumanMessage, SystemMessage, BaseMessage, AIMessage +from openai.error import RateLimitError from requests.exceptions import ConnectionError, ChunkedEncodingError from urllib3.exceptions import InvalidChunkLength @@ -32,6 +31,7 @@ If you have updated it already, please restart your terminal. exit(1) openai.api_key = os.environ['OPENAI_API_KEY'] + class GPTSession: _instance = None _initialized = False @@ -64,7 +64,6 @@ class GPTSession: self.model_name, self.cost_callback, messages, print_stream, print_costs ) - @staticmethod def is_gpt4_available(): try: @@ -120,7 +119,6 @@ class _GPTConversation: self.print_stream = print_stream self.print_costs = print_costs - def print_messages(self, messages): for i, message in enumerate(messages): if os.environ['VERBOSE'].lower() == 'true': @@ -167,8 +165,9 @@ class _GPTConversation: def ask_gpt(prompt_template, parser, **kwargs): template_parameters = get_template_parameters(prompt_template) if set(template_parameters) != set(kwargs.keys()): - raise ValueError(f'Prompt template parameters {get_template_parameters(prompt_template)} do not match ' - f'provided parameters {kwargs.keys()}') + raise ValueError( + f'Prompt template parameters {set(template_parameters)} do not match provided parameters {set(kwargs.keys())}' + ) for key, value in kwargs.items(): if isinstance(value, dict): kwargs[key] = json.dumps(value, indent=4) diff --git a/dev_gpt/options/generate/chains/extract_information.py b/dev_gpt/options/generate/chains/extract_information.py index 5816a85..70aae59 100644 --- a/dev_gpt/options/generate/chains/extract_information.py +++ b/dev_gpt/options/generate/chains/extract_information.py @@ -1,29 +1,29 @@ -from typing import Dict - -from dev_gpt.apis.gpt import ask_gpt -from dev_gpt.options.generate.chains.question_answering import answer_yes_no_question -from dev_gpt.options.generate.parser import identity_parser, boolean_parser - - -def extract_information(text, info_keys) -> Dict[str, str]: - extracted_infos = {} - for info_key in info_keys: - is_information_in_text = answer_yes_no_question(text, f'Is a {info_key} mentioned above?') - if is_information_in_text: - extracted_info = ask_gpt( - extract_information_prompt, - identity_parser, - text=text, - info_key=info_key - ) - extracted_infos[info_key] = extracted_info - return extracted_infos - - -extract_information_prompt = '''\ -{text} - -Your task: -Return all {info_key}s from above.' -Note: you must only output your answer. -''' \ No newline at end of file +# from typing import Dict +# +# from dev_gpt.apis.gpt import ask_gpt +# from dev_gpt.options.generate.chains.question_answering import answer_yes_no_question +# from dev_gpt.options.generate.parser import identity_parser, boolean_parser +# +# +# def extract_information(text, info_keys) -> Dict[str, str]: +# extracted_infos = {} +# for info_key in info_keys: +# is_information_in_text = answer_yes_no_question(text, f'Is a {info_key} mentioned above?') +# if is_information_in_text: +# extracted_info = ask_gpt( +# extract_information_prompt, +# identity_parser, +# text=text, +# info_key=info_key +# ) +# extracted_infos[info_key] = extracted_info +# return extracted_infos +# +# +# extract_information_prompt = '''\ +# {text} +# +# Your task: +# Return all {info_key}s from above.' +# Note: you must only output your answer. +# ''' \ No newline at end of file diff --git a/dev_gpt/options/generate/pm/pm.py b/dev_gpt/options/generate/pm/pm.py index cedf931..9c28b44 100644 --- a/dev_gpt/options/generate/pm/pm.py +++ b/dev_gpt/options/generate/pm/pm.py @@ -6,6 +6,7 @@ from dev_gpt.options.generate.chains.translation import translation from dev_gpt.options.generate.chains.user_confirmation_feedback_loop import user_feedback_loop from dev_gpt.options.generate.chains.get_user_input_if_needed import get_user_input_if_needed from dev_gpt.options.generate.parser import identity_parser +from dev_gpt.options.generate.prompt_factory import make_prompt_friendly from dev_gpt.options.generate.ui import get_random_employee @@ -88,6 +89,7 @@ Description of the microservice: ) if user_answer: if post_transformation_fn: + user_answer = make_prompt_friendly(user_answer) user_answer = post_transformation_fn(user_answer) return f'\n{extension_name}: {user_answer}' else: diff --git a/dev_gpt/options/generate/prompt_factory.py b/dev_gpt/options/generate/prompt_factory.py index 26236a9..fa42f72 100644 --- a/dev_gpt/options/generate/prompt_factory.py +++ b/dev_gpt/options/generate/prompt_factory.py @@ -1,12 +1,14 @@ import json +def make_prompt_friendly(text): + return text.replace('{', '{{').replace('}', '}}') def context_to_string(context): context_strings = [] for k, v in context.items(): if isinstance(v, dict): v = json.dumps(v, indent=4) - v = v.replace('{', '{{').replace('}', '}}') + v = make_prompt_friendly(v) context_strings.append(f'''\ {k}: ``` diff --git a/test/integration/test_generator.py b/test/integration/test_generator.py index 180d0c7..46b051b 100644 --- a/test/integration/test_generator.py +++ b/test/integration/test_generator.py @@ -99,7 +99,7 @@ Example input: 'AAPL' 'mock_input_sequence', [ [ 'y', - 'https://www.signalogic.com/melp/EngSamples/Orig/ENG_M.wav', + 'https://www2.cs.uic.edu/~i101/SoundFiles/taunt.wav', f'''\ import requests url = "https://transcribe.whisperapi.com"