From 08d1ca4d33105554fff85e9b128c1b7de9858992 Mon Sep 17 00:00:00 2001 From: profitroll Date: Mon, 21 Apr 2025 22:49:23 +0200 Subject: [PATCH] Implemented /event edit and /edit cancel --- classes/pycord_bot.py | 7 +- classes/pycord_event.py | 83 +++++++++++++------ classes/pycord_guild.py | 89 ++++++++++++++------ cogs/config.py | 51 ++++++++++-- cogs/event.py | 176 ++++++++++++++++++++++++++-------------- 5 files changed, 286 insertions(+), 120 deletions(-) diff --git a/classes/pycord_bot.py b/classes/pycord_bot.py index eca5b61..0758a1d 100644 --- a/classes/pycord_bot.py +++ b/classes/pycord_bot.py @@ -87,9 +87,6 @@ class PycordBot(LibPycordBot): await super().start(*args, **kwargs) async def close(self, **kwargs) -> None: - if self.scheduler is not None: - self.scheduler.shutdown() - await super().close(**kwargs) async def find_event(self, event_id: str | ObjectId | None = None, event_name: str | None = None): @@ -97,6 +94,6 @@ class PycordBot(LibPycordBot): raise AttributeError("Either event's ID or name must be provided!") if event_id is not None: - await PycordEvent.from_id(event_id, cache=self.cache) + return await PycordEvent.from_id(event_id, cache=self.cache) else: - await PycordEvent.from_name(event_name, cache=self.cache) + return await PycordEvent.from_name(event_name, cache=self.cache) diff --git a/classes/pycord_event.py b/classes/pycord_event.py index 18677f6..9146a8a 100644 --- a/classes/pycord_event.py +++ b/classes/pycord_event.py @@ -1,7 +1,8 @@ from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import datetime from logging import Logger from typing import Any, Dict, List, Optional +from zoneinfo import ZoneInfo from bson import ObjectId from libbot.cache.classes import Cache @@ -21,6 +22,7 @@ class PycordEvent: "guild_id", "created", "ended", + "cancelled", "creator_id", "starts", "ends", @@ -35,6 +37,7 @@ class PycordEvent: guild_id: int created: datetime ended: datetime | None + cancelled: bool creator_id: int starts: datetime ends: datetime @@ -85,28 +88,31 @@ class PycordEvent: if db_entry is None: raise RuntimeError(f"Event with name {event_name} not found") + # TODO Add a unique exception + # raise EventNotFoundError(event_name) + if cache is not None: cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", db_entry) return cls(**db_entry) - # TODO Implement this method @classmethod async def create( - cls, - name: str, - guild_id: int, - creator_id: int, - starts: datetime, - ends: datetime, - thumbnail_id: str | None, - cache: Optional[Cache] = None, + cls, + name: str, + guild_id: int, + creator_id: int, + starts: datetime, + ends: datetime, + thumbnail_id: str | None, + cache: Optional[Cache] = None, ) -> "PycordEvent": db_entry: Dict[str, Any] = { "name": name, "guild_id": guild_id, - "created": datetime.now(tz=timezone.utc), + "created": datetime.now(tz=ZoneInfo("UTC")), "ended": None, + "cancelled": False, "creator_id": creator_id, "starts": starts, "ends": ends, @@ -123,7 +129,8 @@ class PycordEvent: return cls(**db_entry) - async def _set(self, key: str, value: Any, cache: Optional[Cache] = None) -> None: + # TODO Update the docstring + async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None: """Set attribute data and save it into the database. Args: @@ -131,36 +138,43 @@ class PycordEvent: value (Any): Value to set cache (:obj:`Cache`, optional): Cache engine to write the update into """ - if not hasattr(self, key): - raise AttributeError() + for key, value in kwargs.items(): + if not hasattr(self, key): + raise AttributeError() - setattr(self, key, value) + setattr(self, key, value) - await self.__collection__.update_one({"_id": self._id}, {"$set": {key: value}}, upsert=True) + await self.__collection__.update_one({"_id": self._id}, {"$set": kwargs}, upsert=True) self._update_cache(cache) - logger.info("Set attribute '%s' of event %s to '%s'", key, self._id, value) + logger.info("Set attributes of event %s to %s", self._id, kwargs) - async def _remove(self, key: str, cache: Optional[Cache] = None) -> None: + # TODO Update the docstring + async def _remove(self, cache: Optional[Cache] = None, *args: str) -> None: """Remove attribute data and save it into the database. Args: key (str): Attribute to remove cache (:obj:`Cache`, optional): Cache engine to write the update into """ - if not hasattr(self, key): - raise AttributeError() + attributes: Dict[str, Any] = {} - default_value: Any = PycordEvent.get_default_value(key) + for key in args: + if not hasattr(self, key): + raise AttributeError() - setattr(self, key, default_value) + default_value: Any = self.get_default_value(key) - await self.__collection__.update_one({"_id": self._id}, {"$set": {key: default_value}}, upsert=True) + setattr(self, key, default_value) + + attributes[key] = default_value + + await self.__collection__.update_one({"_id": self._id}, {"$set": attributes}, upsert=True) self._update_cache(cache) - logger.info("Removed attribute '%s' of event %s", key, self._id) + logger.info("Reset attributes %s of event %s to default values", args, self._id) def _get_cache_key(self) -> str: return f"{self.__short_name__}_{self._id}" @@ -197,6 +211,7 @@ class PycordEvent: "guild_id": self.guild_id, "created": self.created, "ended": self.ended, + "cancelled": self.cancelled, "creator_id": self.creator_id, "starts": self.starts, "ends": self.ends, @@ -211,6 +226,7 @@ class PycordEvent: "guild_id": None, "created": None, "ended": None, + "cancelled": False, "creator_id": None, "starts": None, "ends": None, @@ -225,6 +241,22 @@ class PycordEvent: return PycordEvent.get_defaults()[key] + # TODO Add documentation + async def update( + self, + cache: Optional[Cache] = None, + **kwargs, + ): + await self._set(cache=cache, **kwargs) + + # TODO Add documentation + async def reset( + self, + cache: Optional[Cache] = None, + *args, + ): + await self._remove(cache, *args) + async def purge(self, cache: Optional[Cache] = None) -> None: """Completely remove event data from database. Currently only removes the event record from events collection. @@ -233,3 +265,6 @@ class PycordEvent: """ await self.__collection__.delete_one({"_id": self._id}) self._delete_cache(cache) + + async def cancel(self, cache: Optional[Cache] = None): + await self._set(cache, cancelled=True) diff --git a/classes/pycord_guild.py b/classes/pycord_guild.py index 4b6fea3..a56a40f 100644 --- a/classes/pycord_guild.py +++ b/classes/pycord_guild.py @@ -17,15 +17,16 @@ logger: Logger = get_logger(__name__) class PycordGuild: """Dataclass of DB entry of a guild""" - __slots__ = ("_id", "id", "channel_id", "category_id", "timezone") + __slots__ = ("_id", "id", "channel_id", "category_id", "timezone", "language") __short_name__ = "guild" __collection__ = col_guilds _id: ObjectId id: int - channel_id: Optional[int] - category_id: Optional[int] + channel_id: int | None + category_id: int | None timezone: str + language: str | None @classmethod async def from_id( @@ -67,7 +68,8 @@ class PycordGuild: return cls(**db_entry) - async def _set(self, key: str, value: Any, cache: Optional[Cache] = None) -> None: + # TODO Update the docstring + async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None: """Set attribute data and save it into the database. Args: @@ -75,36 +77,43 @@ class PycordGuild: value (Any): Value to set cache (:obj:`Cache`, optional): Cache engine to write the update into """ - if not hasattr(self, key): - raise AttributeError() + for key, value in kwargs.items(): + if not hasattr(self, key): + raise AttributeError() - setattr(self, key, value) + setattr(self, key, value) - await self.__collection__.update_one({"_id": self._id}, {"$set": {key: value}}, upsert=True) + await self.__collection__.update_one({"_id": self._id}, {"$set": kwargs}, upsert=True) self._update_cache(cache) - logger.info("Set attribute '%s' of guild %s to '%s'", key, self.id, value) + logger.info("Set attributes of guild %s to %s", self.id, kwargs) - async def _remove(self, key: str, cache: Optional[Cache] = None) -> None: + # TODO Update the docstring + async def _remove(self, cache: Optional[Cache] = None, *args: str) -> None: """Remove attribute data and save it into the database. Args: key (str): Attribute to remove cache (:obj:`Cache`, optional): Cache engine to write the update into """ - if not hasattr(self, key): - raise AttributeError() + attributes: Dict[str, Any] = {} - default_value: Any = PycordGuild.get_default_value(key) + for key in args: + if not hasattr(self, key): + raise AttributeError() - setattr(self, key, default_value) + default_value: Any = self.get_default_value(key) - await self.__collection__.update_one({"_id": self._id}, {"$set": {key: default_value}}, upsert=True) + setattr(self, key, default_value) + + attributes[key] = default_value + + await self.__collection__.update_one({"_id": self._id}, {"$set": attributes}, upsert=True) self._update_cache(cache) - logger.info("Removed attribute '%s' of guild %s", key, self.id) + logger.info("Reset attributes %s of guild %s to default values", args, self.id) def _get_cache_key(self) -> str: return f"{self.__short_name__}_{self.id}" @@ -141,11 +150,18 @@ class PycordGuild: "channel_id": self.channel_id, "category_id": self.category_id, "timezone": self.timezone, + "language": self.language, } @staticmethod def get_defaults(guild_id: Optional[int] = None) -> Dict[str, Any]: - return {"id": guild_id, "channel_id": None, "category_id": None, "timezone": "UTC"} + return { + "id": guild_id, + "channel_id": None, + "category_id": None, + "timezone": "UTC", + "language": None, + } @staticmethod def get_default_value(key: str) -> Any: @@ -154,6 +170,22 @@ class PycordGuild: return PycordGuild.get_defaults()[key] + # TODO Add documentation + async def update( + self, + cache: Optional[Cache] = None, + **kwargs, + ): + await self._set(cache=cache, **kwargs) + + # TODO Add documentation + async def reset( + self, + cache: Optional[Cache] = None, + *args, + ): + await self._remove(cache, *args) + async def purge(self, cache: Optional[Cache] = None) -> None: """Completely remove guild data from database. Currently only removes the guild record from guilds collection. @@ -161,8 +193,11 @@ class PycordGuild: cache (:obj:`Cache`, optional): Cache engine to write the update into """ await self.__collection__.delete_one({"_id": self._id}) + self._delete_cache(cache) + logger.info("Purged guild %s (%s) from the database", self.id, self._id) + # TODO Add documentation def is_configured(self) -> bool: return ( @@ -174,24 +209,32 @@ class PycordGuild: # TODO Add documentation async def set_channel(self, channel_id: Optional[int] = None, cache: Optional[Cache] = None) -> None: - await self._set("channel_id", channel_id, cache) + await self._set(cache, channel_id=channel_id) # TODO Add documentation async def reset_channel(self, cache: Optional[Cache] = None) -> None: - await self._remove("channel_id", cache) + await self._remove(cache, "channel_id") # TODO Add documentation async def set_category(self, category_id: Optional[int] = None, cache: Optional[Cache] = None) -> None: - await self._set("category_id", category_id, cache) + await self._set(cache, category_id=category_id) # TODO Add documentation async def reset_category(self, cache: Optional[Cache] = None) -> None: - await self._remove("category_id", cache) + await self._remove(cache, "category_id") # TODO Add documentation async def set_timezone(self, timezone: str, cache: Optional[Cache] = None) -> None: - await self._set("timezone", timezone, cache) + await self._set(cache, timezone=timezone) # TODO Add documentation async def reset_timezone(self, cache: Optional[Cache] = None) -> None: - await self._remove("timezone", cache) + await self._remove(cache, "timezone") + + # TODO Add documentation + async def set_language(self, language: str, cache: Optional[Cache] = None) -> None: + await self._set(cache, language=language) + + # TODO Add documentation + async def reset_language(self, cache: Optional[Cache] = None) -> None: + await self._remove(cache, "language") diff --git a/cogs/config.py b/cogs/config.py index 4a72c80..932fea5 100644 --- a/cogs/config.py +++ b/cogs/config.py @@ -1,4 +1,5 @@ -from zoneinfo import ZoneInfo, ZoneInfoNotFoundError +from typing import List +from zoneinfo import ZoneInfo, ZoneInfoNotFoundError, available_timezones from discord import ( ApplicationContext, @@ -6,13 +7,27 @@ from discord import ( SlashCommandGroup, TextChannel, option, + AutocompleteContext, ) from discord.ext.commands import Cog +from discord.utils import basic_autocomplete from classes import PycordGuild from classes.pycord_bot import PycordBot +# TODO Move to staticmethod or to a separate module +async def get_timezones(ctx: AutocompleteContext) -> List[str]: + return sorted(list(available_timezones())) + + +# TODO Move to staticmethod or to a separate module +async def get_languages(ctx: AutocompleteContext) -> List[str]: + # TODO Discord normally uses a different set of locales. + # For example, "en" being "en-US", etc. This will require changes to locale handling later. + return ctx.bot.locales.keys() + + class Config(Cog): """Cog with guild configuration commands.""" @@ -29,9 +44,25 @@ class Config(Cog): ) @option("category", description="Category where channels for each user will be created", required=True) @option("channel", description="Text channel for admin notifications", required=True) - @option("timezone", description="Timezone in which events take place", required=True) + @option( + "timezone", + description="Timezone in which events take place", + autocomplete=basic_autocomplete(get_timezones), + required=True, + ) + @option( + "language", + description="Language for bot's messages", + autocomplete=basic_autocomplete(get_languages), + required=True, + ) async def command_config_set( - self, ctx: ApplicationContext, category: CategoryChannel, channel: TextChannel, timezone: str + self, + ctx: ApplicationContext, + category: CategoryChannel, + channel: TextChannel, + timezone: str, + language: str, ) -> None: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) @@ -41,9 +72,13 @@ class Config(Cog): await ctx.respond(f"Timezone {timezone} was not found.") return - await guild.set_channel(channel.id, cache=self.bot.cache) - await guild.set_category(category.id, cache=self.bot.cache) - await guild.set_timezone(str(timezone_parsed), cache=self.bot.cache) + await guild.update( + self.bot.cache, + channel_id=channel.id, + category_id=category.id, + timezone=str(timezone_parsed), + language=language, + ) # TODO Make a nice message await ctx.respond("Okay.") @@ -62,9 +97,7 @@ class Config(Cog): guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) - await guild.reset_channel(cache=self.bot.cache) - await guild.reset_category(cache=self.bot.cache) - await guild.reset_timezone(cache=self.bot.cache) + await guild.purge(self.bot.cache) # TODO Make a nice message await ctx.respond("Okay.") diff --git a/cogs/event.py b/cogs/event.py index 076b722..3c1442e 100644 --- a/cogs/event.py +++ b/cogs/event.py @@ -3,7 +3,14 @@ from typing import Dict, Any, List from zoneinfo import ZoneInfo from bson import ObjectId -from discord import ApplicationContext, Attachment, SlashCommandGroup, option, AutocompleteContext +from discord import ( + ApplicationContext, + Attachment, + SlashCommandGroup, + option, + AutocompleteContext, + OptionChoice, +) from discord.ext.commands import Cog from discord.utils import basic_autocomplete @@ -12,38 +19,51 @@ from classes.pycord_bot import PycordBot from modules.database import col_events -async def get_event(ctx: AutocompleteContext): - query: Dict[str, Any] = {"ended": None, "ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}} +# TODO Move to staticmethod or to a separate module +async def get_event(ctx: AutocompleteContext) -> List[OptionChoice]: + query: Dict[str, Any] = { + "ended": None, + "ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}, + "cancelled": {"$ne": True}, + } - event_names: List[str] = [] + event_names: List[OptionChoice] = [] async for result in col_events.find(query): - event_names.append(result["name"]) + event_names.append(OptionChoice(result["name"], str(result["_id"]))) return event_names async def validate_event_validity( - ctx: ApplicationContext, - name: str, - start_date: datetime | None, - finish_date: datetime | None, - guild_timezone: ZoneInfo, - event_id: ObjectId | None = None, + ctx: ApplicationContext, + name: str, + start_date: datetime | None, + finish_date: datetime | None, + guild_timezone: ZoneInfo, + event_id: ObjectId | None = None, ) -> None: if start_date > finish_date: + # TODO Make a nice message await ctx.respond("Start date must be before finish date") return elif start_date < datetime.now(tz=guild_timezone): + # TODO Make a nice message await ctx.respond("Start date must not be in the past") return - query: Dict[str, Any] = {"name": name, "ended": None, "ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}} + query: Dict[str, Any] = { + "name": name, + "ended": None, + "ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}, + "cancelled": {"$ne": True}, + } if event_id is not None: query["_id"] = {"$ne": event_id} if (await col_events.find_one(query)) is not None: + # TODO Make a nice message await ctx.respond("There can only be one active event with the same name") return @@ -57,26 +77,27 @@ class Event(Cog): # TODO Introduce i18n command_group: SlashCommandGroup = SlashCommandGroup("event", "Event management") - # TODO Implement the command + # TODO Introduce i18n @command_group.command( name="create", description="Create new event", ) @option("name", description="Name of the event", required=True) @option("start", description="Date when the event starts (DD.MM.YYYY HH:MM)", required=True) - @option("finish", description="Date when the event finishes (DD.MM.YYYY HH:MM)", required=True) + @option("end", description="Date when the event ends (DD.MM.YYYY HH:MM)", required=True) @option("thumbnail", description="Thumbnail of the event", required=False) async def command_event_create( - self, - ctx: ApplicationContext, - name: str, - start: str, - finish: str, - thumbnail: Attachment = None, + self, + ctx: ApplicationContext, + name: str, + start: str, + end: str, + thumbnail: Attachment = None, ) -> None: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) if not guild.is_configured(): + # TODO Make a nice message await ctx.respond("Guild is not configured.") return @@ -84,28 +105,30 @@ class Event(Cog): try: start_date: datetime = datetime.strptime(start, "%d.%m.%Y %H:%M") - finish_date: datetime = datetime.strptime(finish, "%d.%m.%Y %H:%M") + end_date: datetime = datetime.strptime(end, "%d.%m.%Y %H:%M") start_date = start_date.replace(tzinfo=guild_timezone) - finish_date = finish_date.replace(tzinfo=guild_timezone) + end_date = end_date.replace(tzinfo=guild_timezone) except ValueError: - await ctx.respond("Could not parse start and finish dates.") + # TODO Make a nice message + await ctx.respond("Could not parse start and end dates.") return - await validate_event_validity(ctx, name, start_date, finish_date, guild_timezone) + await validate_event_validity(ctx, name, start_date, end_date, guild_timezone) event: PycordEvent = await self.bot.create_event( name=name, guild_id=guild.id, creator_id=ctx.author.id, starts=start_date.astimezone(ZoneInfo("UTC")), - ends=finish_date.astimezone(ZoneInfo("UTC")), + ends=end_date.astimezone(ZoneInfo("UTC")), thumbnail_id=thumbnail.id if thumbnail else None, ) + # TODO Make a nice message await ctx.respond("Event has been created.") - # TODO Implement the command + # TODO Introduce i18n @command_group.command( name="edit", description="Edit event", @@ -115,21 +138,23 @@ class Event(Cog): ) @option("name", description="New name of the event", required=False) @option("start", description="Date when the event starts (DD.MM.YYYY HH:MM)", required=False) - @option("finish", description="Date when the event finishes (DD.MM.YYYY HH:MM)", required=False) + @option("end", description="Date when the event ends (DD.MM.YYYY HH:MM)", required=False) @option("thumbnail", description="Thumbnail of the event", required=False) async def command_event_edit( - self, - ctx: ApplicationContext, - event: str, - name: str = None, - start: str = None, - finish: str = None, - thumbnail: Attachment = None, + self, + ctx: ApplicationContext, + event: str, + name: str = None, + start: str = None, + end: str = None, + thumbnail: Attachment = None, ) -> None: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) - event: PycordEvent = await self.bot.find_event( - event_name=name - ) + pycord_event: PycordEvent = await self.bot.find_event(event_id=event) + + if pycord_event is None: + await ctx.respond("Event was not found.") + return if not guild.is_configured(): await ctx.respond("Guild is not configured.") @@ -137,42 +162,75 @@ class Event(Cog): guild_timezone: ZoneInfo = ZoneInfo(guild.timezone) - if start is not None: - try: - start_date: datetime = datetime.strptime(start, "%d.%m.%Y %H:%M") - start_date = start_date.replace(tzinfo=guild_timezone) + try: + start_date: datetime = ( + pycord_event.starts if start is None else datetime.strptime(start, "%d.%m.%Y %H:%M") + ) + start_date = start_date.replace(tzinfo=guild_timezone) + except ValueError: + await ctx.respond("Could not parse the start date.") + return - await event.set_start_date(start_date) - except ValueError: - await ctx.respond("Could not parse the start date.") - return + try: + end_date: datetime = ( + pycord_event.ends if end is None else datetime.strptime(end, "%d.%m.%Y %H:%M") + ) + end_date = end_date.replace(tzinfo=guild_timezone) + except ValueError: + await ctx.respond("Could not parse the end date.") + return - if finish is not None: - try: - finish_date: datetime = datetime.strptime(finish, "%d.%m.%Y %H:%M") - finish_date = finish_date.replace(tzinfo=guild_timezone) + await validate_event_validity(ctx, name, start_date, end_date, guild_timezone) - await event.set_end_date(finish_date) - except ValueError: - await ctx.respond("Could not parse the finish date.") - return + await pycord_event.update( + self.bot.cache, + starts=start_date, + ends=end_date, + name=pycord_event.name if name is None else name, + thumbnail_id=pycord_event.thumbnail_id if thumbnail is None else thumbnail.id, + ) - await validate_event_validity(ctx, name, start_date, finish_date, guild_timezone) + await ctx.respond("Event has been updated.") # TODO Implement the command @command_group.command( name="cancel", description="Cancel event", ) - @option("name", description="Name of the event", required=True) + @option( + "event", description="Name of the event", autocomplete=basic_autocomplete(get_event), required=True + ) async def command_event_cancel( - self, - ctx: ApplicationContext, - name: str, + self, + ctx: ApplicationContext, + event: str, ) -> None: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) + pycord_event: PycordEvent = await self.bot.find_event(event_id=event) - await ctx.respond("Not implemented.") + if pycord_event is None: + await ctx.respond("Event was not found.") + return + + if not guild.is_configured(): + await ctx.respond("Guild is not configured.") + return + + start_date: datetime = pycord_event.starts.replace(tzinfo=ZoneInfo("UTC")) + end_date: datetime = pycord_event.ends.replace(tzinfo=ZoneInfo("UTC")) + + # TODO Make ongoing events cancellable + if ( + pycord_event.ended is not None + or end_date <= datetime.now(tz=ZoneInfo("UTC")) + or start_date <= datetime.now(tz=ZoneInfo("UTC")) + ): + await ctx.respond("Finished or ongoing events cannot be cancelled.") + return + + await pycord_event.cancel() + + await ctx.respond("Event was cancelled.") def setup(bot: PycordBot) -> None: