diff --git a/classes/pyrogroup.py b/classes/pyrogroup.py index 724b212..915ea9e 100644 --- a/classes/pyrogroup.py +++ b/classes/pyrogroup.py @@ -37,7 +37,7 @@ class PyroGroup: @classmethod async def find( cls, - id: int, + id: Union[int, float], locale: Union[str, None] = sync.config_get("locale", "defaults", "group"), locale_auto: bool = sync.config_get("locale_auto", "defaults", "group"), ban_failed: bool = sync.config_get("ban_failed", "defaults", "group"), @@ -53,7 +53,7 @@ class PyroGroup: if db_entry is None: inserted = await col_groups.insert_one( { - "id": id, + "id": int(id), "locale": locale, "locale_auto": locale_auto, "ban_failed": ban_failed, diff --git a/classes/pyrouser.py b/classes/pyrouser.py index 3b57004..19cae5b 100644 --- a/classes/pyrouser.py +++ b/classes/pyrouser.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import List +from typing import List, Union from bson import ObjectId @@ -36,7 +36,7 @@ class PyroUser: @classmethod async def find( cls, - id: int, + id: Union[int, float], group: int, failed: bool = False, emojis: List[str] = [], @@ -49,7 +49,7 @@ class PyroUser: if db_entry is None: inserted = await col_users.insert_one( { - "id": id, + "id": int(id), "group": group, "failed": failed, "emojis": emojis, diff --git a/modules/utils.py b/modules/utils.py index 645bd52..6bd9412 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -1,12 +1,16 @@ from io import BytesIO from pathlib import Path from random import randint, sample -from typing import List +from typing import List, Union from huepaper import generate from PIL import Image +from pyrogram.enums.chat_member_status import ChatMemberStatus +from pyrogram.types import CallbackQuery, Message from classes.captcha import Captcha +from classes.pyroclient import PyroClient +from classes.pyrogroup import PyroGroup def get_captcha_image(emojis: List[str]) -> Captcha: @@ -49,3 +53,40 @@ def get_captcha_image(emojis: List[str]) -> Captcha: base_img.save(output, format="jpeg") return Captcha(output, emojis_all, emojis_correct) + + +async def is_permitted( + app: PyroClient, + group: PyroGroup, + message: Union[Message, None] = None, + callback: Union[CallbackQuery, None] = None, +) -> bool: + """Check if User is an admin or a creator of a group. Alternatively, if the User is actually a group itself. + + ### Args: + * app (`PyroClient`): Pyrogram Client + * group (`PyroGroup`): Group + * message (`Union[Message, None]`, *optional*): Message if the request originates from a command. Defaults to `None`. + * callback (`Union[CallbackQuery, None]`, *optional*): CallbackQuery if the request originates from a callback. Defaults to `None`. + + ### Returns: + * `bool`: `True` if permitted and `False` if not. Also `False` if no message or callback provided. + """ + if message is not None: + return ( + message.sender_chat is not None and message.sender_chat.id == group.id + ) or ( + message.from_user is not None + and (await app.get_chat_member(group.id, message.from_user.id)).status + ) in [ + ChatMemberStatus.ADMINISTRATOR, + ChatMemberStatus.OWNER, + ] + + if callback is not None: + return (await app.get_chat_member(group.id, callback.from_user.id)).status in [ + ChatMemberStatus.ADMINISTRATOR, + ChatMemberStatus.OWNER, + ] + + return False diff --git a/plugins/callbacks/ban.py b/plugins/callbacks/ban.py index a368ad8..4b5bb35 100644 --- a/plugins/callbacks/ban.py +++ b/plugins/callbacks/ban.py @@ -1,11 +1,11 @@ import logging from pyrogram import filters -from pyrogram.enums.chat_member_status import ChatMemberStatus from pyrogram.types import CallbackQuery, Message from classes.callbacks import CallbackBan from classes.pyroclient import PyroClient +from modules.utils import is_permitted logger = logging.getLogger(__name__) @@ -15,10 +15,7 @@ async def callback_ban(app: PyroClient, callback: CallbackQuery): group = await app.find_group(callback.message.chat.id) locale = group.select_locale(app, callback.message.from_user) - if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [ - ChatMemberStatus.ADMINISTRATOR, - ChatMemberStatus.OWNER, - ]: + if not (await is_permitted(app, group, callback=callback)): await callback.answer( app._("wrong_user", "callbacks", locale=locale), show_alert=True ) diff --git a/plugins/commands/ban_failed.py b/plugins/commands/ban_failed.py index e6f04f6..3c24638 100644 --- a/plugins/commands/ban_failed.py +++ b/plugins/commands/ban_failed.py @@ -1,10 +1,10 @@ import logging from pyrogram import filters -from pyrogram.enums.chat_member_status import ChatMemberStatus from pyrogram.types import Message from classes.pyroclient import PyroClient +from modules.utils import is_permitted logger = logging.getLogger(__name__) @@ -18,12 +18,7 @@ async def command_ban_failed(app: PyroClient, message: Message): group = await app.find_group(message.chat.id) locale = group.select_locale(app, message.from_user) - if (message.sender_chat is not None and (message.sender_chat.id != group.id)) or ( - await app.get_chat_member(group.id, message.from_user.id) - ).status not in [ - ChatMemberStatus.ADMINISTRATOR, - ChatMemberStatus.OWNER, - ]: + if not (await is_permitted(app, group, message=message)): await message.reply_text( app._("permission_denied", "messages", locale=locale), quote=True ) diff --git a/plugins/commands/language_auto.py b/plugins/commands/language_auto.py index 0496a70..ad7142a 100644 --- a/plugins/commands/language_auto.py +++ b/plugins/commands/language_auto.py @@ -5,6 +5,7 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus from pyrogram.types import Message from classes.pyroclient import PyroClient +from modules.utils import is_permitted logger = logging.getLogger(__name__) @@ -18,12 +19,7 @@ async def command_language_auto(app: PyroClient, message: Message): group = await app.find_group(message.chat.id) locale = group.select_locale(app, message.from_user) - if (message.sender_chat is not None and (message.sender_chat.id != group.id)) or ( - await app.get_chat_member(group.id, message.from_user.id) - ).status not in [ - ChatMemberStatus.ADMINISTRATOR, - ChatMemberStatus.OWNER, - ]: + if not (await is_permitted(app, group, message=message)): await message.reply_text( app._("permission_denied", "messages", locale=locale), quote=True ) diff --git a/plugins/commands/timeout_join.py b/plugins/commands/timeout_join.py index 9d638a6..e852500 100644 --- a/plugins/commands/timeout_join.py +++ b/plugins/commands/timeout_join.py @@ -1,10 +1,10 @@ import logging from pyrogram import filters -from pyrogram.enums.chat_member_status import ChatMemberStatus from pyrogram.types import Message from classes.pyroclient import PyroClient +from modules.utils import is_permitted logger = logging.getLogger(__name__) @@ -18,12 +18,7 @@ async def command_timeout_join(app: PyroClient, message: Message): group = await app.find_group(message.chat.id) locale = group.select_locale(app, message.from_user) - if (message.sender_chat is not None and (message.sender_chat.id != group.id)) or ( - await app.get_chat_member(group.id, message.from_user.id) - ).status not in [ - ChatMemberStatus.ADMINISTRATOR, - ChatMemberStatus.OWNER, - ]: + if not (await is_permitted(app, group, message=message)): await message.reply_text( app._("permission_denied", "messages", locale=locale), quote=True ) diff --git a/plugins/commands/timeout_verify.py b/plugins/commands/timeout_verify.py index a9a939f..23e2c92 100644 --- a/plugins/commands/timeout_verify.py +++ b/plugins/commands/timeout_verify.py @@ -1,10 +1,10 @@ import logging from pyrogram import filters -from pyrogram.enums.chat_member_status import ChatMemberStatus from pyrogram.types import Message from classes.pyroclient import PyroClient +from modules.utils import is_permitted logger = logging.getLogger(__name__) @@ -18,12 +18,7 @@ async def command_timeout_verify(app: PyroClient, message: Message): group = await app.find_group(message.chat.id) locale = group.select_locale(app, message.from_user) - if (message.sender_chat is not None and (message.sender_chat.id != group.id)) or ( - await app.get_chat_member(group.id, message.from_user.id) - ).status not in [ - ChatMemberStatus.ADMINISTRATOR, - ChatMemberStatus.OWNER, - ]: + if not (await is_permitted(app, group, message=message)): await message.reply_text( app._("permission_denied", "messages", locale=locale), quote=True ) diff --git a/plugins/handlers/user_join.py b/plugins/handlers/user_join.py index 2885a0e..e08cd82 100644 --- a/plugins/handlers/user_join.py +++ b/plugins/handlers/user_join.py @@ -103,5 +103,5 @@ async def handler_user_join(app: PyroClient, message: Message): run_date=datetime.now() + timedelta(seconds=group.timeout_join), ) await col_schedule.insert_one( - {"user": user.id, "group": group.id, "job_id": job.id} + {"user": int(user.id), "group": int(group.id), "job_id": job.id} ) diff --git a/plugins/language.py b/plugins/language.py index 7651703..caa2f5e 100644 --- a/plugins/language.py +++ b/plugins/language.py @@ -3,11 +3,11 @@ import logging from pykeyboard import InlineButton, InlineKeyboard from pyrogram import filters from pyrogram.client import Client -from pyrogram.enums.chat_member_status import ChatMemberStatus from pyrogram.types import CallbackQuery, Message from classes.callbacks import CallbackLanguage from classes.pyroclient import PyroClient +from modules.utils import is_permitted logger = logging.getLogger(__name__) @@ -19,26 +19,18 @@ async def command_language(app: PyroClient, message: Message): group = await app.find_group(message.chat.id) locale = group.select_locale(app, message.from_user) - if (message.sender_chat is not None and (message.sender_chat.id != group.id)) or ( - await app.get_chat_member(group.id, message.from_user.id) - ).status not in [ - ChatMemberStatus.ADMINISTRATOR, - ChatMemberStatus.OWNER, - ]: + if not (await is_permitted(app, group, message=message)): await message.reply_text( app._("permission_denied", "messages", locale=locale), quote=True ) return keyboard = InlineKeyboard(row_width=2) - buttons = [] - - for language, data in app.in_every_locale("metadata").items(): - if data["selectable"]: - buttons.append( - InlineButton(f"{data['flag']} {data['name']}", f"language:{language}") - ) - + buttons = [ + InlineButton(f"{data['flag']} {data['name']}", f"language:{language}") + for language, data in app.in_every_locale("metadata").items() + if data["selectable"] + ] buttons.append( InlineButton( f"🤖 {app._('locale_default', 'buttons', locale=locale)}", "language:default" @@ -58,10 +50,7 @@ async def callback_language(app: PyroClient, callback: CallbackQuery): group = await app.find_group(callback.message.chat.id) locale = group.select_locale(app, callback.message.from_user) - if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [ - ChatMemberStatus.ADMINISTRATOR, - ChatMemberStatus.OWNER, - ]: + if not (await is_permitted(app, group, callback=callback)): await callback.answer( app._("wrong_user", "callbacks", locale=locale), show_alert=True )