Added a couple of needed attributes (#8)

This commit is contained in:
2025-04-22 23:32:23 +02:00
parent 004b5336d8
commit 57c4ff3bf9

View File

@@ -6,6 +6,7 @@ from bson import ObjectId
from libbot.cache.classes import Cache from libbot.cache.classes import Cache
from pymongo.results import InsertOneResult from pymongo.results import InsertOneResult
from classes.abstract.cacheable import Cacheable
from classes.errors.pycord_user import UserNotFoundError from classes.errors.pycord_user import UserNotFoundError
from modules.database import col_users from modules.database import col_users
from modules.utils import get_logger, restore_from_cache from modules.utils import get_logger, restore_from_cache
@@ -14,15 +15,19 @@ logger: Logger = get_logger(__name__)
@dataclass @dataclass
class PycordUser: class PycordUser(Cacheable):
"""Dataclass of DB entry of a user""" """Dataclass of DB entry of a user"""
__slots__ = ("_id", "id") __slots__ = ("_id", "id", "guild_id", "channel_id", "current_event_id", "current_stage_id")
__short_name__ = "user" __short_name__ = "user"
__collection__ = col_users __collection__ = col_users
_id: ObjectId _id: ObjectId
id: int id: int
guild_id: int
channel_id: int | None
current_event_id: ObjectId | None
current_stage_id: ObjectId | None
@classmethod @classmethod
async def from_id( async def from_id(
@@ -75,6 +80,14 @@ class PycordUser:
return { return {
"_id": self._id if not json_compatible else str(self._id), "_id": self._id if not json_compatible else str(self._id),
"id": self.id, "id": self.id,
"guild_id": self.guild_id,
"channel_id": self.channel_id,
"current_event_id": (
self.current_event_id if not json_compatible else str(self.current_event_id)
),
"current_stage_id": (
self.current_stage_id if not json_compatible else str(self.current_stage_id)
),
} }
async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None: async def _set(self, cache: Optional[Cache] = None, **kwargs) -> None:
@@ -142,9 +155,13 @@ class PycordUser:
cache.delete(self._get_cache_key()) cache.delete(self._get_cache_key())
@staticmethod @staticmethod
def get_defaults(user_id: Optional[int] = None) -> Dict[str, Any]: def get_defaults(user_id: Optional[int] = None, guild_id: Optional[int] = None) -> Dict[str, Any]:
return { return {
"id": user_id, "id": user_id,
"guild_id": guild_id,
"channel_id": None,
"current_event_id": None,
"current_stage_id": None,
} }
@staticmethod @staticmethod