Files
QuizBot/classes/pycord_user.py

364 lines
13 KiB
Python

from dataclasses import dataclass
from hashlib import shake_256
from logging import Logger
from typing import Any, Dict, List, Optional
from bson import ObjectId
from discord import Bot, Guild, Member, PermissionOverwrite, TextChannel, Forbidden, Role
from discord.abc import GuildChannel
from libbot.cache.classes import Cache
from pymongo.results import InsertOneResult
from classes.errors.pycord_user import UserNotFoundError
from modules.database import col_users
from modules.utils import get_logger, restore_from_cache
logger: Logger = get_logger(__name__)
@dataclass
class PycordUser:
"""Dataclass of DB entry of a user"""
__slots__ = (
"_id",
"id",
"guild_id",
"event_channels",
"is_jailed",
"current_event_id",
"current_stage_id",
"registered_event_ids",
"completed_event_ids",
)
__short_name__ = "user"
__collection__ = col_users
_id: ObjectId
id: int
guild_id: int
event_channels: Dict[str, int]
is_jailed: bool
current_event_id: ObjectId | None
current_stage_id: ObjectId | None
registered_event_ids: List[ObjectId]
completed_event_ids: List[ObjectId]
# TODO Review the redesign
# event_channel_ids: {
# "%event_id%": %channel_id%
# }
@classmethod
async def from_id(
cls, user_id: int, allow_creation: bool = True, cache: Optional[Cache] = None
) -> "PycordUser":
"""Find user in database and create new record if user does not exist.
Args:
user_id (int): User's Discord ID
allow_creation (:obj:`bool`, optional): Create new user record if none found in the database
cache (:obj:`Cache`, optional): Cache engine to get the cache from
Returns:
PycordUser: User object
Raises:
UserNotFoundError: User was not found and creation was not allowed
"""
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, user_id, cache=cache)
if cached_entry is not None:
return cls(**cached_entry)
db_entry = await cls.__collection__.find_one({"id": user_id})
if db_entry is None:
if not allow_creation:
raise UserNotFoundError(user_id)
db_entry = PycordUser.get_defaults(user_id)
insert_result: InsertOneResult = await cls.__collection__.insert_one(db_entry)
db_entry["_id"] = insert_result.inserted_id
if cache is not None:
cache.set_json(f"{cls.__short_name__}_{user_id}", db_entry)
return cls(**db_entry)
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
"""Convert PycordUser object to a JSON representation.
Args:
json_compatible (bool): Whether the JSON-incompatible objects like ObjectId need to be converted
Returns:
Dict[str, Any]: JSON representation of PycordUser
"""
return {
"_id": self._id if not json_compatible else str(self._id),
"id": self.id,
"guild_id": self.guild_id,
"event_channels": self.event_channels,
"is_jailed": self.is_jailed,
"current_event_id": (
self.current_event_id if not json_compatible else str(self.current_event_id)
),
"current_stage_id": (
self.current_stage_id if not json_compatible else str(self.current_stage_id)
),
"registered_event_ids": (
self.registered_event_ids
if not json_compatible
else [str(event_id) for event_id in self.registered_event_ids]
),
"completed_event_ids": (
self.completed_event_ids
if not json_compatible
else [str(event_id) for event_id in self.completed_event_ids]
),
}
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
"""Set attribute data and save it into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
**kwargs (Any): Mapping of attribute names and respective values to be set
"""
for key, value in kwargs.items():
if not hasattr(self, key):
raise AttributeError()
setattr(self, key, value)
await self.__collection__.update_one({"_id": self._id}, {"$set": kwargs}, upsert=True)
self._update_cache(cache)
logger.info("Set attributes of user %s to %s", self.id, kwargs)
async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None:
"""Remove attribute data and save it into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
*args (str): List of attributes to remove
"""
attributes: Dict[str, Any] = {}
for key in args:
if not hasattr(self, key):
raise AttributeError()
default_value: Any = self.get_default_value(key)
setattr(self, key, default_value)
attributes[key] = default_value
await self.__collection__.update_one({"_id": self._id}, {"$set": attributes}, upsert=True)
self._update_cache(cache)
logger.info("Reset attributes %s of user %s to default values", args, self.id)
def _get_cache_key(self) -> str:
return f"{self.__short_name__}_{self.id}"
def _update_cache(self, cache: Optional[Cache] = None) -> None:
if cache is None:
return
user_dict: Dict[str, Any] = self.to_dict()
if user_dict is not None:
cache.set_json(self._get_cache_key(), user_dict)
else:
self._delete_cache(cache)
def _delete_cache(self, cache: Optional[Cache] = None) -> None:
if cache is None:
return
cache.delete(self._get_cache_key())
@staticmethod
def get_defaults(user_id: Optional[int] = None, guild_id: Optional[int] = None) -> Dict[str, Any]:
return {
"id": user_id,
"guild_id": guild_id,
"event_channels": {},
"is_jailed": False,
"current_event_id": None,
"current_stage_id": None,
"registered_event_ids": [],
"completed_event_ids": [],
}
@staticmethod
def get_default_value(key: str) -> Any:
if key not in PycordUser.get_defaults():
raise KeyError(f"There's no default value for key '{key}' in PycordUser")
return PycordUser.get_defaults()[key]
async def purge(self, cache: Optional[Cache] = None) -> None:
"""Completely remove user data from database. Currently only removes the user record from users collection.
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
"""
await self.__collection__.delete_one({"_id": self._id})
self._delete_cache(cache)
# TODO Add documentation
async def event_register(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None:
event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id
if event_id in self.registered_event_ids:
raise RuntimeError(f"User is already registered for event {event_id}")
# TODO Add a unique exception
# raise UserAlreadyRegisteredForEventError(event_name)
self.registered_event_ids.append(event_id)
await self._set(cache, registered_event_ids=self.registered_event_ids)
# TODO Add documentation
async def event_unregister(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None:
event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id
if event_id not in self.registered_event_ids:
raise RuntimeError(f"User is not registered for event {event_id}")
# TODO Add a unique exception
# raise UserNotRegisteredForEventError(event_name)
self.registered_event_ids.remove(event_id)
await self._set(cache, registered_event_ids=self.registered_event_ids)
# TODO Add documentation
async def event_complete(self, event_id: str | ObjectId, cache: Optional[Cache] = None) -> None:
event_id: ObjectId = ObjectId(event_id) if isinstance(event_id, str) else event_id
if event_id in self.completed_event_ids:
raise RuntimeError(f"User has already completed event {event_id}")
# TODO Add a unique exception
# raise UserAlreadyCompletedEventError(event_name)
self.completed_event_ids.append(event_id)
await self._set(cache, completed_event_ids=self.completed_event_ids)
async def setup_event_channel(
self,
bot: Bot,
guild: Guild,
pycord_guild: "PycordGuild",
pycord_event: "PycordEvent",
cache: Optional[Cache] = None,
):
if str(pycord_event._id) in self.event_channels.keys():
return
discord_member: Member | None = guild.get_member(self.id)
discord_category: GuildChannel | None = bot.get_channel(pycord_guild.category_id)
if discord_member is None:
raise RuntimeError(
f"Discord guild member with ID {self.id} in guild with ID {guild.id} could not be found!"
)
# TODO Add a unique exception
# raise DiscordGuildMemberNotFoundError(self.id, guild.id)
if discord_category is None:
raise RuntimeError(
f"Discord category with ID {pycord_guild.category_id} in guild with ID {guild.id} could not be found!"
)
# TODO Add a unique exception
# raise DiscordCategoryNotFoundError(pycord_guild.category_id, guild.id)
permission_overwrites: Dict[Role | Member, PermissionOverwrite] = {
guild.default_role: PermissionOverwrite(
view_channel=False,
),
guild.self_role: PermissionOverwrite(
view_channel=True,
),
discord_member: PermissionOverwrite(
view_channel=True,
send_messages=True,
use_application_commands=True,
),
}
channel: TextChannel = await guild.create_text_channel(
f"{discord_member.name}_{shake_256(str(pycord_event._id).encode()).hexdigest(3)}",
category=discord_category,
overwrites=permission_overwrites,
reason=f"Event channel of {self.id} for event {pycord_event._id}",
)
await self.set_event_channel(pycord_event._id, channel.id, cache=cache)
async def lock_event_channel(
self,
guild: Guild,
event_id: str | ObjectId,
completely: bool = False,
channel: Optional[TextChannel] = None,
):
discord_member: Member | None = guild.get_member(self.id)
discord_channel: TextChannel | None = (
channel if channel is not None else guild.get_channel(self.event_channels[str(event_id)])
)
if discord_member is None:
raise RuntimeError(
f"Discord guild member with ID {self.id} in guild with ID {guild.id} could not be found!"
)
# TODO Add a unique exception
# raise DiscordGuildMemberNotFoundError(self.id, guild.id)
if discord_member is None:
raise RuntimeError(
f"Discord channel with ID {self.event_channels[str(event_id)]} in guild with ID {guild.id} could not be found!"
)
# TODO Add a unique exception
# raise DiscordChannelNotFoundError(self.event_channels[str(event_id)], guild.id)
permission_overwrite: PermissionOverwrite = PermissionOverwrite(
view_channel=not completely,
send_messages=False,
use_application_commands=False,
)
try:
await discord_channel.set_permissions(
discord_member, overwrite=permission_overwrite, reason="Invoked from the user method"
)
except Forbidden:
logger.error(
"Could not update channel permissions of %s for %s due to user having higher privileges.",
discord_channel.id,
self.id,
)
# TODO Add documentation
async def set_event_channel(
self, event_id: str | ObjectId, channel_id: int, cache: Optional[Cache] = None
) -> None:
self.event_channels[event_id if isinstance(event_id, str) else str(event_id)] = channel_id
await self._set(cache, event_channels=self.event_channels)