Improved tmp files system

This commit is contained in:
Profitroll 2023-01-03 13:01:46 +01:00
parent a59a7b738c
commit a7038e9d8f
2 changed files with 41 additions and 18 deletions

View File

@ -11,7 +11,7 @@ from dateutil.relativedelta import relativedelta
from classes.errors.geo import PlaceNotFoundError
from modules.database import col_tmp, col_users, col_context, col_warnings, col_applications, col_sponsorships, col_messages
from modules.logging import logWrite
from modules.utils import configGet, find_location, locale, should_quote
from modules.utils import configGet, create_tmp, download_tmp, find_location, locale, should_quote
class DefaultApplicationTemp(dict):
def __init__(self, user: int, reapply: bool = False):
@ -349,14 +349,14 @@ class HoloUser():
* msg (`Message`): Message that should receive replies
"""
if col_tmp.find_one({"user": self.id, "type": "application"}) is None:
# if col_tmp.find_one({"user": self.id, "type": "application"}) is None:
if self.sponsorship_state()[0] == "fill":
return
if self.sponsorship_state()[0] == "fill":
return
col_tmp.insert_one(
document=DefaultApplicationTemp(self.id).dict
)
# col_tmp.insert_one(
# document=DefaultApplicationTemp(self.id).dict
# )
progress = col_tmp.find_one({"user": self.id, "type": "application"})
@ -365,9 +365,9 @@ class HoloUser():
stage = progress["stage"]
if self.sponsorship_state()[0] == "fill":
await msg.reply_text(locale("finish_sponsorship", "message"), quote=should_quote(msg))
return
# if self.sponsorship_state()[0] == "fill":
# await msg.reply_text(locale("finish_sponsorship", "message"), quote=should_quote(msg))
# return
if progress["state"] == "fill" and progress["sent"] is False:
@ -522,11 +522,7 @@ class HoloUser():
elif stage == 3:
if photo is not None:
filename = uuid1()
await app.download_media(photo.file_id, f"tmp{sep}{filename}")
with open(f"tmp{sep}{filename}", "rb") as f:
photo_bytes = f.read()
progress["sponsorship"]["proof"] = photo_bytes
progress["sponsorship"]["proof"] = await download_tmp(app, photo.file_id)
col_tmp.update_one({"user": {"$eq": self.id}, "type": {"$eq": "sponsorship"}}, {"$set": {"sponsorship": progress["sponsorship"], "stage": progress["stage"]+1}})
await msg.reply_text(locale(f"sponsor{stage+1}", "message", locale=self.locale), reply_markup=ForceReply(placeholder=str(locale(f"sponsor{stage+1}", "force_reply", locale=self.locale))))
@ -536,7 +532,14 @@ class HoloUser():
return
progress["sponsorship"]["label"] = query
col_tmp.update_one({"user": {"$eq": self.id}, "type": {"$eq": "sponsorship"}}, {"$set": {"sponsorship": progress["sponsorship"], "complete": True}})
await msg.reply_text(locale("sponsor_confirm", "message", locale=self.locale), reply_markup=ReplyKeyboardMarkup(locale("confirm", "keyboard", locale=self.locale), resize_keyboard=True))
await msg.reply_photo(
photo=create_tmp(progress["sponsorship"]["proof"], kind="image"),
caption=locale("sponsor_confirm", "message", locale=self.locale).format(
progress["sponsorship"]["streamer"],
progress["sponsorship"]["expires"].strftime("%d.%m.%Y"),
progress["sponsorship"]["label"]
),
reply_markup=ReplyKeyboardMarkup(locale("confirm", "keyboard", locale=self.locale), resize_keyboard=True))
else:
return

View File

@ -1,4 +1,5 @@
from typing import Any, Union
from typing import Any, Literal, Union
from uuid import uuid1
from requests import get
from pyrogram.enums.chat_type import ChatType
from pyrogram.types import User
@ -8,7 +9,7 @@ from ujson import JSONDecodeError as JSONDecodeError
from ujson import loads, dumps
from sys import exit
from os import kill, listdir, sep
from os import kill, listdir, makedirs, path, sep
from os import name as osname
from traceback import print_exc
from classes.errors.geo import PlaceNotFoundError
@ -191,6 +192,25 @@ def find_location(query: str) -> dict:
except (ValueError, KeyError, IndexError):
raise PlaceNotFoundError(query)
def create_tmp(bytedata: Union[bytes, bytearray], kind: Union[Literal["image", "video"], None]) -> str:
filename = str(uuid1())
if kind == "image":
filename += ".jpg"
elif kind == "video":
filename += ".mp4"
makedirs("tmp", exist_ok=True)
with open(path.join("tmp", filename), "wb") as file:
file.write(bytedata)
return path.join("tmp", filename)
async def download_tmp(app: Client, file_id: str) -> bytes:
filename = str(uuid1())
makedirs("tmp", exist_ok=True)
await app.download_media(file_id, path.join("tmp", filename))
with open(path.join("tmp", filename), "rb") as f:
bytedata = f.read()
return bytedata
try:
from psutil import Process
except ModuleNotFoundError: