2023-08-27 23:43:16 +03:00
|
|
|
from typing import List, Union
|
|
|
|
|
|
|
|
from apscheduler.triggers.cron import CronTrigger
|
|
|
|
from libbot.pyrogram.classes import PyroClient as LibPyroClient
|
2023-08-28 16:41:20 +03:00
|
|
|
from pymongo import ASCENDING, GEOSPHERE, TEXT
|
2023-08-27 23:43:16 +03:00
|
|
|
from pyrogram.types import User
|
|
|
|
|
|
|
|
from classes.location import Location
|
|
|
|
from classes.pyrouser import PyroUser
|
|
|
|
from modules.database import col_locations
|
|
|
|
from modules.reminder import remind
|
|
|
|
|
|
|
|
|
|
|
|
class PyroClient(LibPyroClient):
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
if self.scheduler is not None:
|
|
|
|
self.scheduler.add_job(
|
|
|
|
remind, CronTrigger.from_crontab("* * * * *"), args=(self,)
|
|
|
|
)
|
2023-08-30 15:41:34 +03:00
|
|
|
self.contexts = []
|
2023-08-27 23:43:16 +03:00
|
|
|
|
|
|
|
async def start(self, **kwargs):
|
|
|
|
await col_locations.create_index(
|
|
|
|
[("id", ASCENDING)], name="location_id", unique=True
|
|
|
|
)
|
|
|
|
await col_locations.create_index(
|
2023-08-28 16:41:20 +03:00
|
|
|
[("location", GEOSPHERE)],
|
2023-08-27 23:43:16 +03:00
|
|
|
name="location_location",
|
|
|
|
)
|
2023-08-28 16:41:20 +03:00
|
|
|
await col_locations.create_index([("name", TEXT)], name="location_name")
|
2023-08-27 23:43:16 +03:00
|
|
|
return await super().start(**kwargs)
|
|
|
|
|
|
|
|
async def find_user(self, user: Union[int, User]) -> PyroUser:
|
|
|
|
"""Find User by it's ID or User object.
|
|
|
|
|
|
|
|
### Args:
|
|
|
|
* user (`Union[int, User]`): ID or User object to extract ID from.
|
|
|
|
|
|
|
|
### Returns:
|
|
|
|
* `PyroUser`: User in database representation.
|
|
|
|
"""
|
|
|
|
|
|
|
|
return (
|
|
|
|
await PyroUser.find(user)
|
|
|
|
if isinstance(user, int)
|
|
|
|
else await PyroUser.find(user.id, locale=user.language_code)
|
|
|
|
)
|
|
|
|
|
|
|
|
async def get_location(self, id: int) -> Location:
|
|
|
|
"""Get Location by it's ID.
|
|
|
|
|
|
|
|
### Args:
|
|
|
|
* id (`int`): Location's ID. Defaults to `None`.
|
|
|
|
|
|
|
|
### Returns:
|
|
|
|
* `Location`: Location from database as an object.
|
|
|
|
"""
|
|
|
|
|
|
|
|
return await Location.get(id)
|
|
|
|
|
|
|
|
async def list_locations(self) -> List[Location]:
|
|
|
|
"""Get all locations stored in database.
|
|
|
|
|
|
|
|
### Returns:
|
|
|
|
* `List[Location]`: List of `Location` objects.
|
|
|
|
"""
|
|
|
|
return [
|
|
|
|
await Location.get(record["id"]) async for record in col_locations.find({})
|
|
|
|
]
|