import asyncio import logging from datetime import datetime, timedelta from os import cpu_count, getpid from pathlib import Path from time import time from typing import Any, Dict, List, Union try: import pyrogram from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.background import BackgroundScheduler from pyrogram.client import Client from pyrogram.errors import BadRequest from pyrogram.handlers.message_handler import MessageHandler from pyrogram.raw.all import layer from pyrogram.types import ( BotCommand, BotCommandScopeAllChatAdministrators, BotCommandScopeAllGroupChats, BotCommandScopeAllPrivateChats, BotCommandScopeChat, BotCommandScopeChatAdministrators, BotCommandScopeChatMember, BotCommandScopeDefault, ) except ImportError as exc: raise ImportError( "You need to install libbot[pyrogram] in order to use this class." ) from exc try: from ujson import dumps, loads except ImportError: from json import dumps, loads from libbot.i18n import BotLocale from libbot.i18n.sync import _ from libbot.pyrogram.classes.command import PyroCommand from libbot.pyrogram.classes.commandset import CommandSet logger = logging.getLogger(__name__) class PyroClient(Client): def __init__( self, name: str = "bot_client", config: Union[Dict[str, Any], None] = None, config_path: Union[str, Path] = Path("config.json"), api_id: Union[int, None] = None, api_hash: Union[str, None] = None, bot_token: Union[str, None] = None, workers: int = min(32, cpu_count() + 4), locales_root: Union[str, Path, None] = None, plugins_root: str = "plugins", plugins_exclude: Union[List[str], None] = None, sleep_threshold: int = 120, max_concurrent_transmissions: int = 1, commands_source: Union[Dict[str, dict], None] = None, scheduler: Union[AsyncIOScheduler, BackgroundScheduler, None] = None, ): if config is None: with open(config_path, "r", encoding="utf-8") as f: self.config: dict = loads(f.read()) else: self.config = config super().__init__( name=name, api_id=self.config["bot"]["api_id"] if api_id is None else api_id, api_hash=self.config["bot"]["api_hash"] if api_hash is None else api_hash, bot_token=self.config["bot"]["bot_token"] if bot_token is None else bot_token, # Workers should be `min(32, cpu_count() + 4)`, otherwise # handlers land in another event loop and you won't see them workers=self.config["bot"]["workers"] if "workers" in self.config["bot"] else workers, plugins=dict( root=plugins_root, exclude=self.config["disabled_plugins"] if plugins_exclude is None else plugins_exclude, ), sleep_threshold=sleep_threshold, max_concurrent_transmissions=self.config["bot"][ "max_concurrent_transmissions" ] if "max_concurrent_transmissions" in self.config["bot"] else max_concurrent_transmissions, ) self.owner: int = self.config["bot"]["owner"] self.commands: List[PyroCommand] = [] self.commands_source: Dict[str, dict] = ( self.config["commands"] if commands_source is None else commands_source ) self.scoped_commands: bool = self.config["bot"]["scoped_commands"] self.start_time: float = 0 self.bot_locale: BotLocale = BotLocale( (Path("locale") if locales_root is None else locales_root) ) self.default_locale: str = self.bot_locale.default self.locales: dict = self.bot_locale.locales self._ = self.bot_locale._ self.in_all_locales = self.bot_locale.in_all_locales self.in_every_locale = self.bot_locale.in_every_locale self.scheduler: Union[AsyncIOScheduler, BackgroundScheduler, None] = scheduler self.scopes_placeholders: Dict[str, int] = {"owner": self.owner} async def start(self, register_commands: bool = True): await super().start() self.start_time = time() logger.info( "Bot is running with Pyrogram v%s (Layer %s) and has started as @%s on PID %s.", pyrogram.__version__, layer, self.me.username, getpid(), ) try: await self.send_message( chat_id=self.owner if self.config["reports"]["chat_id"] == "owner" else self.config["reports"]["chat_id"], text=f"Bot started PID `{getpid()}`", ) if self.scheduler is None: return if register_commands: self.scheduler.add_job( self.register_commands, trigger="date", run_date=datetime.now() + timedelta(seconds=5), kwargs={"command_sets": await self.collect_commands()}, ) self.scheduler.start() except BadRequest: logger.warning("Unable to send message to report chat.") async def stop(self, exit_completely: bool = True): try: await self.send_message( chat_id=self.owner if self.config["reports"]["chat_id"] == "owner" else self.config["reports"]["chat_id"], text=f"Bot stopped with PID `{getpid()}`", ) await asyncio.sleep(0.5) except BadRequest: logger.warning("Unable to send message to report chat.") await super().stop() logger.warning("Bot stopped with PID %s.", getpid()) if exit_completely: try: exit() except SystemExit as exp: raise SystemExit( "Bot has been shut down, this is not an application error!" ) from exp async def collect_commands(self) -> Union[List[CommandSet], None]: """Gather list of the bot's commands ### Returns: * `List[CommandSet]`: List of the commands' sets """ command_sets = None # If config's bot.scoped_commands is true - more complicated # scopes system will be used instead of simple global commands if self.scoped_commands: scopes = {} command_sets = [] # Iterate through all commands in config for command, contents in self.commands_source.items(): # Iterate through all scopes of a command for scope in contents["scopes"]: if dumps(scope) not in scopes: scopes[dumps(scope)] = {"_": []} # Add command to the scope's flattened key in scopes dict scopes[dumps(scope)]["_"].append( BotCommand(command, _(command, "commands")) ) for locale, string in ( self.in_every_locale(command, "commands") ).items(): if locale not in scopes[dumps(scope)]: scopes[dumps(scope)][locale] = [] scopes[dumps(scope)][locale].append(BotCommand(command, string)) # Iterate through all scopes and its commands for scope, locales in scopes.items(): # Make flat key a dict again scope_dict = loads(scope) # Replace "owner" in the bot scope with owner's id for placeholder, chat_id in self.scopes_placeholders.items(): if "chat_id" in scope_dict and scope_dict["chat_id"] == placeholder: scope_dict["chat_id"] = chat_id # Create object with the same name and args from the dict try: scope_obj = globals()[scope_dict["name"]]( **{ key: value for key, value in scope_dict.items() if key != "name" } ) except NameError: logger.error( "Could not register commands of the scope '%s' due to an invalid scope class provided!", scope_dict["name"], ) continue except TypeError: logger.error( "Could not register commands of the scope '%s' due to an invalid class arguments provided!", scope_dict["name"], ) continue # Add set of commands to the list of the command sets for locale, commands in locales.items(): if locale == "_": command_sets.append( CommandSet(commands, scope=scope_obj, language_code="") ) continue command_sets.append( CommandSet(commands, scope=scope_obj, language_code=locale) ) logger.info("Registering the following command sets: %s", command_sets) else: # This part here looks into the handlers and looks for commands # in it, if there are any. Then adds them to self.commands for handler in self.dispatcher.groups[0]: if isinstance(handler, MessageHandler): for entry in [handler.filters.base, handler.filters.other]: if hasattr(entry, "commands"): for command in entry.commands: logger.info("I see a command %s in my filters", command) self.add_command(command) return command_sets def add_command( self, command: str, ): """Add command to the bot's internal commands list ### Args: * command (`str`) """ self.commands.append( PyroCommand( command, _(command, "commands"), ) ) logger.info( "Added command '%s' to the bot's internal commands list", command, ) async def register_commands( self, command_sets: Union[List[CommandSet], None] = None ): """Register commands stored in bot's 'commands' attribute""" if command_sets is None: commands = [ BotCommand(command=command.command, description=command.description) for command in self.commands ] logger.info( "Registering commands %s with a default scope 'BotCommandScopeDefault'" ) await self.set_bot_commands(commands) return for command_set in command_sets: logger.info( "Registering command set with commands %s and scope '%s' (%s)", command_set.commands, command_set.scope, command_set.language_code, ) await self.set_bot_commands( command_set.commands, command_set.scope, language_code=command_set.language_code, ) async def remove_commands(self, command_sets: Union[List[CommandSet], None] = None): """Remove commands stored in bot's 'commands' attribute""" if command_sets is None: logger.info( "Removing commands with a default scope 'BotCommandScopeDefault'" ) await self.delete_bot_commands(BotCommandScopeDefault()) return for command_set in command_sets: logger.info( "Removing command set with scope '%s' (%s)", command_set.scope, command_set.language_code, ) await self.delete_bot_commands( command_set.scope, language_code=command_set.language_code, )