TelegramBot/classes/pyroclient.py
2024-05-26 22:56:11 +02:00

113 lines
3.6 KiB
Python

import logging
from datetime import datetime, timedelta
from typing import List, Union
from aiohttp import ClientSession
from apscheduler.triggers.cron import CronTrigger
from libbot.pyrogram.classes import PyroClient as LibPyroClient
from pymongo import ASCENDING, GEOSPHERE, TEXT
from pyrogram.types import User
from classes.location import Location
from classes.pyrouser import PyroUser
from classes.updater import Updater
from modules.database_api import col_locations
from modules.reminder import remind
logger = logging.getLogger(__name__)
class PyroClient(LibPyroClient):
def __init__(self, **kwargs):
self.__version__ = (0, 1, 2)
super().__init__(**kwargs)
self.updater = Updater(ClientSession())
self.contexts = []
if self.scheduler is not None:
self.scheduler.add_job(
remind, CronTrigger.from_crontab("* * * * *"), args=(self,)
)
if self.config["update_checker"]:
self.scheduler.add_job(
self.check_updates,
CronTrigger.from_crontab("0 12 */3 * *"),
next_run_time=datetime.now() + timedelta(seconds=10),
)
async def start(self, **kwargs):
await col_locations.create_index(
[("id", ASCENDING)], name="location_id", unique=True
)
await col_locations.create_index(
[("location", GEOSPHERE)],
name="location_location",
)
await col_locations.create_index([("name", TEXT)], name="location_name")
return await super().start(**kwargs)
async def stop(self, **kwargs):
await self.updater.client_session.close()
await super().stop(**kwargs)
async def find_user(self, user: Union[int, User]) -> PyroUser:
"""Find User by it's ID or User object.
### Args:
* user (`Union[int, User]`): ID or User object to extract ID from.
### Returns:
* `PyroUser`: User in database representation.
"""
return (
await PyroUser.find(user)
if isinstance(user, int)
else await PyroUser.find(user.id, locale=user.language_code)
)
async def get_location(self, id: int) -> Location:
"""Get Location by it's ID.
### Args:
* id (`int`): Location's ID. Defaults to `None`.
### Returns:
* `Location`: Location from database as an object.
"""
return await Location.get(id)
async def list_locations(self) -> List[Location]:
"""Get all locations stored in database.
### Returns:
* `List[Location]`: List of `Location` objects.
"""
return [
await Location.get(record["id"]) async for record in col_locations.find({})
]
async def check_updates(self) -> None:
if await self.updater.check_updates(
self.__version__, self.config["strings"]["url_updater"]
):
try:
release = await self.updater.get_latest_release(
self.config["strings"]["url_updater"]
)
except Exception as exc:
logger.error("Could not fetch the latest version: %s", exc)
return
await self.send_message(
self.owner,
self._("update_available", "messages").format(
version_current=f"v{'.'.join(str(subversion) for subversion in self.__version__)}",
version_new=release["tag_name"],
release_url=release["html_url"],
),
)