diff --git a/classes/pyroclient.py b/classes/pyroclient.py index 0496d03..8297ec8 100644 --- a/classes/pyroclient.py +++ b/classes/pyroclient.py @@ -1,30 +1,37 @@ from typing import Union -from libbot.pyrogram.classes import PyroClient +from libbot.pyrogram.classes import PyroClient as LibPyroClient from pyrogram.types import User +from classes.pyrogroup import PyroGroup from classes.pyrouser import PyroUser -from modules.database import col_users -class PyroClient(PyroClient): - async def find_user(self, user: Union[int, User], group: int) -> PyroUser: - """Find User by it's ID or User object +class PyroClient(LibPyroClient): + async def find_user(self, user: Union[int, User], *args, **kwargs) -> PyroUser: + """Find User by it's ID or User object. ### Args: * user (`Union[int, User]`): ID or User object to extract ID from - * group (`int`): ID of the group ### Returns: * `PyroUser`: PyroUser object """ - db_record = await col_users.find_one( - {"id": user.id if isinstance(user, User) else user, "group": group} + + return ( + await PyroUser.find(user, *args, **kwargs) + if isinstance(user, int) + else await PyroUser.find(user.id, *args, **kwargs) ) - if db_record is None: - raise KeyError( - f"User with ID {user.id if isinstance(user, User) else user} was not found in the database" - ) + async def find_group(self, id: int, *args, **kwargs) -> PyroGroup: + """Find Group by it's ID. - return PyroUser(**db_record) + ### Args: + * id (`int`): Group ID + + ### Returns: + * `PyroGroup`: PyroGroup object + """ + + return await PyroGroup.find(id, *args, **kwargs) diff --git a/classes/pyrogroup.py b/classes/pyrogroup.py index b72322b..792e1a5 100644 --- a/classes/pyrogroup.py +++ b/classes/pyrogroup.py @@ -35,7 +35,7 @@ class PyroGroup: timeout_verify: int @classmethod - async def create_if_not_exists( + async def find( cls, id: int, locale: Union[str, None] = sync.config_get("locale", "defaults", "group"), diff --git a/classes/pyrouser.py b/classes/pyrouser.py index 0293371..3b57004 100644 --- a/classes/pyrouser.py +++ b/classes/pyrouser.py @@ -34,7 +34,7 @@ class PyroUser: mistakes: int @classmethod - async def create_if_not_exists( + async def find( cls, id: int, group: int, diff --git a/modules/kicker.py b/modules/kicker.py index 2f03874..e5aeeb6 100644 --- a/modules/kicker.py +++ b/modules/kicker.py @@ -1,14 +1,13 @@ from pyrogram.types import Message from classes.pyroclient import PyroClient -from classes.pyrogroup import PyroGroup async def kick_unstarted( app: PyroClient, user_id: int, group_id: int, message_id: int ) -> None: user = await app.find_user(user_id, group_id) - group = await PyroGroup.create_if_not_exists(group_id) + group = await app.find_group(group_id) if user.score == 0 and user.failed == 0: if group.ban_failed: @@ -25,7 +24,7 @@ async def kick_unverified( app: PyroClient, user_id: int, group_id: int, message_id: int ) -> None: user = await app.find_user(user_id, group_id) - group = await PyroGroup.create_if_not_exists(group_id) + group = await app.find_group(group_id) if user.score < 6 or user.failed: if group.ban_failed: diff --git a/plugins/callbacks/ban.py b/plugins/callbacks/ban.py index e68bdc5..a368ad8 100644 --- a/plugins/callbacks/ban.py +++ b/plugins/callbacks/ban.py @@ -6,14 +6,13 @@ from pyrogram.types import CallbackQuery, Message from classes.callbacks import CallbackBan from classes.pyroclient import PyroClient -from classes.pyrogroup import PyroGroup logger = logging.getLogger(__name__) @PyroClient.on_callback_query(filters.regex(r"ban:[\s\S]*")) # type: ignore async def callback_ban(app: PyroClient, callback: CallbackQuery): - group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True) + group = await app.find_group(callback.message.chat.id) locale = group.select_locale(app, callback.message.from_user) if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [ diff --git a/plugins/callbacks/emoji_button.py b/plugins/callbacks/emoji_button.py index 7d36bcc..47a91dd 100644 --- a/plugins/callbacks/emoji_button.py +++ b/plugins/callbacks/emoji_button.py @@ -11,7 +11,6 @@ from pyrogram.types import ( from classes.callbacks import CallbackEmoji from classes.pyroclient import PyroClient -from classes.pyrogroup import PyroGroup logger = logging.getLogger(__name__) @@ -19,7 +18,7 @@ logger = logging.getLogger(__name__) @PyroClient.on_callback_query(filters.regex(r"emoji:[\s\S]*")) # type: ignore async def callback_emoji_button(app: PyroClient, callback: CallbackQuery): parsed = CallbackEmoji.from_callback(callback) - group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True) + group = await app.find_group(callback.message.chat.id) locale = group.select_locale(app, callback.message.from_user) if callback.from_user.id != parsed.user_id: diff --git a/plugins/callbacks/nothing.py b/plugins/callbacks/nothing.py index 58d6800..adcca16 100644 --- a/plugins/callbacks/nothing.py +++ b/plugins/callbacks/nothing.py @@ -2,12 +2,11 @@ from pyrogram import filters from pyrogram.types import CallbackQuery from classes.pyroclient import PyroClient -from classes.pyrogroup import PyroGroup @PyroClient.on_callback_query(filters.regex(r"nothing")) # type: ignore async def callback_nothing(app: PyroClient, callback: CallbackQuery): - group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True) + group = await app.find_group(callback.message.chat.id) locale = group.select_locale(app, callback.message.from_user) await callback.answer(app._("nothing", "callbacks", locale=locale)) diff --git a/plugins/callbacks/verify.py b/plugins/callbacks/verify.py index 0cc17d1..40920e8 100644 --- a/plugins/callbacks/verify.py +++ b/plugins/callbacks/verify.py @@ -9,7 +9,6 @@ from pyrogram.types import CallbackQuery from classes.callbacks import CallbackVerify from classes.pyroclient import PyroClient -from classes.pyrogroup import PyroGroup from modules.database import col_schedule from modules.kicker import kick_unverified from modules.utils import get_captcha_image @@ -20,7 +19,7 @@ logger = logging.getLogger(__name__) @PyroClient.on_callback_query(filters.regex(r"verify:[\s\S]*")) # type: ignore async def callback_verify(app: PyroClient, callback: CallbackQuery): parsed = CallbackVerify.from_callback(callback) - group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True) + group = await app.find_group(callback.message.chat.id) locale = group.select_locale(app, callback.message.from_user) if callback.from_user.id != parsed.user_id: diff --git a/plugins/commands/ban_failed.py b/plugins/commands/ban_failed.py index 46a0ac4..7831c7d 100644 --- a/plugins/commands/ban_failed.py +++ b/plugins/commands/ban_failed.py @@ -5,7 +5,6 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus from pyrogram.types import Message from classes.pyroclient import PyroClient -from classes.pyrogroup import PyroGroup logger = logging.getLogger(__name__) @@ -16,7 +15,7 @@ logger = logging.getLogger(__name__) & filters.command(["ban_failed"], prefixes=["/"]) # type: ignore ) async def command_ban_failed(app: PyroClient, message: Message): - group = await PyroGroup.create_if_not_exists(message.chat.id, None, True) + group = await app.find_group(message.chat.id) locale = group.select_locale(app, message.from_user) if (await app.get_chat_member(group.id, message.from_user.id)).status not in [ diff --git a/plugins/commands/language_auto.py b/plugins/commands/language_auto.py index 3a1f782..3cd340e 100644 --- a/plugins/commands/language_auto.py +++ b/plugins/commands/language_auto.py @@ -5,7 +5,6 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus from pyrogram.types import Message from classes.pyroclient import PyroClient -from classes.pyrogroup import PyroGroup logger = logging.getLogger(__name__) @@ -16,7 +15,7 @@ logger = logging.getLogger(__name__) & filters.command(["language_auto"], prefixes=["/"]) # type: ignore ) async def command_language_auto(app: PyroClient, message: Message): - group = await PyroGroup.create_if_not_exists(message.chat.id, None, True) + group = await app.find_group(message.chat.id) locale = group.select_locale(app, message.from_user) if (await app.get_chat_member(group.id, message.from_user.id)).status not in [ diff --git a/plugins/commands/timeout_join.py b/plugins/commands/timeout_join.py index 61ecb18..385fab6 100644 --- a/plugins/commands/timeout_join.py +++ b/plugins/commands/timeout_join.py @@ -5,7 +5,6 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus from pyrogram.types import Message from classes.pyroclient import PyroClient -from classes.pyrogroup import PyroGroup logger = logging.getLogger(__name__) @@ -16,7 +15,7 @@ logger = logging.getLogger(__name__) & filters.command(["timeout_join"], prefixes=["/"]) # type: ignore ) async def command_timeout_join(app: PyroClient, message: Message): - group = await PyroGroup.create_if_not_exists(message.chat.id, None, True) + group = await app.find_group(message.chat.id) locale = group.select_locale(app, message.from_user) if (await app.get_chat_member(group.id, message.from_user.id)).status not in [ diff --git a/plugins/commands/timeout_verify.py b/plugins/commands/timeout_verify.py index cb7231b..dac7def 100644 --- a/plugins/commands/timeout_verify.py +++ b/plugins/commands/timeout_verify.py @@ -5,7 +5,6 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus from pyrogram.types import Message from classes.pyroclient import PyroClient -from classes.pyrogroup import PyroGroup logger = logging.getLogger(__name__) @@ -16,7 +15,7 @@ logger = logging.getLogger(__name__) & filters.command(["timeout_verify"], prefixes=["/"]) # type: ignore ) async def command_timeout_verify(app: PyroClient, message: Message): - group = await PyroGroup.create_if_not_exists(message.chat.id, None, True) + group = await app.find_group(message.chat.id) locale = group.select_locale(app, message.from_user) if (await app.get_chat_member(group.id, message.from_user.id)).status not in [ diff --git a/plugins/handlers/bot_join.py b/plugins/handlers/bot_join.py index f8e82be..51cdf6d 100644 --- a/plugins/handlers/bot_join.py +++ b/plugins/handlers/bot_join.py @@ -4,7 +4,6 @@ from pyrogram import filters from pyrogram.types import Message from classes.pyroclient import PyroClient -from classes.pyrogroup import PyroGroup logger = logging.getLogger(__name__) @@ -12,4 +11,4 @@ logger = logging.getLogger(__name__) @PyroClient.on_message(filters.new_chat_members & filters.group & filters.me) # type: ignore async def handler_bot_join(app: PyroClient, message: Message): logger.info("Bot has joined the group %s") - await PyroGroup.create_if_not_exists(message.chat.id, None, True) + await app.find_group(message.chat.id) diff --git a/plugins/handlers/user_join.py b/plugins/handlers/user_join.py index ec48c03..2885a0e 100644 --- a/plugins/handlers/user_join.py +++ b/plugins/handlers/user_join.py @@ -12,8 +12,6 @@ from pyrogram.types import ( ) from classes.pyroclient import PyroClient -from classes.pyrogroup import PyroGroup -from classes.pyrouser import PyroUser from modules.database import col_schedule from modules.kicker import kick_unstarted @@ -24,7 +22,7 @@ logger = logging.getLogger(__name__) filters.new_chat_members & filters.group & ~filters.me & ~filters.bot # type: ignore ) async def handler_user_join(app: PyroClient, message: Message): - group = await PyroGroup.create_if_not_exists(message.chat.id, None, True) + group = await app.find_group(message.chat.id) locale = group.select_locale(app, message.from_user) if ( @@ -64,7 +62,7 @@ async def handler_user_join(app: PyroClient, message: Message): permissions=ChatPermissions(can_send_messages=False), ) - user = await PyroUser.create_if_not_exists(message.from_user.id, group.id) + user = await app.find_user(message.from_user, group=group.id) if user.mistakes > 0 or user.score > 0: await user.set_score(0) diff --git a/plugins/language.py b/plugins/language.py index 571282e..cac2d45 100644 --- a/plugins/language.py +++ b/plugins/language.py @@ -8,7 +8,6 @@ from pyrogram.types import CallbackQuery, Message from classes.callbacks import CallbackLanguage from classes.pyroclient import PyroClient -from classes.pyrogroup import PyroGroup logger = logging.getLogger(__name__) @@ -17,7 +16,7 @@ logger = logging.getLogger(__name__) ~filters.scheduled & filters.group & filters.command(["language"], prefixes=["/"]) # type: ignore ) async def command_language(app: PyroClient, message: Message): - group = await PyroGroup.create_if_not_exists(message.chat.id, None, True) + group = await app.find_group(message.chat.id) locale = group.select_locale(app, message.from_user) if (await app.get_chat_member(group.id, message.from_user.id)).status not in [ @@ -54,7 +53,7 @@ async def command_language(app: PyroClient, message: Message): @Client.on_callback_query(filters.regex(r"language:[\s\S]*")) # type: ignore async def callback_language(app: PyroClient, callback: CallbackQuery): - group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True) + group = await app.find_group(callback.message.chat.id) locale = group.select_locale(app, callback.message.from_user) if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [