diff --git a/classes/pycord_bot.py b/classes/pycord_bot.py index c32f237..5d6775c 100644 --- a/classes/pycord_bot.py +++ b/classes/pycord_bot.py @@ -6,7 +6,7 @@ from libbot.cache.classes import CacheMemcached, CacheRedis from libbot.cache.manager import create_cache_client from libbot.pycord.classes import PycordBot as LibPycordBot -from classes import PycordGuild, PycordUser +from classes import PycordEvent, PycordGuild, PycordUser from modules.logging_utils import get_logger logger: Logger = get_logger(__name__) @@ -78,6 +78,10 @@ class PycordBot(LibPycordBot): else await PycordGuild.from_id(guild.id, cache=self.cache) ) + # TODO Document this method + async def create_event(self, **kwargs) -> PycordEvent: + return await PycordEvent.create(**kwargs, cache=self.cache) + async def start(self, *args: Any, **kwargs: Any) -> None: await super().start(*args, **kwargs) diff --git a/classes/pycord_event.py b/classes/pycord_event.py index 3ab7eb5..0f45922 100644 --- a/classes/pycord_event.py +++ b/classes/pycord_event.py @@ -1,17 +1,216 @@ from dataclasses import dataclass -from datetime import datetime -from typing import List +from datetime import datetime, timezone +from logging import Logger +from typing import Any, Dict, List, Optional from bson import ObjectId +from libbot.cache.classes import Cache +from pymongo.results import InsertOneResult + +from modules.database import col_events +from modules.logging_utils import get_logger + +logger: Logger = get_logger(__name__) @dataclass class PycordEvent: + __slots__ = ( + "_id", + "name", + "guild_id", + "created", + "creator_id", + "starts", + "ends", + "thumbnail_id", + "stage_ids", + ) + __short_name__ = "event" + __collection__ = col_events + _id: ObjectId - id: int + name: str guild_id: int created: datetime creator_id: int starts: datetime ends: datetime + thumbnail_id: str | None stage_ids: List[int] + + @classmethod + async def from_id(cls, event_id: str | ObjectId, cache: Optional[Cache] = None) -> "PycordEvent": + """Find event in the database. + + Args: + event_id (str | ObjectId): Event's ID + cache (:obj:`Cache`, optional): Cache engine to get the cache from + + Returns: + PycordEvent: Event object + + Raises: + EventNotFoundError: Event was not found + """ + if cache is not None: + cached_entry: Dict[str, Any] | None = cache.get_json(f"{cls.__short_name__}_{event_id}") + + if cached_entry is not None: + return cls(**cached_entry) + + db_entry = await cls.__collection__.find_one( + {"_id": event_id if isinstance(event_id, ObjectId) else ObjectId(event_id)} + ) + + if db_entry is None: + raise RuntimeError(f"Event {event_id} not found") + + # TODO Add a unique exception + # raise EventNotFoundError(event_id) + + if cache is not None: + cache.set_json(f"{cls.__short_name__}_{event_id}", db_entry) + + return cls(**db_entry) + + # TODO Implement this method + @classmethod + async def create( + cls, + name: str, + guild_id: int, + creator_id: int, + starts: datetime, + ends: datetime, + thumbnail_id: str | None, + cache: Optional[Cache] = None, + ) -> "PycordEvent": + db_entry: Dict[str, Any] = { + "name": name, + "guild_id": guild_id, + "created": datetime.now(tz=timezone.utc), + "creator_id": creator_id, + "starts": starts, + "ends": ends, + "thumbnail_id": thumbnail_id, + "stage_ids": [], + } + + insert_result: InsertOneResult = await cls.__collection__.insert_one(db_entry) + + db_entry["_id"] = insert_result.inserted_id + + if cache is not None: + cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry) + + return cls(**db_entry) + + async def _set(self, key: str, value: Any, cache: Optional[Cache] = None) -> None: + """Set attribute data and save it into the database. + + Args: + key (str): Attribute to change + value (Any): Value to set + cache (:obj:`Cache`, optional): Cache engine to write the update into + """ + if not hasattr(self, key): + raise AttributeError() + + setattr(self, key, value) + + await self.__collection__.update_one({"_id": self._id}, {"$set": {key: value}}, upsert=True) + + self._update_cache(cache) + + logger.info("Set attribute '%s' of event %s to '%s'", key, self._id, value) + + async def _remove(self, key: str, cache: Optional[Cache] = None) -> None: + """Remove attribute data and save it into the database. + + Args: + key (str): Attribute to remove + cache (:obj:`Cache`, optional): Cache engine to write the update into + """ + if not hasattr(self, key): + raise AttributeError() + + default_value: Any = PycordEvent.get_default_value(key) + + setattr(self, key, default_value) + + await self.__collection__.update_one({"_id": self._id}, {"$set": {key: default_value}}, upsert=True) + + self._update_cache(cache) + + logger.info("Removed attribute '%s' of event %s", key, self._id) + + def _get_cache_key(self) -> str: + return f"{self.__short_name__}_{self._id}" + + def _update_cache(self, cache: Optional[Cache] = None) -> None: + if cache is None: + return + + user_dict: Dict[str, Any] = self.to_dict() + + if user_dict is not None: + cache.set_json(self._get_cache_key(), user_dict) + else: + self._delete_cache(cache) + + def _delete_cache(self, cache: Optional[Cache] = None) -> None: + if cache is None: + return + + cache.delete(self._get_cache_key()) + + def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]: + """Convert PycordEvent object to a JSON representation. + + Args: + json_compatible (bool): Whether the JSON-incompatible objects like ObjectId need to be converted + + Returns: + Dict[str, Any]: JSON representation of PycordEvent + """ + return { + "_id": self._id if not json_compatible else str(self._id), + "name": self.name, + "guild_id": self.guild_id, + "created": self.created, + "creator_id": self.creator_id, + "starts": self.starts, + "ends": self.ends, + "thumbnail_id": self.thumbnail_id, + "stage_ids": self.stage_ids, + } + + @staticmethod + def get_defaults() -> Dict[str, Any]: + return { + "name": None, + "guild_id": None, + "created": None, + "creator_id": None, + "starts": None, + "ends": None, + "thumbnail_id": None, + "stage_ids": [], + } + + @staticmethod + def get_default_value(key: str) -> Any: + if key not in PycordEvent.get_defaults(): + raise KeyError(f"There's no default value for key '{key}' in PycordEvent") + + return PycordEvent.get_defaults()[key] + + async def purge(self, cache: Optional[Cache] = None) -> None: + """Completely remove event data from database. Currently only removes the event record from events collection. + + Args: + cache (:obj:`Cache`, optional): Cache engine to write the update into + """ + await self.__collection__.delete_one({"_id": self._id}) + self._delete_cache(cache) diff --git a/classes/pycord_guild.py b/classes/pycord_guild.py index 68f8db9..4b6fea3 100644 --- a/classes/pycord_guild.py +++ b/classes/pycord_guild.py @@ -17,7 +17,7 @@ logger: Logger = get_logger(__name__) class PycordGuild: """Dataclass of DB entry of a guild""" - __slots__ = ("_id", "id", "channel_id", "category_id") + __slots__ = ("_id", "id", "channel_id", "category_id", "timezone") __short_name__ = "guild" __collection__ = col_guilds @@ -25,6 +25,7 @@ class PycordGuild: id: int channel_id: Optional[int] category_id: Optional[int] + timezone: str @classmethod async def from_id( @@ -83,7 +84,7 @@ class PycordGuild: self._update_cache(cache) - logger.info("Set attribute '%s' of user %s to '%s'", key, self.id, value) + logger.info("Set attribute '%s' of guild %s to '%s'", key, self.id, value) async def _remove(self, key: str, cache: Optional[Cache] = None) -> None: """Remove attribute data and save it into the database. @@ -139,11 +140,12 @@ class PycordGuild: "id": self.id, "channel_id": self.channel_id, "category_id": self.category_id, + "timezone": self.timezone, } @staticmethod def get_defaults(guild_id: Optional[int] = None) -> Dict[str, Any]: - return {"id": guild_id, "channel_id": None, "category_id": None} + return {"id": guild_id, "channel_id": None, "category_id": None, "timezone": "UTC"} @staticmethod def get_default_value(key: str) -> Any: @@ -162,17 +164,34 @@ class PycordGuild: self._delete_cache(cache) # TODO Add documentation - async def set_channel(self, channel_id: Optional[int] = None, cache: Optional[Cache] = None) -> None: - await self._set("channel_id", channel_id, cache) + def is_configured(self) -> bool: + return ( + (self.id is not None) + and (self.channel_id is not None) + and (self.category_id is not None) + and (self.timezone is not None) + ) # TODO Add documentation - async def set_category(self, category_id: Optional[int] = None, cache: Optional[Cache] = None) -> None: - await self._set("category_id", category_id, cache) + async def set_channel(self, channel_id: Optional[int] = None, cache: Optional[Cache] = None) -> None: + await self._set("channel_id", channel_id, cache) # TODO Add documentation async def reset_channel(self, cache: Optional[Cache] = None) -> None: await self._remove("channel_id", cache) + # TODO Add documentation + async def set_category(self, category_id: Optional[int] = None, cache: Optional[Cache] = None) -> None: + await self._set("category_id", category_id, cache) + # TODO Add documentation async def reset_category(self, cache: Optional[Cache] = None) -> None: await self._remove("category_id", cache) + + # TODO Add documentation + async def set_timezone(self, timezone: str, cache: Optional[Cache] = None) -> None: + await self._set("timezone", timezone, cache) + + # TODO Add documentation + async def reset_timezone(self, cache: Optional[Cache] = None) -> None: + await self._remove("timezone", cache) diff --git a/cogs/config.py b/cogs/config.py index 303f711..4a72c80 100644 --- a/cogs/config.py +++ b/cogs/config.py @@ -1,3 +1,5 @@ +from zoneinfo import ZoneInfo, ZoneInfoNotFoundError + from discord import ( ApplicationContext, CategoryChannel, @@ -27,13 +29,21 @@ class Config(Cog): ) @option("category", description="Category where channels for each user will be created", required=True) @option("channel", description="Text channel for admin notifications", required=True) + @option("timezone", description="Timezone in which events take place", required=True) async def command_config_set( - self, ctx: ApplicationContext, category: CategoryChannel, channel: TextChannel + self, ctx: ApplicationContext, category: CategoryChannel, channel: TextChannel, timezone: str ) -> None: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) + try: + timezone_parsed: ZoneInfo = ZoneInfo(timezone) + except ZoneInfoNotFoundError: + await ctx.respond(f"Timezone {timezone} was not found.") + return + await guild.set_channel(channel.id, cache=self.bot.cache) await guild.set_category(category.id, cache=self.bot.cache) + await guild.set_timezone(str(timezone_parsed), cache=self.bot.cache) # TODO Make a nice message await ctx.respond("Okay.") @@ -54,6 +64,7 @@ class Config(Cog): await guild.reset_channel(cache=self.bot.cache) await guild.reset_category(cache=self.bot.cache) + await guild.reset_timezone(cache=self.bot.cache) # TODO Make a nice message await ctx.respond("Okay.") diff --git a/cogs/event.py b/cogs/event.py index 2b1a583..f19fd8c 100644 --- a/cogs/event.py +++ b/cogs/event.py @@ -1,7 +1,10 @@ +from datetime import datetime +from zoneinfo import ZoneInfo + from discord import ApplicationContext, Attachment, SlashCommandGroup, option from discord.ext.commands import Cog -from classes import PycordGuild +from classes import PycordEvent, PycordGuild from classes.pycord_bot import PycordBot @@ -20,8 +23,8 @@ class Event(Cog): description="Create new event", ) @option("name", description="Name of the event", required=True) - @option("start", description="Date when the event starts (DD.MM.YYYY)", required=True) - @option("finish", description="Date when the event finishes (DD.MM.YYYY)", required=True) + @option("start", description="Date when the event starts (DD.MM.YYYY HH:MM)", required=True) + @option("finish", description="Date when the event finishes (DD.MM.YYYY HH:MM)", required=True) @option("thumbnail", description="Thumbnail of the event", required=False) async def command_event_create( self, @@ -33,7 +36,39 @@ class Event(Cog): ) -> None: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) - await ctx.respond("Not implemented.") + if not guild.is_configured(): + await ctx.respond("Guild is not configured.") + return + + guild_timezone: ZoneInfo = ZoneInfo(guild.timezone) + + try: + start_date: datetime = datetime.strptime(start, "%d.%m.%Y %H:%M") + finish_date: datetime = datetime.strptime(finish, "%d.%m.%Y %H:%M") + + start_date = start_date.replace(tzinfo=guild_timezone) + finish_date = finish_date.replace(tzinfo=guild_timezone) + except ValueError: + await ctx.respond("Could not parse start and finish dates.") + return + + if start_date > finish_date: + await ctx.respond("Start date must be before finish date") + return + elif start_date < datetime.now(tz=guild_timezone): + await ctx.respond("Start date must not be in the past") + return + + event: PycordEvent = await self.bot.create_event( + name=name, + guild_id=guild.id, + creator_id=ctx.author.id, + starts=start_date.astimezone(ZoneInfo("UTC")), + ends=finish_date.astimezone(ZoneInfo("UTC")), + thumbnail_id=thumbnail.id if thumbnail else None, + ) + + await ctx.respond("Event has been created.") # TODO Implement the command @command_group.command( @@ -42,8 +77,8 @@ class Event(Cog): ) @option("event", description="Name of the event", required=True) @option("name", description="New name of the event", required=False) - @option("start", description="Date when the event starts (DD.MM.YYYY)", required=False) - @option("finish", description="Date when the event finishes (DD.MM.YYYY)", required=False) + @option("start", description="Date when the event starts (DD.MM.YYYY HH:MM)", required=False) + @option("finish", description="Date when the event finishes (DD.MM.YYYY HH:MM)", required=False) @option("thumbnail", description="Thumbnail of the event", required=False) async def command_event_edit( self, @@ -56,6 +91,12 @@ class Event(Cog): ) -> None: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) + if not guild.is_configured(): + await ctx.respond("Guild is not configured.") + return + + guild_timezone: ZoneInfo = ZoneInfo(guild.timezone) + await ctx.respond("Not implemented.") # TODO Implement the command diff --git a/modules/database.py b/modules/database.py index 3e41929..8d1a360 100644 --- a/modules/database.py +++ b/modules/database.py @@ -28,9 +28,9 @@ col_events: AsyncCollection = db.get_collection("events") col_stages: AsyncCollection = db.get_collection("stages") # Update indexes -db.dispatch.get_collection("users").create_index("id", unique=True) -db.dispatch.get_collection("guilds").create_index("id", unique=True) -db.dispatch.get_collection("events").create_index("id", unique=True) -db.dispatch.get_collection("events").create_index("guild_id", unique=False) -db.dispatch.get_collection("stages").create_index("id", unique=True) -db.dispatch.get_collection("stages").create_index(["event_id", "guild_id"], unique=False) +db.dispatch.get_collection("users").create_index("id", name="user_id", unique=True) +db.dispatch.get_collection("guilds").create_index("id", name="guild_id", unique=True) +db.dispatch.get_collection("events").create_index("guild_id", name="guild_id", unique=False) +db.dispatch.get_collection("stages").create_index( + ["event_id", "guild_id"], name="event_id-and-guild_id", unique=False +)