import contextlib import logging from datetime import datetime, timedelta from io import BytesIO from os import getpid, makedirs, remove, sep from pathlib import Path from shutil import rmtree from time import time from traceback import format_exc from typing import List, Tuple, Union import pyrogram from aiohttp import ClientSession from bson import ObjectId from dateutil.relativedelta import relativedelta from libbot import json_read, json_write from libbot.i18n import BotLocale from libbot.i18n.sync import _ from photosapi_client.errors import UnexpectedStatus from pyrogram.client import Client from pyrogram.errors import BadRequest, bad_request_400 from pyrogram.handlers.message_handler import MessageHandler from pyrogram.raw.all import layer from pyrogram.types import ( BotCommand, BotCommandScopeAllChatAdministrators, BotCommandScopeAllGroupChats, BotCommandScopeAllPrivateChats, BotCommandScopeChat, BotCommandScopeChatAdministrators, BotCommandScopeChatMember, BotCommandScopeDefault, Message, ) from pytimeparse.timeparse import timeparse from ujson import dumps, loads from classes.commandset import CommandSet from classes.exceptions import ( SubmissionDuplicatesError, SubmissionUnavailableError, SubmissionUnsupportedError, ) from classes.pyrocommand import PyroCommand from modules.api_client import ( BodyPhotoUploadAlbumsAlbumPhotosPost, File, client, photo_upload, ) from modules.database import col_submitted from modules.http_client import http_session from modules.scheduler import scheduler from modules.sender import send_content logger = logging.getLogger(__name__) class PyroClient(Client): def __init__(self): with open("config.json", "r", encoding="utf-8") as f: self.config: dict = loads(f.read()) super().__init__( name="bot_client", api_id=self.config["bot"]["api_id"], api_hash=self.config["bot"]["api_hash"], bot_token=self.config["bot"]["bot_token"], plugins=dict(root="plugins", exclude=self.config["disabled_plugins"]), sleep_threshold=120, max_concurrent_transmissions=self.config["bot"][ "max_concurrent_transmissions" ], ) self.version: float = 0.2 self.owner: int = self.config["bot"]["owner"] self.admins: List[int] = self.config["bot"]["admins"] + [ self.config["bot"]["owner"] ] self.commands: List[PyroCommand] = [] self.scoped_commands: bool = self.config["bot"]["scoped_commands"] self.start_time: float = 0 self.bot_locale: BotLocale = BotLocale(Path(self.config["locations"]["locale"])) self.default_locale: str = self.bot_locale.default self.locales: dict = self.bot_locale.locales self._ = self.bot_locale._ self.in_all_locales = self.bot_locale.in_all_locales self.in_every_locale = self.bot_locale.in_every_locale self.sender_session = ClientSession() async def start(self): await super().start() self.start_time = time() logger.info( "Bot is running with Pyrogram v%s (Layer %s) and has started as @%s on PID %s.", pyrogram.__version__, layer, self.me.username, getpid(), ) try: if Path(f"{self.config['locations']['cache']}/shutdown_time").exists(): downtime = relativedelta( datetime.now(), datetime.fromtimestamp( ( await json_read( Path( f"{self.config['locations']['cache']}/shutdown_time" ) ) )["timestamp"] ), ) if downtime.days >= 1: startup_message = self._( "startup_downtime_days", "message", ).format(getpid(), downtime.days) elif downtime.hours >= 1: startup_message = self._( "startup_downtime_hours", "message", ).format(getpid(), downtime.hours) else: startup_message = self._( "startup_downtime_minutes", "message", ).format(getpid(), downtime.minutes) else: startup_message = (self._("startup", "message").format(getpid()),) await self.send_message( chat_id=self.config["reports"]["chat_id"], text=startup_message, ) if self.config["reports"]["update"]: try: async with ClientSession( json_serialize=dumps, ) as http_session: check_update = await http_session.get( "https://git.end-play.xyz/api/v1/repos/profitroll/TelegramPoster/releases?page=1&limit=1" ) response = await check_update.json() if len(response) == 0: raise ValueError("No bot releases on git found.") if float(response[0]["tag_name"].replace("v", "")) > self.version: logger.info( "New version %s found (current %s)", response[0]["tag_name"].replace("v", ""), self.version, ) await self.send_message( self.owner, self._( "update_available", "message", ).format( response[0]["tag_name"], response[0]["html_url"], response[0]["body"], ), ) else: logger.info("No updates found, bot is up to date.") except bad_request_400.PeerIdInvalid: logger.warning( "Could not send startup message to bot owner. Perhaps user has not started the bot yet." ) except Exception as exp: logger.exception( "Update check failed due to %s: %s", exp, format_exc() ) scheduler.add_job( self.register_commands, trigger="date", run_date=datetime.now() + timedelta(seconds=5), kwargs={"command_sets": await self.collect_commands()}, ) if self.config["mode"]["post"]: if self.config["posting"]["use_interval"]: scheduler.add_job( send_content, "interval", seconds=timeparse(self.config["posting"]["interval"]), args=[self, self.sender_session], ) else: for entry in self.config["posting"]["time"]: dt_obj = datetime.strptime(entry, "%H:%M") scheduler.add_job( send_content, "cron", hour=dt_obj.hour, minute=dt_obj.minute, args=[self, self.sender_session], ) scheduler.start() except BadRequest: logger.warning("Unable to send message to report chat.") async def stop(self): makedirs(self.config["locations"]["cache"], exist_ok=True) await json_write( {"timestamp": time()}, Path(f"{self.config['locations']['cache']}/shutdown_time"), ) try: await self.send_message( chat_id=self.config["reports"]["chat_id"], text=f"Bot stopped with PID `{getpid()}`", ) except BadRequest: logger.warning("Unable to send message to report chat.") await http_session.close() await self.sender_session.close() await super().stop() logger.warning("Bot stopped with PID %s.", getpid()) async def collect_commands(self) -> Union[List[CommandSet], None]: """Gather list of the bot's commands ### Returns: * `List[CommandSet]`: List of the commands' sets """ command_sets = None # If config get bot.scoped_commands is true - more complicated # scopes system will be used instead of simple global commands if self.scoped_commands: scopes = {} command_sets = [] # Iterate through all commands in config for command, contents in self.config["commands"].items(): # Iterate through all scopes of a command for scope in contents["scopes"]: if dumps(scope) not in scopes: scopes[dumps(scope)] = {"_": []} # Add command to the scope's flattened key in scopes dict scopes[dumps(scope)]["_"].append( BotCommand(command, _(command, "commands")) ) for locale, string in ( self.in_every_locale(command, "commands") ).items(): if locale not in scopes[dumps(scope)]: scopes[dumps(scope)][locale] = [] scopes[dumps(scope)][locale].append(BotCommand(command, string)) # Iterate through all scopes and its commands for scope, locales in scopes.items(): # Make flat key a dict again scope_dict = loads(scope) # Replace "owner" in the bot scope with owner's id if "chat_id" in scope_dict and scope_dict["chat_id"] == "owner": scope_dict["chat_id"] = self.owner # Create object with the same name and args from the dict try: scope_obj = globals()[scope_dict["name"]]( **{ key: value for key, value in scope_dict.items() if key != "name" } ) except NameError: logger.error( "Could not register commands of the scope '%s' due to an invalid scope class provided!", scope_dict["name"], ) continue except TypeError: logger.error( "Could not register commands of the scope '%s' due to an invalid class arguments provided!", scope_dict["name"], ) continue # Add set of commands to the list of the command sets for locale, commands in locales.items(): if locale == "_": command_sets.append( CommandSet(commands, scope=scope_obj, language_code="") ) continue command_sets.append( CommandSet(commands, scope=scope_obj, language_code=locale) ) logger.info("Registering the following command sets: %s", command_sets) else: # This part here looks into the handlers and looks for commands # in it, if there are any. Then adds them to self.commands for handler in self.dispatcher.groups[0]: if isinstance(handler, MessageHandler): for entry in [handler.filters.base, handler.filters.other]: if hasattr(entry, "commands"): for command in entry.commands: logger.info("I see a command %s in my filters", command) self.add_command(command) return command_sets def add_command( self, command: str, ): """Add command to the bot's internal commands list ### Args: * command (`str`) """ self.commands.append( PyroCommand( command, _(command, "commands"), ) ) logger.info( "Added command '%s' to the bot's internal commands list", command, ) async def register_commands( self, command_sets: Union[List[CommandSet], None] = None ): """Register commands stored in bot's 'commands' attribute""" if command_sets is None: commands = [ BotCommand(command=command.command, description=command.description) for command in self.commands ] logger.info( "Registering commands %s with a default scope 'BotCommandScopeDefault'" ) await self.set_bot_commands(commands) return for command_set in command_sets: logger.info( "Registering command set with commands %s and scope '%s' (%s)", command_set.commands, command_set.scope, command_set.language_code, ) await self.set_bot_commands( command_set.commands, command_set.scope, language_code=command_set.language_code, ) async def remove_commands(self, command_sets: Union[List[CommandSet], None] = None): """Remove commands stored in bot's 'commands' attribute""" if command_sets is None: logger.info( "Removing commands with a default scope 'BotCommandScopeDefault'" ) await self.delete_bot_commands(BotCommandScopeDefault()) return for command_set in command_sets: logger.info( "Removing command set with scope '%s' (%s)", command_set.scope, command_set.language_code, ) await self.delete_bot_commands( command_set.scope, language_code=command_set.language_code, ) async def submit_photo( self, id: str ) -> Tuple[Union[Message, None], Union[str, None]]: db_entry = col_submitted.find_one({"_id": ObjectId(id)}) submission = None if db_entry is None: raise SubmissionUnavailableError() if db_entry["temp"]["uuid"] is None: try: submission = await self.get_messages( db_entry["user"], db_entry["telegram"]["msg_id"] ) filepath = await self.download_media( submission, file_name=self.config["locations"]["tmp"] + sep ) except Exception as exp: raise SubmissionUnavailableError() elif not Path( f"{self.config['locations']['data']}/submissions/{db_entry['temp']['uuid']}/{db_entry['temp']['file']}", ).exists(): raise SubmissionUnavailableError() else: filepath = Path( f"{self.config['locations']['data']}/submissions/{db_entry['temp']['uuid']}/{db_entry['temp']['file']}", ) with contextlib.suppress(Exception): submission = await self.get_messages( db_entry["user"], db_entry["telegram"]["msg_id"] ) with open(str(filepath), "rb") as fh: photo_bytes = BytesIO(fh.read()) try: response = await photo_upload( self.config["posting"]["api"]["album"], client=client, multipart_data=BodyPhotoUploadAlbumsAlbumPhotosPost( File(photo_bytes, filepath.name, "image/jpeg") ), ignore_duplicates=self.config["submission"]["allow_duplicates"], compress=False, caption="queue", ) except UnexpectedStatus: raise SubmissionUnsupportedError(str(filepath)) response_dict = loads(response.content.decode("utf-8")) if "duplicates" in response_dict and len(response_dict["duplicates"]) > 0: duplicates = [] for index, duplicate in enumerate(response_dict["duplicates"]): # type: ignore if response_dict["access_token"] is None: duplicates.append( f"`{duplicate['id']}`:\n{self.config['posting']['api']['address_external']}/photos/{duplicate['id']}" ) else: duplicates.append( f"`{duplicate['id']}`:\n{self.config['posting']['api']['address_external']}/token/photo/{response_dict['access_token']}?id={index}" ) raise SubmissionDuplicatesError(str(filepath), duplicates) col_submitted.find_one_and_update( {"_id": ObjectId(id)}, {"$set": {"done": True}} ) try: if db_entry["temp"]["uuid"] is not None: rmtree( Path( f"{self.config['locations']['data']}/submissions/{db_entry['temp']['uuid']}", ), ignore_errors=True, ) else: remove(str(filepath)) except (FileNotFoundError, NotADirectoryError): logger.error("Could not delete '%s' on submission accepted", filepath) return submission, response.parsed.id async def ban_user(self, id: int) -> None: pass async def unban_user(self, id: int) -> None: pass