Cleanups and bugfixes for (#2 and #8)

This commit is contained in:
2025-04-24 00:16:53 +02:00
parent 57c4ff3bf9
commit c1d8620478
11 changed files with 222 additions and 72 deletions

View File

@@ -113,7 +113,9 @@ class PycordBot(LibPycordBot):
return event_stage
# TODO Document this method
async def find_event(self, event_id: str | ObjectId | None = None, event_name: str | None = None):
async def find_event(
self, event_id: str | ObjectId | None = None, event_name: str | None = None
) -> PycordEvent:
if event_id is None and event_name is None:
raise AttributeError("Either event's ID or name must be provided!")

View File

@@ -23,7 +23,7 @@ class PycordEvent:
"guild_id",
"created",
"ended",
"cancelled",
"is_cancelled",
"creator_id",
"starts",
"ends",
@@ -38,7 +38,7 @@ class PycordEvent:
guild_id: int
created: datetime
ended: datetime | None
cancelled: bool
is_cancelled: bool
creator_id: int
starts: datetime
ends: datetime
@@ -58,6 +58,7 @@ class PycordEvent:
Raises:
EventNotFoundError: Event was not found
InvalidId: Invalid event ID was provided
"""
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache)
@@ -112,7 +113,7 @@ class PycordEvent:
"guild_id": guild_id,
"created": datetime.now(tz=ZoneInfo("UTC")),
"ended": None,
"cancelled": False,
"is_cancelled": False,
"creator_id": creator_id,
"starts": starts,
"ends": ends,
@@ -129,12 +130,12 @@ class PycordEvent:
return cls(**db_entry)
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None:
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
"""Set attribute data and save it into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
**kwargs (str): Mapping of attribute names and respective values to be set
**kwargs (Any): Mapping of attribute names and respective values to be set
"""
for key, value in kwargs.items():
if not hasattr(self, key):
@@ -208,7 +209,7 @@ class PycordEvent:
"guild_id": self.guild_id,
"created": self.created,
"ended": self.ended,
"cancelled": self.cancelled,
"is_cancelled": self.is_cancelled,
"creator_id": self.creator_id,
"starts": self.starts,
"ends": self.ends,
@@ -223,7 +224,7 @@ class PycordEvent:
"guild_id": None,
"created": None,
"ended": None,
"cancelled": False,
"is_cancelled": False,
"creator_id": None,
"starts": None,
"ends": None,
@@ -265,7 +266,7 @@ class PycordEvent:
# TODO Add documentation
async def cancel(self, cache: Optional[Cache] = None):
await self._set(cache, cancelled=True)
await self._set(cache, is_cancelled=True)
async def _update_event_stage_order(
self,

View File

@@ -106,12 +106,12 @@ class PycordEventStage:
return cls(**db_entry)
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None:
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
"""Set attribute data and save it into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
**kwargs (str): Mapping of attribute names and respective values to be set
**kwargs (Any): Mapping of attribute names and respective values to be set
"""
for key, value in kwargs.items():
if not hasattr(self, key):

View File

@@ -67,12 +67,12 @@ class PycordGuild:
return cls(**db_entry)
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None:
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
"""Set attribute data and save it into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
**kwargs (str): Mapping of attribute names and respective values to be set
**kwargs (Any): Mapping of attribute names and respective values to be set
"""
for key, value in kwargs.items():
if not hasattr(self, key):

View File

@@ -1,12 +1,11 @@
from dataclasses import dataclass
from logging import Logger
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from bson import ObjectId
from libbot.cache.classes import Cache
from pymongo.results import InsertOneResult
from classes.abstract.cacheable import Cacheable
from classes.errors.pycord_user import UserNotFoundError
from modules.database import col_users
from modules.utils import get_logger, restore_from_cache
@@ -15,10 +14,20 @@ logger: Logger = get_logger(__name__)
@dataclass
class PycordUser(Cacheable):
class PycordUser:
"""Dataclass of DB entry of a user"""
__slots__ = ("_id", "id", "guild_id", "channel_id", "current_event_id", "current_stage_id")
__slots__ = (
"_id",
"id",
"guild_id",
"channel_id",
"is_jailed",
"current_event_id",
"current_stage_id",
"registered_event_ids",
"completed_event_ids",
)
__short_name__ = "user"
__collection__ = col_users
@@ -26,8 +35,11 @@ class PycordUser(Cacheable):
id: int
guild_id: int
channel_id: int | None
is_jailed: bool
current_event_id: ObjectId | None
current_stage_id: ObjectId | None
registered_event_ids: List[ObjectId]
completed_event_ids: List[ObjectId]
@classmethod
async def from_id(
@@ -82,20 +94,31 @@ class PycordUser(Cacheable):
"id": self.id,
"guild_id": self.guild_id,
"channel_id": self.channel_id,
"is_jailed": self.is_jailed,
"current_event_id": (
self.current_event_id if not json_compatible else str(self.current_event_id)
),
"current_stage_id": (
self.current_stage_id if not json_compatible else str(self.current_stage_id)
),
"registered_event_ids": (
self.registered_event_ids
if not json_compatible
else [str(event_id) for event_id in self.registered_event_ids]
),
"completed_event_ids": (
self.completed_event_ids
if not json_compatible
else [str(event_id) for event_id in self.completed_event_ids]
),
}
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None:
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
"""Set attribute data and save it into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
**kwargs (str): Mapping of attribute names and respective values to be set
**kwargs (Any): Mapping of attribute names and respective values to be set
"""
for key, value in kwargs.items():
if not hasattr(self, key):
@@ -160,8 +183,11 @@ class PycordUser(Cacheable):
"id": user_id,
"guild_id": guild_id,
"channel_id": None,
"is_jailed": False,
"current_event_id": None,
"current_stage_id": None,
"registered_event_ids": [],
"completed_event_ids": [],
}
@staticmethod
@@ -179,3 +205,45 @@ class PycordUser(Cacheable):
"""
await self.__collection__.delete_one({"_id": self._id})
self._delete_cache(cache)
# TODO Add documentation
async def event_register(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None:
event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id
if event_id in self.registered_event_ids:
raise RuntimeError(f"User is already registered for event {event_id}")
# TODO Add a unique exception
# raise UserAlreadyRegisteredForEventError(event_name)
self.registered_event_ids.append(event_id)
await self._set(cache, registered_event_ids=self.registered_event_ids)
# TODO Add documentation
async def event_unregister(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None:
event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id
if event_id not in self.registered_event_ids:
raise RuntimeError(f"User is not registered for event {event_id}")
# TODO Add a unique exception
# raise UserNotRegisteredForEventError(event_name)
self.registered_event_ids.remove(event_id)
await self._set(cache, registered_event_ids=self.registered_event_ids)
# TODO Add documentation
async def event_complete(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None:
event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id
if event_id in self.completed_event_ids:
raise RuntimeError(f"User has already completed event {event_id}")
# TODO Add a unique exception
# raise UserAlreadyCompletedEventError(event_name)
self.completed_event_ids.append(event_id)
await self._set(cache, completed_event_ids=self.completed_event_ids)

View File

@@ -1,8 +1,7 @@
from datetime import datetime
from typing import Any, Dict
from zoneinfo import ZoneInfo
from bson import ObjectId
from bson.errors import InvalidId
from discord import (
ApplicationContext,
Attachment,
@@ -14,42 +13,7 @@ from discord.utils import basic_autocomplete
from classes import PycordEvent, PycordGuild
from classes.pycord_bot import PycordBot
from modules.database import col_events
from modules.utils import autocomplete_active_events
# TODO Move to staticmethod or to a separate module
async def validate_event_validity(
ctx: ApplicationContext,
name: str,
start_date: datetime | None,
finish_date: datetime | None,
guild_timezone: ZoneInfo,
event_id: ObjectId | None = None,
) -> None:
if start_date > finish_date:
# TODO Make a nice message
await ctx.respond("Start date must be before finish date")
return
elif start_date < datetime.now(tz=guild_timezone):
# TODO Make a nice message
await ctx.respond("Start date must not be in the past")
return
query: Dict[str, Any] = {
"name": name,
"ended": None,
"ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"cancelled": {"$ne": True},
}
if event_id is not None:
query["_id"] = {"$ne": event_id}
if (await col_events.find_one(query)) is not None:
# TODO Make a nice message
await ctx.respond("There can only be one active event with the same name")
return
from modules.utils import autocomplete_active_events, validate_event_validity
class Event(Cog):
@@ -137,9 +101,10 @@ class Event(Cog):
thumbnail: Attachment = None,
) -> None:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
if pycord_event is None:
try:
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event was not found.")
return
@@ -208,9 +173,10 @@ class Event(Cog):
return
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
if pycord_event is None:
try:
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event was not found.")
return
@@ -251,9 +217,10 @@ class Event(Cog):
)
async def command_event_show(self, ctx: ApplicationContext, event: str) -> None:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
if pycord_event is None:
try:
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event was not found.")
return

View File

@@ -1,6 +1,7 @@
from datetime import datetime
from zoneinfo import ZoneInfo
from bson.errors import InvalidId
from discord import ApplicationContext, Attachment, SlashCommandGroup, option
from discord.ext.commands import Cog
from discord.utils import basic_autocomplete
@@ -14,7 +15,7 @@ async def validate_event_status(
ctx: ApplicationContext,
event: PycordEvent,
) -> None:
if event.cancelled:
if event.is_cancelled:
# TODO Make a nice message
await ctx.respond("This event was cancelled.")
return
@@ -63,7 +64,12 @@ class Stage(Cog):
await ctx.respond("Guild is not configured.")
return
pycord_event: PycordEvent = await self.bot.find_event(event)
try:
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event was not found.")
return
await validate_event_status(ctx, pycord_event)
@@ -123,11 +129,21 @@ class Stage(Cog):
await ctx.respond("Guild is not configured.")
return
pycord_event: PycordEvent = await self.bot.find_event(event)
try:
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event was not found.")
return
await validate_event_status(ctx, pycord_event)
event_stage: PycordEventStage = await self.bot.find_event_stage(stage)
try:
event_stage: PycordEventStage = await self.bot.find_event_stage(stage)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event stage was not found.")
return
if order is not None and order > len(pycord_event.stage_ids):
# TODO Make a nice message
@@ -180,11 +196,21 @@ class Stage(Cog):
await ctx.respond("Guild is not configured.")
return
pycord_event: PycordEvent = await self.bot.find_event(event)
try:
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event was not found.")
return
await validate_event_status(ctx, pycord_event)
event_stage: PycordEventStage = await self.bot.find_event_stage(stage)
try:
event_stage: PycordEventStage = await self.bot.find_event_stage(stage)
except (InvalidId, RuntimeError):
# TODO Make a nice message
await ctx.respond("Event stage was not found.")
return
await pycord_event.remove_stage(self.bot, event_stage._id, cache=self.bot.cache)
await event_stage.purge(cache=self.bot.cache)

View File

@@ -1 +1 @@
from . import utils, database, migrator, scheduler
from . import database, migrator, scheduler, utils

View File

@@ -3,6 +3,9 @@ from .autocomplete_utils import (
autocomplete_event_stages,
autocomplete_languages,
autocomplete_timezones,
autocomplete_user_available_events,
autocomplete_user_registered_events,
)
from .cache_utils import restore_from_cache
from .event_utils import validate_event_validity
from .logging_utils import get_logger, get_logging_config

View File

@@ -6,7 +6,7 @@ from bson import ObjectId
from discord import AutocompleteContext, OptionChoice
from pymongo import ASCENDING
from modules.database import col_events, col_stages
from modules.database import col_events, col_stages, col_users
async def autocomplete_timezones(ctx: AutocompleteContext) -> List[str]:
@@ -29,7 +29,7 @@ async def autocomplete_active_events(ctx: AutocompleteContext) -> List[OptionCho
query: Dict[str, Any] = {
"ended": None,
"ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"cancelled": {"$ne": True},
"is_cancelled": {"$ne": True},
}
event_names: List[OptionChoice] = []
@@ -40,6 +40,43 @@ async def autocomplete_active_events(ctx: AutocompleteContext) -> List[OptionCho
return event_names
async def autocomplete_user_available_events(ctx: AutocompleteContext) -> List[OptionChoice]:
"""Return list of active events user can register in"""
return await autocomplete_active_events(ctx)
async def autocomplete_user_registered_events(ctx: AutocompleteContext) -> List[OptionChoice]:
"""Return list of active events user is registered in"""
pipeline: List[Dict[str, Any]] = [
{
"$lookup": {
"from": "events",
"localField": "registered_event_ids",
"foreignField": "_id",
"as": "registered_events",
}
},
{
"$match": {
"registered_events.ended": None,
"registered_events.ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"registered_events.starts": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"registered_events.is_cancelled": {"$ne": True},
}
},
]
event_names: List[OptionChoice] = []
async for result in col_users.aggregate(pipeline):
for registered_event in result["registered_events"]:
event_names.append(OptionChoice(registered_event["name"], str(registered_event["_id"])))
return event_names
async def autocomplete_event_stages(ctx: AutocompleteContext) -> List[OptionChoice]:
"""Return list of stages of the event"""

View File

@@ -0,0 +1,46 @@
from datetime import datetime
from typing import Any, Dict
from zoneinfo import ZoneInfo
from bson import ObjectId
from discord import (
ApplicationContext,
)
from modules.database import col_events
async def validate_event_validity(
ctx: ApplicationContext,
name: str,
start_date: datetime | None,
finish_date: datetime | None,
guild_timezone: ZoneInfo,
event_id: ObjectId | None = None,
) -> None:
if start_date > finish_date:
# TODO Make a nice message
await ctx.respond("Start date must be before finish date")
return
if start_date < datetime.now(tz=guild_timezone):
# TODO Make a nice message
await ctx.respond("Start date must not be in the past")
return
# TODO Add validation for concurrent events.
# Only one event can take place at the same time.
query: Dict[str, Any] = {
"name": name,
"ended": None,
"ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"is_cancelled": {"$ne": True},
}
if event_id is not None:
query["_id"] = {"$ne": event_id}
if (await col_events.find_one(query)) is not None:
# TODO Make a nice message
await ctx.respond("There can only be one active event with the same name")
return