diff --git a/classes/pyroclient.py b/classes/pyroclient.py index 54f2edc..bb966bf 100644 --- a/classes/pyroclient.py +++ b/classes/pyroclient.py @@ -16,7 +16,7 @@ from libbot import json_write from libbot.i18n.sync import _ from photosapi_client.errors import UnexpectedStatus from pyrogram.errors import bad_request_400 -from pyrogram.types import Message +from pyrogram.types import Message, User from pytimeparse.timeparse import timeparse from ujson import dumps, loads @@ -26,6 +26,7 @@ from classes.exceptions import ( SubmissionUnavailableError, SubmissionUnsupportedError, ) +from classes.pyrouser import PyroUser from modules.api_client import ( BodyPhotoUpload, BodyVideoUpload, @@ -36,7 +37,7 @@ from modules.api_client import ( photo_upload, video_upload, ) -from modules.database import col_submitted +from modules.database import col_submitted, col_users from modules.http_client import http_session from modules.sender import send_content @@ -252,8 +253,31 @@ class PyroClient(PyroClient): response.id if not hasattr(response, "parsed") else response.parsed.id, ) - async def ban_user(self, id: int) -> None: - pass + async def find_user(self, user: Union[int, User]) -> PyroUser: + """Find User by it's ID or User object - async def unban_user(self, id: int) -> None: - pass + ### Args: + * user (`Union[int, User]`): ID or User object to extract ID from + + ### Returns: + * `PyroUser`: PyroUser object + """ + if ( + col_users.find_one( + {"id": user.id if isinstance(user, User) else user} + ) # type: ignore + is None + ): + col_users.insert_one( + { + "id": user.id if isinstance(user, User) else user, + "locale": user.language_code if isinstance(user, User) else None, + "subscription": {"expires": datetime(1970, 1, 1, 0, 0)}, + } + ) # type: ignore + + db_record = col_users.find_one( + {"id": user.id if isinstance(user, User) else user} + ) # type: ignore + + return PyroUser(**db_record) diff --git a/classes/pyrouser.py b/classes/pyrouser.py new file mode 100644 index 0000000..6f77db6 --- /dev/null +++ b/classes/pyrouser.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Union + +from bson import ObjectId +from libbot import config_get +from libbot.pyrogram.classes import PyroClient + +from modules.database import col_users + + +@dataclass +class PyroUser: + """Dataclass of DB entry of a user""" + + _id: ObjectId + id: int + locale: Union[str, None] + banned: bool + cooldown: datetime + subscription: dict + + async def update_locale(self, locale: str): + col_users.update_one({"_id": self._id}, {"$set": {"locale": locale}}) + + async def update_cooldown(self, time: datetime = datetime.now()): + col_users.update_one({"_id": self._id}, {"$set": {"cooldown": time}}) + + async def block(self) -> None: + """Ban user from using command and submitting content.""" + col_users.update_one({"_id": self._id}, {"$set": {"banned": True}}) + + async def unblock(self) -> None: + """Allow user to use command and submit posts again.""" + col_users.update_one({"_id": self._id}, {"$set": {"banned": False}}) + + async def is_limited(self, app: Union[PyroClient, None] = None) -> bool: + """Check if user is on a cooldown after submitting something. + + ### Returns: + `bool`: Must be `True` if on the cooldown and `False` if not + """ + admins = ( + app.admins + if app is not None + else ( + await config_get("admins", "bot") + [await config_get("owner", "bot")] + ) + ) + + return (datetime.now() - self.cooldown).total_seconds() < ( + app.config["submission"]["timeout"] + if app is not None + else await config_get("timeout", "submission") + )