From cd26990b7e68956c4349a67aff109daadaa14a6a Mon Sep 17 00:00:00 2001 From: profitroll Date: Mon, 14 Aug 2023 14:52:02 +0200 Subject: [PATCH] Migrate to async_pymongo --- classes/pyroclient.py | 10 +++++----- classes/pyrouser.py | 20 +++++++++----------- modules/database.py | 10 ++-------- modules/sender.py | 4 ++-- plugins/callbacks/submission.py | 4 ++-- plugins/handlers/submission.py | 4 ++-- requirements.txt | 2 +- 7 files changed, 23 insertions(+), 31 deletions(-) diff --git a/classes/pyroclient.py b/classes/pyroclient.py index ac6e23b..a8aa076 100644 --- a/classes/pyroclient.py +++ b/classes/pyroclient.py @@ -141,7 +141,7 @@ class PyroClient(PyroClient): async def submit_media( self, id: str ) -> Tuple[Union[Message, None], Union[str, None]]: - db_entry = col_submitted.find_one({"_id": ObjectId(id)}) + db_entry = await col_submitted.find_one({"_id": ObjectId(id)}) submission = None if db_entry is None: @@ -226,7 +226,7 @@ class PyroClient(PyroClient): ) raise SubmissionDuplicatesError(str(filepath), duplicates) - col_submitted.find_one_and_update( + await col_submitted.find_one_and_update( {"_id": ObjectId(id)}, {"$set": {"done": True}} ) @@ -258,12 +258,12 @@ class PyroClient(PyroClient): * `PyroUser`: PyroUser object """ if ( - col_users.find_one( + await col_users.find_one( {"id": user.id if isinstance(user, User) else user} ) # type: ignore is None ): - col_users.insert_one( + await col_users.insert_one( { "id": user.id if isinstance(user, User) else user, "locale": user.language_code if isinstance(user, User) else None, @@ -273,7 +273,7 @@ class PyroClient(PyroClient): } ) # type: ignore - db_record = col_users.find_one( + db_record = await col_users.find_one( {"id": user.id if isinstance(user, User) else user} ) # type: ignore diff --git a/classes/pyrouser.py b/classes/pyrouser.py index 6f77db6..249e6f6 100644 --- a/classes/pyrouser.py +++ b/classes/pyrouser.py @@ -20,19 +20,19 @@ class PyroUser: cooldown: datetime subscription: dict - async def update_locale(self, locale: str): - col_users.update_one({"_id": self._id}, {"$set": {"locale": locale}}) + async def update_locale(self, locale: str) -> None: + await col_users.update_one({"_id": self._id}, {"$set": {"locale": locale}}) - async def update_cooldown(self, time: datetime = datetime.now()): - col_users.update_one({"_id": self._id}, {"$set": {"cooldown": time}}) + async def update_cooldown(self, time: datetime = datetime.now()) -> None: + await col_users.update_one({"_id": self._id}, {"$set": {"cooldown": time}}) async def block(self) -> None: """Ban user from using command and submitting content.""" - col_users.update_one({"_id": self._id}, {"$set": {"banned": True}}) + await col_users.update_one({"_id": self._id}, {"$set": {"banned": True}}) async def unblock(self) -> None: """Allow user to use command and submit posts again.""" - col_users.update_one({"_id": self._id}, {"$set": {"banned": False}}) + await col_users.update_one({"_id": self._id}, {"$set": {"banned": False}}) async def is_limited(self, app: Union[PyroClient, None] = None) -> bool: """Check if user is on a cooldown after submitting something. @@ -41,11 +41,9 @@ class PyroUser: `bool`: Must be `True` if on the cooldown and `False` if not """ admins = ( - app.admins - if app is not None - else ( - await config_get("admins", "bot") + [await config_get("owner", "bot")] - ) + await config_get("admins", "bot") + [await config_get("owner", "bot")] + if app is None + else app.admins ) return (datetime.now() - self.cooldown).total_seconds() < ( diff --git a/modules/database.py b/modules/database.py index acdb504..57fcf4d 100644 --- a/modules/database.py +++ b/modules/database.py @@ -1,6 +1,6 @@ """Module that provides all database columns""" -from pymongo import MongoClient +from async_pymongo import AsyncClient from ujson import loads with open("config.json", "r", encoding="utf-8") as f: @@ -20,15 +20,9 @@ else: db_config["host"], db_config["port"], db_config["name"] ) -db_client = MongoClient(con_string) +db_client = AsyncClient(con_string) db = db_client.get_database(name=db_config["name"]) -collections = db.list_collection_names() - -for collection in ["sent", "users", "submitted"]: - if collection not in collections: - db.create_collection(collection) - col_sent = db.get_collection("sent") col_users = db.get_collection("users") col_submitted = db.get_collection("submitted") diff --git a/modules/sender.py b/modules/sender.py index 0d9af0a..e7bdedf 100644 --- a/modules/sender.py +++ b/modules/sender.py @@ -179,7 +179,7 @@ async def send_content(app: PyroClient, http_session: ClientSession) -> None: del response - submitted = col_submitted.find_one({"temp.file": media.filename}) + submitted = await col_submitted.find_one({"temp.file": media.filename}) if submitted is not None and submitted["caption"] is not None: caption = submitted["caption"].strip() @@ -229,7 +229,7 @@ async def send_content(app: PyroClient, http_session: ClientSession) -> None: # rmtree(path.join(app.config['locations']['tmp'], tmp_dir), ignore_errors=True) return - col_sent.insert_one( + await col_sent.insert_one( { "date": datetime.now(), "image": media.id, diff --git a/plugins/callbacks/submission.py b/plugins/callbacks/submission.py index 47f70a0..f2fe34c 100644 --- a/plugins/callbacks/submission.py +++ b/plugins/callbacks/submission.py @@ -25,7 +25,7 @@ async def callback_query_yes(app: PyroClient, clb: CallbackQuery): user = await app.find_user(clb.from_user) fullclb = str(clb.data).split("_") - db_entry = col_submitted.find_one({"_id": ObjectId(fullclb[2])}) + db_entry = await col_submitted.find_one({"_id": ObjectId(fullclb[2])}) try: submission = await app.submit_media(fullclb[2]) @@ -127,7 +127,7 @@ async def callback_query_no(app: PyroClient, clb: CallbackQuery): user = await app.find_user(clb.from_user) fullclb = str(clb.data).split("_") - db_entry = col_submitted.find_one_and_delete({"_id": ObjectId(fullclb[2])}) + db_entry = await col_submitted.find_one_and_delete({"_id": ObjectId(fullclb[2])}) if ( db_entry["temp"]["uuid"] is not None diff --git a/plugins/handlers/submission.py b/plugins/handlers/submission.py index 3206c85..f6b7bfb 100644 --- a/plugins/handlers/submission.py +++ b/plugins/handlers/submission.py @@ -152,7 +152,7 @@ async def get_submission(app: PyroClient, msg: Message): + sep, ) - inserted = col_submitted.insert_one( + inserted = await col_submitted.insert_one( { "user": msg.from_user.id, "date": datetime.now(), @@ -165,7 +165,7 @@ async def get_submission(app: PyroClient, msg: Message): ) else: - inserted = col_submitted.insert_one( + inserted = await col_submitted.insert_one( { "user": msg.from_user.id, "date": datetime.now(), diff --git a/requirements.txt b/requirements.txt index d0bda96..6049c93 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,12 +4,12 @@ convopyro==0.5 pillow~=10.0.0 psutil~=5.9.4 pykeyboard==0.1.5 -pymongo~=4.4.0 pyrogram==2.0.106 python_dateutil==2.8.2 pytimeparse~=1.1.8 tgcrypto==1.2.5 uvloop==0.17.0 --extra-index-url https://git.end-play.xyz/api/packages/profitroll/pypi/simple +async_pymongo==0.1.4 libbot[speed,pyrogram]==2.0.1 photosapi_client==0.5.0 \ No newline at end of file