mirror of
https://github.com/aljazceru/chatgpt-telegram-bot.git
synced 2025-12-22 23:25:41 +01:00
added stream support
This commit is contained in:
@@ -518,46 +518,128 @@ class ChatGPTTelegramBot:
|
||||
if user_id not in self.usage:
|
||||
self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name)
|
||||
|
||||
try:
|
||||
interpretation, tokens = await self.openai.interpret_image(chat_id, temp_file_png, prompt=prompt)
|
||||
if self.config['stream']:
|
||||
|
||||
vision_token_price = self.config['vision_token_price']
|
||||
self.usage[user_id].add_vision_tokens(tokens, vision_token_price)
|
||||
stream_response = self.openai.interpret_image_stream(chat_id=chat_id, fileobj=temp_file_png, prompt=prompt)
|
||||
i = 0
|
||||
prev = ''
|
||||
sent_message = None
|
||||
backoff = 0
|
||||
stream_chunk = 0
|
||||
|
||||
allowed_user_ids = self.config['allowed_user_ids'].split(',')
|
||||
if str(user_id) not in allowed_user_ids and 'guests' in self.usage:
|
||||
self.usage["guests"].add_vision_tokens(tokens, vision_token_price)
|
||||
async for content, tokens in stream_response:
|
||||
if is_direct_result(content):
|
||||
return await handle_direct_result(self.config, update, content)
|
||||
|
||||
if len(content.strip()) == 0:
|
||||
continue
|
||||
|
||||
stream_chunks = split_into_chunks(content)
|
||||
if len(stream_chunks) > 1:
|
||||
content = stream_chunks[-1]
|
||||
if stream_chunk != len(stream_chunks) - 1:
|
||||
stream_chunk += 1
|
||||
try:
|
||||
await edit_message_with_retry(context, chat_id, str(sent_message.message_id),
|
||||
stream_chunks[-2])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
sent_message = await update.effective_message.reply_text(
|
||||
message_thread_id=get_thread_id(update),
|
||||
text=content if len(content) > 0 else "..."
|
||||
)
|
||||
except:
|
||||
pass
|
||||
continue
|
||||
|
||||
cutoff = get_stream_cutoff_values(update, content)
|
||||
cutoff += backoff
|
||||
|
||||
if i == 0:
|
||||
try:
|
||||
if sent_message is not None:
|
||||
await context.bot.delete_message(chat_id=sent_message.chat_id,
|
||||
message_id=sent_message.message_id)
|
||||
sent_message = await update.effective_message.reply_text(
|
||||
message_thread_id=get_thread_id(update),
|
||||
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||
text=content,
|
||||
)
|
||||
except:
|
||||
continue
|
||||
|
||||
elif abs(len(content) - len(prev)) > cutoff or tokens != 'not_finished':
|
||||
prev = content
|
||||
|
||||
try:
|
||||
use_markdown = tokens != 'not_finished'
|
||||
await edit_message_with_retry(context, chat_id, str(sent_message.message_id),
|
||||
text=content, markdown=use_markdown)
|
||||
|
||||
except RetryAfter as e:
|
||||
backoff += 5
|
||||
await asyncio.sleep(e.retry_after)
|
||||
continue
|
||||
|
||||
except TimedOut:
|
||||
backoff += 5
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
|
||||
except Exception:
|
||||
backoff += 5
|
||||
continue
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
i += 1
|
||||
if tokens != 'not_finished':
|
||||
total_tokens = int(tokens)
|
||||
|
||||
|
||||
else:
|
||||
|
||||
try:
|
||||
await update.effective_message.reply_text(
|
||||
message_thread_id=get_thread_id(update),
|
||||
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||
text=interpretation,
|
||||
parse_mode=constants.ParseMode.MARKDOWN
|
||||
)
|
||||
except BadRequest:
|
||||
interpretation, total_tokens = await self.openai.interpret_image(chat_id, temp_file_png, prompt=prompt)
|
||||
|
||||
|
||||
try:
|
||||
await update.effective_message.reply_text(
|
||||
message_thread_id=get_thread_id(update),
|
||||
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||
text=interpretation
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
await update.effective_message.reply_text(
|
||||
message_thread_id=get_thread_id(update),
|
||||
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||
text=f"{localized_text('vision_fail', bot_language)}: {str(e)}",
|
||||
text=interpretation,
|
||||
parse_mode=constants.ParseMode.MARKDOWN
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
await update.effective_message.reply_text(
|
||||
message_thread_id=get_thread_id(update),
|
||||
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||
text=f"{localized_text('vision_fail', bot_language)}: {str(e)}",
|
||||
parse_mode=constants.ParseMode.MARKDOWN
|
||||
)
|
||||
except BadRequest:
|
||||
try:
|
||||
await update.effective_message.reply_text(
|
||||
message_thread_id=get_thread_id(update),
|
||||
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||
text=interpretation
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
await update.effective_message.reply_text(
|
||||
message_thread_id=get_thread_id(update),
|
||||
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||
text=f"{localized_text('vision_fail', bot_language)}: {str(e)}",
|
||||
parse_mode=constants.ParseMode.MARKDOWN
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
await update.effective_message.reply_text(
|
||||
message_thread_id=get_thread_id(update),
|
||||
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||
text=f"{localized_text('vision_fail', bot_language)}: {str(e)}",
|
||||
parse_mode=constants.ParseMode.MARKDOWN
|
||||
)
|
||||
vision_token_price = self.config['vision_token_price']
|
||||
self.usage[user_id].add_vision_tokens(total_tokens, vision_token_price)
|
||||
|
||||
allowed_user_ids = self.config['allowed_user_ids'].split(',')
|
||||
if str(user_id) not in allowed_user_ids and 'guests' in self.usage:
|
||||
self.usage["guests"].add_vision_tokens(total_tokens, vision_token_price)
|
||||
|
||||
await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user