diff --git a/.env.example b/.env.example index 977f7ad..d3d368d 100644 --- a/.env.example +++ b/.env.example @@ -17,10 +17,12 @@ ALLOWED_TELEGRAM_USER_IDS=USER_ID_1,USER_ID_2 # TOKEN_PRICE=0.002 # IMAGE_PRICES=0.016,0.018,0.02 # TRANSCRIPTION_PRICE=0.006 +# VISION_TOKEN_PRICE=0.01 # ENABLE_QUOTING=true # ENABLE_IMAGE_GENERATION=true # ENABLE_TTS_GENERATION=true # ENABLE_TRANSCRIPTION=true +# ENABLE_VISION=true # PROXY=http://localhost:8080 # OPENAI_MODEL=gpt-3.5-turbo # OPENAI_BASE_URL=https://example.com/v1/ @@ -28,10 +30,12 @@ ALLOWED_TELEGRAM_USER_IDS=USER_ID_1,USER_ID_2 # SHOW_USAGE=false # STREAM=true # MAX_TOKENS=1200 +# VISION_MAX_TOKENS=300 # MAX_HISTORY_SIZE=15 # MAX_CONVERSATION_AGE_MINUTES=180 # VOICE_REPLY_WITH_TRANSCRIPT_ONLY=true # VOICE_REPLY_PROMPTS="Hi bot;Hey bot;Hi chat;Hey chat" +# VISION_PROMPT="What is in this image" # N_CHOICES=1 # TEMPERATURE=1.0 # PRESENCE_PENALTY=0.0 @@ -41,9 +45,13 @@ ALLOWED_TELEGRAM_USER_IDS=USER_ID_1,USER_ID_2 # IMAGE_STYLE=natural # IMAGE_SIZE=1024x1024 # IMAGE_FORMAT=document +# VISION_DETAIL="low" # GROUP_TRIGGER_KEYWORD="" # IGNORE_GROUP_TRANSCRIPTIONS=true +# IGNORE_GROUP_VISION=true # TTS_MODEL="tts-1" # TTS_VOICE="alloy" # TTS_PRICES=0.015,0.030 -# BOT_LANGUAGE=en \ No newline at end of file +# BOT_LANGUAGE=en +# ENABLE_VISION_FOLLOW_UP_QUESTIONS="true" +# VISION_MODEL="gpt-4-vision-preview" \ No newline at end of file diff --git a/README.md b/README.md index 435499d..0285567 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ The following parameters are optional and can be set in the `.env` file: | `TOKEN_PRICE` | $-price per 1000 tokens used to compute cost information in usage statistics. Source: https://openai.com/pricing | `0.002` | | `IMAGE_PRICES` | A comma-separated list with 3 elements of prices for the different image sizes: `256x256`, `512x512` and `1024x1024`. Source: https://openai.com/pricing | `0.016,0.018,0.02` | | `TRANSCRIPTION_PRICE` | USD-price for one minute of audio transcription. Source: https://openai.com/pricing | `0.006` | +| `VISION_TOKEN_PRICE` | USD-price per 1K tokens of image interpretation. Source: https://openai.com/pricing | `0.01` | | `TTS_PRICES` | A comma-separated list with prices for the tts models: `tts-1`, `tts-1-hd`. Source: https://openai.com/pricing | `0.015,0.030` | Check out the [Budget Manual](https://github.com/n3d1117/chatgpt-telegram-bot/discussions/184) for possible budget configurations. @@ -86,6 +87,7 @@ Check out the [Budget Manual](https://github.com/n3d1117/chatgpt-telegram-bot/di | `ENABLE_IMAGE_GENERATION` | Whether to enable image generation via the `/image` command | `true` | | `ENABLE_TRANSCRIPTION` | Whether to enable transcriptions of audio and video messages | `true` | | `ENABLE_TTS_GENERATION` | Whether to enable text to speech generation via the `/tts` | `true` | +| `ENABLE_VISION` | Whether to enable vision capabilities in supported models | `true` | | `PROXY` | Proxy to be used for OpenAI and Telegram bot (e.g. `http://localhost:8080`) | - | | `OPENAI_PROXY` | Proxy to be used only for OpenAI (e.g. `http://localhost:8080`) | - | | `TELEGRAM_PROXY` | Proxy to be used only for Telegram bot (e.g. `http://localhost:8080`) | - | @@ -95,10 +97,14 @@ Check out the [Budget Manual](https://github.com/n3d1117/chatgpt-telegram-bot/di | `SHOW_USAGE` | Whether to show OpenAI token usage information after each response | `false` | | `STREAM` | Whether to stream responses. **Note**: incompatible, if enabled, with `N_CHOICES` higher than 1 | `true` | | `MAX_TOKENS` | Upper bound on how many tokens the ChatGPT API will return | `1200` for GPT-3, `2400` for GPT-4 | +| `VISION_MAX_TOKENS` | Upper bound on how many tokens vision models will return | `300` for gpt-4-vision-preview | +| `VISION_MODEL` | The Vision to Speech model to use. Allowed values: `gpt-4-vision-preview` | `gpt-4-vision-preview` | +| `ENABLE_VISION_FOLLOW_UP_QUESTIONS` | If true, once you send an image to the bot, it uses the configured VISION_MODEL until the conversation ends. Otherwise, it uses the OPENAI_MODEL to follow the conversation. Allowed values: `true` or `false` | `true` | | `MAX_HISTORY_SIZE` | Max number of messages to keep in memory, after which the conversation will be summarised to avoid excessive token usage | `15` | | `MAX_CONVERSATION_AGE_MINUTES` | Maximum number of minutes a conversation should live since the last message, after which the conversation will be reset | `180` | | `VOICE_REPLY_WITH_TRANSCRIPT_ONLY` | Whether to answer to voice messages with the transcript only or with a ChatGPT response of the transcript | `false` | | `VOICE_REPLY_PROMPTS` | A semicolon separated list of phrases (i.e. `Hi bot;Hello chat`). If the transcript starts with any of them, it will be treated as a prompt even if `VOICE_REPLY_WITH_TRANSCRIPT_ONLY` is set to `true` | - | +| `VISION_PROMPT` | A phrase (i.e. `What is in this image`). The vision models use it as prompt to interpret a given image. If there is caption in the image sent to the bot, that supersedes this parameter | `What is in this image` | | `N_CHOICES` | Number of answers to generate for each input message. **Note**: setting this to a number higher than 1 will not work properly if `STREAM` is enabled | `1` | | `TEMPERATURE` | Number between 0 and 2. Higher values will make the output more random | `1.0` | | `PRESENCE_PENALTY` | Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far | `0.0` | @@ -108,8 +114,10 @@ Check out the [Budget Manual](https://github.com/n3d1117/chatgpt-telegram-bot/di | `IMAGE_QUALITY` | Quality of DALL·E images, only available for `dall-e-3`-model. Possible options: `standard` or `hd`, beware of [pricing differences](https://openai.com/pricing#image-models). | `standard` | | `IMAGE_STYLE` | Style for DALL·E image generation, only available for `dall-e-3`-model. Possible options: `vivid` or `natural`. Check availbe styles [here](https://platform.openai.com/docs/api-reference/images/create). | `vivid` | | `IMAGE_SIZE` | The DALL·E generated image size. Must be `256x256`, `512x512`, or `1024x1024` for dall-e-2. Must be `1024x1024` for dall-e-3 models. | `512x512` | +| `VISION_DETAIL` | The detail parameter for vision models, explained [Vision Guide](https://platform.openai.com/docs/guides/vision). Allowed values: `low` or `high` | `auto` | | `GROUP_TRIGGER_KEYWORD` | If set, the bot in group chats will only respond to messages that start with this keyword | - | | `IGNORE_GROUP_TRANSCRIPTIONS` | If set to true, the bot will not process transcriptions in group chats | `true` | +| `IGNORE_GROUP_VISION` | If set to true, the bot will not process vision queries in group chats | `true` | | `BOT_LANGUAGE` | Language of general bot messages. Currently available: `en`, `de`, `ru`, `tr`, `it`, `fi`, `es`, `id`, `nl`, `zh-cn`, `zh-tw`, `vi`, `fa`, `pt-br`, `uk`, `ms`, `uz`. [Contribute with additional translations](https://github.com/n3d1117/chatgpt-telegram-bot/discussions/219) | `en` | | `WHISPER_PROMPT` | To improve the accuracy of Whisper's transcription service, especially for specific names or terms, you can set up a custom message. [Speech to text - Prompting](https://platform.openai.com/docs/guides/speech-to-text/prompting) | `-` | | `TTS_VOICE` | The Text to Speech voice to use. Allowed values: `alloy`, `echo`, `fable`, `onyx`, `nova`, or `shimmer` | `alloy` | diff --git a/bot/main.py b/bot/main.py index ed6c901..41604c8 100644 --- a/bot/main.py +++ b/bot/main.py @@ -53,6 +53,11 @@ def main(): 'bot_language': os.environ.get('BOT_LANGUAGE', 'en'), 'show_plugins_used': os.environ.get('SHOW_PLUGINS_USED', 'false').lower() == 'true', 'whisper_prompt': os.environ.get('WHISPER_PROMPT', ''), + 'vision_model': os.environ.get('VISION_MODEL', 'gpt-4-vision-preview'), + 'enable_vision_follow_up_questions': os.environ.get('ENABLE_VISION_FOLLOW_UP_QUESTIONS', 'true').lower() == 'true', + 'vision_prompt': os.environ.get('VISION_PROMPT', 'What is in this image'), + 'vision_detail': os.environ.get('VISION_DETAIL', 'auto'), + 'vision_max_tokens': int(os.environ.get('VISION_MAX_TOKENS', '300')), 'tts_model': os.environ.get('TTS_MODEL', 'tts-1'), 'tts_voice': os.environ.get('TTS_VOICE', 'alloy'), } @@ -75,6 +80,7 @@ def main(): 'enable_quoting': os.environ.get('ENABLE_QUOTING', 'true').lower() == 'true', 'enable_image_generation': os.environ.get('ENABLE_IMAGE_GENERATION', 'true').lower() == 'true', 'enable_transcription': os.environ.get('ENABLE_TRANSCRIPTION', 'true').lower() == 'true', + 'enable_vision': os.environ.get('ENABLE_VISION', 'true').lower() == 'true', 'enable_tts_generation': os.environ.get('ENABLE_TTS_GENERATION', 'true').lower() == 'true', 'budget_period': os.environ.get('BUDGET_PERIOD', 'monthly').lower(), 'user_budgets': os.environ.get('USER_BUDGETS', os.environ.get('MONTHLY_USER_BUDGETS', '*')), @@ -84,9 +90,11 @@ def main(): 'voice_reply_transcript': os.environ.get('VOICE_REPLY_WITH_TRANSCRIPT_ONLY', 'false').lower() == 'true', 'voice_reply_prompts': os.environ.get('VOICE_REPLY_PROMPTS', '').split(';'), 'ignore_group_transcriptions': os.environ.get('IGNORE_GROUP_TRANSCRIPTIONS', 'true').lower() == 'true', + 'ignore_group_vision': os.environ.get('IGNORE_GROUP_VISION', 'true').lower() == 'true', 'group_trigger_keyword': os.environ.get('GROUP_TRIGGER_KEYWORD', ''), 'token_price': float(os.environ.get('TOKEN_PRICE', 0.002)), 'image_prices': [float(i) for i in os.environ.get('IMAGE_PRICES', "0.016,0.018,0.02").split(",")], + 'vision_token_price': float(os.environ.get('VISION_TOKEN_PRICE', '0.01')), 'image_receive_mode': os.environ.get('IMAGE_FORMAT', "photo"), 'tts_model': os.environ.get('TTS_MODEL', 'tts-1'), 'tts_prices': [float(i) for i in os.environ.get('TTS_PRICES', "0.015,0.030").split(",")], diff --git a/bot/openai_helper.py b/bot/openai_helper.py index 43e1d5a..6346eff 100644 --- a/bot/openai_helper.py +++ b/bot/openai_helper.py @@ -13,10 +13,11 @@ import httpx import io from datetime import date from calendar import monthrange +from PIL import Image from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type -from utils import is_direct_result +from utils import is_direct_result, encode_image, decode_image from plugin_manager import PluginManager # Models can be found here: https://platform.openai.com/docs/models/overview @@ -24,8 +25,9 @@ GPT_3_MODELS = ("gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613") GPT_3_16K_MODELS = ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-1106") GPT_4_MODELS = ("gpt-4", "gpt-4-0314", "gpt-4-0613") GPT_4_32K_MODELS = ("gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613") +GPT_4_VISION_MODELS = ("gpt-4-vision-preview",) GPT_4_128K_MODELS = ("gpt-4-1106-preview",) -GPT_ALL_MODELS = GPT_3_MODELS + GPT_3_16K_MODELS + GPT_4_MODELS + GPT_4_32K_MODELS + GPT_4_128K_MODELS +GPT_ALL_MODELS = GPT_3_MODELS + GPT_3_16K_MODELS + GPT_4_MODELS + GPT_4_32K_MODELS + GPT_4_VISION_MODELS + GPT_4_128K_MODELS def default_max_tokens(model: str) -> int: @@ -45,6 +47,8 @@ def default_max_tokens(model: str) -> int: return base * 4 elif model in GPT_4_32K_MODELS: return base * 8 + elif model in GPT_4_VISION_MODELS: + return 4096 elif model in GPT_4_128K_MODELS: return 4096 @@ -59,6 +63,8 @@ def are_functions_available(model: str) -> bool: # Stable models will be updated to support functions on June 27, 2023 if model in ("gpt-3.5-turbo", "gpt-3.5-turbo-1106", "gpt-4", "gpt-4-32k","gpt-4-1106-preview"): return datetime.date.today() > datetime.date(2023, 6, 27) + if model == 'gpt-4-vision-preview': + return False return True @@ -103,6 +109,7 @@ class OpenAIHelper: self.config = config self.plugin_manager = plugin_manager self.conversations: dict[int: list] = {} # {chat_id: history} + self.conversations_vision: dict[int: bool] = {} # {chat_id: is_vision} self.last_updated: dict[int: datetime] = {} # {chat_id: last_update_timestamp} def get_conversation_stats(self, chat_id: int) -> tuple[int, int]: @@ -124,7 +131,7 @@ class OpenAIHelper: """ plugins_used = () response = await self.__common_get_chat_response(chat_id, query) - if self.config['enable_functions']: + if self.config['enable_functions'] and not self.conversations_vision[chat_id]: response, plugins_used = await self.__handle_function_call(chat_id, response) if is_direct_result(response): return response, '0' @@ -167,7 +174,7 @@ class OpenAIHelper: """ plugins_used = () response = await self.__common_get_chat_response(chat_id, query, stream=True) - if self.config['enable_functions']: + if self.config['enable_functions'] and not self.conversations_vision[chat_id]: response, plugins_used = await self.__handle_function_call(chat_id, response, stream=True) if is_direct_result(response): yield response, '0' @@ -236,7 +243,7 @@ class OpenAIHelper: self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:] common_args = { - 'model': self.config['model'], + 'model': self.config['model'] if not self.conversations_vision[chat_id] else self.config['vision_model'], 'messages': self.conversations[chat_id], 'temperature': self.config['temperature'], 'n': self.config['n_choices'], @@ -246,7 +253,7 @@ class OpenAIHelper: 'stream': stream } - if self.config['enable_functions']: + if self.config['enable_functions'] and not self.conversations_vision[chat_id]: functions = self.plugin_manager.get_functions_specs() if len(functions) > 0: common_args['functions'] = self.plugin_manager.get_functions_specs() @@ -378,6 +385,183 @@ class OpenAIHelper: logging.exception(e) raise Exception(f"⚠️ _{localized_text('error', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e + @retry( + reraise=True, + retry=retry_if_exception_type(openai.RateLimitError), + wait=wait_fixed(20), + stop=stop_after_attempt(3) + ) + async def __common_get_chat_response_vision(self, chat_id: int, content: list, stream=False): + """ + Request a response from the GPT model. + :param chat_id: The chat ID + :param query: The query to send to the model + :return: The answer from the model and the number of tokens used + """ + bot_language = self.config['bot_language'] + try: + if chat_id not in self.conversations or self.__max_age_reached(chat_id): + self.reset_chat_history(chat_id) + + self.last_updated[chat_id] = datetime.datetime.now() + + if self.config['enable_vision_follow_up_questions']: + self.conversations_vision[chat_id] = True + self.__add_to_history(chat_id, role="user", content=content) + else: + for message in content: + if message['type'] == 'text': + query = message['text'] + break + self.__add_to_history(chat_id, role="user", content=query) + + # Summarize the chat history if it's too long to avoid excessive token usage + token_count = self.__count_tokens(self.conversations[chat_id]) + exceeded_max_tokens = token_count + self.config['max_tokens'] > self.__max_model_tokens() + exceeded_max_history_size = len(self.conversations[chat_id]) > self.config['max_history_size'] + + if exceeded_max_tokens or exceeded_max_history_size: + logging.info(f'Chat history for chat ID {chat_id} is too long. Summarising...') + try: + + last = self.conversations[chat_id][-1] + summary = await self.__summarise(self.conversations[chat_id][:-1]) + logging.debug(f'Summary: {summary}') + self.reset_chat_history(chat_id, self.conversations[chat_id][0]['content']) + self.__add_to_history(chat_id, role="assistant", content=summary) + self.conversations[chat_id] += [last] + except Exception as e: + logging.warning(f'Error while summarising chat history: {str(e)}. Popping elements instead...') + self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:] + + message = {'role':'user', 'content':content} + + common_args = { + 'model': self.config['vision_model'], + 'messages': self.conversations[chat_id][:-1] + [message], + 'temperature': self.config['temperature'], + 'n': 1, # several choices is not implemented yet + 'max_tokens': self.config['vision_max_tokens'], + 'presence_penalty': self.config['presence_penalty'], + 'frequency_penalty': self.config['frequency_penalty'], + 'stream': stream + } + + + # vision model does not yet support functions + + # if self.config['enable_functions']: + # functions = self.plugin_manager.get_functions_specs() + # if len(functions) > 0: + # common_args['functions'] = self.plugin_manager.get_functions_specs() + # common_args['function_call'] = 'auto' + + return await self.client.chat.completions.create(**common_args) + + except openai.RateLimitError as e: + raise e + + except openai.BadRequestError as e: + raise Exception(f"⚠️ _{localized_text('openai_invalid', bot_language)}._ ⚠️\n{str(e)}") from e + + except Exception as e: + raise Exception(f"⚠️ _{localized_text('error', bot_language)}._ ⚠️\n{str(e)}") from e + + + async def interpret_image(self, chat_id, fileobj, prompt=None): + """ + Interprets a given PNG image file using the Vision model. + """ + image = encode_image(fileobj) + prompt = self.config['vision_prompt'] if prompt is None else prompt + + content = [{'type':'text', 'text':prompt}, {'type':'image_url', \ + 'image_url': {'url':image, 'detail':self.config['vision_detail'] } }] + + response = await self.__common_get_chat_response_vision(chat_id, content) + + + + # functions are not available for this model + + # if self.config['enable_functions']: + # response, plugins_used = await self.__handle_function_call(chat_id, response) + # if is_direct_result(response): + # return response, '0' + + answer = '' + + if len(response.choices) > 1 and self.config['n_choices'] > 1: + for index, choice in enumerate(response.choices): + content = choice.message.content.strip() + if index == 0: + self.__add_to_history(chat_id, role="assistant", content=content) + answer += f'{index + 1}\u20e3\n' + answer += content + answer += '\n\n' + else: + answer = response.choices[0].message.content.strip() + self.__add_to_history(chat_id, role="assistant", content=answer) + + bot_language = self.config['bot_language'] + # Plugins are not enabled either + # show_plugins_used = len(plugins_used) > 0 and self.config['show_plugins_used'] + # plugin_names = tuple(self.plugin_manager.get_plugin_source_name(plugin) for plugin in plugins_used) + if self.config['show_usage']: + answer += "\n\n---\n" \ + f"💰 {str(response.usage.total_tokens)} {localized_text('stats_tokens', bot_language)}" \ + f" ({str(response.usage.prompt_tokens)} {localized_text('prompt', bot_language)}," \ + f" {str(response.usage.completion_tokens)} {localized_text('completion', bot_language)})" + # if show_plugins_used: + # answer += f"\n🔌 {', '.join(plugin_names)}" + # elif show_plugins_used: + # answer += f"\n\n---\n🔌 {', '.join(plugin_names)}" + + return answer, response.usage.total_tokens + + async def interpret_image_stream(self, chat_id, fileobj, prompt=None): + """ + Interprets a given PNG image file using the Vision model. + """ + image = encode_image(fileobj) + prompt = self.config['vision_prompt'] if prompt is None else prompt + + content = [{'type':'text', 'text':prompt}, {'type':'image_url', \ + 'image_url': {'url':image, 'detail':self.config['vision_detail'] } }] + + response = await self.__common_get_chat_response_vision(chat_id, content, stream=True) + + + + # if self.config['enable_functions']: + # response, plugins_used = await self.__handle_function_call(chat_id, response, stream=True) + # if is_direct_result(response): + # yield response, '0' + # return + + answer = '' + async for chunk in response: + if len(chunk.choices) == 0: + continue + delta = chunk.choices[0].delta + if delta.content: + answer += delta.content + yield answer, 'not_finished' + answer = answer.strip() + self.__add_to_history(chat_id, role="assistant", content=answer) + tokens_used = str(self.__count_tokens(self.conversations[chat_id])) + + #show_plugins_used = len(plugins_used) > 0 and self.config['show_plugins_used'] + #plugin_names = tuple(self.plugin_manager.get_plugin_source_name(plugin) for plugin in plugins_used) + if self.config['show_usage']: + answer += f"\n\n---\n💰 {tokens_used} {localized_text('stats_tokens', self.config['bot_language'])}" + # if show_plugins_used: + # answer += f"\n🔌 {', '.join(plugin_names)}" + # elif show_plugins_used: + # answer += f"\n\n---\n🔌 {', '.join(plugin_names)}" + + yield answer, tokens_used + def reset_chat_history(self, chat_id, content=''): """ Resets the conversation history. @@ -385,6 +569,7 @@ class OpenAIHelper: if content == '': content = self.config['assistant_prompt'] self.conversations[chat_id] = [{"role": "system", "content": content}] + self.conversations_vision[chat_id] = False def __max_age_reached(self, chat_id) -> bool: """ @@ -441,6 +626,8 @@ class OpenAIHelper: return base * 2 if self.config['model'] in GPT_4_32K_MODELS: return base * 8 + if self.config['model'] in GPT_4_VISION_MODELS: + return base * 31 if self.config['model'] in GPT_4_128K_MODELS: return base * 31 raise NotImplementedError( @@ -463,7 +650,7 @@ class OpenAIHelper: if model in GPT_3_MODELS + GPT_3_16K_MODELS: tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n tokens_per_name = -1 # if there's a name, the role is omitted - elif model in GPT_4_MODELS + GPT_4_32K_MODELS + GPT_4_128K_MODELS: + elif model in GPT_4_MODELS + GPT_4_32K_MODELS + GPT_4_VISION_MODELS + GPT_4_128K_MODELS: tokens_per_message = 3 tokens_per_name = 1 else: @@ -472,12 +659,55 @@ class OpenAIHelper: for message in messages: num_tokens += tokens_per_message for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += tokens_per_name + if key == 'content': + if isinstance(value, str): + num_tokens += len(encoding.encode(value)) + else: + for message1 in value: + if message1['type'] == 'image_url': + image = decode_image(message1['image_url']['url']) + num_tokens += self.__count_tokens_vision(image) + else: + num_tokens += len(encoding.encode(message1['text'])) + else: + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens + # no longer needed + + def __count_tokens_vision(self, image_bytes: bytes) -> int: + """ + Counts the number of tokens for interpreting an image. + :param image_bytes: image to interpret + :return: the number of tokens required + """ + image_file = io.BytesIO(image_bytes) + image = Image.open(image_file) + model = self.config['vision_model'] + if model not in GPT_4_VISION_MODELS: + raise NotImplementedError(f"""count_tokens_vision() is not implemented for model {model}.""") + + w, h = image.size + if w > h: w, h = h, w + # this computation follows https://platform.openai.com/docs/guides/vision and https://openai.com/pricing#gpt-4-turbo + base_tokens = 85 + detail = self.config['vision_detail'] + if detail == 'low': + return base_tokens + elif detail == 'high' or detail == 'auto': # assuming worst cost for auto + f = max(w / 768, h / 2048) + if f > 1: + w, h = int(w / f), int(h / f) + tw, th = (w + 511) // 512, (h + 511) // 512 + tiles = tw * th + num_tokens = base_tokens + tiles * 170 + return num_tokens + else: + raise NotImplementedError(f"""unknown parameter detail={detail} for model {model}.""") + # No longer works as of July 21st 2023, as OpenAI has removed the billing API # def get_billing_current_month(self): # """Gets billed usage for current month from OpenAI API. diff --git a/bot/telegram_bot.py b/bot/telegram_bot.py index 922faa9..7a536b1 100644 --- a/bot/telegram_bot.py +++ b/bot/telegram_bot.py @@ -3,16 +3,18 @@ from __future__ import annotations import asyncio import logging import os +import io from uuid import uuid4 from telegram import BotCommandScopeAllGroupChats, Update, constants from telegram import InlineKeyboardMarkup, InlineKeyboardButton, InlineQueryResultArticle from telegram import InputTextMessageContent, BotCommand -from telegram.error import RetryAfter, TimedOut +from telegram.error import RetryAfter, TimedOut, BadRequest from telegram.ext import ApplicationBuilder, CommandHandler, MessageHandler, \ filters, InlineQueryHandler, CallbackQueryHandler, Application, ContextTypes, CallbackContext from pydub import AudioSegment +from PIL import Image from utils import is_group_chat, get_thread_id, message_text, wrap_with_indicator, split_into_chunks, \ edit_message_with_retry, get_stream_cutoff_values, is_allowed, get_remaining_budget, is_admin, is_within_budget, \ @@ -97,6 +99,7 @@ class ChatGPTTelegramBot: images_today, images_month = self.usage[user_id].get_current_image_count() (transcribe_minutes_today, transcribe_seconds_today, transcribe_minutes_month, transcribe_seconds_month) = self.usage[user_id].get_current_transcription_duration() + vision_today, vision_month = self.usage[user_id].get_current_vision_tokens() characters_today, characters_month = self.usage[user_id].get_current_tts_usage() current_cost = self.usage[user_id].get_current_cost() @@ -117,6 +120,10 @@ class ChatGPTTelegramBot: if self.config.get('enable_image_generation', False): text_today_images = f"{images_today} {localized_text('stats_images', bot_language)}\n" + text_today_vision = "" + if self.config.get('enable_vision', False): + text_today_vision = f"{vision_today} {localized_text('stats_vision', bot_language)}\n" + text_today_tts = "" if self.config.get('enable_tts_generation', False): text_today_tts = f"{characters_today} {localized_text('stats_tts', bot_language)}\n" @@ -125,6 +132,7 @@ class ChatGPTTelegramBot: f"*{localized_text('usage_today', bot_language)}:*\n" f"{tokens_today} {localized_text('stats_tokens', bot_language)}\n" f"{text_today_images}" # Include the image statistics for today if applicable + f"{text_today_vision}" f"{text_today_tts}" f"{transcribe_minutes_today} {localized_text('stats_transcribe', bot_language)[0]} " f"{transcribe_seconds_today} {localized_text('stats_transcribe', bot_language)[1]}\n" @@ -136,6 +144,10 @@ class ChatGPTTelegramBot: if self.config.get('enable_image_generation', False): text_month_images = f"{images_month} {localized_text('stats_images', bot_language)}\n" + text_month_vision = "" + if self.config.get('enable_vision', False): + text_month_vision = f"{vision_month} {localized_text('stats_vision', bot_language)}\n" + text_month_tts = "" if self.config.get('enable_tts_generation', False): text_month_tts = f"{characters_month} {localized_text('stats_tts', bot_language)}\n" @@ -145,6 +157,7 @@ class ChatGPTTelegramBot: f"*{localized_text('usage_month', bot_language)}:*\n" f"{tokens_month} {localized_text('stats_tokens', bot_language)}\n" f"{text_month_images}" # Include the image statistics for the month if applicable + f"{text_month_vision}" f"{text_month_tts}" f"{transcribe_minutes_month} {localized_text('stats_transcribe', bot_language)[0]} " f"{transcribe_seconds_month} {localized_text('stats_transcribe', bot_language)[1]}\n" @@ -438,6 +451,198 @@ class ChatGPTTelegramBot: await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING) + async def vision(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + """ + Interpret image using vision model. + """ + if not self.config['enable_vision'] or not await self.check_allowed_and_within_budget(update, context): + return + + chat_id = update.effective_chat.id + prompt = update.message.caption + + if is_group_chat(update): + if self.config['ignore_group_vision']: + logging.info(f'Vision coming from group chat, ignoring...') + return + else: + trigger_keyword = self.config['group_trigger_keyword'] + if (prompt is None and trigger_keyword != '') or \ + (prompt is not None and not prompt.lower().startswith(trigger_keyword.lower())): + logging.info(f'Vision coming from group chat with wrong keyword, ignoring...') + return + + image = update.message.effective_attachment[-1] + + + async def _execute(): + bot_language = self.config['bot_language'] + try: + media_file = await context.bot.get_file(image.file_id) + temp_file = io.BytesIO(await media_file.download_as_bytearray()) + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=( + f"{localized_text('media_download_fail', bot_language)[0]}: " + f"{str(e)}. {localized_text('media_download_fail', bot_language)[1]}" + ), + parse_mode=constants.ParseMode.MARKDOWN + ) + return + + # convert jpg from telegram to png as understood by openai + + temp_file_png = io.BytesIO() + + try: + original_image = Image.open(temp_file) + + original_image.save(temp_file_png, format='PNG') + logging.info(f'New vision request received from user {update.message.from_user.name} ' + f'(id: {update.message.from_user.id})') + + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=localized_text('media_type_fail', bot_language) + ) + + + + user_id = update.message.from_user.id + if user_id not in self.usage: + self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name) + + if self.config['stream']: + + stream_response = self.openai.interpret_image_stream(chat_id=chat_id, fileobj=temp_file_png, prompt=prompt) + i = 0 + prev = '' + sent_message = None + backoff = 0 + stream_chunk = 0 + + async for content, tokens in stream_response: + if is_direct_result(content): + return await handle_direct_result(self.config, update, content) + + if len(content.strip()) == 0: + continue + + stream_chunks = split_into_chunks(content) + if len(stream_chunks) > 1: + content = stream_chunks[-1] + if stream_chunk != len(stream_chunks) - 1: + stream_chunk += 1 + try: + await edit_message_with_retry(context, chat_id, str(sent_message.message_id), + stream_chunks[-2]) + except: + pass + try: + sent_message = await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + text=content if len(content) > 0 else "..." + ) + except: + pass + continue + + cutoff = get_stream_cutoff_values(update, content) + cutoff += backoff + + if i == 0: + try: + if sent_message is not None: + await context.bot.delete_message(chat_id=sent_message.chat_id, + message_id=sent_message.message_id) + sent_message = await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=content, + ) + except: + continue + + elif abs(len(content) - len(prev)) > cutoff or tokens != 'not_finished': + prev = content + + try: + use_markdown = tokens != 'not_finished' + await edit_message_with_retry(context, chat_id, str(sent_message.message_id), + text=content, markdown=use_markdown) + + except RetryAfter as e: + backoff += 5 + await asyncio.sleep(e.retry_after) + continue + + except TimedOut: + backoff += 5 + await asyncio.sleep(0.5) + continue + + except Exception: + backoff += 5 + continue + + await asyncio.sleep(0.01) + + i += 1 + if tokens != 'not_finished': + total_tokens = int(tokens) + + + else: + + try: + interpretation, total_tokens = await self.openai.interpret_image(chat_id, temp_file_png, prompt=prompt) + + + try: + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=interpretation, + parse_mode=constants.ParseMode.MARKDOWN + ) + except BadRequest: + try: + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=interpretation + ) + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=f"{localized_text('vision_fail', bot_language)}: {str(e)}", + parse_mode=constants.ParseMode.MARKDOWN + ) + except Exception as e: + logging.exception(e) + await update.effective_message.reply_text( + message_thread_id=get_thread_id(update), + reply_to_message_id=get_reply_to_message_id(self.config, update), + text=f"{localized_text('vision_fail', bot_language)}: {str(e)}", + parse_mode=constants.ParseMode.MARKDOWN + ) + vision_token_price = self.config['vision_token_price'] + self.usage[user_id].add_vision_tokens(total_tokens, vision_token_price) + + allowed_user_ids = self.config['allowed_user_ids'].split(',') + if str(user_id) not in allowed_user_ids and 'guests' in self.usage: + self.usage["guests"].add_vision_tokens(total_tokens, vision_token_price) + + await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING) + async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ React to incoming messages and respond accordingly. @@ -861,6 +1066,9 @@ class ChatGPTTelegramBot: application.add_handler(CommandHandler( 'chat', self.prompt, filters=filters.ChatType.GROUP | filters.ChatType.SUPERGROUP) ) + application.add_handler(MessageHandler( + filters.PHOTO | filters.Document.IMAGE, + self.vision)) application.add_handler(MessageHandler( filters.AUDIO | filters.VOICE | filters.Document.AUDIO | filters.VIDEO | filters.VIDEO_NOTE | filters.Document.VIDEO, diff --git a/bot/usage_tracker.py b/bot/usage_tracker.py index fd907f3..58cef70 100644 --- a/bot/usage_tracker.py +++ b/bot/usage_tracker.py @@ -56,6 +56,8 @@ class UsageTracker: if os.path.isfile(self.user_file): with open(self.user_file, "r") as file: self.usage = json.load(file) + if 'vision_tokens' not in self.usage['usage_history']: + self.usage['usage_history']['vision_tokens'] = {} if 'tts_characters' not in self.usage['usage_history']: self.usage['usage_history']['tts_characters'] = {} else: @@ -65,7 +67,7 @@ class UsageTracker: self.usage = { "user_name": user_name, "current_cost": {"day": 0.0, "month": 0.0, "all_time": 0.0, "last_update": str(date.today())}, - "usage_history": {"chat_tokens": {}, "transcription_seconds": {}, "number_images": {}, "tts_characters": {}} + "usage_history": {"chat_tokens": {}, "transcription_seconds": {}, "number_images": {}, "tts_characters": {}, "vision_tokens":{}} } # token usage functions: @@ -153,6 +155,47 @@ class UsageTracker: usage_month += sum(images) return usage_day, usage_month + + # vision usage functions + def add_vision_tokens(self, tokens, vision_token_price=0.01): + """ + Adds requested vision tokens to a users usage history and updates current cost. + :param tokens: total tokens used in last request + :param vision_token_price: price per 1K tokens transcription, defaults to 0.01 + """ + today = date.today() + token_price = round(tokens * vision_token_price / 1000, 2) + self.add_current_costs(token_price) + + # update usage_history + if str(today) in self.usage["usage_history"]["vision_tokens"]: + # add requested seconds to existing date + self.usage["usage_history"]["vision_tokens"][str(today)] += tokens + else: + # create new entry for current date + self.usage["usage_history"]["vision_tokens"][str(today)] = tokens + + # write updated token usage to user file + with open(self.user_file, "w") as outfile: + json.dump(self.usage, outfile) + + def get_current_vision_tokens(self): + """Get vision tokens for today and this month. + + :return: total amount of vision tokens per day and per month + """ + today = date.today() + if str(today) in self.usage["usage_history"]["vision_tokens"]: + tokens_day = self.usage["usage_history"]["vision_tokens"][str(today)] + else: + tokens_day = 0 + month = str(today)[:7] # year-month as string + tokens_month = 0 + for today, tokens in self.usage["usage_history"]["vision_tokens"].items(): + if today.startswith(month): + tokens_month += tokens + return tokens_day, tokens_month + # tts usage functions: def add_tts_request(self, text_length, tts_model, tts_prices): @@ -289,14 +332,15 @@ class UsageTracker: cost_all_time = self.usage["current_cost"].get("all_time", self.initialize_all_time_cost()) return {"cost_today": cost_day, "cost_month": cost_month, "cost_all_time": cost_all_time} - def initialize_all_time_cost(self, tokens_price=0.002, image_prices="0.016,0.018,0.02", minute_price=0.006, tts_prices='0.015,0.030'): + def initialize_all_time_cost(self, tokens_price=0.002, image_prices="0.016,0.018,0.02", minute_price=0.006, vision_token_price=0.01, tts_prices='0.015,0.030'): """Get total USD amount of all requests in history :param tokens_price: price per 1000 tokens, defaults to 0.002 :param image_prices: prices for images of sizes ["256x256", "512x512", "1024x1024"], defaults to [0.016, 0.018, 0.02] :param minute_price: price per minute transcription, defaults to 0.006 - :param character_price: price per character tts per model ['tts-1', 'tts-1-hd'], defaults to [0.015, 0.030] + :param vision_token_price: price per 1K vision token interpretation, defaults to 0.01 + :param tts_prices: price per 1K characters tts per model ['tts-1', 'tts-1-hd'], defaults to [0.015, 0.030] :return: total cost of all requests """ total_tokens = sum(self.usage['usage_history']['chat_tokens'].values()) @@ -309,9 +353,12 @@ class UsageTracker: total_transcription_seconds = sum(self.usage['usage_history']['transcription_seconds'].values()) transcription_cost = round(total_transcription_seconds * minute_price / 60, 2) + total_vision_tokens = sum(self.usage['usage_history']['vision_tokens'].values()) + vision_cost = round(total_vision_tokens * vision_token_price / 1000, 2) + total_characters = [sum(tts_model.values()) for tts_model in self.usage['usage_history']['tts_characters'].values()] tts_prices_list = [float(x) for x in tts_prices.split(',')] tts_cost = round(sum([count * price / 1000 for count, price in zip(total_characters, tts_prices_list)]), 2) - all_time_cost = token_cost + transcription_cost + image_cost + tts_cost + all_time_cost = token_cost + transcription_cost + image_cost + vision_cost + tts_cost return all_time_cost diff --git a/bot/utils.py b/bot/utils.py index 6ce2e98..d306dc6 100644 --- a/bot/utils.py +++ b/bot/utils.py @@ -5,6 +5,7 @@ import itertools import json import logging import os +import base64 import telegram from telegram import Message, MessageEntity, Update, ChatMember, constants @@ -376,4 +377,14 @@ def cleanup_intermediate_files(response: any): if format == 'path': if os.path.exists(value): - os.remove(value) \ No newline at end of file + os.remove(value) + + +# Function to encode the image +def encode_image(fileobj): + image = base64.b64encode(fileobj.getvalue()).decode('utf-8') + return f'data:image/jpeg;base64,{image}' + +def decode_image(imgbase64): + image = imgbase64[len('data:image/jpeg;base64,'):] + return base64.b64decode(image) diff --git a/requirements.txt b/requirements.txt index 83b4367..6d9ff92 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ spotipy~=2.23.0 pytube~=15.0.0 gtts~=2.3.2 whois~=0.9.27 +Pillow~=10.1.0 diff --git a/translations.json b/translations.json index 7277c51..7f36a72 100644 --- a/translations.json +++ b/translations.json @@ -15,6 +15,7 @@ "usage_month":"Usage this month", "stats_tokens":"tokens", "stats_images":"images generated", + "stats_vision":"image tokens interpreted", "stats_tts":"characters converted to speech", "stats_transcribe":["minutes and", "seconds transcribed"], "stats_total":"💰 For a total amount of $", @@ -27,6 +28,7 @@ "reset_done":"Done!", "image_no_prompt":"Please provide a prompt! (e.g. /image cat)", "image_fail":"Failed to generate image", + "vision_fail":"Failed to interpret image", "tts_no_prompt":"Please provide text! (e.g. /tts my house)", "tts_fail":"Failed to generate speech", "media_download_fail":["Failed to download audio file", "Make sure the file is not too large. (max 20MB)"],