From c1d862047842ad2aea4498e8acdb7c00facd0591 Mon Sep 17 00:00:00 2001 From: profitroll Date: Thu, 24 Apr 2025 00:16:53 +0200 Subject: [PATCH] Cleanups and bugfixes for (#2 and #8) --- classes/pycord_bot.py | 4 +- classes/pycord_event.py | 17 +++--- classes/pycord_event_stage.py | 4 +- classes/pycord_guild.py | 4 +- classes/pycord_user.py | 80 ++++++++++++++++++++++++++--- cogs/event.py | 55 ++++---------------- cogs/stage.py | 38 +++++++++++--- modules/__init__.py | 2 +- modules/utils/__init__.py | 3 ++ modules/utils/autocomplete_utils.py | 41 ++++++++++++++- modules/utils/event_utils.py | 46 +++++++++++++++++ 11 files changed, 222 insertions(+), 72 deletions(-) create mode 100644 modules/utils/event_utils.py diff --git a/classes/pycord_bot.py b/classes/pycord_bot.py index d69104a..02d6718 100644 --- a/classes/pycord_bot.py +++ b/classes/pycord_bot.py @@ -113,7 +113,9 @@ class PycordBot(LibPycordBot): return event_stage # TODO Document this method - async def find_event(self, event_id: str | ObjectId | None = None, event_name: str | None = None): + async def find_event( + self, event_id: str | ObjectId | None = None, event_name: str | None = None + ) -> PycordEvent: if event_id is None and event_name is None: raise AttributeError("Either event's ID or name must be provided!") diff --git a/classes/pycord_event.py b/classes/pycord_event.py index 32518de..4c95e97 100644 --- a/classes/pycord_event.py +++ b/classes/pycord_event.py @@ -23,7 +23,7 @@ class PycordEvent: "guild_id", "created", "ended", - "cancelled", + "is_cancelled", "creator_id", "starts", "ends", @@ -38,7 +38,7 @@ class PycordEvent: guild_id: int created: datetime ended: datetime | None - cancelled: bool + is_cancelled: bool creator_id: int starts: datetime ends: datetime @@ -58,6 +58,7 @@ class PycordEvent: Raises: EventNotFoundError: Event was not found + InvalidId: Invalid event ID was provided """ cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache) @@ -112,7 +113,7 @@ class PycordEvent: "guild_id": guild_id, "created": datetime.now(tz=ZoneInfo("UTC")), "ended": None, - "cancelled": False, + "is_cancelled": False, "creator_id": creator_id, "starts": starts, "ends": ends, @@ -129,12 +130,12 @@ class PycordEvent: return cls(**db_entry) - async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None: + async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None: """Set attribute data and save it into the database. Args: cache (:obj:`Cache`, optional): Cache engine to write the update into - **kwargs (str): Mapping of attribute names and respective values to be set + **kwargs (Any): Mapping of attribute names and respective values to be set """ for key, value in kwargs.items(): if not hasattr(self, key): @@ -208,7 +209,7 @@ class PycordEvent: "guild_id": self.guild_id, "created": self.created, "ended": self.ended, - "cancelled": self.cancelled, + "is_cancelled": self.is_cancelled, "creator_id": self.creator_id, "starts": self.starts, "ends": self.ends, @@ -223,7 +224,7 @@ class PycordEvent: "guild_id": None, "created": None, "ended": None, - "cancelled": False, + "is_cancelled": False, "creator_id": None, "starts": None, "ends": None, @@ -265,7 +266,7 @@ class PycordEvent: # TODO Add documentation async def cancel(self, cache: Optional[Cache] = None): - await self._set(cache, cancelled=True) + await self._set(cache, is_cancelled=True) async def _update_event_stage_order( self, diff --git a/classes/pycord_event_stage.py b/classes/pycord_event_stage.py index 82c84a5..49f9606 100644 --- a/classes/pycord_event_stage.py +++ b/classes/pycord_event_stage.py @@ -106,12 +106,12 @@ class PycordEventStage: return cls(**db_entry) - async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None: + async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None: """Set attribute data and save it into the database. Args: cache (:obj:`Cache`, optional): Cache engine to write the update into - **kwargs (str): Mapping of attribute names and respective values to be set + **kwargs (Any): Mapping of attribute names and respective values to be set """ for key, value in kwargs.items(): if not hasattr(self, key): diff --git a/classes/pycord_guild.py b/classes/pycord_guild.py index aa37c54..11765ae 100644 --- a/classes/pycord_guild.py +++ b/classes/pycord_guild.py @@ -67,12 +67,12 @@ class PycordGuild: return cls(**db_entry) - async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None: + async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None: """Set attribute data and save it into the database. Args: cache (:obj:`Cache`, optional): Cache engine to write the update into - **kwargs (str): Mapping of attribute names and respective values to be set + **kwargs (Any): Mapping of attribute names and respective values to be set """ for key, value in kwargs.items(): if not hasattr(self, key): diff --git a/classes/pycord_user.py b/classes/pycord_user.py index da25146..3c4c2be 100644 --- a/classes/pycord_user.py +++ b/classes/pycord_user.py @@ -1,12 +1,11 @@ from dataclasses import dataclass from logging import Logger -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from bson import ObjectId from libbot.cache.classes import Cache from pymongo.results import InsertOneResult -from classes.abstract.cacheable import Cacheable from classes.errors.pycord_user import UserNotFoundError from modules.database import col_users from modules.utils import get_logger, restore_from_cache @@ -15,10 +14,20 @@ logger: Logger = get_logger(__name__) @dataclass -class PycordUser(Cacheable): +class PycordUser: """Dataclass of DB entry of a user""" - __slots__ = ("_id", "id", "guild_id", "channel_id", "current_event_id", "current_stage_id") + __slots__ = ( + "_id", + "id", + "guild_id", + "channel_id", + "is_jailed", + "current_event_id", + "current_stage_id", + "registered_event_ids", + "completed_event_ids", + ) __short_name__ = "user" __collection__ = col_users @@ -26,8 +35,11 @@ class PycordUser(Cacheable): id: int guild_id: int channel_id: int | None + is_jailed: bool current_event_id: ObjectId | None current_stage_id: ObjectId | None + registered_event_ids: List[ObjectId] + completed_event_ids: List[ObjectId] @classmethod async def from_id( @@ -82,20 +94,31 @@ class PycordUser(Cacheable): "id": self.id, "guild_id": self.guild_id, "channel_id": self.channel_id, + "is_jailed": self.is_jailed, "current_event_id": ( self.current_event_id if not json_compatible else str(self.current_event_id) ), "current_stage_id": ( self.current_stage_id if not json_compatible else str(self.current_stage_id) ), + "registered_event_ids": ( + self.registered_event_ids + if not json_compatible + else [str(event_id) for event_id in self.registered_event_ids] + ), + "completed_event_ids": ( + self.completed_event_ids + if not json_compatible + else [str(event_id) for event_id in self.completed_event_ids] + ), } - async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None: + async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None: """Set attribute data and save it into the database. Args: cache (:obj:`Cache`, optional): Cache engine to write the update into - **kwargs (str): Mapping of attribute names and respective values to be set + **kwargs (Any): Mapping of attribute names and respective values to be set """ for key, value in kwargs.items(): if not hasattr(self, key): @@ -160,8 +183,11 @@ class PycordUser(Cacheable): "id": user_id, "guild_id": guild_id, "channel_id": None, + "is_jailed": False, "current_event_id": None, "current_stage_id": None, + "registered_event_ids": [], + "completed_event_ids": [], } @staticmethod @@ -179,3 +205,45 @@ class PycordUser(Cacheable): """ await self.__collection__.delete_one({"_id": self._id}) self._delete_cache(cache) + + # TODO Add documentation + async def event_register(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None: + event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id + + if event_id in self.registered_event_ids: + raise RuntimeError(f"User is already registered for event {event_id}") + + # TODO Add a unique exception + # raise UserAlreadyRegisteredForEventError(event_name) + + self.registered_event_ids.append(event_id) + + await self._set(cache, registered_event_ids=self.registered_event_ids) + + # TODO Add documentation + async def event_unregister(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None: + event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id + + if event_id not in self.registered_event_ids: + raise RuntimeError(f"User is not registered for event {event_id}") + + # TODO Add a unique exception + # raise UserNotRegisteredForEventError(event_name) + + self.registered_event_ids.remove(event_id) + + await self._set(cache, registered_event_ids=self.registered_event_ids) + + # TODO Add documentation + async def event_complete(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None: + event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id + + if event_id in self.completed_event_ids: + raise RuntimeError(f"User has already completed event {event_id}") + + # TODO Add a unique exception + # raise UserAlreadyCompletedEventError(event_name) + + self.completed_event_ids.append(event_id) + + await self._set(cache, completed_event_ids=self.completed_event_ids) diff --git a/cogs/event.py b/cogs/event.py index 36dd577..ac758d0 100644 --- a/cogs/event.py +++ b/cogs/event.py @@ -1,8 +1,7 @@ from datetime import datetime -from typing import Any, Dict from zoneinfo import ZoneInfo -from bson import ObjectId +from bson.errors import InvalidId from discord import ( ApplicationContext, Attachment, @@ -14,42 +13,7 @@ from discord.utils import basic_autocomplete from classes import PycordEvent, PycordGuild from classes.pycord_bot import PycordBot -from modules.database import col_events -from modules.utils import autocomplete_active_events - - -# TODO Move to staticmethod or to a separate module -async def validate_event_validity( - ctx: ApplicationContext, - name: str, - start_date: datetime | None, - finish_date: datetime | None, - guild_timezone: ZoneInfo, - event_id: ObjectId | None = None, -) -> None: - if start_date > finish_date: - # TODO Make a nice message - await ctx.respond("Start date must be before finish date") - return - elif start_date < datetime.now(tz=guild_timezone): - # TODO Make a nice message - await ctx.respond("Start date must not be in the past") - return - - query: Dict[str, Any] = { - "name": name, - "ended": None, - "ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}, - "cancelled": {"$ne": True}, - } - - if event_id is not None: - query["_id"] = {"$ne": event_id} - - if (await col_events.find_one(query)) is not None: - # TODO Make a nice message - await ctx.respond("There can only be one active event with the same name") - return +from modules.utils import autocomplete_active_events, validate_event_validity class Event(Cog): @@ -137,9 +101,10 @@ class Event(Cog): thumbnail: Attachment = None, ) -> None: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) - pycord_event: PycordEvent = await self.bot.find_event(event_id=event) - if pycord_event is None: + try: + pycord_event: PycordEvent = await self.bot.find_event(event_id=event) + except (InvalidId, RuntimeError): # TODO Make a nice message await ctx.respond("Event was not found.") return @@ -208,9 +173,10 @@ class Event(Cog): return guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) - pycord_event: PycordEvent = await self.bot.find_event(event_id=event) - if pycord_event is None: + try: + pycord_event: PycordEvent = await self.bot.find_event(event_id=event) + except (InvalidId, RuntimeError): # TODO Make a nice message await ctx.respond("Event was not found.") return @@ -251,9 +217,10 @@ class Event(Cog): ) async def command_event_show(self, ctx: ApplicationContext, event: str) -> None: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) - pycord_event: PycordEvent = await self.bot.find_event(event_id=event) - if pycord_event is None: + try: + pycord_event: PycordEvent = await self.bot.find_event(event_id=event) + except (InvalidId, RuntimeError): # TODO Make a nice message await ctx.respond("Event was not found.") return diff --git a/cogs/stage.py b/cogs/stage.py index f6ff126..e1b1a82 100644 --- a/cogs/stage.py +++ b/cogs/stage.py @@ -1,6 +1,7 @@ from datetime import datetime from zoneinfo import ZoneInfo +from bson.errors import InvalidId from discord import ApplicationContext, Attachment, SlashCommandGroup, option from discord.ext.commands import Cog from discord.utils import basic_autocomplete @@ -14,7 +15,7 @@ async def validate_event_status( ctx: ApplicationContext, event: PycordEvent, ) -> None: - if event.cancelled: + if event.is_cancelled: # TODO Make a nice message await ctx.respond("This event was cancelled.") return @@ -63,7 +64,12 @@ class Stage(Cog): await ctx.respond("Guild is not configured.") return - pycord_event: PycordEvent = await self.bot.find_event(event) + try: + pycord_event: PycordEvent = await self.bot.find_event(event_id=event) + except (InvalidId, RuntimeError): + # TODO Make a nice message + await ctx.respond("Event was not found.") + return await validate_event_status(ctx, pycord_event) @@ -123,11 +129,21 @@ class Stage(Cog): await ctx.respond("Guild is not configured.") return - pycord_event: PycordEvent = await self.bot.find_event(event) + try: + pycord_event: PycordEvent = await self.bot.find_event(event_id=event) + except (InvalidId, RuntimeError): + # TODO Make a nice message + await ctx.respond("Event was not found.") + return await validate_event_status(ctx, pycord_event) - event_stage: PycordEventStage = await self.bot.find_event_stage(stage) + try: + event_stage: PycordEventStage = await self.bot.find_event_stage(stage) + except (InvalidId, RuntimeError): + # TODO Make a nice message + await ctx.respond("Event stage was not found.") + return if order is not None and order > len(pycord_event.stage_ids): # TODO Make a nice message @@ -180,11 +196,21 @@ class Stage(Cog): await ctx.respond("Guild is not configured.") return - pycord_event: PycordEvent = await self.bot.find_event(event) + try: + pycord_event: PycordEvent = await self.bot.find_event(event_id=event) + except (InvalidId, RuntimeError): + # TODO Make a nice message + await ctx.respond("Event was not found.") + return await validate_event_status(ctx, pycord_event) - event_stage: PycordEventStage = await self.bot.find_event_stage(stage) + try: + event_stage: PycordEventStage = await self.bot.find_event_stage(stage) + except (InvalidId, RuntimeError): + # TODO Make a nice message + await ctx.respond("Event stage was not found.") + return await pycord_event.remove_stage(self.bot, event_stage._id, cache=self.bot.cache) await event_stage.purge(cache=self.bot.cache) diff --git a/modules/__init__.py b/modules/__init__.py index e686dad..fe17ef5 100644 --- a/modules/__init__.py +++ b/modules/__init__.py @@ -1 +1 @@ -from . import utils, database, migrator, scheduler +from . import database, migrator, scheduler, utils diff --git a/modules/utils/__init__.py b/modules/utils/__init__.py index 0c02512..12b22d3 100644 --- a/modules/utils/__init__.py +++ b/modules/utils/__init__.py @@ -3,6 +3,9 @@ from .autocomplete_utils import ( autocomplete_event_stages, autocomplete_languages, autocomplete_timezones, + autocomplete_user_available_events, + autocomplete_user_registered_events, ) from .cache_utils import restore_from_cache +from .event_utils import validate_event_validity from .logging_utils import get_logger, get_logging_config diff --git a/modules/utils/autocomplete_utils.py b/modules/utils/autocomplete_utils.py index 74e751b..94d26c0 100644 --- a/modules/utils/autocomplete_utils.py +++ b/modules/utils/autocomplete_utils.py @@ -6,7 +6,7 @@ from bson import ObjectId from discord import AutocompleteContext, OptionChoice from pymongo import ASCENDING -from modules.database import col_events, col_stages +from modules.database import col_events, col_stages, col_users async def autocomplete_timezones(ctx: AutocompleteContext) -> List[str]: @@ -29,7 +29,7 @@ async def autocomplete_active_events(ctx: AutocompleteContext) -> List[OptionCho query: Dict[str, Any] = { "ended": None, "ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}, - "cancelled": {"$ne": True}, + "is_cancelled": {"$ne": True}, } event_names: List[OptionChoice] = [] @@ -40,6 +40,43 @@ async def autocomplete_active_events(ctx: AutocompleteContext) -> List[OptionCho return event_names +async def autocomplete_user_available_events(ctx: AutocompleteContext) -> List[OptionChoice]: + """Return list of active events user can register in""" + + return await autocomplete_active_events(ctx) + + +async def autocomplete_user_registered_events(ctx: AutocompleteContext) -> List[OptionChoice]: + """Return list of active events user is registered in""" + + pipeline: List[Dict[str, Any]] = [ + { + "$lookup": { + "from": "events", + "localField": "registered_event_ids", + "foreignField": "_id", + "as": "registered_events", + } + }, + { + "$match": { + "registered_events.ended": None, + "registered_events.ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}, + "registered_events.starts": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}, + "registered_events.is_cancelled": {"$ne": True}, + } + }, + ] + + event_names: List[OptionChoice] = [] + + async for result in col_users.aggregate(pipeline): + for registered_event in result["registered_events"]: + event_names.append(OptionChoice(registered_event["name"], str(registered_event["_id"]))) + + return event_names + + async def autocomplete_event_stages(ctx: AutocompleteContext) -> List[OptionChoice]: """Return list of stages of the event""" diff --git a/modules/utils/event_utils.py b/modules/utils/event_utils.py new file mode 100644 index 0000000..b40ae42 --- /dev/null +++ b/modules/utils/event_utils.py @@ -0,0 +1,46 @@ +from datetime import datetime +from typing import Any, Dict +from zoneinfo import ZoneInfo + +from bson import ObjectId +from discord import ( + ApplicationContext, +) + +from modules.database import col_events + + +async def validate_event_validity( + ctx: ApplicationContext, + name: str, + start_date: datetime | None, + finish_date: datetime | None, + guild_timezone: ZoneInfo, + event_id: ObjectId | None = None, +) -> None: + if start_date > finish_date: + # TODO Make a nice message + await ctx.respond("Start date must be before finish date") + return + + if start_date < datetime.now(tz=guild_timezone): + # TODO Make a nice message + await ctx.respond("Start date must not be in the past") + return + + # TODO Add validation for concurrent events. + # Only one event can take place at the same time. + query: Dict[str, Any] = { + "name": name, + "ended": None, + "ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}, + "is_cancelled": {"$ne": True}, + } + + if event_id is not None: + query["_id"] = {"$ne": event_id} + + if (await col_events.find_one(query)) is not None: + # TODO Make a nice message + await ctx.respond("There can only be one active event with the same name") + return