refactor: langchain

This commit is contained in:
Joschka Braun
2023-04-17 15:01:28 +02:00
parent badf295f71
commit deaea68f4f
7 changed files with 76 additions and 132 deletions

View File

@@ -5,3 +5,4 @@ psutil
jina
jcloud
jina-hubble-sdk
langchain

View File

@@ -1,15 +1,21 @@
import os
from time import sleep
from typing import List, Tuple, Optional
from typing import List, Any
import openai
from openai.error import RateLimitError, Timeout, APIConnectionError
from langchain.callbacks import CallbackManager
from langchain.chat_models import ChatOpenAI
from openai.error import RateLimitError
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage,
BaseMessage
)
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from src.constants import PRICING_GPT4_PROMPT, PRICING_GPT4_GENERATION, PRICING_GPT3_5_TURBO_PROMPT, \
PRICING_GPT3_5_TURBO_GENERATION, CHARS_PER_TOKEN
from src.options.generate.prompt_system import system_base_definition, executor_example, docarray_example, client_example
from src.utils.io import timeout_generator_wrapper, GenerationTimeoutError
from src.options.generate.prompt_system import system_message_base, executor_example, docarray_example, client_example
from src.utils.string_tools import print_colored
@@ -18,19 +24,14 @@ class GPTSession:
self.configure_openai_api_key()
if model == 'gpt-4' and self.is_gpt4_available():
self.supported_model = 'gpt-4'
self.pricing_prompt = PRICING_GPT4_PROMPT
self.pricing_generation = PRICING_GPT4_GENERATION
else:
if model == 'gpt-4':
print_colored('GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow')
model = 'gpt-3.5-turbo'
self.supported_model = model
self.pricing_prompt = PRICING_GPT3_5_TURBO_PROMPT
self.pricing_generation = PRICING_GPT3_5_TURBO_GENERATION
self.chars_prompt_so_far = 0
self.chars_generation_so_far = 0
def configure_openai_api_key(self):
@staticmethod
def configure_openai_api_key():
if 'OPENAI_API_KEY' not in os.environ:
raise Exception('''
You need to set OPENAI_API_KEY in your environment.
@@ -39,7 +40,8 @@ If you have updated it already, please restart your terminal.
)
openai.api_key = os.environ['OPENAI_API_KEY']
def is_gpt4_available(self):
@staticmethod
def is_gpt4_available():
try:
for i in range(5):
try:
@@ -58,87 +60,47 @@ If you have updated it already, please restart your terminal.
except openai.error.InvalidRequestError:
return False
def cost_callback(self, chars_prompt, chars_generation):
self.chars_prompt_so_far += chars_prompt
self.chars_generation_so_far += chars_generation
print('\n')
money_prompt = self.calculate_money_spent(self.chars_prompt_so_far, self.pricing_prompt)
money_generation = self.calculate_money_spent(self.chars_generation_so_far, self.pricing_generation)
print('Total money spent so far on openai.com:', f'${money_prompt + money_generation}')
print('\n')
def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
return _GPTConversation(self.supported_model, self.cost_callback, system_definition_examples)
return _GPTConversation(self.supported_model, system_definition_examples)
def calculate_money_spent(self, num_chars, price):
return round(num_chars / CHARS_PER_TOKEN * price / 1000, 3)
class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
if os.environ['VERBOSE'].lower() == 'true':
print_colored('', token, 'green', end='')
class _GPTConversation:
def __init__(self, model: str, cost_callback, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
self.model = model
self.cost_callback = cost_callback
self.prompt_list: List[Optional[Tuple]] = [None]
self.set_system_definition(system_definition_examples)
def __init__(self, model: str, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
self.chat = ChatOpenAI(
model_name=model,
streaming=True,
callback_manager=CallbackManager([AssistantStreamingStdOutCallbackHandler()]),
temperature=0
)
self.messages: List[BaseMessage] = []
self.system_message = self._create_system_message(system_definition_examples)
if os.environ['VERBOSE'].lower() == 'true':
print_colored('system', self.prompt_list[0][1], 'magenta')
print_colored('system', self.system_message.content, 'magenta')
def query(self, prompt: str):
def chat(self, prompt: str):
chat_message = HumanMessage(content=prompt)
self.messages.append(chat_message)
if os.environ['VERBOSE'].lower() == 'true':
print_colored('user', prompt, 'blue')
self.prompt_list.append(('user', prompt))
response = self.get_response(self.prompt_list)
self.prompt_list.append(('assistant', response))
print_colored('assistant', '', 'green', end='')
response = self.chat([self.system_message] + self.messages)
self.messages.append(AIMessage(content=response))
return response
def set_system_definition(self, system_definition_examples: List[str] = []):
system_message = system_base_definition
@staticmethod
def _create_system_message(system_definition_examples: List[str] = []) -> SystemMessage:
system_message = system_message_base
if 'executor' in system_definition_examples:
system_message += f'\n{executor_example}'
if 'docarray' in system_definition_examples:
system_message += f'\n{docarray_example}'
if 'client' in system_definition_examples:
system_message += f'\n{client_example}'
self.prompt_list[0] = ('system', system_message)
def get_response_from_stream(self, response_generator):
response_generator_with_timeout = timeout_generator_wrapper(response_generator, 10)
complete_string = ''
for chunk in response_generator_with_timeout:
delta = chunk['choices'][0]['delta']
if 'content' in delta:
content = delta['content']
print_colored('' if complete_string else 'assistant', content, 'green', end='')
complete_string += content
return complete_string
def get_response(self, prompt_list: List[Tuple[str, str]]):
for i in range(10):
try:
response_generator = openai.ChatCompletion.create(
temperature=0,
max_tokens=None,
model=self.model,
stream=True,
messages=[
{
"role": prompt[0],
"content": prompt[1]
}
for prompt in prompt_list
]
)
complete_string = self.get_response_from_stream(response_generator)
except (RateLimitError, Timeout, ConnectionError, APIConnectionError, GenerationTimeoutError) as e:
print('/n', e)
print('retrying...')
sleep(3)
continue
chars_prompt = sum(len(prompt[1]) for prompt in prompt_list)
chars_generation = len(complete_string)
self.cost_callback(chars_prompt, chars_generation)
return complete_string
raise Exception('Failed to get response')
return SystemMessage(content=system_message)

View File

@@ -2,6 +2,7 @@ import functools
import os
import click
from langchain.callbacks import get_openai_callback
from src.apis.jina_cloud import jina_auth_login
from src.options.configure.key_handling import set_api_key
@@ -64,7 +65,11 @@ def generate(
from src.options.generate.generator import Generator
generator = Generator(model=model)
generator.generate(description, test, path)
with get_openai_callback() as cb:
generator.generate(description, test, path)
print(f"Prompt/Completion/Total Tokens: {cb.prompt_tokens}/{cb.completion_tokens}/{cb.total_tokens}")
print(f"Total Cost on OpenAI (USD): ${cb.total_cost}")
@main.command()
@path_param

View File

@@ -23,12 +23,6 @@ FILE_AND_TAG_PAIRS = [
FLOW_URL_PLACEHOLDER = 'jcloud.jina.ai'
PRICING_GPT4_PROMPT = 0.03
PRICING_GPT4_GENERATION = 0.06
PRICING_GPT3_5_TURBO_PROMPT = 0.002
PRICING_GPT3_5_TURBO_GENERATION = 0.002
CHARS_PER_TOKEN = 3.4
NUM_IMPLEMENTATION_STRATEGIES = 5
MAX_DEBUGGING_ITERATIONS = 10

View File

@@ -74,15 +74,15 @@ class Generator:
+ executor_file_task(microservice_name, description, test, package)
)
conversation = self.gpt_session.get_conversation()
microservice_content_raw = conversation.query(user_query)
microservice_content_raw = conversation.chat(user_query)
if is_chain_of_thought:
microservice_content_raw = conversation.query(
microservice_content_raw = conversation.chat(
f"General rules: " + not_allowed_executor() + chain_of_thought_optimization('python',
'microservice.py'))
microservice_content = self.extract_content_from_result(microservice_content_raw, 'microservice.py',
match_single_block=True)
if microservice_content == '':
microservice_content_raw = conversation.query('You must add the executor code.')
microservice_content_raw = conversation.chat('You must add the executor code.')
microservice_content = self.extract_content_from_result(
microservice_content_raw, 'microservice.py', match_single_block=True
)
@@ -95,9 +95,9 @@ class Generator:
+ test_executor_file_task(microservice_name, test)
)
conversation = self.gpt_session.get_conversation()
test_microservice_content_raw = conversation.query(user_query)
test_microservice_content_raw = conversation.chat(user_query)
if is_chain_of_thought:
test_microservice_content_raw = conversation.query(
test_microservice_content_raw = conversation.chat(
f"General rules: " + not_allowed_executor() +
chain_of_thought_optimization('python', 'test_microservice.py')
+ "Don't add any additional tests. "
@@ -116,9 +116,9 @@ class Generator:
+ requirements_file_task()
)
conversation = self.gpt_session.get_conversation()
requirements_content_raw = conversation.query(user_query)
requirements_content_raw = conversation.chat(user_query)
if is_chain_of_thought:
requirements_content_raw = conversation.query(
requirements_content_raw = conversation.chat(
chain_of_thought_optimization('', requirements_path) + "Keep the same version of jina ")
requirements_content = self.extract_content_from_result(requirements_content_raw, 'requirements.txt',
@@ -134,9 +134,9 @@ class Generator:
+ docker_file_task()
)
conversation = self.gpt_session.get_conversation()
dockerfile_content_raw = conversation.query(user_query)
dockerfile_content_raw = conversation.chat(user_query)
if is_chain_of_thought:
dockerfile_content_raw = conversation.query(
dockerfile_content_raw = conversation.chat(
f"General rules: " + not_allowed_executor() + chain_of_thought_optimization('dockerfile', 'Dockerfile'))
dockerfile_content = self.extract_content_from_result(dockerfile_content_raw, 'Dockerfile',
match_single_block=True)
@@ -172,8 +172,8 @@ The playground (app.py) must not let the user configure the host on the ui.
'''
)
conversation = self.gpt_session.get_conversation([])
conversation.query(user_query)
playground_content_raw = conversation.query(chain_of_thought_optimization('python', 'app.py', 'the playground'))
conversation.chat(user_query)
playground_content_raw = conversation.chat(chain_of_thought_optimization('python', 'app.py', 'the playground'))
playground_content = self.extract_content_from_result(playground_content_raw, 'app.py', match_single_block=True)
persist_file(playground_content, os.path.join(microservice_path, 'app.py'))
@@ -213,7 +213,7 @@ The playground (app.py) must not let the user configure the host on the ui.
user_query = self.get_user_query_code_issue(description, error, file_name_to_content,
test)
conversation = self.gpt_session.get_conversation()
returned_files_raw = conversation.query(user_query)
returned_files_raw = conversation.chat(user_query)
for file_name, tag in FILE_AND_TAG_PAIRS:
updated_file = self.extract_content_from_result(returned_files_raw, file_name)
if updated_file and (not is_dependency_issue or file_name in ['requirements.txt', 'Dockerfile']):
@@ -280,7 +280,7 @@ complete file. Use the exact same syntax to wrap the code:
print_colored('', 'Is it a dependency issue?', 'blue')
conversation = self.gpt_session.get_conversation([])
answer = conversation.query(
answer = conversation.chat(
f'Your task is to assist in identifying the root cause of a Docker build error for a python application. '
f'The error message is as follows::\n\n{error}\n\n'
f'The docker file is as follows:\n\n{docker_file}\n\n'
@@ -305,7 +305,7 @@ The output is a the raw string wrapped into ``` and starting with **name.txt** l
PDFParserExecutor
```
'''
name_raw = conversation.query(user_query)
name_raw = conversation.chat(user_query)
name = self.extract_content_from_result(name_raw, 'name.txt')
return name
@@ -341,7 +341,7 @@ package5
```
'''
conversation = self.gpt_session.get_conversation()
packages_raw = conversation.query(user_query)
packages_raw = conversation.chat(user_query)
packages_csv_string = self.extract_content_from_result(packages_raw, 'packages.csv')
packages = [package.split(',') for package in packages_csv_string.split('\n')]
packages = packages[:NUM_IMPLEMENTATION_STRATEGIES]
@@ -351,13 +351,17 @@ package5
generated_name = self.generate_microservice_name(description)
microservice_name = f'{generated_name}{random.randint(0, 10_000_000)}'
packages_list = self.get_possible_packages(description)
packages_list = [packages for packages in packages_list if len(set(packages).intersection(set(PROBLEMATIC_PACKAGES))) == 0]
packages_list = [
packages for packages in packages_list if len(set(packages).intersection(set(PROBLEMATIC_PACKAGES))) == 0
]
for num_approach, packages in enumerate(packages_list):
try:
self.generate_microservice(description, test, microservice_path, microservice_name, packages,
num_approach)
final_version_path = self.debug_microservice(microservice_path, microservice_name, num_approach,
packages, description, test)
self.generate_microservice(
description, test, microservice_path, microservice_name, packages, num_approach
)
final_version_path = self.debug_microservice(
microservice_path, microservice_name, num_approach, packages, description, test
)
self.generate_playground(microservice_name, final_version_path)
except self.MaxDebugTimeReachedException:
print('Could not debug the Microservice with the approach:', packages)

View File

@@ -71,7 +71,7 @@ print(response[0].text)
```'''
system_base_definition = f'''
system_message_base = f'''
It is the year 2021.
You are a principal engineer working at Jina - an open source company.
You accurately satisfy all of the user's requirements.

View File

@@ -31,28 +31,6 @@ def get_all_microservice_files_with_content(folder_path):
return file_name_to_content
class GenerationTimeoutError(Exception):
pass
def timeout_generator_wrapper(generator, timeout):
def generator_func():
for item in generator:
yield item
def wrapper() -> Generator:
gen = generator_func()
while True:
try:
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(next, gen)
yield future.result(timeout=timeout)
except StopIteration:
break
except concurrent.futures.TimeoutError:
raise GenerationTimeoutError(f"Generation took too long")
return wrapper()
@contextmanager
def suppress_stdout():
original_stdout = sys.stdout