added ENABLE_VISION_FOLLOW_UP_QUESTIONS support

This commit is contained in:
gilcu3
2023-11-25 14:26:37 +01:00
parent 6de497dac8
commit 237705cec0
5 changed files with 77 additions and 47 deletions

View File

@@ -52,4 +52,6 @@ ALLOWED_TELEGRAM_USER_IDS=USER_ID_1,USER_ID_2
# TTS_MODEL="tts-1" # TTS_MODEL="tts-1"
# TTS_VOICE="alloy" # TTS_VOICE="alloy"
# TTS_PRICES=0.015,0.030 # TTS_PRICES=0.015,0.030
# BOT_LANGUAGE=en # BOT_LANGUAGE=en
# ENABLE_VISION_FOLLOW_UP_QUESTIONS="true"
# VISION_MODEL="gpt-4-vision-preview"

View File

@@ -96,6 +96,8 @@ Check out the [Budget Manual](https://github.com/n3d1117/chatgpt-telegram-bot/di
| `STREAM` | Whether to stream responses. **Note**: incompatible, if enabled, with `N_CHOICES` higher than 1 | `true` | | `STREAM` | Whether to stream responses. **Note**: incompatible, if enabled, with `N_CHOICES` higher than 1 | `true` |
| `MAX_TOKENS` | Upper bound on how many tokens the ChatGPT API will return | `1200` for GPT-3, `2400` for GPT-4 | | `MAX_TOKENS` | Upper bound on how many tokens the ChatGPT API will return | `1200` for GPT-3, `2400` for GPT-4 |
| `VISION_MAX_TOKENS` | Upper bound on how many tokens vision models will return | `300` for gpt-4-vision-preview | | `VISION_MAX_TOKENS` | Upper bound on how many tokens vision models will return | `300` for gpt-4-vision-preview |
| `VISION_MODEL` | The Vision to Speech model to use. Allowed values: `gpt-4-vision-preview` | `gpt-4-vision-preview` |
| `ENABLE_VISION_FOLLOW_UP_QUESTIONS` | If true, once you send an image to the bot, it uses the configured VISION_MODEL until the conversation ends. Otherwise, it uses the OPENAI_MODEL to follow the conversation. Allowed values: `true` or `false` | `true` |
| `MAX_HISTORY_SIZE` | Max number of messages to keep in memory, after which the conversation will be summarised to avoid excessive token usage | `15` | | `MAX_HISTORY_SIZE` | Max number of messages to keep in memory, after which the conversation will be summarised to avoid excessive token usage | `15` |
| `MAX_CONVERSATION_AGE_MINUTES` | Maximum number of minutes a conversation should live since the last message, after which the conversation will be reset | `180` | | `MAX_CONVERSATION_AGE_MINUTES` | Maximum number of minutes a conversation should live since the last message, after which the conversation will be reset | `180` |
| `VOICE_REPLY_WITH_TRANSCRIPT_ONLY` | Whether to answer to voice messages with the transcript only or with a ChatGPT response of the transcript | `false` | | `VOICE_REPLY_WITH_TRANSCRIPT_ONLY` | Whether to answer to voice messages with the transcript only or with a ChatGPT response of the transcript | `false` |

View File

@@ -53,6 +53,8 @@ def main():
'bot_language': os.environ.get('BOT_LANGUAGE', 'en'), 'bot_language': os.environ.get('BOT_LANGUAGE', 'en'),
'show_plugins_used': os.environ.get('SHOW_PLUGINS_USED', 'false').lower() == 'true', 'show_plugins_used': os.environ.get('SHOW_PLUGINS_USED', 'false').lower() == 'true',
'whisper_prompt': os.environ.get('WHISPER_PROMPT', ''), 'whisper_prompt': os.environ.get('WHISPER_PROMPT', ''),
'vision_model': os.environ.get('VISION_MODEL', 'gpt-4-vision-preview'),
'enable_vision_follow_up_questions': os.environ.get('ENABLE_VISION_FOLLOW_UP_QUESTIONS', 'true').lower() == 'true',
'vision_prompt': os.environ.get('VISION_PROMPT', 'What is in this image'), 'vision_prompt': os.environ.get('VISION_PROMPT', 'What is in this image'),
'vision_detail': os.environ.get('VISION_DETAIL', 'auto'), 'vision_detail': os.environ.get('VISION_DETAIL', 'auto'),
'vision_max_tokens': int(os.environ.get('VISION_MAX_TOKENS', '300')), 'vision_max_tokens': int(os.environ.get('VISION_MAX_TOKENS', '300')),

View File

@@ -17,7 +17,7 @@ from PIL import Image
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
from utils import is_direct_result, encode_image from utils import is_direct_result, encode_image, decode_image
from plugin_manager import PluginManager from plugin_manager import PluginManager
# Models can be found here: https://platform.openai.com/docs/models/overview # Models can be found here: https://platform.openai.com/docs/models/overview
@@ -109,6 +109,7 @@ class OpenAIHelper:
self.config = config self.config = config
self.plugin_manager = plugin_manager self.plugin_manager = plugin_manager
self.conversations: dict[int: list] = {} # {chat_id: history} self.conversations: dict[int: list] = {} # {chat_id: history}
self.conversations_vision: dict[int: bool] = {} # {chat_id: is_vision}
self.last_updated: dict[int: datetime] = {} # {chat_id: last_update_timestamp} self.last_updated: dict[int: datetime] = {} # {chat_id: last_update_timestamp}
def get_conversation_stats(self, chat_id: int) -> tuple[int, int]: def get_conversation_stats(self, chat_id: int) -> tuple[int, int]:
@@ -130,7 +131,7 @@ class OpenAIHelper:
""" """
plugins_used = () plugins_used = ()
response = await self.__common_get_chat_response(chat_id, query) response = await self.__common_get_chat_response(chat_id, query)
if self.config['enable_functions']: if self.config['enable_functions'] and not self.conversations_vision[chat_id]:
response, plugins_used = await self.__handle_function_call(chat_id, response) response, plugins_used = await self.__handle_function_call(chat_id, response)
if is_direct_result(response): if is_direct_result(response):
return response, '0' return response, '0'
@@ -173,7 +174,7 @@ class OpenAIHelper:
""" """
plugins_used = () plugins_used = ()
response = await self.__common_get_chat_response(chat_id, query, stream=True) response = await self.__common_get_chat_response(chat_id, query, stream=True)
if self.config['enable_functions']: if self.config['enable_functions'] and not self.conversations_vision[chat_id]:
response, plugins_used = await self.__handle_function_call(chat_id, response, stream=True) response, plugins_used = await self.__handle_function_call(chat_id, response, stream=True)
if is_direct_result(response): if is_direct_result(response):
yield response, '0' yield response, '0'
@@ -242,7 +243,7 @@ class OpenAIHelper:
self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:] self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:]
common_args = { common_args = {
'model': self.config['model'], 'model': self.config['model'] if not self.conversations_vision[chat_id] else self.config['vision_model'],
'messages': self.conversations[chat_id], 'messages': self.conversations[chat_id],
'temperature': self.config['temperature'], 'temperature': self.config['temperature'],
'n': self.config['n_choices'], 'n': self.config['n_choices'],
@@ -252,7 +253,7 @@ class OpenAIHelper:
'stream': stream 'stream': stream
} }
if self.config['enable_functions']: if self.config['enable_functions'] and not self.conversations_vision[chat_id]:
functions = self.plugin_manager.get_functions_specs() functions = self.plugin_manager.get_functions_specs()
if len(functions) > 0: if len(functions) > 0:
common_args['functions'] = self.plugin_manager.get_functions_specs() common_args['functions'] = self.plugin_manager.get_functions_specs()
@@ -404,12 +405,15 @@ class OpenAIHelper:
self.last_updated[chat_id] = datetime.datetime.now() self.last_updated[chat_id] = datetime.datetime.now()
for message in content: if self.config['enable_vision_follow_up_questions']:
if message['type'] == 'text': self.conversations_vision[chat_id] = True
query = message['text'] self.__add_to_history(chat_id, role="user", content=content)
break else:
for message in content:
self.__add_to_history(chat_id, role="user", content=query) if message['type'] == 'text':
query = message['text']
break
self.__add_to_history(chat_id, role="user", content=query)
# Summarize the chat history if it's too long to avoid excessive token usage # Summarize the chat history if it's too long to avoid excessive token usage
token_count = self.__count_tokens(self.conversations[chat_id]) token_count = self.__count_tokens(self.conversations[chat_id])
@@ -419,11 +423,13 @@ class OpenAIHelper:
if exceeded_max_tokens or exceeded_max_history_size: if exceeded_max_tokens or exceeded_max_history_size:
logging.info(f'Chat history for chat ID {chat_id} is too long. Summarising...') logging.info(f'Chat history for chat ID {chat_id} is too long. Summarising...')
try: try:
last = self.conversations[chat_id][-1]
summary = await self.__summarise(self.conversations[chat_id][:-1]) summary = await self.__summarise(self.conversations[chat_id][:-1])
logging.debug(f'Summary: {summary}') logging.debug(f'Summary: {summary}')
self.reset_chat_history(chat_id, self.conversations[chat_id][0]['content']) self.reset_chat_history(chat_id, self.conversations[chat_id][0]['content'])
self.__add_to_history(chat_id, role="assistant", content=summary) self.__add_to_history(chat_id, role="assistant", content=summary)
self.__add_to_history(chat_id, role="user", content=query) self.conversations[chat_id] += [last]
except Exception as e: except Exception as e:
logging.warning(f'Error while summarising chat history: {str(e)}. Popping elements instead...') logging.warning(f'Error while summarising chat history: {str(e)}. Popping elements instead...')
self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:] self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:]
@@ -431,7 +437,7 @@ class OpenAIHelper:
message = {'role':'user', 'content':content} message = {'role':'user', 'content':content}
common_args = { common_args = {
'model': 'gpt-4-vision-preview', # the only one that currently makes sense here 'model': self.config['vision_model'],
'messages': self.conversations[chat_id][:-1] + [message], 'messages': self.conversations[chat_id][:-1] + [message],
'temperature': self.config['temperature'], 'temperature': self.config['temperature'],
'n': 1, # several choices is not implemented yet 'n': 1, # several choices is not implemented yet
@@ -470,7 +476,7 @@ class OpenAIHelper:
prompt = self.config['vision_prompt'] if prompt is None else prompt prompt = self.config['vision_prompt'] if prompt is None else prompt
content = [{'type':'text', 'text':prompt}, {'type':'image_url', \ content = [{'type':'text', 'text':prompt}, {'type':'image_url', \
'image_url': {'url':f'data:image/jpeg;base64,{image}', 'detail':self.config['vision_detail'] } }] 'image_url': {'url':image, 'detail':self.config['vision_detail'] } }]
response = await self.__common_get_chat_response_vision(chat_id, content) response = await self.__common_get_chat_response_vision(chat_id, content)
@@ -521,7 +527,7 @@ class OpenAIHelper:
prompt = self.config['vision_prompt'] if prompt is None else prompt prompt = self.config['vision_prompt'] if prompt is None else prompt
content = [{'type':'text', 'text':prompt}, {'type':'image_url', \ content = [{'type':'text', 'text':prompt}, {'type':'image_url', \
'image_url': {'url':f'data:image/jpeg;base64,{image}', 'detail':self.config['vision_detail'] } }] 'image_url': {'url':image, 'detail':self.config['vision_detail'] } }]
response = await self.__common_get_chat_response_vision(chat_id, content, stream=True) response = await self.__common_get_chat_response_vision(chat_id, content, stream=True)
@@ -563,6 +569,7 @@ class OpenAIHelper:
if content == '': if content == '':
content = self.config['assistant_prompt'] content = self.config['assistant_prompt']
self.conversations[chat_id] = [{"role": "system", "content": content}] self.conversations[chat_id] = [{"role": "system", "content": content}]
self.conversations_vision[chat_id] = False
def __max_age_reached(self, chat_id) -> bool: def __max_age_reached(self, chat_id) -> bool:
""" """
@@ -652,42 +659,54 @@ class OpenAIHelper:
for message in messages: for message in messages:
num_tokens += tokens_per_message num_tokens += tokens_per_message
for key, value in message.items(): for key, value in message.items():
num_tokens += len(encoding.encode(value)) if key == 'content':
if key == "name": if isinstance(value, str):
num_tokens += tokens_per_name num_tokens += len(encoding.encode(value))
else:
for message1 in value:
if message1['type'] == 'image_url':
image = decode_image(message1['image_url']['url'])
num_tokens += self.__count_tokens_vision(image)
else:
num_tokens += len(encoding.encode(message1['text']))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens return num_tokens
# no longer needed # no longer needed
# def __count_tokens_vision(self, fileobj) -> int: def __count_tokens_vision(self, image_bytes: bytes) -> int:
# """ """
# Counts the number of tokens for interpreting an image. Counts the number of tokens for interpreting an image.
# :param image: image to interpret :param image_bytes: image to interpret
# :return: the number of tokens required :return: the number of tokens required
# """ """
# image = Image.open(fileobj) image_file = io.BytesIO(image_bytes)
# model = 'gpt-4-vision-preview' # fixed for now image = Image.open(image_file)
# if model not in GPT_4_VISION_MODELS: model = self.config['vision_model']
# raise NotImplementedError(f"""count_tokens_vision() is not implemented for model {model}.""") if model not in GPT_4_VISION_MODELS:
raise NotImplementedError(f"""count_tokens_vision() is not implemented for model {model}.""")
# w, h = image.size w, h = image.size
# if w > h: w, h = h, w if w > h: w, h = h, w
# # this computation follows https://platform.openai.com/docs/guides/vision and https://openai.com/pricing#gpt-4-turbo # this computation follows https://platform.openai.com/docs/guides/vision and https://openai.com/pricing#gpt-4-turbo
# base_tokens = 85 base_tokens = 85
# detail = self.config['vision_detail'] detail = self.config['vision_detail']
# if detail == 'low': if detail == 'low':
# return base_tokens return base_tokens
# elif detail == 'high': elif detail == 'high' or detail == 'auto': # assuming worst cost for auto
# f = max(w / 768, h / 2048) f = max(w / 768, h / 2048)
# if f > 1: if f > 1:
# w, h = int(w / f), int(h / f) w, h = int(w / f), int(h / f)
# tw, th = (w + 511) // 512, (h + 511) // 512 tw, th = (w + 511) // 512, (h + 511) // 512
# tiles = tw * th tiles = tw * th
# num_tokens = base_tokens + tiles * 170 num_tokens = base_tokens + tiles * 170
# return num_tokens return num_tokens
# else: else:
# raise NotImplementedError(f"""unknown parameter detail={detail} for model {model}.""") raise NotImplementedError(f"""unknown parameter detail={detail} for model {model}.""")
# No longer works as of July 21st 2023, as OpenAI has removed the billing API # No longer works as of July 21st 2023, as OpenAI has removed the billing API
# def get_billing_current_month(self): # def get_billing_current_month(self):

View File

@@ -382,4 +382,9 @@ def cleanup_intermediate_files(response: any):
# Function to encode the image # Function to encode the image
def encode_image(fileobj): def encode_image(fileobj):
return base64.b64encode(fileobj.getvalue()).decode('utf-8') image = base64.b64encode(fileobj.getvalue()).decode('utf-8')
return f'data:image/jpeg;base64,{image}'
def decode_image(imgbase64):
image = imgbase64[len('data:image/jpeg;base64,'):]
return base64.b64decode(image)