Improved init flexibility

This commit is contained in:
Profitroll 2023-06-29 15:58:50 +02:00
parent fe9cc3674f
commit 8c2054f496
Signed by: profitroll
GPG Key ID: FA35CAB49DACD3B2

View File

@ -1,9 +1,10 @@
import asyncio
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from os import cpu_count, getpid from os import cpu_count, getpid
from pathlib import Path from pathlib import Path
from time import time from time import time
from typing import Dict, List, Union from typing import Any, Dict, List, Union
try: try:
import pyrogram import pyrogram
@ -43,34 +44,68 @@ logger = logging.getLogger(__name__)
class PyroClient(Client): class PyroClient(Client):
def __init__( def __init__(
self, scheduler: Union[AsyncIOScheduler, BackgroundScheduler, None] = None self,
name: str = "bot_client",
config: Union[Dict[str, Any], None] = None,
config_path: Union[str, Path] = Path("config.json"),
api_id: Union[int, None] = None,
api_hash: Union[str, None] = None,
bot_token: Union[str, None] = None,
workers: int = min(32, cpu_count() + 4),
locales_root: Union[str, Path, None] = None,
plugins_root: str = "plugins",
plugins_exclude: Union[List[str], None] = None,
sleep_threshold: int = 120,
max_concurrent_transmissions: int = 1,
commands_source: Union[Dict[str, dict], None] = None,
scheduler: Union[AsyncIOScheduler, BackgroundScheduler, None] = None,
): ):
with open("config.json", "r", encoding="utf-8") as f: if config is None:
with open(config_path, "r", encoding="utf-8") as f:
self.config: dict = loads(f.read()) self.config: dict = loads(f.read())
else:
self.config = config
super().__init__( super().__init__(
name="bot_client", name=name,
api_id=self.config["bot"]["api_id"], api_id=self.config["bot"]["api_id"] if api_id is None else api_id,
api_hash=self.config["bot"]["api_hash"], api_hash=self.config["bot"]["api_hash"] if api_hash is None else api_hash,
bot_token=self.config["bot"]["bot_token"], bot_token=self.config["bot"]["bot_token"]
# Workers should be commented when using convopyro, otherwise if bot_token is None
else bot_token,
# Workers should be `min(32, cpu_count() + 4)`, otherwise
# handlers land in another event loop and you won't see them # handlers land in another event loop and you won't see them
workers=self.config["bot"]["workers"] workers=self.config["bot"]["workers"]
if "workers" in self.config["bot"] if "workers" in self.config["bot"]
else min(32, cpu_count() + 4), else workers,
plugins=dict(root="plugins", exclude=self.config["disabled_plugins"]), plugins=dict(
sleep_threshold=120, root=plugins_root,
exclude=self.config["disabled_plugins"]
if plugins_exclude is None
else plugins_exclude,
),
sleep_threshold=sleep_threshold,
max_concurrent_transmissions=self.config["bot"][ max_concurrent_transmissions=self.config["bot"][
"max_concurrent_transmissions" "max_concurrent_transmissions"
] ]
if "max_concurrent_transmissions" in self.config["bot"] if "max_concurrent_transmissions" in self.config["bot"]
else 1, else max_concurrent_transmissions,
) )
self.owner: int = self.config["bot"]["owner"] self.owner: int = self.config["bot"]["owner"]
self.commands: List[PyroCommand] = [] self.commands: List[PyroCommand] = []
self.commands_source: Dict[str, dict] = (
self.config["commands"] if commands_source is None else commands_source
)
self.scoped_commands: bool = self.config["bot"]["scoped_commands"] self.scoped_commands: bool = self.config["bot"]["scoped_commands"]
self.start_time: float = 0 self.start_time: float = 0
self.bot_locale: BotLocale = BotLocale(Path(self.config["locations"]["locale"])) self.bot_locale: BotLocale = BotLocale(
(
Path(self.config["locations"]["locale"])
if locales_root is None
else locales_root
)
)
self.default_locale: str = self.bot_locale.default self.default_locale: str = self.bot_locale.default
self.locales: dict = self.bot_locale.locales self.locales: dict = self.bot_locale.locales
@ -82,7 +117,7 @@ class PyroClient(Client):
self.scopes_placeholders: Dict[str, int] = {"owner": self.owner} self.scopes_placeholders: Dict[str, int] = {"owner": self.owner}
async def start(self): async def start(self, register_commands: bool = True):
await super().start() await super().start()
self.start_time = time() self.start_time = time()
@ -104,6 +139,7 @@ class PyroClient(Client):
if self.scheduler is None: if self.scheduler is None:
return return
if register_commands:
self.scheduler.add_job( self.scheduler.add_job(
self.register_commands, self.register_commands,
trigger="date", trigger="date",
@ -115,17 +151,27 @@ class PyroClient(Client):
except BadRequest: except BadRequest:
logger.warning("Unable to send message to report chat.") logger.warning("Unable to send message to report chat.")
async def stop(self): async def stop(self, exit_completely: bool = True):
try: try:
await self.send_message( await self.send_message(
chat_id=self.config["reports"]["chat_id"], chat_id=self.config["reports"]["chat_id"],
text=f"Bot stopped with PID `{getpid()}`", text=f"Bot stopped with PID `{getpid()}`",
) )
await asyncio.sleep(0.5)
except BadRequest: except BadRequest:
logger.warning("Unable to send message to report chat.") logger.warning("Unable to send message to report chat.")
await super().stop() await super().stop()
logger.warning("Bot stopped with PID %s.", getpid()) logger.warning("Bot stopped with PID %s.", getpid())
if exit_completely:
try:
exit()
except SystemExit as exp:
raise SystemExit(
"Bot has been shut down, this is not an application error!"
) from exp
async def collect_commands(self) -> Union[List[CommandSet], None]: async def collect_commands(self) -> Union[List[CommandSet], None]:
"""Gather list of the bot's commands """Gather list of the bot's commands
@ -141,7 +187,7 @@ class PyroClient(Client):
command_sets = [] command_sets = []
# Iterate through all commands in config # Iterate through all commands in config
for command, contents in self.config["commands"].items(): for command, contents in self.commands_source.items():
# Iterate through all scopes of a command # Iterate through all scopes of a command
for scope in contents["scopes"]: for scope in contents["scopes"]:
if dumps(scope) not in scopes: if dumps(scope) not in scopes: