@@ -1,13 +1,16 @@
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from typing import Any, override
|
||||
from typing import Any, override, List, Dict
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from bson import ObjectId
|
||||
from discord import Guild, User
|
||||
from discord import Guild, User, TextChannel
|
||||
from libbot.cache.classes import CacheMemcached, CacheRedis
|
||||
from libbot.cache.manager import create_cache_client
|
||||
from libbot.pycord.classes import PycordBot as LibPycordBot
|
||||
|
||||
from classes import PycordEvent, PycordEventStage, PycordGuild, PycordUser
|
||||
from modules.database import col_users, col_events
|
||||
from modules.utils import get_logger
|
||||
|
||||
logger: Logger = get_logger(__name__)
|
||||
@@ -39,18 +42,152 @@ class PycordBot(LibPycordBot):
|
||||
"""
|
||||
return self.bot_locale._(key, *args, locale=None if locale is None else locale.split("-")[0])
|
||||
|
||||
def _set_cache_engine(self) -> None:
|
||||
if "cache" in self.config and self.config["cache"]["type"] is not None:
|
||||
self.cache = create_cache_client(self.config, self.config["cache"]["type"])
|
||||
|
||||
@override
|
||||
async def start(self, *args: Any, **kwargs: Any) -> None:
|
||||
await self._schedule_tasks()
|
||||
await super().start(*args, **kwargs)
|
||||
|
||||
@override
|
||||
async def close(self, **kwargs) -> None:
|
||||
await super().close(**kwargs)
|
||||
|
||||
async def _schedule_tasks(self) -> None:
|
||||
self.scheduler.add_job(
|
||||
self._execute_event_controller, trigger="cron", minute="*/1", id="event_controller"
|
||||
)
|
||||
|
||||
async def _execute_event_controller(self) -> None:
|
||||
await self._process_events_start()
|
||||
await self._process_events_end()
|
||||
# await self._process_events_post_end()
|
||||
|
||||
async def _process_events_start(self) -> None:
|
||||
# Get events to start
|
||||
events: List[PycordEvent] = await self._get_events(
|
||||
{"starts": datetime.now(tz=ZoneInfo("UTC")).replace(second=0, microsecond=0)}
|
||||
)
|
||||
|
||||
# Process each event
|
||||
for event in events:
|
||||
guild: Guild = self.get_guild(event.guild_id)
|
||||
pycord_guild: PycordGuild = await self.find_guild(guild)
|
||||
|
||||
# Get list of participants
|
||||
users: List[PycordUser] = await self._get_event_participants(event._id)
|
||||
|
||||
for user in users:
|
||||
# Create a channel for each participant
|
||||
await user.setup_event_channel(self, guild, pycord_guild, event, cache=self.cache)
|
||||
|
||||
# Send a notification about event start
|
||||
user_channel: TextChannel = guild.get_channel(user.event_channels[str(event._id)])
|
||||
|
||||
# TODO Make a nice message
|
||||
# TODO Also send a thumbnail, event info and short explanation on how to play
|
||||
await user_channel.send(f"Event **{event.name}** is starting!")
|
||||
|
||||
# TODO Make a nice message
|
||||
await self._notify_admins(
|
||||
guild,
|
||||
pycord_guild,
|
||||
f"Event **{event.name}** has started! Users have gotten their channels and can already start submitting their answers.",
|
||||
)
|
||||
|
||||
async def _process_events_end(self) -> None:
|
||||
# Get events to end
|
||||
events: List[PycordEvent] = await self._get_events(
|
||||
{"ends": datetime.now(tz=ZoneInfo("UTC")).replace(second=0, microsecond=0)}
|
||||
)
|
||||
|
||||
# Process each event
|
||||
for event in events:
|
||||
guild: Guild = self.get_guild(event.guild_id)
|
||||
pycord_guild: PycordGuild = await self.find_guild(guild)
|
||||
|
||||
# Get list of participants
|
||||
users: List[PycordUser] = await self._get_event_participants(event._id)
|
||||
|
||||
for user in users:
|
||||
# Send a notification about event start
|
||||
user_channel: TextChannel = guild.get_channel(user.event_channels[str(event._id)])
|
||||
|
||||
# TODO Make a nice message
|
||||
# TODO Reveal answers to stages
|
||||
await user_channel.send(f"Event **{event.name}** has ended!")
|
||||
|
||||
# Lock each participant out
|
||||
await user.lock_event_channel(guild, event._id, channel=user_channel)
|
||||
|
||||
# TODO Make a nice message
|
||||
await self._notify_admins(
|
||||
guild,
|
||||
pycord_guild,
|
||||
f"Event **{event.name}** has ended! Users can no longer submit their answers.",
|
||||
)
|
||||
|
||||
# async def _process_events_post_end(self) -> None:
|
||||
# # Get events that ended an hour ago
|
||||
# # TODO Replace with 1 hour after testing!
|
||||
# events: List[PycordEvent] = await self._get_events(
|
||||
# {
|
||||
# "ends": datetime.now(tz=ZoneInfo("UTC")).replace(second=0, microsecond=0)
|
||||
# - timedelta(minutes=1)
|
||||
# }
|
||||
# )
|
||||
#
|
||||
# # Process each event
|
||||
# for event in events:
|
||||
# guild: Guild = self.get_guild(event.guild_id)
|
||||
# pycord_guild: PycordGuild = await self.find_guild(guild)
|
||||
#
|
||||
# # Get list of participants
|
||||
# users: List[PycordUser] = await self._get_event_participants(event._id)
|
||||
#
|
||||
# for user in users:
|
||||
# # Send a notification about event start
|
||||
# user_channel: TextChannel = guild.get_channel(user.event_channels[str(event._id)])
|
||||
#
|
||||
# # Remove their view permissions
|
||||
# await user.lock_event_channel(guild, event._id, completely=True, channel=user_channel)
|
||||
#
|
||||
# await self._notify_admins(
|
||||
# guild,
|
||||
# pycord_guild,
|
||||
# f"Access has been updated, users can no longer access their channels for the event **{event.name}**.",
|
||||
# )
|
||||
|
||||
@staticmethod
|
||||
async def _get_events(query: Dict[str, Any]) -> List[PycordEvent]:
|
||||
events: List[PycordEvent] = []
|
||||
|
||||
async for event_entry in col_events.find(query):
|
||||
events.append(PycordEvent(**event_entry))
|
||||
|
||||
return events
|
||||
|
||||
@staticmethod
|
||||
async def _get_event_participants(event_id: str | ObjectId) -> List[PycordUser]:
|
||||
users: List[PycordUser] = []
|
||||
|
||||
async for user_entry in col_users.find({"registered_event_ids": event_id}):
|
||||
users.append(PycordUser(**user_entry))
|
||||
|
||||
return users
|
||||
|
||||
@staticmethod
|
||||
async def _notify_admins(guild: Guild, pycord_guild: PycordGuild, message: str) -> None:
|
||||
management_channel: TextChannel | None = guild.get_channel(pycord_guild.channel_id)
|
||||
|
||||
if management_channel is None:
|
||||
logger.error(
|
||||
"Discord channel with ID %s in guild with ID %s could not be found!",
|
||||
pycord_guild.channel_id,
|
||||
guild.id,
|
||||
)
|
||||
return
|
||||
|
||||
await management_channel.send(message)
|
||||
|
||||
async def find_user(self, user: int | User) -> PycordUser:
|
||||
"""Find User by its ID or User object.
|
||||
|
||||
|
@@ -1,8 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from hashlib import shake_256
|
||||
from logging import Logger
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from bson import ObjectId
|
||||
from discord import Bot, Guild, Member, PermissionOverwrite, TextChannel, Forbidden, Role
|
||||
from discord.abc import GuildChannel
|
||||
from libbot.cache.classes import Cache
|
||||
from pymongo.results import InsertOneResult
|
||||
|
||||
@@ -21,7 +24,7 @@ class PycordUser:
|
||||
"_id",
|
||||
"id",
|
||||
"guild_id",
|
||||
"channel_id",
|
||||
"event_channels",
|
||||
"is_jailed",
|
||||
"current_event_id",
|
||||
"current_stage_id",
|
||||
@@ -34,13 +37,18 @@ class PycordUser:
|
||||
_id: ObjectId
|
||||
id: int
|
||||
guild_id: int
|
||||
channel_id: int | None
|
||||
event_channels: Dict[str, int]
|
||||
is_jailed: bool
|
||||
current_event_id: ObjectId | None
|
||||
current_stage_id: ObjectId | None
|
||||
registered_event_ids: List[ObjectId]
|
||||
completed_event_ids: List[ObjectId]
|
||||
|
||||
# TODO Review the redesign
|
||||
# event_channel_ids: {
|
||||
# "%event_id%": %channel_id%
|
||||
# }
|
||||
|
||||
@classmethod
|
||||
async def from_id(
|
||||
cls, user_id: int, allow_creation: bool = True, cache: Optional[Cache] = None
|
||||
@@ -93,7 +101,7 @@ class PycordUser:
|
||||
"_id": self._id if not json_compatible else str(self._id),
|
||||
"id": self.id,
|
||||
"guild_id": self.guild_id,
|
||||
"channel_id": self.channel_id,
|
||||
"event_channels": self.event_channels,
|
||||
"is_jailed": self.is_jailed,
|
||||
"current_event_id": (
|
||||
self.current_event_id if not json_compatible else str(self.current_event_id)
|
||||
@@ -182,7 +190,7 @@ class PycordUser:
|
||||
return {
|
||||
"id": user_id,
|
||||
"guild_id": guild_id,
|
||||
"channel_id": None,
|
||||
"event_channels": {},
|
||||
"is_jailed": False,
|
||||
"current_event_id": None,
|
||||
"current_stage_id": None,
|
||||
@@ -247,3 +255,109 @@ class PycordUser:
|
||||
self.completed_event_ids.append(event_id)
|
||||
|
||||
await self._set(cache, completed_event_ids=self.completed_event_ids)
|
||||
|
||||
async def setup_event_channel(
|
||||
self,
|
||||
bot: Bot,
|
||||
guild: Guild,
|
||||
pycord_guild: "PycordGuild",
|
||||
pycord_event: "PycordEvent",
|
||||
cache: Optional[Cache] = None,
|
||||
):
|
||||
if str(pycord_event._id) in self.event_channels.keys():
|
||||
return
|
||||
|
||||
discord_member: Member | None = guild.get_member(self.id)
|
||||
discord_category: GuildChannel | None = bot.get_channel(pycord_guild.category_id)
|
||||
|
||||
if discord_member is None:
|
||||
raise RuntimeError(
|
||||
f"Discord guild member with ID {self.id} in guild with ID {guild.id} could not be found!"
|
||||
)
|
||||
|
||||
# TODO Add a unique exception
|
||||
# raise DiscordGuildMemberNotFoundError(self.id, guild.id)
|
||||
|
||||
if discord_category is None:
|
||||
raise RuntimeError(
|
||||
f"Discord category with ID {pycord_guild.category_id} in guild with ID {guild.id} could not be found!"
|
||||
)
|
||||
|
||||
# TODO Add a unique exception
|
||||
# raise DiscordCategoryNotFoundError(pycord_guild.category_id, guild.id)
|
||||
|
||||
permission_overwrites: Dict[Role | Member, PermissionOverwrite] = {
|
||||
guild.default_role: PermissionOverwrite(
|
||||
view_channel=False,
|
||||
),
|
||||
guild.self_role: PermissionOverwrite(
|
||||
view_channel=True,
|
||||
),
|
||||
discord_member: PermissionOverwrite(
|
||||
view_channel=True,
|
||||
send_messages=True,
|
||||
use_application_commands=True,
|
||||
),
|
||||
}
|
||||
|
||||
channel: TextChannel = await guild.create_text_channel(
|
||||
f"{discord_member.name}_{shake_256(str(pycord_event._id).encode()).hexdigest(3)}",
|
||||
category=discord_category,
|
||||
overwrites=permission_overwrites,
|
||||
reason=f"Event channel of {self.id} for event {pycord_event._id}",
|
||||
)
|
||||
|
||||
await self.set_event_channel(pycord_event._id, channel.id, cache=cache)
|
||||
|
||||
async def lock_event_channel(
|
||||
self,
|
||||
guild: Guild,
|
||||
event_id: str | ObjectId,
|
||||
completely: bool = False,
|
||||
channel: Optional[TextChannel] = None,
|
||||
):
|
||||
discord_member: Member | None = guild.get_member(self.id)
|
||||
discord_channel: TextChannel | None = (
|
||||
channel if channel is not None else guild.get_channel(self.event_channels[str(event_id)])
|
||||
)
|
||||
|
||||
if discord_member is None:
|
||||
raise RuntimeError(
|
||||
f"Discord guild member with ID {self.id} in guild with ID {guild.id} could not be found!"
|
||||
)
|
||||
|
||||
# TODO Add a unique exception
|
||||
# raise DiscordGuildMemberNotFoundError(self.id, guild.id)
|
||||
|
||||
if discord_member is None:
|
||||
raise RuntimeError(
|
||||
f"Discord channel with ID {self.event_channels[str(event_id)]} in guild with ID {guild.id} could not be found!"
|
||||
)
|
||||
|
||||
# TODO Add a unique exception
|
||||
# raise DiscordChannelNotFoundError(self.event_channels[str(event_id)], guild.id)
|
||||
|
||||
permission_overwrite: PermissionOverwrite = PermissionOverwrite(
|
||||
view_channel=not completely,
|
||||
send_messages=False,
|
||||
use_application_commands=False,
|
||||
)
|
||||
|
||||
try:
|
||||
await discord_channel.set_permissions(
|
||||
discord_member, overwrite=permission_overwrite, reason="Invoked from the user method"
|
||||
)
|
||||
except Forbidden:
|
||||
logger.error(
|
||||
"Could not update channel permissions of %s for %s due to user having higher privileges.",
|
||||
discord_channel.id,
|
||||
self.id,
|
||||
)
|
||||
|
||||
# TODO Add documentation
|
||||
async def set_event_channel(
|
||||
self, event_id: str | ObjectId, channel_id: int, cache: Optional[Cache] = None
|
||||
) -> None:
|
||||
self.event_channels[event_id if isinstance(event_id, str) else str(event_id)] = channel_id
|
||||
|
||||
await self._set(cache, event_channels=self.event_channels)
|
||||
|
Reference in New Issue
Block a user