519 lines
18 KiB
Python
519 lines
18 KiB
Python
"""Module with class PycordEvent."""
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, tzinfo
|
|
from logging import Logger
|
|
from typing import Any, Dict, List, Optional
|
|
from zoneinfo import ZoneInfo
|
|
|
|
from bson import ObjectId
|
|
from libbot.cache.classes import Cache
|
|
from pymongo import DESCENDING
|
|
from pymongo.results import InsertOneResult
|
|
|
|
from classes.abstract import Cacheable
|
|
from classes.errors import EventNotFoundError
|
|
from modules.database import col_events
|
|
from modules.utils import get_logger, restore_from_cache
|
|
|
|
logger: Logger = get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class PycordEvent(Cacheable):
|
|
"""Object representation of an event in the database.
|
|
|
|
Attributes:
|
|
_id (ObjectId): ID of the event generated by the database.
|
|
name (str): Name of the event.
|
|
guild_id (int): Discord ID of the guild where the event takes place.
|
|
created (datetime): Date of event's creation in UTC.
|
|
ended (datetime | None): Date of the event's actual end in UTC.
|
|
is_cancelled (bool): Whether the event is cancelled.
|
|
creator_id (int): Discord ID of the creator.
|
|
starts (datetime): Date of the event's planned start in UTC.
|
|
ends (datetime): Date of the event's planned end in UTC.
|
|
thumbnail (Dict[str, Any] | None): Thumbnail to use for the event in format `{"id": thumbnail_id (int), "filename": thumbnail_filename (str)}`.
|
|
stage_ids (List[ObjectId]): Database ID's of the event's stages ordered in the completion order.
|
|
"""
|
|
|
|
__slots__ = (
|
|
"_id",
|
|
"name",
|
|
"guild_id",
|
|
"created",
|
|
"ended",
|
|
"is_cancelled",
|
|
"creator_id",
|
|
"starts",
|
|
"ends",
|
|
"thumbnail",
|
|
"stage_ids",
|
|
)
|
|
__short_name__ = "event"
|
|
__collection__ = col_events
|
|
|
|
_id: ObjectId
|
|
name: str
|
|
guild_id: int
|
|
created: datetime
|
|
ended: datetime | None
|
|
is_cancelled: bool
|
|
creator_id: int
|
|
starts: datetime
|
|
ends: datetime
|
|
thumbnail: Dict[str, Any] | None
|
|
stage_ids: List[ObjectId]
|
|
|
|
@classmethod
|
|
async def from_id(cls, event_id: str | ObjectId, cache: Optional[Cache] = None) -> "PycordEvent":
|
|
"""Find the event by its ID and construct PycordEvent from database entry.
|
|
|
|
Args:
|
|
event_id (str | ObjectId): ID of the event to look up.
|
|
cache (:obj:`Cache`, optional): Cache engine that will be used to fetch and update the cache.
|
|
|
|
Returns:
|
|
PycordEvent: Object of the found event.
|
|
|
|
Raises:
|
|
EventNotFoundError: Event with such ID does not exist.
|
|
InvalidId: Provided event ID is of invalid format.
|
|
"""
|
|
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache)
|
|
|
|
if cached_entry is not None:
|
|
return cls(**cls._entry_from_cache(cached_entry))
|
|
|
|
db_entry = await cls.__collection__.find_one(
|
|
{"_id": event_id if isinstance(event_id, ObjectId) else ObjectId(event_id)}
|
|
)
|
|
|
|
if db_entry is None:
|
|
raise EventNotFoundError(event_id=event_id)
|
|
|
|
if cache is not None:
|
|
cache.set_json(f"{cls.__short_name__}_{event_id}", cls._entry_to_cache(dict(db_entry)))
|
|
|
|
return cls(**db_entry)
|
|
|
|
@classmethod
|
|
async def from_name(
|
|
cls, event_name: str, guild_id: int, cache: Optional[Cache] = None
|
|
) -> "PycordEvent":
|
|
"""Find the event by its name and construct PycordEvent from database entry.
|
|
|
|
If multiple events with the same name exist, the one with the greatest start date will be returned.
|
|
|
|
Args:
|
|
event_name (str): Name of the event to look up.
|
|
guild_id (int): Discord ID of the guild where the event takes place.
|
|
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
|
|
|
|
Returns:
|
|
PycordEvent: Object of the found event.
|
|
|
|
Raises:
|
|
EventNotFoundError: Event with such name does not exist.
|
|
"""
|
|
db_entry: Dict[str, Any] | None = await cls.__collection__.find_one(
|
|
{"name": event_name, "guild_id": guild_id}, sort=[("starts", DESCENDING)]
|
|
)
|
|
|
|
if db_entry is None:
|
|
raise EventNotFoundError(event_name=event_name, guild_id=guild_id)
|
|
|
|
if cache is not None:
|
|
cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", cls._entry_to_cache(db_entry))
|
|
|
|
return cls(**db_entry)
|
|
|
|
@classmethod
|
|
async def create(
|
|
cls,
|
|
name: str,
|
|
guild_id: int,
|
|
creator_id: int,
|
|
starts: datetime,
|
|
ends: datetime,
|
|
thumbnail: Dict[str, Any] | None,
|
|
cache: Optional[Cache] = None,
|
|
) -> "PycordEvent":
|
|
"""Create an event, write it to the database and return the constructed PycordEvent object.
|
|
|
|
Creation date will be set to current time in UTC automatically.
|
|
|
|
Args:
|
|
name (str): Name of the event.
|
|
guild_id (int): Guild ID where the event takes place.
|
|
creator_id (int): Discord ID of the event creator.
|
|
starts (datetime): Date when the event starts. Must be UTC.
|
|
ends (datetime): Date when the event ends. Must be UTC.
|
|
thumbnail (:obj:`Dict[str, Any]`, optional): Thumbnail to use for the event in format `{"id": thumbnail_id (int), "filename": thumbnail_filename (str)}`.
|
|
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
|
|
|
|
Returns:
|
|
PycordEvent: Object of the created event.
|
|
"""
|
|
db_entry: Dict[str, Any] = {
|
|
"name": name,
|
|
"guild_id": guild_id,
|
|
"created": datetime.now(tz=ZoneInfo("UTC")),
|
|
"ended": None,
|
|
"is_cancelled": False,
|
|
"creator_id": creator_id,
|
|
"starts": starts,
|
|
"ends": ends,
|
|
"thumbnail": thumbnail,
|
|
"stage_ids": [],
|
|
}
|
|
|
|
insert_result: InsertOneResult = await cls.__collection__.insert_one(db_entry)
|
|
|
|
db_entry["_id"] = insert_result.inserted_id
|
|
|
|
if cache is not None:
|
|
cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry))
|
|
|
|
return cls(**db_entry)
|
|
|
|
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
|
|
for key, value in kwargs.items():
|
|
if not hasattr(self, key):
|
|
raise AttributeError(f"Attribute '{key}' does not exist in PycordEvent")
|
|
|
|
setattr(self, key, value)
|
|
|
|
await self.__collection__.update_one({"_id": self._id}, {"$set": kwargs}, upsert=True)
|
|
|
|
self._update_cache(cache)
|
|
|
|
logger.info("Set attributes of event %s to %s", self._id, kwargs)
|
|
|
|
async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None:
|
|
attributes: Dict[str, Any] = {}
|
|
|
|
for key in args:
|
|
if not hasattr(self, key):
|
|
raise AttributeError(f"Attribute '{key}' does not exist in PycordEvent")
|
|
|
|
default_value: Any = self.get_default_value(key)
|
|
|
|
setattr(self, key, default_value)
|
|
|
|
attributes[key] = default_value
|
|
|
|
await self.__collection__.update_one({"_id": self._id}, {"$set": attributes}, upsert=True)
|
|
|
|
self._update_cache(cache)
|
|
|
|
logger.info("Reset attributes %s of event %s to default values", args, self._id)
|
|
|
|
def _get_cache_key(self) -> str:
|
|
return f"{self.__short_name__}_{self._id}"
|
|
|
|
def _update_cache(self, cache: Optional[Cache] = None) -> None:
|
|
if cache is None:
|
|
return
|
|
|
|
object_dict: Dict[str, Any] = self.to_dict(json_compatible=True)
|
|
|
|
if object_dict is not None:
|
|
cache.set_json(self._get_cache_key(), object_dict)
|
|
else:
|
|
self._delete_cache(cache)
|
|
|
|
def _delete_cache(self, cache: Optional[Cache] = None) -> None:
|
|
if cache is None:
|
|
return
|
|
|
|
cache.delete(self._get_cache_key())
|
|
|
|
async def _update_event_stage_order(
|
|
self,
|
|
bot: Any,
|
|
old_stage_ids: List[ObjectId],
|
|
cache: Optional[Cache] = None,
|
|
) -> None:
|
|
logger.info("Updating event stages order for %s...", self._id)
|
|
|
|
logger.debug("Old stage IDs: %s", old_stage_ids)
|
|
logger.debug("New stage IDs: %s", self.stage_ids)
|
|
|
|
for event_stage_id in self.stage_ids:
|
|
if event_stage_id not in old_stage_ids:
|
|
continue
|
|
|
|
stage_index: int = self.stage_ids.index(event_stage_id)
|
|
old_stage_index: int = old_stage_ids.index(event_stage_id)
|
|
|
|
logger.debug(
|
|
"Indexes for %s: was %s and is now %s", event_stage_id, old_stage_index, stage_index
|
|
)
|
|
|
|
if stage_index != old_stage_index:
|
|
await (await bot.find_event_stage(event_stage_id)).update(cache, sequence=stage_index)
|
|
|
|
@staticmethod
|
|
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
|
|
cache_entry: Dict[str, Any] = db_entry.copy()
|
|
|
|
cache_entry["_id"] = str(cache_entry["_id"])
|
|
cache_entry["created"] = cache_entry["created"].isoformat()
|
|
cache_entry["ended"] = None if cache_entry["ended"] is None else cache_entry["ended"].isoformat()
|
|
cache_entry["starts"] = cache_entry["starts"].isoformat()
|
|
cache_entry["ends"] = cache_entry["ends"].isoformat()
|
|
cache_entry["stage_ids"] = [str(stage_id) for stage_id in cache_entry["stage_ids"]]
|
|
|
|
return cache_entry
|
|
|
|
@staticmethod
|
|
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
|
|
db_entry: Dict[str, Any] = cache_entry.copy()
|
|
|
|
db_entry["_id"] = ObjectId(db_entry["_id"])
|
|
db_entry["created"] = datetime.fromisoformat(db_entry["created"])
|
|
db_entry["ended"] = None if db_entry["ended"] is None else datetime.fromisoformat(db_entry["ended"])
|
|
db_entry["starts"] = datetime.fromisoformat(db_entry["starts"])
|
|
db_entry["ends"] = datetime.fromisoformat(db_entry["ends"])
|
|
db_entry["stage_ids"] = [ObjectId(stage_id) for stage_id in db_entry["stage_ids"]]
|
|
|
|
return db_entry
|
|
|
|
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
|
|
"""Convert the object to a JSON representation.
|
|
|
|
Args:
|
|
json_compatible (bool): Whether the JSON-incompatible objects like ObjectId need to be converted.
|
|
|
|
Returns:
|
|
Dict[str, Any]: JSON representation of the object.
|
|
"""
|
|
return {
|
|
"_id": self._id if not json_compatible else str(self._id),
|
|
"name": self.name,
|
|
"guild_id": self.guild_id,
|
|
"created": self.created if not json_compatible else self.created.isoformat(),
|
|
"ended": (
|
|
self.ended
|
|
if not json_compatible
|
|
else (None if self.ended is None else self.ended.isoformat())
|
|
),
|
|
"is_cancelled": self.is_cancelled,
|
|
"creator_id": self.creator_id,
|
|
"starts": self.starts if not json_compatible else self.starts.isoformat(),
|
|
"ends": self.ends if not json_compatible else self.ends.isoformat(),
|
|
"thumbnail": self.thumbnail,
|
|
"stage_ids": (
|
|
self.stage_ids if not json_compatible else [str(stage_id) for stage_id in self.stage_ids]
|
|
),
|
|
}
|
|
|
|
@staticmethod
|
|
def get_defaults() -> Dict[str, Any]:
|
|
"""Get default values for the object attributes.
|
|
|
|
Returns:
|
|
Dict[str, Any]: Mapping of attributes and their respective values in format `{"attribute_name:" attribute_value}`.
|
|
"""
|
|
return {
|
|
"name": None,
|
|
"guild_id": None,
|
|
"created": None,
|
|
"ended": None,
|
|
"is_cancelled": False,
|
|
"creator_id": None,
|
|
"starts": None,
|
|
"ends": None,
|
|
"thumbnail": None,
|
|
"stage_ids": [],
|
|
}
|
|
|
|
@staticmethod
|
|
def get_default_value(key: str) -> Any:
|
|
"""Get default value of the attribute for the object.
|
|
|
|
Args:
|
|
key (str): Name of the attribute.
|
|
|
|
Returns:
|
|
Any: Default value of the attribute.
|
|
|
|
Raises:
|
|
KeyError: There's no default value for the provided attribute.
|
|
"""
|
|
if key not in PycordEvent.get_defaults():
|
|
raise KeyError(f"There's no default value for key '{key}' in PycordEvent")
|
|
|
|
return PycordEvent.get_defaults()[key]
|
|
|
|
async def update(
|
|
self,
|
|
cache: Optional[Cache] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Update attribute(s) on the object and save the updated entry into the database.
|
|
|
|
Args:
|
|
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
|
|
**kwargs (Any): Mapping of attributes in format `attribute_name=attribute_value` to update.
|
|
|
|
Raises:
|
|
AttributeError: Provided attribute does not exist in the class.
|
|
"""
|
|
await self._set(cache=cache, **kwargs)
|
|
|
|
async def reset(
|
|
self,
|
|
*args: str,
|
|
cache: Optional[Cache] = None,
|
|
) -> None:
|
|
"""Remove attribute(s) on the object, replace them with a default value and save the updated entry into the database.
|
|
|
|
Args:
|
|
*args (str): List of attributes to remove.
|
|
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
|
|
|
|
Raises:
|
|
AttributeError: Provided attribute does not exist in the class.
|
|
"""
|
|
await self._remove(*args, cache=cache)
|
|
|
|
async def purge(self, cache: Optional[Cache] = None) -> None:
|
|
"""Completely remove event data from database. Currently only removes the event record from events collection.
|
|
|
|
Args:
|
|
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
|
|
"""
|
|
await self.__collection__.delete_one({"_id": self._id})
|
|
self._delete_cache(cache)
|
|
|
|
async def cancel(self, cache: Optional[Cache] = None) -> None:
|
|
"""Cancel the event.
|
|
|
|
Attribute `is_cancelled` will be set to `True`.
|
|
|
|
Args:
|
|
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
|
|
"""
|
|
await self._set(cache, is_cancelled=True)
|
|
|
|
async def end(self, cache: Optional[Cache] = None) -> None:
|
|
"""End the event.
|
|
|
|
Attribute `ended` will be set to the current date in UTC.
|
|
|
|
Args:
|
|
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
|
|
"""
|
|
await self._set(cache, ended=datetime.now(tz=ZoneInfo("UTC")))
|
|
|
|
async def insert_stage(
|
|
self, bot: "PycordBot", event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None
|
|
) -> None:
|
|
"""Insert a stage at the provided index.
|
|
|
|
Args:
|
|
bot (PycordBot): Bot object.
|
|
event_stage_id (ObjectId): Stage ID to be inserted.
|
|
index (int): Index to be inserted at.
|
|
cache: Cache engine that will be used to update the cache.
|
|
"""
|
|
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
|
|
|
|
self.stage_ids.insert(index, event_stage_id)
|
|
|
|
await self._set(cache, stage_ids=self.stage_ids)
|
|
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
|
|
|
|
async def reorder_stage(
|
|
self, bot: "PycordBot", event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None
|
|
) -> None:
|
|
"""Reorder a stage to the provided index.
|
|
|
|
Args:
|
|
bot (PycordBot): Bot object.
|
|
event_stage_id (ObjectId): Stage ID to be reordered.
|
|
index (int): Index to be reordered to.
|
|
cache: Cache engine that will be used to update the cache.
|
|
"""
|
|
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
|
|
|
|
self.stage_ids.insert(index, self.stage_ids.pop(self.stage_ids.index(event_stage_id)))
|
|
|
|
await self._set(cache, stage_ids=self.stage_ids)
|
|
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
|
|
|
|
async def remove_stage(
|
|
self, bot: "PycordBot", event_stage_id: ObjectId, cache: Optional[Cache] = None
|
|
) -> None:
|
|
"""Remove a stage from the event.
|
|
|
|
Args:
|
|
bot (PycordBot): Bot object.
|
|
event_stage_id (ObjectId): Stage ID to be reordered.
|
|
cache: Cache engine that will be used to update the cache.
|
|
"""
|
|
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
|
|
|
|
self.stage_ids.pop(self.stage_ids.index(event_stage_id))
|
|
|
|
await self._set(cache, stage_ids=self.stage_ids)
|
|
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
|
|
|
|
def get_start_date_utc(self) -> datetime:
|
|
"""Get the event start date in UTC timezone.
|
|
|
|
Returns:
|
|
datetime: Start date in UTC.
|
|
|
|
Raises:
|
|
ValueError: Event does not have a start date.
|
|
"""
|
|
if self.starts is None:
|
|
raise ValueError("Event does not have a start date")
|
|
|
|
return self.starts.replace(tzinfo=ZoneInfo("UTC"))
|
|
|
|
def get_end_date_utc(self) -> datetime:
|
|
"""Get the event end date in UTC timezone.
|
|
|
|
Returns:
|
|
datetime: End date in UTC.
|
|
|
|
Raises:
|
|
ValueError: Event does not have an end date.
|
|
"""
|
|
if self.ends is None:
|
|
raise ValueError("Event does not have an end date")
|
|
|
|
return self.ends.replace(tzinfo=ZoneInfo("UTC"))
|
|
|
|
def get_start_date_localized(self, tz: tzinfo) -> datetime:
|
|
"""Get the event start date in the provided timezone.
|
|
|
|
Returns:
|
|
datetime: Start date in the provided timezone.
|
|
|
|
Raises:
|
|
ValueError: Event does not have a start date.
|
|
"""
|
|
if self.starts is None:
|
|
raise ValueError("Event does not have a start date")
|
|
|
|
return self.starts.replace(tzinfo=tz)
|
|
|
|
def get_end_date_localized(self, tz: tzinfo) -> datetime:
|
|
"""Get the event end date in the provided timezone.
|
|
|
|
Returns:
|
|
datetime: End date in the provided timezone.
|
|
|
|
Raises:
|
|
ValueError: Event does not have an end date.
|
|
"""
|
|
if self.ends is None:
|
|
raise ValueError("Event does not have an end date")
|
|
|
|
return self.ends.replace(tzinfo=tz)
|