386 lines
14 KiB
Python
Raw Normal View History

2023-06-29 15:58:50 +02:00
import asyncio
2023-06-26 12:19:29 +02:00
import logging
2024-12-18 14:16:37 +01:00
import sys
2023-06-26 12:19:29 +02:00
from datetime import datetime, timedelta
from os import cpu_count, getpid
2023-06-26 12:19:29 +02:00
from pathlib import Path
from time import time
2024-12-26 18:36:57 +01:00
from typing import Any, Dict, List
2023-06-26 12:19:29 +02:00
2024-12-18 14:16:37 +01:00
from typing_extensions import override
2023-06-26 12:19:29 +02:00
try:
import pyrogram
2023-06-26 12:45:39 +02:00
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.schedulers.background import BackgroundScheduler
2023-06-26 12:19:29 +02:00
from pyrogram.client import Client
from pyrogram.errors import BadRequest
from pyrogram.handlers.message_handler import MessageHandler
from pyrogram.raw.all import layer
from pyrogram.types import (
BotCommand,
BotCommandScopeAllChatAdministrators,
BotCommandScopeAllGroupChats,
BotCommandScopeAllPrivateChats,
BotCommandScopeChat,
BotCommandScopeChatAdministrators,
BotCommandScopeChatMember,
BotCommandScopeDefault,
)
except ImportError as exc:
2024-12-18 14:16:37 +01:00
raise ImportError("You need to install libbot[pyrogram] in order to use this class.") from exc
2023-06-26 12:19:29 +02:00
try:
from ujson import dumps, loads
except ImportError:
from json import dumps, loads
2024-12-18 14:16:37 +01:00
from ...i18n.classes import BotLocale
from ...i18n import _
from .command import PyroCommand
from .commandset import CommandSet
2023-06-26 12:19:29 +02:00
logger = logging.getLogger(__name__)
class PyroClient(Client):
2024-12-18 14:16:37 +01:00
@override
2023-06-26 12:45:39 +02:00
def __init__(
2023-06-29 15:58:50 +02:00
self,
name: str = "bot_client",
2024-12-26 18:36:57 +01:00
owner: int | None = None,
config: Dict[str, Any] | None = None,
config_path: str | Path = Path("config.json"),
api_id: int | None = None,
api_hash: str | None = None,
bot_token: str | None = None,
2023-06-29 15:58:50 +02:00
workers: int = min(32, cpu_count() + 4),
2024-12-26 18:36:57 +01:00
locales_root: str | Path | None = None,
2023-06-29 15:58:50 +02:00
plugins_root: str = "plugins",
2024-12-26 18:36:57 +01:00
plugins_exclude: List[str] | None = None,
2023-06-29 15:58:50 +02:00
sleep_threshold: int = 120,
max_concurrent_transmissions: int = 1,
2024-12-26 18:36:57 +01:00
commands_source: Dict[str, dict] | None = None,
scoped_commands: bool | None = None,
i18n_bot_info: bool = False,
2024-12-26 18:36:57 +01:00
scheduler: AsyncIOScheduler | BackgroundScheduler | None = None,
2023-07-26 14:12:05 +02:00
**kwargs,
2023-06-26 12:45:39 +02:00
):
2023-06-29 15:58:50 +02:00
if config is None:
with open(config_path, "r", encoding="utf-8") as f:
self.config: dict = loads(f.read())
else:
self.config = config
2023-06-26 12:19:29 +02:00
super().__init__(
2023-06-29 15:58:50 +02:00
name=name,
api_id=self.config["bot"]["api_id"] if api_id is None else api_id,
api_hash=self.config["bot"]["api_hash"] if api_hash is None else api_hash,
2024-12-18 14:16:37 +01:00
bot_token=self.config["bot"]["bot_token"] if bot_token is None else bot_token,
2023-06-29 15:58:50 +02:00
# Workers should be `min(32, cpu_count() + 4)`, otherwise
2023-06-26 12:19:29 +02:00
# handlers land in another event loop and you won't see them
2024-12-18 14:16:37 +01:00
workers=self.config["bot"]["workers"] if "workers" in self.config["bot"] else workers,
2023-06-29 15:58:50 +02:00
plugins=dict(
root=plugins_root,
2024-12-18 14:16:37 +01:00
exclude=self.config["disabled_plugins"] if plugins_exclude is None else plugins_exclude,
2023-06-29 15:58:50 +02:00
),
sleep_threshold=sleep_threshold,
2024-12-18 14:16:37 +01:00
max_concurrent_transmissions=(
self.config["bot"]["max_concurrent_transmissions"]
if "max_concurrent_transmissions" in self.config["bot"]
else max_concurrent_transmissions
),
2023-07-26 14:12:05 +02:00
**kwargs,
2023-06-26 12:19:29 +02:00
)
self.owner: int = self.config["bot"]["owner"] if owner is None else owner
2023-06-26 12:19:29 +02:00
self.commands: List[PyroCommand] = []
2023-06-29 15:58:50 +02:00
self.commands_source: Dict[str, dict] = (
self.config["commands"] if commands_source is None else commands_source
)
self.scoped_commands: bool = (
2024-12-18 14:16:37 +01:00
self.config["bot"]["scoped_commands"] if scoped_commands is None else scoped_commands
)
2023-06-26 12:19:29 +02:00
self.start_time: float = 0
2023-06-29 15:58:50 +02:00
self.bot_locale: BotLocale = BotLocale(
2023-08-06 21:59:48 +02:00
default_locale=self.config["locale"],
2024-01-03 22:37:04 +01:00
locales_root=(Path("locale") if locales_root is None else locales_root),
2023-06-29 15:58:50 +02:00
)
2023-06-26 12:19:29 +02:00
self.default_locale: str = self.bot_locale.default
self.locales: dict = self.bot_locale.locales
self._ = self.bot_locale._
self.in_all_locales = self.bot_locale.in_all_locales
self.in_every_locale = self.bot_locale.in_every_locale
2024-12-26 18:36:57 +01:00
self.scheduler: AsyncIOScheduler | BackgroundScheduler | None = scheduler
2023-06-26 12:45:39 +02:00
2023-06-26 13:29:26 +02:00
self.scopes_placeholders: Dict[str, int] = {"owner": self.owner}
self.i18n_bot_info: bool = i18n_bot_info
2024-12-18 14:16:37 +01:00
@override
2024-12-18 13:53:09 +01:00
async def start(self, register_commands: bool = True, scheduler_start: bool = True) -> None:
2023-06-26 12:19:29 +02:00
await super().start()
self.start_time = time()
logger.info(
"Bot is running with Pyrogram v%s (Layer %s) and has started as @%s on PID %s.",
pyrogram.__version__,
layer,
self.me.username,
getpid(),
)
if self.i18n_bot_info:
# Register default bot's info
try:
await self.set_bot_info(
name=self._("name", "bot"),
about=self._("about", "bot"),
description=self._("description", "bot"),
lang_code="",
)
logger.info(
"Bot's info for the default locale %s has been updated",
self.default_locale,
)
except KeyError:
2024-01-03 22:37:04 +01:00
logger.warning(
"Default locale %s has incorrect keys or values in bot section",
self.default_locale,
2024-01-03 22:37:04 +01:00
)
# Register bot's info for each available locale
for locale_code in self.locales:
locale = self.locales[locale_code]
if "metadata" not in locale or ("codes" not in locale["metadata"]):
2024-01-03 22:37:04 +01:00
logger.warning(
"Locale %s is missing metadata or metadata.codes key",
2024-01-03 22:37:04 +01:00
locale_code,
)
continue
for code in locale["metadata"]["codes"]:
try:
await self.set_bot_info(
name=locale["bot"]["name"],
about=locale["bot"]["about"],
description=locale["bot"]["description"],
lang_code=code,
)
logger.info(
"Bot's info for the locale %s has been updated",
self.code,
)
except KeyError:
logger.warning(
"Locale %s has incorrect keys or values in bot section",
locale_code,
)
2024-01-03 22:37:04 +01:00
# Send a message to the bot's reports chat about the startup
2023-06-26 12:19:29 +02:00
try:
await self.send_message(
2024-12-18 14:16:37 +01:00
chat_id=(
self.owner
if self.config["reports"]["chat_id"] == "owner"
else self.config["reports"]["chat_id"]
),
2023-06-26 12:19:29 +02:00
text=f"Bot started PID `{getpid()}`",
)
2024-01-03 22:37:04 +01:00
except BadRequest:
logger.warning("Unable to send message to report chat.")
2023-06-26 12:19:29 +02:00
2024-01-03 22:37:04 +01:00
if self.scheduler is None:
return
2023-06-26 12:45:39 +02:00
2024-01-03 22:37:04 +01:00
# Schedule the task to register all commands
if register_commands:
self.scheduler.add_job(
self.register_commands,
trigger="date",
run_date=datetime.now() + timedelta(seconds=5),
kwargs={"command_sets": await self.collect_commands()},
)
2023-06-26 12:19:29 +02:00
2024-12-18 13:53:09 +01:00
if scheduler_start:
self.scheduler.start()
2023-06-26 12:19:29 +02:00
2024-12-18 14:16:37 +01:00
@override
2024-12-18 13:53:09 +01:00
async def stop(
self, exit_completely: bool = True, scheduler_shutdown: bool = True, scheduler_wait: bool = True
) -> None:
2023-06-26 12:19:29 +02:00
try:
await self.send_message(
2024-12-18 14:16:37 +01:00
chat_id=(
self.owner
if self.config["reports"]["chat_id"] == "owner"
else self.config["reports"]["chat_id"]
),
2023-06-26 12:19:29 +02:00
text=f"Bot stopped with PID `{getpid()}`",
)
2023-06-29 15:58:50 +02:00
await asyncio.sleep(0.5)
2023-06-26 12:19:29 +02:00
except BadRequest:
logger.warning("Unable to send message to report chat.")
2023-06-29 15:58:50 +02:00
2024-12-18 13:53:09 +01:00
if self.scheduler is not None and scheduler_shutdown:
self.scheduler.shutdown(scheduler_wait)
2023-06-26 12:19:29 +02:00
await super().stop()
logger.warning("Bot stopped with PID %s.", getpid())
2023-06-29 15:58:50 +02:00
if exit_completely:
try:
2024-12-18 14:16:37 +01:00
sys.exit()
2024-01-03 22:37:04 +01:00
except SystemExit as exc:
2024-12-18 14:16:37 +01:00
raise SystemExit("Bot has been shut down, this is not an application error!") from exc
2023-06-29 15:58:50 +02:00
2024-12-26 18:36:57 +01:00
async def collect_commands(self) -> List[CommandSet] | None:
2023-06-26 12:19:29 +02:00
"""Gather list of the bot's commands
### Returns:
2024-01-03 22:37:04 +01:00
* `List[CommandSet]`: List of the commands' sets.
2023-06-26 12:19:29 +02:00
"""
command_sets = None
# If config's bot.scoped_commands is true - more complicated
# scopes system will be used instead of simple global commands
if self.scoped_commands:
scopes = {}
command_sets = []
# Iterate through all commands in config
2023-06-29 15:58:50 +02:00
for command, contents in self.commands_source.items():
2023-06-26 12:19:29 +02:00
# Iterate through all scopes of a command
for scope in contents["scopes"]:
if dumps(scope) not in scopes:
scopes[dumps(scope)] = {"_": []}
# Add command to the scope's flattened key in scopes dict
2024-12-18 14:16:37 +01:00
scopes[dumps(scope)]["_"].append(BotCommand(command, _(command, "commands")))
2023-06-26 12:19:29 +02:00
2024-12-18 14:16:37 +01:00
for locale, string in (self.in_every_locale(command, "commands")).items():
2023-06-26 12:19:29 +02:00
if locale not in scopes[dumps(scope)]:
scopes[dumps(scope)][locale] = []
scopes[dumps(scope)][locale].append(BotCommand(command, string))
# Iterate through all scopes and its commands
for scope, locales in scopes.items():
# Make flat key a dict again
scope_dict = loads(scope)
# Replace "owner" in the bot scope with owner's id
2023-06-26 13:29:26 +02:00
for placeholder, chat_id in self.scopes_placeholders.items():
if "chat_id" in scope_dict and scope_dict["chat_id"] == placeholder:
scope_dict["chat_id"] = chat_id
2023-06-26 12:19:29 +02:00
# Create object with the same name and args from the dict
try:
scope_obj = globals()[scope_dict["name"]](
2024-12-18 14:16:37 +01:00
**{key: value for key, value in scope_dict.items() if key != "name"}
2023-06-26 12:19:29 +02:00
)
except NameError:
logger.error(
"Could not register commands of the scope '%s' due to an invalid scope class provided!",
scope_dict["name"],
)
continue
except TypeError:
logger.error(
"Could not register commands of the scope '%s' due to an invalid class arguments provided!",
scope_dict["name"],
)
continue
# Add set of commands to the list of the command sets
for locale, commands in locales.items():
if locale == "_":
2024-12-18 14:16:37 +01:00
command_sets.append(CommandSet(commands, scope=scope_obj, language_code=""))
2023-06-26 12:19:29 +02:00
continue
2024-12-18 14:16:37 +01:00
command_sets.append(CommandSet(commands, scope=scope_obj, language_code=locale))
2023-06-26 12:19:29 +02:00
logger.info("Registering the following command sets: %s", command_sets)
else:
# This part here looks into the handlers and looks for commands
# in it, if there are any. Then adds them to self.commands
for handler in self.dispatcher.groups[0]:
if isinstance(handler, MessageHandler):
for entry in [handler.filters.base, handler.filters.other]:
if hasattr(entry, "commands"):
for command in entry.commands:
logger.info("I see a command %s in my filters", command)
self.add_command(command)
return command_sets
def add_command(
self,
command: str,
2023-08-06 19:11:16 +02:00
) -> None:
2023-06-26 12:19:29 +02:00
"""Add command to the bot's internal commands list
### Args:
* command (`str`)
"""
self.commands.append(
PyroCommand(
command,
_(command, "commands"),
)
)
logger.info(
"Added command '%s' to the bot's internal commands list",
command,
)
2024-12-26 18:36:57 +01:00
async def register_commands(self, command_sets: List[CommandSet] | None = None) -> None:
2023-06-26 12:19:29 +02:00
"""Register commands stored in bot's 'commands' attribute"""
if command_sets is None:
commands = [
BotCommand(command=command.command, description=command.description)
for command in self.commands
]
2024-12-18 14:16:37 +01:00
logger.info("Registering commands %s with a default scope 'BotCommandScopeDefault'", commands)
2023-06-26 12:19:29 +02:00
await self.set_bot_commands(commands)
return
for command_set in command_sets:
logger.info(
"Registering command set with commands %s and scope '%s' (%s)",
command_set.commands,
command_set.scope,
command_set.language_code,
)
await self.set_bot_commands(
command_set.commands,
command_set.scope,
language_code=command_set.language_code,
)
2024-12-26 18:36:57 +01:00
async def remove_commands(self, command_sets: List[CommandSet] | None = None) -> None:
2023-06-26 12:19:29 +02:00
"""Remove commands stored in bot's 'commands' attribute"""
if command_sets is None:
2024-12-18 14:16:37 +01:00
logger.info("Removing commands with a default scope 'BotCommandScopeDefault'")
2023-06-26 12:19:29 +02:00
await self.delete_bot_commands(BotCommandScopeDefault())
return
for command_set in command_sets:
logger.info(
"Removing command set with scope '%s' (%s)",
command_set.scope,
command_set.language_code,
)
await self.delete_bot_commands(
command_set.scope,
language_code=command_set.language_code,
)