mirror of
https://github.com/aljazceru/chatgpt-telegram-bot.git
synced 2025-12-20 14:14:52 +01:00
move utils funcs to standalone file
This commit is contained in:
@@ -1,48 +1,30 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import itertools
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import telegram
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from telegram import constants, BotCommandScopeAllGroupChats
|
from telegram import BotCommandScopeAllGroupChats, Update, constants
|
||||||
from telegram import InlineKeyboardMarkup, InlineKeyboardButton, InlineQueryResultArticle
|
from telegram import InlineKeyboardMarkup, InlineKeyboardButton, InlineQueryResultArticle
|
||||||
from telegram import Message, MessageEntity, Update, InputTextMessageContent, BotCommand, ChatMember
|
from telegram import InputTextMessageContent, BotCommand
|
||||||
from telegram.error import RetryAfter, TimedOut
|
from telegram.error import RetryAfter, TimedOut
|
||||||
from telegram.ext import ApplicationBuilder, ContextTypes, CommandHandler, MessageHandler, \
|
from telegram.ext import ApplicationBuilder, CommandHandler, MessageHandler, \
|
||||||
filters, InlineQueryHandler, CallbackQueryHandler, Application, CallbackContext
|
filters, InlineQueryHandler, CallbackQueryHandler, Application, ContextTypes, CallbackContext
|
||||||
|
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
from utils import is_group_chat, get_thread_id, message_text, wrap_with_indicator, split_into_chunks, \
|
||||||
|
edit_message_with_retry, get_stream_cutoff_values, is_allowed, get_remaining_budget, is_admin, is_within_budget, \
|
||||||
|
get_reply_to_message_id, add_chat_request_to_usage_tracker, error_handler
|
||||||
from openai_helper import OpenAIHelper, localized_text
|
from openai_helper import OpenAIHelper, localized_text
|
||||||
from usage_tracker import UsageTracker
|
from usage_tracker import UsageTracker
|
||||||
|
|
||||||
|
|
||||||
def message_text(message: Message) -> str:
|
|
||||||
"""
|
|
||||||
Returns the text of a message, excluding any bot commands.
|
|
||||||
"""
|
|
||||||
message_txt = message.text
|
|
||||||
if message_txt is None:
|
|
||||||
return ''
|
|
||||||
|
|
||||||
for _, text in sorted(message.parse_entities([MessageEntity.BOT_COMMAND]).items(),
|
|
||||||
key=(lambda item: item[0].offset)):
|
|
||||||
message_txt = message_txt.replace(text, '').strip()
|
|
||||||
|
|
||||||
return message_txt if len(message_txt) > 0 else ''
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGPTTelegramBot:
|
class ChatGPTTelegramBot:
|
||||||
"""
|
"""
|
||||||
Class representing a ChatGPT Telegram Bot.
|
Class representing a ChatGPT Telegram Bot.
|
||||||
"""
|
"""
|
||||||
# Mapping of budget period to cost period
|
|
||||||
budget_cost_map = {
|
|
||||||
"monthly": "cost_month",
|
|
||||||
"daily": "cost_today",
|
|
||||||
"all-time": "cost_all_time"
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, config: dict, openai: OpenAIHelper):
|
def __init__(self, config: dict, openai: OpenAIHelper):
|
||||||
"""
|
"""
|
||||||
@@ -60,10 +42,9 @@ class ChatGPTTelegramBot:
|
|||||||
BotCommand(command='stats', description=localized_text('stats_description', bot_language)),
|
BotCommand(command='stats', description=localized_text('stats_description', bot_language)),
|
||||||
BotCommand(command='resend', description=localized_text('resend_description', bot_language))
|
BotCommand(command='resend', description=localized_text('resend_description', bot_language))
|
||||||
]
|
]
|
||||||
self.group_commands = [
|
self.group_commands = [BotCommand(
|
||||||
BotCommand(command='chat',
|
command='chat', description=localized_text('chat_description', bot_language)
|
||||||
description=localized_text('chat_description', bot_language))
|
)] + self.commands
|
||||||
] + self.commands
|
|
||||||
self.disallowed_message = localized_text('disallowed', bot_language)
|
self.disallowed_message = localized_text('disallowed', bot_language)
|
||||||
self.budget_limit_message = localized_text('budget_limit', bot_language)
|
self.budget_limit_message = localized_text('budget_limit', bot_language)
|
||||||
self.usage = {}
|
self.usage = {}
|
||||||
@@ -74,7 +55,7 @@ class ChatGPTTelegramBot:
|
|||||||
"""
|
"""
|
||||||
Shows the help menu.
|
Shows the help menu.
|
||||||
"""
|
"""
|
||||||
commands = self.group_commands if self.is_group_chat(update) else self.commands
|
commands = self.group_commands if is_group_chat(update) else self.commands
|
||||||
commands_description = [f'/{command.command} - {command.description}' for command in commands]
|
commands_description = [f'/{command.command} - {command.description}' for command in commands]
|
||||||
bot_language = self.config['bot_language']
|
bot_language = self.config['bot_language']
|
||||||
help_text = (
|
help_text = (
|
||||||
@@ -92,7 +73,7 @@ class ChatGPTTelegramBot:
|
|||||||
"""
|
"""
|
||||||
Returns token usage statistics for current day and month.
|
Returns token usage statistics for current day and month.
|
||||||
"""
|
"""
|
||||||
if not await self.is_allowed(update, context):
|
if not await is_allowed(self.config, update, context):
|
||||||
logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id}) '
|
logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id}) '
|
||||||
f'is not allowed to request their usage statistics')
|
f'is not allowed to request their usage statistics')
|
||||||
await self.send_disallowed_message(update, context)
|
await self.send_disallowed_message(update, context)
|
||||||
@@ -113,7 +94,7 @@ class ChatGPTTelegramBot:
|
|||||||
|
|
||||||
chat_id = update.effective_chat.id
|
chat_id = update.effective_chat.id
|
||||||
chat_messages, chat_token_length = self.openai.get_conversation_stats(chat_id)
|
chat_messages, chat_token_length = self.openai.get_conversation_stats(chat_id)
|
||||||
remaining_budget = self.get_remaining_budget(update)
|
remaining_budget = get_remaining_budget(self.config, self.usage, update)
|
||||||
bot_language = self.config['bot_language']
|
bot_language = self.config['bot_language']
|
||||||
text_current_conversation = (
|
text_current_conversation = (
|
||||||
f"*{localized_text('stats_conversation', bot_language)[0]}*:\n"
|
f"*{localized_text('stats_conversation', bot_language)[0]}*:\n"
|
||||||
@@ -148,7 +129,7 @@ class ChatGPTTelegramBot:
|
|||||||
f"${remaining_budget:.2f}.\n"
|
f"${remaining_budget:.2f}.\n"
|
||||||
)
|
)
|
||||||
# add OpenAI account information for admin request
|
# add OpenAI account information for admin request
|
||||||
if self.is_admin(user_id):
|
if is_admin(self.config, user_id):
|
||||||
text_budget += (
|
text_budget += (
|
||||||
f"{localized_text('stats_openai', bot_language)}"
|
f"{localized_text('stats_openai', bot_language)}"
|
||||||
f"{self.openai.get_billing_current_month():.2f}"
|
f"{self.openai.get_billing_current_month():.2f}"
|
||||||
@@ -161,7 +142,7 @@ class ChatGPTTelegramBot:
|
|||||||
"""
|
"""
|
||||||
Resend the last request
|
Resend the last request
|
||||||
"""
|
"""
|
||||||
if not await self.is_allowed(update, context):
|
if not await is_allowed(self.config, update, context):
|
||||||
logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id})'
|
logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id})'
|
||||||
f' is not allowed to resend the message')
|
f' is not allowed to resend the message')
|
||||||
await self.send_disallowed_message(update, context)
|
await self.send_disallowed_message(update, context)
|
||||||
@@ -172,7 +153,7 @@ class ChatGPTTelegramBot:
|
|||||||
logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id})'
|
logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id})'
|
||||||
f' does not have anything to resend')
|
f' does not have anything to resend')
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
text=localized_text('resend_failed', self.config['bot_language'])
|
text=localized_text('resend_failed', self.config['bot_language'])
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -189,7 +170,7 @@ class ChatGPTTelegramBot:
|
|||||||
"""
|
"""
|
||||||
Resets the conversation.
|
Resets the conversation.
|
||||||
"""
|
"""
|
||||||
if not await self.is_allowed(update, context):
|
if not await is_allowed(self.config, update, context):
|
||||||
logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id}) '
|
logging.warning(f'User {update.message.from_user.name} (id: {update.message.from_user.id}) '
|
||||||
f'is not allowed to reset the conversation')
|
f'is not allowed to reset the conversation')
|
||||||
await self.send_disallowed_message(update, context)
|
await self.send_disallowed_message(update, context)
|
||||||
@@ -202,7 +183,7 @@ class ChatGPTTelegramBot:
|
|||||||
reset_content = message_text(update.message)
|
reset_content = message_text(update.message)
|
||||||
self.openai.reset_chat_history(chat_id=chat_id, content=reset_content)
|
self.openai.reset_chat_history(chat_id=chat_id, content=reset_content)
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
text=localized_text('reset_done', self.config['bot_language'])
|
text=localized_text('reset_done', self.config['bot_language'])
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -210,15 +191,14 @@ class ChatGPTTelegramBot:
|
|||||||
"""
|
"""
|
||||||
Generates an image for the given prompt using DALL·E APIs
|
Generates an image for the given prompt using DALL·E APIs
|
||||||
"""
|
"""
|
||||||
if not self.config['enable_image_generation'] or not await self.check_allowed_and_within_budget(update,
|
if not self.config['enable_image_generation'] \
|
||||||
context):
|
or not await self.check_allowed_and_within_budget(update, context):
|
||||||
return
|
return
|
||||||
|
|
||||||
chat_id = update.effective_chat.id
|
|
||||||
image_query = message_text(update.message)
|
image_query = message_text(update.message)
|
||||||
if image_query == '':
|
if image_query == '':
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
text=localized_text('image_no_prompt', self.config['bot_language'])
|
text=localized_text('image_no_prompt', self.config['bot_language'])
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -230,7 +210,7 @@ class ChatGPTTelegramBot:
|
|||||||
try:
|
try:
|
||||||
image_url, image_size = await self.openai.generate_image(prompt=image_query)
|
image_url, image_size = await self.openai.generate_image(prompt=image_query)
|
||||||
await update.effective_message.reply_photo(
|
await update.effective_message.reply_photo(
|
||||||
reply_to_message_id=self.get_reply_to_message_id(update),
|
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||||
photo=image_url
|
photo=image_url
|
||||||
)
|
)
|
||||||
# add image request to users usage tracker
|
# add image request to users usage tracker
|
||||||
@@ -243,13 +223,13 @@ class ChatGPTTelegramBot:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
reply_to_message_id=self.get_reply_to_message_id(update),
|
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||||
text=f"{localized_text('image_fail', self.config['bot_language'])}: {str(e)}",
|
text=f"{localized_text('image_fail', self.config['bot_language'])}: {str(e)}",
|
||||||
parse_mode=constants.ParseMode.MARKDOWN
|
parse_mode=constants.ParseMode.MARKDOWN
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.wrap_with_indicator(update, context, _generate, constants.ChatAction.UPLOAD_PHOTO)
|
await wrap_with_indicator(update, context, _generate, constants.ChatAction.UPLOAD_PHOTO)
|
||||||
|
|
||||||
async def transcribe(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
async def transcribe(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
"""
|
"""
|
||||||
@@ -258,7 +238,7 @@ class ChatGPTTelegramBot:
|
|||||||
if not self.config['enable_transcription'] or not await self.check_allowed_and_within_budget(update, context):
|
if not self.config['enable_transcription'] or not await self.check_allowed_and_within_budget(update, context):
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.is_group_chat(update) and self.config['ignore_group_transcriptions']:
|
if is_group_chat(update) and self.config['ignore_group_transcriptions']:
|
||||||
logging.info(f'Transcription coming from group chat, ignoring...')
|
logging.info(f'Transcription coming from group chat, ignoring...')
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -274,8 +254,8 @@ class ChatGPTTelegramBot:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
reply_to_message_id=self.get_reply_to_message_id(update),
|
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||||
text=(
|
text=(
|
||||||
f"{localized_text('media_download_fail', bot_language)[0]}: "
|
f"{localized_text('media_download_fail', bot_language)[0]}: "
|
||||||
f"{str(e)}. {localized_text('media_download_fail', bot_language)[1]}"
|
f"{str(e)}. {localized_text('media_download_fail', bot_language)[1]}"
|
||||||
@@ -284,7 +264,6 @@ class ChatGPTTelegramBot:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# detect and extract audio from the attachment with pydub
|
|
||||||
try:
|
try:
|
||||||
audio_track = AudioSegment.from_file(filename)
|
audio_track = AudioSegment.from_file(filename)
|
||||||
audio_track.export(filename_mp3, format="mp3")
|
audio_track.export(filename_mp3, format="mp3")
|
||||||
@@ -294,8 +273,8 @@ class ChatGPTTelegramBot:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
reply_to_message_id=self.get_reply_to_message_id(update),
|
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||||
text=localized_text('media_type_fail', bot_language)
|
text=localized_text('media_type_fail', bot_language)
|
||||||
)
|
)
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
@@ -306,17 +285,12 @@ class ChatGPTTelegramBot:
|
|||||||
if user_id not in self.usage:
|
if user_id not in self.usage:
|
||||||
self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name)
|
self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name)
|
||||||
|
|
||||||
# send decoded audio to openai
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Transcribe the audio file
|
|
||||||
transcript = await self.openai.transcribe(filename_mp3)
|
transcript = await self.openai.transcribe(filename_mp3)
|
||||||
|
|
||||||
# add transcription seconds to usage tracker
|
|
||||||
transcription_price = self.config['transcription_price']
|
transcription_price = self.config['transcription_price']
|
||||||
self.usage[user_id].add_transcription_seconds(audio_track.duration_seconds, transcription_price)
|
self.usage[user_id].add_transcription_seconds(audio_track.duration_seconds, transcription_price)
|
||||||
|
|
||||||
# add guest chat request to guest usage tracker
|
|
||||||
allowed_user_ids = self.config['allowed_user_ids'].split(',')
|
allowed_user_ids = self.config['allowed_user_ids'].split(',')
|
||||||
if str(user_id) not in allowed_user_ids and 'guests' in self.usage:
|
if str(user_id) not in allowed_user_ids and 'guests' in self.usage:
|
||||||
self.usage["guests"].add_transcription_seconds(audio_track.duration_seconds, transcription_price)
|
self.usage["guests"].add_transcription_seconds(audio_track.duration_seconds, transcription_price)
|
||||||
@@ -329,12 +303,12 @@ class ChatGPTTelegramBot:
|
|||||||
|
|
||||||
# Split into chunks of 4096 characters (Telegram's message limit)
|
# Split into chunks of 4096 characters (Telegram's message limit)
|
||||||
transcript_output = f"_{localized_text('transcript', bot_language)}:_\n\"{transcript}\""
|
transcript_output = f"_{localized_text('transcript', bot_language)}:_\n\"{transcript}\""
|
||||||
chunks = self.split_into_chunks(transcript_output)
|
chunks = split_into_chunks(transcript_output)
|
||||||
|
|
||||||
for index, transcript_chunk in enumerate(chunks):
|
for index, transcript_chunk in enumerate(chunks):
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
reply_to_message_id=self.get_reply_to_message_id(update) if index == 0 else None,
|
reply_to_message_id=get_reply_to_message_id(self.config, update) if index == 0 else None,
|
||||||
text=transcript_chunk,
|
text=transcript_chunk,
|
||||||
parse_mode=constants.ParseMode.MARKDOWN
|
parse_mode=constants.ParseMode.MARKDOWN
|
||||||
)
|
)
|
||||||
@@ -342,9 +316,7 @@ class ChatGPTTelegramBot:
|
|||||||
# Get the response of the transcript
|
# Get the response of the transcript
|
||||||
response, total_tokens = await self.openai.get_chat_response(chat_id=chat_id, query=transcript)
|
response, total_tokens = await self.openai.get_chat_response(chat_id=chat_id, query=transcript)
|
||||||
|
|
||||||
# add chat request to users usage tracker
|
|
||||||
self.usage[user_id].add_chat_tokens(total_tokens, self.config['token_price'])
|
self.usage[user_id].add_chat_tokens(total_tokens, self.config['token_price'])
|
||||||
# add guest chat request to guest usage tracker
|
|
||||||
if str(user_id) not in allowed_user_ids and 'guests' in self.usage:
|
if str(user_id) not in allowed_user_ids and 'guests' in self.usage:
|
||||||
self.usage["guests"].add_chat_tokens(total_tokens, self.config['token_price'])
|
self.usage["guests"].add_chat_tokens(total_tokens, self.config['token_price'])
|
||||||
|
|
||||||
@@ -353,12 +325,12 @@ class ChatGPTTelegramBot:
|
|||||||
f"_{localized_text('transcript', bot_language)}:_\n\"{transcript}\"\n\n"
|
f"_{localized_text('transcript', bot_language)}:_\n\"{transcript}\"\n\n"
|
||||||
f"_{localized_text('answer', bot_language)}:_\n{response}"
|
f"_{localized_text('answer', bot_language)}:_\n{response}"
|
||||||
)
|
)
|
||||||
chunks = self.split_into_chunks(transcript_output)
|
chunks = split_into_chunks(transcript_output)
|
||||||
|
|
||||||
for index, transcript_chunk in enumerate(chunks):
|
for index, transcript_chunk in enumerate(chunks):
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
reply_to_message_id=self.get_reply_to_message_id(update) if index == 0 else None,
|
reply_to_message_id=get_reply_to_message_id(self.config, update) if index == 0 else None,
|
||||||
text=transcript_chunk,
|
text=transcript_chunk,
|
||||||
parse_mode=constants.ParseMode.MARKDOWN
|
parse_mode=constants.ParseMode.MARKDOWN
|
||||||
)
|
)
|
||||||
@@ -366,19 +338,18 @@ class ChatGPTTelegramBot:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
reply_to_message_id=self.get_reply_to_message_id(update),
|
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||||
text=f"{localized_text('transcribe_fail', bot_language)}: {str(e)}",
|
text=f"{localized_text('transcribe_fail', bot_language)}: {str(e)}",
|
||||||
parse_mode=constants.ParseMode.MARKDOWN
|
parse_mode=constants.ParseMode.MARKDOWN
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
# Cleanup files
|
|
||||||
if os.path.exists(filename_mp3):
|
if os.path.exists(filename_mp3):
|
||||||
os.remove(filename_mp3)
|
os.remove(filename_mp3)
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
os.remove(filename)
|
os.remove(filename)
|
||||||
|
|
||||||
await self.wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING)
|
await wrap_with_indicator(update, context, _execute, constants.ChatAction.TYPING)
|
||||||
|
|
||||||
async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
async def prompt(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||||
"""
|
"""
|
||||||
@@ -397,7 +368,7 @@ class ChatGPTTelegramBot:
|
|||||||
prompt = message_text(update.message)
|
prompt = message_text(update.message)
|
||||||
self.last_message[chat_id] = prompt
|
self.last_message[chat_id] = prompt
|
||||||
|
|
||||||
if self.is_group_chat(update):
|
if is_group_chat(update):
|
||||||
trigger_keyword = self.config['group_trigger_keyword']
|
trigger_keyword = self.config['group_trigger_keyword']
|
||||||
if prompt.lower().startswith(trigger_keyword.lower()):
|
if prompt.lower().startswith(trigger_keyword.lower()):
|
||||||
prompt = prompt[len(trigger_keyword):].strip()
|
prompt = prompt[len(trigger_keyword):].strip()
|
||||||
@@ -405,10 +376,7 @@ class ChatGPTTelegramBot:
|
|||||||
if update.message.reply_to_message and \
|
if update.message.reply_to_message and \
|
||||||
update.message.reply_to_message.text and \
|
update.message.reply_to_message.text and \
|
||||||
update.message.reply_to_message.from_user.id != context.bot.id:
|
update.message.reply_to_message.from_user.id != context.bot.id:
|
||||||
prompt = '"{reply}" {prompt}'.format(
|
prompt = f'"{update.message.reply_to_message.text}" {prompt}'
|
||||||
reply=update.message.reply_to_message.text,
|
|
||||||
prompt=prompt
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if update.message.reply_to_message and update.message.reply_to_message.from_user.id == context.bot.id:
|
if update.message.reply_to_message and update.message.reply_to_message.from_user.id == context.bot.id:
|
||||||
logging.info('Message is a reply to the bot, allowing...')
|
logging.info('Message is a reply to the bot, allowing...')
|
||||||
@@ -422,7 +390,7 @@ class ChatGPTTelegramBot:
|
|||||||
if self.config['stream']:
|
if self.config['stream']:
|
||||||
await update.effective_message.reply_chat_action(
|
await update.effective_message.reply_chat_action(
|
||||||
action=constants.ChatAction.TYPING,
|
action=constants.ChatAction.TYPING,
|
||||||
message_thread_id=self.get_thread_id(update)
|
message_thread_id=get_thread_id(update)
|
||||||
)
|
)
|
||||||
|
|
||||||
stream_response = self.openai.get_chat_response_stream(chat_id=chat_id, query=prompt)
|
stream_response = self.openai.get_chat_response_stream(chat_id=chat_id, query=prompt)
|
||||||
@@ -436,26 +404,26 @@ class ChatGPTTelegramBot:
|
|||||||
if len(content.strip()) == 0:
|
if len(content.strip()) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
stream_chunks = self.split_into_chunks(content)
|
stream_chunks = split_into_chunks(content)
|
||||||
if len(stream_chunks) > 1:
|
if len(stream_chunks) > 1:
|
||||||
content = stream_chunks[-1]
|
content = stream_chunks[-1]
|
||||||
if stream_chunk != len(stream_chunks) - 1:
|
if stream_chunk != len(stream_chunks) - 1:
|
||||||
stream_chunk += 1
|
stream_chunk += 1
|
||||||
try:
|
try:
|
||||||
await self.edit_message_with_retry(context, chat_id, str(sent_message.message_id),
|
await edit_message_with_retry(context, chat_id, str(sent_message.message_id),
|
||||||
stream_chunks[-2])
|
stream_chunks[-2])
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
sent_message = await update.effective_message.reply_text(
|
sent_message = await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
text=content if len(content) > 0 else "..."
|
text=content if len(content) > 0 else "..."
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cutoff = self.get_stream_cutoff_values(update, content)
|
cutoff = get_stream_cutoff_values(update, content)
|
||||||
cutoff += backoff
|
cutoff += backoff
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@@ -464,8 +432,8 @@ class ChatGPTTelegramBot:
|
|||||||
await context.bot.delete_message(chat_id=sent_message.chat_id,
|
await context.bot.delete_message(chat_id=sent_message.chat_id,
|
||||||
message_id=sent_message.message_id)
|
message_id=sent_message.message_id)
|
||||||
sent_message = await update.effective_message.reply_text(
|
sent_message = await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
reply_to_message_id=self.get_reply_to_message_id(update),
|
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||||
text=content
|
text=content
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
@@ -476,7 +444,7 @@ class ChatGPTTelegramBot:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
use_markdown = tokens != 'not_finished'
|
use_markdown = tokens != 'not_finished'
|
||||||
await self.edit_message_with_retry(context, chat_id, str(sent_message.message_id),
|
await edit_message_with_retry(context, chat_id, str(sent_message.message_id),
|
||||||
text=content, markdown=use_markdown)
|
text=content, markdown=use_markdown)
|
||||||
|
|
||||||
except RetryAfter as e:
|
except RetryAfter as e:
|
||||||
@@ -505,35 +473,37 @@ class ChatGPTTelegramBot:
|
|||||||
response, total_tokens = await self.openai.get_chat_response(chat_id=chat_id, query=prompt)
|
response, total_tokens = await self.openai.get_chat_response(chat_id=chat_id, query=prompt)
|
||||||
|
|
||||||
# Split into chunks of 4096 characters (Telegram's message limit)
|
# Split into chunks of 4096 characters (Telegram's message limit)
|
||||||
chunks = self.split_into_chunks(response)
|
chunks = split_into_chunks(response)
|
||||||
|
|
||||||
for index, chunk in enumerate(chunks):
|
for index, chunk in enumerate(chunks):
|
||||||
try:
|
try:
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
reply_to_message_id=self.get_reply_to_message_id(update) if index == 0 else None,
|
reply_to_message_id=get_reply_to_message_id(self.config,
|
||||||
|
update) if index == 0 else None,
|
||||||
text=chunk,
|
text=chunk,
|
||||||
parse_mode=constants.ParseMode.MARKDOWN
|
parse_mode=constants.ParseMode.MARKDOWN
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
reply_to_message_id=self.get_reply_to_message_id(update) if index == 0 else None,
|
reply_to_message_id=get_reply_to_message_id(self.config,
|
||||||
|
update) if index == 0 else None,
|
||||||
text=chunk
|
text=chunk
|
||||||
)
|
)
|
||||||
except Exception as exception:
|
except Exception as exception:
|
||||||
raise exception
|
raise exception
|
||||||
|
|
||||||
await self.wrap_with_indicator(update, context, _reply, constants.ChatAction.TYPING)
|
await wrap_with_indicator(update, context, _reply, constants.ChatAction.TYPING)
|
||||||
|
|
||||||
self.add_chat_request_to_usage_tracker(user_id, total_tokens)
|
add_chat_request_to_usage_tracker(self.usage, self.config, user_id, total_tokens)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
await update.effective_message.reply_text(
|
await update.effective_message.reply_text(
|
||||||
message_thread_id=self.get_thread_id(update),
|
message_thread_id=get_thread_id(update),
|
||||||
reply_to_message_id=self.get_reply_to_message_id(update),
|
reply_to_message_id=get_reply_to_message_id(self.config, update),
|
||||||
text=f"{localized_text('chat_fail', self.config['bot_language'])} {str(e)}",
|
text=f"{localized_text('chat_fail', self.config['bot_language'])} {str(e)}",
|
||||||
parse_mode=constants.ParseMode.MARKDOWN
|
parse_mode=constants.ParseMode.MARKDOWN
|
||||||
)
|
)
|
||||||
@@ -556,6 +526,9 @@ class ChatGPTTelegramBot:
|
|||||||
await self.send_inline_query_result(update, result_id, message_content=query, callback_data=callback_data)
|
await self.send_inline_query_result(update, result_id, message_content=query, callback_data=callback_data)
|
||||||
|
|
||||||
async def send_inline_query_result(self, update: Update, result_id, message_content, callback_data=""):
|
async def send_inline_query_result(self, update: Update, result_id, message_content, callback_data=""):
|
||||||
|
"""
|
||||||
|
Send inline query result
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
reply_markup = None
|
reply_markup = None
|
||||||
bot_language = self.config['bot_language']
|
bot_language = self.config['bot_language']
|
||||||
@@ -580,6 +553,9 @@ class ChatGPTTelegramBot:
|
|||||||
logging.error(f'An error occurred while generating the result card for inline query {e}')
|
logging.error(f'An error occurred while generating the result card for inline query {e}')
|
||||||
|
|
||||||
async def handle_callback_inline_query(self, update: Update, context: CallbackContext):
|
async def handle_callback_inline_query(self, update: Update, context: CallbackContext):
|
||||||
|
"""
|
||||||
|
Handle the callback query from the inline query result
|
||||||
|
"""
|
||||||
callback_data = update.callback_query.data
|
callback_data = update.callback_query.data
|
||||||
user_id = update.callback_query.from_user.id
|
user_id = update.callback_query.from_user.id
|
||||||
inline_message_id = update.callback_query.inline_message_id
|
inline_message_id = update.callback_query.inline_message_id
|
||||||
@@ -604,7 +580,7 @@ class ChatGPTTelegramBot:
|
|||||||
f'{localized_text("error", bot_language)}. '
|
f'{localized_text("error", bot_language)}. '
|
||||||
f'{localized_text("try_again", bot_language)}'
|
f'{localized_text("try_again", bot_language)}'
|
||||||
)
|
)
|
||||||
await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id,
|
await edit_message_with_retry(context, chat_id=None, message_id=inline_message_id,
|
||||||
text=f'{query}\n\n_{answer_tr}:_\n{error_message}',
|
text=f'{query}\n\n_{answer_tr}:_\n{error_message}',
|
||||||
is_inline=True)
|
is_inline=True)
|
||||||
return
|
return
|
||||||
@@ -619,13 +595,13 @@ class ChatGPTTelegramBot:
|
|||||||
if len(content.strip()) == 0:
|
if len(content.strip()) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cutoff = self.get_stream_cutoff_values(update, content)
|
cutoff = get_stream_cutoff_values(update, content)
|
||||||
cutoff += backoff
|
cutoff += backoff
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
try:
|
try:
|
||||||
if sent_message is not None:
|
if sent_message is not None:
|
||||||
await self.edit_message_with_retry(context, chat_id=None,
|
await edit_message_with_retry(context, chat_id=None,
|
||||||
message_id=inline_message_id,
|
message_id=inline_message_id,
|
||||||
text=f'{query}\n\n{answer_tr}:\n{content}',
|
text=f'{query}\n\n{answer_tr}:\n{content}',
|
||||||
is_inline=True)
|
is_inline=True)
|
||||||
@@ -642,7 +618,7 @@ class ChatGPTTelegramBot:
|
|||||||
# We only want to send the first 4096 characters. No chunking allowed in inline mode.
|
# We only want to send the first 4096 characters. No chunking allowed in inline mode.
|
||||||
text = text[:4096]
|
text = text[:4096]
|
||||||
|
|
||||||
await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id,
|
await edit_message_with_retry(context, chat_id=None, message_id=inline_message_id,
|
||||||
text=text, markdown=use_markdown, is_inline=True)
|
text=text, markdown=use_markdown, is_inline=True)
|
||||||
|
|
||||||
except RetryAfter as e:
|
except RetryAfter as e:
|
||||||
@@ -680,271 +656,22 @@ class ChatGPTTelegramBot:
|
|||||||
text_content = text_content[:4096]
|
text_content = text_content[:4096]
|
||||||
|
|
||||||
# Edit the original message with the generated content
|
# Edit the original message with the generated content
|
||||||
await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id,
|
await edit_message_with_retry(context, chat_id=None, message_id=inline_message_id,
|
||||||
text=text_content, is_inline=True)
|
text=text_content, is_inline=True)
|
||||||
|
|
||||||
await self.wrap_with_indicator(update, context, _send_inline_query_response,
|
await wrap_with_indicator(update, context, _send_inline_query_response,
|
||||||
constants.ChatAction.TYPING, is_inline=True)
|
constants.ChatAction.TYPING, is_inline=True)
|
||||||
|
|
||||||
self.add_chat_request_to_usage_tracker(user_id, total_tokens)
|
add_chat_request_to_usage_tracker(self.usage, self.config, user_id, total_tokens)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f'Failed to respond to an inline query via button callback: {e}')
|
logging.error(f'Failed to respond to an inline query via button callback: {e}')
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
localized_answer = localized_text('chat_fail', self.config['bot_language'])
|
localized_answer = localized_text('chat_fail', self.config['bot_language'])
|
||||||
await self.edit_message_with_retry(context, chat_id=None, message_id=inline_message_id,
|
await edit_message_with_retry(context, chat_id=None, message_id=inline_message_id,
|
||||||
text=f"{query}\n\n_{answer_tr}:_\n{localized_answer} {str(e)}",
|
text=f"{query}\n\n_{answer_tr}:_\n{localized_answer} {str(e)}",
|
||||||
is_inline=True)
|
is_inline=True)
|
||||||
|
|
||||||
async def edit_message_with_retry(self, context: ContextTypes.DEFAULT_TYPE, chat_id: int | None,
|
|
||||||
message_id: str, text: str, markdown: bool = True, is_inline: bool = False):
|
|
||||||
"""
|
|
||||||
Edit a message with retry logic in case of failure (e.g. broken markdown)
|
|
||||||
:param context: The context to use
|
|
||||||
:param chat_id: The chat id to edit the message in
|
|
||||||
:param message_id: The message id to edit
|
|
||||||
:param text: The text to edit the message with
|
|
||||||
:param markdown: Whether to use markdown parse mode
|
|
||||||
:param is_inline: Whether the message to edit is an inline message
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await context.bot.edit_message_text(
|
|
||||||
chat_id=chat_id,
|
|
||||||
message_id=int(message_id) if not is_inline else None,
|
|
||||||
inline_message_id=message_id if is_inline else None,
|
|
||||||
text=text,
|
|
||||||
parse_mode=constants.ParseMode.MARKDOWN if markdown else None
|
|
||||||
)
|
|
||||||
except telegram.error.BadRequest as e:
|
|
||||||
if str(e).startswith("Message is not modified"):
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
await context.bot.edit_message_text(
|
|
||||||
chat_id=chat_id,
|
|
||||||
message_id=int(message_id) if not is_inline else None,
|
|
||||||
inline_message_id=message_id if is_inline else None,
|
|
||||||
text=text
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f'Failed to edit message: {str(e)}')
|
|
||||||
raise e
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(str(e))
|
|
||||||
raise e
|
|
||||||
|
|
||||||
async def wrap_with_indicator(self, update: Update, context: CallbackContext, coroutine,
|
|
||||||
chat_action: constants.ChatAction = "", is_inline=False):
|
|
||||||
"""
|
|
||||||
Wraps a coroutine while repeatedly sending a chat action to the user.
|
|
||||||
"""
|
|
||||||
task = context.application.create_task(coroutine(), update=update)
|
|
||||||
while not task.done():
|
|
||||||
if not is_inline:
|
|
||||||
context.application.create_task(
|
|
||||||
update.effective_chat.send_action(chat_action, message_thread_id=self.get_thread_id(update))
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(asyncio.shield(task), 4.5)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def send_disallowed_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE, is_inline=False):
|
|
||||||
"""
|
|
||||||
Sends the disallowed message to the user.
|
|
||||||
"""
|
|
||||||
if not is_inline:
|
|
||||||
await update.effective_message.reply_text(
|
|
||||||
message_thread_id=self.get_thread_id(update),
|
|
||||||
text=self.disallowed_message,
|
|
||||||
disable_web_page_preview=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
result_id = str(uuid4())
|
|
||||||
await self.send_inline_query_result(update, result_id, message_content=self.disallowed_message)
|
|
||||||
|
|
||||||
async def send_budget_reached_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE, is_inline=False):
|
|
||||||
"""
|
|
||||||
Sends the budget reached message to the user.
|
|
||||||
"""
|
|
||||||
if not is_inline:
|
|
||||||
await update.effective_message.reply_text(
|
|
||||||
message_thread_id=self.get_thread_id(update),
|
|
||||||
text=self.budget_limit_message
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
result_id = str(uuid4())
|
|
||||||
await self.send_inline_query_result(update, result_id, message_content=self.budget_limit_message)
|
|
||||||
|
|
||||||
async def error_handler(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
|
||||||
"""
|
|
||||||
Handles errors in the telegram-python-bot library.
|
|
||||||
"""
|
|
||||||
logging.error(f'Exception while handling an update: {context.error}')
|
|
||||||
|
|
||||||
def get_thread_id(self, update: Update) -> int | None:
|
|
||||||
"""
|
|
||||||
Gets the message thread id for the update, if any
|
|
||||||
"""
|
|
||||||
if update.effective_message and update.effective_message.is_topic_message:
|
|
||||||
return update.effective_message.message_thread_id
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_stream_cutoff_values(self, update: Update, content: str) -> int:
|
|
||||||
"""
|
|
||||||
Gets the stream cutoff values for the message length
|
|
||||||
"""
|
|
||||||
if self.is_group_chat(update):
|
|
||||||
# group chats have stricter flood limits
|
|
||||||
return 180 if len(content) > 1000 else 120 if len(content) > 200 else 90 if len(
|
|
||||||
content) > 50 else 50
|
|
||||||
else:
|
|
||||||
return 90 if len(content) > 1000 else 45 if len(content) > 200 else 25 if len(
|
|
||||||
content) > 50 else 15
|
|
||||||
|
|
||||||
def is_group_chat(self, update: Update) -> bool:
|
|
||||||
"""
|
|
||||||
Checks if the message was sent from a group chat
|
|
||||||
"""
|
|
||||||
if not update.effective_chat:
|
|
||||||
return False
|
|
||||||
return update.effective_chat.type in [
|
|
||||||
constants.ChatType.GROUP,
|
|
||||||
constants.ChatType.SUPERGROUP
|
|
||||||
]
|
|
||||||
|
|
||||||
async def is_user_in_group(self, update: Update, context: CallbackContext, user_id: int) -> bool:
|
|
||||||
"""
|
|
||||||
Checks if user_id is a member of the group
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
chat_member = await context.bot.get_chat_member(update.message.chat_id, user_id)
|
|
||||||
return chat_member.status in [ChatMember.OWNER, ChatMember.ADMINISTRATOR, ChatMember.MEMBER]
|
|
||||||
except telegram.error.BadRequest as e:
|
|
||||||
if str(e) == "User not found":
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
async def is_allowed(self, update: Update, context: CallbackContext, is_inline=False) -> bool:
|
|
||||||
"""
|
|
||||||
Checks if the user is allowed to use the bot.
|
|
||||||
"""
|
|
||||||
if self.config['allowed_user_ids'] == '*':
|
|
||||||
return True
|
|
||||||
|
|
||||||
user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id
|
|
||||||
if self.is_admin(user_id):
|
|
||||||
return True
|
|
||||||
name = update.inline_query.from_user.name if is_inline else update.message.from_user.name
|
|
||||||
allowed_user_ids = self.config['allowed_user_ids'].split(',')
|
|
||||||
# Check if user is allowed
|
|
||||||
if str(user_id) in allowed_user_ids:
|
|
||||||
return True
|
|
||||||
# Check if it's a group a chat with at least one authorized member
|
|
||||||
if not is_inline and self.is_group_chat(update):
|
|
||||||
admin_user_ids = self.config['admin_user_ids'].split(',')
|
|
||||||
for user in itertools.chain(allowed_user_ids, admin_user_ids):
|
|
||||||
if not user.strip():
|
|
||||||
continue
|
|
||||||
if await self.is_user_in_group(update, context, user):
|
|
||||||
logging.info(f'{user} is a member. Allowing group chat message...')
|
|
||||||
return True
|
|
||||||
logging.info(f'Group chat messages from user {name} '
|
|
||||||
f'(id: {user_id}) are not allowed')
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_admin(self, user_id: int, log_no_admin=False) -> bool:
|
|
||||||
"""
|
|
||||||
Checks if the user is the admin of the bot.
|
|
||||||
The first user in the user list is the admin.
|
|
||||||
"""
|
|
||||||
if self.config['admin_user_ids'] == '-':
|
|
||||||
if log_no_admin:
|
|
||||||
logging.info('No admin user defined.')
|
|
||||||
return False
|
|
||||||
|
|
||||||
admin_user_ids = self.config['admin_user_ids'].split(',')
|
|
||||||
|
|
||||||
# Check if user is in the admin user list
|
|
||||||
if str(user_id) in admin_user_ids:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_user_budget(self, user_id) -> float | None:
|
|
||||||
"""
|
|
||||||
Get the user's budget based on their user ID and the bot configuration.
|
|
||||||
:param user_id: User id
|
|
||||||
:return: The user's budget as a float, or None if the user is not found in the allowed user list
|
|
||||||
"""
|
|
||||||
|
|
||||||
# no budget restrictions for admins and '*'-budget lists
|
|
||||||
if self.is_admin(user_id) or self.config['user_budgets'] == '*':
|
|
||||||
return float('inf')
|
|
||||||
|
|
||||||
user_budgets = self.config['user_budgets'].split(',')
|
|
||||||
if self.config['allowed_user_ids'] == '*':
|
|
||||||
# same budget for all users, use value in first position of budget list
|
|
||||||
if len(user_budgets) > 1:
|
|
||||||
logging.warning('multiple values for budgets set with unrestricted user list '
|
|
||||||
'only the first value is used as budget for everyone.')
|
|
||||||
return float(user_budgets[0])
|
|
||||||
|
|
||||||
allowed_user_ids = self.config['allowed_user_ids'].split(',')
|
|
||||||
if str(user_id) in allowed_user_ids:
|
|
||||||
user_index = allowed_user_ids.index(str(user_id))
|
|
||||||
if len(user_budgets) <= user_index:
|
|
||||||
logging.warning(f'No budget set for user id: {user_id}. Budget list shorter than user list.')
|
|
||||||
return 0.0
|
|
||||||
return float(user_budgets[user_index])
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_remaining_budget(self, update: Update, is_inline=False) -> float:
|
|
||||||
"""
|
|
||||||
Calculate the remaining budget for a user based on their current usage.
|
|
||||||
:param update: Telegram update object
|
|
||||||
:param is_inline: Boolean flag for inline queries
|
|
||||||
:return: The remaining budget for the user as a float
|
|
||||||
"""
|
|
||||||
user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id
|
|
||||||
name = update.inline_query.from_user.name if is_inline else update.message.from_user.name
|
|
||||||
if user_id not in self.usage:
|
|
||||||
self.usage[user_id] = UsageTracker(user_id, name)
|
|
||||||
|
|
||||||
# Get budget for users
|
|
||||||
user_budget = self.get_user_budget(user_id)
|
|
||||||
budget_period = self.config['budget_period']
|
|
||||||
if user_budget is not None:
|
|
||||||
cost = self.usage[user_id].get_current_cost()[self.budget_cost_map[budget_period]]
|
|
||||||
return user_budget - cost
|
|
||||||
|
|
||||||
# Get budget for guests
|
|
||||||
if 'guests' not in self.usage:
|
|
||||||
self.usage['guests'] = UsageTracker('guests', 'all guest users in group chats')
|
|
||||||
cost = self.usage['guests'].get_current_cost()[self.budget_cost_map[budget_period]]
|
|
||||||
return self.config['guest_budget'] - cost
|
|
||||||
|
|
||||||
def is_within_budget(self, update: Update, is_inline=False) -> bool:
|
|
||||||
"""
|
|
||||||
Checks if the user reached their usage limit.
|
|
||||||
Initializes UsageTracker for user and guest when needed.
|
|
||||||
:param update: Telegram update object
|
|
||||||
:param is_inline: Boolean flag for inline queries
|
|
||||||
:return: Boolean indicating if the user has a positive budget
|
|
||||||
"""
|
|
||||||
user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id
|
|
||||||
name = update.inline_query.from_user.name if is_inline else update.message.from_user.name
|
|
||||||
if user_id not in self.usage:
|
|
||||||
self.usage[user_id] = UsageTracker(user_id, name)
|
|
||||||
|
|
||||||
remaining_budget = self.get_remaining_budget(update, is_inline=is_inline)
|
|
||||||
|
|
||||||
return remaining_budget > 0
|
|
||||||
|
|
||||||
async def check_allowed_and_within_budget(self, update: Update, context: ContextTypes.DEFAULT_TYPE,
|
async def check_allowed_and_within_budget(self, update: Update, context: ContextTypes.DEFAULT_TYPE,
|
||||||
is_inline=False) -> bool:
|
is_inline=False) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -957,46 +684,43 @@ class ChatGPTTelegramBot:
|
|||||||
name = update.inline_query.from_user.name if is_inline else update.message.from_user.name
|
name = update.inline_query.from_user.name if is_inline else update.message.from_user.name
|
||||||
user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id
|
user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id
|
||||||
|
|
||||||
if not await self.is_allowed(update, context, is_inline=is_inline):
|
if not await is_allowed(self.config, update, context, is_inline=is_inline):
|
||||||
logging.warning(f'User {name} (id: {user_id}) '
|
logging.warning(f'User {name} (id: {user_id}) is not allowed to use the bot')
|
||||||
f'is not allowed to use the bot')
|
|
||||||
await self.send_disallowed_message(update, context, is_inline)
|
await self.send_disallowed_message(update, context, is_inline)
|
||||||
return False
|
return False
|
||||||
if not self.is_within_budget(update, is_inline=is_inline):
|
if not is_within_budget(self.config, self.usage, update, is_inline=is_inline):
|
||||||
logging.warning(f'User {name} (id: {user_id}) '
|
logging.warning(f'User {name} (id: {user_id}) reached their usage limit')
|
||||||
f'reached their usage limit')
|
|
||||||
await self.send_budget_reached_message(update, context, is_inline)
|
await self.send_budget_reached_message(update, context, is_inline)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def add_chat_request_to_usage_tracker(self, user_id, used_tokens):
|
async def send_disallowed_message(self, update: Update, _: ContextTypes.DEFAULT_TYPE, is_inline=False):
|
||||||
try:
|
"""
|
||||||
# add chat request to users usage tracker
|
Sends the disallowed message to the user.
|
||||||
self.usage[user_id].add_chat_tokens(used_tokens, self.config['token_price'])
|
"""
|
||||||
# add guest chat request to guest usage tracker
|
if not is_inline:
|
||||||
allowed_user_ids = self.config['allowed_user_ids'].split(',')
|
await update.effective_message.reply_text(
|
||||||
if str(user_id) not in allowed_user_ids and 'guests' in self.usage:
|
message_thread_id=get_thread_id(update),
|
||||||
self.usage["guests"].add_chat_tokens(used_tokens, self.config['token_price'])
|
text=self.disallowed_message,
|
||||||
except Exception as e:
|
disable_web_page_preview=True
|
||||||
logging.warning(f'Failed to add tokens to usage_logs: {str(e)}')
|
)
|
||||||
pass
|
else:
|
||||||
|
result_id = str(uuid4())
|
||||||
|
await self.send_inline_query_result(update, result_id, message_content=self.disallowed_message)
|
||||||
|
|
||||||
def get_reply_to_message_id(self, update: Update):
|
async def send_budget_reached_message(self, update: Update, _: ContextTypes.DEFAULT_TYPE, is_inline=False):
|
||||||
"""
|
"""
|
||||||
Returns the message id of the message to reply to
|
Sends the budget reached message to the user.
|
||||||
:param update: Telegram update object
|
|
||||||
:return: Message id of the message to reply to, or None if quoting is disabled
|
|
||||||
"""
|
"""
|
||||||
if self.config['enable_quoting'] or self.is_group_chat(update):
|
if not is_inline:
|
||||||
return update.message.message_id
|
await update.effective_message.reply_text(
|
||||||
return None
|
message_thread_id=get_thread_id(update),
|
||||||
|
text=self.budget_limit_message
|
||||||
def split_into_chunks(self, text: str, chunk_size: int = 4096) -> list[str]:
|
)
|
||||||
"""
|
else:
|
||||||
Splits a string into chunks of a given size.
|
result_id = str(uuid4())
|
||||||
"""
|
await self.send_inline_query_result(update, result_id, message_content=self.budget_limit_message)
|
||||||
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
|
|
||||||
|
|
||||||
async def post_init(self, application: Application) -> None:
|
async def post_init(self, application: Application) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -1036,6 +760,6 @@ class ChatGPTTelegramBot:
|
|||||||
]))
|
]))
|
||||||
application.add_handler(CallbackQueryHandler(self.handle_callback_inline_query))
|
application.add_handler(CallbackQueryHandler(self.handle_callback_inline_query))
|
||||||
|
|
||||||
application.add_error_handler(self.error_handler)
|
application.add_error_handler(error_handler)
|
||||||
|
|
||||||
application.run_polling()
|
application.run_polling()
|
||||||
|
|||||||
@@ -175,6 +175,9 @@ class UsageTracker:
|
|||||||
json.dump(self.usage, outfile)
|
json.dump(self.usage, outfile)
|
||||||
|
|
||||||
def add_current_costs(self, request_cost):
|
def add_current_costs(self, request_cost):
|
||||||
|
"""
|
||||||
|
Add current cost to all_time, day and month cost and update last_update date.
|
||||||
|
"""
|
||||||
today = date.today()
|
today = date.today()
|
||||||
last_update = date.fromisoformat(self.usage["current_cost"]["last_update"])
|
last_update = date.fromisoformat(self.usage["current_cost"]["last_update"])
|
||||||
|
|
||||||
|
|||||||
307
bot/utils.py
Normal file
307
bot/utils.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import telegram
|
||||||
|
from telegram import Message, MessageEntity, Update, ChatMember, constants
|
||||||
|
from telegram.ext import CallbackContext, ContextTypes
|
||||||
|
|
||||||
|
from usage_tracker import UsageTracker
|
||||||
|
|
||||||
|
|
||||||
|
def message_text(message: Message) -> str:
|
||||||
|
"""
|
||||||
|
Returns the text of a message, excluding any bot commands.
|
||||||
|
"""
|
||||||
|
message_txt = message.text
|
||||||
|
if message_txt is None:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
for _, text in sorted(message.parse_entities([MessageEntity.BOT_COMMAND]).items(),
|
||||||
|
key=(lambda item: item[0].offset)):
|
||||||
|
message_txt = message_txt.replace(text, '').strip()
|
||||||
|
|
||||||
|
return message_txt if len(message_txt) > 0 else ''
|
||||||
|
|
||||||
|
|
||||||
|
async def is_user_in_group(update: Update, context: CallbackContext, user_id: int) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if user_id is a member of the group
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
chat_member = await context.bot.get_chat_member(update.message.chat_id, user_id)
|
||||||
|
return chat_member.status in [ChatMember.OWNER, ChatMember.ADMINISTRATOR, ChatMember.MEMBER]
|
||||||
|
except telegram.error.BadRequest as e:
|
||||||
|
if str(e) == "User not found":
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_id(update: Update) -> int | None:
|
||||||
|
"""
|
||||||
|
Gets the message thread id for the update, if any
|
||||||
|
"""
|
||||||
|
if update.effective_message and update.effective_message.is_topic_message:
|
||||||
|
return update.effective_message.message_thread_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_stream_cutoff_values(update: Update, content: str) -> int:
|
||||||
|
"""
|
||||||
|
Gets the stream cutoff values for the message length
|
||||||
|
"""
|
||||||
|
if is_group_chat(update):
|
||||||
|
# group chats have stricter flood limits
|
||||||
|
return 180 if len(content) > 1000 else 120 if len(content) > 200 \
|
||||||
|
else 90 if len(content) > 50 else 50
|
||||||
|
return 90 if len(content) > 1000 else 45 if len(content) > 200 \
|
||||||
|
else 25 if len(content) > 50 else 15
|
||||||
|
|
||||||
|
|
||||||
|
def is_group_chat(update: Update) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the message was sent from a group chat
|
||||||
|
"""
|
||||||
|
if not update.effective_chat:
|
||||||
|
return False
|
||||||
|
return update.effective_chat.type in [
|
||||||
|
constants.ChatType.GROUP,
|
||||||
|
constants.ChatType.SUPERGROUP
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def split_into_chunks(text: str, chunk_size: int = 4096) -> list[str]:
|
||||||
|
"""
|
||||||
|
Splits a string into chunks of a given size.
|
||||||
|
"""
|
||||||
|
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
|
||||||
|
|
||||||
|
|
||||||
|
async def wrap_with_indicator(update: Update, context: CallbackContext, coroutine,
|
||||||
|
chat_action: constants.ChatAction = "", is_inline=False):
|
||||||
|
"""
|
||||||
|
Wraps a coroutine while repeatedly sending a chat action to the user.
|
||||||
|
"""
|
||||||
|
task = context.application.create_task(coroutine(), update=update)
|
||||||
|
while not task.done():
|
||||||
|
if not is_inline:
|
||||||
|
context.application.create_task(
|
||||||
|
update.effective_chat.send_action(chat_action, message_thread_id=get_thread_id(update))
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.shield(task), 4.5)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def edit_message_with_retry(context: ContextTypes.DEFAULT_TYPE, chat_id: int | None,
|
||||||
|
message_id: str, text: str, markdown: bool = True, is_inline: bool = False):
|
||||||
|
"""
|
||||||
|
Edit a message with retry logic in case of failure (e.g. broken markdown)
|
||||||
|
:param context: The context to use
|
||||||
|
:param chat_id: The chat id to edit the message in
|
||||||
|
:param message_id: The message id to edit
|
||||||
|
:param text: The text to edit the message with
|
||||||
|
:param markdown: Whether to use markdown parse mode
|
||||||
|
:param is_inline: Whether the message to edit is an inline message
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await context.bot.edit_message_text(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=int(message_id) if not is_inline else None,
|
||||||
|
inline_message_id=message_id if is_inline else None,
|
||||||
|
text=text,
|
||||||
|
parse_mode=constants.ParseMode.MARKDOWN if markdown else None
|
||||||
|
)
|
||||||
|
except telegram.error.BadRequest as e:
|
||||||
|
if str(e).startswith("Message is not modified"):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await context.bot.edit_message_text(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=int(message_id) if not is_inline else None,
|
||||||
|
inline_message_id=message_id if is_inline else None,
|
||||||
|
text=text
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f'Failed to edit message: {str(e)}')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(str(e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
async def error_handler(_: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||||
|
"""
|
||||||
|
Handles errors in the telegram-python-bot library.
|
||||||
|
"""
|
||||||
|
logging.error(f'Exception while handling an update: {context.error}')
|
||||||
|
|
||||||
|
|
||||||
|
async def is_allowed(config, update: Update, context: CallbackContext, is_inline=False) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the user is allowed to use the bot.
|
||||||
|
"""
|
||||||
|
if config['allowed_user_ids'] == '*':
|
||||||
|
return True
|
||||||
|
|
||||||
|
user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id
|
||||||
|
if is_admin(config, user_id):
|
||||||
|
return True
|
||||||
|
name = update.inline_query.from_user.name if is_inline else update.message.from_user.name
|
||||||
|
allowed_user_ids = config['allowed_user_ids'].split(',')
|
||||||
|
# Check if user is allowed
|
||||||
|
if str(user_id) in allowed_user_ids:
|
||||||
|
return True
|
||||||
|
# Check if it's a group a chat with at least one authorized member
|
||||||
|
if not is_inline and is_group_chat(update):
|
||||||
|
admin_user_ids = config['admin_user_ids'].split(',')
|
||||||
|
for user in itertools.chain(allowed_user_ids, admin_user_ids):
|
||||||
|
if not user.strip():
|
||||||
|
continue
|
||||||
|
if await is_user_in_group(update, context, user):
|
||||||
|
logging.info(f'{user} is a member. Allowing group chat message...')
|
||||||
|
return True
|
||||||
|
logging.info(f'Group chat messages from user {name} '
|
||||||
|
f'(id: {user_id}) are not allowed')
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_admin(config, user_id: int, log_no_admin=False) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the user is the admin of the bot.
|
||||||
|
The first user in the user list is the admin.
|
||||||
|
"""
|
||||||
|
if config['admin_user_ids'] == '-':
|
||||||
|
if log_no_admin:
|
||||||
|
logging.info('No admin user defined.')
|
||||||
|
return False
|
||||||
|
|
||||||
|
admin_user_ids = config['admin_user_ids'].split(',')
|
||||||
|
|
||||||
|
# Check if user is in the admin user list
|
||||||
|
if str(user_id) in admin_user_ids:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_budget(config, user_id) -> float | None:
|
||||||
|
"""
|
||||||
|
Get the user's budget based on their user ID and the bot configuration.
|
||||||
|
:param config: The bot configuration object
|
||||||
|
:param user_id: User id
|
||||||
|
:return: The user's budget as a float, or None if the user is not found in the allowed user list
|
||||||
|
"""
|
||||||
|
|
||||||
|
# no budget restrictions for admins and '*'-budget lists
|
||||||
|
if is_admin(config, user_id) or config['user_budgets'] == '*':
|
||||||
|
return float('inf')
|
||||||
|
|
||||||
|
user_budgets = config['user_budgets'].split(',')
|
||||||
|
if config['allowed_user_ids'] == '*':
|
||||||
|
# same budget for all users, use value in first position of budget list
|
||||||
|
if len(user_budgets) > 1:
|
||||||
|
logging.warning('multiple values for budgets set with unrestricted user list '
|
||||||
|
'only the first value is used as budget for everyone.')
|
||||||
|
return float(user_budgets[0])
|
||||||
|
|
||||||
|
allowed_user_ids = config['allowed_user_ids'].split(',')
|
||||||
|
if str(user_id) in allowed_user_ids:
|
||||||
|
user_index = allowed_user_ids.index(str(user_id))
|
||||||
|
if len(user_budgets) <= user_index:
|
||||||
|
logging.warning(f'No budget set for user id: {user_id}. Budget list shorter than user list.')
|
||||||
|
return 0.0
|
||||||
|
return float(user_budgets[user_index])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_remaining_budget(config, usage, update: Update, is_inline=False) -> float:
|
||||||
|
"""
|
||||||
|
Calculate the remaining budget for a user based on their current usage.
|
||||||
|
:param config: The bot configuration object
|
||||||
|
:param usage: The usage tracker object
|
||||||
|
:param update: Telegram update object
|
||||||
|
:param is_inline: Boolean flag for inline queries
|
||||||
|
:return: The remaining budget for the user as a float
|
||||||
|
"""
|
||||||
|
# Mapping of budget period to cost period
|
||||||
|
budget_cost_map = {
|
||||||
|
"monthly": "cost_month",
|
||||||
|
"daily": "cost_today",
|
||||||
|
"all-time": "cost_all_time"
|
||||||
|
}
|
||||||
|
|
||||||
|
user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id
|
||||||
|
name = update.inline_query.from_user.name if is_inline else update.message.from_user.name
|
||||||
|
if user_id not in usage:
|
||||||
|
usage[user_id] = UsageTracker(user_id, name)
|
||||||
|
|
||||||
|
# Get budget for users
|
||||||
|
user_budget = get_user_budget(config, user_id)
|
||||||
|
budget_period = config['budget_period']
|
||||||
|
if user_budget is not None:
|
||||||
|
cost = usage[user_id].get_current_cost()[budget_cost_map[budget_period]]
|
||||||
|
return user_budget - cost
|
||||||
|
|
||||||
|
# Get budget for guests
|
||||||
|
if 'guests' not in usage:
|
||||||
|
usage['guests'] = UsageTracker('guests', 'all guest users in group chats')
|
||||||
|
cost = usage['guests'].get_current_cost()[budget_cost_map[budget_period]]
|
||||||
|
return config['guest_budget'] - cost
|
||||||
|
|
||||||
|
|
||||||
|
def is_within_budget(config, usage, update: Update, is_inline=False) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if the user reached their usage limit.
|
||||||
|
Initializes UsageTracker for user and guest when needed.
|
||||||
|
:param config: The bot configuration object
|
||||||
|
:param usage: The usage tracker object
|
||||||
|
:param update: Telegram update object
|
||||||
|
:param is_inline: Boolean flag for inline queries
|
||||||
|
:return: Boolean indicating if the user has a positive budget
|
||||||
|
"""
|
||||||
|
user_id = update.inline_query.from_user.id if is_inline else update.message.from_user.id
|
||||||
|
name = update.inline_query.from_user.name if is_inline else update.message.from_user.name
|
||||||
|
if user_id not in usage:
|
||||||
|
usage[user_id] = UsageTracker(user_id, name)
|
||||||
|
remaining_budget = get_remaining_budget(config, usage, update, is_inline=is_inline)
|
||||||
|
return remaining_budget > 0
|
||||||
|
|
||||||
|
|
||||||
|
def add_chat_request_to_usage_tracker(usage, config, user_id, used_tokens):
|
||||||
|
"""
|
||||||
|
Add chat request to usage tracker
|
||||||
|
:param usage: The usage tracker object
|
||||||
|
:param config: The bot configuration object
|
||||||
|
:param user_id: The user id
|
||||||
|
:param used_tokens: The number of tokens used
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# add chat request to users usage tracker
|
||||||
|
usage[user_id].add_chat_tokens(used_tokens, config['token_price'])
|
||||||
|
# add guest chat request to guest usage tracker
|
||||||
|
allowed_user_ids = config['allowed_user_ids'].split(',')
|
||||||
|
if str(user_id) not in allowed_user_ids and 'guests' in usage:
|
||||||
|
usage["guests"].add_chat_tokens(used_tokens, config['token_price'])
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f'Failed to add tokens to usage_logs: {str(e)}')
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_reply_to_message_id(config, update: Update):
|
||||||
|
"""
|
||||||
|
Returns the message id of the message to reply to
|
||||||
|
:param config: Bot configuration object
|
||||||
|
:param update: Telegram update object
|
||||||
|
:return: Message id of the message to reply to, or None if quoting is disabled
|
||||||
|
"""
|
||||||
|
if config['enable_quoting'] or is_group_chat(update):
|
||||||
|
return update.message.message_id
|
||||||
|
return None
|
||||||
Reference in New Issue
Block a user