WIP: Overhaul for 4.0.0

This commit is contained in:
kku
2024-12-18 14:16:37 +01:00
parent 5e479ddc79
commit 95d04308bd
32 changed files with 718 additions and 688 deletions

View File

View File

@@ -1,11 +1,14 @@
import asyncio
import logging
import sys
from datetime import datetime, timedelta
from os import cpu_count, getpid
from pathlib import Path
from time import time
from typing import Any, Dict, List, Union
from typing_extensions import override
try:
import pyrogram
from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -25,24 +28,23 @@ try:
BotCommandScopeDefault,
)
except ImportError as exc:
raise ImportError(
"You need to install libbot[pyrogram] in order to use this class."
) from exc
raise ImportError("You need to install libbot[pyrogram] in order to use this class.") from exc
try:
from ujson import dumps, loads
except ImportError:
from json import dumps, loads
from libbot.i18n import BotLocale
from libbot.i18n.sync import _
from libbot.pyrogram.classes.command import PyroCommand
from libbot.pyrogram.classes.commandset import CommandSet
from ...i18n.classes import BotLocale
from ...i18n import _
from .command import PyroCommand
from .commandset import CommandSet
logger = logging.getLogger(__name__)
class PyroClient(Client):
@override
def __init__(
self,
name: str = "bot_client",
@@ -74,26 +76,20 @@ class PyroClient(Client):
name=name,
api_id=self.config["bot"]["api_id"] if api_id is None else api_id,
api_hash=self.config["bot"]["api_hash"] if api_hash is None else api_hash,
bot_token=self.config["bot"]["bot_token"]
if bot_token is None
else bot_token,
bot_token=self.config["bot"]["bot_token"] if bot_token is None else bot_token,
# Workers should be `min(32, cpu_count() + 4)`, otherwise
# handlers land in another event loop and you won't see them
workers=self.config["bot"]["workers"]
if "workers" in self.config["bot"]
else workers,
workers=self.config["bot"]["workers"] if "workers" in self.config["bot"] else workers,
plugins=dict(
root=plugins_root,
exclude=self.config["disabled_plugins"]
if plugins_exclude is None
else plugins_exclude,
exclude=self.config["disabled_plugins"] if plugins_exclude is None else plugins_exclude,
),
sleep_threshold=sleep_threshold,
max_concurrent_transmissions=self.config["bot"][
"max_concurrent_transmissions"
]
if "max_concurrent_transmissions" in self.config["bot"]
else max_concurrent_transmissions,
max_concurrent_transmissions=(
self.config["bot"]["max_concurrent_transmissions"]
if "max_concurrent_transmissions" in self.config["bot"]
else max_concurrent_transmissions
),
**kwargs,
)
self.owner: int = self.config["bot"]["owner"] if owner is None else owner
@@ -102,9 +98,7 @@ class PyroClient(Client):
self.config["commands"] if commands_source is None else commands_source
)
self.scoped_commands: bool = (
self.config["bot"]["scoped_commands"]
if scoped_commands is None
else scoped_commands
self.config["bot"]["scoped_commands"] if scoped_commands is None else scoped_commands
)
self.start_time: float = 0
@@ -125,6 +119,7 @@ class PyroClient(Client):
self.i18n_bot_info: bool = i18n_bot_info
@override
async def start(self, register_commands: bool = True, scheduler_start: bool = True) -> None:
await super().start()
@@ -189,9 +184,11 @@ class PyroClient(Client):
# Send a message to the bot's reports chat about the startup
try:
await self.send_message(
chat_id=self.owner
if self.config["reports"]["chat_id"] == "owner"
else self.config["reports"]["chat_id"],
chat_id=(
self.owner
if self.config["reports"]["chat_id"] == "owner"
else self.config["reports"]["chat_id"]
),
text=f"Bot started PID `{getpid()}`",
)
except BadRequest:
@@ -212,14 +209,17 @@ class PyroClient(Client):
if scheduler_start:
self.scheduler.start()
@override
async def stop(
self, exit_completely: bool = True, scheduler_shutdown: bool = True, scheduler_wait: bool = True
) -> None:
try:
await self.send_message(
chat_id=self.owner
if self.config["reports"]["chat_id"] == "owner"
else self.config["reports"]["chat_id"],
chat_id=(
self.owner
if self.config["reports"]["chat_id"] == "owner"
else self.config["reports"]["chat_id"]
),
text=f"Bot stopped with PID `{getpid()}`",
)
await asyncio.sleep(0.5)
@@ -234,11 +234,9 @@ class PyroClient(Client):
if exit_completely:
try:
exit()
sys.exit()
except SystemExit as exc:
raise SystemExit(
"Bot has been shut down, this is not an application error!"
) from exc
raise SystemExit("Bot has been shut down, this is not an application error!") from exc
async def collect_commands(self) -> Union[List[CommandSet], None]:
"""Gather list of the bot's commands
@@ -262,13 +260,9 @@ class PyroClient(Client):
scopes[dumps(scope)] = {"_": []}
# Add command to the scope's flattened key in scopes dict
scopes[dumps(scope)]["_"].append(
BotCommand(command, _(command, "commands"))
)
scopes[dumps(scope)]["_"].append(BotCommand(command, _(command, "commands")))
for locale, string in (
self.in_every_locale(command, "commands")
).items():
for locale, string in (self.in_every_locale(command, "commands")).items():
if locale not in scopes[dumps(scope)]:
scopes[dumps(scope)][locale] = []
@@ -287,11 +281,7 @@ class PyroClient(Client):
# Create object with the same name and args from the dict
try:
scope_obj = globals()[scope_dict["name"]](
**{
key: value
for key, value in scope_dict.items()
if key != "name"
}
**{key: value for key, value in scope_dict.items() if key != "name"}
)
except NameError:
logger.error(
@@ -309,13 +299,9 @@ class PyroClient(Client):
# Add set of commands to the list of the command sets
for locale, commands in locales.items():
if locale == "_":
command_sets.append(
CommandSet(commands, scope=scope_obj, language_code="")
)
command_sets.append(CommandSet(commands, scope=scope_obj, language_code=""))
continue
command_sets.append(
CommandSet(commands, scope=scope_obj, language_code=locale)
)
command_sets.append(CommandSet(commands, scope=scope_obj, language_code=locale))
logger.info("Registering the following command sets: %s", command_sets)
@@ -352,9 +338,7 @@ class PyroClient(Client):
command,
)
async def register_commands(
self, command_sets: Union[List[CommandSet], None] = None
) -> None:
async def register_commands(self, command_sets: Union[List[CommandSet], None] = None) -> None:
"""Register commands stored in bot's 'commands' attribute"""
if command_sets is None:
@@ -363,10 +347,7 @@ class PyroClient(Client):
for command in self.commands
]
logger.info(
"Registering commands %s with a default scope 'BotCommandScopeDefault'",
commands
)
logger.info("Registering commands %s with a default scope 'BotCommandScopeDefault'", commands)
await self.set_bot_commands(commands)
return
@@ -384,15 +365,11 @@ class PyroClient(Client):
language_code=command_set.language_code,
)
async def remove_commands(
self, command_sets: Union[List[CommandSet], None] = None
) -> None:
async def remove_commands(self, command_sets: Union[List[CommandSet], None] = None) -> None:
"""Remove commands stored in bot's 'commands' attribute"""
if command_sets is None:
logger.info(
"Removing commands with a default scope 'BotCommandScopeDefault'"
)
logger.info("Removing commands with a default scope 'BotCommandScopeDefault'")
await self.delete_bot_commands(BotCommandScopeDefault())
return