Files
QuizBot/classes/pycord_event.py

332 lines
10 KiB
Python

from dataclasses import dataclass
from datetime import datetime
from logging import Logger
from typing import Any, Dict, List, Optional
from zoneinfo import ZoneInfo
from bson import ObjectId
from discord import Bot
from libbot.cache.classes import Cache
from pymongo.results import InsertOneResult
from modules.database import col_events
from modules.utils import get_logger, restore_from_cache
logger: Logger = get_logger(__name__)
@dataclass
class PycordEvent:
__slots__ = (
"_id",
"name",
"guild_id",
"created",
"ended",
"is_cancelled",
"creator_id",
"starts",
"ends",
"thumbnail",
"stage_ids",
)
__short_name__ = "event"
__collection__ = col_events
_id: ObjectId
name: str
guild_id: int
created: datetime
ended: datetime | None
is_cancelled: bool
creator_id: int
starts: datetime
ends: datetime
thumbnail: Dict[str, Any] | None
stage_ids: List[ObjectId]
@classmethod
async def from_id(cls, event_id: str | ObjectId, cache: Optional[Cache] = None) -> "PycordEvent":
"""Find event in the database.
Args:
event_id (str | ObjectId): Event's ID
cache (:obj:`Cache`, optional): Cache engine to get the cache from
Returns:
PycordEvent: Event object
Raises:
EventNotFoundError: Event was not found
InvalidId: Invalid event ID was provided
"""
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache)
if cached_entry is not None:
return cls(**cached_entry)
db_entry = await cls.__collection__.find_one(
{"_id": event_id if isinstance(event_id, ObjectId) else ObjectId(event_id)}
)
if db_entry is None:
raise RuntimeError(f"Event {event_id} not found")
# TODO Add a unique exception
# raise EventNotFoundError(event_id)
if cache is not None:
cache.set_json(f"{cls.__short_name__}_{event_id}", db_entry)
return cls(**db_entry)
@classmethod
async def from_name(cls, event_name: str, cache: Optional[Cache] = None) -> "PycordEvent":
# TODO Add sorting by creation date or something.
# Duplicate events should be avoided, latest active event should be returned.
db_entry: Dict[str, Any] | None = await cls.__collection__.find_one({"name": event_name})
if db_entry is None:
raise RuntimeError(f"Event with name {event_name} not found")
# TODO Add a unique exception
# raise EventNotFoundError(event_name)
if cache is not None:
cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", db_entry)
return cls(**db_entry)
@classmethod
async def create(
cls,
name: str,
guild_id: int,
creator_id: int,
starts: datetime,
ends: datetime,
thumbnail: Dict[str, Any] | None,
cache: Optional[Cache] = None,
) -> "PycordEvent":
db_entry: Dict[str, Any] = {
"name": name,
"guild_id": guild_id,
"created": datetime.now(tz=ZoneInfo("UTC")),
"ended": None,
"is_cancelled": False,
"creator_id": creator_id,
"starts": starts,
"ends": ends,
"thumbnail": thumbnail,
"stage_ids": [],
}
insert_result: InsertOneResult = await cls.__collection__.insert_one(db_entry)
db_entry["_id"] = insert_result.inserted_id
if cache is not None:
cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry)
return cls(**db_entry)
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
"""Set attribute data and save it into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
**kwargs (Any): Mapping of attribute names and respective values to be set
"""
for key, value in kwargs.items():
if not hasattr(self, key):
raise AttributeError()
setattr(self, key, value)
await self.__collection__.update_one({"_id": self._id}, {"$set": kwargs}, upsert=True)
self._update_cache(cache)
logger.info("Set attributes of event %s to %s", self._id, kwargs)
async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None:
"""Remove attribute data and save it into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
*args (str): List of attributes to remove
"""
attributes: Dict[str, Any] = {}
for key in args:
if not hasattr(self, key):
raise AttributeError()
default_value: Any = self.get_default_value(key)
setattr(self, key, default_value)
attributes[key] = default_value
await self.__collection__.update_one({"_id": self._id}, {"$set": attributes}, upsert=True)
self._update_cache(cache)
logger.info("Reset attributes %s of event %s to default values", args, self._id)
def _get_cache_key(self) -> str:
return f"{self.__short_name__}_{self._id}"
def _update_cache(self, cache: Optional[Cache] = None) -> None:
if cache is None:
return
user_dict: Dict[str, Any] = self.to_dict()
if user_dict is not None:
cache.set_json(self._get_cache_key(), user_dict)
else:
self._delete_cache(cache)
def _delete_cache(self, cache: Optional[Cache] = None) -> None:
if cache is None:
return
cache.delete(self._get_cache_key())
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
"""Convert PycordEvent object to a JSON representation.
Args:
json_compatible (bool): Whether the JSON-incompatible objects like ObjectId need to be converted
Returns:
Dict[str, Any]: JSON representation of PycordEvent
"""
return {
"_id": self._id if not json_compatible else str(self._id),
"name": self.name,
"guild_id": self.guild_id,
"created": self.created,
"ended": self.ended,
"is_cancelled": self.is_cancelled,
"creator_id": self.creator_id,
"starts": self.starts,
"ends": self.ends,
"thumbnail": self.thumbnail,
"stage_ids": self.stage_ids,
}
@staticmethod
def get_defaults() -> Dict[str, Any]:
return {
"name": None,
"guild_id": None,
"created": None,
"ended": None,
"is_cancelled": False,
"creator_id": None,
"starts": None,
"ends": None,
"thumbnail": None,
"stage_ids": [],
}
@staticmethod
def get_default_value(key: str) -> Any:
if key not in PycordEvent.get_defaults():
raise KeyError(f"There's no default value for key '{key}' in PycordEvent")
return PycordEvent.get_defaults()[key]
# TODO Add documentation
async def update(
self,
cache: Optional[Cache] = None,
**kwargs,
):
await self._set(cache=cache, **kwargs)
# TODO Add documentation
async def reset(
self,
cache: Optional[Cache] = None,
*args,
):
await self._remove(cache, *args)
async def purge(self, cache: Optional[Cache] = None) -> None:
"""Completely remove event data from database. Currently only removes the event record from events collection.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
"""
await self.__collection__.delete_one({"_id": self._id})
self._delete_cache(cache)
# TODO Add documentation
async def cancel(self, cache: Optional[Cache] = None):
await self._set(cache, is_cancelled=True)
async def _update_event_stage_order(
self,
bot: Any,
old_stage_ids: List[ObjectId],
cache: Optional[Cache] = None,
) -> None:
logger.info("Updating event stages order for %s...", self._id)
logger.debug("Old stage IDs: %s", old_stage_ids)
logger.debug("New stage IDs: %s", self.stage_ids)
for event_stage_id in self.stage_ids:
if event_stage_id not in old_stage_ids:
continue
stage_index: int = self.stage_ids.index(event_stage_id)
old_stage_index: int = old_stage_ids.index(event_stage_id)
logger.debug("Indexes for %s: was %s and is now %s", event_stage_id, old_stage_index, stage_index)
if stage_index != old_stage_index:
await (await bot.find_event_stage(event_stage_id)).update(cache, sequence=stage_index)
# TODO Add documentation
async def insert_stage(
self, bot: Bot, event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None
) -> None:
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
self.stage_ids.insert(index, event_stage_id)
await self._set(cache, stage_ids=self.stage_ids)
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
# TODO Add documentation
async def reorder_stage(
self, bot: Any, event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None
) -> None:
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
self.stage_ids.insert(index, self.stage_ids.pop(self.stage_ids.index(event_stage_id)))
await self._set(cache, stage_ids=self.stage_ids)
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
# TODO Add documentation
async def remove_stage(self, bot: Bot, event_stage_id: ObjectId, cache: Optional[Cache] = None) -> None:
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
self.stage_ids.pop(self.stage_ids.index(event_stage_id))
await self._set(cache, stage_ids=self.stage_ids)
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
# # TODO Add documentation
# def get_localized_start_date(self, tz: str | timezone | ZoneInfo) -> datetime:
# return self.starts.replace(tzinfo=tz)
#
# # TODO Add documentation
# def get_localized_end_date(self, tz: str | timezone | ZoneInfo) -> datetime:
# return self.ends.replace(tzinfo=tz)