mirror of
https://github.com/aljazceru/dev-gpt.git
synced 2025-12-21 23:54:19 +01:00
refactor: langchain
This commit is contained in:
124
src/apis/gpt.py
124
src/apis/gpt.py
@@ -1,15 +1,21 @@
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
from typing import List, Any
|
||||
|
||||
import openai
|
||||
from openai.error import RateLimitError, Timeout, APIConnectionError
|
||||
from langchain.callbacks import CallbackManager
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from openai.error import RateLimitError
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
BaseMessage
|
||||
)
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from src.constants import PRICING_GPT4_PROMPT, PRICING_GPT4_GENERATION, PRICING_GPT3_5_TURBO_PROMPT, \
|
||||
PRICING_GPT3_5_TURBO_GENERATION, CHARS_PER_TOKEN
|
||||
from src.options.generate.prompt_system import system_base_definition, executor_example, docarray_example, client_example
|
||||
from src.utils.io import timeout_generator_wrapper, GenerationTimeoutError
|
||||
from src.options.generate.prompt_system import system_message_base, executor_example, docarray_example, client_example
|
||||
from src.utils.string_tools import print_colored
|
||||
|
||||
|
||||
@@ -18,19 +24,14 @@ class GPTSession:
|
||||
self.configure_openai_api_key()
|
||||
if model == 'gpt-4' and self.is_gpt4_available():
|
||||
self.supported_model = 'gpt-4'
|
||||
self.pricing_prompt = PRICING_GPT4_PROMPT
|
||||
self.pricing_generation = PRICING_GPT4_GENERATION
|
||||
else:
|
||||
if model == 'gpt-4':
|
||||
print_colored('GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow')
|
||||
model = 'gpt-3.5-turbo'
|
||||
self.supported_model = model
|
||||
self.pricing_prompt = PRICING_GPT3_5_TURBO_PROMPT
|
||||
self.pricing_generation = PRICING_GPT3_5_TURBO_GENERATION
|
||||
self.chars_prompt_so_far = 0
|
||||
self.chars_generation_so_far = 0
|
||||
|
||||
def configure_openai_api_key(self):
|
||||
@staticmethod
|
||||
def configure_openai_api_key():
|
||||
if 'OPENAI_API_KEY' not in os.environ:
|
||||
raise Exception('''
|
||||
You need to set OPENAI_API_KEY in your environment.
|
||||
@@ -39,7 +40,8 @@ If you have updated it already, please restart your terminal.
|
||||
)
|
||||
openai.api_key = os.environ['OPENAI_API_KEY']
|
||||
|
||||
def is_gpt4_available(self):
|
||||
@staticmethod
|
||||
def is_gpt4_available():
|
||||
try:
|
||||
for i in range(5):
|
||||
try:
|
||||
@@ -58,87 +60,47 @@ If you have updated it already, please restart your terminal.
|
||||
except openai.error.InvalidRequestError:
|
||||
return False
|
||||
|
||||
def cost_callback(self, chars_prompt, chars_generation):
|
||||
self.chars_prompt_so_far += chars_prompt
|
||||
self.chars_generation_so_far += chars_generation
|
||||
print('\n')
|
||||
money_prompt = self.calculate_money_spent(self.chars_prompt_so_far, self.pricing_prompt)
|
||||
money_generation = self.calculate_money_spent(self.chars_generation_so_far, self.pricing_generation)
|
||||
print('Total money spent so far on openai.com:', f'${money_prompt + money_generation}')
|
||||
print('\n')
|
||||
|
||||
def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
|
||||
return _GPTConversation(self.supported_model, self.cost_callback, system_definition_examples)
|
||||
return _GPTConversation(self.supported_model, system_definition_examples)
|
||||
|
||||
def calculate_money_spent(self, num_chars, price):
|
||||
return round(num_chars / CHARS_PER_TOKEN * price / 1000, 3)
|
||||
|
||||
class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
if os.environ['VERBOSE'].lower() == 'true':
|
||||
print_colored('', token, 'green', end='')
|
||||
|
||||
|
||||
class _GPTConversation:
|
||||
def __init__(self, model: str, cost_callback, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
|
||||
self.model = model
|
||||
self.cost_callback = cost_callback
|
||||
self.prompt_list: List[Optional[Tuple]] = [None]
|
||||
self.set_system_definition(system_definition_examples)
|
||||
def __init__(self, model: str, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
|
||||
self.chat = ChatOpenAI(
|
||||
model_name=model,
|
||||
streaming=True,
|
||||
callback_manager=CallbackManager([AssistantStreamingStdOutCallbackHandler()]),
|
||||
temperature=0
|
||||
)
|
||||
self.messages: List[BaseMessage] = []
|
||||
self.system_message = self._create_system_message(system_definition_examples)
|
||||
if os.environ['VERBOSE'].lower() == 'true':
|
||||
print_colored('system', self.prompt_list[0][1], 'magenta')
|
||||
print_colored('system', self.system_message.content, 'magenta')
|
||||
|
||||
def query(self, prompt: str):
|
||||
def chat(self, prompt: str):
|
||||
chat_message = HumanMessage(content=prompt)
|
||||
self.messages.append(chat_message)
|
||||
if os.environ['VERBOSE'].lower() == 'true':
|
||||
print_colored('user', prompt, 'blue')
|
||||
self.prompt_list.append(('user', prompt))
|
||||
response = self.get_response(self.prompt_list)
|
||||
self.prompt_list.append(('assistant', response))
|
||||
print_colored('assistant', '', 'green', end='')
|
||||
response = self.chat([self.system_message] + self.messages)
|
||||
self.messages.append(AIMessage(content=response))
|
||||
return response
|
||||
|
||||
def set_system_definition(self, system_definition_examples: List[str] = []):
|
||||
system_message = system_base_definition
|
||||
@staticmethod
|
||||
def _create_system_message(system_definition_examples: List[str] = []) -> SystemMessage:
|
||||
system_message = system_message_base
|
||||
if 'executor' in system_definition_examples:
|
||||
system_message += f'\n{executor_example}'
|
||||
if 'docarray' in system_definition_examples:
|
||||
system_message += f'\n{docarray_example}'
|
||||
if 'client' in system_definition_examples:
|
||||
system_message += f'\n{client_example}'
|
||||
self.prompt_list[0] = ('system', system_message)
|
||||
|
||||
def get_response_from_stream(self, response_generator):
|
||||
response_generator_with_timeout = timeout_generator_wrapper(response_generator, 10)
|
||||
complete_string = ''
|
||||
for chunk in response_generator_with_timeout:
|
||||
delta = chunk['choices'][0]['delta']
|
||||
if 'content' in delta:
|
||||
content = delta['content']
|
||||
print_colored('' if complete_string else 'assistant', content, 'green', end='')
|
||||
complete_string += content
|
||||
return complete_string
|
||||
|
||||
def get_response(self, prompt_list: List[Tuple[str, str]]):
|
||||
for i in range(10):
|
||||
try:
|
||||
response_generator = openai.ChatCompletion.create(
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
model=self.model,
|
||||
stream=True,
|
||||
messages=[
|
||||
{
|
||||
"role": prompt[0],
|
||||
"content": prompt[1]
|
||||
}
|
||||
for prompt in prompt_list
|
||||
]
|
||||
)
|
||||
|
||||
complete_string = self.get_response_from_stream(response_generator)
|
||||
|
||||
except (RateLimitError, Timeout, ConnectionError, APIConnectionError, GenerationTimeoutError) as e:
|
||||
print('/n', e)
|
||||
print('retrying...')
|
||||
sleep(3)
|
||||
continue
|
||||
chars_prompt = sum(len(prompt[1]) for prompt in prompt_list)
|
||||
chars_generation = len(complete_string)
|
||||
self.cost_callback(chars_prompt, chars_generation)
|
||||
return complete_string
|
||||
raise Exception('Failed to get response')
|
||||
|
||||
return SystemMessage(content=system_message)
|
||||
|
||||
Reference in New Issue
Block a user