import logging from datetime import datetime, timedelta from typing import List, Union from aiohttp import ClientSession from apscheduler.triggers.cron import CronTrigger from libbot.pyrogram.classes import PyroClient as LibPyroClient from pymongo import ASCENDING, GEOSPHERE, TEXT from pyrogram.types import User from classes.location import Location from classes.pyrouser import PyroUser from classes.updater import Updater from modules.database_api import col_locations from modules.reminder import remind logger = logging.getLogger(__name__) class PyroClient(LibPyroClient): def __init__(self, **kwargs): self.__version__ = (0, 1, 2) super().__init__(**kwargs) self.updater = Updater(ClientSession()) self.contexts = [] if self.scheduler is not None: self.scheduler.add_job( remind, CronTrigger.from_crontab("* * * * *"), args=(self,) ) if self.config["update_checker"]: self.scheduler.add_job( self.check_updates, CronTrigger.from_crontab("0 12 */3 * *"), next_run_time=datetime.now() + timedelta(seconds=10), ) async def start(self, **kwargs): await col_locations.create_index( [("id", ASCENDING)], name="location_id", unique=True ) await col_locations.create_index( [("location", GEOSPHERE)], name="location_location", ) await col_locations.create_index([("name", TEXT)], name="location_name") return await super().start(**kwargs) async def stop(self, **kwargs): await self.updater.client_session.close() await super().stop(**kwargs) async def find_user(self, user: Union[int, User]) -> PyroUser: """Find User by it's ID or User object. ### Args: * user (`Union[int, User]`): ID or User object to extract ID from. ### Returns: * `PyroUser`: User in database representation. """ return ( await PyroUser.find(user) if isinstance(user, int) else await PyroUser.find(user.id, locale=user.language_code) ) async def get_location(self, id: int) -> Location: """Get Location by it's ID. ### Args: * id (`int`): Location's ID. Defaults to `None`. ### Returns: * `Location`: Location from database as an object. """ return await Location.get(id) async def list_locations(self) -> List[Location]: """Get all locations stored in database. ### Returns: * `List[Location]`: List of `Location` objects. """ return [ await Location.get(record["id"]) async for record in col_locations.find({}) ] async def check_updates(self) -> None: if await self.updater.check_updates( self.__version__, self.config["strings"]["url_updater"] ): try: release = await self.updater.get_latest_release( self.config["strings"]["url_updater"] ) except Exception as exc: logger.error("Could not fetch the latest version: %s", exc) return await self.send_message( self.owner, self._("update_available", "messages").format( version_current=f"v{'.'.join(str(subversion) for subversion in self.__version__)}", version_new=release["tag_name"], release_url=release["html_url"], ), )