From 8c2054f496373d8bdfbf032efe0c467fe6faf614 Mon Sep 17 00:00:00 2001 From: profitroll Date: Thu, 29 Jun 2023 15:58:50 +0200 Subject: [PATCH] Improved init flexibility --- libbot/pyrogram/classes/client.py | 92 +++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 23 deletions(-) diff --git a/libbot/pyrogram/classes/client.py b/libbot/pyrogram/classes/client.py index 7b2b029..d9d16a3 100644 --- a/libbot/pyrogram/classes/client.py +++ b/libbot/pyrogram/classes/client.py @@ -1,9 +1,10 @@ +import asyncio import logging from datetime import datetime, timedelta from os import cpu_count, getpid from pathlib import Path from time import time -from typing import Dict, List, Union +from typing import Any, Dict, List, Union try: import pyrogram @@ -43,34 +44,68 @@ logger = logging.getLogger(__name__) class PyroClient(Client): def __init__( - self, scheduler: Union[AsyncIOScheduler, BackgroundScheduler, None] = None + self, + name: str = "bot_client", + config: Union[Dict[str, Any], None] = None, + config_path: Union[str, Path] = Path("config.json"), + api_id: Union[int, None] = None, + api_hash: Union[str, None] = None, + bot_token: Union[str, None] = None, + workers: int = min(32, cpu_count() + 4), + locales_root: Union[str, Path, None] = None, + plugins_root: str = "plugins", + plugins_exclude: Union[List[str], None] = None, + sleep_threshold: int = 120, + max_concurrent_transmissions: int = 1, + commands_source: Union[Dict[str, dict], None] = None, + scheduler: Union[AsyncIOScheduler, BackgroundScheduler, None] = None, ): - with open("config.json", "r", encoding="utf-8") as f: - self.config: dict = loads(f.read()) + if config is None: + with open(config_path, "r", encoding="utf-8") as f: + self.config: dict = loads(f.read()) + else: + self.config = config + super().__init__( - name="bot_client", - api_id=self.config["bot"]["api_id"], - api_hash=self.config["bot"]["api_hash"], - bot_token=self.config["bot"]["bot_token"], - # Workers should be commented when using convopyro, otherwise + name=name, + api_id=self.config["bot"]["api_id"] if api_id is None else api_id, + api_hash=self.config["bot"]["api_hash"] if api_hash is None else api_hash, + bot_token=self.config["bot"]["bot_token"] + if bot_token is None + else bot_token, + # Workers should be `min(32, cpu_count() + 4)`, otherwise # handlers land in another event loop and you won't see them workers=self.config["bot"]["workers"] if "workers" in self.config["bot"] - else min(32, cpu_count() + 4), - plugins=dict(root="plugins", exclude=self.config["disabled_plugins"]), - sleep_threshold=120, + else workers, + plugins=dict( + root=plugins_root, + exclude=self.config["disabled_plugins"] + if plugins_exclude is None + else plugins_exclude, + ), + sleep_threshold=sleep_threshold, max_concurrent_transmissions=self.config["bot"][ "max_concurrent_transmissions" ] if "max_concurrent_transmissions" in self.config["bot"] - else 1, + else max_concurrent_transmissions, ) self.owner: int = self.config["bot"]["owner"] self.commands: List[PyroCommand] = [] + self.commands_source: Dict[str, dict] = ( + self.config["commands"] if commands_source is None else commands_source + ) self.scoped_commands: bool = self.config["bot"]["scoped_commands"] self.start_time: float = 0 - self.bot_locale: BotLocale = BotLocale(Path(self.config["locations"]["locale"])) + self.bot_locale: BotLocale = BotLocale( + ( + Path(self.config["locations"]["locale"]) + if locales_root is None + else locales_root + ) + ) self.default_locale: str = self.bot_locale.default self.locales: dict = self.bot_locale.locales @@ -82,7 +117,7 @@ class PyroClient(Client): self.scopes_placeholders: Dict[str, int] = {"owner": self.owner} - async def start(self): + async def start(self, register_commands: bool = True): await super().start() self.start_time = time() @@ -104,28 +139,39 @@ class PyroClient(Client): if self.scheduler is None: return - self.scheduler.add_job( - self.register_commands, - trigger="date", - run_date=datetime.now() + timedelta(seconds=5), - kwargs={"command_sets": await self.collect_commands()}, - ) + if register_commands: + self.scheduler.add_job( + self.register_commands, + trigger="date", + run_date=datetime.now() + timedelta(seconds=5), + kwargs={"command_sets": await self.collect_commands()}, + ) self.scheduler.start() except BadRequest: logger.warning("Unable to send message to report chat.") - async def stop(self): + async def stop(self, exit_completely: bool = True): try: await self.send_message( chat_id=self.config["reports"]["chat_id"], text=f"Bot stopped with PID `{getpid()}`", ) + await asyncio.sleep(0.5) except BadRequest: logger.warning("Unable to send message to report chat.") + await super().stop() logger.warning("Bot stopped with PID %s.", getpid()) + if exit_completely: + try: + exit() + except SystemExit as exp: + raise SystemExit( + "Bot has been shut down, this is not an application error!" + ) from exp + async def collect_commands(self) -> Union[List[CommandSet], None]: """Gather list of the bot's commands @@ -141,7 +187,7 @@ class PyroClient(Client): command_sets = [] # Iterate through all commands in config - for command, contents in self.config["commands"].items(): + for command, contents in self.commands_source.items(): # Iterate through all scopes of a command for scope in contents["scopes"]: if dumps(scope) not in scopes: