Cleanups and bugfixes for (#2 and #8)

This commit is contained in:
2025-04-24 00:16:53 +02:00
parent 57c4ff3bf9
commit c1d8620478
11 changed files with 222 additions and 72 deletions

View File

@@ -113,7 +113,9 @@ class PycordBot(LibPycordBot):
return event_stage return event_stage
# TODO Document this method # TODO Document this method
async def find_event(self, event_id: str | ObjectId | None = None, event_name: str | None = None): async def find_event(
self, event_id: str | ObjectId | None = None, event_name: str | None = None
) -> PycordEvent:
if event_id is None and event_name is None: if event_id is None and event_name is None:
raise AttributeError("Either event's ID or name must be provided!") raise AttributeError("Either event's ID or name must be provided!")

View File

@@ -23,7 +23,7 @@ class PycordEvent:
"guild_id", "guild_id",
"created", "created",
"ended", "ended",
"cancelled", "is_cancelled",
"creator_id", "creator_id",
"starts", "starts",
"ends", "ends",
@@ -38,7 +38,7 @@ class PycordEvent:
guild_id: int guild_id: int
created: datetime created: datetime
ended: datetime | None ended: datetime | None
cancelled: bool is_cancelled: bool
creator_id: int creator_id: int
starts: datetime starts: datetime
ends: datetime ends: datetime
@@ -58,6 +58,7 @@ class PycordEvent:
Raises: Raises:
EventNotFoundError: Event was not found EventNotFoundError: Event was not found
InvalidId: Invalid event ID was provided
""" """
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache) cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache)
@@ -112,7 +113,7 @@ class PycordEvent:
"guild_id": guild_id, "guild_id": guild_id,
"created": datetime.now(tz=ZoneInfo("UTC")), "created": datetime.now(tz=ZoneInfo("UTC")),
"ended": None, "ended": None,
"cancelled": False, "is_cancelled": False,
"creator_id": creator_id, "creator_id": creator_id,
"starts": starts, "starts": starts,
"ends": ends, "ends": ends,
@@ -129,12 +130,12 @@ class PycordEvent:
return cls(**db_entry) return cls(**db_entry)
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None: async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
"""Set attribute data and save it into the database. """Set attribute data and save it into the database.
Args: Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into cache (:obj:`Cache`, optional): Cache engine to write the update into
**kwargs (str): Mapping of attribute names and respective values to be set **kwargs (Any): Mapping of attribute names and respective values to be set
""" """
for key, value in kwargs.items(): for key, value in kwargs.items():
if not hasattr(self, key): if not hasattr(self, key):
@@ -208,7 +209,7 @@ class PycordEvent:
"guild_id": self.guild_id, "guild_id": self.guild_id,
"created": self.created, "created": self.created,
"ended": self.ended, "ended": self.ended,
"cancelled": self.cancelled, "is_cancelled": self.is_cancelled,
"creator_id": self.creator_id, "creator_id": self.creator_id,
"starts": self.starts, "starts": self.starts,
"ends": self.ends, "ends": self.ends,
@@ -223,7 +224,7 @@ class PycordEvent:
"guild_id": None, "guild_id": None,
"created": None, "created": None,
"ended": None, "ended": None,
"cancelled": False, "is_cancelled": False,
"creator_id": None, "creator_id": None,
"starts": None, "starts": None,
"ends": None, "ends": None,
@@ -265,7 +266,7 @@ class PycordEvent:
# TODO Add documentation # TODO Add documentation
async def cancel(self, cache: Optional[Cache] = None): async def cancel(self, cache: Optional[Cache] = None):
await self._set(cache, cancelled=True) await self._set(cache, is_cancelled=True)
async def _update_event_stage_order( async def _update_event_stage_order(
self, self,

View File

@@ -106,12 +106,12 @@ class PycordEventStage:
return cls(**db_entry) return cls(**db_entry)
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None: async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
"""Set attribute data and save it into the database. """Set attribute data and save it into the database.
Args: Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into cache (:obj:`Cache`, optional): Cache engine to write the update into
**kwargs (str): Mapping of attribute names and respective values to be set **kwargs (Any): Mapping of attribute names and respective values to be set
""" """
for key, value in kwargs.items(): for key, value in kwargs.items():
if not hasattr(self, key): if not hasattr(self, key):

View File

@@ -67,12 +67,12 @@ class PycordGuild:
return cls(**db_entry) return cls(**db_entry)
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None: async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
"""Set attribute data and save it into the database. """Set attribute data and save it into the database.
Args: Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into cache (:obj:`Cache`, optional): Cache engine to write the update into
**kwargs (str): Mapping of attribute names and respective values to be set **kwargs (Any): Mapping of attribute names and respective values to be set
""" """
for key, value in kwargs.items(): for key, value in kwargs.items():
if not hasattr(self, key): if not hasattr(self, key):

View File

@@ -1,12 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
from logging import Logger from logging import Logger
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
from bson import ObjectId from bson import ObjectId
from libbot.cache.classes import Cache from libbot.cache.classes import Cache
from pymongo.results import InsertOneResult from pymongo.results import InsertOneResult
from classes.abstract.cacheable import Cacheable
from classes.errors.pycord_user import UserNotFoundError from classes.errors.pycord_user import UserNotFoundError
from modules.database import col_users from modules.database import col_users
from modules.utils import get_logger, restore_from_cache from modules.utils import get_logger, restore_from_cache
@@ -15,10 +14,20 @@ logger: Logger = get_logger(__name__)
@dataclass @dataclass
class PycordUser(Cacheable): class PycordUser:
"""Dataclass of DB entry of a user""" """Dataclass of DB entry of a user"""
__slots__ = ("_id", "id", "guild_id", "channel_id", "current_event_id", "current_stage_id") __slots__ = (
"_id",
"id",
"guild_id",
"channel_id",
"is_jailed",
"current_event_id",
"current_stage_id",
"registered_event_ids",
"completed_event_ids",
)
__short_name__ = "user" __short_name__ = "user"
__collection__ = col_users __collection__ = col_users
@@ -26,8 +35,11 @@ class PycordUser(Cacheable):
id: int id: int
guild_id: int guild_id: int
channel_id: int | None channel_id: int | None
is_jailed: bool
current_event_id: ObjectId | None current_event_id: ObjectId | None
current_stage_id: ObjectId | None current_stage_id: ObjectId | None
registered_event_ids: List[ObjectId]
completed_event_ids: List[ObjectId]
@classmethod @classmethod
async def from_id( async def from_id(
@@ -82,20 +94,31 @@ class PycordUser(Cacheable):
"id": self.id, "id": self.id,
"guild_id": self.guild_id, "guild_id": self.guild_id,
"channel_id": self.channel_id, "channel_id": self.channel_id,
"is_jailed": self.is_jailed,
"current_event_id": ( "current_event_id": (
self.current_event_id if not json_compatible else str(self.current_event_id) self.current_event_id if not json_compatible else str(self.current_event_id)
), ),
"current_stage_id": ( "current_stage_id": (
self.current_stage_id if not json_compatible else str(self.current_stage_id) self.current_stage_id if not json_compatible else str(self.current_stage_id)
), ),
"registered_event_ids": (
self.registered_event_ids
if not json_compatible
else [str(event_id) for event_id in self.registered_event_ids]
),
"completed_event_ids": (
self.completed_event_ids
if not json_compatible
else [str(event_id) for event_id in self.completed_event_ids]
),
} }
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None: async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
"""Set attribute data and save it into the database. """Set attribute data and save it into the database.
Args: Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into cache (:obj:`Cache`, optional): Cache engine to write the update into
**kwargs (str): Mapping of attribute names and respective values to be set **kwargs (Any): Mapping of attribute names and respective values to be set
""" """
for key, value in kwargs.items(): for key, value in kwargs.items():
if not hasattr(self, key): if not hasattr(self, key):
@@ -160,8 +183,11 @@ class PycordUser(Cacheable):
"id": user_id, "id": user_id,
"guild_id": guild_id, "guild_id": guild_id,
"channel_id": None, "channel_id": None,
"is_jailed": False,
"current_event_id": None, "current_event_id": None,
"current_stage_id": None, "current_stage_id": None,
"registered_event_ids": [],
"completed_event_ids": [],
} }
@staticmethod @staticmethod
@@ -179,3 +205,45 @@ class PycordUser(Cacheable):
""" """
await self.__collection__.delete_one({"_id": self._id}) await self.__collection__.delete_one({"_id": self._id})
self._delete_cache(cache) self._delete_cache(cache)
# TODO Add documentation
async def event_register(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None:
event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id
if event_id in self.registered_event_ids:
raise RuntimeError(f"User is already registered for event {event_id}")
# TODO Add a unique exception
# raise UserAlreadyRegisteredForEventError(event_name)
self.registered_event_ids.append(event_id)
await self._set(cache, registered_event_ids=self.registered_event_ids)
# TODO Add documentation
async def event_unregister(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None:
event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id
if event_id not in self.registered_event_ids:
raise RuntimeError(f"User is not registered for event {event_id}")
# TODO Add a unique exception
# raise UserNotRegisteredForEventError(event_name)
self.registered_event_ids.remove(event_id)
await self._set(cache, registered_event_ids=self.registered_event_ids)
# TODO Add documentation
async def event_complete(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None:
event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id
if event_id in self.completed_event_ids:
raise RuntimeError(f"User has already completed event {event_id}")
# TODO Add a unique exception
# raise UserAlreadyCompletedEventError(event_name)
self.completed_event_ids.append(event_id)
await self._set(cache, completed_event_ids=self.completed_event_ids)

View File

@@ -1,8 +1,7 @@
from datetime import datetime from datetime import datetime
from typing import Any, Dict
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from bson import ObjectId from bson.errors import InvalidId
from discord import ( from discord import (
ApplicationContext, ApplicationContext,
Attachment, Attachment,
@@ -14,42 +13,7 @@ from discord.utils import basic_autocomplete
from classes import PycordEvent, PycordGuild from classes import PycordEvent, PycordGuild
from classes.pycord_bot import PycordBot from classes.pycord_bot import PycordBot
from modules.database import col_events from modules.utils import autocomplete_active_events, validate_event_validity
from modules.utils import autocomplete_active_events
# TODO Move to staticmethod or to a separate module
async def validate_event_validity(
ctx: ApplicationContext,
name: str,
start_date: datetime | None,
finish_date: datetime | None,
guild_timezone: ZoneInfo,
event_id: ObjectId | None = None,
) -> None:
if start_date > finish_date:
# TODO Make a nice message
await ctx.respond("Start date must be before finish date")
return
elif start_date < datetime.now(tz=guild_timezone):
# TODO Make a nice message
await ctx.respond("Start date must not be in the past")
return
query: Dict[str, Any] = {
"name": name,
"ended": None,
"ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"cancelled": {"$ne": True},
}
if event_id is not None:
query["_id"] = {"$ne": event_id}
if (await col_events.find_one(query)) is not None:
# TODO Make a nice message
await ctx.respond("There can only be one active event with the same name")
return
class Event(Cog): class Event(Cog):
@@ -137,9 +101,10 @@ class Event(Cog):
thumbnail: Attachment = None, thumbnail: Attachment = None,
) -> None: ) -> None:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
if pycord_event is None: try:
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
except (InvalidId, RuntimeError):
# TODO Make a nice message # TODO Make a nice message
await ctx.respond("Event was not found.") await ctx.respond("Event was not found.")
return return
@@ -208,9 +173,10 @@ class Event(Cog):
return return
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
if pycord_event is None: try:
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
except (InvalidId, RuntimeError):
# TODO Make a nice message # TODO Make a nice message
await ctx.respond("Event was not found.") await ctx.respond("Event was not found.")
return return
@@ -251,9 +217,10 @@ class Event(Cog):
) )
async def command_event_show(self, ctx: ApplicationContext, event: str) -> None: async def command_event_show(self, ctx: ApplicationContext, event: str) -> None:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
if pycord_event is None: try:
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
except (InvalidId, RuntimeError):
# TODO Make a nice message # TODO Make a nice message
await ctx.respond("Event was not found.") await ctx.respond("Event was not found.")
return return

View File

@@ -1,6 +1,7 @@
from datetime import datetime from datetime import datetime
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from bson.errors import InvalidId
from discord import ApplicationContext, Attachment, SlashCommandGroup, option from discord import ApplicationContext, Attachment, SlashCommandGroup, option
from discord.ext.commands import Cog from discord.ext.commands import Cog
from discord.utils import basic_autocomplete from discord.utils import basic_autocomplete
@@ -14,7 +15,7 @@ async def validate_event_status(
ctx: ApplicationContext, ctx: ApplicationContext,
event: PycordEvent, event: PycordEvent,
) -> None: ) -> None:
if event.cancelled: if event.is_cancelled:
# TODO Make a nice message # TODO Make a nice message
await ctx.respond("This event was cancelled.") await ctx.respond("This event was cancelled.")
return return
@@ -63,7 +64,12 @@ class Stage(Cog):
await ctx.respond("Guild is not configured.") await ctx.respond("Guild is not configured.")
return return
pycord_event: PycordEvent = await self.bot.find_event(event) try:
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event was not found.")
return
await validate_event_status(ctx, pycord_event) await validate_event_status(ctx, pycord_event)
@@ -123,11 +129,21 @@ class Stage(Cog):
await ctx.respond("Guild is not configured.") await ctx.respond("Guild is not configured.")
return return
pycord_event: PycordEvent = await self.bot.find_event(event) try:
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event was not found.")
return
await validate_event_status(ctx, pycord_event) await validate_event_status(ctx, pycord_event)
event_stage: PycordEventStage = await self.bot.find_event_stage(stage) try:
event_stage: PycordEventStage = await self.bot.find_event_stage(stage)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event stage was not found.")
return
if order is not None and order > len(pycord_event.stage_ids): if order is not None and order > len(pycord_event.stage_ids):
# TODO Make a nice message # TODO Make a nice message
@@ -180,11 +196,21 @@ class Stage(Cog):
await ctx.respond("Guild is not configured.") await ctx.respond("Guild is not configured.")
return return
pycord_event: PycordEvent = await self.bot.find_event(event) try:
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event was not found.")
return
await validate_event_status(ctx, pycord_event) await validate_event_status(ctx, pycord_event)
event_stage: PycordEventStage = await self.bot.find_event_stage(stage) try:
event_stage: PycordEventStage = await self.bot.find_event_stage(stage)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event stage was not found.")
return
await pycord_event.remove_stage(self.bot, event_stage._id, cache=self.bot.cache) await pycord_event.remove_stage(self.bot, event_stage._id, cache=self.bot.cache)
await event_stage.purge(cache=self.bot.cache) await event_stage.purge(cache=self.bot.cache)

View File

@@ -1 +1 @@
from . import utils, database, migrator, scheduler from . import database, migrator, scheduler, utils

View File

@@ -3,6 +3,9 @@ from .autocomplete_utils import (
autocomplete_event_stages, autocomplete_event_stages,
autocomplete_languages, autocomplete_languages,
autocomplete_timezones, autocomplete_timezones,
autocomplete_user_available_events,
autocomplete_user_registered_events,
) )
from .cache_utils import restore_from_cache from .cache_utils import restore_from_cache
from .event_utils import validate_event_validity
from .logging_utils import get_logger, get_logging_config from .logging_utils import get_logger, get_logging_config

View File

@@ -6,7 +6,7 @@ from bson import ObjectId
from discord import AutocompleteContext, OptionChoice from discord import AutocompleteContext, OptionChoice
from pymongo import ASCENDING from pymongo import ASCENDING
from modules.database import col_events, col_stages from modules.database import col_events, col_stages, col_users
async def autocomplete_timezones(ctx: AutocompleteContext) -> List[str]: async def autocomplete_timezones(ctx: AutocompleteContext) -> List[str]:
@@ -29,7 +29,7 @@ async def autocomplete_active_events(ctx: AutocompleteContext) -> List[OptionCho
query: Dict[str, Any] = { query: Dict[str, Any] = {
"ended": None, "ended": None,
"ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}, "ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"cancelled": {"$ne": True}, "is_cancelled": {"$ne": True},
} }
event_names: List[OptionChoice] = [] event_names: List[OptionChoice] = []
@@ -40,6 +40,43 @@ async def autocomplete_active_events(ctx: AutocompleteContext) -> List[OptionCho
return event_names return event_names
async def autocomplete_user_available_events(ctx: AutocompleteContext) -> List[OptionChoice]:
"""Return list of active events user can register in"""
return await autocomplete_active_events(ctx)
async def autocomplete_user_registered_events(ctx: AutocompleteContext) -> List[OptionChoice]:
"""Return list of active events user is registered in"""
pipeline: List[Dict[str, Any]] = [
{
"$lookup": {
"from": "events",
"localField": "registered_event_ids",
"foreignField": "_id",
"as": "registered_events",
}
},
{
"$match": {
"registered_events.ended": None,
"registered_events.ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"registered_events.starts": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"registered_events.is_cancelled": {"$ne": True},
}
},
]
event_names: List[OptionChoice] = []
async for result in col_users.aggregate(pipeline):
for registered_event in result["registered_events"]:
event_names.append(OptionChoice(registered_event["name"], str(registered_event["_id"])))
return event_names
async def autocomplete_event_stages(ctx: AutocompleteContext) -> List[OptionChoice]: async def autocomplete_event_stages(ctx: AutocompleteContext) -> List[OptionChoice]:
"""Return list of stages of the event""" """Return list of stages of the event"""

View File

@@ -0,0 +1,46 @@
from datetime import datetime
from typing import Any, Dict
from zoneinfo import ZoneInfo
from bson import ObjectId
from discord import (
ApplicationContext,
)
from modules.database import col_events
async def validate_event_validity(
ctx: ApplicationContext,
name: str,
start_date: datetime | None,
finish_date: datetime | None,
guild_timezone: ZoneInfo,
event_id: ObjectId | None = None,
) -> None:
if start_date > finish_date:
# TODO Make a nice message
await ctx.respond("Start date must be before finish date")
return
if start_date < datetime.now(tz=guild_timezone):
# TODO Make a nice message
await ctx.respond("Start date must not be in the past")
return
# TODO Add validation for concurrent events.
# Only one event can take place at the same time.
query: Dict[str, Any] = {
"name": name,
"ended": None,
"ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"is_cancelled": {"$ne": True},
}
if event_id is not None:
query["_id"] = {"$ne": event_id}
if (await col_events.find_one(query)) is not None:
# TODO Make a nice message
await ctx.respond("There can only be one active event with the same name")
return