332 lines
10 KiB
Python
332 lines
10 KiB
Python
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from logging import Logger
|
|
from typing import Any, Dict, List, Optional
|
|
from zoneinfo import ZoneInfo
|
|
|
|
from bson import ObjectId
|
|
from discord import Bot
|
|
from libbot.cache.classes import Cache
|
|
from pymongo.results import InsertOneResult
|
|
|
|
from modules.database import col_events
|
|
from modules.utils import get_logger, restore_from_cache
|
|
|
|
logger: Logger = get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class PycordEvent:
|
|
__slots__ = (
|
|
"_id",
|
|
"name",
|
|
"guild_id",
|
|
"created",
|
|
"ended",
|
|
"is_cancelled",
|
|
"creator_id",
|
|
"starts",
|
|
"ends",
|
|
"thumbnail",
|
|
"stage_ids",
|
|
)
|
|
__short_name__ = "event"
|
|
__collection__ = col_events
|
|
|
|
_id: ObjectId
|
|
name: str
|
|
guild_id: int
|
|
created: datetime
|
|
ended: datetime | None
|
|
is_cancelled: bool
|
|
creator_id: int
|
|
starts: datetime
|
|
ends: datetime
|
|
thumbnail: Dict[str, Any] | None
|
|
stage_ids: List[ObjectId]
|
|
|
|
@classmethod
|
|
async def from_id(cls, event_id: str | ObjectId, cache: Optional[Cache] = None) -> "PycordEvent":
|
|
"""Find event in the database.
|
|
|
|
Args:
|
|
event_id (str | ObjectId): Event's ID
|
|
cache (:obj:`Cache`, optional): Cache engine to get the cache from
|
|
|
|
Returns:
|
|
PycordEvent: Event object
|
|
|
|
Raises:
|
|
EventNotFoundError: Event was not found
|
|
InvalidId: Invalid event ID was provided
|
|
"""
|
|
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache)
|
|
|
|
if cached_entry is not None:
|
|
return cls(**cached_entry)
|
|
|
|
db_entry = await cls.__collection__.find_one(
|
|
{"_id": event_id if isinstance(event_id, ObjectId) else ObjectId(event_id)}
|
|
)
|
|
|
|
if db_entry is None:
|
|
raise RuntimeError(f"Event {event_id} not found")
|
|
|
|
# TODO Add a unique exception
|
|
# raise EventNotFoundError(event_id)
|
|
|
|
if cache is not None:
|
|
cache.set_json(f"{cls.__short_name__}_{event_id}", db_entry)
|
|
|
|
return cls(**db_entry)
|
|
|
|
@classmethod
|
|
async def from_name(cls, event_name: str, cache: Optional[Cache] = None) -> "PycordEvent":
|
|
# TODO Add sorting by creation date or something.
|
|
# Duplicate events should be avoided, latest active event should be returned.
|
|
db_entry: Dict[str, Any] | None = await cls.__collection__.find_one({"name": event_name})
|
|
|
|
if db_entry is None:
|
|
raise RuntimeError(f"Event with name {event_name} not found")
|
|
|
|
# TODO Add a unique exception
|
|
# raise EventNotFoundError(event_name)
|
|
|
|
if cache is not None:
|
|
cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", db_entry)
|
|
|
|
return cls(**db_entry)
|
|
|
|
@classmethod
|
|
async def create(
|
|
cls,
|
|
name: str,
|
|
guild_id: int,
|
|
creator_id: int,
|
|
starts: datetime,
|
|
ends: datetime,
|
|
thumbnail: Dict[str, Any] | None,
|
|
cache: Optional[Cache] = None,
|
|
) -> "PycordEvent":
|
|
db_entry: Dict[str, Any] = {
|
|
"name": name,
|
|
"guild_id": guild_id,
|
|
"created": datetime.now(tz=ZoneInfo("UTC")),
|
|
"ended": None,
|
|
"is_cancelled": False,
|
|
"creator_id": creator_id,
|
|
"starts": starts,
|
|
"ends": ends,
|
|
"thumbnail": thumbnail,
|
|
"stage_ids": [],
|
|
}
|
|
|
|
insert_result: InsertOneResult = await cls.__collection__.insert_one(db_entry)
|
|
|
|
db_entry["_id"] = insert_result.inserted_id
|
|
|
|
if cache is not None:
|
|
cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry)
|
|
|
|
return cls(**db_entry)
|
|
|
|
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
|
|
"""Set attribute data and save it into the database.
|
|
|
|
Args:
|
|
cache (:obj:`Cache`, optional): Cache engine to write the update into
|
|
**kwargs (Any): Mapping of attribute names and respective values to be set
|
|
"""
|
|
for key, value in kwargs.items():
|
|
if not hasattr(self, key):
|
|
raise AttributeError()
|
|
|
|
setattr(self, key, value)
|
|
|
|
await self.__collection__.update_one({"_id": self._id}, {"$set": kwargs}, upsert=True)
|
|
|
|
self._update_cache(cache)
|
|
|
|
logger.info("Set attributes of event %s to %s", self._id, kwargs)
|
|
|
|
async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None:
|
|
"""Remove attribute data and save it into the database.
|
|
|
|
Args:
|
|
cache (:obj:`Cache`, optional): Cache engine to write the update into
|
|
*args (str): List of attributes to remove
|
|
"""
|
|
attributes: Dict[str, Any] = {}
|
|
|
|
for key in args:
|
|
if not hasattr(self, key):
|
|
raise AttributeError()
|
|
|
|
default_value: Any = self.get_default_value(key)
|
|
|
|
setattr(self, key, default_value)
|
|
|
|
attributes[key] = default_value
|
|
|
|
await self.__collection__.update_one({"_id": self._id}, {"$set": attributes}, upsert=True)
|
|
|
|
self._update_cache(cache)
|
|
|
|
logger.info("Reset attributes %s of event %s to default values", args, self._id)
|
|
|
|
def _get_cache_key(self) -> str:
|
|
return f"{self.__short_name__}_{self._id}"
|
|
|
|
def _update_cache(self, cache: Optional[Cache] = None) -> None:
|
|
if cache is None:
|
|
return
|
|
|
|
user_dict: Dict[str, Any] = self.to_dict()
|
|
|
|
if user_dict is not None:
|
|
cache.set_json(self._get_cache_key(), user_dict)
|
|
else:
|
|
self._delete_cache(cache)
|
|
|
|
def _delete_cache(self, cache: Optional[Cache] = None) -> None:
|
|
if cache is None:
|
|
return
|
|
|
|
cache.delete(self._get_cache_key())
|
|
|
|
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
|
|
"""Convert PycordEvent object to a JSON representation.
|
|
|
|
Args:
|
|
json_compatible (bool): Whether the JSON-incompatible objects like ObjectId need to be converted
|
|
|
|
Returns:
|
|
Dict[str, Any]: JSON representation of PycordEvent
|
|
"""
|
|
return {
|
|
"_id": self._id if not json_compatible else str(self._id),
|
|
"name": self.name,
|
|
"guild_id": self.guild_id,
|
|
"created": self.created,
|
|
"ended": self.ended,
|
|
"is_cancelled": self.is_cancelled,
|
|
"creator_id": self.creator_id,
|
|
"starts": self.starts,
|
|
"ends": self.ends,
|
|
"thumbnail": self.thumbnail,
|
|
"stage_ids": self.stage_ids,
|
|
}
|
|
|
|
@staticmethod
|
|
def get_defaults() -> Dict[str, Any]:
|
|
return {
|
|
"name": None,
|
|
"guild_id": None,
|
|
"created": None,
|
|
"ended": None,
|
|
"is_cancelled": False,
|
|
"creator_id": None,
|
|
"starts": None,
|
|
"ends": None,
|
|
"thumbnail": None,
|
|
"stage_ids": [],
|
|
}
|
|
|
|
@staticmethod
|
|
def get_default_value(key: str) -> Any:
|
|
if key not in PycordEvent.get_defaults():
|
|
raise KeyError(f"There's no default value for key '{key}' in PycordEvent")
|
|
|
|
return PycordEvent.get_defaults()[key]
|
|
|
|
# TODO Add documentation
|
|
async def update(
|
|
self,
|
|
cache: Optional[Cache] = None,
|
|
**kwargs,
|
|
):
|
|
await self._set(cache=cache, **kwargs)
|
|
|
|
# TODO Add documentation
|
|
async def reset(
|
|
self,
|
|
cache: Optional[Cache] = None,
|
|
*args,
|
|
):
|
|
await self._remove(cache, *args)
|
|
|
|
async def purge(self, cache: Optional[Cache] = None) -> None:
|
|
"""Completely remove event data from database. Currently only removes the event record from events collection.
|
|
|
|
Args:
|
|
cache (:obj:`Cache`, optional): Cache engine to write the update into
|
|
"""
|
|
await self.__collection__.delete_one({"_id": self._id})
|
|
self._delete_cache(cache)
|
|
|
|
# TODO Add documentation
|
|
async def cancel(self, cache: Optional[Cache] = None):
|
|
await self._set(cache, is_cancelled=True)
|
|
|
|
async def _update_event_stage_order(
|
|
self,
|
|
bot: Any,
|
|
old_stage_ids: List[ObjectId],
|
|
cache: Optional[Cache] = None,
|
|
) -> None:
|
|
logger.info("Updating event stages order for %s...", self._id)
|
|
|
|
logger.debug("Old stage IDs: %s", old_stage_ids)
|
|
logger.debug("New stage IDs: %s", self.stage_ids)
|
|
|
|
for event_stage_id in self.stage_ids:
|
|
if event_stage_id not in old_stage_ids:
|
|
continue
|
|
|
|
stage_index: int = self.stage_ids.index(event_stage_id)
|
|
old_stage_index: int = old_stage_ids.index(event_stage_id)
|
|
|
|
logger.debug("Indexes for %s: was %s and is now %s", event_stage_id, old_stage_index, stage_index)
|
|
|
|
if stage_index != old_stage_index:
|
|
await (await bot.find_event_stage(event_stage_id)).update(cache, sequence=stage_index)
|
|
|
|
# TODO Add documentation
|
|
async def insert_stage(
|
|
self, bot: Bot, event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None
|
|
) -> None:
|
|
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
|
|
|
|
self.stage_ids.insert(index, event_stage_id)
|
|
|
|
await self._set(cache, stage_ids=self.stage_ids)
|
|
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
|
|
|
|
# TODO Add documentation
|
|
async def reorder_stage(
|
|
self, bot: Any, event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None
|
|
) -> None:
|
|
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
|
|
|
|
self.stage_ids.insert(index, self.stage_ids.pop(self.stage_ids.index(event_stage_id)))
|
|
|
|
await self._set(cache, stage_ids=self.stage_ids)
|
|
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
|
|
|
|
# TODO Add documentation
|
|
async def remove_stage(self, bot: Bot, event_stage_id: ObjectId, cache: Optional[Cache] = None) -> None:
|
|
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
|
|
|
|
self.stage_ids.pop(self.stage_ids.index(event_stage_id))
|
|
|
|
await self._set(cache, stage_ids=self.stage_ids)
|
|
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
|
|
|
|
# # TODO Add documentation
|
|
# def get_localized_start_date(self, tz: str | timezone | ZoneInfo) -> datetime:
|
|
# return self.starts.replace(tzinfo=tz)
|
|
#
|
|
# # TODO Add documentation
|
|
# def get_localized_end_date(self, tz: str | timezone | ZoneInfo) -> datetime:
|
|
# return self.ends.replace(tzinfo=tz)
|