@@ -113,7 +113,9 @@ class PycordBot(LibPycordBot):
|
||||
return event_stage
|
||||
|
||||
# TODO Document this method
|
||||
async def find_event(self, event_id: str | ObjectId | None = None, event_name: str | None = None):
|
||||
async def find_event(
|
||||
self, event_id: str | ObjectId | None = None, event_name: str | None = None
|
||||
) -> PycordEvent:
|
||||
if event_id is None and event_name is None:
|
||||
raise AttributeError("Either event's ID or name must be provided!")
|
||||
|
||||
|
@@ -23,7 +23,7 @@ class PycordEvent:
|
||||
"guild_id",
|
||||
"created",
|
||||
"ended",
|
||||
"cancelled",
|
||||
"is_cancelled",
|
||||
"creator_id",
|
||||
"starts",
|
||||
"ends",
|
||||
@@ -38,7 +38,7 @@ class PycordEvent:
|
||||
guild_id: int
|
||||
created: datetime
|
||||
ended: datetime | None
|
||||
cancelled: bool
|
||||
is_cancelled: bool
|
||||
creator_id: int
|
||||
starts: datetime
|
||||
ends: datetime
|
||||
@@ -58,6 +58,7 @@ class PycordEvent:
|
||||
|
||||
Raises:
|
||||
EventNotFoundError: Event was not found
|
||||
InvalidId: Invalid event ID was provided
|
||||
"""
|
||||
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache)
|
||||
|
||||
@@ -112,7 +113,7 @@ class PycordEvent:
|
||||
"guild_id": guild_id,
|
||||
"created": datetime.now(tz=ZoneInfo("UTC")),
|
||||
"ended": None,
|
||||
"cancelled": False,
|
||||
"is_cancelled": False,
|
||||
"creator_id": creator_id,
|
||||
"starts": starts,
|
||||
"ends": ends,
|
||||
@@ -129,12 +130,12 @@ class PycordEvent:
|
||||
|
||||
return cls(**db_entry)
|
||||
|
||||
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None:
|
||||
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
|
||||
"""Set attribute data and save it into the database.
|
||||
|
||||
Args:
|
||||
cache (:obj:`Cache`, optional): Cache engine to write the update into
|
||||
**kwargs (str): Mapping of attribute names and respective values to be set
|
||||
**kwargs (Any): Mapping of attribute names and respective values to be set
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if not hasattr(self, key):
|
||||
@@ -208,7 +209,7 @@ class PycordEvent:
|
||||
"guild_id": self.guild_id,
|
||||
"created": self.created,
|
||||
"ended": self.ended,
|
||||
"cancelled": self.cancelled,
|
||||
"is_cancelled": self.is_cancelled,
|
||||
"creator_id": self.creator_id,
|
||||
"starts": self.starts,
|
||||
"ends": self.ends,
|
||||
@@ -223,7 +224,7 @@ class PycordEvent:
|
||||
"guild_id": None,
|
||||
"created": None,
|
||||
"ended": None,
|
||||
"cancelled": False,
|
||||
"is_cancelled": False,
|
||||
"creator_id": None,
|
||||
"starts": None,
|
||||
"ends": None,
|
||||
@@ -265,7 +266,7 @@ class PycordEvent:
|
||||
|
||||
# TODO Add documentation
|
||||
async def cancel(self, cache: Optional[Cache] = None):
|
||||
await self._set(cache, cancelled=True)
|
||||
await self._set(cache, is_cancelled=True)
|
||||
|
||||
async def _update_event_stage_order(
|
||||
self,
|
||||
|
@@ -106,12 +106,12 @@ class PycordEventStage:
|
||||
|
||||
return cls(**db_entry)
|
||||
|
||||
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None:
|
||||
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
|
||||
"""Set attribute data and save it into the database.
|
||||
|
||||
Args:
|
||||
cache (:obj:`Cache`, optional): Cache engine to write the update into
|
||||
**kwargs (str): Mapping of attribute names and respective values to be set
|
||||
**kwargs (Any): Mapping of attribute names and respective values to be set
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if not hasattr(self, key):
|
||||
|
@@ -67,12 +67,12 @@ class PycordGuild:
|
||||
|
||||
return cls(**db_entry)
|
||||
|
||||
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None:
|
||||
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
|
||||
"""Set attribute data and save it into the database.
|
||||
|
||||
Args:
|
||||
cache (:obj:`Cache`, optional): Cache engine to write the update into
|
||||
**kwargs (str): Mapping of attribute names and respective values to be set
|
||||
**kwargs (Any): Mapping of attribute names and respective values to be set
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if not hasattr(self, key):
|
||||
|
@@ -1,12 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from bson import ObjectId
|
||||
from libbot.cache.classes import Cache
|
||||
from pymongo.results import InsertOneResult
|
||||
|
||||
from classes.abstract.cacheable import Cacheable
|
||||
from classes.errors.pycord_user import UserNotFoundError
|
||||
from modules.database import col_users
|
||||
from modules.utils import get_logger, restore_from_cache
|
||||
@@ -15,10 +14,20 @@ logger: Logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PycordUser(Cacheable):
|
||||
class PycordUser:
|
||||
"""Dataclass of DB entry of a user"""
|
||||
|
||||
__slots__ = ("_id", "id", "guild_id", "channel_id", "current_event_id", "current_stage_id")
|
||||
__slots__ = (
|
||||
"_id",
|
||||
"id",
|
||||
"guild_id",
|
||||
"channel_id",
|
||||
"is_jailed",
|
||||
"current_event_id",
|
||||
"current_stage_id",
|
||||
"registered_event_ids",
|
||||
"completed_event_ids",
|
||||
)
|
||||
__short_name__ = "user"
|
||||
__collection__ = col_users
|
||||
|
||||
@@ -26,8 +35,11 @@ class PycordUser(Cacheable):
|
||||
id: int
|
||||
guild_id: int
|
||||
channel_id: int | None
|
||||
is_jailed: bool
|
||||
current_event_id: ObjectId | None
|
||||
current_stage_id: ObjectId | None
|
||||
registered_event_ids: List[ObjectId]
|
||||
completed_event_ids: List[ObjectId]
|
||||
|
||||
@classmethod
|
||||
async def from_id(
|
||||
@@ -82,20 +94,31 @@ class PycordUser(Cacheable):
|
||||
"id": self.id,
|
||||
"guild_id": self.guild_id,
|
||||
"channel_id": self.channel_id,
|
||||
"is_jailed": self.is_jailed,
|
||||
"current_event_id": (
|
||||
self.current_event_id if not json_compatible else str(self.current_event_id)
|
||||
),
|
||||
"current_stage_id": (
|
||||
self.current_stage_id if not json_compatible else str(self.current_stage_id)
|
||||
),
|
||||
"registered_event_ids": (
|
||||
self.registered_event_ids
|
||||
if not json_compatible
|
||||
else [str(event_id) for event_id in self.registered_event_ids]
|
||||
),
|
||||
"completed_event_ids": (
|
||||
self.completed_event_ids
|
||||
if not json_compatible
|
||||
else [str(event_id) for event_id in self.completed_event_ids]
|
||||
),
|
||||
}
|
||||
|
||||
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None:
|
||||
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
|
||||
"""Set attribute data and save it into the database.
|
||||
|
||||
Args:
|
||||
cache (:obj:`Cache`, optional): Cache engine to write the update into
|
||||
**kwargs (str): Mapping of attribute names and respective values to be set
|
||||
**kwargs (Any): Mapping of attribute names and respective values to be set
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if not hasattr(self, key):
|
||||
@@ -160,8 +183,11 @@ class PycordUser(Cacheable):
|
||||
"id": user_id,
|
||||
"guild_id": guild_id,
|
||||
"channel_id": None,
|
||||
"is_jailed": False,
|
||||
"current_event_id": None,
|
||||
"current_stage_id": None,
|
||||
"registered_event_ids": [],
|
||||
"completed_event_ids": [],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -179,3 +205,45 @@ class PycordUser(Cacheable):
|
||||
"""
|
||||
await self.__collection__.delete_one({"_id": self._id})
|
||||
self._delete_cache(cache)
|
||||
|
||||
# TODO Add documentation
|
||||
async def event_register(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None:
|
||||
event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id
|
||||
|
||||
if event_id in self.registered_event_ids:
|
||||
raise RuntimeError(f"User is already registered for event {event_id}")
|
||||
|
||||
# TODO Add a unique exception
|
||||
# raise UserAlreadyRegisteredForEventError(event_name)
|
||||
|
||||
self.registered_event_ids.append(event_id)
|
||||
|
||||
await self._set(cache, registered_event_ids=self.registered_event_ids)
|
||||
|
||||
# TODO Add documentation
|
||||
async def event_unregister(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None:
|
||||
event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id
|
||||
|
||||
if event_id not in self.registered_event_ids:
|
||||
raise RuntimeError(f"User is not registered for event {event_id}")
|
||||
|
||||
# TODO Add a unique exception
|
||||
# raise UserNotRegisteredForEventError(event_name)
|
||||
|
||||
self.registered_event_ids.remove(event_id)
|
||||
|
||||
await self._set(cache, registered_event_ids=self.registered_event_ids)
|
||||
|
||||
# TODO Add documentation
|
||||
async def event_complete(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None:
|
||||
event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id
|
||||
|
||||
if event_id in self.completed_event_ids:
|
||||
raise RuntimeError(f"User has already completed event {event_id}")
|
||||
|
||||
# TODO Add a unique exception
|
||||
# raise UserAlreadyCompletedEventError(event_name)
|
||||
|
||||
self.completed_event_ids.append(event_id)
|
||||
|
||||
await self._set(cache, completed_event_ids=self.completed_event_ids)
|
||||
|
@@ -1,8 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from bson import ObjectId
|
||||
from bson.errors import InvalidId
|
||||
from discord import (
|
||||
ApplicationContext,
|
||||
Attachment,
|
||||
@@ -14,42 +13,7 @@ from discord.utils import basic_autocomplete
|
||||
|
||||
from classes import PycordEvent, PycordGuild
|
||||
from classes.pycord_bot import PycordBot
|
||||
from modules.database import col_events
|
||||
from modules.utils import autocomplete_active_events
|
||||
|
||||
|
||||
# TODO Move to staticmethod or to a separate module
|
||||
async def validate_event_validity(
|
||||
ctx: ApplicationContext,
|
||||
name: str,
|
||||
start_date: datetime | None,
|
||||
finish_date: datetime | None,
|
||||
guild_timezone: ZoneInfo,
|
||||
event_id: ObjectId | None = None,
|
||||
) -> None:
|
||||
if start_date > finish_date:
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("Start date must be before finish date")
|
||||
return
|
||||
elif start_date < datetime.now(tz=guild_timezone):
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("Start date must not be in the past")
|
||||
return
|
||||
|
||||
query: Dict[str, Any] = {
|
||||
"name": name,
|
||||
"ended": None,
|
||||
"ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
|
||||
"cancelled": {"$ne": True},
|
||||
}
|
||||
|
||||
if event_id is not None:
|
||||
query["_id"] = {"$ne": event_id}
|
||||
|
||||
if (await col_events.find_one(query)) is not None:
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("There can only be one active event with the same name")
|
||||
return
|
||||
from modules.utils import autocomplete_active_events, validate_event_validity
|
||||
|
||||
|
||||
class Event(Cog):
|
||||
@@ -137,9 +101,10 @@ class Event(Cog):
|
||||
thumbnail: Attachment = None,
|
||||
) -> None:
|
||||
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
|
||||
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
|
||||
|
||||
if pycord_event is None:
|
||||
try:
|
||||
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
|
||||
except (InvalidId, RuntimeError):
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("Event was not found.")
|
||||
return
|
||||
@@ -208,9 +173,10 @@ class Event(Cog):
|
||||
return
|
||||
|
||||
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
|
||||
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
|
||||
|
||||
if pycord_event is None:
|
||||
try:
|
||||
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
|
||||
except (InvalidId, RuntimeError):
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("Event was not found.")
|
||||
return
|
||||
@@ -251,9 +217,10 @@ class Event(Cog):
|
||||
)
|
||||
async def command_event_show(self, ctx: ApplicationContext, event: str) -> None:
|
||||
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
|
||||
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
|
||||
|
||||
if pycord_event is None:
|
||||
try:
|
||||
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
|
||||
except (InvalidId, RuntimeError):
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("Event was not found.")
|
||||
return
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from bson.errors import InvalidId
|
||||
from discord import ApplicationContext, Attachment, SlashCommandGroup, option
|
||||
from discord.ext.commands import Cog
|
||||
from discord.utils import basic_autocomplete
|
||||
@@ -14,7 +15,7 @@ async def validate_event_status(
|
||||
ctx: ApplicationContext,
|
||||
event: PycordEvent,
|
||||
) -> None:
|
||||
if event.cancelled:
|
||||
if event.is_cancelled:
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("This event was cancelled.")
|
||||
return
|
||||
@@ -63,7 +64,12 @@ class Stage(Cog):
|
||||
await ctx.respond("Guild is not configured.")
|
||||
return
|
||||
|
||||
pycord_event: PycordEvent = await self.bot.find_event(event)
|
||||
try:
|
||||
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
|
||||
except (InvalidId, RuntimeError):
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("Event was not found.")
|
||||
return
|
||||
|
||||
await validate_event_status(ctx, pycord_event)
|
||||
|
||||
@@ -123,11 +129,21 @@ class Stage(Cog):
|
||||
await ctx.respond("Guild is not configured.")
|
||||
return
|
||||
|
||||
pycord_event: PycordEvent = await self.bot.find_event(event)
|
||||
try:
|
||||
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
|
||||
except (InvalidId, RuntimeError):
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("Event was not found.")
|
||||
return
|
||||
|
||||
await validate_event_status(ctx, pycord_event)
|
||||
|
||||
event_stage: PycordEventStage = await self.bot.find_event_stage(stage)
|
||||
try:
|
||||
event_stage: PycordEventStage = await self.bot.find_event_stage(stage)
|
||||
except (InvalidId, RuntimeError):
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("Event stage was not found.")
|
||||
return
|
||||
|
||||
if order is not None and order > len(pycord_event.stage_ids):
|
||||
# TODO Make a nice message
|
||||
@@ -180,11 +196,21 @@ class Stage(Cog):
|
||||
await ctx.respond("Guild is not configured.")
|
||||
return
|
||||
|
||||
pycord_event: PycordEvent = await self.bot.find_event(event)
|
||||
try:
|
||||
pycord_event: PycordEvent = await self.bot.find_event(event_id=event)
|
||||
except (InvalidId, RuntimeError):
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("Event was not found.")
|
||||
return
|
||||
|
||||
await validate_event_status(ctx, pycord_event)
|
||||
|
||||
event_stage: PycordEventStage = await self.bot.find_event_stage(stage)
|
||||
try:
|
||||
event_stage: PycordEventStage = await self.bot.find_event_stage(stage)
|
||||
except (InvalidId, RuntimeError):
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("Event stage was not found.")
|
||||
return
|
||||
|
||||
await pycord_event.remove_stage(self.bot, event_stage._id, cache=self.bot.cache)
|
||||
await event_stage.purge(cache=self.bot.cache)
|
||||
|
@@ -1 +1 @@
|
||||
from . import utils, database, migrator, scheduler
|
||||
from . import database, migrator, scheduler, utils
|
||||
|
@@ -3,6 +3,9 @@ from .autocomplete_utils import (
|
||||
autocomplete_event_stages,
|
||||
autocomplete_languages,
|
||||
autocomplete_timezones,
|
||||
autocomplete_user_available_events,
|
||||
autocomplete_user_registered_events,
|
||||
)
|
||||
from .cache_utils import restore_from_cache
|
||||
from .event_utils import validate_event_validity
|
||||
from .logging_utils import get_logger, get_logging_config
|
||||
|
@@ -6,7 +6,7 @@ from bson import ObjectId
|
||||
from discord import AutocompleteContext, OptionChoice
|
||||
from pymongo import ASCENDING
|
||||
|
||||
from modules.database import col_events, col_stages
|
||||
from modules.database import col_events, col_stages, col_users
|
||||
|
||||
|
||||
async def autocomplete_timezones(ctx: AutocompleteContext) -> List[str]:
|
||||
@@ -29,7 +29,7 @@ async def autocomplete_active_events(ctx: AutocompleteContext) -> List[OptionCho
|
||||
query: Dict[str, Any] = {
|
||||
"ended": None,
|
||||
"ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
|
||||
"cancelled": {"$ne": True},
|
||||
"is_cancelled": {"$ne": True},
|
||||
}
|
||||
|
||||
event_names: List[OptionChoice] = []
|
||||
@@ -40,6 +40,43 @@ async def autocomplete_active_events(ctx: AutocompleteContext) -> List[OptionCho
|
||||
return event_names
|
||||
|
||||
|
||||
async def autocomplete_user_available_events(ctx: AutocompleteContext) -> List[OptionChoice]:
|
||||
"""Return list of active events user can register in"""
|
||||
|
||||
return await autocomplete_active_events(ctx)
|
||||
|
||||
|
||||
async def autocomplete_user_registered_events(ctx: AutocompleteContext) -> List[OptionChoice]:
|
||||
"""Return list of active events user is registered in"""
|
||||
|
||||
pipeline: List[Dict[str, Any]] = [
|
||||
{
|
||||
"$lookup": {
|
||||
"from": "events",
|
||||
"localField": "registered_event_ids",
|
||||
"foreignField": "_id",
|
||||
"as": "registered_events",
|
||||
}
|
||||
},
|
||||
{
|
||||
"$match": {
|
||||
"registered_events.ended": None,
|
||||
"registered_events.ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
|
||||
"registered_events.starts": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
|
||||
"registered_events.is_cancelled": {"$ne": True},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
event_names: List[OptionChoice] = []
|
||||
|
||||
async for result in col_users.aggregate(pipeline):
|
||||
for registered_event in result["registered_events"]:
|
||||
event_names.append(OptionChoice(registered_event["name"], str(registered_event["_id"])))
|
||||
|
||||
return event_names
|
||||
|
||||
|
||||
async def autocomplete_event_stages(ctx: AutocompleteContext) -> List[OptionChoice]:
|
||||
"""Return list of stages of the event"""
|
||||
|
||||
|
46
modules/utils/event_utils.py
Normal file
46
modules/utils/event_utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from bson import ObjectId
|
||||
from discord import (
|
||||
ApplicationContext,
|
||||
)
|
||||
|
||||
from modules.database import col_events
|
||||
|
||||
|
||||
async def validate_event_validity(
|
||||
ctx: ApplicationContext,
|
||||
name: str,
|
||||
start_date: datetime | None,
|
||||
finish_date: datetime | None,
|
||||
guild_timezone: ZoneInfo,
|
||||
event_id: ObjectId | None = None,
|
||||
) -> None:
|
||||
if start_date > finish_date:
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("Start date must be before finish date")
|
||||
return
|
||||
|
||||
if start_date < datetime.now(tz=guild_timezone):
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("Start date must not be in the past")
|
||||
return
|
||||
|
||||
# TODO Add validation for concurrent events.
|
||||
# Only one event can take place at the same time.
|
||||
query: Dict[str, Any] = {
|
||||
"name": name,
|
||||
"ended": None,
|
||||
"ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
|
||||
"is_cancelled": {"$ne": True},
|
||||
}
|
||||
|
||||
if event_id is not None:
|
||||
query["_id"] = {"$ne": event_id}
|
||||
|
||||
if (await col_events.find_one(query)) is not None:
|
||||
# TODO Make a nice message
|
||||
await ctx.respond("There can only be one active event with the same name")
|
||||
return
|
Reference in New Issue
Block a user