diff --git a/classes/pycord_bot.py b/classes/pycord_bot.py index 0758a1d..5a18c49 100644 --- a/classes/pycord_bot.py +++ b/classes/pycord_bot.py @@ -7,7 +7,7 @@ from libbot.cache.classes import CacheMemcached, CacheRedis from libbot.cache.manager import create_cache_client from libbot.pycord.classes import PycordBot as LibPycordBot -from classes import PycordEvent, PycordGuild, PycordUser +from classes import PycordEvent, PycordGuild, PycordUser, PycordEventStage from modules.logging_utils import get_logger logger: Logger = get_logger(__name__) @@ -83,6 +83,27 @@ class PycordBot(LibPycordBot): async def create_event(self, **kwargs) -> PycordEvent: return await PycordEvent.create(**kwargs, cache=self.cache) + # TODO Document this method + async def create_event_stage(self, event: PycordEvent, **kwargs) -> PycordEventStage: + # TODO Validation is handled by the caller for now, but + # ideally this should not be the case at all. + # + # if "event_id" not in kwargs: + # # TODO Create a nicer exception + # raise RuntimeError("Event ID must be provided while creating an event stage") + # + # event: PycordEvent = await self.find_event(event_id=kwargs["event_id"]) + + if "sequence" not in kwargs: + # TODO Create a nicer exception + raise RuntimeError("Stage must have a defined sequence") + + event_stage: PycordEventStage = await PycordEventStage.create(**kwargs, cache=self.cache) + + await event.insert_stage(event_stage._id, kwargs["sequence"], cache=self.cache) + + return event_stage + async def start(self, *args: Any, **kwargs: Any) -> None: await super().start(*args, **kwargs) diff --git a/classes/pycord_event.py b/classes/pycord_event.py index 9146a8a..88d16f8 100644 --- a/classes/pycord_event.py +++ b/classes/pycord_event.py @@ -8,7 +8,7 @@ from bson import ObjectId from libbot.cache.classes import Cache from pymongo.results import InsertOneResult -from modules.database import col_events +from modules.database import col_events, col_stages from modules.logging_utils import get_logger logger: Logger = get_logger(__name__) @@ -42,7 +42,7 @@ class PycordEvent: starts: datetime ends: datetime thumbnail_id: str | None - stage_ids: List[int] + stage_ids: List[ObjectId] @classmethod async def from_id(cls, event_id: str | ObjectId, cache: Optional[Cache] = None) -> "PycordEvent": @@ -268,3 +268,12 @@ class PycordEvent: async def cancel(self, cache: Optional[Cache] = None): await self._set(cache, cancelled=True) + + async def insert_stage( + self, event_stage_id: ObjectId, sequence: int, cache: Optional[Cache] = None + ) -> None: + self.stage_ids.insert(sequence, event_stage_id) + await self._set(cache, stage_ids=self.stage_ids) + + # TODO Check if this works + await col_stages.update_many({"_id": {"$eq": self.stage_ids[sequence:]}}, {"$inc": {"sequence": 1}}) diff --git a/classes/pycord_event_stage.py b/classes/pycord_event_stage.py index 51499d7..aa7b4d8 100644 --- a/classes/pycord_event_stage.py +++ b/classes/pycord_event_stage.py @@ -1,18 +1,241 @@ from dataclasses import dataclass from datetime import datetime -from typing import List +from logging import Logger +from typing import List, Dict, Any, Optional +from zoneinfo import ZoneInfo from bson import ObjectId +from libbot.cache.classes import Cache +from pymongo.results import InsertOneResult + +from modules.database import col_stages +from modules.logging_utils import get_logger + +logger: Logger = get_logger(__name__) @dataclass class PycordEventStage: + __slots__ = ( + "_id", + "event_id", + "guild_id", + "sequence", + "created", + "creator_id", + "question", + "answer", + "media", + ) + __short_name__ = "stage" + __collection__ = col_stages + _id: ObjectId - id: int - event_id: int + event_id: ObjectId guild_id: int sequence: int created: datetime creator_id: int - text: str | None + question: str + answer: str media: List[str] + + @classmethod + async def from_id(cls, stage_id: str | ObjectId, cache: Optional[Cache] = None) -> "PycordEventStage": + """Find event stage in the database. + + Args: + stage_id (str | ObjectId): Stage's ID + cache (:obj:`Cache`, optional): Cache engine to get the cache from + + Returns: + PycordEventStage: Event stage object + + Raises: + EventStageNotFoundError: Event stage was not found + """ + if cache is not None: + cached_entry: Dict[str, Any] | None = cache.get_json(f"{cls.__short_name__}_{stage_id}") + + if cached_entry is not None: + return cls(**cached_entry) + + db_entry = await cls.__collection__.find_one( + {"_id": stage_id if isinstance(stage_id, ObjectId) else ObjectId(stage_id)} + ) + + if db_entry is None: + raise RuntimeError(f"Event stage {stage_id} not found") + + # TODO Add a unique exception + # raise EventStageNotFoundError(event_id) + + if cache is not None: + cache.set_json(f"{cls.__short_name__}_{stage_id}", db_entry) + + return cls(**db_entry) + + @classmethod + async def create( + cls, + event_id: Optional[str | ObjectId], + guild_id: Optional[int], + sequence: int, + creator_id: int, + question: str, + answer: Optional[str] = None, + media: Optional[List[str]] = None, + cache: Optional[Cache] = None, + ) -> "PycordEventStage": + db_entry: Dict[str, Any] = { + "event_id": event_id, + "guild_id": guild_id, + "sequence": sequence, + "created": datetime.now(tz=ZoneInfo("UTC")), + "creator_id": creator_id, + "question": question, + "answer": answer, + "media": [] if media is None else media, + } + + insert_result: InsertOneResult = await cls.__collection__.insert_one(db_entry) + + db_entry["_id"] = insert_result.inserted_id + + if cache is not None: + cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry) + + return cls(**db_entry) + + # TODO Update the docstring + async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None: + """Set attribute data and save it into the database. + + Args: + key (str): Attribute to change + value (Any): Value to set + cache (:obj:`Cache`, optional): Cache engine to write the update into + """ + for key, value in kwargs.items(): + if not hasattr(self, key): + raise AttributeError() + + setattr(self, key, value) + + await self.__collection__.update_one({"_id": self._id}, {"$set": kwargs}, upsert=True) + + self._update_cache(cache) + + logger.info("Set attributes of event stage %s to %s", self._id, kwargs) + + # TODO Update the docstring + async def _remove(self, cache: Optional[Cache] = None, *args: str) -> None: + """Remove attribute data and save it into the database. + + Args: + key (str): Attribute to remove + cache (:obj:`Cache`, optional): Cache engine to write the update into + """ + attributes: Dict[str, Any] = {} + + for key in args: + if not hasattr(self, key): + raise AttributeError() + + default_value: Any = self.get_default_value(key) + + setattr(self, key, default_value) + + attributes[key] = default_value + + await self.__collection__.update_one({"_id": self._id}, {"$set": attributes}, upsert=True) + + self._update_cache(cache) + + logger.info("Reset attributes %s of event stage %s to default values", args, self._id) + + def _get_cache_key(self) -> str: + return f"{self.__short_name__}_{self._id}" + + def _update_cache(self, cache: Optional[Cache] = None) -> None: + if cache is None: + return + + user_dict: Dict[str, Any] = self.to_dict() + + if user_dict is not None: + cache.set_json(self._get_cache_key(), user_dict) + else: + self._delete_cache(cache) + + def _delete_cache(self, cache: Optional[Cache] = None) -> None: + if cache is None: + return + + cache.delete(self._get_cache_key()) + + def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]: + """Convert PycordEventStage object to a JSON representation. + + Args: + json_compatible (bool): Whether the JSON-incompatible objects like ObjectId need to be converted + + Returns: + Dict[str, Any]: JSON representation of PycordEventStage + """ + return { + "_id": self._id if not json_compatible else str(self._id), + "event_id": self.event_id if not json_compatible else str(self.event_id), + "guild_id": self.guild_id, + "sequence": self.sequence, + "created": self.created, + "creator_id": self.creator_id, + "question": self.question, + "answer": self.answer, + "media": self.media, + } + + @staticmethod + def get_defaults() -> Dict[str, Any]: + return { + "event_id": None, + "guild_id": None, + "sequence": 0, + "created": None, + "creator_id": None, + "question": None, + "answer": None, + "media": [], + } + + @staticmethod + def get_default_value(key: str) -> Any: + if key not in PycordEventStage.get_defaults(): + raise KeyError(f"There's no default value for key '{key}' in PycordEventStage") + + return PycordEventStage.get_defaults()[key] + + # TODO Add documentation + async def update( + self, + cache: Optional[Cache] = None, + **kwargs, + ): + await self._set(cache=cache, **kwargs) + + # TODO Add documentation + async def reset( + self, + cache: Optional[Cache] = None, + *args, + ): + await self._remove(cache, *args) + + async def purge(self, cache: Optional[Cache] = None) -> None: + """Completely remove event stage data from database. Currently only removes the event stage record from events collection. + + Args: + cache (:obj:`Cache`, optional): Cache engine to write the update into + """ + await self.__collection__.delete_one({"_id": self._id}) + self._delete_cache(cache) diff --git a/cogs/stage.py b/cogs/stage.py index 8b1d9e3..fb587ab 100644 --- a/cogs/stage.py +++ b/cogs/stage.py @@ -2,6 +2,7 @@ from discord import SlashCommandGroup, option, ApplicationContext, Attachment from discord.ext.commands import Cog from discord.utils import basic_autocomplete +from classes import PycordGuild, PycordEventStage, PycordEvent from classes.pycord_bot import PycordBot from modules.utils import autofill_active_events @@ -14,8 +15,7 @@ class Stage(Cog): command_group: SlashCommandGroup = SlashCommandGroup("stage", "Event stage management") - # TODO Implement the command - # /stage add + # TODO Introduce i18n # TODO Maybe add an option for order? @command_group.command( name="add", @@ -38,7 +38,27 @@ class Stage(Cog): answer: str, media: Attachment = None, ) -> None: - await ctx.respond("Not implemented.") + guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) + pycord_event: PycordEvent = await self.bot.find_event(event) + + if not guild.is_configured(): + # TODO Make a nice message + await ctx.respond("Guild is not configured.") + return + + event_stage: PycordEventStage = await self.bot.create_event_stage( + event=pycord_event, + event_id=pycord_event._id, + guild_id=guild.id, + sequence=len(pycord_event.stage_ids), + creator_id=ctx.author.id, + question=question, + answer=answer, + media=None if media is None else media.id, + ) + + # TODO Make a nice message + await ctx.respond("Event stage has been created.") # TODO Implement the command # /stage edit