From b72f930c94d72c44576e0d22821fed1050c4fd62 Mon Sep 17 00:00:00 2001 From: Profitroll Date: Mon, 21 Apr 2025 14:34:13 +0200 Subject: [PATCH] WIP: Events and Stages --- classes/pycord_bot.py | 10 ++++ classes/pycord_event.py | 35 ++++++++++--- cogs/event.py | 110 +++++++++++++++++++++++++++++++--------- cogs/stage.py | 2 + 4 files changed, 124 insertions(+), 33 deletions(-) diff --git a/classes/pycord_bot.py b/classes/pycord_bot.py index 5d6775c..eca5b61 100644 --- a/classes/pycord_bot.py +++ b/classes/pycord_bot.py @@ -1,6 +1,7 @@ from logging import Logger from typing import Any +from bson import ObjectId from discord import Guild, User from libbot.cache.classes import CacheMemcached, CacheRedis from libbot.cache.manager import create_cache_client @@ -90,3 +91,12 @@ class PycordBot(LibPycordBot): self.scheduler.shutdown() await super().close(**kwargs) + + async def find_event(self, event_id: str | ObjectId | None = None, event_name: str | None = None): + if event_id is None and event_name is None: + raise AttributeError("Either event's ID or name must be provided!") + + if event_id is not None: + await PycordEvent.from_id(event_id, cache=self.cache) + else: + await PycordEvent.from_name(event_name, cache=self.cache) diff --git a/classes/pycord_event.py b/classes/pycord_event.py index 0f45922..18677f6 100644 --- a/classes/pycord_event.py +++ b/classes/pycord_event.py @@ -20,6 +20,7 @@ class PycordEvent: "name", "guild_id", "created", + "ended", "creator_id", "starts", "ends", @@ -33,6 +34,7 @@ class PycordEvent: name: str guild_id: int created: datetime + ended: datetime | None creator_id: int starts: datetime ends: datetime @@ -74,22 +76,37 @@ class PycordEvent: return cls(**db_entry) + @classmethod + async def from_name(cls, event_name: str, cache: Optional[Cache] = None) -> "PycordEvent": + # TODO Add sorting by creation date or something. + # Duplicate events should be avoided, latest active event should be returned. + db_entry: Dict[str, Any] | None = await cls.__collection__.find_one({"name": event_name}) + + if db_entry is None: + raise RuntimeError(f"Event with name {event_name} not found") + + if cache is not None: + cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", db_entry) + + return cls(**db_entry) + # TODO Implement this method @classmethod async def create( - cls, - name: str, - guild_id: int, - creator_id: int, - starts: datetime, - ends: datetime, - thumbnail_id: str | None, - cache: Optional[Cache] = None, + cls, + name: str, + guild_id: int, + creator_id: int, + starts: datetime, + ends: datetime, + thumbnail_id: str | None, + cache: Optional[Cache] = None, ) -> "PycordEvent": db_entry: Dict[str, Any] = { "name": name, "guild_id": guild_id, "created": datetime.now(tz=timezone.utc), + "ended": None, "creator_id": creator_id, "starts": starts, "ends": ends, @@ -179,6 +196,7 @@ class PycordEvent: "name": self.name, "guild_id": self.guild_id, "created": self.created, + "ended": self.ended, "creator_id": self.creator_id, "starts": self.starts, "ends": self.ends, @@ -192,6 +210,7 @@ class PycordEvent: "name": None, "guild_id": None, "created": None, + "ended": None, "creator_id": None, "starts": None, "ends": None, diff --git a/cogs/event.py b/cogs/event.py index f19fd8c..076b722 100644 --- a/cogs/event.py +++ b/cogs/event.py @@ -1,11 +1,51 @@ from datetime import datetime +from typing import Dict, Any, List from zoneinfo import ZoneInfo -from discord import ApplicationContext, Attachment, SlashCommandGroup, option +from bson import ObjectId +from discord import ApplicationContext, Attachment, SlashCommandGroup, option, AutocompleteContext from discord.ext.commands import Cog +from discord.utils import basic_autocomplete from classes import PycordEvent, PycordGuild from classes.pycord_bot import PycordBot +from modules.database import col_events + + +async def get_event(ctx: AutocompleteContext): + query: Dict[str, Any] = {"ended": None, "ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}} + + event_names: List[str] = [] + + async for result in col_events.find(query): + event_names.append(result["name"]) + + return event_names + + +async def validate_event_validity( + ctx: ApplicationContext, + name: str, + start_date: datetime | None, + finish_date: datetime | None, + guild_timezone: ZoneInfo, + event_id: ObjectId | None = None, +) -> None: + if start_date > finish_date: + await ctx.respond("Start date must be before finish date") + return + elif start_date < datetime.now(tz=guild_timezone): + await ctx.respond("Start date must not be in the past") + return + + query: Dict[str, Any] = {"name": name, "ended": None, "ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))}} + + if event_id is not None: + query["_id"] = {"$ne": event_id} + + if (await col_events.find_one(query)) is not None: + await ctx.respond("There can only be one active event with the same name") + return class Event(Cog): @@ -27,12 +67,12 @@ class Event(Cog): @option("finish", description="Date when the event finishes (DD.MM.YYYY HH:MM)", required=True) @option("thumbnail", description="Thumbnail of the event", required=False) async def command_event_create( - self, - ctx: ApplicationContext, - name: str, - start: str, - finish: str, - thumbnail: Attachment = None, + self, + ctx: ApplicationContext, + name: str, + start: str, + finish: str, + thumbnail: Attachment = None, ) -> None: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) @@ -52,12 +92,7 @@ class Event(Cog): await ctx.respond("Could not parse start and finish dates.") return - if start_date > finish_date: - await ctx.respond("Start date must be before finish date") - return - elif start_date < datetime.now(tz=guild_timezone): - await ctx.respond("Start date must not be in the past") - return + await validate_event_validity(ctx, name, start_date, finish_date, guild_timezone) event: PycordEvent = await self.bot.create_event( name=name, @@ -75,21 +110,26 @@ class Event(Cog): name="edit", description="Edit event", ) - @option("event", description="Name of the event", required=True) + @option( + "event", description="Name of the event", autocomplete=basic_autocomplete(get_event), required=True + ) @option("name", description="New name of the event", required=False) @option("start", description="Date when the event starts (DD.MM.YYYY HH:MM)", required=False) @option("finish", description="Date when the event finishes (DD.MM.YYYY HH:MM)", required=False) @option("thumbnail", description="Thumbnail of the event", required=False) async def command_event_edit( - self, - ctx: ApplicationContext, - event: str, - name: str, - start: str, - finish: str, - thumbnail: Attachment = None, + self, + ctx: ApplicationContext, + event: str, + name: str = None, + start: str = None, + finish: str = None, + thumbnail: Attachment = None, ) -> None: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) + event: PycordEvent = await self.bot.find_event( + event_name=name + ) if not guild.is_configured(): await ctx.respond("Guild is not configured.") @@ -97,7 +137,27 @@ class Event(Cog): guild_timezone: ZoneInfo = ZoneInfo(guild.timezone) - await ctx.respond("Not implemented.") + if start is not None: + try: + start_date: datetime = datetime.strptime(start, "%d.%m.%Y %H:%M") + start_date = start_date.replace(tzinfo=guild_timezone) + + await event.set_start_date(start_date) + except ValueError: + await ctx.respond("Could not parse the start date.") + return + + if finish is not None: + try: + finish_date: datetime = datetime.strptime(finish, "%d.%m.%Y %H:%M") + finish_date = finish_date.replace(tzinfo=guild_timezone) + + await event.set_end_date(finish_date) + except ValueError: + await ctx.respond("Could not parse the finish date.") + return + + await validate_event_validity(ctx, name, start_date, finish_date, guild_timezone) # TODO Implement the command @command_group.command( @@ -106,9 +166,9 @@ class Event(Cog): ) @option("name", description="Name of the event", required=True) async def command_event_cancel( - self, - ctx: ApplicationContext, - name: str, + self, + ctx: ApplicationContext, + name: str, ) -> None: guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) diff --git a/cogs/stage.py b/cogs/stage.py index 917f3ba..05cf505 100644 --- a/cogs/stage.py +++ b/cogs/stage.py @@ -9,6 +9,8 @@ class Stage(Cog): def __init__(self, bot: PycordBot): self.bot: PycordBot = bot + # command_group: SlashCommandGroup = SlashCommandGroup("stage", "Event stage management") + def setup(bot: PycordBot) -> None: bot.add_cog(Stage(bot))