refactor: langchain

This commit is contained in:
Joschka Braun
2023-04-17 15:01:28 +02:00
parent badf295f71
commit deaea68f4f
7 changed files with 76 additions and 132 deletions

View File

@@ -5,3 +5,4 @@ psutil
jina jina
jcloud jcloud
jina-hubble-sdk jina-hubble-sdk
langchain

View File

@@ -1,15 +1,21 @@
import os import os
from time import sleep from time import sleep
from typing import List, Tuple, Optional from typing import List, Any
import openai import openai
from openai.error import RateLimitError, Timeout, APIConnectionError from langchain.callbacks import CallbackManager
from langchain.chat_models import ChatOpenAI
from openai.error import RateLimitError
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage,
BaseMessage
)
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from src.constants import PRICING_GPT4_PROMPT, PRICING_GPT4_GENERATION, PRICING_GPT3_5_TURBO_PROMPT, \ from src.options.generate.prompt_system import system_message_base, executor_example, docarray_example, client_example
PRICING_GPT3_5_TURBO_GENERATION, CHARS_PER_TOKEN
from src.options.generate.prompt_system import system_base_definition, executor_example, docarray_example, client_example
from src.utils.io import timeout_generator_wrapper, GenerationTimeoutError
from src.utils.string_tools import print_colored from src.utils.string_tools import print_colored
@@ -18,19 +24,14 @@ class GPTSession:
self.configure_openai_api_key() self.configure_openai_api_key()
if model == 'gpt-4' and self.is_gpt4_available(): if model == 'gpt-4' and self.is_gpt4_available():
self.supported_model = 'gpt-4' self.supported_model = 'gpt-4'
self.pricing_prompt = PRICING_GPT4_PROMPT
self.pricing_generation = PRICING_GPT4_GENERATION
else: else:
if model == 'gpt-4': if model == 'gpt-4':
print_colored('GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow') print_colored('GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow')
model = 'gpt-3.5-turbo' model = 'gpt-3.5-turbo'
self.supported_model = model self.supported_model = model
self.pricing_prompt = PRICING_GPT3_5_TURBO_PROMPT
self.pricing_generation = PRICING_GPT3_5_TURBO_GENERATION
self.chars_prompt_so_far = 0
self.chars_generation_so_far = 0
def configure_openai_api_key(self): @staticmethod
def configure_openai_api_key():
if 'OPENAI_API_KEY' not in os.environ: if 'OPENAI_API_KEY' not in os.environ:
raise Exception(''' raise Exception('''
You need to set OPENAI_API_KEY in your environment. You need to set OPENAI_API_KEY in your environment.
@@ -39,7 +40,8 @@ If you have updated it already, please restart your terminal.
) )
openai.api_key = os.environ['OPENAI_API_KEY'] openai.api_key = os.environ['OPENAI_API_KEY']
def is_gpt4_available(self): @staticmethod
def is_gpt4_available():
try: try:
for i in range(5): for i in range(5):
try: try:
@@ -58,87 +60,47 @@ If you have updated it already, please restart your terminal.
except openai.error.InvalidRequestError: except openai.error.InvalidRequestError:
return False return False
def cost_callback(self, chars_prompt, chars_generation):
self.chars_prompt_so_far += chars_prompt
self.chars_generation_so_far += chars_generation
print('\n')
money_prompt = self.calculate_money_spent(self.chars_prompt_so_far, self.pricing_prompt)
money_generation = self.calculate_money_spent(self.chars_generation_so_far, self.pricing_generation)
print('Total money spent so far on openai.com:', f'${money_prompt + money_generation}')
print('\n')
def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']): def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
return _GPTConversation(self.supported_model, self.cost_callback, system_definition_examples) return _GPTConversation(self.supported_model, system_definition_examples)
def calculate_money_spent(self, num_chars, price):
return round(num_chars / CHARS_PER_TOKEN * price / 1000, 3) class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
if os.environ['VERBOSE'].lower() == 'true':
print_colored('', token, 'green', end='')
class _GPTConversation: class _GPTConversation:
def __init__(self, model: str, cost_callback, system_definition_examples: List[str] = ['executor', 'docarray', 'client']): def __init__(self, model: str, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
self.model = model self.chat = ChatOpenAI(
self.cost_callback = cost_callback model_name=model,
self.prompt_list: List[Optional[Tuple]] = [None] streaming=True,
self.set_system_definition(system_definition_examples) callback_manager=CallbackManager([AssistantStreamingStdOutCallbackHandler()]),
temperature=0
)
self.messages: List[BaseMessage] = []
self.system_message = self._create_system_message(system_definition_examples)
if os.environ['VERBOSE'].lower() == 'true': if os.environ['VERBOSE'].lower() == 'true':
print_colored('system', self.prompt_list[0][1], 'magenta') print_colored('system', self.system_message.content, 'magenta')
def query(self, prompt: str): def chat(self, prompt: str):
chat_message = HumanMessage(content=prompt)
self.messages.append(chat_message)
if os.environ['VERBOSE'].lower() == 'true': if os.environ['VERBOSE'].lower() == 'true':
print_colored('user', prompt, 'blue') print_colored('user', prompt, 'blue')
self.prompt_list.append(('user', prompt)) print_colored('assistant', '', 'green', end='')
response = self.get_response(self.prompt_list) response = self.chat([self.system_message] + self.messages)
self.prompt_list.append(('assistant', response)) self.messages.append(AIMessage(content=response))
return response return response
def set_system_definition(self, system_definition_examples: List[str] = []): @staticmethod
system_message = system_base_definition def _create_system_message(system_definition_examples: List[str] = []) -> SystemMessage:
system_message = system_message_base
if 'executor' in system_definition_examples: if 'executor' in system_definition_examples:
system_message += f'\n{executor_example}' system_message += f'\n{executor_example}'
if 'docarray' in system_definition_examples: if 'docarray' in system_definition_examples:
system_message += f'\n{docarray_example}' system_message += f'\n{docarray_example}'
if 'client' in system_definition_examples: if 'client' in system_definition_examples:
system_message += f'\n{client_example}' system_message += f'\n{client_example}'
self.prompt_list[0] = ('system', system_message) return SystemMessage(content=system_message)
def get_response_from_stream(self, response_generator):
response_generator_with_timeout = timeout_generator_wrapper(response_generator, 10)
complete_string = ''
for chunk in response_generator_with_timeout:
delta = chunk['choices'][0]['delta']
if 'content' in delta:
content = delta['content']
print_colored('' if complete_string else 'assistant', content, 'green', end='')
complete_string += content
return complete_string
def get_response(self, prompt_list: List[Tuple[str, str]]):
for i in range(10):
try:
response_generator = openai.ChatCompletion.create(
temperature=0,
max_tokens=None,
model=self.model,
stream=True,
messages=[
{
"role": prompt[0],
"content": prompt[1]
}
for prompt in prompt_list
]
)
complete_string = self.get_response_from_stream(response_generator)
except (RateLimitError, Timeout, ConnectionError, APIConnectionError, GenerationTimeoutError) as e:
print('/n', e)
print('retrying...')
sleep(3)
continue
chars_prompt = sum(len(prompt[1]) for prompt in prompt_list)
chars_generation = len(complete_string)
self.cost_callback(chars_prompt, chars_generation)
return complete_string
raise Exception('Failed to get response')

View File

@@ -2,6 +2,7 @@ import functools
import os import os
import click import click
from langchain.callbacks import get_openai_callback
from src.apis.jina_cloud import jina_auth_login from src.apis.jina_cloud import jina_auth_login
from src.options.configure.key_handling import set_api_key from src.options.configure.key_handling import set_api_key
@@ -64,7 +65,11 @@ def generate(
from src.options.generate.generator import Generator from src.options.generate.generator import Generator
generator = Generator(model=model) generator = Generator(model=model)
generator.generate(description, test, path) with get_openai_callback() as cb:
generator.generate(description, test, path)
print(f"Prompt/Completion/Total Tokens: {cb.prompt_tokens}/{cb.completion_tokens}/{cb.total_tokens}")
print(f"Total Cost on OpenAI (USD): ${cb.total_cost}")
@main.command() @main.command()
@path_param @path_param

View File

@@ -23,12 +23,6 @@ FILE_AND_TAG_PAIRS = [
FLOW_URL_PLACEHOLDER = 'jcloud.jina.ai' FLOW_URL_PLACEHOLDER = 'jcloud.jina.ai'
PRICING_GPT4_PROMPT = 0.03
PRICING_GPT4_GENERATION = 0.06
PRICING_GPT3_5_TURBO_PROMPT = 0.002
PRICING_GPT3_5_TURBO_GENERATION = 0.002
CHARS_PER_TOKEN = 3.4
NUM_IMPLEMENTATION_STRATEGIES = 5 NUM_IMPLEMENTATION_STRATEGIES = 5
MAX_DEBUGGING_ITERATIONS = 10 MAX_DEBUGGING_ITERATIONS = 10

View File

@@ -74,15 +74,15 @@ class Generator:
+ executor_file_task(microservice_name, description, test, package) + executor_file_task(microservice_name, description, test, package)
) )
conversation = self.gpt_session.get_conversation() conversation = self.gpt_session.get_conversation()
microservice_content_raw = conversation.query(user_query) microservice_content_raw = conversation.chat(user_query)
if is_chain_of_thought: if is_chain_of_thought:
microservice_content_raw = conversation.query( microservice_content_raw = conversation.chat(
f"General rules: " + not_allowed_executor() + chain_of_thought_optimization('python', f"General rules: " + not_allowed_executor() + chain_of_thought_optimization('python',
'microservice.py')) 'microservice.py'))
microservice_content = self.extract_content_from_result(microservice_content_raw, 'microservice.py', microservice_content = self.extract_content_from_result(microservice_content_raw, 'microservice.py',
match_single_block=True) match_single_block=True)
if microservice_content == '': if microservice_content == '':
microservice_content_raw = conversation.query('You must add the executor code.') microservice_content_raw = conversation.chat('You must add the executor code.')
microservice_content = self.extract_content_from_result( microservice_content = self.extract_content_from_result(
microservice_content_raw, 'microservice.py', match_single_block=True microservice_content_raw, 'microservice.py', match_single_block=True
) )
@@ -95,9 +95,9 @@ class Generator:
+ test_executor_file_task(microservice_name, test) + test_executor_file_task(microservice_name, test)
) )
conversation = self.gpt_session.get_conversation() conversation = self.gpt_session.get_conversation()
test_microservice_content_raw = conversation.query(user_query) test_microservice_content_raw = conversation.chat(user_query)
if is_chain_of_thought: if is_chain_of_thought:
test_microservice_content_raw = conversation.query( test_microservice_content_raw = conversation.chat(
f"General rules: " + not_allowed_executor() + f"General rules: " + not_allowed_executor() +
chain_of_thought_optimization('python', 'test_microservice.py') chain_of_thought_optimization('python', 'test_microservice.py')
+ "Don't add any additional tests. " + "Don't add any additional tests. "
@@ -116,9 +116,9 @@ class Generator:
+ requirements_file_task() + requirements_file_task()
) )
conversation = self.gpt_session.get_conversation() conversation = self.gpt_session.get_conversation()
requirements_content_raw = conversation.query(user_query) requirements_content_raw = conversation.chat(user_query)
if is_chain_of_thought: if is_chain_of_thought:
requirements_content_raw = conversation.query( requirements_content_raw = conversation.chat(
chain_of_thought_optimization('', requirements_path) + "Keep the same version of jina ") chain_of_thought_optimization('', requirements_path) + "Keep the same version of jina ")
requirements_content = self.extract_content_from_result(requirements_content_raw, 'requirements.txt', requirements_content = self.extract_content_from_result(requirements_content_raw, 'requirements.txt',
@@ -134,9 +134,9 @@ class Generator:
+ docker_file_task() + docker_file_task()
) )
conversation = self.gpt_session.get_conversation() conversation = self.gpt_session.get_conversation()
dockerfile_content_raw = conversation.query(user_query) dockerfile_content_raw = conversation.chat(user_query)
if is_chain_of_thought: if is_chain_of_thought:
dockerfile_content_raw = conversation.query( dockerfile_content_raw = conversation.chat(
f"General rules: " + not_allowed_executor() + chain_of_thought_optimization('dockerfile', 'Dockerfile')) f"General rules: " + not_allowed_executor() + chain_of_thought_optimization('dockerfile', 'Dockerfile'))
dockerfile_content = self.extract_content_from_result(dockerfile_content_raw, 'Dockerfile', dockerfile_content = self.extract_content_from_result(dockerfile_content_raw, 'Dockerfile',
match_single_block=True) match_single_block=True)
@@ -172,8 +172,8 @@ The playground (app.py) must not let the user configure the host on the ui.
''' '''
) )
conversation = self.gpt_session.get_conversation([]) conversation = self.gpt_session.get_conversation([])
conversation.query(user_query) conversation.chat(user_query)
playground_content_raw = conversation.query(chain_of_thought_optimization('python', 'app.py', 'the playground')) playground_content_raw = conversation.chat(chain_of_thought_optimization('python', 'app.py', 'the playground'))
playground_content = self.extract_content_from_result(playground_content_raw, 'app.py', match_single_block=True) playground_content = self.extract_content_from_result(playground_content_raw, 'app.py', match_single_block=True)
persist_file(playground_content, os.path.join(microservice_path, 'app.py')) persist_file(playground_content, os.path.join(microservice_path, 'app.py'))
@@ -213,7 +213,7 @@ The playground (app.py) must not let the user configure the host on the ui.
user_query = self.get_user_query_code_issue(description, error, file_name_to_content, user_query = self.get_user_query_code_issue(description, error, file_name_to_content,
test) test)
conversation = self.gpt_session.get_conversation() conversation = self.gpt_session.get_conversation()
returned_files_raw = conversation.query(user_query) returned_files_raw = conversation.chat(user_query)
for file_name, tag in FILE_AND_TAG_PAIRS: for file_name, tag in FILE_AND_TAG_PAIRS:
updated_file = self.extract_content_from_result(returned_files_raw, file_name) updated_file = self.extract_content_from_result(returned_files_raw, file_name)
if updated_file and (not is_dependency_issue or file_name in ['requirements.txt', 'Dockerfile']): if updated_file and (not is_dependency_issue or file_name in ['requirements.txt', 'Dockerfile']):
@@ -280,7 +280,7 @@ complete file. Use the exact same syntax to wrap the code:
print_colored('', 'Is it a dependency issue?', 'blue') print_colored('', 'Is it a dependency issue?', 'blue')
conversation = self.gpt_session.get_conversation([]) conversation = self.gpt_session.get_conversation([])
answer = conversation.query( answer = conversation.chat(
f'Your task is to assist in identifying the root cause of a Docker build error for a python application. ' f'Your task is to assist in identifying the root cause of a Docker build error for a python application. '
f'The error message is as follows::\n\n{error}\n\n' f'The error message is as follows::\n\n{error}\n\n'
f'The docker file is as follows:\n\n{docker_file}\n\n' f'The docker file is as follows:\n\n{docker_file}\n\n'
@@ -305,7 +305,7 @@ The output is a the raw string wrapped into ``` and starting with **name.txt** l
PDFParserExecutor PDFParserExecutor
``` ```
''' '''
name_raw = conversation.query(user_query) name_raw = conversation.chat(user_query)
name = self.extract_content_from_result(name_raw, 'name.txt') name = self.extract_content_from_result(name_raw, 'name.txt')
return name return name
@@ -341,7 +341,7 @@ package5
``` ```
''' '''
conversation = self.gpt_session.get_conversation() conversation = self.gpt_session.get_conversation()
packages_raw = conversation.query(user_query) packages_raw = conversation.chat(user_query)
packages_csv_string = self.extract_content_from_result(packages_raw, 'packages.csv') packages_csv_string = self.extract_content_from_result(packages_raw, 'packages.csv')
packages = [package.split(',') for package in packages_csv_string.split('\n')] packages = [package.split(',') for package in packages_csv_string.split('\n')]
packages = packages[:NUM_IMPLEMENTATION_STRATEGIES] packages = packages[:NUM_IMPLEMENTATION_STRATEGIES]
@@ -351,13 +351,17 @@ package5
generated_name = self.generate_microservice_name(description) generated_name = self.generate_microservice_name(description)
microservice_name = f'{generated_name}{random.randint(0, 10_000_000)}' microservice_name = f'{generated_name}{random.randint(0, 10_000_000)}'
packages_list = self.get_possible_packages(description) packages_list = self.get_possible_packages(description)
packages_list = [packages for packages in packages_list if len(set(packages).intersection(set(PROBLEMATIC_PACKAGES))) == 0] packages_list = [
packages for packages in packages_list if len(set(packages).intersection(set(PROBLEMATIC_PACKAGES))) == 0
]
for num_approach, packages in enumerate(packages_list): for num_approach, packages in enumerate(packages_list):
try: try:
self.generate_microservice(description, test, microservice_path, microservice_name, packages, self.generate_microservice(
num_approach) description, test, microservice_path, microservice_name, packages, num_approach
final_version_path = self.debug_microservice(microservice_path, microservice_name, num_approach, )
packages, description, test) final_version_path = self.debug_microservice(
microservice_path, microservice_name, num_approach, packages, description, test
)
self.generate_playground(microservice_name, final_version_path) self.generate_playground(microservice_name, final_version_path)
except self.MaxDebugTimeReachedException: except self.MaxDebugTimeReachedException:
print('Could not debug the Microservice with the approach:', packages) print('Could not debug the Microservice with the approach:', packages)

View File

@@ -71,7 +71,7 @@ print(response[0].text)
```''' ```'''
system_base_definition = f''' system_message_base = f'''
It is the year 2021. It is the year 2021.
You are a principal engineer working at Jina - an open source company. You are a principal engineer working at Jina - an open source company.
You accurately satisfy all of the user's requirements. You accurately satisfy all of the user's requirements.

View File

@@ -31,28 +31,6 @@ def get_all_microservice_files_with_content(folder_path):
return file_name_to_content return file_name_to_content
class GenerationTimeoutError(Exception):
pass
def timeout_generator_wrapper(generator, timeout):
def generator_func():
for item in generator:
yield item
def wrapper() -> Generator:
gen = generator_func()
while True:
try:
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(next, gen)
yield future.result(timeout=timeout)
except StopIteration:
break
except concurrent.futures.TimeoutError:
raise GenerationTimeoutError(f"Generation took too long")
return wrapper()
@contextmanager @contextmanager
def suppress_stdout(): def suppress_stdout():
original_stdout = sys.stdout original_stdout = sys.stdout