mirror of
https://github.com/aljazceru/chatgpt-telegram-bot.git
synced 2025-12-21 22:54:52 +01:00
some refactoring
This commit is contained in:
@@ -8,7 +8,7 @@ from telegram import constants
|
||||
from telegram import Message, MessageEntity, Update, InlineQueryResultArticle, InputTextMessageContent, BotCommand
|
||||
from telegram.error import RetryAfter, TimedOut
|
||||
from telegram.ext import ApplicationBuilder, ContextTypes, CommandHandler, MessageHandler, \
|
||||
filters, InlineQueryHandler, Application
|
||||
filters, InlineQueryHandler, Application, CallbackContext
|
||||
|
||||
from pydub import AudioSegment
|
||||
from openai_helper import OpenAIHelper
|
||||
@@ -184,13 +184,7 @@ class ChatGPTTelegramBot:
|
||||
parse_mode=constants.ParseMode.MARKDOWN
|
||||
)
|
||||
|
||||
task = context.application.create_task(_generate(), update=update)
|
||||
while not task.done():
|
||||
context.application.create_task(update.effective_chat.send_action(constants.ChatAction.UPLOAD_PHOTO))
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(task), 4.5)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
await self.wrap_with_indicator(update, context, constants.ChatAction.UPLOAD_PHOTO, _generate)
|
||||
|
||||
async def transcribe(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""
|
||||
@@ -317,13 +311,7 @@ class ChatGPTTelegramBot:
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
|
||||
task = context.application.create_task(_execute(), update=update)
|
||||
while not task.done():
|
||||
context.application.create_task(update.effective_chat.send_action(constants.ChatAction.TYPING))
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(task), 4.5)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
await self.wrap_with_indicator(update, context, constants.ChatAction.TYPING, _execute)
|
||||
|
||||
async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""
|
||||
@@ -451,13 +439,7 @@ class ChatGPTTelegramBot:
|
||||
parse_mode=constants.ParseMode.MARKDOWN
|
||||
)
|
||||
|
||||
task = context.application.create_task(_reply(), update=update)
|
||||
while not task.done():
|
||||
context.application.create_task(update.effective_chat.send_action(constants.ChatAction.TYPING))
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(task), 4.5)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
await self.wrap_with_indicator(update, context, constants.ChatAction.TYPING, _reply)
|
||||
|
||||
try:
|
||||
# add chat request to users usage tracker
|
||||
@@ -532,6 +514,18 @@ class ChatGPTTelegramBot:
|
||||
logging.warning(str(e))
|
||||
raise e
|
||||
|
||||
async def wrap_with_indicator(self, update: Update, context: CallbackContext, chat_action: constants.ChatAction, coroutine):
|
||||
"""
|
||||
Wraps a coroutine while repeatedly sending a chat action to the user.
|
||||
"""
|
||||
task = context.application.create_task(coroutine(), update=update)
|
||||
while not task.done():
|
||||
context.application.create_task(update.effective_chat.send_action(chat_action))
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(task), 4.5)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
async def send_disallowed_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
"""
|
||||
Sends the disallowed message to the user.
|
||||
|
||||
Reference in New Issue
Block a user