From 2d941e2cb303b060966d1197a1031c58306cd74a Mon Sep 17 00:00:00 2001 From: profitroll Date: Fri, 26 May 2023 16:32:56 +0200 Subject: [PATCH] PyroClient overhaul --- classes/pyroclient.py | 247 ++++++++++++++++++++++++++++++++++++++++++ main.py | 4 +- modules/app.py | 60 ---------- plugins/callback.py | 2 +- plugins/command.py | 2 +- plugins/handler.py | 2 +- plugins/inline.py | 2 +- 7 files changed, 254 insertions(+), 65 deletions(-) create mode 100644 classes/pyroclient.py delete mode 100644 modules/app.py diff --git a/classes/pyroclient.py b/classes/pyroclient.py new file mode 100644 index 0000000..27bc4cb --- /dev/null +++ b/classes/pyroclient.py @@ -0,0 +1,247 @@ +import logging +from datetime import datetime, timedelta +from os import getpid +from time import time +from typing import List, Union + +import pyrogram +from libbot import config_get +from libbot.i18n import in_every_locale +from libbot.i18n.sync import _ +from pyrogram.client import Client +from pyrogram.errors import BadRequest +from pyrogram.handlers.message_handler import MessageHandler +from pyrogram.raw.all import layer +from pyrogram.types import ( + BotCommand, + BotCommandScopeAllChatAdministrators, + BotCommandScopeAllGroupChats, + BotCommandScopeAllPrivateChats, + BotCommandScopeChat, + BotCommandScopeChatAdministrators, + BotCommandScopeChatMember, + BotCommandScopeDefault, +) +from ujson import dumps, loads + +from classes.commandset import CommandSet +from classes.pyrocommand import PyroCommand +from modules.scheduler import scheduler + +logger = logging.getLogger(__name__) + + +class PyroClient(Client): + def __init__(self): + with open("config.json", "r", encoding="utf-8") as f: + config = loads(f.read()) + super().__init__( + name="bot_client", + api_id=config["bot"]["api_id"], + api_hash=config["bot"]["api_hash"], + bot_token=config["bot"]["bot_token"], + workers=config["bot"]["workers"], + plugins=dict(root="plugins", exclude=config["disabled_plugins"]), + sleep_threshold=120, + ) + self.commands: List[PyroCommand] = [] + self.scoped_commands = config["bot"]["scoped_commands"] + self.start_time = 0 + + async def start(self): + await super().start() + + self.start_time = time() + + logger.info( + "Bot is running with Pyrogram v%s (Layer %s) and has started as @%s on PID %s.", + pyrogram.__version__, + layer, + self.me.username, + getpid(), + ) + + try: + await self.send_message( + chat_id=await config_get("chat_id", "reports"), + text=f"Bot started PID `{getpid()}`", + ) + + scheduler.add_job( + self.register_commands, + trigger="date", + run_date=datetime.now() + timedelta(seconds=5), + kwargs={"command_sets": await self.collect_commands()}, + ) + + scheduler.start() + except BadRequest: + logger.warning("Unable to send message to report chat.") + + async def stop(self): + try: + await self.send_message( + chat_id=await config_get("chat_id", "reports"), + text=f"Bot stopped with PID `{getpid()}`", + ) + except BadRequest: + logger.warning("Unable to send message to report chat.") + await super().stop() + logger.warning("Bot stopped with PID %s.", getpid()) + + async def collect_commands(self) -> Union[List[CommandSet], None]: + """Gather list of the bot's commands + + ### Returns: + * `List[CommandSet]`: List of the commands' sets + """ + command_sets = None + + # If config get bot.scoped_commands is true - more complicated + # scopes system will be used instead of simple global commands + if self.scoped_commands: + scopes = {} + command_sets = [] + + # Iterate through all commands in config + for command, contents in (await config_get("commands")).items(): + # Iterate through all scopes of a command + for scope in contents["scopes"]: + if dumps(scope) not in scopes: + scopes[dumps(scope)] = {"_": []} + + # Add command to the scope's flattened key in scopes dict + scopes[dumps(scope)]["_"].append( + BotCommand(command, _(command, "commands")) + ) + + for locale, string in ( + await in_every_locale(command, "commands") + ).items(): + if locale not in scopes[dumps(scope)]: + scopes[dumps(scope)][locale] = [] + + scopes[dumps(scope)][locale].append(BotCommand(command, string)) + + # Iterate through all scopes and its commands + for scope, locales in scopes.items(): + # Make flat key a dict again + scope_dict = loads(scope) + + # Create object with the same name and args from the dict + try: + scope_obj = globals()[scope_dict["name"]]( + **{ + key: value + for key, value in scope_dict.items() + if key != "name" + } + ) + except NameError: + logger.error( + "Could not register commands of the scope '%s' due to an invalid scope class provided!", + scope_dict["name"], + ) + continue + except TypeError: + logger.error( + "Could not register commands of the scope '%s' due to an invalid class arguments provided!", + scope_dict["name"], + ) + continue + + # Add set of commands to the list of the command sets + for locale, commands in locales.items(): + if locale == "_": + command_sets.append( + CommandSet(commands, scope=scope_obj, language_code="") + ) + continue + command_sets.append( + CommandSet(commands, scope=scope_obj, language_code=locale) + ) + + logger.info("Registering the following command sets: %s", command_sets) + + else: + # This part here looks into the handlers and looks for commands + # in it, if there are any. Then adds them to self.commands + for handler in self.dispatcher.groups[0]: + if isinstance(handler, MessageHandler): + for entry in [handler.filters.base, handler.filters.other]: + if hasattr(entry, "commands"): + for command in entry.commands: + logger.info("I see a command %s in my filters", command) + self.add_command(command) + + return command_sets + + def add_command( + self, + command: str, + ): + """Add command to the bot's internal commands list + + ### Args: + * command (`str`) + """ + self.commands.append( + PyroCommand( + command, + _(command, "commands"), + ) + ) + logger.info( + "Added command '%s' to the bot's internal commands list", + command, + ) + + async def register_commands( + self, command_sets: Union[List[CommandSet], None] = None + ): + """Register commands stored in bot's 'commands' attribute""" + + if command_sets is None: + commands = [ + BotCommand(command=command.command, description=command.description) + for command in self.commands + ] + + logger.info( + "Registering commands %s with a default scope 'BotCommandScopeDefault'" + ) + + await self.set_bot_commands(commands) + return + + for command_set in command_sets: + await self.set_bot_commands( + command_set.commands, + command_set.scope, + language_code=command_set.language_code, + ) + logger.info( + "Registering command set with commands %s and scope '%s'", + command_set.commands, + command_set.scope, + ) + + async def remove_commands(self, command_sets: Union[List[CommandSet], None] = None): + """Remove commands stored in bot's 'commands' attribute""" + + if command_sets is None: + logger.info( + "Removing commands with a default scope 'BotCommandScopeDefault'" + ) + await self.delete_bot_commands(BotCommandScopeDefault()) + return + + for command_set in command_sets: + logger.info( + "Removing command set with scope '%s'", + command_set.scope, + ) + await self.delete_bot_commands( + command_set.scope, + language_code=command_set.language_code, + ) diff --git a/main.py b/main.py index bc65869..60dbe2e 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,8 @@ import contextlib import logging from os import getpid -from modules.app import PyroClient +from classes.pyroclient import PyroClient +from modules.scheduler import scheduler # Uncomment this and the line below client declaration # in order to use context manager in your commands. @@ -32,6 +33,7 @@ def main(): except KeyboardInterrupt: logger.warning("Forcefully shutting down with PID %s...", getpid()) finally: + scheduler.shutdown() exit() diff --git a/modules/app.py b/modules/app.py deleted file mode 100644 index b2c84a3..0000000 --- a/modules/app.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging -from os import getpid -from time import time - -import pyrogram -from libbot import config_get -from pyrogram.client import Client -from pyrogram.errors import BadRequest -from pyrogram.raw.all import layer -from ujson import loads - -logger = logging.getLogger(__name__) - - -class PyroClient(Client): - def __init__(self): - with open("config.json", "r", encoding="utf-8") as f: - config = loads(f.read()) - super().__init__( - name="bot_client", - api_id=config["bot"]["api_id"], - api_hash=config["bot"]["api_hash"], - bot_token=config["bot"]["bot_token"], - workers=config["bot"]["workers"], - plugins=dict(root="plugins", exclude=config["disabled_plugins"]), - sleep_threshold=120, - ) - self.start_time = 0 - - async def start(self): - await super().start() - - self.start_time = time() - - logger.info( - "Bot is running with Pyrogram v%s (Layer %s) and has started as @%s on PID %s.", - pyrogram.__version__, - layer, - self.me.username, - getpid(), - ) - - try: - await self.send_message( - chat_id=await config_get("chat_id", "reports"), - text=f"Bot started PID `{getpid()}`", - ) - except BadRequest: - logger.warning("Unable to send message to report chat.") - - async def stop(self): - try: - await self.send_message( - chat_id=await config_get("chat_id", "reports"), - text=f"Bot stopped with PID `{getpid()}`", - ) - except BadRequest: - logger.warning("Unable to send message to report chat.") - await super().stop() - logger.warning("Bot stopped with PID %s.", getpid()) diff --git a/plugins/callback.py b/plugins/callback.py index 3636deb..9233d97 100644 --- a/plugins/callback.py +++ b/plugins/callback.py @@ -2,7 +2,7 @@ from pyrogram import filters from pyrogram.client import Client from pyrogram.types import CallbackQuery -from modules.app import PyroClient +from classes.pyroclient import PyroClient @Client.on_callback_query(filters.regex("nothing")) # type: ignore diff --git a/plugins/command.py b/plugins/command.py index 4d2007e..97a136f 100644 --- a/plugins/command.py +++ b/plugins/command.py @@ -2,7 +2,7 @@ from pyrogram import filters from pyrogram.client import Client from pyrogram.types import Message -from modules.app import PyroClient +from classes.pyroclient import PyroClient @Client.on_message( diff --git a/plugins/handler.py b/plugins/handler.py index df99764..9eb8622 100644 --- a/plugins/handler.py +++ b/plugins/handler.py @@ -2,7 +2,7 @@ from pyrogram import filters from pyrogram.client import Client from pyrogram.types import Message -from modules.app import PyroClient +from classes.pyroclient import PyroClient @Client.on_message(filters.text & filters.private) # type: ignore diff --git a/plugins/inline.py b/plugins/inline.py index f9285c9..1daa5a8 100644 --- a/plugins/inline.py +++ b/plugins/inline.py @@ -5,7 +5,7 @@ from pyrogram.types import ( InputTextMessageContent, ) -from modules.app import PyroClient +from classes.pyroclient import PyroClient @Client.on_inline_query() # type: ignore