diff --git a/classes/pyroclient.py b/classes/pyroclient.py index 2ca2662..0d03584 100644 --- a/classes/pyroclient.py +++ b/classes/pyroclient.py @@ -32,6 +32,7 @@ from pyrogram.types import ( BotCommandScopeDefault, Message, ) +from pytimeparse.timeparse import timeparse from ujson import dumps, loads from classes.commandset import CommandSet @@ -50,6 +51,7 @@ from modules.api_client import ( from modules.database import col_submitted from modules.http_client import http_session from modules.scheduler import scheduler +from modules.sender import send_content logger = logging.getLogger(__name__) @@ -89,6 +91,8 @@ class PyroClient(Client): self.in_all_locales = self.bot_locale.in_all_locales self.in_every_locale = self.bot_locale.in_every_locale + self.sender_session = ClientSession() + async def start(self): await super().start() @@ -189,6 +193,25 @@ class PyroClient(Client): kwargs={"command_sets": await self.collect_commands()}, ) + if self.config["mode"]["post"]: + if self.config["posting"]["use_interval"]: + scheduler.add_job( + send_content, + "interval", + seconds=timeparse(self.config["posting"]["interval"]), + args=[self, self.sender_session], + ) + else: + for entry in self.config["posting"]["time"]: + dt_obj = datetime.strptime(entry, "%H:%M") + scheduler.add_job( + send_content, + "cron", + hour=dt_obj.hour, + minute=dt_obj.minute, + args=[self, self.sender_session], + ) + scheduler.start() except BadRequest: logger.warning("Unable to send message to report chat.") @@ -208,6 +231,7 @@ class PyroClient(Client): except BadRequest: logger.warning("Unable to send message to report chat.") await http_session.close() + await self.sender_session.close() await super().stop() logger.warning("Bot stopped with PID %s.", getpid()) diff --git a/modules/api_client.py b/modules/api_client.py index fac6487..515f28d 100644 --- a/modules/api_client.py +++ b/modules/api_client.py @@ -3,8 +3,10 @@ import logging from base64 import b64decode, b64encode from os import makedirs, path, sep from pathlib import Path +from typing import Union import aiofiles +from aiohttp import ClientSession from libbot import config_get, i18n, sync from photosapi_client import AuthenticatedClient, Client from photosapi_client.api.default.album_create_albums_post import ( @@ -48,15 +50,17 @@ from modules.http_client import http_session logger = logging.getLogger(__name__) -async def authorize() -> str: +async def authorize(custom_session: Union[ClientSession, None] = None) -> str: makedirs(await config_get("cache", "locations"), exist_ok=True) + session = http_session if custom_session is None else custom_session + if path.exists(await config_get("cache", "locations") + sep + "api_access") is True: async with aiofiles.open( await config_get("cache", "locations") + sep + "api_access", "rb" ) as file: token = b64decode(await file.read()).decode("utf-8") if ( - await http_session.get( + await session.get( await config_get("address", "posting", "api") + "/users/me/", headers={"Authorization": f"Bearer {token}"}, ) @@ -68,7 +72,7 @@ async def authorize() -> str: "username": await config_get("username", "posting", "api"), "password": await config_get("password", "posting", "api"), } - response = await http_session.post( + response = await session.post( await config_get("address", "posting", "api") + "/token", data=payload ) if not response.ok: diff --git a/modules/scheduler.py b/modules/scheduler.py index b01330d..a5eb79d 100644 --- a/modules/scheduler.py +++ b/modules/scheduler.py @@ -1,24 +1,3 @@ -from datetime import datetime - from apscheduler.schedulers.asyncio import AsyncIOScheduler -from libbot import sync -from pytimeparse.timeparse import timeparse - -# from modules.sender import send_content scheduler = AsyncIOScheduler() - -# if sync.config_get("post", "mode"): -# if sync.config_get("use_interval", "posting"): -# scheduler.add_job( -# send_content, -# "interval", -# seconds=timeparse(sync.config_get("interval", "posting")), -# args=[app], -# ) -# else: -# for entry in sync.config_get("time", "posting"): -# dt_obj = datetime.strptime(entry, "%H:%M") -# scheduler.add_job( -# send_content, "cron", hour=dt_obj.hour, minute=dt_obj.minute, args=[app] -# ) diff --git a/modules/sender.py b/modules/sender.py index b4f2845..82bcff1 100644 --- a/modules/sender.py +++ b/modules/sender.py @@ -7,20 +7,22 @@ from traceback import format_exc from uuid import uuid4 from PIL import Image import aiofiles +from aiohttp import ClientSession -from classes.pyroclient import PyroClient +from pyrogram.client import Client -from modules.api_client import authorize, http_session, photo_patch, photo_find, client +from modules.api_client import authorize, photo_patch, photo_find, client from modules.database import col_sent, col_submitted from photosapi_client.errors import UnexpectedStatus + logger = logging.getLogger(__name__) -async def send_content(app: PyroClient) -> None: +async def send_content(app: Client, http_session: ClientSession) -> None: try: try: - token = await authorize() + token = await authorize(http_session) except ValueError: await app.send_message( app.owner, @@ -29,7 +31,16 @@ async def send_content(app: PyroClient) -> None: return try: - pic = choice((await photo_find(album=app.config["posting"]["api"]["album"], caption="queue", page_size=app.config["posting"]["page_size"], client=client)).results) + pic = choice( + ( + await photo_find( + album=app.config["posting"]["api"]["album"], + caption="queue", + page_size=app.config["posting"]["page_size"], + client=client, + ) + ).results + ) except (KeyError, AttributeError, TypeError): logger.info(app._("post_empty", "console")) if app.config["reports"]["error"]: @@ -67,42 +78,42 @@ async def send_content(app: PyroClient) -> None: tmp_dir = str(uuid4()) - makedirs(path.join(app.config['locations']['tmp'], tmp_dir), exist_ok=True) + makedirs(path.join(app.config["locations"]["tmp"], tmp_dir), exist_ok=True) tmp_path = path.join(tmp_dir, pic.filename) async with aiofiles.open( - path.join(app.config['locations']['tmp'], tmp_path), "wb" + path.join(app.config["locations"]["tmp"], tmp_path), "wb" ) as out_file: await out_file.write(await response.read()) logger.info( - f'Candidate {pic.filename} ({pic.id}) is {path.getsize(path.join(app.config['locations']['tmp'], tmp_path))} bytes big', + f"Candidate {pic.filename} ({pic.id}) is {path.getsize(path.join(app.config['locations']['tmp'], tmp_path))} bytes big", ) - if path.getsize(path.join(app.config['locations']['tmp'], tmp_path)) > 5242880: - image = Image.open(path.join(app.config['locations']['tmp'], tmp_path)) + if path.getsize(path.join(app.config["locations"]["tmp"], tmp_path)) > 5242880: + image = Image.open(path.join(app.config["locations"]["tmp"], tmp_path)) width, height = image.size image = image.resize((int(width / 2), int(height / 2)), Image.ANTIALIAS) if tmp_path.lower().endswith(".jpeg") or tmp_path.lower().endswith(".jpg"): image.save( - path.join(app.config['locations']['tmp'], tmp_path), + path.join(app.config["locations"]["tmp"], tmp_path), "JPEG", optimize=True, quality=50, ) elif tmp_path.lower().endswith(".png"): image.save( - path.join(app.config['locations']['tmp'], tmp_path), + path.join(app.config["locations"]["tmp"], tmp_path), "PNG", optimize=True, compress_level=8, ) image.close() - if path.getsize(path.join(app.config['locations']['tmp'], tmp_path)) > 5242880: + if path.getsize(path.join(app.config["locations"]["tmp"], tmp_path)) > 5242880: rmtree( - path.join(app.config['locations']['tmp'], tmp_dir), ignore_errors=True + path.join(app.config["locations"]["tmp"], tmp_dir), ignore_errors=True ) raise BytesWarning @@ -120,7 +131,9 @@ async def send_content(app: PyroClient) -> None: and app.config["posting"]["submitted_caption"]["enabled"] and ( (submitted["user"] not in app.admins) - or (app.config["posting"]["submitted_caption"]["ignore_admins"] is False) + or ( + app.config["posting"]["submitted_caption"]["ignore_admins"] is False + ) ) ): caption = ( @@ -140,7 +153,7 @@ async def send_content(app: PyroClient) -> None: try: sent = await app.send_photo( app.config["posting"]["channel"], - path.join(app.config['locations']['tmp'], tmp_path), + path.join(app.config["locations"]["tmp"], tmp_path), caption=caption, disable_notification=app.config["posting"]["silent"], ) @@ -168,7 +181,7 @@ async def send_content(app: PyroClient) -> None: await photo_patch(id=pic.id, client=client, caption="sent") - rmtree(path.join(app.config['locations']['tmp'], tmp_dir), ignore_errors=True) + rmtree(path.join(app.config["locations"]["tmp"], tmp_dir), ignore_errors=True) logger.info( app._("post_sent", "console").format( @@ -188,7 +201,7 @@ async def send_content(app: PyroClient) -> None: ) try: rmtree( - path.join(app.config['locations']['tmp'], tmp_dir), ignore_errors=True + path.join(app.config["locations"]["tmp"], tmp_dir), ignore_errors=True ) except: pass