Fully documented and updated PycordEvent (#4)

This commit is contained in:
2025-04-27 17:41:14 +02:00
parent 638658af75
commit 12a88d5a23
5 changed files with 292 additions and 136 deletions

View File

@@ -6,9 +6,15 @@ from bson import ObjectId
class EventNotFoundError(Exception):
"""PycordEvent could not find event with such an ID in the database"""
def __init__(self, event_id: Optional[str | ObjectId] = None, event_name: Optional[str] = None) -> None:
self.event_id = event_id
self.event_name = event_name
def __init__(
self,
event_id: Optional[str | ObjectId] = None,
event_name: Optional[str] = None,
guild_id: Optional[int] = None,
) -> None:
self.event_id: str | ObjectId | None = event_id
self.event_name: str | None = event_name
self.guild_id: int | None = guild_id
if self.event_id is None and self.event_name is None:
raise AttributeError("Either event id or name must be provided")
@@ -16,5 +22,5 @@ class EventNotFoundError(Exception):
super().__init__(
f"Event with id {self.event_id} was not found"
if event_id is not None
else f"Event with name {self.event_name} was not found"
else f"Event with name {self.event_name} was not found for the guild {self.guild_id}"
)

View File

@@ -1,7 +1,7 @@
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from zoneinfo import ZoneInfo
from bson import ObjectId
@@ -293,14 +293,19 @@ class PycordBot(LibPycordBot):
return event_stage
# TODO Document this method
async def find_event(self, event_id: str | ObjectId | None = None, event_name: str | None = None) -> PycordEvent:
if event_id is None and event_name is None:
raise AttributeError("Either event's ID or name must be provided!")
async def find_event(
self,
event_id: Optional[str | ObjectId] = None,
event_name: Optional[str] = None,
guild_id: Optional[int] = None,
) -> PycordEvent:
if event_id is None or (event_name is None and guild_id is None):
raise AttributeError("Either event ID or name with guild ID must be provided")
if event_id is not None:
return await PycordEvent.from_id(event_id, cache=self.cache)
else:
return await PycordEvent.from_name(event_name, cache=self.cache)
return await PycordEvent.from_name(event_name, guild_id, cache=self.cache)
# TODO Document this method
async def find_event_stage(self, stage_id: str | ObjectId) -> PycordEventStage:

View File

@@ -1,12 +1,14 @@
"""Module with class PycordEvent."""
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timezone
from logging import Logger
from typing import Any, Dict, List, Optional
from zoneinfo import ZoneInfo
from bson import ObjectId
from discord import Bot
from libbot.cache.classes import Cache
from pymongo import DESCENDING
from pymongo.results import InsertOneResult
from classes.errors import EventNotFoundError
@@ -18,6 +20,22 @@ logger: Logger = get_logger(__name__)
@dataclass
class PycordEvent:
"""Object representation of an event in the database.
Attributes:
_id (ObjectId): ID of the event generated by the database.
name (str): Name of the event.
guild_id (int): Discord ID of the guild where the event takes place.
created (datetime): Date of event's creation in UTC.
ended (datetime | None): Date of the event's actual end in UTC.
is_cancelled (bool): Whether the event is cancelled.
creator_id (int): Discord ID of the creator.
starts (datetime): Date of the event's planned start in UTC.
ends (datetime): Date of the event's planned end in UTC.
thumbnail (Dict[str, Any] | None): Thumbnail to use for the event in format `{"id": thumbnail_id (int), "filename": thumbnail_filename (str)}`.
stage_ids (List[ObjectId]): Database ID's of the event's stages ordered in the completion order.
"""
__slots__ = (
"_id",
"name",
@@ -48,18 +66,18 @@ class PycordEvent:
@classmethod
async def from_id(cls, event_id: str | ObjectId, cache: Optional[Cache] = None) -> "PycordEvent":
"""Find event in the database.
"""Find the event by its ID and construct PycordEvent from database entry.
Args:
event_id (str | ObjectId): Event's ID
cache (:obj:`Cache`, optional): Cache engine to get the cache from
event_id (str | ObjectId): ID of the event to look up.
cache (:obj:`Cache`, optional): Cache engine that will be used to fetch and update the cache.
Returns:
PycordEvent: Event object
PycordEvent: Object of the found event.
Raises:
EventNotFoundError: Event was not found
InvalidId: Invalid event ID was provided
EventNotFoundError: Event with such ID does not exist.
InvalidId: Provided event ID is of invalid format.
"""
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache)
@@ -78,22 +96,35 @@ class PycordEvent:
return cls(**db_entry)
# TODO Add documentation
@classmethod
async def from_name(cls, event_name: str, cache: Optional[Cache] = None) -> "PycordEvent":
# TODO Add sorting by creation date or something.
# Duplicate events should be avoided, latest active event should be returned.
db_entry: Dict[str, Any] | None = await cls.__collection__.find_one({"name": event_name})
async def from_name(cls, event_name: str, guild_id: int, cache: Optional[Cache] = None) -> "PycordEvent":
"""Find the event by its name and construct PycordEvent from database entry.
If multiple events with the same name exist, the one with the greatest start date will be returned.
Args:
event_name (str): Name of the event to look up.
guild_id (int): Discord ID of the guild where the event takes place.
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
Returns:
PycordEvent: Object of the found event.
Raises:
EventNotFoundError: Event with such name does not exist.
"""
db_entry: Dict[str, Any] | None = await cls.__collection__.find_one(
{"name": event_name, "guild_id": guild_id}, sort=[("starts", DESCENDING)]
)
if db_entry is None:
raise EventNotFoundError(event_name=event_name)
raise EventNotFoundError(event_name=event_name, guild_id=guild_id)
if cache is not None:
cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", db_entry)
return cls(**db_entry)
# TODO Add documentation
@classmethod
async def create(
cls,
@@ -105,6 +136,22 @@ class PycordEvent:
thumbnail: Dict[str, Any] | None,
cache: Optional[Cache] = None,
) -> "PycordEvent":
"""Create an event, write it to the database and return the constructed PycordEvent object.
Creation date will be set to current time in UTC automatically.
Args:
name (str): Name of the event.
guild_id (int): Guild ID where the event takes place.
creator_id (int): Discord ID of the event creator.
starts (datetime): Date when the event starts. Must be UTC.
ends (datetime): Date when the event ends. Must be UTC.
thumbnail (:obj:`Dict[str, Any]`, optional): Thumbnail to use for the event in format `{"id": thumbnail_id (int), "filename": thumbnail_filename (str)}`.
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
Returns:
PycordEvent: Object of the created event.
"""
db_entry: Dict[str, Any] = {
"name": name,
"guild_id": guild_id,
@@ -128,15 +175,9 @@ class PycordEvent:
return cls(**db_entry)
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
"""Set attribute data and save it into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
**kwargs (Any): Mapping of attribute names and respective values to be set
"""
for key, value in kwargs.items():
if not hasattr(self, key):
raise AttributeError()
raise AttributeError(f"Attribute '{key}' does not exist in PycordEvent")
setattr(self, key, value)
@@ -147,17 +188,11 @@ class PycordEvent:
logger.info("Set attributes of event %s to %s", self._id, kwargs)
async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None:
"""Remove attribute data and save it into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
*args (str): List of attributes to remove
"""
attributes: Dict[str, Any] = {}
for key in args:
if not hasattr(self, key):
raise AttributeError()
raise AttributeError(f"Attribute '{key}' does not exist in PycordEvent")
default_value: Any = self.get_default_value(key)
@@ -191,86 +226,6 @@ class PycordEvent:
cache.delete(self._get_cache_key())
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
"""Convert PycordEvent object to a JSON representation.
Args:
json_compatible (bool): Whether the JSON-incompatible objects like ObjectId need to be converted
Returns:
Dict[str, Any]: JSON representation of PycordEvent
"""
return {
"_id": self._id if not json_compatible else str(self._id),
"name": self.name,
"guild_id": self.guild_id,
"created": self.created,
"ended": self.ended,
"is_cancelled": self.is_cancelled,
"creator_id": self.creator_id,
"starts": self.starts,
"ends": self.ends,
"thumbnail": self.thumbnail,
"stage_ids": self.stage_ids,
}
# TODO Add documentation
@staticmethod
def get_defaults() -> Dict[str, Any]:
return {
"name": None,
"guild_id": None,
"created": None,
"ended": None,
"is_cancelled": False,
"creator_id": None,
"starts": None,
"ends": None,
"thumbnail": None,
"stage_ids": [],
}
# TODO Add documentation
@staticmethod
def get_default_value(key: str) -> Any:
if key not in PycordEvent.get_defaults():
raise KeyError(f"There's no default value for key '{key}' in PycordEvent")
return PycordEvent.get_defaults()[key]
# TODO Add documentation
async def update(
self,
cache: Optional[Cache] = None,
**kwargs,
):
await self._set(cache=cache, **kwargs)
# TODO Add documentation
async def reset(
self,
cache: Optional[Cache] = None,
*args,
):
await self._remove(cache, *args)
async def purge(self, cache: Optional[Cache] = None) -> None:
"""Completely remove event data from database. Currently only removes the event record from events collection.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
"""
await self.__collection__.delete_one({"_id": self._id})
self._delete_cache(cache)
# TODO Add documentation
async def cancel(self, cache: Optional[Cache] = None):
await self._set(cache, is_cancelled=True)
# # TODO Add documentation
async def end(self, cache: Optional[Cache] = None) -> None:
await self._set(cache, ended=datetime.now(tz=ZoneInfo("UTC")))
async def _update_event_stage_order(
self,
bot: Any,
@@ -294,10 +249,139 @@ class PycordEvent:
if stage_index != old_stage_index:
await (await bot.find_event_stage(event_stage_id)).update(cache, sequence=stage_index)
# TODO Add documentation
async def insert_stage(
self, bot: Bot, event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
"""Convert the object to a JSON representation.
Args:
json_compatible (bool): Whether the JSON-incompatible objects like ObjectId need to be converted.
Returns:
Dict[str, Any]: JSON representation of the object.
"""
return {
"_id": self._id if not json_compatible else str(self._id),
"name": self.name,
"guild_id": self.guild_id,
"created": self.created,
"ended": self.ended,
"is_cancelled": self.is_cancelled,
"creator_id": self.creator_id,
"starts": self.starts,
"ends": self.ends,
"thumbnail": self.thumbnail,
"stage_ids": self.stage_ids,
}
@staticmethod
def get_defaults() -> Dict[str, Any]:
"""Get default values for the object attributes.
Returns:
Dict[str, Any]: Mapping of attributes and their respective values in format `{"attribute_name:" attribute_value}`.
"""
return {
"name": None,
"guild_id": None,
"created": None,
"ended": None,
"is_cancelled": False,
"creator_id": None,
"starts": None,
"ends": None,
"thumbnail": None,
"stage_ids": [],
}
@staticmethod
def get_default_value(key: str) -> Any:
"""Get default value of the attribute for the object.
Args:
key (str): Name of the attribute.
Returns:
Any: Default value of the attribute.
Raises:
KeyError: There's no default value for the provided attribute.
"""
if key not in PycordEvent.get_defaults():
raise KeyError(f"There's no default value for key '{key}' in PycordEvent")
return PycordEvent.get_defaults()[key]
async def update(
self,
cache: Optional[Cache] = None,
**kwargs,
) -> None:
"""Update attribute(s) on the object and save the updated entry into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
**kwargs (Any): Mapping of attributes in format `attribute_name=attribute_value` to update.
Raises:
AttributeError: Provided attribute does not exist in the class.
"""
await self._set(cache=cache, **kwargs)
async def reset(
self,
*args,
cache: Optional[Cache] = None,
) -> None:
"""Remove attribute(s) on the object, replace them with a default value and save the updated entry into the database.
Args:
*args (str): List of attributes to remove.
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
Raises:
AttributeError: Provided attribute does not exist in the class.
"""
await self._remove(*args, cache=cache)
async def purge(self, cache: Optional[Cache] = None) -> None:
"""Completely remove event data from database. Currently only removes the event record from events collection.
Args:
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
"""
await self.__collection__.delete_one({"_id": self._id})
self._delete_cache(cache)
async def cancel(self, cache: Optional[Cache] = None) -> None:
"""Cancel the event.
Attribute `is_cancelled` will be set to `True`.
Args:
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
"""
await self._set(cache, is_cancelled=True)
async def end(self, cache: Optional[Cache] = None) -> None:
"""End the event.
Attribute `ended` will be set to the current date in UTC.
Args:
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
"""
await self._set(cache, ended=datetime.now(tz=ZoneInfo("UTC")))
async def insert_stage(
self, bot: "PycordBot", event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None
) -> None:
"""Insert a stage at the provided index.
Args:
bot (PycordBot): Bot object.
event_stage_id (ObjectId): Stage ID to be inserted.
index (int): Index to be inserted at.
cache: Cache engine that will be used to update the cache.
"""
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
self.stage_ids.insert(index, event_stage_id)
@@ -305,10 +389,17 @@ class PycordEvent:
await self._set(cache, stage_ids=self.stage_ids)
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
# TODO Add documentation
async def reorder_stage(
self, bot: Any, event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None
self, bot: "PycordBot", event_stage_id: ObjectId, index: int, cache: Optional[Cache] = None
) -> None:
"""Reorder a stage to the provided index.
Args:
bot (PycordBot): Bot object.
event_stage_id (ObjectId): Stage ID to be reordered.
index (int): Index to be reordered to.
cache: Cache engine that will be used to update the cache.
"""
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
self.stage_ids.insert(index, self.stage_ids.pop(self.stage_ids.index(event_stage_id)))
@@ -316,8 +407,14 @@ class PycordEvent:
await self._set(cache, stage_ids=self.stage_ids)
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
# TODO Add documentation
async def remove_stage(self, bot: Bot, event_stage_id: ObjectId, cache: Optional[Cache] = None) -> None:
async def remove_stage(self, bot: "PycordBot", event_stage_id: ObjectId, cache: Optional[Cache] = None) -> None:
"""Remove a stage from the event.
Args:
bot (PycordBot): Bot object.
event_stage_id (ObjectId): Stage ID to be reordered.
cache: Cache engine that will be used to update the cache.
"""
old_stage_ids: List[ObjectId] = self.stage_ids.copy()
self.stage_ids.pop(self.stage_ids.index(event_stage_id))
@@ -325,10 +422,58 @@ class PycordEvent:
await self._set(cache, stage_ids=self.stage_ids)
await self._update_event_stage_order(bot, old_stage_ids, cache=cache)
# # TODO Add documentation
# def get_localized_start_date(self, tz: str | timezone | ZoneInfo) -> datetime:
# return self.starts.replace(tzinfo=tz)
def get_start_date_utc(self) -> datetime:
"""Get the event start date in UTC timezone.
# TODO Add documentation
# def get_localized_end_date(self, tz: str | timezone | ZoneInfo) -> datetime:
# return self.ends.replace(tzinfo=tz)
Returns:
datetime: Start date in UTC.
Raises:
ValueError: Event does not have a start date.
"""
if self.starts is None:
raise ValueError("Event does not have a start date")
return self.starts.replace(tzinfo=ZoneInfo("UTC"))
def get_end_date_utc(self) -> datetime:
"""Get the event end date in UTC timezone.
Returns:
datetime: End date in UTC.
Raises:
ValueError: Event does not have an end date.
"""
if self.ends is None:
raise ValueError("Event does not have an end date")
return self.ends.replace(tzinfo=ZoneInfo("UTC"))
def get_start_date_localized(self, tz: str | timezone | ZoneInfo) -> datetime:
"""Get the event start date in the provided timezone.
Returns:
datetime: Start date in the provided timezone.
Raises:
ValueError: Event does not have a start date.
"""
if self.starts is None:
raise ValueError("Event does not have a start date")
return self.starts.replace(tzinfo=tz)
def get_end_date_localized(self, tz: str | timezone | ZoneInfo) -> datetime:
"""Get the event end date in the provided timezone.
Returns:
datetime: End date in the provided timezone.
Raises:
ValueError: Event does not have an end date.
"""
if self.ends is None:
raise ValueError("Event does not have an end date")
return self.ends.replace(tzinfo=tz)