Files
QuizBot/classes/pycord_event.py

519 lines
18 KiB
Python

"""Module with class PycordEvent."""
from dataclasses import dataclass
from datetime import datetime, tzinfo
from logging import Logger
from typing import Any, Dict, List, Optional
from zoneinfo import ZoneInfo
from bson import ObjectId
from libbot.cache.classes import Cache
from pymongo import DESCENDING
from pymongo.results import InsertOneResult
from classes.abstract import Cacheable
from classes.errors import EventNotFoundError
from modules.database import col_events
from modules.utils import get_logger, restore_from_cache
logger: Logger = get_logger(__name__)
@dataclass
class PycordEvent(Cacheable):
"""Object representation of an event in the database.
Attributes:
_id (ObjectId): ID of the event generated by the database.
name (str): Name of the event.
guild_id (int): Discord ID of the guild where the event takes place.
created (datetime): Date of event's creation in UTC.
ended (datetime | None): Date of the event's actual end in UTC.
is_cancelled (bool): Whether the event is cancelled.
creator_id (int): Discord ID of the creator.
starts (datetime): Date of the event's planned start in UTC.
ends (datetime): Date of the event's planned end in UTC.
thumbnail (Dict[str, Any] | None): Thumbnail to use for the event in format `{"id": thumbnail_id (int), "filename": thumbnail_filename (str)}`.
stage_ids (List[ObjectId]): Database ID's of the event's stages ordered in the completion order.
"""
__slots__ = (
"_id",
"name",
"guild_id",
"created",
"ended",
"is_cancelled",
"creator_id",
"starts",
"ends",
"thumbnail",
"stage_ids",
)
__short_name__ = "event"
__collection__ = col_events
_id: ObjectId
name: str
guild_id: int
created: datetime
ended: datetime | None
is_cancelled: bool
creator_id: int
starts: datetime
ends: datetime
thumbnail: Dict[str, Any] | None
stage_ids: List[ObjectId]
@classmethod
async def from_id(cls, event_id: str | ObjectId, cache: Optional[Cache] = None) -> "PycordEvent":
"""Find the event by its ID and construct PycordEvent from database entry.
Args:
event_id (str | ObjectId): ID of the event to look up.
cache (:obj:`Cache`, optional): Cache engine that will be used to fetch and update the cache.
Returns:
PycordEvent: Object of the found event.
Raises:
EventNotFoundError: Event with such ID does not exist.
InvalidId: Provided event ID is of invalid format.
"""
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache)
if cached_entry is not None:
return cls(**cls._entry_from_cache(cached_entry))
db_entry = await cls.__collection__.find_one(
{"_id": event_id if isinstance(event_id, ObjectId) else ObjectId(event_id)}
)
if db_entry is None:
raise EventNotFoundError(event_id=event_id)
if cache is not None:
cache.set_json(f"{cls.__short_name__}_{event_id}", cls._entry_to_cache(dict(db_entry)))
return cls(**db_entry)
@classmethod
async def from_name(
cls, event_name: str, guild_id: int, cache: Optional[Cache] = None
) -> "PycordEvent":
"""Find the event by its name and construct PycordEvent from database entry.
If multiple events with the same name exist, the one with the greatest start date will be returned.
Args:
event_name (str): Name of the event to look up.
guild_id (int): Discord ID of the guild where the event takes place.
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
Returns:
PycordEvent: Object of the found event.
Raises:
EventNotFoundError: Event with such name does not exist.
"""
db_entry: Dict[str, Any] | None = await cls.__collection__.find_one(
{"name": event_name, "guild_id": guild_id}, sort=[("starts", DESCENDING)]
)
if db_entry is None:
raise EventNotFoundError(event_name=event_name, guild_id=guild_id)
if cache is not None:
cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", cls._entry_to_cache(db_entry))
return cls(**db_entry)
@classmethod
async def create(
cls,
name: str,
guild_id: int,
creator_id: int,
starts: datetime,
ends: datetime,
thumbnail: Dict[str, Any] | None,
cache: Optional[Cache] = None,
) -> "PycordEvent":
"""Create an event, write it to the database and return the constructed PycordEvent object.
Creation date will be set to current time in UTC automatically.
Args:
name (str): Name of the event.
guild_id (int): Guild ID where the event takes place.
creator_id (int): Discord ID of the event creator.
starts (datetime): Date when the event starts. Must be UTC.
ends (datetime): Date when the event ends. Must be UTC.
thumbnail (:obj:`Dict[str, Any]`, optional): Thumbnail to use for the event in format `{"id": thumbnail_id (int), "filename": thumbnail_filename (str)}`.
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
Returns:
PycordEvent: Object of the created event.
"""
db_entry: Dict[str, Any] = {
"name": name,
"guild_id": guild_id,
"created": datetime.now(tz=ZoneInfo("UTC")),
"ended": None,
"is_cancelled": False,
"creator_id": creator_id,
"starts": starts,
"ends": ends,
"thumbnail": thumbnail,
"stage_ids": [],
}
insert_result: InsertOneResult = await cls.__collection__.insert_one(db_entry)
db_entry["_id"] = insert_result.inserted_id
if cache is not None:
cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry))
return cls(**db_entry)
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
for key, value in kwargs.items():
if not hasattr(self, key):
raise AttributeError(f"Attribute '{key}' does not exist in PycordEvent")
setattr(self, key, value)
await self.__collection__.update_one({"_id": self._id}, {"$set": kwargs}, upsert=True)
self._update_cache(cache)
logger.info("Set attributes of event %s to %s", self._id, kwargs)
async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None:
attributes: Dict[str, Any] = {}
for key in args:
if not hasattr(self, key):
raise AttributeError(f"Attribute '{key}' does not exist in PycordEvent")
default_value: Any = self.get_default_value(key)
setattr(self, key, default_value)
attributes[key] = default_value
await self.__collection__.update_one({"_id": self._id}, {"$set": attributes}, upsert=True)
self._update_cache(cache)
logger.info("Reset attributes %s of event %s to default values", args, self._id)
def _get_cache_key(self) -> str:
return f"{self.__short_name__}_{self._id}"
def _update_cache(self, cache: Optional[Cache] = None) -> None:
if cache is None:
return
object_dict: Dict[str, Any] = self.to_dict(json_compatible=True)
if object_dict is not None:
cache.set_json(self._get_cache_key(), object_dict)
else:
self._delete_cache(cache)
def _delete_cache(self, cache: Optional[Cache] = None) -> None:
if cache is None:
return
cache.delete(self._get_cache_key())
async def _update_event_stage_order(
self,
bot: Any,
old_stage_ids: List[ObjectId],
cache: Optional[Cache] = None,
) -> None:
logger.info("Updating event stages order for %s...", self._id)
logger.debug("Old stage IDs: %s", old_stage_ids)
logger.debug("New stage IDs: %s", self.stage_ids)
for event_stage_id in self.stage_ids:
if event_stage_id not in old_stage_ids:
continue
stage_index: int = self.stage_ids.index(event_stage_id)
old_stage_index: int = old_stage_ids.index(event_stage_id)
logger.debug(
"Indexes for %s: was %s and is now %s", event_stage_id, old_stage_index, stage_index
)
if stage_index != old_stage_index:
await (await bot.find_event_stage(event_stage_id)).update(cache, sequence=stage_index)
@staticmethod
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
cache_entry: Dict[str, Any] = db_entry.copy()
cache_entry["_id"] = str(cache_entry["_id"])
cache_entry["created"] = cache_entry["created"].isoformat()
cache_entry["ended"] = None if cache_entry["ended"] is None else cache_entry["ended"].isoformat()
cache_entry["starts"] = cache_entry["starts"].isoformat()
cache_entry["ends"] = cache_entry["ends"].isoformat()
cache_entry["stage_ids"] = [str(stage_id) for stage_id in cache_entry["stage_ids"]]
return cache_entry
@staticmethod
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
db_entry: Dict[str, Any] = cache_entry.copy()
db_entry["_id"] = ObjectId(db_entry["_id"])
db_entry["created"] = datetime.fromisoformat(db_entry["created"])
db_entry["ended"] = None if db_entry["ended"] is None else datetime.fromisoformat(db_entry["ended"])
db_entry["starts"] = datetime.fromisoformat(db_entry["starts"])
db_entry["ends"] = datetime.fromisoformat(db_entry["ends"])
db_entry["stage_ids"] = [ObjectId(stage_id) for stage_id in db_entry["stage_ids"]]
return db_entry
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
"""Convert the object to a JSON representation.
Args:
json_compatible (bool): Whether the JSON-incompatible objects like ObjectId need to be converted.
Returns:
Dict[str, Any]: JSON representation of the object.
"""
return {
"_id": self._id if not json_compatible else str(self._id),
"name": self.name,
"guild_id": self.guild_id,
"created": self.created if not json_compatible else self.created.isoformat(),
"ended": (
self.ended
if not json_compatible
else (None if self.ended is None else self.ended.isoformat())
),
"is_cancelled": self.is_cancelled,
"creator_id": self.creator_id,
"starts": self.starts if not json_compatible else self.starts.isoformat(),
"ends": self.ends if not json_compatible else self.ends.isoformat(),
"thumbnail": self.thumbnail,
"stage_ids": (
self.stage_ids if not json_compatible else [str(stage_id) for stage_id in self.stage_ids]
),
}
@staticmethod
def get_defaults() -> Dict[str, Any]:
"""Get default values for the object attributes.
Returns:
Dict[str, Any]: Mapping of attributes and their respective values in format `{"attribute_name:" attribute_value}`.
"""
return {
"name": None,
"guild_id": None,
"created": None,
"ended": None,
"is_cancelled": False,
"creator_id": None,
"starts": None,
"ends": None,
"thumbnail": None,
"stage_ids": [],
}
@staticmethod
def get_default_value(key: str) -> Any:
"""Get default value of the attribute for the object.
Args:
key (str): Name of the attribute.
Returns:
Any: Default value of the attribute.
Raises:
KeyError: There's no default value for the provided attribute.
"""
if key not in PycordEvent.get_defaults():
raise KeyError(f"There's no default value for key '{key}' in PycordEvent")
return PycordEvent.get_defaults()[key]
async def update(
self,
cache: Optional[Cache] = None,
**kwargs: Any,
) -> None:
"""Update attribute(s) on the object and save the updated entry into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
**kwargs (Any): Mapping of attributes in format `attribute_name=attribute_value` to update.
Raises:
AttributeError: Provided attribute does not exist in the class.
"""
await self._set(cache=cache, **kwargs)
async def reset(
self,
*args: str,
cache: Optional[Cache] = None,
) -> None:
"""Remove attribute(s) on the object, replace them with a default value and save the updated entry into the database.
Args:
*args (str): List of attributes to remove.
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
Raises:
AttributeError: Provided attribute does not exist in the class.
"""
await self._remove(*args, cache=cache)
async def purge(self, cache: Optional[Cache] = None) -> None:
"""Completely remove event data from database. Currently only removes the event record from events collection.
Args:
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
"""
await self.__collection__.delete_one({"_id": self._id})
self._delete_cache(cache)
async def cancel(self, cache: Optional[Cache] = None) -> None:
"""Cancel the event.
Attribute `is_cancelled` will be set to `True`.
Args:
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
"""
await self._set(cache, is_cancelled=True)
async def end(self, cache: Optional[Cache] = None) -> None:
"""End the event.
Attribute `ended` will be set to the current date in UTC.
Args:
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
"""
await self._set(cache, ended=datetime.now(tz=ZoneInfo("UTC")))
async def insert_stage(
self, bot: "PycordBot", event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None
) -> None:
"""Insert a stage at the provided index.
Args:
bot (PycordBot): Bot object.
event_stage_id (ObjectId): Stage ID to be inserted.
index (int): Index to be inserted at.
cache: Cache engine that will be used to update the cache.
"""
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
self.stage_ids.insert(index, event_stage_id)
await self._set(cache, stage_ids=self.stage_ids)
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
async def reorder_stage(
self, bot: "PycordBot", event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None
) -> None:
"""Reorder a stage to the provided index.
Args:
bot (PycordBot): Bot object.
event_stage_id (ObjectId): Stage ID to be reordered.
index (int): Index to be reordered to.
cache: Cache engine that will be used to update the cache.
"""
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
self.stage_ids.insert(index, self.stage_ids.pop(self.stage_ids.index(event_stage_id)))
await self._set(cache, stage_ids=self.stage_ids)
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
async def remove_stage(
self, bot: "PycordBot", event_stage_id: ObjectId, cache: Optional[Cache] = None
) -> None:
"""Remove a stage from the event.
Args:
bot (PycordBot): Bot object.
event_stage_id (ObjectId): Stage ID to be reordered.
cache: Cache engine that will be used to update the cache.
"""
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
self.stage_ids.pop(self.stage_ids.index(event_stage_id))
await self._set(cache, stage_ids=self.stage_ids)
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
def get_start_date_utc(self) -> datetime:
"""Get the event start date in UTC timezone.
Returns:
datetime: Start date in UTC.
Raises:
ValueError: Event does not have a start date.
"""
if self.starts is None:
raise ValueError("Event does not have a start date")
return self.starts.replace(tzinfo=ZoneInfo("UTC"))
def get_end_date_utc(self) -> datetime:
"""Get the event end date in UTC timezone.
Returns:
datetime: End date in UTC.
Raises:
ValueError: Event does not have an end date.
"""
if self.ends is None:
raise ValueError("Event does not have an end date")
return self.ends.replace(tzinfo=ZoneInfo("UTC"))
def get_start_date_localized(self, tz: tzinfo) -> datetime:
"""Get the event start date in the provided timezone.
Returns:
datetime: Start date in the provided timezone.
Raises:
ValueError: Event does not have a start date.
"""
if self.starts is None:
raise ValueError("Event does not have a start date")
return self.starts.replace(tzinfo=tz)
def get_end_date_localized(self, tz: tzinfo) -> datetime:
"""Get the event end date in the provided timezone.
Returns:
datetime: End date in the provided timezone.
Raises:
ValueError: Event does not have an end date.
"""
if self.ends is None:
raise ValueError("Event does not have an end date")
return self.ends.replace(tzinfo=tz)