Fixed permissions check

This commit is contained in:
Profitroll 2023-08-23 11:12:26 +02:00
parent 8eef171391
commit e61516f51f
Signed by: profitroll
GPG Key ID: FA35CAB49DACD3B2
10 changed files with 66 additions and 58 deletions

View File

@ -37,7 +37,7 @@ class PyroGroup:
@classmethod @classmethod
async def find( async def find(
cls, cls,
id: int, id: Union[int, float],
locale: Union[str, None] = sync.config_get("locale", "defaults", "group"), locale: Union[str, None] = sync.config_get("locale", "defaults", "group"),
locale_auto: bool = sync.config_get("locale_auto", "defaults", "group"), locale_auto: bool = sync.config_get("locale_auto", "defaults", "group"),
ban_failed: bool = sync.config_get("ban_failed", "defaults", "group"), ban_failed: bool = sync.config_get("ban_failed", "defaults", "group"),
@ -53,7 +53,7 @@ class PyroGroup:
if db_entry is None: if db_entry is None:
inserted = await col_groups.insert_one( inserted = await col_groups.insert_one(
{ {
"id": id, "id": int(id),
"locale": locale, "locale": locale,
"locale_auto": locale_auto, "locale_auto": locale_auto,
"ban_failed": ban_failed, "ban_failed": ban_failed,

View File

@ -1,6 +1,6 @@
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List, Union
from bson import ObjectId from bson import ObjectId
@ -36,7 +36,7 @@ class PyroUser:
@classmethod @classmethod
async def find( async def find(
cls, cls,
id: int, id: Union[int, float],
group: int, group: int,
failed: bool = False, failed: bool = False,
emojis: List[str] = [], emojis: List[str] = [],
@ -49,7 +49,7 @@ class PyroUser:
if db_entry is None: if db_entry is None:
inserted = await col_users.insert_one( inserted = await col_users.insert_one(
{ {
"id": id, "id": int(id),
"group": group, "group": group,
"failed": failed, "failed": failed,
"emojis": emojis, "emojis": emojis,

View File

@ -1,12 +1,16 @@
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from random import randint, sample from random import randint, sample
from typing import List from typing import List, Union
from huepaper import generate from huepaper import generate
from PIL import Image from PIL import Image
from pyrogram.enums.chat_member_status import ChatMemberStatus
from pyrogram.types import CallbackQuery, Message
from classes.captcha import Captcha from classes.captcha import Captcha
from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
def get_captcha_image(emojis: List[str]) -> Captcha: def get_captcha_image(emojis: List[str]) -> Captcha:
@ -49,3 +53,40 @@ def get_captcha_image(emojis: List[str]) -> Captcha:
base_img.save(output, format="jpeg") base_img.save(output, format="jpeg")
return Captcha(output, emojis_all, emojis_correct) return Captcha(output, emojis_all, emojis_correct)
async def is_permitted(
app: PyroClient,
group: PyroGroup,
message: Union[Message, None] = None,
callback: Union[CallbackQuery, None] = None,
) -> bool:
"""Check if User is an admin or a creator of a group. Alternatively, if the User is actually a group itself.
### Args:
* app (`PyroClient`): Pyrogram Client
* group (`PyroGroup`): Group
* message (`Union[Message, None]`, *optional*): Message if the request originates from a command. Defaults to `None`.
* callback (`Union[CallbackQuery, None]`, *optional*): CallbackQuery if the request originates from a callback. Defaults to `None`.
### Returns:
* `bool`: `True` if permitted and `False` if not. Also `False` if no message or callback provided.
"""
if message is not None:
return (
message.sender_chat is not None and message.sender_chat.id == group.id
) or (
message.from_user is not None
and (await app.get_chat_member(group.id, message.from_user.id)).status
) in [
ChatMemberStatus.ADMINISTRATOR,
ChatMemberStatus.OWNER,
]
if callback is not None:
return (await app.get_chat_member(group.id, callback.from_user.id)).status in [
ChatMemberStatus.ADMINISTRATOR,
ChatMemberStatus.OWNER,
]
return False

View File

@ -1,11 +1,11 @@
import logging import logging
from pyrogram import filters from pyrogram import filters
from pyrogram.enums.chat_member_status import ChatMemberStatus
from pyrogram.types import CallbackQuery, Message from pyrogram.types import CallbackQuery, Message
from classes.callbacks import CallbackBan from classes.callbacks import CallbackBan
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from modules.utils import is_permitted
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,10 +15,7 @@ async def callback_ban(app: PyroClient, callback: CallbackQuery):
group = await app.find_group(callback.message.chat.id) group = await app.find_group(callback.message.chat.id)
locale = group.select_locale(app, callback.message.from_user) locale = group.select_locale(app, callback.message.from_user)
if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [ if not (await is_permitted(app, group, callback=callback)):
ChatMemberStatus.ADMINISTRATOR,
ChatMemberStatus.OWNER,
]:
await callback.answer( await callback.answer(
app._("wrong_user", "callbacks", locale=locale), show_alert=True app._("wrong_user", "callbacks", locale=locale), show_alert=True
) )

View File

@ -1,10 +1,10 @@
import logging import logging
from pyrogram import filters from pyrogram import filters
from pyrogram.enums.chat_member_status import ChatMemberStatus
from pyrogram.types import Message from pyrogram.types import Message
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from modules.utils import is_permitted
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,12 +18,7 @@ async def command_ban_failed(app: PyroClient, message: Message):
group = await app.find_group(message.chat.id) group = await app.find_group(message.chat.id)
locale = group.select_locale(app, message.from_user) locale = group.select_locale(app, message.from_user)
if (message.sender_chat is not None and (message.sender_chat.id != group.id)) or ( if not (await is_permitted(app, group, message=message)):
await app.get_chat_member(group.id, message.from_user.id)
).status not in [
ChatMemberStatus.ADMINISTRATOR,
ChatMemberStatus.OWNER,
]:
await message.reply_text( await message.reply_text(
app._("permission_denied", "messages", locale=locale), quote=True app._("permission_denied", "messages", locale=locale), quote=True
) )

View File

@ -5,6 +5,7 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus
from pyrogram.types import Message from pyrogram.types import Message
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from modules.utils import is_permitted
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,12 +19,7 @@ async def command_language_auto(app: PyroClient, message: Message):
group = await app.find_group(message.chat.id) group = await app.find_group(message.chat.id)
locale = group.select_locale(app, message.from_user) locale = group.select_locale(app, message.from_user)
if (message.sender_chat is not None and (message.sender_chat.id != group.id)) or ( if not (await is_permitted(app, group, message=message)):
await app.get_chat_member(group.id, message.from_user.id)
).status not in [
ChatMemberStatus.ADMINISTRATOR,
ChatMemberStatus.OWNER,
]:
await message.reply_text( await message.reply_text(
app._("permission_denied", "messages", locale=locale), quote=True app._("permission_denied", "messages", locale=locale), quote=True
) )

View File

@ -1,10 +1,10 @@
import logging import logging
from pyrogram import filters from pyrogram import filters
from pyrogram.enums.chat_member_status import ChatMemberStatus
from pyrogram.types import Message from pyrogram.types import Message
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from modules.utils import is_permitted
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,12 +18,7 @@ async def command_timeout_join(app: PyroClient, message: Message):
group = await app.find_group(message.chat.id) group = await app.find_group(message.chat.id)
locale = group.select_locale(app, message.from_user) locale = group.select_locale(app, message.from_user)
if (message.sender_chat is not None and (message.sender_chat.id != group.id)) or ( if not (await is_permitted(app, group, message=message)):
await app.get_chat_member(group.id, message.from_user.id)
).status not in [
ChatMemberStatus.ADMINISTRATOR,
ChatMemberStatus.OWNER,
]:
await message.reply_text( await message.reply_text(
app._("permission_denied", "messages", locale=locale), quote=True app._("permission_denied", "messages", locale=locale), quote=True
) )

View File

@ -1,10 +1,10 @@
import logging import logging
from pyrogram import filters from pyrogram import filters
from pyrogram.enums.chat_member_status import ChatMemberStatus
from pyrogram.types import Message from pyrogram.types import Message
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from modules.utils import is_permitted
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,12 +18,7 @@ async def command_timeout_verify(app: PyroClient, message: Message):
group = await app.find_group(message.chat.id) group = await app.find_group(message.chat.id)
locale = group.select_locale(app, message.from_user) locale = group.select_locale(app, message.from_user)
if (message.sender_chat is not None and (message.sender_chat.id != group.id)) or ( if not (await is_permitted(app, group, message=message)):
await app.get_chat_member(group.id, message.from_user.id)
).status not in [
ChatMemberStatus.ADMINISTRATOR,
ChatMemberStatus.OWNER,
]:
await message.reply_text( await message.reply_text(
app._("permission_denied", "messages", locale=locale), quote=True app._("permission_denied", "messages", locale=locale), quote=True
) )

View File

@ -103,5 +103,5 @@ async def handler_user_join(app: PyroClient, message: Message):
run_date=datetime.now() + timedelta(seconds=group.timeout_join), run_date=datetime.now() + timedelta(seconds=group.timeout_join),
) )
await col_schedule.insert_one( await col_schedule.insert_one(
{"user": user.id, "group": group.id, "job_id": job.id} {"user": int(user.id), "group": int(group.id), "job_id": job.id}
) )

View File

@ -3,11 +3,11 @@ import logging
from pykeyboard import InlineButton, InlineKeyboard from pykeyboard import InlineButton, InlineKeyboard
from pyrogram import filters from pyrogram import filters
from pyrogram.client import Client from pyrogram.client import Client
from pyrogram.enums.chat_member_status import ChatMemberStatus
from pyrogram.types import CallbackQuery, Message from pyrogram.types import CallbackQuery, Message
from classes.callbacks import CallbackLanguage from classes.callbacks import CallbackLanguage
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from modules.utils import is_permitted
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,26 +19,18 @@ async def command_language(app: PyroClient, message: Message):
group = await app.find_group(message.chat.id) group = await app.find_group(message.chat.id)
locale = group.select_locale(app, message.from_user) locale = group.select_locale(app, message.from_user)
if (message.sender_chat is not None and (message.sender_chat.id != group.id)) or ( if not (await is_permitted(app, group, message=message)):
await app.get_chat_member(group.id, message.from_user.id)
).status not in [
ChatMemberStatus.ADMINISTRATOR,
ChatMemberStatus.OWNER,
]:
await message.reply_text( await message.reply_text(
app._("permission_denied", "messages", locale=locale), quote=True app._("permission_denied", "messages", locale=locale), quote=True
) )
return return
keyboard = InlineKeyboard(row_width=2) keyboard = InlineKeyboard(row_width=2)
buttons = [] buttons = [
InlineButton(f"{data['flag']} {data['name']}", f"language:{language}")
for language, data in app.in_every_locale("metadata").items(): for language, data in app.in_every_locale("metadata").items()
if data["selectable"]: if data["selectable"]
buttons.append( ]
InlineButton(f"{data['flag']} {data['name']}", f"language:{language}")
)
buttons.append( buttons.append(
InlineButton( InlineButton(
f"🤖 {app._('locale_default', 'buttons', locale=locale)}", "language:default" f"🤖 {app._('locale_default', 'buttons', locale=locale)}", "language:default"
@ -58,10 +50,7 @@ async def callback_language(app: PyroClient, callback: CallbackQuery):
group = await app.find_group(callback.message.chat.id) group = await app.find_group(callback.message.chat.id)
locale = group.select_locale(app, callback.message.from_user) locale = group.select_locale(app, callback.message.from_user)
if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [ if not (await is_permitted(app, group, callback=callback)):
ChatMemberStatus.ADMINISTRATOR,
ChatMemberStatus.OWNER,
]:
await callback.answer( await callback.answer(
app._("wrong_user", "callbacks", locale=locale), show_alert=True app._("wrong_user", "callbacks", locale=locale), show_alert=True
) )