change to use image files in memory

This commit is contained in:
gilcu3
2023-11-09 18:18:15 +01:00
parent f69f9e5034
commit b2f2114e36
3 changed files with 13 additions and 18 deletions

View File

@@ -353,12 +353,12 @@ class OpenAIHelper:
logging.exception(e) logging.exception(e)
raise Exception(f"⚠️ _{localized_text('error', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e raise Exception(f"⚠️ _{localized_text('error', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e
async def interpret_image(self, chat_id, filename, prompt=None): async def interpret_image(self, chat_id, fileobj, prompt=None):
""" """
Interprets a given PNG image file using the Vision model. Interprets a given PNG image file using the Vision model.
""" """
try: try:
image = encode_image(filename) image = encode_image(fileobj)
prompt = self.config['vision_prompt'] if prompt is None else prompt prompt = self.config['vision_prompt'] if prompt is None else prompt
# for now I am not adding the image itself to the history # for now I am not adding the image itself to the history
@@ -390,7 +390,7 @@ class OpenAIHelper:
self.__add_to_history(chat_id, role="assistant", content=content) self.__add_to_history(chat_id, role="assistant", content=content)
return content, self.__count_tokens_vision(filename) return content, self.__count_tokens_vision(fileobj)
except openai.RateLimitError as e: except openai.RateLimitError as e:
raise e raise e
@@ -500,13 +500,13 @@ class OpenAIHelper:
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens return num_tokens
def __count_tokens_vision(self, filename) -> int: def __count_tokens_vision(self, fileobj) -> int:
""" """
Counts the number of tokens for interpreting an image. Counts the number of tokens for interpreting an image.
:param image: image to interpret :param image: image to interpret
:return: the number of tokens required :return: the number of tokens required
""" """
image = Image.open(filename) image = Image.open(fileobj)
model = self.config['model'] model = self.config['model']
if model not in GPT_4_VISION_MODELS: if model not in GPT_4_VISION_MODELS:
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}.""") raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}.""")

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os import os
import tempfile import io
from uuid import uuid4 from uuid import uuid4
from telegram import BotCommandScopeAllGroupChats, Update, constants from telegram import BotCommandScopeAllGroupChats, Update, constants
@@ -406,13 +406,12 @@ class ChatGPTTelegramBot:
image = update.message.effective_attachment[-1] image = update.message.effective_attachment[-1]
temp_file = tempfile.NamedTemporaryFile()
async def _execute(): async def _execute():
bot_language = self.config['bot_language'] bot_language = self.config['bot_language']
try: try:
media_file = await context.bot.get_file(image.file_id) media_file = await context.bot.get_file(image.file_id)
await media_file.download_to_drive(temp_file.name) temp_file = io.BytesIO(await media_file.download_as_bytearray())
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
await update.effective_message.reply_text( await update.effective_message.reply_text(
@@ -428,12 +427,12 @@ class ChatGPTTelegramBot:
# convert jpg from telegram to png as understood by openai # convert jpg from telegram to png as understood by openai
temp_file_png = tempfile.NamedTemporaryFile() temp_file_png = io.BytesIO()
try: try:
original_image = Image.open(temp_file.name) original_image = Image.open(temp_file)
original_image.save(temp_file_png.name, format='PNG') original_image.save(temp_file_png, format='PNG')
logging.info(f'New vision request received from user {update.message.from_user.name} ' logging.info(f'New vision request received from user {update.message.from_user.name} '
f'(id: {update.message.from_user.id})') f'(id: {update.message.from_user.id})')
@@ -452,7 +451,7 @@ class ChatGPTTelegramBot:
self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name) self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name)
try: try:
interpretation, tokens = await self.openai.interpret_image(chat_id, temp_file_png.name, prompt=prompt) interpretation, tokens = await self.openai.interpret_image(chat_id, temp_file_png, prompt=prompt)
vision_token_price = self.config['vision_token_price'] vision_token_price = self.config['vision_token_price']
self.usage[user_id].add_vision_tokens(tokens, vision_token_price) self.usage[user_id].add_vision_tokens(tokens, vision_token_price)
@@ -477,9 +476,6 @@ class ChatGPTTelegramBot:
text=f"{localized_text('vision_fail', bot_language)}: {str(e)}", text=f"{localized_text('vision_fail', bot_language)}: {str(e)}",
parse_mode=constants.ParseMode.MARKDOWN parse_mode=constants.ParseMode.MARKDOWN
) )
finally:
temp_file.close()
temp_file_png.close()
await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING) await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING)

View File

@@ -381,6 +381,5 @@ def cleanup_intermediate_files(response: any):
# Function to encode the image # Function to encode the image
def encode_image(image_path): def encode_image(fileobj):
with open(image_path, "rb") as image_file: return base64.b64encode(fileobj.getvalue()).decode('utf-8')
return base64.b64encode(image_file.read()).decode('utf-8')