diff --git a/.env.example b/.env.example index e151186..d3d368d 100644 --- a/.env.example +++ b/.env.example @@ -52,4 +52,6 @@ ALLOWED_TELEGRAM_USER_IDS=USER_ID_1,USER_ID_2 # TTS_MODEL="tts-1" # TTS_VOICE="alloy" # TTS_PRICES=0.015,0.030 -# BOT_LANGUAGE=en \ No newline at end of file +# BOT_LANGUAGE=en +# ENABLE_VISION_FOLLOW_UP_QUESTIONS="true" +# VISION_MODEL="gpt-4-vision-preview" \ No newline at end of file diff --git a/README.md b/README.md index 0c065e0..86d73fe 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,8 @@ Check out the [Budget Manual](https://github.com/n3d1117/chatgpt-telegram-bot/di | `STREAM` | Whether to stream responses. **Note**: incompatible, if enabled, with `N_CHOICES` higher than 1 | `true` | | `MAX_TOKENS` | Upper bound on how many tokens the ChatGPT API will return | `1200` for GPT-3, `2400` for GPT-4 | | `VISION_MAX_TOKENS` | Upper bound on how many tokens vision models will return | `300` for gpt-4-vision-preview | +| `VISION_MODEL` | The Vision to Speech model to use. Allowed values: `gpt-4-vision-preview` | `gpt-4-vision-preview` | +| `ENABLE_VISION_FOLLOW_UP_QUESTIONS` | If true, once you send an image to the bot, it uses the configured VISION_MODEL until the conversation ends. Otherwise, it uses the OPENAI_MODEL to follow the conversation. Allowed values: `true` or `false` | `true` | | `MAX_HISTORY_SIZE` | Max number of messages to keep in memory, after which the conversation will be summarised to avoid excessive token usage | `15` | | `MAX_CONVERSATION_AGE_MINUTES` | Maximum number of minutes a conversation should live since the last message, after which the conversation will be reset | `180` | | `VOICE_REPLY_WITH_TRANSCRIPT_ONLY` | Whether to answer to voice messages with the transcript only or with a ChatGPT response of the transcript | `false` | diff --git a/bot/main.py b/bot/main.py index 8c87925..d7673d1 100644 --- a/bot/main.py +++ b/bot/main.py @@ -53,6 +53,8 @@ def main(): 'bot_language': os.environ.get('BOT_LANGUAGE', 'en'), 'show_plugins_used': os.environ.get('SHOW_PLUGINS_USED', 'false').lower() == 'true', 'whisper_prompt': os.environ.get('WHISPER_PROMPT', ''), + 'vision_model': os.environ.get('VISION_MODEL', 'gpt-4-vision-preview'), + 'enable_vision_follow_up_questions': os.environ.get('ENABLE_VISION_FOLLOW_UP_QUESTIONS', 'true').lower() == 'true', 'vision_prompt': os.environ.get('VISION_PROMPT', 'What is in this image'), 'vision_detail': os.environ.get('VISION_DETAIL', 'auto'), 'vision_max_tokens': int(os.environ.get('VISION_MAX_TOKENS', '300')), diff --git a/bot/openai_helper.py b/bot/openai_helper.py index 11e9df3..6346eff 100644 --- a/bot/openai_helper.py +++ b/bot/openai_helper.py @@ -17,7 +17,7 @@ from PIL import Image from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type -from utils import is_direct_result, encode_image +from utils import is_direct_result, encode_image, decode_image from plugin_manager import PluginManager # Models can be found here: https://platform.openai.com/docs/models/overview @@ -109,6 +109,7 @@ class OpenAIHelper: self.config = config self.plugin_manager = plugin_manager self.conversations: dict[int: list] = {} # {chat_id: history} + self.conversations_vision: dict[int: bool] = {} # {chat_id: is_vision} self.last_updated: dict[int: datetime] = {} # {chat_id: last_update_timestamp} def get_conversation_stats(self, chat_id: int) -> tuple[int, int]: @@ -130,7 +131,7 @@ class OpenAIHelper: """ plugins_used = () response = await self.__common_get_chat_response(chat_id, query) - if self.config['enable_functions']: + if self.config['enable_functions'] and not self.conversations_vision[chat_id]: response, plugins_used = await self.__handle_function_call(chat_id, response) if is_direct_result(response): return response, '0' @@ -173,7 +174,7 @@ class OpenAIHelper: """ plugins_used = () response = await self.__common_get_chat_response(chat_id, query, stream=True) - if self.config['enable_functions']: + if self.config['enable_functions'] and not self.conversations_vision[chat_id]: response, plugins_used = await self.__handle_function_call(chat_id, response, stream=True) if is_direct_result(response): yield response, '0' @@ -242,7 +243,7 @@ class OpenAIHelper: self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:] common_args = { - 'model': self.config['model'], + 'model': self.config['model'] if not self.conversations_vision[chat_id] else self.config['vision_model'], 'messages': self.conversations[chat_id], 'temperature': self.config['temperature'], 'n': self.config['n_choices'], @@ -252,7 +253,7 @@ class OpenAIHelper: 'stream': stream } - if self.config['enable_functions']: + if self.config['enable_functions'] and not self.conversations_vision[chat_id]: functions = self.plugin_manager.get_functions_specs() if len(functions) > 0: common_args['functions'] = self.plugin_manager.get_functions_specs() @@ -404,12 +405,15 @@ class OpenAIHelper: self.last_updated[chat_id] = datetime.datetime.now() - for message in content: - if message['type'] == 'text': - query = message['text'] - break - - self.__add_to_history(chat_id, role="user", content=query) + if self.config['enable_vision_follow_up_questions']: + self.conversations_vision[chat_id] = True + self.__add_to_history(chat_id, role="user", content=content) + else: + for message in content: + if message['type'] == 'text': + query = message['text'] + break + self.__add_to_history(chat_id, role="user", content=query) # Summarize the chat history if it's too long to avoid excessive token usage token_count = self.__count_tokens(self.conversations[chat_id]) @@ -419,11 +423,13 @@ class OpenAIHelper: if exceeded_max_tokens or exceeded_max_history_size: logging.info(f'Chat history for chat ID {chat_id} is too long. Summarising...') try: + + last = self.conversations[chat_id][-1] summary = await self.__summarise(self.conversations[chat_id][:-1]) logging.debug(f'Summary: {summary}') self.reset_chat_history(chat_id, self.conversations[chat_id][0]['content']) self.__add_to_history(chat_id, role="assistant", content=summary) - self.__add_to_history(chat_id, role="user", content=query) + self.conversations[chat_id] += [last] except Exception as e: logging.warning(f'Error while summarising chat history: {str(e)}. Popping elements instead...') self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:] @@ -431,7 +437,7 @@ class OpenAIHelper: message = {'role':'user', 'content':content} common_args = { - 'model': 'gpt-4-vision-preview', # the only one that currently makes sense here + 'model': self.config['vision_model'], 'messages': self.conversations[chat_id][:-1] + [message], 'temperature': self.config['temperature'], 'n': 1, # several choices is not implemented yet @@ -470,7 +476,7 @@ class OpenAIHelper: prompt = self.config['vision_prompt'] if prompt is None else prompt content = [{'type':'text', 'text':prompt}, {'type':'image_url', \ - 'image_url': {'url':f'data:image/jpeg;base64,{image}', 'detail':self.config['vision_detail'] } }] + 'image_url': {'url':image, 'detail':self.config['vision_detail'] } }] response = await self.__common_get_chat_response_vision(chat_id, content) @@ -521,7 +527,7 @@ class OpenAIHelper: prompt = self.config['vision_prompt'] if prompt is None else prompt content = [{'type':'text', 'text':prompt}, {'type':'image_url', \ - 'image_url': {'url':f'data:image/jpeg;base64,{image}', 'detail':self.config['vision_detail'] } }] + 'image_url': {'url':image, 'detail':self.config['vision_detail'] } }] response = await self.__common_get_chat_response_vision(chat_id, content, stream=True) @@ -563,6 +569,7 @@ class OpenAIHelper: if content == '': content = self.config['assistant_prompt'] self.conversations[chat_id] = [{"role": "system", "content": content}] + self.conversations_vision[chat_id] = False def __max_age_reached(self, chat_id) -> bool: """ @@ -652,42 +659,54 @@ class OpenAIHelper: for message in messages: num_tokens += tokens_per_message for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += tokens_per_name + if key == 'content': + if isinstance(value, str): + num_tokens += len(encoding.encode(value)) + else: + for message1 in value: + if message1['type'] == 'image_url': + image = decode_image(message1['image_url']['url']) + num_tokens += self.__count_tokens_vision(image) + else: + num_tokens += len(encoding.encode(message1['text'])) + else: + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens # no longer needed - # def __count_tokens_vision(self, fileobj) -> int: - # """ - # Counts the number of tokens for interpreting an image. - # :param image: image to interpret - # :return: the number of tokens required - # """ - # image = Image.open(fileobj) - # model = 'gpt-4-vision-preview' # fixed for now - # if model not in GPT_4_VISION_MODELS: - # raise NotImplementedError(f"""count_tokens_vision() is not implemented for model {model}.""") + def __count_tokens_vision(self, image_bytes: bytes) -> int: + """ + Counts the number of tokens for interpreting an image. + :param image_bytes: image to interpret + :return: the number of tokens required + """ + image_file = io.BytesIO(image_bytes) + image = Image.open(image_file) + model = self.config['vision_model'] + if model not in GPT_4_VISION_MODELS: + raise NotImplementedError(f"""count_tokens_vision() is not implemented for model {model}.""") - # w, h = image.size - # if w > h: w, h = h, w - # # this computation follows https://platform.openai.com/docs/guides/vision and https://openai.com/pricing#gpt-4-turbo - # base_tokens = 85 - # detail = self.config['vision_detail'] - # if detail == 'low': - # return base_tokens - # elif detail == 'high': - # f = max(w / 768, h / 2048) - # if f > 1: - # w, h = int(w / f), int(h / f) - # tw, th = (w + 511) // 512, (h + 511) // 512 - # tiles = tw * th - # num_tokens = base_tokens + tiles * 170 - # return num_tokens - # else: - # raise NotImplementedError(f"""unknown parameter detail={detail} for model {model}.""") + w, h = image.size + if w > h: w, h = h, w + # this computation follows https://platform.openai.com/docs/guides/vision and https://openai.com/pricing#gpt-4-turbo + base_tokens = 85 + detail = self.config['vision_detail'] + if detail == 'low': + return base_tokens + elif detail == 'high' or detail == 'auto': # assuming worst cost for auto + f = max(w / 768, h / 2048) + if f > 1: + w, h = int(w / f), int(h / f) + tw, th = (w + 511) // 512, (h + 511) // 512 + tiles = tw * th + num_tokens = base_tokens + tiles * 170 + return num_tokens + else: + raise NotImplementedError(f"""unknown parameter detail={detail} for model {model}.""") # No longer works as of July 21st 2023, as OpenAI has removed the billing API # def get_billing_current_month(self): diff --git a/bot/utils.py b/bot/utils.py index 989cfc7..d306dc6 100644 --- a/bot/utils.py +++ b/bot/utils.py @@ -382,4 +382,9 @@ def cleanup_intermediate_files(response: any): # Function to encode the image def encode_image(fileobj): - return base64.b64encode(fileobj.getvalue()).decode('utf-8') \ No newline at end of file + image = base64.b64encode(fileobj.getvalue()).decode('utf-8') + return f'data:image/jpeg;base64,{image}' + +def decode_image(imgbase64): + image = imgbase64[len('data:image/jpeg;base64,'):] + return base64.b64decode(image)