from dataclasses import dataclass from logging import Logger from typing import Any, Dict, Optional from bson import ObjectId from libbot.cache.classes import Cache from pymongo.results import InsertOneResult from classes.base.base_cacheable import BaseCacheable from classes.errors import GuildNotFoundError from modules.database import col_guilds from modules.utils import get_logger, restore_from_cache logger: Logger = get_logger(__name__) @dataclass class PycordGuild(BaseCacheable): """Dataclass of DB entry of a guild""" __slots__ = ( "_id", "id", "general_channel_id", "management_channel_id", "category_id", "timezone", "prefer_emojis", ) __short_name__ = "guild" __collection__ = col_guilds _id: ObjectId id: int general_channel_id: int | None management_channel_id: int | None category_id: int | None timezone: str prefer_emojis: bool @classmethod async def from_id( cls, guild_id: int, allow_creation: bool = True, cache: Optional[Cache] = None ) -> "PycordGuild": """Find the guild by its ID and construct PycordEventStage from database entry. Args: guild_id (int): ID of the guild to look up. allow_creation (:obj:`bool`, optional): Create a new record if none found in the database. cache (:obj:`Cache`, optional): Cache engine that will be used to fetch and update the cache. Returns: PycordGuild: Object of the found or newly created guild. Raises: GuildNotFoundError: Guild with such ID does not exist and creation was not allowed. """ cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, guild_id, cache=cache) if cached_entry is not None: return cls(**cls._entry_from_cache(cached_entry)) db_entry: Dict[str, Any] | None = await cls.__collection__.find_one({"id": guild_id}) if db_entry is None: if not allow_creation: raise GuildNotFoundError(guild_id) db_entry = PycordGuild.get_defaults(guild_id) insert_result: InsertOneResult = await cls.__collection__.insert_one(db_entry) db_entry["_id"] = insert_result.inserted_id if cache is not None: cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry)) return cls(**db_entry) async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None: await super()._set(cache, **kwargs) async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None: await super()._remove(*args, cache=cache) def _get_cache_key(self) -> str: return f"{self.__short_name__}_{self.id}" def _update_cache(self, cache: Optional[Cache] = None) -> None: super()._update_cache(cache) def _delete_cache(self, cache: Optional[Cache] = None) -> None: super()._delete_cache(cache) @staticmethod def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]: cache_entry: Dict[str, Any] = db_entry.copy() cache_entry["_id"] = str(cache_entry["_id"]) return cache_entry @staticmethod def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]: db_entry: Dict[str, Any] = cache_entry.copy() db_entry["_id"] = ObjectId(db_entry["_id"]) return db_entry def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]: """Convert the object to a JSON representation. Args: json_compatible (bool): Whether the JSON-incompatible objects like ObjectId need to be converted. Returns: Dict[str, Any]: JSON representation of the object. """ return { "_id": self._id if not json_compatible else str(self._id), "id": self.id, "general_channel_id": self.general_channel_id, "management_channel_id": self.management_channel_id, "category_id": self.category_id, "timezone": self.timezone, "prefer_emojis": self.prefer_emojis, } @staticmethod def get_defaults(guild_id: Optional[int] = None) -> Dict[str, Any]: """Get default values for the object attributes. Returns: Dict[str, Any]: Mapping of attributes and their respective values in format `{"attribute_name:" attribute_value}`. """ return { "id": guild_id, "general_channel_id": None, "management_channel_id": None, "category_id": None, "timezone": "UTC", "prefer_emojis": False, } @staticmethod def get_default_value(key: str) -> Any: """Get default value of the attribute for the object. Args: key (str): Name of the attribute. Returns: Any: Default value of the attribute. Raises: KeyError: There's no default value for the provided attribute. """ if key not in PycordGuild.get_defaults(): raise KeyError(f"There's no default value for key '{key}' in PycordGuild") return PycordGuild.get_defaults()[key] async def update( self, cache: Optional[Cache] = None, **kwargs: Any, ) -> None: await super().update(cache=cache, **kwargs) async def reset( self, *args: str, cache: Optional[Cache] = None, ) -> None: await super().reset(*args, cache=cache) async def purge(self, cache: Optional[Cache] = None) -> None: await super().purge(cache) def is_configured(self) -> bool: """Return whether all attributes required for bot's use on the server are set. Returns: bool: `True` if yes and `False` if not. """ return ( (self.id is not None) and (self.general_channel_id is not None) and (self.management_channel_id is not None) and (self.category_id is not None) and (self.timezone is not None) and (self.prefer_emojis is not None) )