Files
QuizBot/classes/pycord_bot.py

267 lines
10 KiB
Python

from datetime import datetime
from logging import Logger
from typing import Any, override, List, Dict
from zoneinfo import ZoneInfo
from bson import ObjectId
from discord import Guild, User, TextChannel
from libbot.cache.classes import CacheMemcached, CacheRedis
from libbot.cache.manager import create_cache_client
from libbot.pycord.classes import PycordBot as LibPycordBot
from classes import PycordEvent, PycordEventStage, PycordGuild, PycordUser
from modules.database import col_users, col_events
from modules.utils import get_logger
logger: Logger = get_logger(__name__)
class PycordBot(LibPycordBot):
cache: CacheMemcached | CacheRedis | None = None
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._set_cache_engine()
if self.scheduler is None:
return
# This replacement exists because of the different
# i18n formats than provided by libbot
self._ = self._modified_string_getter
def _set_cache_engine(self) -> None:
if "cache" in self.config and self.config["cache"]["type"] is not None:
self.cache = create_cache_client(self.config, self.config["cache"]["type"])
def _modified_string_getter(self, key: str, *args: str, locale: str | None = None) -> Any:
"""This method exists because of the different i18n formats than provided by libbot.
It splits "-" and takes the first part of the provided locale to make complex language codes
compatible with an easy libbot approach to i18n.
"""
return self.bot_locale._(key, *args, locale=None if locale is None else locale.split("-")[0])
@override
async def start(self, *args: Any, **kwargs: Any) -> None:
await self._schedule_tasks()
await super().start(*args, **kwargs)
@override
async def close(self, **kwargs) -> None:
await super().close(**kwargs)
async def _schedule_tasks(self) -> None:
self.scheduler.add_job(
self._execute_event_controller, trigger="cron", minute="*/1", id="event_controller"
)
async def _execute_event_controller(self) -> None:
await self._process_events_start()
await self._process_events_end()
# await self._process_events_post_end()
async def _process_events_start(self) -> None:
# Get events to start
events: List[PycordEvent] = await self._get_events(
{"starts": datetime.now(tz=ZoneInfo("UTC")).replace(second=0, microsecond=0)}
)
# Process each event
for event in events:
guild: Guild = self.get_guild(event.guild_id)
pycord_guild: PycordGuild = await self.find_guild(guild)
# Get list of participants
users: List[PycordUser] = await self._get_event_participants(event._id)
for user in users:
# Create a channel for each participant
await user.setup_event_channel(self, guild, pycord_guild, event, cache=self.cache)
# Send a notification about event start
user_channel: TextChannel = guild.get_channel(user.event_channels[str(event._id)])
# TODO Make a nice message
# TODO Also send a thumbnail, event info and short explanation on how to play
await user_channel.send(f"Event **{event.name}** is starting!")
# TODO Make a nice message
await self._notify_admins(
guild,
pycord_guild,
f"Event **{event.name}** has started! Users have gotten their channels and can already start submitting their answers.",
)
async def _process_events_end(self) -> None:
# Get events to end
events: List[PycordEvent] = await self._get_events(
{"ends": datetime.now(tz=ZoneInfo("UTC")).replace(second=0, microsecond=0)}
)
# Process each event
for event in events:
guild: Guild = self.get_guild(event.guild_id)
pycord_guild: PycordGuild = await self.find_guild(guild)
# Get list of participants
users: List[PycordUser] = await self._get_event_participants(event._id)
for user in users:
# Send a notification about event start
user_channel: TextChannel = guild.get_channel(user.event_channels[str(event._id)])
# TODO Make a nice message
# TODO Reveal answers to stages
await user_channel.send(f"Event **{event.name}** has ended!")
# Lock each participant out
await user.lock_event_channel(guild, event._id, channel=user_channel)
# TODO Make a nice message
await self._notify_admins(
guild,
pycord_guild,
f"Event **{event.name}** has ended! Users can no longer submit their answers.",
)
# async def _process_events_post_end(self) -> None:
# # Get events that ended an hour ago
# # TODO Replace with 1 hour after testing!
# events: List[PycordEvent] = await self._get_events(
# {
# "ends": datetime.now(tz=ZoneInfo("UTC")).replace(second=0, microsecond=0)
# - timedelta(minutes=1)
# }
# )
#
# # Process each event
# for event in events:
# guild: Guild = self.get_guild(event.guild_id)
# pycord_guild: PycordGuild = await self.find_guild(guild)
#
# # Get list of participants
# users: List[PycordUser] = await self._get_event_participants(event._id)
#
# for user in users:
# # Send a notification about event start
# user_channel: TextChannel = guild.get_channel(user.event_channels[str(event._id)])
#
# # Remove their view permissions
# await user.lock_event_channel(guild, event._id, completely=True, channel=user_channel)
#
# await self._notify_admins(
# guild,
# pycord_guild,
# f"Access has been updated, users can no longer access their channels for the event **{event.name}**.",
# )
@staticmethod
async def _get_events(query: Dict[str, Any]) -> List[PycordEvent]:
events: List[PycordEvent] = []
async for event_entry in col_events.find(query):
events.append(PycordEvent(**event_entry))
return events
@staticmethod
async def _get_event_participants(event_id: str | ObjectId) -> List[PycordUser]:
users: List[PycordUser] = []
async for user_entry in col_users.find({"registered_event_ids": event_id}):
users.append(PycordUser(**user_entry))
return users
@staticmethod
async def _notify_admins(guild: Guild, pycord_guild: PycordGuild, message: str) -> None:
management_channel: TextChannel | None = guild.get_channel(pycord_guild.channel_id)
if management_channel is None:
logger.error(
"Discord channel with ID %s in guild with ID %s could not be found!",
pycord_guild.channel_id,
guild.id,
)
return
await management_channel.send(message)
async def find_user(self, user: int | User) -> PycordUser:
"""Find User by its ID or User object.
Args:
user (int | User): ID or User object to extract ID from
Returns:
PycordUser: User object
Raises:
UserNotFoundException: User was not found and creation was not allowed
"""
return (
await PycordUser.from_id(user, cache=self.cache)
if isinstance(user, int)
else await PycordUser.from_id(user.id, cache=self.cache)
)
async def find_guild(self, guild: int | Guild) -> PycordGuild:
"""Find Guild by its ID or Guild object.
Args:
guild (int | Guild): ID or User object to extract ID from
Returns:
PycordGuild: Guild object
Raises:
GuildNotFoundException: Guild was not found and creation was not allowed
"""
return (
await PycordGuild.from_id(guild, cache=self.cache)
if isinstance(guild, int)
else await PycordGuild.from_id(guild.id, cache=self.cache)
)
# TODO Document this method
async def create_event(self, **kwargs) -> PycordEvent:
return await PycordEvent.create(**kwargs, cache=self.cache)
# TODO Document this method
async def create_event_stage(self, event: PycordEvent, **kwargs) -> PycordEventStage:
# TODO Validation is handled by the caller for now, but
# ideally this should not be the case at all.
#
# if "event_id" not in kwargs:
# # TODO Create a nicer exception
# raise RuntimeError("Event ID must be provided while creating an event stage")
#
# event: PycordEvent = await self.find_event(event_id=kwargs["event_id"])
if "sequence" not in kwargs:
# TODO Create a nicer exception
raise RuntimeError("Stage must have a defined sequence")
event_stage: PycordEventStage = await PycordEventStage.create(**kwargs, cache=self.cache)
await event.insert_stage(self, event_stage._id, kwargs["sequence"], cache=self.cache)
return event_stage
# TODO Document this method
async def find_event(
self, event_id: str | ObjectId | None = None, event_name: str | None = None
) -> PycordEvent:
if event_id is None and event_name is None:
raise AttributeError("Either event's ID or name must be provided!")
if event_id is not None:
return await PycordEvent.from_id(event_id, cache=self.cache)
else:
return await PycordEvent.from_name(event_name, cache=self.cache)
# TODO Document this method
async def find_event_stage(self, stage_id: str | ObjectId) -> PycordEventStage:
return await PycordEventStage.from_id(stage_id, cache=self.cache)