Improved tmp files system

This commit is contained in:
Profitroll 2023-01-03 13:01:46 +01:00
parent a59a7b738c
commit a7038e9d8f
2 changed files with 41 additions and 18 deletions

View File

@ -11,7 +11,7 @@ from dateutil.relativedelta import relativedelta
from classes.errors.geo import PlaceNotFoundError from classes.errors.geo import PlaceNotFoundError
from modules.database import col_tmp, col_users, col_context, col_warnings, col_applications, col_sponsorships, col_messages from modules.database import col_tmp, col_users, col_context, col_warnings, col_applications, col_sponsorships, col_messages
from modules.logging import logWrite from modules.logging import logWrite
from modules.utils import configGet, find_location, locale, should_quote from modules.utils import configGet, create_tmp, download_tmp, find_location, locale, should_quote
class DefaultApplicationTemp(dict): class DefaultApplicationTemp(dict):
def __init__(self, user: int, reapply: bool = False): def __init__(self, user: int, reapply: bool = False):
@ -349,14 +349,14 @@ class HoloUser():
* msg (`Message`): Message that should receive replies * msg (`Message`): Message that should receive replies
""" """
if col_tmp.find_one({"user": self.id, "type": "application"}) is None: # if col_tmp.find_one({"user": self.id, "type": "application"}) is None:
if self.sponsorship_state()[0] == "fill": if self.sponsorship_state()[0] == "fill":
return return
col_tmp.insert_one( # col_tmp.insert_one(
document=DefaultApplicationTemp(self.id).dict # document=DefaultApplicationTemp(self.id).dict
) # )
progress = col_tmp.find_one({"user": self.id, "type": "application"}) progress = col_tmp.find_one({"user": self.id, "type": "application"})
@ -365,9 +365,9 @@ class HoloUser():
stage = progress["stage"] stage = progress["stage"]
if self.sponsorship_state()[0] == "fill": # if self.sponsorship_state()[0] == "fill":
await msg.reply_text(locale("finish_sponsorship", "message"), quote=should_quote(msg)) # await msg.reply_text(locale("finish_sponsorship", "message"), quote=should_quote(msg))
return # return
if progress["state"] == "fill" and progress["sent"] is False: if progress["state"] == "fill" and progress["sent"] is False:
@ -522,11 +522,7 @@ class HoloUser():
elif stage == 3: elif stage == 3:
if photo is not None: if photo is not None:
filename = uuid1() progress["sponsorship"]["proof"] = await download_tmp(app, photo.file_id)
await app.download_media(photo.file_id, f"tmp{sep}{filename}")
with open(f"tmp{sep}{filename}", "rb") as f:
photo_bytes = f.read()
progress["sponsorship"]["proof"] = photo_bytes
col_tmp.update_one({"user": {"$eq": self.id}, "type": {"$eq": "sponsorship"}}, {"$set": {"sponsorship": progress["sponsorship"], "stage": progress["stage"]+1}}) col_tmp.update_one({"user": {"$eq": self.id}, "type": {"$eq": "sponsorship"}}, {"$set": {"sponsorship": progress["sponsorship"], "stage": progress["stage"]+1}})
await msg.reply_text(locale(f"sponsor{stage+1}", "message", locale=self.locale), reply_markup=ForceReply(placeholder=str(locale(f"sponsor{stage+1}", "force_reply", locale=self.locale)))) await msg.reply_text(locale(f"sponsor{stage+1}", "message", locale=self.locale), reply_markup=ForceReply(placeholder=str(locale(f"sponsor{stage+1}", "force_reply", locale=self.locale))))
@ -536,7 +532,14 @@ class HoloUser():
return return
progress["sponsorship"]["label"] = query progress["sponsorship"]["label"] = query
col_tmp.update_one({"user": {"$eq": self.id}, "type": {"$eq": "sponsorship"}}, {"$set": {"sponsorship": progress["sponsorship"], "complete": True}}) col_tmp.update_one({"user": {"$eq": self.id}, "type": {"$eq": "sponsorship"}}, {"$set": {"sponsorship": progress["sponsorship"], "complete": True}})
await msg.reply_text(locale("sponsor_confirm", "message", locale=self.locale), reply_markup=ReplyKeyboardMarkup(locale("confirm", "keyboard", locale=self.locale), resize_keyboard=True)) await msg.reply_photo(
photo=create_tmp(progress["sponsorship"]["proof"], kind="image"),
caption=locale("sponsor_confirm", "message", locale=self.locale).format(
progress["sponsorship"]["streamer"],
progress["sponsorship"]["expires"].strftime("%d.%m.%Y"),
progress["sponsorship"]["label"]
),
reply_markup=ReplyKeyboardMarkup(locale("confirm", "keyboard", locale=self.locale), resize_keyboard=True))
else: else:
return return

View File

@ -1,4 +1,5 @@
from typing import Any, Union from typing import Any, Literal, Union
from uuid import uuid1
from requests import get from requests import get
from pyrogram.enums.chat_type import ChatType from pyrogram.enums.chat_type import ChatType
from pyrogram.types import User from pyrogram.types import User
@ -8,7 +9,7 @@ from ujson import JSONDecodeError as JSONDecodeError
from ujson import loads, dumps from ujson import loads, dumps
from sys import exit from sys import exit
from os import kill, listdir, sep from os import kill, listdir, makedirs, path, sep
from os import name as osname from os import name as osname
from traceback import print_exc from traceback import print_exc
from classes.errors.geo import PlaceNotFoundError from classes.errors.geo import PlaceNotFoundError
@ -191,6 +192,25 @@ def find_location(query: str) -> dict:
except (ValueError, KeyError, IndexError): except (ValueError, KeyError, IndexError):
raise PlaceNotFoundError(query) raise PlaceNotFoundError(query)
def create_tmp(bytedata: Union[bytes, bytearray], kind: Union[Literal["image", "video"], None]) -> str:
filename = str(uuid1())
if kind == "image":
filename += ".jpg"
elif kind == "video":
filename += ".mp4"
makedirs("tmp", exist_ok=True)
with open(path.join("tmp", filename), "wb") as file:
file.write(bytedata)
return path.join("tmp", filename)
async def download_tmp(app: Client, file_id: str) -> bytes:
filename = str(uuid1())
makedirs("tmp", exist_ok=True)
await app.download_media(file_id, path.join("tmp", filename))
with open(path.join("tmp", filename), "rb") as f:
bytedata = f.read()
return bytedata
try: try:
from psutil import Process from psutil import Process
except ModuleNotFoundError: except ModuleNotFoundError: