"""Module with class PycordEvent.""" from dataclasses import dataclass from datetime import datetime, tzinfo from logging import Logger from typing import Any, Dict, List, Optional from zoneinfo import ZoneInfo from bson import ObjectId from libbot.cache.classes import Cache from pymongo import DESCENDING from pymongo.results import InsertOneResult from classes.abstract import Cacheable from classes.errors import EventNotFoundError from modules.database import col_events from modules.utils import get_logger, restore_from_cache logger: Logger = get_logger(__name__) @dataclass class PycordEvent(Cacheable): """Object representation of an event in the database. Attributes: _id (ObjectId): ID of the event generated by the database. name (str): Name of the event. guild_id (int): Discord ID of the guild where the event takes place. created (datetime): Date of event's creation in UTC. ended (datetime | None): Date of the event's actual end in UTC. is_cancelled (bool): Whether the event is cancelled. creator_id (int): Discord ID of the creator. starts (datetime): Date of the event's planned start in UTC. ends (datetime): Date of the event's planned end in UTC. thumbnail (Dict[str, Any] | None): Thumbnail to use for the event in format `{"id": thumbnail_id (int), "filename": thumbnail_filename (str)}`. stage_ids (List[ObjectId]): Database ID's of the event's stages ordered in the completion order. """ __slots__ = ( "_id", "name", "guild_id", "created", "ended", "is_cancelled", "creator_id", "starts", "ends", "thumbnail", "stage_ids", ) __short_name__ = "event" __collection__ = col_events _id: ObjectId name: str guild_id: int created: datetime ended: datetime | None is_cancelled: bool creator_id: int starts: datetime ends: datetime thumbnail: Dict[str, Any] | None stage_ids: List[ObjectId] @classmethod async def from_id(cls, event_id: str | ObjectId, cache: Optional[Cache] = None) -> "PycordEvent": """Find the event by its ID and construct PycordEvent from database entry. Args: event_id (str | ObjectId): ID of the event to look up. cache (:obj:`Cache`, optional): Cache engine that will be used to fetch and update the cache. Returns: PycordEvent: Object of the found event. Raises: EventNotFoundError: Event with such ID does not exist. InvalidId: Provided event ID is of invalid format. """ cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache) if cached_entry is not None: return cls(**cls._entry_from_cache(cached_entry)) db_entry = await cls.__collection__.find_one( {"_id": event_id if isinstance(event_id, ObjectId) else ObjectId(event_id)} ) if db_entry is None: raise EventNotFoundError(event_id=event_id) if cache is not None: cache.set_json(f"{cls.__short_name__}_{event_id}", cls._entry_to_cache(dict(db_entry))) return cls(**db_entry) @classmethod async def from_name( cls, event_name: str, guild_id: int, cache: Optional[Cache] = None ) -> "PycordEvent": """Find the event by its name and construct PycordEvent from database entry. If multiple events with the same name exist, the one with the greatest start date will be returned. Args: event_name (str): Name of the event to look up. guild_id (int): Discord ID of the guild where the event takes place. cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache. Returns: PycordEvent: Object of the found event. Raises: EventNotFoundError: Event with such name does not exist. """ db_entry: Dict[str, Any] | None = await cls.__collection__.find_one( {"name": event_name, "guild_id": guild_id}, sort=[("starts", DESCENDING)] ) if db_entry is None: raise EventNotFoundError(event_name=event_name, guild_id=guild_id) if cache is not None: cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", cls._entry_to_cache(db_entry)) return cls(**db_entry) @classmethod async def create( cls, name: str, guild_id: int, creator_id: int, starts: datetime, ends: datetime, thumbnail: Dict[str, Any] | None, cache: Optional[Cache] = None, ) -> "PycordEvent": """Create an event, write it to the database and return the constructed PycordEvent object. Creation date will be set to current time in UTC automatically. Args: name (str): Name of the event. guild_id (int): Guild ID where the event takes place. creator_id (int): Discord ID of the event creator. starts (datetime): Date when the event starts. Must be UTC. ends (datetime): Date when the event ends. Must be UTC. thumbnail (:obj:`Dict[str, Any]`, optional): Thumbnail to use for the event in format `{"id": thumbnail_id (int), "filename": thumbnail_filename (str)}`. cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache. Returns: PycordEvent: Object of the created event. """ db_entry: Dict[str, Any] = { "name": name, "guild_id": guild_id, "created": datetime.now(tz=ZoneInfo("UTC")), "ended": None, "is_cancelled": False, "creator_id": creator_id, "starts": starts, "ends": ends, "thumbnail": thumbnail, "stage_ids": [], } insert_result: InsertOneResult = await cls.__collection__.insert_one(db_entry) db_entry["_id"] = insert_result.inserted_id if cache is not None: cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry)) return cls(**db_entry) async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None: for key, value in kwargs.items(): if not hasattr(self, key): raise AttributeError(f"Attribute '{key}' does not exist in PycordEvent") setattr(self, key, value) await self.__collection__.update_one({"_id": self._id}, {"$set": kwargs}, upsert=True) self._update_cache(cache) logger.info("Set attributes of event %s to %s", self._id, kwargs) async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None: attributes: Dict[str, Any] = {} for key in args: if not hasattr(self, key): raise AttributeError(f"Attribute '{key}' does not exist in PycordEvent") default_value: Any = self.get_default_value(key) setattr(self, key, default_value) attributes[key] = default_value await self.__collection__.update_one({"_id": self._id}, {"$set": attributes}, upsert=True) self._update_cache(cache) logger.info("Reset attributes %s of event %s to default values", args, self._id) def _get_cache_key(self) -> str: return f"{self.__short_name__}_{self._id}" def _update_cache(self, cache: Optional[Cache] = None) -> None: if cache is None: return object_dict: Dict[str, Any] = self.to_dict(json_compatible=True) if object_dict is not None: cache.set_json(self._get_cache_key(), object_dict) else: self._delete_cache(cache) def _delete_cache(self, cache: Optional[Cache] = None) -> None: if cache is None: return cache.delete(self._get_cache_key()) async def _update_event_stage_order( self, bot: Any, old_stage_ids: List[ObjectId], cache: Optional[Cache] = None, ) -> None: logger.info("Updating event stages order for %s...", self._id) logger.debug("Old stage IDs: %s", old_stage_ids) logger.debug("New stage IDs: %s", self.stage_ids) for event_stage_id in self.stage_ids: if event_stage_id not in old_stage_ids: continue stage_index: int = self.stage_ids.index(event_stage_id) old_stage_index: int = old_stage_ids.index(event_stage_id) logger.debug( "Indexes for %s: was %s and is now %s", event_stage_id, old_stage_index, stage_index ) if stage_index != old_stage_index: await (await bot.find_event_stage(event_stage_id)).update(cache, sequence=stage_index) @staticmethod def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]: cache_entry: Dict[str, Any] = db_entry.copy() cache_entry["_id"] = str(cache_entry["_id"]) cache_entry["created"] = cache_entry["created"].isoformat() cache_entry["ended"] = None if cache_entry["ended"] is None else cache_entry["ended"].isoformat() cache_entry["starts"] = cache_entry["starts"].isoformat() cache_entry["ends"] = cache_entry["ends"].isoformat() cache_entry["stage_ids"] = [str(stage_id) for stage_id in cache_entry["stage_ids"]] return cache_entry @staticmethod def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]: db_entry: Dict[str, Any] = cache_entry.copy() db_entry["_id"] = ObjectId(db_entry["_id"]) db_entry["created"] = datetime.fromisoformat(db_entry["created"]) db_entry["ended"] = None if db_entry["ended"] is None else datetime.fromisoformat(db_entry["ended"]) db_entry["starts"] = datetime.fromisoformat(db_entry["starts"]) db_entry["ends"] = datetime.fromisoformat(db_entry["ends"]) db_entry["stage_ids"] = [ObjectId(stage_id) for stage_id in db_entry["stage_ids"]] return db_entry def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]: """Convert the object to a JSON representation. Args: json_compatible (bool): Whether the JSON-incompatible objects like ObjectId need to be converted. Returns: Dict[str, Any]: JSON representation of the object. """ return { "_id": self._id if not json_compatible else str(self._id), "name": self.name, "guild_id": self.guild_id, "created": self.created if not json_compatible else self.created.isoformat(), "ended": ( self.ended if not json_compatible else (None if self.ended is None else self.ended.isoformat()) ), "is_cancelled": self.is_cancelled, "creator_id": self.creator_id, "starts": self.starts if not json_compatible else self.starts.isoformat(), "ends": self.ends if not json_compatible else self.ends.isoformat(), "thumbnail": self.thumbnail, "stage_ids": ( self.stage_ids if not json_compatible else [str(stage_id) for stage_id in self.stage_ids] ), } @staticmethod def get_defaults() -> Dict[str, Any]: """Get default values for the object attributes. Returns: Dict[str, Any]: Mapping of attributes and their respective values in format `{"attribute_name:" attribute_value}`. """ return { "name": None, "guild_id": None, "created": None, "ended": None, "is_cancelled": False, "creator_id": None, "starts": None, "ends": None, "thumbnail": None, "stage_ids": [], } @staticmethod def get_default_value(key: str) -> Any: """Get default value of the attribute for the object. Args: key (str): Name of the attribute. Returns: Any: Default value of the attribute. Raises: KeyError: There's no default value for the provided attribute. """ if key not in PycordEvent.get_defaults(): raise KeyError(f"There's no default value for key '{key}' in PycordEvent") return PycordEvent.get_defaults()[key] async def update( self, cache: Optional[Cache] = None, **kwargs: Any, ) -> None: """Update attribute(s) on the object and save the updated entry into the database. Args: cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache. **kwargs (Any): Mapping of attributes in format `attribute_name=attribute_value` to update. Raises: AttributeError: Provided attribute does not exist in the class. """ await self._set(cache=cache, **kwargs) async def reset( self, *args: str, cache: Optional[Cache] = None, ) -> None: """Remove attribute(s) on the object, replace them with a default value and save the updated entry into the database. Args: *args (str): List of attributes to remove. cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache. Raises: AttributeError: Provided attribute does not exist in the class. """ await self._remove(*args, cache=cache) async def purge(self, cache: Optional[Cache] = None) -> None: """Completely remove event data from database. Currently only removes the event record from events collection. Args: cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache. """ await self.__collection__.delete_one({"_id": self._id}) self._delete_cache(cache) async def cancel(self, cache: Optional[Cache] = None) -> None: """Cancel the event. Attribute `is_cancelled` will be set to `True`. Args: cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache. """ await self._set(cache, is_cancelled=True) async def end(self, cache: Optional[Cache] = None) -> None: """End the event. Attribute `ended` will be set to the current date in UTC. Args: cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache. """ await self._set(cache, ended=datetime.now(tz=ZoneInfo("UTC"))) async def insert_stage( self, bot: "PycordBot", event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None ) -> None: """Insert a stage at the provided index. Args: bot (PycordBot): Bot object. event_stage_id (ObjectId): Stage ID to be inserted. index (int): Index to be inserted at. cache: Cache engine that will be used to update the cache. """ old_stage_ids: List[ObjectId] = self.stage_ids.copy() self.stage_ids.insert(index, event_stage_id) await self._set(cache, stage_ids=self.stage_ids) await self._update_event_stage_order(bot, old_stage_ids, cache=cache) async def reorder_stage( self, bot: "PycordBot", event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None ) -> None: """Reorder a stage to the provided index. Args: bot (PycordBot): Bot object. event_stage_id (ObjectId): Stage ID to be reordered. index (int): Index to be reordered to. cache: Cache engine that will be used to update the cache. """ old_stage_ids: List[ObjectId] = self.stage_ids.copy() self.stage_ids.insert(index, self.stage_ids.pop(self.stage_ids.index(event_stage_id))) await self._set(cache, stage_ids=self.stage_ids) await self._update_event_stage_order(bot, old_stage_ids, cache=cache) async def remove_stage( self, bot: "PycordBot", event_stage_id: ObjectId, cache: Optional[Cache] = None ) -> None: """Remove a stage from the event. Args: bot (PycordBot): Bot object. event_stage_id (ObjectId): Stage ID to be reordered. cache: Cache engine that will be used to update the cache. """ old_stage_ids: List[ObjectId] = self.stage_ids.copy() self.stage_ids.pop(self.stage_ids.index(event_stage_id)) await self._set(cache, stage_ids=self.stage_ids) await self._update_event_stage_order(bot, old_stage_ids, cache=cache) def get_start_date_utc(self) -> datetime: """Get the event start date in UTC timezone. Returns: datetime: Start date in UTC. Raises: ValueError: Event does not have a start date. """ if self.starts is None: raise ValueError("Event does not have a start date") return self.starts.replace(tzinfo=ZoneInfo("UTC")) def get_end_date_utc(self) -> datetime: """Get the event end date in UTC timezone. Returns: datetime: End date in UTC. Raises: ValueError: Event does not have an end date. """ if self.ends is None: raise ValueError("Event does not have an end date") return self.ends.replace(tzinfo=ZoneInfo("UTC")) def get_start_date_localized(self, tz: tzinfo) -> datetime: """Get the event start date in the provided timezone. Returns: datetime: Start date in the provided timezone. Raises: ValueError: Event does not have a start date. """ if self.starts is None: raise ValueError("Event does not have a start date") return self.starts.replace(tzinfo=tz) def get_end_date_localized(self, tz: tzinfo) -> datetime: """Get the event end date in the provided timezone. Returns: datetime: End date in the provided timezone. Raises: ValueError: Event does not have an end date. """ if self.ends is None: raise ValueError("Event does not have an end date") return self.ends.replace(tzinfo=tz)