diff --git a/classes/pycord_bot.py b/classes/pycord_bot.py index 02d6718..fd74200 100644 --- a/classes/pycord_bot.py +++ b/classes/pycord_bot.py @@ -1,13 +1,16 @@ +from datetime import datetime from logging import Logger -from typing import Any, override +from typing import Any, override, List, Dict +from zoneinfo import ZoneInfo from bson import ObjectId -from discord import Guild, User +from discord import Guild, User, TextChannel from libbot.cache.classes import CacheMemcached, CacheRedis from libbot.cache.manager import create_cache_client from libbot.pycord.classes import PycordBot as LibPycordBot from classes import PycordEvent, PycordEventStage, PycordGuild, PycordUser +from modules.database import col_users, col_events from modules.utils import get_logger logger: Logger = get_logger(__name__) @@ -39,18 +42,152 @@ class PycordBot(LibPycordBot): """ return self.bot_locale._(key, *args, locale=None if locale is None else locale.split("-")[0]) - def _set_cache_engine(self) -> None: - if "cache" in self.config and self.config["cache"]["type"] is not None: - self.cache = create_cache_client(self.config, self.config["cache"]["type"]) - @override async def start(self, *args: Any, **kwargs: Any) -> None: + await self._schedule_tasks() await super().start(*args, **kwargs) @override async def close(self, **kwargs) -> None: await super().close(**kwargs) + async def _schedule_tasks(self) -> None: + self.scheduler.add_job( + self._execute_event_controller, trigger="cron", minute="*/1", id="event_controller" + ) + + async def _execute_event_controller(self) -> None: + await self._process_events_start() + await self._process_events_end() + # await self._process_events_post_end() + + async def _process_events_start(self) -> None: + # Get events to start + events: List[PycordEvent] = await self._get_events( + {"starts": datetime.now(tz=ZoneInfo("UTC")).replace(second=0, microsecond=0)} + ) + + # Process each event + for event in events: + guild: Guild = self.get_guild(event.guild_id) + pycord_guild: PycordGuild = await self.find_guild(guild) + + # Get list of participants + users: List[PycordUser] = await self._get_event_participants(event._id) + + for user in users: + # Create a channel for each participant + await user.setup_event_channel(self, guild, pycord_guild, event, cache=self.cache) + + # Send a notification about event start + user_channel: TextChannel = guild.get_channel(user.event_channels[str(event._id)]) + + # TODO Make a nice message + # TODO Also send a thumbnail, event info and short explanation on how to play + await user_channel.send(f"Event **{event.name}** is starting!") + + # TODO Make a nice message + await self._notify_admins( + guild, + pycord_guild, + f"Event **{event.name}** has started! Users have gotten their channels and can already start submitting their answers.", + ) + + async def _process_events_end(self) -> None: + # Get events to end + events: List[PycordEvent] = await self._get_events( + {"ends": datetime.now(tz=ZoneInfo("UTC")).replace(second=0, microsecond=0)} + ) + + # Process each event + for event in events: + guild: Guild = self.get_guild(event.guild_id) + pycord_guild: PycordGuild = await self.find_guild(guild) + + # Get list of participants + users: List[PycordUser] = await self._get_event_participants(event._id) + + for user in users: + # Send a notification about event start + user_channel: TextChannel = guild.get_channel(user.event_channels[str(event._id)]) + + # TODO Make a nice message + # TODO Reveal answers to stages + await user_channel.send(f"Event **{event.name}** has ended!") + + # Lock each participant out + await user.lock_event_channel(guild, event._id, channel=user_channel) + + # TODO Make a nice message + await self._notify_admins( + guild, + pycord_guild, + f"Event **{event.name}** has ended! Users can no longer submit their answers.", + ) + + # async def _process_events_post_end(self) -> None: + # # Get events that ended an hour ago + # # TODO Replace with 1 hour after testing! + # events: List[PycordEvent] = await self._get_events( + # { + # "ends": datetime.now(tz=ZoneInfo("UTC")).replace(second=0, microsecond=0) + # - timedelta(minutes=1) + # } + # ) + # + # # Process each event + # for event in events: + # guild: Guild = self.get_guild(event.guild_id) + # pycord_guild: PycordGuild = await self.find_guild(guild) + # + # # Get list of participants + # users: List[PycordUser] = await self._get_event_participants(event._id) + # + # for user in users: + # # Send a notification about event start + # user_channel: TextChannel = guild.get_channel(user.event_channels[str(event._id)]) + # + # # Remove their view permissions + # await user.lock_event_channel(guild, event._id, completely=True, channel=user_channel) + # + # await self._notify_admins( + # guild, + # pycord_guild, + # f"Access has been updated, users can no longer access their channels for the event **{event.name}**.", + # ) + + @staticmethod + async def _get_events(query: Dict[str, Any]) -> List[PycordEvent]: + events: List[PycordEvent] = [] + + async for event_entry in col_events.find(query): + events.append(PycordEvent(**event_entry)) + + return events + + @staticmethod + async def _get_event_participants(event_id: str | ObjectId) -> List[PycordUser]: + users: List[PycordUser] = [] + + async for user_entry in col_users.find({"registered_event_ids": event_id}): + users.append(PycordUser(**user_entry)) + + return users + + @staticmethod + async def _notify_admins(guild: Guild, pycord_guild: PycordGuild, message: str) -> None: + management_channel: TextChannel | None = guild.get_channel(pycord_guild.channel_id) + + if management_channel is None: + logger.error( + "Discord channel with ID %s in guild with ID %s could not be found!", + pycord_guild.channel_id, + guild.id, + ) + return + + await management_channel.send(message) + async def find_user(self, user: int | User) -> PycordUser: """Find User by its ID or User object. diff --git a/classes/pycord_user.py b/classes/pycord_user.py index 3c4c2be..6dd51a5 100644 --- a/classes/pycord_user.py +++ b/classes/pycord_user.py @@ -1,8 +1,11 @@ from dataclasses import dataclass +from hashlib import shake_256 from logging import Logger from typing import Any, Dict, List, Optional from bson import ObjectId +from discord import Bot, Guild, Member, PermissionOverwrite, TextChannel, Forbidden, Role +from discord.abc import GuildChannel from libbot.cache.classes import Cache from pymongo.results import InsertOneResult @@ -21,7 +24,7 @@ class PycordUser: "_id", "id", "guild_id", - "channel_id", + "event_channels", "is_jailed", "current_event_id", "current_stage_id", @@ -34,13 +37,18 @@ class PycordUser: _id: ObjectId id: int guild_id: int - channel_id: int | None + event_channels: Dict[str, int] is_jailed: bool current_event_id: ObjectId | None current_stage_id: ObjectId | None registered_event_ids: List[ObjectId] completed_event_ids: List[ObjectId] + # TODO Review the redesign + # event_channel_ids: { + # "%event_id%": %channel_id% + # } + @classmethod async def from_id( cls, user_id: int, allow_creation: bool = True, cache: Optional[Cache] = None @@ -93,7 +101,7 @@ class PycordUser: "_id": self._id if not json_compatible else str(self._id), "id": self.id, "guild_id": self.guild_id, - "channel_id": self.channel_id, + "event_channels": self.event_channels, "is_jailed": self.is_jailed, "current_event_id": ( self.current_event_id if not json_compatible else str(self.current_event_id) @@ -182,7 +190,7 @@ class PycordUser: return { "id": user_id, "guild_id": guild_id, - "channel_id": None, + "event_channels": {}, "is_jailed": False, "current_event_id": None, "current_stage_id": None, @@ -247,3 +255,109 @@ class PycordUser: self.completed_event_ids.append(event_id) await self._set(cache, completed_event_ids=self.completed_event_ids) + + async def setup_event_channel( + self, + bot: Bot, + guild: Guild, + pycord_guild: "PycordGuild", + pycord_event: "PycordEvent", + cache: Optional[Cache] = None, + ): + if str(pycord_event._id) in self.event_channels.keys(): + return + + discord_member: Member | None = guild.get_member(self.id) + discord_category: GuildChannel | None = bot.get_channel(pycord_guild.category_id) + + if discord_member is None: + raise RuntimeError( + f"Discord guild member with ID {self.id} in guild with ID {guild.id} could not be found!" + ) + + # TODO Add a unique exception + # raise DiscordGuildMemberNotFoundError(self.id, guild.id) + + if discord_category is None: + raise RuntimeError( + f"Discord category with ID {pycord_guild.category_id} in guild with ID {guild.id} could not be found!" + ) + + # TODO Add a unique exception + # raise DiscordCategoryNotFoundError(pycord_guild.category_id, guild.id) + + permission_overwrites: Dict[Role | Member, PermissionOverwrite] = { + guild.default_role: PermissionOverwrite( + view_channel=False, + ), + guild.self_role: PermissionOverwrite( + view_channel=True, + ), + discord_member: PermissionOverwrite( + view_channel=True, + send_messages=True, + use_application_commands=True, + ), + } + + channel: TextChannel = await guild.create_text_channel( + f"{discord_member.name}_{shake_256(str(pycord_event._id).encode()).hexdigest(3)}", + category=discord_category, + overwrites=permission_overwrites, + reason=f"Event channel of {self.id} for event {pycord_event._id}", + ) + + await self.set_event_channel(pycord_event._id, channel.id, cache=cache) + + async def lock_event_channel( + self, + guild: Guild, + event_id: str | ObjectId, + completely: bool = False, + channel: Optional[TextChannel] = None, + ): + discord_member: Member | None = guild.get_member(self.id) + discord_channel: TextChannel | None = ( + channel if channel is not None else guild.get_channel(self.event_channels[str(event_id)]) + ) + + if discord_member is None: + raise RuntimeError( + f"Discord guild member with ID {self.id} in guild with ID {guild.id} could not be found!" + ) + + # TODO Add a unique exception + # raise DiscordGuildMemberNotFoundError(self.id, guild.id) + + if discord_member is None: + raise RuntimeError( + f"Discord channel with ID {self.event_channels[str(event_id)]} in guild with ID {guild.id} could not be found!" + ) + + # TODO Add a unique exception + # raise DiscordChannelNotFoundError(self.event_channels[str(event_id)], guild.id) + + permission_overwrite: PermissionOverwrite = PermissionOverwrite( + view_channel=not completely, + send_messages=False, + use_application_commands=False, + ) + + try: + await discord_channel.set_permissions( + discord_member, overwrite=permission_overwrite, reason="Invoked from the user method" + ) + except Forbidden: + logger.error( + "Could not update channel permissions of %s for %s due to user having higher privileges.", + discord_channel.id, + self.id, + ) + + # TODO Add documentation + async def set_event_channel( + self, event_id: str | ObjectId, channel_id: int, cache: Optional[Cache] = None + ) -> None: + self.event_channels[event_id if isinstance(event_id, str) else str(event_id)] = channel_id + + await self._set(cache, event_channels=self.event_channels) diff --git a/main.py b/main.py index 494f84d..58f70ed 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,7 @@ from os import makedirs from pathlib import Path from sys import exit -from discord import LoginFailure +from discord import LoginFailure, Intents from libbot.utils import config_get from classes.pycord_bot import PycordBot @@ -55,7 +55,11 @@ def main(): # downgrade_database() # exit() - bot: PycordBot = PycordBot(scheduler=scheduler) + intents = Intents.default() + + intents.members = True + + bot: PycordBot = PycordBot(scheduler=scheduler, intents=intents) bot.load_extension("cogs")