Implemented find_group and find_user

This commit is contained in:
Profitroll 2023-08-17 16:37:42 +02:00
parent ab39c111eb
commit 9e00d38877
Signed by: profitroll
GPG Key ID: FA35CAB49DACD3B2
15 changed files with 37 additions and 43 deletions

View File

@ -1,30 +1,37 @@
from typing import Union from typing import Union
from libbot.pyrogram.classes import PyroClient from libbot.pyrogram.classes import PyroClient as LibPyroClient
from pyrogram.types import User from pyrogram.types import User
from classes.pyrogroup import PyroGroup
from classes.pyrouser import PyroUser from classes.pyrouser import PyroUser
from modules.database import col_users
class PyroClient(PyroClient): class PyroClient(LibPyroClient):
async def find_user(self, user: Union[int, User], group: int) -> PyroUser: async def find_user(self, user: Union[int, User], *args, **kwargs) -> PyroUser:
"""Find User by it's ID or User object """Find User by it's ID or User object.
### Args: ### Args:
* user (`Union[int, User]`): ID or User object to extract ID from * user (`Union[int, User]`): ID or User object to extract ID from
* group (`int`): ID of the group
### Returns: ### Returns:
* `PyroUser`: PyroUser object * `PyroUser`: PyroUser object
""" """
db_record = await col_users.find_one(
{"id": user.id if isinstance(user, User) else user, "group": group} return (
await PyroUser.find(user, *args, **kwargs)
if isinstance(user, int)
else await PyroUser.find(user.id, *args, **kwargs)
) )
if db_record is None: async def find_group(self, id: int, *args, **kwargs) -> PyroGroup:
raise KeyError( """Find Group by it's ID.
f"User with ID {user.id if isinstance(user, User) else user} was not found in the database"
)
return PyroUser(**db_record) ### Args:
* id (`int`): Group ID
### Returns:
* `PyroGroup`: PyroGroup object
"""
return await PyroGroup.find(id, *args, **kwargs)

View File

@ -35,7 +35,7 @@ class PyroGroup:
timeout_verify: int timeout_verify: int
@classmethod @classmethod
async def create_if_not_exists( async def find(
cls, cls,
id: int, id: int,
locale: Union[str, None] = sync.config_get("locale", "defaults", "group"), locale: Union[str, None] = sync.config_get("locale", "defaults", "group"),

View File

@ -34,7 +34,7 @@ class PyroUser:
mistakes: int mistakes: int
@classmethod @classmethod
async def create_if_not_exists( async def find(
cls, cls,
id: int, id: int,
group: int, group: int,

View File

@ -1,14 +1,13 @@
from pyrogram.types import Message from pyrogram.types import Message
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
async def kick_unstarted( async def kick_unstarted(
app: PyroClient, user_id: int, group_id: int, message_id: int app: PyroClient, user_id: int, group_id: int, message_id: int
) -> None: ) -> None:
user = await app.find_user(user_id, group_id) user = await app.find_user(user_id, group_id)
group = await PyroGroup.create_if_not_exists(group_id) group = await app.find_group(group_id)
if user.score == 0 and user.failed == 0: if user.score == 0 and user.failed == 0:
if group.ban_failed: if group.ban_failed:
@ -25,7 +24,7 @@ async def kick_unverified(
app: PyroClient, user_id: int, group_id: int, message_id: int app: PyroClient, user_id: int, group_id: int, message_id: int
) -> None: ) -> None:
user = await app.find_user(user_id, group_id) user = await app.find_user(user_id, group_id)
group = await PyroGroup.create_if_not_exists(group_id) group = await app.find_group(group_id)
if user.score < 6 or user.failed: if user.score < 6 or user.failed:
if group.ban_failed: if group.ban_failed:

View File

@ -6,14 +6,13 @@ from pyrogram.types import CallbackQuery, Message
from classes.callbacks import CallbackBan from classes.callbacks import CallbackBan
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@PyroClient.on_callback_query(filters.regex(r"ban:[\s\S]*")) # type: ignore @PyroClient.on_callback_query(filters.regex(r"ban:[\s\S]*")) # type: ignore
async def callback_ban(app: PyroClient, callback: CallbackQuery): async def callback_ban(app: PyroClient, callback: CallbackQuery):
group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True) group = await app.find_group(callback.message.chat.id)
locale = group.select_locale(app, callback.message.from_user) locale = group.select_locale(app, callback.message.from_user)
if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [ if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [

View File

@ -11,7 +11,6 @@ from pyrogram.types import (
from classes.callbacks import CallbackEmoji from classes.callbacks import CallbackEmoji
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -19,7 +18,7 @@ logger = logging.getLogger(__name__)
@PyroClient.on_callback_query(filters.regex(r"emoji:[\s\S]*")) # type: ignore @PyroClient.on_callback_query(filters.regex(r"emoji:[\s\S]*")) # type: ignore
async def callback_emoji_button(app: PyroClient, callback: CallbackQuery): async def callback_emoji_button(app: PyroClient, callback: CallbackQuery):
parsed = CallbackEmoji.from_callback(callback) parsed = CallbackEmoji.from_callback(callback)
group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True) group = await app.find_group(callback.message.chat.id)
locale = group.select_locale(app, callback.message.from_user) locale = group.select_locale(app, callback.message.from_user)
if callback.from_user.id != parsed.user_id: if callback.from_user.id != parsed.user_id:

View File

@ -2,12 +2,11 @@ from pyrogram import filters
from pyrogram.types import CallbackQuery from pyrogram.types import CallbackQuery
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
@PyroClient.on_callback_query(filters.regex(r"nothing")) # type: ignore @PyroClient.on_callback_query(filters.regex(r"nothing")) # type: ignore
async def callback_nothing(app: PyroClient, callback: CallbackQuery): async def callback_nothing(app: PyroClient, callback: CallbackQuery):
group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True) group = await app.find_group(callback.message.chat.id)
locale = group.select_locale(app, callback.message.from_user) locale = group.select_locale(app, callback.message.from_user)
await callback.answer(app._("nothing", "callbacks", locale=locale)) await callback.answer(app._("nothing", "callbacks", locale=locale))

View File

@ -9,7 +9,6 @@ from pyrogram.types import CallbackQuery
from classes.callbacks import CallbackVerify from classes.callbacks import CallbackVerify
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
from modules.database import col_schedule from modules.database import col_schedule
from modules.kicker import kick_unverified from modules.kicker import kick_unverified
from modules.utils import get_captcha_image from modules.utils import get_captcha_image
@ -20,7 +19,7 @@ logger = logging.getLogger(__name__)
@PyroClient.on_callback_query(filters.regex(r"verify:[\s\S]*")) # type: ignore @PyroClient.on_callback_query(filters.regex(r"verify:[\s\S]*")) # type: ignore
async def callback_verify(app: PyroClient, callback: CallbackQuery): async def callback_verify(app: PyroClient, callback: CallbackQuery):
parsed = CallbackVerify.from_callback(callback) parsed = CallbackVerify.from_callback(callback)
group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True) group = await app.find_group(callback.message.chat.id)
locale = group.select_locale(app, callback.message.from_user) locale = group.select_locale(app, callback.message.from_user)
if callback.from_user.id != parsed.user_id: if callback.from_user.id != parsed.user_id:

View File

@ -5,7 +5,6 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus
from pyrogram.types import Message from pyrogram.types import Message
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
& filters.command(["ban_failed"], prefixes=["/"]) # type: ignore & filters.command(["ban_failed"], prefixes=["/"]) # type: ignore
) )
async def command_ban_failed(app: PyroClient, message: Message): async def command_ban_failed(app: PyroClient, message: Message):
group = await PyroGroup.create_if_not_exists(message.chat.id, None, True) group = await app.find_group(message.chat.id)
locale = group.select_locale(app, message.from_user) locale = group.select_locale(app, message.from_user)
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [ if (await app.get_chat_member(group.id, message.from_user.id)).status not in [

View File

@ -5,7 +5,6 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus
from pyrogram.types import Message from pyrogram.types import Message
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
& filters.command(["language_auto"], prefixes=["/"]) # type: ignore & filters.command(["language_auto"], prefixes=["/"]) # type: ignore
) )
async def command_language_auto(app: PyroClient, message: Message): async def command_language_auto(app: PyroClient, message: Message):
group = await PyroGroup.create_if_not_exists(message.chat.id, None, True) group = await app.find_group(message.chat.id)
locale = group.select_locale(app, message.from_user) locale = group.select_locale(app, message.from_user)
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [ if (await app.get_chat_member(group.id, message.from_user.id)).status not in [

View File

@ -5,7 +5,6 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus
from pyrogram.types import Message from pyrogram.types import Message
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
& filters.command(["timeout_join"], prefixes=["/"]) # type: ignore & filters.command(["timeout_join"], prefixes=["/"]) # type: ignore
) )
async def command_timeout_join(app: PyroClient, message: Message): async def command_timeout_join(app: PyroClient, message: Message):
group = await PyroGroup.create_if_not_exists(message.chat.id, None, True) group = await app.find_group(message.chat.id)
locale = group.select_locale(app, message.from_user) locale = group.select_locale(app, message.from_user)
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [ if (await app.get_chat_member(group.id, message.from_user.id)).status not in [

View File

@ -5,7 +5,6 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus
from pyrogram.types import Message from pyrogram.types import Message
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
& filters.command(["timeout_verify"], prefixes=["/"]) # type: ignore & filters.command(["timeout_verify"], prefixes=["/"]) # type: ignore
) )
async def command_timeout_verify(app: PyroClient, message: Message): async def command_timeout_verify(app: PyroClient, message: Message):
group = await PyroGroup.create_if_not_exists(message.chat.id, None, True) group = await app.find_group(message.chat.id)
locale = group.select_locale(app, message.from_user) locale = group.select_locale(app, message.from_user)
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [ if (await app.get_chat_member(group.id, message.from_user.id)).status not in [

View File

@ -4,7 +4,6 @@ from pyrogram import filters
from pyrogram.types import Message from pyrogram.types import Message
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -12,4 +11,4 @@ logger = logging.getLogger(__name__)
@PyroClient.on_message(filters.new_chat_members & filters.group & filters.me) # type: ignore @PyroClient.on_message(filters.new_chat_members & filters.group & filters.me) # type: ignore
async def handler_bot_join(app: PyroClient, message: Message): async def handler_bot_join(app: PyroClient, message: Message):
logger.info("Bot has joined the group %s") logger.info("Bot has joined the group %s")
await PyroGroup.create_if_not_exists(message.chat.id, None, True) await app.find_group(message.chat.id)

View File

@ -12,8 +12,6 @@ from pyrogram.types import (
) )
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
from classes.pyrouser import PyroUser
from modules.database import col_schedule from modules.database import col_schedule
from modules.kicker import kick_unstarted from modules.kicker import kick_unstarted
@ -24,7 +22,7 @@ logger = logging.getLogger(__name__)
filters.new_chat_members & filters.group & ~filters.me & ~filters.bot # type: ignore filters.new_chat_members & filters.group & ~filters.me & ~filters.bot # type: ignore
) )
async def handler_user_join(app: PyroClient, message: Message): async def handler_user_join(app: PyroClient, message: Message):
group = await PyroGroup.create_if_not_exists(message.chat.id, None, True) group = await app.find_group(message.chat.id)
locale = group.select_locale(app, message.from_user) locale = group.select_locale(app, message.from_user)
if ( if (
@ -64,7 +62,7 @@ async def handler_user_join(app: PyroClient, message: Message):
permissions=ChatPermissions(can_send_messages=False), permissions=ChatPermissions(can_send_messages=False),
) )
user = await PyroUser.create_if_not_exists(message.from_user.id, group.id) user = await app.find_user(message.from_user, group=group.id)
if user.mistakes > 0 or user.score > 0: if user.mistakes > 0 or user.score > 0:
await user.set_score(0) await user.set_score(0)

View File

@ -8,7 +8,6 @@ from pyrogram.types import CallbackQuery, Message
from classes.callbacks import CallbackLanguage from classes.callbacks import CallbackLanguage
from classes.pyroclient import PyroClient from classes.pyroclient import PyroClient
from classes.pyrogroup import PyroGroup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,7 +16,7 @@ logger = logging.getLogger(__name__)
~filters.scheduled & filters.group & filters.command(["language"], prefixes=["/"]) # type: ignore ~filters.scheduled & filters.group & filters.command(["language"], prefixes=["/"]) # type: ignore
) )
async def command_language(app: PyroClient, message: Message): async def command_language(app: PyroClient, message: Message):
group = await PyroGroup.create_if_not_exists(message.chat.id, None, True) group = await app.find_group(message.chat.id)
locale = group.select_locale(app, message.from_user) locale = group.select_locale(app, message.from_user)
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [ if (await app.get_chat_member(group.id, message.from_user.id)).status not in [
@ -54,7 +53,7 @@ async def command_language(app: PyroClient, message: Message):
@Client.on_callback_query(filters.regex(r"language:[\s\S]*")) # type: ignore @Client.on_callback_query(filters.regex(r"language:[\s\S]*")) # type: ignore
async def callback_language(app: PyroClient, callback: CallbackQuery): async def callback_language(app: PyroClient, callback: CallbackQuery):
group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True) group = await app.find_group(callback.message.chat.id)
locale = group.select_locale(app, callback.message.from_user) locale = group.select_locale(app, callback.message.from_user)
if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [ if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [