Implemented find_group and find_user
This commit is contained in:
parent
ab39c111eb
commit
9e00d38877
@ -1,30 +1,37 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from libbot.pyrogram.classes import PyroClient
|
from libbot.pyrogram.classes import PyroClient as LibPyroClient
|
||||||
from pyrogram.types import User
|
from pyrogram.types import User
|
||||||
|
|
||||||
|
from classes.pyrogroup import PyroGroup
|
||||||
from classes.pyrouser import PyroUser
|
from classes.pyrouser import PyroUser
|
||||||
from modules.database import col_users
|
|
||||||
|
|
||||||
|
|
||||||
class PyroClient(PyroClient):
|
class PyroClient(LibPyroClient):
|
||||||
async def find_user(self, user: Union[int, User], group: int) -> PyroUser:
|
async def find_user(self, user: Union[int, User], *args, **kwargs) -> PyroUser:
|
||||||
"""Find User by it's ID or User object
|
"""Find User by it's ID or User object.
|
||||||
|
|
||||||
### Args:
|
### Args:
|
||||||
* user (`Union[int, User]`): ID or User object to extract ID from
|
* user (`Union[int, User]`): ID or User object to extract ID from
|
||||||
* group (`int`): ID of the group
|
|
||||||
|
|
||||||
### Returns:
|
### Returns:
|
||||||
* `PyroUser`: PyroUser object
|
* `PyroUser`: PyroUser object
|
||||||
"""
|
"""
|
||||||
db_record = await col_users.find_one(
|
|
||||||
{"id": user.id if isinstance(user, User) else user, "group": group}
|
return (
|
||||||
|
await PyroUser.find(user, *args, **kwargs)
|
||||||
|
if isinstance(user, int)
|
||||||
|
else await PyroUser.find(user.id, *args, **kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_record is None:
|
async def find_group(self, id: int, *args, **kwargs) -> PyroGroup:
|
||||||
raise KeyError(
|
"""Find Group by it's ID.
|
||||||
f"User with ID {user.id if isinstance(user, User) else user} was not found in the database"
|
|
||||||
)
|
|
||||||
|
|
||||||
return PyroUser(**db_record)
|
### Args:
|
||||||
|
* id (`int`): Group ID
|
||||||
|
|
||||||
|
### Returns:
|
||||||
|
* `PyroGroup`: PyroGroup object
|
||||||
|
"""
|
||||||
|
|
||||||
|
return await PyroGroup.find(id, *args, **kwargs)
|
||||||
|
@ -35,7 +35,7 @@ class PyroGroup:
|
|||||||
timeout_verify: int
|
timeout_verify: int
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_if_not_exists(
|
async def find(
|
||||||
cls,
|
cls,
|
||||||
id: int,
|
id: int,
|
||||||
locale: Union[str, None] = sync.config_get("locale", "defaults", "group"),
|
locale: Union[str, None] = sync.config_get("locale", "defaults", "group"),
|
||||||
|
@ -34,7 +34,7 @@ class PyroUser:
|
|||||||
mistakes: int
|
mistakes: int
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_if_not_exists(
|
async def find(
|
||||||
cls,
|
cls,
|
||||||
id: int,
|
id: int,
|
||||||
group: int,
|
group: int,
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
from pyrogram.types import Message
|
from pyrogram.types import Message
|
||||||
|
|
||||||
from classes.pyroclient import PyroClient
|
from classes.pyroclient import PyroClient
|
||||||
from classes.pyrogroup import PyroGroup
|
|
||||||
|
|
||||||
|
|
||||||
async def kick_unstarted(
|
async def kick_unstarted(
|
||||||
app: PyroClient, user_id: int, group_id: int, message_id: int
|
app: PyroClient, user_id: int, group_id: int, message_id: int
|
||||||
) -> None:
|
) -> None:
|
||||||
user = await app.find_user(user_id, group_id)
|
user = await app.find_user(user_id, group_id)
|
||||||
group = await PyroGroup.create_if_not_exists(group_id)
|
group = await app.find_group(group_id)
|
||||||
|
|
||||||
if user.score == 0 and user.failed == 0:
|
if user.score == 0 and user.failed == 0:
|
||||||
if group.ban_failed:
|
if group.ban_failed:
|
||||||
@ -25,7 +24,7 @@ async def kick_unverified(
|
|||||||
app: PyroClient, user_id: int, group_id: int, message_id: int
|
app: PyroClient, user_id: int, group_id: int, message_id: int
|
||||||
) -> None:
|
) -> None:
|
||||||
user = await app.find_user(user_id, group_id)
|
user = await app.find_user(user_id, group_id)
|
||||||
group = await PyroGroup.create_if_not_exists(group_id)
|
group = await app.find_group(group_id)
|
||||||
|
|
||||||
if user.score < 6 or user.failed:
|
if user.score < 6 or user.failed:
|
||||||
if group.ban_failed:
|
if group.ban_failed:
|
||||||
|
@ -6,14 +6,13 @@ from pyrogram.types import CallbackQuery, Message
|
|||||||
|
|
||||||
from classes.callbacks import CallbackBan
|
from classes.callbacks import CallbackBan
|
||||||
from classes.pyroclient import PyroClient
|
from classes.pyroclient import PyroClient
|
||||||
from classes.pyrogroup import PyroGroup
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@PyroClient.on_callback_query(filters.regex(r"ban:[\s\S]*")) # type: ignore
|
@PyroClient.on_callback_query(filters.regex(r"ban:[\s\S]*")) # type: ignore
|
||||||
async def callback_ban(app: PyroClient, callback: CallbackQuery):
|
async def callback_ban(app: PyroClient, callback: CallbackQuery):
|
||||||
group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True)
|
group = await app.find_group(callback.message.chat.id)
|
||||||
locale = group.select_locale(app, callback.message.from_user)
|
locale = group.select_locale(app, callback.message.from_user)
|
||||||
|
|
||||||
if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [
|
if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [
|
||||||
|
@ -11,7 +11,6 @@ from pyrogram.types import (
|
|||||||
|
|
||||||
from classes.callbacks import CallbackEmoji
|
from classes.callbacks import CallbackEmoji
|
||||||
from classes.pyroclient import PyroClient
|
from classes.pyroclient import PyroClient
|
||||||
from classes.pyrogroup import PyroGroup
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -19,7 +18,7 @@ logger = logging.getLogger(__name__)
|
|||||||
@PyroClient.on_callback_query(filters.regex(r"emoji:[\s\S]*")) # type: ignore
|
@PyroClient.on_callback_query(filters.regex(r"emoji:[\s\S]*")) # type: ignore
|
||||||
async def callback_emoji_button(app: PyroClient, callback: CallbackQuery):
|
async def callback_emoji_button(app: PyroClient, callback: CallbackQuery):
|
||||||
parsed = CallbackEmoji.from_callback(callback)
|
parsed = CallbackEmoji.from_callback(callback)
|
||||||
group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True)
|
group = await app.find_group(callback.message.chat.id)
|
||||||
locale = group.select_locale(app, callback.message.from_user)
|
locale = group.select_locale(app, callback.message.from_user)
|
||||||
|
|
||||||
if callback.from_user.id != parsed.user_id:
|
if callback.from_user.id != parsed.user_id:
|
||||||
|
@ -2,12 +2,11 @@ from pyrogram import filters
|
|||||||
from pyrogram.types import CallbackQuery
|
from pyrogram.types import CallbackQuery
|
||||||
|
|
||||||
from classes.pyroclient import PyroClient
|
from classes.pyroclient import PyroClient
|
||||||
from classes.pyrogroup import PyroGroup
|
|
||||||
|
|
||||||
|
|
||||||
@PyroClient.on_callback_query(filters.regex(r"nothing")) # type: ignore
|
@PyroClient.on_callback_query(filters.regex(r"nothing")) # type: ignore
|
||||||
async def callback_nothing(app: PyroClient, callback: CallbackQuery):
|
async def callback_nothing(app: PyroClient, callback: CallbackQuery):
|
||||||
group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True)
|
group = await app.find_group(callback.message.chat.id)
|
||||||
locale = group.select_locale(app, callback.message.from_user)
|
locale = group.select_locale(app, callback.message.from_user)
|
||||||
|
|
||||||
await callback.answer(app._("nothing", "callbacks", locale=locale))
|
await callback.answer(app._("nothing", "callbacks", locale=locale))
|
||||||
|
@ -9,7 +9,6 @@ from pyrogram.types import CallbackQuery
|
|||||||
|
|
||||||
from classes.callbacks import CallbackVerify
|
from classes.callbacks import CallbackVerify
|
||||||
from classes.pyroclient import PyroClient
|
from classes.pyroclient import PyroClient
|
||||||
from classes.pyrogroup import PyroGroup
|
|
||||||
from modules.database import col_schedule
|
from modules.database import col_schedule
|
||||||
from modules.kicker import kick_unverified
|
from modules.kicker import kick_unverified
|
||||||
from modules.utils import get_captcha_image
|
from modules.utils import get_captcha_image
|
||||||
@ -20,7 +19,7 @@ logger = logging.getLogger(__name__)
|
|||||||
@PyroClient.on_callback_query(filters.regex(r"verify:[\s\S]*")) # type: ignore
|
@PyroClient.on_callback_query(filters.regex(r"verify:[\s\S]*")) # type: ignore
|
||||||
async def callback_verify(app: PyroClient, callback: CallbackQuery):
|
async def callback_verify(app: PyroClient, callback: CallbackQuery):
|
||||||
parsed = CallbackVerify.from_callback(callback)
|
parsed = CallbackVerify.from_callback(callback)
|
||||||
group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True)
|
group = await app.find_group(callback.message.chat.id)
|
||||||
locale = group.select_locale(app, callback.message.from_user)
|
locale = group.select_locale(app, callback.message.from_user)
|
||||||
|
|
||||||
if callback.from_user.id != parsed.user_id:
|
if callback.from_user.id != parsed.user_id:
|
||||||
|
@ -5,7 +5,6 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus
|
|||||||
from pyrogram.types import Message
|
from pyrogram.types import Message
|
||||||
|
|
||||||
from classes.pyroclient import PyroClient
|
from classes.pyroclient import PyroClient
|
||||||
from classes.pyrogroup import PyroGroup
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
|
|||||||
& filters.command(["ban_failed"], prefixes=["/"]) # type: ignore
|
& filters.command(["ban_failed"], prefixes=["/"]) # type: ignore
|
||||||
)
|
)
|
||||||
async def command_ban_failed(app: PyroClient, message: Message):
|
async def command_ban_failed(app: PyroClient, message: Message):
|
||||||
group = await PyroGroup.create_if_not_exists(message.chat.id, None, True)
|
group = await app.find_group(message.chat.id)
|
||||||
locale = group.select_locale(app, message.from_user)
|
locale = group.select_locale(app, message.from_user)
|
||||||
|
|
||||||
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [
|
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [
|
||||||
|
@ -5,7 +5,6 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus
|
|||||||
from pyrogram.types import Message
|
from pyrogram.types import Message
|
||||||
|
|
||||||
from classes.pyroclient import PyroClient
|
from classes.pyroclient import PyroClient
|
||||||
from classes.pyrogroup import PyroGroup
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
|
|||||||
& filters.command(["language_auto"], prefixes=["/"]) # type: ignore
|
& filters.command(["language_auto"], prefixes=["/"]) # type: ignore
|
||||||
)
|
)
|
||||||
async def command_language_auto(app: PyroClient, message: Message):
|
async def command_language_auto(app: PyroClient, message: Message):
|
||||||
group = await PyroGroup.create_if_not_exists(message.chat.id, None, True)
|
group = await app.find_group(message.chat.id)
|
||||||
locale = group.select_locale(app, message.from_user)
|
locale = group.select_locale(app, message.from_user)
|
||||||
|
|
||||||
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [
|
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [
|
||||||
|
@ -5,7 +5,6 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus
|
|||||||
from pyrogram.types import Message
|
from pyrogram.types import Message
|
||||||
|
|
||||||
from classes.pyroclient import PyroClient
|
from classes.pyroclient import PyroClient
|
||||||
from classes.pyrogroup import PyroGroup
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
|
|||||||
& filters.command(["timeout_join"], prefixes=["/"]) # type: ignore
|
& filters.command(["timeout_join"], prefixes=["/"]) # type: ignore
|
||||||
)
|
)
|
||||||
async def command_timeout_join(app: PyroClient, message: Message):
|
async def command_timeout_join(app: PyroClient, message: Message):
|
||||||
group = await PyroGroup.create_if_not_exists(message.chat.id, None, True)
|
group = await app.find_group(message.chat.id)
|
||||||
locale = group.select_locale(app, message.from_user)
|
locale = group.select_locale(app, message.from_user)
|
||||||
|
|
||||||
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [
|
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [
|
||||||
|
@ -5,7 +5,6 @@ from pyrogram.enums.chat_member_status import ChatMemberStatus
|
|||||||
from pyrogram.types import Message
|
from pyrogram.types import Message
|
||||||
|
|
||||||
from classes.pyroclient import PyroClient
|
from classes.pyroclient import PyroClient
|
||||||
from classes.pyrogroup import PyroGroup
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
|
|||||||
& filters.command(["timeout_verify"], prefixes=["/"]) # type: ignore
|
& filters.command(["timeout_verify"], prefixes=["/"]) # type: ignore
|
||||||
)
|
)
|
||||||
async def command_timeout_verify(app: PyroClient, message: Message):
|
async def command_timeout_verify(app: PyroClient, message: Message):
|
||||||
group = await PyroGroup.create_if_not_exists(message.chat.id, None, True)
|
group = await app.find_group(message.chat.id)
|
||||||
locale = group.select_locale(app, message.from_user)
|
locale = group.select_locale(app, message.from_user)
|
||||||
|
|
||||||
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [
|
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [
|
||||||
|
@ -4,7 +4,6 @@ from pyrogram import filters
|
|||||||
from pyrogram.types import Message
|
from pyrogram.types import Message
|
||||||
|
|
||||||
from classes.pyroclient import PyroClient
|
from classes.pyroclient import PyroClient
|
||||||
from classes.pyrogroup import PyroGroup
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -12,4 +11,4 @@ logger = logging.getLogger(__name__)
|
|||||||
@PyroClient.on_message(filters.new_chat_members & filters.group & filters.me) # type: ignore
|
@PyroClient.on_message(filters.new_chat_members & filters.group & filters.me) # type: ignore
|
||||||
async def handler_bot_join(app: PyroClient, message: Message):
|
async def handler_bot_join(app: PyroClient, message: Message):
|
||||||
logger.info("Bot has joined the group %s")
|
logger.info("Bot has joined the group %s")
|
||||||
await PyroGroup.create_if_not_exists(message.chat.id, None, True)
|
await app.find_group(message.chat.id)
|
||||||
|
@ -12,8 +12,6 @@ from pyrogram.types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from classes.pyroclient import PyroClient
|
from classes.pyroclient import PyroClient
|
||||||
from classes.pyrogroup import PyroGroup
|
|
||||||
from classes.pyrouser import PyroUser
|
|
||||||
from modules.database import col_schedule
|
from modules.database import col_schedule
|
||||||
from modules.kicker import kick_unstarted
|
from modules.kicker import kick_unstarted
|
||||||
|
|
||||||
@ -24,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|||||||
filters.new_chat_members & filters.group & ~filters.me & ~filters.bot # type: ignore
|
filters.new_chat_members & filters.group & ~filters.me & ~filters.bot # type: ignore
|
||||||
)
|
)
|
||||||
async def handler_user_join(app: PyroClient, message: Message):
|
async def handler_user_join(app: PyroClient, message: Message):
|
||||||
group = await PyroGroup.create_if_not_exists(message.chat.id, None, True)
|
group = await app.find_group(message.chat.id)
|
||||||
locale = group.select_locale(app, message.from_user)
|
locale = group.select_locale(app, message.from_user)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -64,7 +62,7 @@ async def handler_user_join(app: PyroClient, message: Message):
|
|||||||
permissions=ChatPermissions(can_send_messages=False),
|
permissions=ChatPermissions(can_send_messages=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
user = await PyroUser.create_if_not_exists(message.from_user.id, group.id)
|
user = await app.find_user(message.from_user, group=group.id)
|
||||||
|
|
||||||
if user.mistakes > 0 or user.score > 0:
|
if user.mistakes > 0 or user.score > 0:
|
||||||
await user.set_score(0)
|
await user.set_score(0)
|
||||||
|
@ -8,7 +8,6 @@ from pyrogram.types import CallbackQuery, Message
|
|||||||
|
|
||||||
from classes.callbacks import CallbackLanguage
|
from classes.callbacks import CallbackLanguage
|
||||||
from classes.pyroclient import PyroClient
|
from classes.pyroclient import PyroClient
|
||||||
from classes.pyrogroup import PyroGroup
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -17,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||||||
~filters.scheduled & filters.group & filters.command(["language"], prefixes=["/"]) # type: ignore
|
~filters.scheduled & filters.group & filters.command(["language"], prefixes=["/"]) # type: ignore
|
||||||
)
|
)
|
||||||
async def command_language(app: PyroClient, message: Message):
|
async def command_language(app: PyroClient, message: Message):
|
||||||
group = await PyroGroup.create_if_not_exists(message.chat.id, None, True)
|
group = await app.find_group(message.chat.id)
|
||||||
locale = group.select_locale(app, message.from_user)
|
locale = group.select_locale(app, message.from_user)
|
||||||
|
|
||||||
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [
|
if (await app.get_chat_member(group.id, message.from_user.id)).status not in [
|
||||||
@ -54,7 +53,7 @@ async def command_language(app: PyroClient, message: Message):
|
|||||||
|
|
||||||
@Client.on_callback_query(filters.regex(r"language:[\s\S]*")) # type: ignore
|
@Client.on_callback_query(filters.regex(r"language:[\s\S]*")) # type: ignore
|
||||||
async def callback_language(app: PyroClient, callback: CallbackQuery):
|
async def callback_language(app: PyroClient, callback: CallbackQuery):
|
||||||
group = await PyroGroup.create_if_not_exists(callback.message.chat.id, None, True)
|
group = await app.find_group(callback.message.chat.id)
|
||||||
locale = group.select_locale(app, callback.message.from_user)
|
locale = group.select_locale(app, callback.message.from_user)
|
||||||
|
|
||||||
if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [
|
if (await app.get_chat_member(group.id, callback.from_user.id)).status not in [
|
||||||
|
Reference in New Issue
Block a user