mirror of
https://github.com/aljazceru/chatgpt-telegram-bot.git
synced 2025-12-23 07:35:06 +01:00
change to use image files in memory
This commit is contained in:
@@ -353,12 +353,12 @@ class OpenAIHelper:
|
|||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
raise Exception(f"⚠️ _{localized_text('error', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e
|
raise Exception(f"⚠️ _{localized_text('error', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e
|
||||||
|
|
||||||
async def interpret_image(self, chat_id, filename, prompt=None):
|
async def interpret_image(self, chat_id, fileobj, prompt=None):
|
||||||
"""
|
"""
|
||||||
Interprets a given PNG image file using the Vision model.
|
Interprets a given PNG image file using the Vision model.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
image = encode_image(filename)
|
image = encode_image(fileobj)
|
||||||
prompt = self.config['vision_prompt'] if prompt is None else prompt
|
prompt = self.config['vision_prompt'] if prompt is None else prompt
|
||||||
|
|
||||||
# for now I am not adding the image itself to the history
|
# for now I am not adding the image itself to the history
|
||||||
@@ -390,7 +390,7 @@ class OpenAIHelper:
|
|||||||
self.__add_to_history(chat_id, role="assistant", content=content)
|
self.__add_to_history(chat_id, role="assistant", content=content)
|
||||||
|
|
||||||
|
|
||||||
return content, self.__count_tokens_vision(filename)
|
return content, self.__count_tokens_vision(fileobj)
|
||||||
|
|
||||||
except openai.RateLimitError as e:
|
except openai.RateLimitError as e:
|
||||||
raise e
|
raise e
|
||||||
@@ -500,13 +500,13 @@ class OpenAIHelper:
|
|||||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def __count_tokens_vision(self, filename) -> int:
|
def __count_tokens_vision(self, fileobj) -> int:
|
||||||
"""
|
"""
|
||||||
Counts the number of tokens for interpreting an image.
|
Counts the number of tokens for interpreting an image.
|
||||||
:param image: image to interpret
|
:param image: image to interpret
|
||||||
:return: the number of tokens required
|
:return: the number of tokens required
|
||||||
"""
|
"""
|
||||||
image = Image.open(filename)
|
image = Image.open(fileobj)
|
||||||
model = self.config['model']
|
model = self.config['model']
|
||||||
if model not in GPT_4_VISION_MODELS:
|
if model not in GPT_4_VISION_MODELS:
|
||||||
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}.""")
|
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}.""")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import io
|
||||||
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from telegram import BotCommandScopeAllGroupChats, Update, constants
|
from telegram import BotCommandScopeAllGroupChats, Update, constants
|
||||||
@@ -406,13 +406,12 @@ class ChatGPTTelegramBot:
|
|||||||
|
|
||||||
image = update.message.effective_attachment[-1]
|
image = update.message.effective_attachment[-1]
|
||||||
|
|
||||||
temp_file = tempfile.NamedTemporaryFile()
|
|
||||||
|
|
||||||
async def _execute():
|
async def _execute():
|
||||||
bot_language = self.config['bot_language']
|
bot_language = self.config['bot_language']
|
||||||
try:
|
try:
|
||||||
media_file = await context.bot.get_file(image.file_id)
|
media_file = await context.bot.get_file(image.file_id)
|
||||||
await media_file.download_to_drive(temp_file.name)
|
temp_file = io.BytesIO(await media_file.download_as_bytearray())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
@@ -428,12 +427,12 @@ class ChatGPTTelegramBot:
|
|||||||
|
|
||||||
# convert jpg from telegram to png as understood by openai
|
# convert jpg from telegram to png as understood by openai
|
||||||
|
|
||||||
temp_file_png = tempfile.NamedTemporaryFile()
|
temp_file_png = io.BytesIO()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
original_image = Image.open(temp_file.name)
|
original_image = Image.open(temp_file)
|
||||||
|
|
||||||
original_image.save(temp_file_png.name, format='PNG')
|
original_image.save(temp_file_png, format='PNG')
|
||||||
logging.info(f'New vision request received from user {update.message.from_user.name} '
|
logging.info(f'New vision request received from user {update.message.from_user.name} '
|
||||||
f'(id: {update.message.from_user.id})')
|
f'(id: {update.message.from_user.id})')
|
||||||
|
|
||||||
@@ -452,7 +451,7 @@ class ChatGPTTelegramBot:
|
|||||||
self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name)
|
self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
interpretation, tokens = await self.openai.interpret_image(chat_id, temp_file_png.name, prompt=prompt)
|
interpretation, tokens = await self.openai.interpret_image(chat_id, temp_file_png, prompt=prompt)
|
||||||
|
|
||||||
vision_token_price = self.config['vision_token_price']
|
vision_token_price = self.config['vision_token_price']
|
||||||
self.usage[user_id].add_vision_tokens(tokens, vision_token_price)
|
self.usage[user_id].add_vision_tokens(tokens, vision_token_price)
|
||||||
@@ -477,9 +476,6 @@ class ChatGPTTelegramBot:
|
|||||||
text=f"{localized_text('vision_fail', bot_language)}: {str(e)}",
|
text=f"{localized_text('vision_fail', bot_language)}: {str(e)}",
|
||||||
parse_mode=constants.ParseMode.MARKDOWN
|
parse_mode=constants.ParseMode.MARKDOWN
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
temp_file.close()
|
|
||||||
temp_file_png.close()
|
|
||||||
|
|
||||||
await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING)
|
await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING)
|
||||||
|
|
||||||
|
|||||||
@@ -381,6 +381,5 @@ def cleanup_intermediate_files(response: any):
|
|||||||
|
|
||||||
|
|
||||||
# Function to encode the image
|
# Function to encode the image
|
||||||
def encode_image(image_path):
|
def encode_image(fileobj):
|
||||||
with open(image_path, "rb") as image_file:
|
return base64.b64encode(fileobj.getvalue()).decode('utf-8')
|
||||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
|
||||||
Reference in New Issue
Block a user