diff --git a/classes/abstract/__init__.py b/classes/abstract/__init__.py new file mode 100644 index 0000000..f33fd40 --- /dev/null +++ b/classes/abstract/__init__.py @@ -0,0 +1 @@ +from .cacheable import Cacheable diff --git a/classes/abstract/cacheable.py b/classes/abstract/cacheable.py new file mode 100644 index 0000000..936b611 --- /dev/null +++ b/classes/abstract/cacheable.py @@ -0,0 +1,81 @@ +from abc import ABC, abstractmethod +from typing import Any, ClassVar, Dict, Optional + +from libbot.cache.classes import Cache +from pymongo.asynchronous.collection import AsyncCollection + + +class Cacheable(ABC): + """Abstract class for cacheable""" + + __short_name__: str + __collection__: ClassVar[AsyncCollection] + + @classmethod + @abstractmethod + async def from_id(cls, *args: Any, cache: Optional[Cache] = None, **kwargs: Any) -> Any: + pass + + @abstractmethod + async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None: + pass + + @abstractmethod + async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None: + pass + + @abstractmethod + def _get_cache_key(self) -> str: + pass + + @abstractmethod + def _update_cache(self, cache: Optional[Cache] = None) -> None: + pass + + @abstractmethod + def _delete_cache(self, cache: Optional[Cache] = None) -> None: + pass + + @staticmethod + @abstractmethod + def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]: + pass + + @staticmethod + @abstractmethod + def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]: + pass + + @abstractmethod + def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]: + pass + + @staticmethod + @abstractmethod + def get_defaults(**kwargs: Any) -> Dict[str, Any]: + pass + + @staticmethod + @abstractmethod + def get_default_value(key: str) -> Any: + pass + + @abstractmethod + async def update( + self, + cache: Optional[Cache] = None, + **kwargs: Any, + ) -> None: + pass + + @abstractmethod + async def reset( + self, + *args: str, + cache: Optional[Cache] = None, + ) -> None: + pass + + @abstractmethod + async def purge(self, cache: Optional[Cache] = None) -> None: + pass diff --git a/classes/pycord_bot.py b/classes/pycord_bot.py index 86da158..f6e64a2 100644 --- a/classes/pycord_bot.py +++ b/classes/pycord_bot.py @@ -14,13 +14,13 @@ from typing_extensions import override from classes import PycordEvent, PycordEventStage, PycordGuild, PycordUser from classes.errors import ( + DiscordGuildMemberNotFoundError, EventNotFoundError, EventStageMissingSequenceError, EventStageNotFoundError, GuildNotFoundError, - DiscordGuildMemberNotFoundError, ) -from modules.database import col_events, col_users +from modules.database import col_events, col_users, _update_database_indexes from modules.utils import get_logger logger: Logger = get_logger(__name__) @@ -58,6 +58,7 @@ class PycordBot(LibPycordBot): @override async def start(self, *args: Any, **kwargs: Any) -> None: await self._schedule_tasks() + await _update_database_indexes() self.started = datetime.now(tz=ZoneInfo("UTC")) diff --git a/classes/pycord_event.py b/classes/pycord_event.py index b272a4c..a0cd3b9 100644 --- a/classes/pycord_event.py +++ b/classes/pycord_event.py @@ -1,7 +1,7 @@ """Module with class PycordEvent.""" from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import datetime, tzinfo from logging import Logger from typing import Any, Dict, List, Optional from zoneinfo import ZoneInfo @@ -11,6 +11,7 @@ from libbot.cache.classes import Cache from pymongo import DESCENDING from pymongo.results import InsertOneResult +from classes.abstract import Cacheable from classes.errors import EventNotFoundError from modules.database import col_events from modules.utils import get_logger, restore_from_cache @@ -19,7 +20,7 @@ logger: Logger = get_logger(__name__) @dataclass -class PycordEvent: +class PycordEvent(Cacheable): """Object representation of an event in the database. Attributes: @@ -82,7 +83,7 @@ class PycordEvent: cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache) if cached_entry is not None: - return cls(**cached_entry) + return cls(**cls._entry_from_cache(cached_entry)) db_entry = await cls.__collection__.find_one( {"_id": event_id if isinstance(event_id, ObjectId) else ObjectId(event_id)} @@ -92,7 +93,7 @@ class PycordEvent: raise EventNotFoundError(event_id=event_id) if cache is not None: - cache.set_json(f"{cls.__short_name__}_{event_id}", db_entry) + cache.set_json(f"{cls.__short_name__}_{event_id}", cls._entry_to_cache(dict(db_entry))) return cls(**db_entry) @@ -123,7 +124,7 @@ class PycordEvent: raise EventNotFoundError(event_name=event_name, guild_id=guild_id) if cache is not None: - cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", db_entry) + cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", cls._entry_to_cache(db_entry)) return cls(**db_entry) @@ -172,7 +173,7 @@ class PycordEvent: db_entry["_id"] = insert_result.inserted_id if cache is not None: - cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry) + cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry)) return cls(**db_entry) @@ -215,10 +216,10 @@ class PycordEvent: if cache is None: return - user_dict: Dict[str, Any] = self.to_dict() + object_dict: Dict[str, Any] = self.to_dict(json_compatible=True) - if user_dict is not None: - cache.set_json(self._get_cache_key(), user_dict) + if object_dict is not None: + cache.set_json(self._get_cache_key(), object_dict) else: self._delete_cache(cache) @@ -253,6 +254,32 @@ class PycordEvent: if stage_index != old_stage_index: await (await bot.find_event_stage(event_stage_id)).update(cache, sequence=stage_index) + @staticmethod + def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]: + cache_entry: Dict[str, Any] = db_entry.copy() + + cache_entry["_id"] = str(cache_entry["_id"]) + cache_entry["created"] = cache_entry["created"].isoformat() + cache_entry["ended"] = None if cache_entry["ended"] is None else cache_entry["ended"].isoformat() + cache_entry["starts"] = cache_entry["starts"].isoformat() + cache_entry["ends"] = cache_entry["ends"].isoformat() + cache_entry["stage_ids"] = [str(stage_id) for stage_id in cache_entry["stage_ids"]] + + return cache_entry + + @staticmethod + def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]: + db_entry: Dict[str, Any] = cache_entry.copy() + + db_entry["_id"] = ObjectId(db_entry["_id"]) + db_entry["created"] = datetime.fromisoformat(db_entry["created"]) + db_entry["ended"] = None if db_entry["ended"] is None else datetime.fromisoformat(db_entry["ended"]) + db_entry["starts"] = datetime.fromisoformat(db_entry["starts"]) + db_entry["ends"] = datetime.fromisoformat(db_entry["ends"]) + db_entry["stage_ids"] = [ObjectId(stage_id) for stage_id in db_entry["stage_ids"]] + + return db_entry + def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]: """Convert the object to a JSON representation. @@ -266,14 +293,20 @@ class PycordEvent: "_id": self._id if not json_compatible else str(self._id), "name": self.name, "guild_id": self.guild_id, - "created": self.created, - "ended": self.ended, + "created": self.created if not json_compatible else self.created.isoformat(), + "ended": ( + self.ended + if not json_compatible + else (None if self.ended is None else self.ended.isoformat()) + ), "is_cancelled": self.is_cancelled, "creator_id": self.creator_id, - "starts": self.starts, - "ends": self.ends, + "starts": self.starts if not json_compatible else self.starts.isoformat(), + "ends": self.ends if not json_compatible else self.ends.isoformat(), "thumbnail": self.thumbnail, - "stage_ids": self.stage_ids, + "stage_ids": ( + self.stage_ids if not json_compatible else [str(stage_id) for stage_id in self.stage_ids] + ), } @staticmethod @@ -456,7 +489,7 @@ class PycordEvent: return self.ends.replace(tzinfo=ZoneInfo("UTC")) - def get_start_date_localized(self, tz: str | timezone | ZoneInfo) -> datetime: + def get_start_date_localized(self, tz: tzinfo) -> datetime: """Get the event start date in the provided timezone. Returns: @@ -470,7 +503,7 @@ class PycordEvent: return self.starts.replace(tzinfo=tz) - def get_end_date_localized(self, tz: str | timezone | ZoneInfo) -> datetime: + def get_end_date_localized(self, tz: tzinfo) -> datetime: """Get the event end date in the provided timezone. Returns: diff --git a/classes/pycord_event_stage.py b/classes/pycord_event_stage.py index 95ee2a9..33f8154 100644 --- a/classes/pycord_event_stage.py +++ b/classes/pycord_event_stage.py @@ -10,6 +10,7 @@ from discord import File from libbot.cache.classes import Cache from pymongo.results import InsertOneResult +from classes.abstract import Cacheable from classes.errors import EventStageNotFoundError from modules.database import col_stages from modules.utils import get_logger, restore_from_cache @@ -18,7 +19,7 @@ logger: Logger = get_logger(__name__) @dataclass -class PycordEventStage: +class PycordEventStage(Cacheable): __slots__ = ( "_id", "event_id", @@ -61,7 +62,7 @@ class PycordEventStage: cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, stage_id, cache=cache) if cached_entry is not None: - return cls(**cached_entry) + return cls(**cls._entry_from_cache(cached_entry)) db_entry = await cls.__collection__.find_one( {"_id": stage_id if isinstance(stage_id, ObjectId) else ObjectId(stage_id)} @@ -71,7 +72,7 @@ class PycordEventStage: raise EventStageNotFoundError(stage_id) if cache is not None: - cache.set_json(f"{cls.__short_name__}_{stage_id}", db_entry) + cache.set_json(f"{cls.__short_name__}_{stage_id}", cls._entry_to_cache(dict(db_entry))) return cls(**db_entry) @@ -104,7 +105,7 @@ class PycordEventStage: db_entry["_id"] = insert_result.inserted_id if cache is not None: - cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry) + cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry)) return cls(**db_entry) @@ -159,10 +160,10 @@ class PycordEventStage: if cache is None: return - user_dict: Dict[str, Any] = self.to_dict() + object_dict: Dict[str, Any] = self.to_dict(json_compatible=True) - if user_dict is not None: - cache.set_json(self._get_cache_key(), user_dict) + if object_dict is not None: + cache.set_json(self._get_cache_key(), object_dict) else: self._delete_cache(cache) @@ -172,6 +173,26 @@ class PycordEventStage: cache.delete(self._get_cache_key()) + @staticmethod + def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]: + cache_entry: Dict[str, Any] = db_entry.copy() + + cache_entry["_id"] = str(cache_entry["_id"]) + cache_entry["event_id"] = str(cache_entry["event_id"]) + cache_entry["created"] = cache_entry["created"].isoformat() + + return cache_entry + + @staticmethod + def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]: + db_entry: Dict[str, Any] = cache_entry.copy() + + db_entry["_id"] = ObjectId(db_entry["_id"]) + db_entry["event_id"] = ObjectId(db_entry["event_id"]) + db_entry["created"] = datetime.fromisoformat(db_entry["created"]) + + return db_entry + def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]: """Convert the object to a JSON representation. @@ -186,7 +207,7 @@ class PycordEventStage: "event_id": self.event_id if not json_compatible else str(self.event_id), "guild_id": self.guild_id, "sequence": self.sequence, - "created": self.created, + "created": self.created if not json_compatible else self.created.isoformat(), "creator_id": self.creator_id, "question": self.question, "answer": self.answer, diff --git a/classes/pycord_guild.py b/classes/pycord_guild.py index bfcfb3a..c7c274b 100644 --- a/classes/pycord_guild.py +++ b/classes/pycord_guild.py @@ -6,6 +6,7 @@ from bson import ObjectId from libbot.cache.classes import Cache from pymongo.results import InsertOneResult +from classes.abstract import Cacheable from classes.errors import GuildNotFoundError from modules.database import col_guilds from modules.utils import get_logger, restore_from_cache @@ -14,7 +15,7 @@ logger: Logger = get_logger(__name__) @dataclass -class PycordGuild: +class PycordGuild(Cacheable): """Dataclass of DB entry of a guild""" __slots__ = ( @@ -57,7 +58,7 @@ class PycordGuild: cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, guild_id, cache=cache) if cached_entry is not None: - return cls(**cached_entry) + return cls(**cls._entry_from_cache(cached_entry)) db_entry = await cls.__collection__.find_one({"id": guild_id}) @@ -72,7 +73,7 @@ class PycordGuild: db_entry["_id"] = insert_result.inserted_id if cache is not None: - cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry) + cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry)) return cls(**db_entry) @@ -115,10 +116,10 @@ class PycordGuild: if cache is None: return - user_dict: Dict[str, Any] = self.to_dict() + object_dict: Dict[str, Any] = self.to_dict(json_compatible=True) - if user_dict is not None: - cache.set_json(self._get_cache_key(), user_dict) + if object_dict is not None: + cache.set_json(self._get_cache_key(), object_dict) else: self._delete_cache(cache) @@ -128,6 +129,22 @@ class PycordGuild: cache.delete(self._get_cache_key()) + @staticmethod + def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]: + cache_entry: Dict[str, Any] = db_entry.copy() + + cache_entry["_id"] = str(cache_entry["_id"]) + + return cache_entry + + @staticmethod + def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]: + db_entry: Dict[str, Any] = cache_entry.copy() + + db_entry["_id"] = ObjectId(db_entry["_id"]) + + return db_entry + def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]: """Convert the object to a JSON representation. diff --git a/classes/pycord_user.py b/classes/pycord_user.py index 3473a5b..f6f8260 100644 --- a/classes/pycord_user.py +++ b/classes/pycord_user.py @@ -17,6 +17,7 @@ from discord.abc import GuildChannel from libbot.cache.classes import Cache from pymongo.results import InsertOneResult +from classes.abstract import Cacheable from classes.errors import ( DiscordCategoryNotFoundError, DiscordChannelNotFoundError, @@ -33,9 +34,17 @@ logger: Logger = get_logger(__name__) @dataclass -class PycordUser: +class PycordUser(Cacheable): """Dataclass of DB entry of a user""" + # TODO Implement this + async def update(self, cache: Optional[Cache] = None, **kwargs: Any) -> None: + pass + + # TODO Implement this + async def reset(self, *args: str, cache: Optional[Cache] = None) -> None: + pass + __slots__ = ( "_id", "id", @@ -83,7 +92,7 @@ class PycordUser: ) if cached_entry is not None: - return cls(**cached_entry) + return cls(**cls._entry_from_cache(cached_entry)) db_entry = await cls.__collection__.find_one({"id": user_id, "guild_id": guild_id}) @@ -98,7 +107,7 @@ class PycordUser: db_entry["_id"] = insert_result.inserted_id if cache is not None: - cache.set_json(f"{cls.__short_name__}_{user_id}_{guild_id}", db_entry) + cache.set_json(f"{cls.__short_name__}_{user_id}_{guild_id}", cls._entry_to_cache(db_entry)) return cls(**db_entry) @@ -186,10 +195,10 @@ class PycordUser: if cache is None: return - user_dict: Dict[str, Any] = self.to_dict() + object_dict: Dict[str, Any] = self.to_dict(json_compatible=True) - if user_dict is not None: - cache.set_json(self._get_cache_key(), user_dict) + if object_dict is not None: + cache.set_json(self._get_cache_key(), object_dict) else: self._delete_cache(cache) @@ -199,6 +208,46 @@ class PycordUser: cache.delete(self._get_cache_key()) + @staticmethod + def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]: + cache_entry: Dict[str, Any] = db_entry.copy() + + cache_entry["_id"] = str(cache_entry["_id"]) + cache_entry["current_event_id"] = ( + None if cache_entry["current_event_id"] is None else str(cache_entry["current_event_id"]) + ) + cache_entry["current_stage_id"] = ( + None if cache_entry["current_stage_id"] is None else str(cache_entry["current_stage_id"]) + ) + cache_entry["registered_event_ids"] = [ + str(event_id) for event_id in cache_entry["registered_event_ids"] + ] + cache_entry["completed_event_ids"] = [ + str(event_id) for event_id in cache_entry["completed_event_ids"] + ] + + return cache_entry + + @staticmethod + def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]: + db_entry: Dict[str, Any] = cache_entry.copy() + + db_entry["_id"] = ObjectId(db_entry["_id"]) + db_entry["current_event_id"] = ( + None if db_entry["current_event_id"] is None else ObjectId(db_entry["current_event_id"]) + ) + db_entry["current_stage_id"] = ( + None if db_entry["current_stage_id"] is None else ObjectId(db_entry["current_stage_id"]) + ) + db_entry["registered_event_ids"] = [ + ObjectId(event_id) for event_id in db_entry["registered_event_ids"] + ] + db_entry["completed_event_ids"] = [ + ObjectId(event_id) for event_id in db_entry["completed_event_ids"] + ] + + return db_entry + # TODO Add documentation @staticmethod def get_defaults(user_id: Optional[int] = None, guild_id: Optional[int] = None) -> Dict[str, Any]: diff --git a/cogs/cog_config.py b/cogs/cog_config.py index 69f69b6..91e49dc 100644 --- a/cogs/cog_config.py +++ b/cogs/cog_config.py @@ -88,14 +88,15 @@ class CogConfig(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return try: timezone_parsed: ZoneInfo = ZoneInfo(timezone) except ZoneInfoNotFoundError: await ctx.respond( - self.bot._("timezone_invalid", "messages", locale=ctx.locale).format(timezone=timezone) + self.bot._("timezone_invalid", "messages", locale=ctx.locale).format(timezone=timezone), + ephemeral=True, ) return @@ -130,7 +131,7 @@ class CogConfig(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return await guild.purge(self.bot.cache) @@ -146,11 +147,13 @@ class CogConfig(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return if not guild.is_configured(): - await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True + ) return await ctx.respond( diff --git a/cogs/cog_event.py b/cogs/cog_event.py index e8641ed..e6561dd 100644 --- a/cogs/cog_event.py +++ b/cogs/cog_event.py @@ -84,11 +84,13 @@ class CogEvent(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return if not guild.is_configured(): - await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True + ) return guild_timezone: ZoneInfo = ZoneInfo(guild.timezone) @@ -97,7 +99,9 @@ class CogEvent(Cog): start_date: datetime = datetime.strptime(start, "%d.%m.%Y %H:%M").replace(tzinfo=guild_timezone) end_date: datetime = datetime.strptime(end, "%d.%m.%Y %H:%M").replace(tzinfo=guild_timezone) except ValueError: - await ctx.respond(self.bot._("event_dates_parsing_failed", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("event_dates_parsing_failed", "messages", locale=ctx.locale), ephemeral=True + ) return if not await validate_event_validity(ctx, name, start_date, end_date, to_utc=True): @@ -180,7 +184,7 @@ class CogEvent(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return try: @@ -190,7 +194,9 @@ class CogEvent(Cog): return if not guild.is_configured(): - await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True + ) return guild_timezone: ZoneInfo = ZoneInfo(guild.timezone) @@ -202,7 +208,9 @@ class CogEvent(Cog): else datetime.strptime(start, "%d.%m.%Y %H:%M").replace(tzinfo=guild_timezone) ) except ValueError: - await ctx.respond(self.bot._("event_start_date_parsing_failed", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("event_start_date_parsing_failed", "messages", locale=ctx.locale), ephemeral=True + ) return try: @@ -212,7 +220,9 @@ class CogEvent(Cog): else datetime.strptime(end, "%d.%m.%Y %H:%M").replace(tzinfo=guild_timezone) ) except ValueError: - await ctx.respond(self.bot._("event_end_date_parsing_failed", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("event_end_date_parsing_failed", "messages", locale=ctx.locale), ephemeral=True + ) return if not await validate_event_validity( @@ -280,7 +290,7 @@ class CogEvent(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return try: @@ -290,7 +300,9 @@ class CogEvent(Cog): return if not guild.is_configured(): - await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True + ) return start_date: datetime = pycord_event.starts.replace(tzinfo=ZoneInfo("UTC")) @@ -305,7 +317,8 @@ class CogEvent(Cog): await ctx.respond( self.bot._("event_not_editable", "messages", locale=ctx.locale).format( event_name=pycord_event.name - ) + ), + ephemeral=True, ) return diff --git a/cogs/cog_guess.py b/cogs/cog_guess.py index 61b6785..661050a 100644 --- a/cogs/cog_guess.py +++ b/cogs/cog_guess.py @@ -34,11 +34,13 @@ class CogGuess(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return if not guild.is_configured(): - await ctx.respond(self.bot._("guild_unconfigured", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("guild_unconfigured", "messages", locale=ctx.locale), ephemeral=True + ) return user: PycordUser = await self.bot.find_user(ctx.author, ctx.guild) diff --git a/cogs/cog_register.py b/cogs/cog_register.py index 0f4b2f4..27810eb 100644 --- a/cogs/cog_register.py +++ b/cogs/cog_register.py @@ -4,7 +4,7 @@ from pathlib import Path from zoneinfo import ZoneInfo from bson.errors import InvalidId -from discord import ApplicationContext, Cog, TextChannel, option, slash_command, File +from discord import ApplicationContext, Cog, File, TextChannel, option, slash_command from discord.utils import basic_autocomplete from libbot.i18n import _, in_every_locale @@ -39,7 +39,7 @@ class CogRegister(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return try: @@ -49,17 +49,21 @@ class CogRegister(Cog): return if not guild.is_configured(): - await ctx.respond(self.bot._("guild_unconfigured", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("guild_unconfigured", "messages", locale=ctx.locale), ephemeral=True + ) return user: PycordUser = await self.bot.find_user(ctx.author, ctx.guild) if user.is_jailed: - await ctx.respond(self.bot._("jailed_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("jailed_error", "messages", locale=ctx.locale), ephemeral=True) return if pycord_event._id in user.registered_event_ids: - await ctx.respond(self.bot._("register_already_registered", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("register_already_registered", "messages", locale=ctx.locale), ephemeral=True + ) return await user.event_register(pycord_event._id, cache=self.bot.cache) diff --git a/cogs/cog_stage.py b/cogs/cog_stage.py index d219bda..3a28379 100644 --- a/cogs/cog_stage.py +++ b/cogs/cog_stage.py @@ -84,11 +84,13 @@ class CogStage(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return if not guild.is_configured(): - await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True + ) return try: @@ -196,11 +198,13 @@ class CogStage(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return if not guild.is_configured(): - await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True + ) return try: @@ -215,11 +219,13 @@ class CogStage(Cog): try: event_stage: PycordEventStage = await self.bot.find_event_stage(stage) except (InvalidId, EventStageNotFoundError): - await ctx.respond(self.bot._("stage_not_found", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("stage_not_found", "messages", locale=ctx.locale), ephemeral=True) return if order is not None and order > len(pycord_event.stage_ids): - await ctx.respond(self.bot._("stage_sequence_out_of_range", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("stage_sequence_out_of_range", "messages", locale=ctx.locale), ephemeral=True + ) return processed_media: List[Dict[str, Any]] = ( @@ -278,11 +284,13 @@ class CogStage(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return if not guild.is_configured(): - await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True + ) return try: @@ -297,7 +305,7 @@ class CogStage(Cog): try: event_stage: PycordEventStage = await self.bot.find_event_stage(stage) except (InvalidId, EventStageNotFoundError): - await ctx.respond(self.bot._("stage_not_found", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("stage_not_found", "messages", locale=ctx.locale), ephemeral=True) return await pycord_event.remove_stage(self.bot, event_stage._id, cache=self.bot.cache) diff --git a/cogs/cog_unregister.py b/cogs/cog_unregister.py index 4e18604..4b55468 100644 --- a/cogs/cog_unregister.py +++ b/cogs/cog_unregister.py @@ -43,7 +43,7 @@ class CogUnregister(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return try: @@ -53,17 +53,22 @@ class CogUnregister(Cog): return if not guild.is_configured(): - await ctx.respond(self.bot._("guild_unconfigured", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("guild_unconfigured", "messages", locale=ctx.locale), ephemeral=True + ) return user: PycordUser = await self.bot.find_user(ctx.author, ctx.guild) if user.is_jailed: - await ctx.respond(self.bot._("jailed_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("jailed_error", "messages", locale=ctx.locale), ephemeral=True) return + # TODO Fix a bug where registered_event_ids is invalid because of caching if pycord_event._id not in user.registered_event_ids: - await ctx.respond(self.bot._("unregister_not_registered", "messages", locale=ctx.locale)) + await ctx.respond( + self.bot._("unregister_not_registered", "messages", locale=ctx.locale), ephemeral=True + ) return await user.event_unregister(pycord_event._id, cache=self.bot.cache) diff --git a/cogs/cog_user.py b/cogs/cog_user.py index b50ace2..e337030 100644 --- a/cogs/cog_user.py +++ b/cogs/cog_user.py @@ -1,27 +1,27 @@ from datetime import datetime from logging import Logger from pathlib import Path -from typing import List, Dict, Any +from typing import Any, Dict, List from zoneinfo import ZoneInfo from bson import ObjectId from bson.errors import InvalidId from discord import ( ApplicationContext, + File, SlashCommandGroup, + TextChannel, User, option, - File, - TextChannel, ) from discord.ext.commands import Cog from libbot.i18n import _, in_every_locale -from classes import PycordUser, PycordEvent, PycordGuild +from classes import PycordEvent, PycordGuild, PycordUser from classes.errors import GuildNotFoundError from classes.pycord_bot import PycordBot from modules.database import col_users -from modules.utils import is_operation_confirmed, get_logger +from modules.utils import get_logger, is_operation_confirmed, get_utc_now logger: Logger = get_logger(__name__) @@ -54,35 +54,45 @@ class CogUser(Cog): try: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) except (InvalidId, GuildNotFoundError): - await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) + await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True) return pycord_user: PycordUser = await self.bot.find_user(user.id, ctx.guild.id) events: List[PycordEvent] = [] + utc_now: datetime = get_utc_now() + pipeline: List[Dict[str, Any]] = [ {"$match": {"id": pycord_user.id}}, { "$lookup": { "from": "events", - "localField": "registered_event_ids", - "foreignField": "_id", + "let": {"event_ids": "$registered_event_ids"}, + "pipeline": [ + { + "$match": { + "$expr": { + "$and": [ + {"$in": ["$_id", "$$event_ids"]}, + {"$eq": ["$ended", None]}, + {"$gt": ["$ends", utc_now]}, + {"$lt": ["$starts", utc_now]}, + {"$eq": ["$is_cancelled", False]}, + ] + } + } + } + ], "as": "registered_events", } }, - { - "$match": { - "registered_events.ended": None, - "registered_events.ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}, - "registered_events.starts": {"$lt": datetime.now(tz=ZoneInfo("UTC"))}, - "registered_events.is_cancelled": False, - } - }, + {"$match": {"registered_events.0": {"$exists": True}}}, ] - async for result in col_users.aggregate(pipeline): - for registered_event in result["registered_events"]: - events.append(PycordEvent(**registered_event)) + async with await col_users.aggregate(pipeline) as cursor: + async for result in cursor: + for registered_event in result["registered_events"]: + events.append(PycordEvent(**registered_event)) for event in events: if pycord_user.current_event_id is not None and pycord_user.current_event_id != event._id: diff --git a/cogs/cog_utility.py b/cogs/cog_utility.py index b8d63a7..67e32f5 100644 --- a/cogs/cog_utility.py +++ b/cogs/cog_utility.py @@ -6,13 +6,13 @@ from zoneinfo import ZoneInfo from bson import ObjectId from bson.errors import InvalidId -from discord import Activity, ActivityType, Cog, Member, TextChannel, File +from discord import Activity, ActivityType, Cog, File, Member, TextChannel from classes import PycordEvent, PycordGuild, PycordUser from classes.errors import GuildNotFoundError from classes.pycord_bot import PycordBot from modules.database import col_users -from modules.utils import get_logger +from modules.utils import get_logger, get_utc_now logger: Logger = get_logger(__name__) @@ -78,29 +78,39 @@ class CogUtility(Cog): user: PycordUser = await self.bot.find_user(member.id, member.guild.id) events: List[PycordEvent] = [] + utc_now: datetime = get_utc_now() + pipeline: List[Dict[str, Any]] = [ {"$match": {"id": user.id}}, { "$lookup": { "from": "events", - "localField": "registered_event_ids", - "foreignField": "_id", + "let": {"event_ids": "$registered_event_ids"}, + "pipeline": [ + { + "$match": { + "$expr": { + "$and": [ + {"$in": ["$_id", "$$event_ids"]}, + {"$eq": ["$ended", None]}, + {"$gt": ["$ends", utc_now]}, + {"$lt": ["$starts", utc_now]}, + {"$eq": ["$is_cancelled", False]}, + ] + } + } + } + ], "as": "registered_events", } }, - { - "$match": { - "registered_events.ended": None, - "registered_events.ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}, - "registered_events.starts": {"$lt": datetime.now(tz=ZoneInfo("UTC"))}, - "registered_events.is_cancelled": False, - } - }, + {"$match": {"registered_events.0": {"$exists": True}}}, ] - async for result in col_users.aggregate(pipeline): - for registered_event in result["registered_events"]: - events.append(PycordEvent(**registered_event)) + async with await col_users.aggregate(pipeline) as cursor: + async for result in cursor: + for registered_event in result["registered_events"]: + events.append(PycordEvent(**registered_event)) for event in events: if user.current_event_id is not None and user.current_event_id != event._id: diff --git a/modules/database.py b/modules/database.py index 8d1a360..cbcb38c 100644 --- a/modules/database.py +++ b/modules/database.py @@ -2,8 +2,10 @@ from typing import Any, Mapping -from async_pymongo import AsyncClient, AsyncCollection, AsyncDatabase from libbot.utils import config_get +from pymongo import AsyncMongoClient +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.asynchronous.database import AsyncDatabase db_config: Mapping[str, Any] = config_get("database") @@ -19,7 +21,7 @@ else: con_string = "mongodb://{0}:{1}/{2}".format(db_config["host"], db_config["port"], db_config["name"]) # Async declarations -db_client = AsyncClient(con_string) +db_client = AsyncMongoClient(con_string) db: AsyncDatabase = db_client.get_database(name=db_config["name"]) col_users: AsyncCollection = db.get_collection("users") @@ -27,10 +29,10 @@ col_guilds: AsyncCollection = db.get_collection("guilds") col_events: AsyncCollection = db.get_collection("events") col_stages: AsyncCollection = db.get_collection("stages") + # Update indexes -db.dispatch.get_collection("users").create_index("id", name="user_id", unique=True) -db.dispatch.get_collection("guilds").create_index("id", name="guild_id", unique=True) -db.dispatch.get_collection("events").create_index("guild_id", name="guild_id", unique=False) -db.dispatch.get_collection("stages").create_index( - ["event_id", "guild_id"], name="event_id-and-guild_id", unique=False -) +async def _update_database_indexes() -> None: + await col_users.create_index("id", name="user_id", unique=True) + await col_guilds.create_index("id", name="guild_id", unique=True) + await col_events.create_index("guild_id", name="guild_id", unique=False) + await col_stages.create_index(["event_id", "guild_id"], name="event_id-and-guild_id", unique=False) diff --git a/modules/utils/__init__.py b/modules/utils/__init__.py index f15d19e..9c4698d 100644 --- a/modules/utils/__init__.py +++ b/modules/utils/__init__.py @@ -7,7 +7,7 @@ from .autocomplete_utils import ( autocomplete_user_registered_events, ) from .cache_utils import restore_from_cache -from .datetime_utils import get_unix_timestamp +from .datetime_utils import get_unix_timestamp, get_utc_now from .event_utils import validate_event_validity from .git_utils import get_current_commit from .logging_utils import get_logger, get_logging_config diff --git a/modules/utils/autocomplete_utils.py b/modules/utils/autocomplete_utils.py index 30056ba..815cedb 100644 --- a/modules/utils/autocomplete_utils.py +++ b/modules/utils/autocomplete_utils.py @@ -49,31 +49,41 @@ async def autocomplete_user_available_events(ctx: AutocompleteContext) -> List[O async def autocomplete_user_registered_events(ctx: AutocompleteContext) -> List[OptionChoice]: """Return list of active events user is registered in""" + utc_now: datetime = datetime.now(tz=ZoneInfo("UTC")) + pipeline: List[Dict[str, Any]] = [ {"$match": {"id": ctx.interaction.user.id}}, { "$lookup": { "from": "events", - "localField": "registered_event_ids", - "foreignField": "_id", + "let": {"event_ids": "$registered_event_ids"}, + "pipeline": [ + { + "$match": { + "$expr": { + "$and": [ + {"$in": ["$_id", "$$event_ids"]}, + {"$eq": ["$ended", None]}, + {"$gt": ["$ends", utc_now]}, + {"$gt": ["$starts", utc_now]}, + {"$eq": ["$is_cancelled", False]}, + ] + } + } + } + ], "as": "registered_events", } }, - { - "$match": { - "registered_events.ended": None, - "registered_events.ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}, - "registered_events.starts": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}, - "registered_events.is_cancelled": False, - } - }, + {"$match": {"registered_events.0": {"$exists": True}}}, ] event_names: List[OptionChoice] = [] - async for result in col_users.aggregate(pipeline): - for registered_event in result["registered_events"]: - event_names.append(OptionChoice(registered_event["name"], str(registered_event["_id"]))) + async with await col_users.aggregate(pipeline) as cursor: + async for result in cursor: + for registered_event in result["registered_events"]: + event_names.append(OptionChoice(registered_event["name"], str(registered_event["_id"]))) return event_names diff --git a/modules/utils/datetime_utils.py b/modules/utils/datetime_utils.py index a267962..19dcdf0 100644 --- a/modules/utils/datetime_utils.py +++ b/modules/utils/datetime_utils.py @@ -5,3 +5,8 @@ from zoneinfo import ZoneInfo # TODO Add documentation def get_unix_timestamp(date: datetime, to_utc: bool = False) -> int: return int((date if not to_utc else date.replace(tzinfo=ZoneInfo("UTC"))).timestamp()) + + +# TODO Add documentation +def get_utc_now() -> datetime: + return datetime.now(tz=ZoneInfo("UTC")) diff --git a/modules/utils/event_utils.py b/modules/utils/event_utils.py index 28500ab..85911c4 100644 --- a/modules/utils/event_utils.py +++ b/modules/utils/event_utils.py @@ -23,15 +23,15 @@ async def validate_event_validity( end_date_internal: datetime = end_date.astimezone(ZoneInfo("UTC")) if to_utc else end_date if start_date_internal < datetime.now(tz=ZoneInfo("UTC")): - await ctx.respond(_("event_start_past", "messages", locale=ctx.locale)) + await ctx.respond(_("event_start_past", "messages", locale=ctx.locale), ephemeral=True) return False if end_date_internal < datetime.now(tz=ZoneInfo("UTC")): - await ctx.respond(_("event_end_past", "messages", locale=ctx.locale)) + await ctx.respond(_("event_end_past", "messages", locale=ctx.locale), ephemeral=True) return False if start_date_internal >= end_date_internal: - await ctx.respond(_("event_end_before_start", "messages", locale=ctx.locale)) + await ctx.respond(_("event_end_before_start", "messages", locale=ctx.locale), ephemeral=True) return False # TODO Add validation for concurrent events. @@ -47,7 +47,7 @@ async def validate_event_validity( query["_id"] = {"$ne": event_id} if (await col_events.find_one(query)) is not None: - await ctx.respond(_("event_name_duplicate", "messages", locale=ctx.locale)) + await ctx.respond(_("event_name_duplicate", "messages", locale=ctx.locale), ephemeral=True) return False return True diff --git a/modules/utils/validation_utils.py b/modules/utils/validation_utils.py index 924bf06..50cfd10 100644 --- a/modules/utils/validation_utils.py +++ b/modules/utils/validation_utils.py @@ -7,7 +7,7 @@ from libbot.i18n import _ async def is_operation_confirmed(ctx: ApplicationContext, confirm: bool) -> bool: if confirm is None or not confirm: - await ctx.respond(ctx.bot._("operation_unconfirmed", "messages", locale=ctx.locale)) + await ctx.respond(ctx.bot._("operation_unconfirmed", "messages", locale=ctx.locale), ephemeral=True) return False return True @@ -18,7 +18,7 @@ async def is_event_status_valid( event: "PycordEvent", ) -> bool: if event.is_cancelled: - await ctx.respond(_("event_is_cancelled", "messages", locale=ctx.locale)) + await ctx.respond(_("event_is_cancelled", "messages", locale=ctx.locale), ephemeral=True) return False if ( @@ -26,7 +26,7 @@ async def is_event_status_valid( <= datetime.now(tz=ZoneInfo("UTC")) <= event.ends.replace(tzinfo=ZoneInfo("UTC")) ): - await ctx.respond(_("event_ongoing_not_editable", "messages", locale=ctx.locale)) + await ctx.respond(_("event_ongoing_not_editable", "messages", locale=ctx.locale), ephemeral=True) return False return True diff --git a/requirements.txt b/requirements.txt index b63f5a7..69e36aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ aiodns~=3.2.0 apscheduler~=3.11.0 -async_pymongo==0.1.11 brotlipy~=0.7.0 faust-cchardet~=2.1.19 libbot[speed,pycord,cache]==4.1.0 mongodb-migrations==1.3.1 msgspec~=0.19.0 +pymongo~=4.12.1,>=4.9 pytz~=2025.1 typing_extensions>=4.11.0 \ No newline at end of file