/cancel, /identify, sponsorships improvements and fixes #3

Merged
profitroll merged 19 commits from dev into master 2023-01-03 16:45:20 +02:00
2 changed files with 41 additions and 18 deletions
Showing only changes of commit a7038e9d8f - Show all commits

View File

@ -11,7 +11,7 @@ from dateutil.relativedelta import relativedelta
from classes.errors.geo import PlaceNotFoundError
from modules.database import col_tmp, col_users, col_context, col_warnings, col_applications, col_sponsorships, col_messages
from modules.logging import logWrite
from modules.utils import configGet, find_location, locale, should_quote
from modules.utils import configGet, create_tmp, download_tmp, find_location, locale, should_quote
class DefaultApplicationTemp(dict):
def __init__(self, user: int, reapply: bool = False):
@ -349,14 +349,14 @@ class HoloUser():
* msg (`Message`): Message that should receive replies
"""
if col_tmp.find_one({"user": self.id, "type": "application"}) is None:
# if col_tmp.find_one({"user": self.id, "type": "application"}) is None:
if self.sponsorship_state()[0] == "fill":
return
col_tmp.insert_one(
document=DefaultApplicationTemp(self.id).dict
)
# col_tmp.insert_one(
# document=DefaultApplicationTemp(self.id).dict
# )
progress = col_tmp.find_one({"user": self.id, "type": "application"})
@ -365,9 +365,9 @@ class HoloUser():
stage = progress["stage"]
if self.sponsorship_state()[0] == "fill":
await msg.reply_text(locale("finish_sponsorship", "message"), quote=should_quote(msg))
return
# if self.sponsorship_state()[0] == "fill":
# await msg.reply_text(locale("finish_sponsorship", "message"), quote=should_quote(msg))
# return
if progress["state"] == "fill" and progress["sent"] is False:
@ -522,11 +522,7 @@ class HoloUser():
elif stage == 3:
if photo is not None:
filename = uuid1()
await app.download_media(photo.file_id, f"tmp{sep}{filename}")
with open(f"tmp{sep}{filename}", "rb") as f:
photo_bytes = f.read()
progress["sponsorship"]["proof"] = photo_bytes
progress["sponsorship"]["proof"] = await download_tmp(app, photo.file_id)
col_tmp.update_one({"user": {"$eq": self.id}, "type": {"$eq": "sponsorship"}}, {"$set": {"sponsorship": progress["sponsorship"], "stage": progress["stage"]+1}})
await msg.reply_text(locale(f"sponsor{stage+1}", "message", locale=self.locale), reply_markup=ForceReply(placeholder=str(locale(f"sponsor{stage+1}", "force_reply", locale=self.locale))))
@ -536,7 +532,14 @@ class HoloUser():
return
progress["sponsorship"]["label"] = query
col_tmp.update_one({"user": {"$eq": self.id}, "type": {"$eq": "sponsorship"}}, {"$set": {"sponsorship": progress["sponsorship"], "complete": True}})
await msg.reply_text(locale("sponsor_confirm", "message", locale=self.locale), reply_markup=ReplyKeyboardMarkup(locale("confirm", "keyboard", locale=self.locale), resize_keyboard=True))
await msg.reply_photo(
photo=create_tmp(progress["sponsorship"]["proof"], kind="image"),
caption=locale("sponsor_confirm", "message", locale=self.locale).format(
progress["sponsorship"]["streamer"],
progress["sponsorship"]["expires"].strftime("%d.%m.%Y"),
progress["sponsorship"]["label"]
),
reply_markup=ReplyKeyboardMarkup(locale("confirm", "keyboard", locale=self.locale), resize_keyboard=True))
else:
return

View File

@ -1,4 +1,5 @@
from typing import Any, Union
from typing import Any, Literal, Union
from uuid import uuid1
from requests import get
from pyrogram.enums.chat_type import ChatType
from pyrogram.types import User
@ -8,7 +9,7 @@ from ujson import JSONDecodeError as JSONDecodeError
from ujson import loads, dumps
from sys import exit
from os import kill, listdir, sep
from os import kill, listdir, makedirs, path, sep
from os import name as osname
from traceback import print_exc
from classes.errors.geo import PlaceNotFoundError
@ -191,6 +192,25 @@ def find_location(query: str) -> dict:
except (ValueError, KeyError, IndexError):
raise PlaceNotFoundError(query)
def create_tmp(bytedata: Union[bytes, bytearray], kind: Union[Literal["image", "video"], None]) -> str:
filename = str(uuid1())
if kind == "image":
filename += ".jpg"
elif kind == "video":
filename += ".mp4"
makedirs("tmp", exist_ok=True)
with open(path.join("tmp", filename), "wb") as file:
file.write(bytedata)
return path.join("tmp", filename)
async def download_tmp(app: Client, file_id: str) -> bytes:
filename = str(uuid1())
makedirs("tmp", exist_ok=True)
await app.download_media(file_id, path.join("tmp", filename))
with open(path.join("tmp", filename), "rb") as f:
bytedata = f.read()
return bytedata
try:
from psutil import Process
except ModuleNotFoundError: