mirror of
https://github.com/aljazceru/chatgpt-telegram-bot.git
synced 2026-01-06 06:26:23 +01:00
added guest budget for group chats
This commit is contained in:
1
main.py
1
main.py
@@ -60,6 +60,7 @@ def main():
|
||||
'token': os.environ['TELEGRAM_BOT_TOKEN'],
|
||||
'allowed_user_ids': os.environ.get('ALLOWED_TELEGRAM_USER_IDS', '*'),
|
||||
'monthly_user_budgets': os.environ.get('MONTHLY_USER_BUDGETS', '*'),
|
||||
'monthly_guest_budget': float(os.environ.get('MONTHLY_GUEST_BUDGET', '100.0')),
|
||||
'proxy': os.environ.get('PROXY', None),
|
||||
'voice_reply_transcript': os.environ.get('VOICE_REPLY_WITH_TRANSCRIPT_ONLY', 'true').lower() == 'true',
|
||||
'token_price': float(os.environ.get('TOKEN_PRICE', 0.002)),
|
||||
|
||||
102
telegram_bot.py
102
telegram_bot.py
@@ -9,7 +9,6 @@ from telegram.ext import ApplicationBuilder, ContextTypes, CommandHandler, Messa
|
||||
from pydub import AudioSegment
|
||||
from openai_helper import OpenAIHelper
|
||||
from usage_tracker import UsageTracker
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class ChatGPT3TelegramBot:
|
||||
@@ -108,10 +107,8 @@ class ChatGPT3TelegramBot:
|
||||
await self.send_disallowed_message(update, context)
|
||||
return
|
||||
|
||||
enough_budget, billed_user = await self.get_budget_info(update)
|
||||
|
||||
if not enough_budget:
|
||||
logging.warning(f'User {billed_user} reached their usage limit')
|
||||
if not await self.is_within_budget(update):
|
||||
logging.warning(f'User {update.message.from_user.name} reached their usage limit')
|
||||
await self.send_budget_reached_message(update, context)
|
||||
return
|
||||
|
||||
@@ -131,8 +128,12 @@ class ChatGPT3TelegramBot:
|
||||
reply_to_message_id=update.message.message_id,
|
||||
photo=image_url
|
||||
)
|
||||
# add image request to usage tracker
|
||||
# add image request to users usage tracker
|
||||
self.usage[update.message.from_user.id].add_image_request(image_size, self.config['image_prices'])
|
||||
# add guest chat request to guest usage tracker
|
||||
allowed_user_ids = self.config['allowed_user_ids'].split(',')
|
||||
if str(update.message.from_user.id) not in allowed_user_ids:
|
||||
self.usage["guests"].add_image_request(image_size, self.config['image_prices'])
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
@@ -151,10 +152,8 @@ class ChatGPT3TelegramBot:
|
||||
await self.send_disallowed_message(update, context)
|
||||
return
|
||||
|
||||
enough_budget, billed_user = await self.get_budget_info(update)
|
||||
|
||||
if not enough_budget:
|
||||
logging.warning(f'User {billed_user} reached their usage limit')
|
||||
if not await self.is_within_budget(update):
|
||||
logging.warning(f'User {update.message.from_user.name} reached their usage limit')
|
||||
await self.send_budget_reached_message(update, context)
|
||||
return
|
||||
|
||||
@@ -196,8 +195,14 @@ class ChatGPT3TelegramBot:
|
||||
|
||||
# Transcribe the audio file
|
||||
transcript = self.openai.transcribe(filename_mp3)
|
||||
|
||||
# add transcription seconds to usage tracker
|
||||
self.usage[billed_user].add_transcription_seconds(audio_track.duration_seconds, self.config['transcription_price'])
|
||||
self.usage[update.message.from_user.id].add_transcription_seconds(audio_track.duration_seconds, self.config['transcription_price'])
|
||||
# add guest chat request to guest usage tracker
|
||||
allowed_user_ids = self.config['allowed_user_ids'].split(',')
|
||||
if str(update.message.from_user.id) not in allowed_user_ids:
|
||||
self.usage["guests"].add_transcription_seconds(audio_track.duration_seconds, self.config['transcription_price'])
|
||||
|
||||
if self.config['voice_reply_transcript']:
|
||||
# Send the transcript
|
||||
await context.bot.send_message(
|
||||
@@ -209,8 +214,14 @@ class ChatGPT3TelegramBot:
|
||||
else:
|
||||
# Send the response of the transcript
|
||||
response, total_tokens = self.openai.get_chat_response(chat_id=chat_id, query=transcript)
|
||||
# add chat request to usage tracker
|
||||
self.usage[billed_user].add_chat_tokens(total_tokens, self.config['token_price'])
|
||||
|
||||
# add chat request to users usage tracker
|
||||
self.usage[update.message.from_user.id].add_chat_tokens(total_tokens, self.config['token_price'])
|
||||
# add guest chat request to guest usage tracker
|
||||
allowed_user_ids = self.config['allowed_user_ids'].split(',')
|
||||
if str(update.message.from_user.id) not in allowed_user_ids:
|
||||
self.usage["guests"].add_chat_tokens(total_tokens, self.config['token_price'])
|
||||
|
||||
await context.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
reply_to_message_id=update.message.message_id,
|
||||
@@ -240,10 +251,8 @@ class ChatGPT3TelegramBot:
|
||||
await self.send_disallowed_message(update, context)
|
||||
return
|
||||
|
||||
enough_budget, billed_user = await self.get_budget_info(update)
|
||||
|
||||
if not enough_budget:
|
||||
logging.warning(f'User {billed_user} reached their usage limit')
|
||||
if not await self.is_within_budget(update):
|
||||
logging.warning(f'User {update.message.from_user.name} reached their usage limit')
|
||||
await self.send_budget_reached_message(update, context)
|
||||
return
|
||||
|
||||
@@ -253,8 +262,12 @@ class ChatGPT3TelegramBot:
|
||||
await context.bot.send_chat_action(chat_id=chat_id, action=constants.ChatAction.TYPING)
|
||||
response, total_tokens = self.openai.get_chat_response(chat_id=chat_id, query=update.message.text)
|
||||
|
||||
# add chat request to usage tracker
|
||||
self.usage[billed_user].add_chat_tokens(total_tokens, self.config['token_price'])
|
||||
# add chat request to users usage tracker
|
||||
self.usage[update.message.from_user.id].add_chat_tokens(total_tokens, self.config['token_price'])
|
||||
# add guest chat request to guest usage tracker
|
||||
allowed_user_ids = self.config['allowed_user_ids'].split(',')
|
||||
if str(update.message.from_user.id) not in allowed_user_ids:
|
||||
self.usage["guests"].add_chat_tokens(total_tokens, self.config['token_price'])
|
||||
|
||||
await context.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
@@ -353,43 +366,48 @@ class ChatGPT3TelegramBot:
|
||||
|
||||
return False
|
||||
|
||||
async def get_budget_info(self, update: Update) -> Tuple[bool, str]:
|
||||
async def is_within_budget(self, update: Update) -> bool:
|
||||
"""
|
||||
Checks if the reached their monthly usage limit.
|
||||
Checks if the user reached their monthly usage limit.
|
||||
Initializes self.usage for user and guest when needed.
|
||||
"""
|
||||
user_id = update.message.from_user.id
|
||||
if user_id not in self.usage:
|
||||
self.usage[user_id] = UsageTracker(user_id, update.message.from_user.name)
|
||||
|
||||
if self.config['monthly_user_budgets'] == '*':
|
||||
return True, user_id
|
||||
return True
|
||||
|
||||
allowed_user_ids = self.config['allowed_user_ids'].split(',')
|
||||
user_index = allowed_user_ids.index(str(user_id))
|
||||
user_budgets = self.config['monthly_user_budgets'].split(',')
|
||||
if len(user_budgets) <= user_index:
|
||||
logging.warning(f'user {update.message.from_user.name} ({user_id}) does not have a budget.')
|
||||
return False, user_id
|
||||
user_budget = float(user_budgets[user_index])
|
||||
cost_month = self.usage[user_id].get_current_cost()[1]
|
||||
# Check if user is within budget
|
||||
if user_budget > cost_month:
|
||||
return True, user_id
|
||||
if str(user_id) in allowed_user_ids:
|
||||
# find budget for allowed user
|
||||
user_index = allowed_user_ids.index(str(user_id))
|
||||
user_budgets = self.config['monthly_user_budgets'].split(',')
|
||||
# check if user is included in budgets list
|
||||
if len(user_budgets) <= user_index:
|
||||
logging.warning(f'No budget set for user: {update.message.from_user.name} ({user_id}).')
|
||||
return False
|
||||
user_budget = float(user_budgets[user_index])
|
||||
cost_month = self.usage[user_id].get_current_cost()[1]
|
||||
# Check if allowed user is within budget
|
||||
if user_budget > cost_month:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
# Check if group member is within budget
|
||||
if self.is_group_chat(update):
|
||||
for user, index in enumerate(allowed_user_ids):
|
||||
logging.info(user)
|
||||
for user in allowed_user_ids:
|
||||
if await self.is_user_in_group(update, user):
|
||||
user_budget = float(user_budgets[index])
|
||||
if user_budget >= self.usage[user].get_current_cost()[1]:
|
||||
if user not in self.usage:
|
||||
self.usage[user_id] = UsageTracker(user_id, "n.a.") # How to get user name here?
|
||||
logging.info(f'Billing {user} for request in group chat by {user_id}({update.message.from_user.name}).')
|
||||
return True, user_id
|
||||
if 'guests' not in self.usage:
|
||||
self.usage['guests'] = UsageTracker('guests', 'non-users in group chats')
|
||||
if self.config['monthly_guest_budget'] >= self.usage['guests'].get_current_cost()[1]:
|
||||
return True
|
||||
else:
|
||||
logging.warning(f'Monthly guest budget for group chats used up.')
|
||||
return False
|
||||
logging.info(f'Group chat messages from user {update.message.from_user.name} are not allowed')
|
||||
|
||||
return False, user_id
|
||||
return False
|
||||
|
||||
async def post_init(self, application: Application) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user