Prototype: Files import

This commit is contained in:
Profitroll 2023-03-20 12:03:03 +01:00
parent 1749d49a52
commit 03b6bbe039
4 changed files with 144 additions and 15 deletions

View File

@ -10,3 +10,5 @@ app = PosterClient(
)
Conversation(app)
users_with_context = []

View File

@ -1,10 +1,12 @@
from os import kill, path
from os import kill, makedirs
from os import name as osname
from os import sep
from os import path, sep
from sys import exit
from traceback import print_exc
from typing import Any
from zipfile import ZipFile
import aiofiles
from ujson import JSONDecodeError, dumps, loads
from modules.logger import logWrite
@ -207,6 +209,28 @@ def locale(key: str, *args: str, locale=configGet("locale")):
return f'⚠️ Locale in config is invalid: could not get "{key}" in {str(args)} from locale "{locale}"'
async def extract_and_save(handle: ZipFile, filename: str, destpath: str):
"""Extract and save file from archive
Args:
* handle (ZipFile): ZipFile handler
* filename (str): File base name
* path (str): Path where to store
"""
data = handle.read(filename)
filepath = path.join(destpath, filename)
try:
makedirs(path.dirname(filepath), exist_ok=True)
async with aiofiles.open(filepath, "wb") as fd:
await fd.write(data)
logWrite(f"Unzipped {filename}", debug=True)
except IsADirectoryError:
makedirs(filepath, exist_ok=True)
except FileNotFoundError:
pass
return
try:
from psutil import Process
except ModuleNotFoundError:

View File

@ -1,22 +1,122 @@
import asyncio
from glob import iglob
from os import getcwd, makedirs, path, remove
from shutil import disk_usage, rmtree
from traceback import format_exc
from uuid import uuid4
from zipfile import ZipFile
from convopyro import listen_message
from pyrogram import filters
from pyrogram.types import Message
from convopyro import listen_message
from classes.poster_client import PosterClient
from modules.app import app
from modules.api_client import upload_pic
from modules.app import app, users_with_context
from modules.logger import logWrite
from modules.utils import configGet, extract_and_save
@app.on_message(~filters.scheduled & filters.command(["import"], prefixes=["", "/"]))
async def cmd_import(app: PosterClient, msg: Message):
if msg.from_user.id in app.admins:
print("Listening to file...", flush=True)
answer = await app.listen.Message(
filters.document, id=filters.user(msg.from_user.id), timeout=None
)
if answer is None:
global users_with_context
if msg.from_user.id not in users_with_context:
users_with_context.append(msg.from_user.id)
else:
return
print("Gotcha", flush=True)
await answer.reply_text("Gotcha")
await msg.reply_text(
f"Alright, please send me a zip archive with your media to be imported. Use /cancel if you want to abort this operation."
)
answer = await listen_message(app, msg.chat.id, timeout=600)
users_with_context.remove(msg.from_user.id)
if answer is None:
await msg.reply_text("No response, aborting import.", quote=True)
return
if answer.text == "/cancel":
await answer.reply_text("Okay, aborting.")
return
if answer.document is None:
await answer.reply_text(
"File to import must be a zip archive. Aborting.", quote=True
)
return
if answer.document.mime_type != "application/zip":
await answer.reply_text(
"Provided file is not supported. Please send `application/zip`. Aborting.",
quote=True,
)
return
if disk_usage(getcwd())[2] < (answer.document.file_size) * 3:
await msg.reply_text(
f"You archive is `{answer.document.file_size//(2**30)} GiB` big, but system has only `{disk_usage(getcwd())[2]//(2**30)} GiB` free. Unpacking may take even more space. Aborting."
)
return
tmp_dir = str(uuid4())
logWrite(
f"Importing '{answer.document.file_name}' file {answer.document.file_size} bytes big (TMP ID {tmp_dir})"
)
makedirs(path.join(configGet("tmp", "locations"), tmp_dir), exist_ok=True)
tmp_path = path.join(configGet("tmp", "locations"), answer.document.file_id)
downloading = await answer.reply_text("Okay, downloading...", quote=True)
await app.download_media(answer, file_name=tmp_path)
await downloading.edit("Downloaded, unpacking...")
try:
with ZipFile(tmp_path, "r") as handle:
tasks = [
extract_and_save(
handle, name, path.join(configGet("tmp", "locations"), tmp_dir)
)
for name in handle.namelist()
]
_ = await asyncio.gather(*tasks)
except Exception as exp:
logWrite(
f"Could not import '{answer.document.file_name}' due to {exp}: {format_exc}"
)
await answer.reply_text(
f"Could not unpack the archive\n\nException: {exp}\n\nTraceback:\n```python\n{format_exc}\n```"
)
return
logWrite(f"Downloaded '{answer.document.file_name}' - awaiting upload")
await downloading.edit("Unpacked, uploading...")
remove(tmp_path)
for filename in iglob(
path.join(configGet("tmp", "locations"), tmp_dir) + "**/**", recursive=True
):
if not path.isfile(filename):
continue
# upload filename
uploaded = await upload_pic(filename)
if uploaded[0] is False:
logWrite(
f"Could not upload '{filename}' from '{path.join(configGet('tmp', 'locations'), tmp_dir)}'. Duplicates: {str(uploaded[1])}",
debug=True,
)
if len(uploaded[1]) > 0:
await msg.reply_text(
f"Could not upload `{path.basename(filename)}` because there're duplicates on server.",
disable_notification=True,
)
else:
await msg.reply_text(
f"Could not upload `{path.basename(filename)}`. Probably disallowed filetype",
disable_notification=True,
)
else:
logWrite(
f"Uploaded '{filename}' from '{path.join(configGet('tmp', 'locations'), tmp_dir)}' and got ID {uploaded[2]}",
debug=True,
)
await downloading.delete()
logWrite(
f"Removing '{path.join(configGet('tmp', 'locations'), tmp_dir)}' after uploading",
debug=True,
)
rmtree(path.join(configGet("tmp", "locations"), tmp_dir), ignore_errors=True)
await answer.reply_text("Done.", quote=True)
return

View File

@ -4,17 +4,17 @@ from traceback import format_exc
from uuid import uuid4
from pyrogram import filters
from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup, Message
from pyrogram.enums.chat_action import ChatAction
from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup, Message
from classes.enums.submission_types import SubmissionType
from classes.exceptions import SubmissionDuplicatesError
from classes.poster_client import PosterClient
from classes.user import PosterUser
from modules.app import app
from modules.app import app, users_with_context
from modules.database import col_banned, col_submitted
from modules.logger import logWrite
from modules.utils import configGet, locale
from classes.enums.submission_types import SubmissionType
@app.on_message(
@ -24,6 +24,9 @@ from classes.enums.submission_types import SubmissionType
| filters.document
)
async def get_submission(app: PosterClient, msg: Message):
global users_with_context
if msg.from_user.id in users_with_context:
return
try:
if col_banned.find_one({"user": msg.from_user.id}) is not None:
return