added stream support

This commit is contained in:
gilcu3
2023-11-20 15:10:54 +01:00
parent 7329cd7fc5
commit 6de497dac8
2 changed files with 259 additions and 52 deletions

View File

@@ -384,52 +384,177 @@ class OpenAIHelper:
logging.exception(e)
raise Exception(f"⚠️ _{localized_text('error', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e
async def interpret_image(self, chat_id, fileobj, prompt=None):
@retry(
reraise=True,
retry=retry_if_exception_type(openai.RateLimitError),
wait=wait_fixed(20),
stop=stop_after_attempt(3)
)
async def __common_get_chat_response_vision(self, chat_id: int, content: list, stream=False):
"""
Interprets a given PNG image file using the Vision model.
Request a response from the GPT model.
:param chat_id: The chat ID
:param query: The query to send to the model
:return: The answer from the model and the number of tokens used
"""
bot_language = self.config['bot_language']
try:
image = encode_image(fileobj)
prompt = self.config['vision_prompt'] if prompt is None else prompt
# for now I am not adding the image itself to the history
if chat_id not in self.conversations or self.__max_age_reached(chat_id):
self.reset_chat_history(chat_id)
self.last_updated[chat_id] = datetime.datetime.now()
for message in content:
if message['type'] == 'text':
query = message['text']
break
self.__add_to_history(chat_id, role="user", content=query)
# Summarize the chat history if it's too long to avoid excessive token usage
token_count = self.__count_tokens(self.conversations[chat_id])
exceeded_max_tokens = token_count + self.config['max_tokens'] > self.__max_model_tokens()
exceeded_max_history_size = len(self.conversations[chat_id]) > self.config['max_history_size']
if exceeded_max_tokens or exceeded_max_history_size:
logging.info(f'Chat history for chat ID {chat_id} is too long. Summarising...')
try:
summary = await self.__summarise(self.conversations[chat_id][:-1])
logging.debug(f'Summary: {summary}')
self.reset_chat_history(chat_id, self.conversations[chat_id][0]['content'])
self.__add_to_history(chat_id, role="assistant", content=summary)
self.__add_to_history(chat_id, role="user", content=query)
except Exception as e:
logging.warning(f'Error while summarising chat history: {str(e)}. Popping elements instead...')
self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:]
message = {'role':'user', 'content':content}
message = {'role':'user', 'content':[{'type':'text', 'text':prompt}, {'type':'image_url', \
'image_url': {'url':f'data:image/jpeg;base64,{image}', 'detail':self.config['vision_detail'] } }]}
common_args = {
'model': 'gpt-4-vision-preview', # the only one that currently makes sense here
'messages': self.conversations[chat_id] + [message],
'messages': self.conversations[chat_id][:-1] + [message],
'temperature': self.config['temperature'],
'n': 1, # several choices is not implemented yet
'max_tokens': self.config['vision_max_tokens'],
'presence_penalty': self.config['presence_penalty'],
'frequency_penalty': self.config['frequency_penalty'],
'stream': False # We need to refactor this class to make this feasible without too much repetition
'stream': stream
}
self.__add_to_history(chat_id, role="user", content=prompt)
response = await self.client.chat.completions.create(**common_args)
# vision model does not yet support functions
# if self.config['enable_functions']:
# functions = self.plugin_manager.get_functions_specs()
# if len(functions) > 0:
# common_args['functions'] = self.plugin_manager.get_functions_specs()
# common_args['function_call'] = 'auto'
return await self.client.chat.completions.create(**common_args)
content = response.choices[0].message.content
self.__add_to_history(chat_id, role="assistant", content=content)
return content, response.usage.total_tokens
except openai.RateLimitError as e:
raise e
except openai.BadRequestError as e:
raise Exception(f"⚠️ _{localized_text('openai_invalid', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e
raise Exception(f"⚠️ _{localized_text('openai_invalid', bot_language)}._ ⚠️\n{str(e)}") from e
except Exception as e:
logging.exception(e)
raise Exception(f"⚠️ _{localized_text('error', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e
raise Exception(f"⚠️ _{localized_text('error', bot_language)}._ ⚠️\n{str(e)}") from e
async def interpret_image(self, chat_id, fileobj, prompt=None):
"""
Interprets a given PNG image file using the Vision model.
"""
image = encode_image(fileobj)
prompt = self.config['vision_prompt'] if prompt is None else prompt
content = [{'type':'text', 'text':prompt}, {'type':'image_url', \
'image_url': {'url':f'data:image/jpeg;base64,{image}', 'detail':self.config['vision_detail'] } }]
response = await self.__common_get_chat_response_vision(chat_id, content)
# functions are not available for this model
# if self.config['enable_functions']:
# response, plugins_used = await self.__handle_function_call(chat_id, response)
# if is_direct_result(response):
# return response, '0'
answer = ''
if len(response.choices) > 1 and self.config['n_choices'] > 1:
for index, choice in enumerate(response.choices):
content = choice.message.content.strip()
if index == 0:
self.__add_to_history(chat_id, role="assistant", content=content)
answer += f'{index + 1}\u20e3\n'
answer += content
answer += '\n\n'
else:
answer = response.choices[0].message.content.strip()
self.__add_to_history(chat_id, role="assistant", content=answer)
bot_language = self.config['bot_language']
# Plugins are not enabled either
# show_plugins_used = len(plugins_used) > 0 and self.config['show_plugins_used']
# plugin_names = tuple(self.plugin_manager.get_plugin_source_name(plugin) for plugin in plugins_used)
if self.config['show_usage']:
answer += "\n\n---\n" \
f"💰 {str(response.usage.total_tokens)} {localized_text('stats_tokens', bot_language)}" \
f" ({str(response.usage.prompt_tokens)} {localized_text('prompt', bot_language)}," \
f" {str(response.usage.completion_tokens)} {localized_text('completion', bot_language)})"
# if show_plugins_used:
# answer += f"\n🔌 {', '.join(plugin_names)}"
# elif show_plugins_used:
# answer += f"\n\n---\n🔌 {', '.join(plugin_names)}"
return answer, response.usage.total_tokens
async def interpret_image_stream(self, chat_id, fileobj, prompt=None):
"""
Interprets a given PNG image file using the Vision model.
"""
image = encode_image(fileobj)
prompt = self.config['vision_prompt'] if prompt is None else prompt
content = [{'type':'text', 'text':prompt}, {'type':'image_url', \
'image_url': {'url':f'data:image/jpeg;base64,{image}', 'detail':self.config['vision_detail'] } }]
response = await self.__common_get_chat_response_vision(chat_id, content, stream=True)
# if self.config['enable_functions']:
# response, plugins_used = await self.__handle_function_call(chat_id, response, stream=True)
# if is_direct_result(response):
# yield response, '0'
# return
answer = ''
async for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0].delta
if delta.content:
answer += delta.content
yield answer, 'not_finished'
answer = answer.strip()
self.__add_to_history(chat_id, role="assistant", content=answer)
tokens_used = str(self.__count_tokens(self.conversations[chat_id]))
#show_plugins_used = len(plugins_used) > 0 and self.config['show_plugins_used']
#plugin_names = tuple(self.plugin_manager.get_plugin_source_name(plugin) for plugin in plugins_used)
if self.config['show_usage']:
answer += f"\n\n---\n💰 {tokens_used} {localized_text('stats_tokens', self.config['bot_language'])}"
# if show_plugins_used:
# answer += f"\n🔌 {', '.join(plugin_names)}"
# elif show_plugins_used:
# answer += f"\n\n---\n🔌 {', '.join(plugin_names)}"
yield answer, tokens_used
def reset_chat_history(self, chat_id, content=''):
"""