added stream support

This commit is contained in:
gilcu3
2023-11-20 15:10:54 +01:00
parent 7329cd7fc5
commit 6de497dac8
2 changed files with 259 additions and 52 deletions

View File

@@ -518,46 +518,128 @@ class ChatGPTTelegramBot:
if user_id not in self.usage:
self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name)
try:
interpretation, tokens = await self.openai.interpret_image(chat_id, temp_file_png, prompt=prompt)
if self.config['stream']:
vision_token_price = self.config['vision_token_price']
self.usage[user_id].add_vision_tokens(tokens, vision_token_price)
stream_response = self.openai.interpret_image_stream(chat_id=chat_id, fileobj=temp_file_png, prompt=prompt)
i = 0
prev = ''
sent_message = None
backoff = 0
stream_chunk = 0
allowed_user_ids = self.config['allowed_user_ids'].split(',')
if str(user_id) not in allowed_user_ids and 'guests' in self.usage:
self.usage["guests"].add_vision_tokens(tokens, vision_token_price)
async for content, tokens in stream_response:
if is_direct_result(content):
return await handle_direct_result(self.config, update, content)
if len(content.strip()) == 0:
continue
stream_chunks = split_into_chunks(content)
if len(stream_chunks) > 1:
content = stream_chunks[-1]
if stream_chunk != len(stream_chunks) - 1:
stream_chunk += 1
try:
await edit_message_with_retry(context, chat_id, str(sent_message.message_id),
stream_chunks[-2])
except:
pass
try:
sent_message = await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
text=content if len(content) > 0 else "..."
)
except:
pass
continue
cutoff = get_stream_cutoff_values(update, content)
cutoff += backoff
if i == 0:
try:
if sent_message is not None:
await context.bot.delete_message(chat_id=sent_message.chat_id,
message_id=sent_message.message_id)
sent_message = await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=content,
)
except:
continue
elif abs(len(content) - len(prev)) > cutoff or tokens != 'not_finished':
prev = content
try:
use_markdown = tokens != 'not_finished'
await edit_message_with_retry(context, chat_id, str(sent_message.message_id),
text=content, markdown=use_markdown)
except RetryAfter as e:
backoff += 5
await asyncio.sleep(e.retry_after)
continue
except TimedOut:
backoff += 5
await asyncio.sleep(0.5)
continue
except Exception:
backoff += 5
continue
await asyncio.sleep(0.01)
i += 1
if tokens != 'not_finished':
total_tokens = int(tokens)
else:
try:
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=interpretation,
parse_mode=constants.ParseMode.MARKDOWN
)
except BadRequest:
interpretation, total_tokens = await self.openai.interpret_image(chat_id, temp_file_png, prompt=prompt)
try:
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=interpretation
)
except Exception as e:
logging.exception(e)
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=f"{localized_text('vision_fail', bot_language)}: {str(e)}",
text=interpretation,
parse_mode=constants.ParseMode.MARKDOWN
)
except Exception as e:
logging.exception(e)
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=f"{localized_text('vision_fail', bot_language)}: {str(e)}",
parse_mode=constants.ParseMode.MARKDOWN
)
except BadRequest:
try:
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=interpretation
)
except Exception as e:
logging.exception(e)
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=f"{localized_text('vision_fail', bot_language)}: {str(e)}",
parse_mode=constants.ParseMode.MARKDOWN
)
except Exception as e:
logging.exception(e)
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=f"{localized_text('vision_fail', bot_language)}: {str(e)}",
parse_mode=constants.ParseMode.MARKDOWN
)
vision_token_price = self.config['vision_token_price']
self.usage[user_id].add_vision_tokens(total_tokens, vision_token_price)
allowed_user_ids = self.config['allowed_user_ids'].split(',')
if str(user_id) not in allowed_user_ids and 'guests' in self.usage:
self.usage["guests"].add_vision_tokens(total_tokens, vision_token_price)
await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING)