added support for vision

This commit is contained in:
gilcu3
2023-11-08 17:56:27 +01:00
parent 450b86d5a7
commit 9e7d8701cd
9 changed files with 245 additions and 8 deletions

View File

@@ -12,10 +12,11 @@ import json
import httpx
from datetime import date
from calendar import monthrange
from PIL import Image
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
from utils import is_direct_result
from utils import is_direct_result, encode_image
from plugin_manager import PluginManager
# Models can be found here: https://platform.openai.com/docs/models/overview
@@ -23,7 +24,8 @@ GPT_3_MODELS = ("gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613")
GPT_3_16K_MODELS = ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613")
GPT_4_MODELS = ("gpt-4", "gpt-4-0314", "gpt-4-0613")
GPT_4_32K_MODELS = ("gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613")
GPT_ALL_MODELS = GPT_3_MODELS + GPT_3_16K_MODELS + GPT_4_MODELS + GPT_4_32K_MODELS
GPT_4_VISION_MODELS = ("gpt-4-vision-preview",)
GPT_ALL_MODELS = GPT_3_MODELS + GPT_3_16K_MODELS + GPT_4_MODELS + GPT_4_32K_MODELS + GPT_4_VISION_MODELS
def default_max_tokens(model: str) -> int:
@@ -41,6 +43,8 @@ def default_max_tokens(model: str) -> int:
return base * 4
elif model in GPT_4_32K_MODELS:
return base * 8
elif model in GPT_4_VISION_MODELS:
return 4096
def are_functions_available(model: str) -> bool:
@@ -347,6 +351,26 @@ class OpenAIHelper:
logging.exception(e)
raise Exception(f"⚠️ _{localized_text('error', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e
async def interpret_image(self, filename, prompt=None):
"""
Interprets a given PNG image file using the Vision model.
"""
try:
image = encode_image(filename)
prompt = self.config['vision_prompt'] if prompt is None else prompt
message = {'role':'user', 'content':[{'type':'text', 'text':prompt}, {'type':'image_url', \
'image_url': {'url':f'data:image/jpeg;base64,{image}', 'detail':self.config['vision_detail'] } }]}
response = await self.client.chat.completions.create(model=self.config['model'], messages=[message], max_tokens=self.config['vision_max_tokens'])
return response.choices[0].message.content, self.__count_tokens_vision(filename)
except openai.RateLimitError as e:
raise e
except openai.BadRequestError as e:
raise Exception(f"⚠️ _{localized_text('openai_invalid', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e
except Exception as e:
logging.exception(e)
raise Exception(f"⚠️ _{localized_text('error', self.config['bot_language'])}._ ⚠️\n{str(e)}") from e
def reset_chat_history(self, chat_id, content=''):
"""
Resets the conversation history.
@@ -410,6 +434,8 @@ class OpenAIHelper:
return base * 2
if self.config['model'] in GPT_4_32K_MODELS:
return base * 8
if self.config['model'] in GPT_4_VISION_MODELS:
return base * 31
raise NotImplementedError(
f"Max tokens for model {self.config['model']} is not implemented yet."
)
@@ -430,7 +456,7 @@ class OpenAIHelper:
if model in GPT_3_MODELS + GPT_3_16K_MODELS:
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model in GPT_4_MODELS + GPT_4_32K_MODELS:
elif model in GPT_4_MODELS + GPT_4_32K_MODELS + GPT_4_VISION_MODELS:
tokens_per_message = 3
tokens_per_name = 1
else:
@@ -445,6 +471,35 @@ class OpenAIHelper:
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def __count_tokens_vision(self, filename) -> int:
"""
Counts the number of tokens for interpreting an image.
:param image: image to interpret
:return: the number of tokens required
"""
image = Image.open(filename)
model = self.config['model']
if model not in GPT_4_VISION_MODELS:
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}.""")
w, h = image.size
if w > h: w, h = h, w
# this computation follows https://platform.openai.com/docs/guides/vision and https://openai.com/pricing#gpt-4-turbo
base_tokens = 85
detail = self.config['vision_detail']
if detail == 'low':
return base_tokens
elif detail == 'high':
f = max(w / 768, h / 2048)
if f > 1:
w, h = int(w / f), int(h / f)
tw, th = (w + 511) // 512, (h + 511) // 512
tiles = tw * th
num_tokens = base_tokens + tiles * 170
return num_tokens
else:
raise NotImplementedError(f"""unknown parameter detail={detail} for model {model}.""")
# No longer works as of July 21st 2023, as OpenAI has removed the billing API
# def get_billing_current_month(self):
# """Gets billed usage for current month from OpenAI API.