219 lines
6.6 KiB
Python
219 lines
6.6 KiB
Python
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from logging import Logger
|
|
from typing import Any, Dict, Optional
|
|
from zoneinfo import ZoneInfo
|
|
|
|
from bson import ObjectId
|
|
from libbot.cache.classes import Cache
|
|
from pymongo.results import InsertOneResult
|
|
|
|
from classes import Consent
|
|
from classes.base import BaseCacheable
|
|
from classes.enums import ConsentScope
|
|
from classes.errors.pycord_user import UserNotFoundError
|
|
from classes.wallet import Wallet
|
|
from modules.database import col_users
|
|
from modules.utils import restore_from_cache
|
|
|
|
logger: Logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class PycordUser(BaseCacheable):
|
|
"""Dataclass of DB entry of a user"""
|
|
|
|
__slots__ = ("_id", "id", "guild_id")
|
|
__short_name__ = "user"
|
|
__collection__ = col_users
|
|
|
|
_id: ObjectId
|
|
id: int
|
|
guild_id: int
|
|
|
|
@classmethod
|
|
async def from_id(
|
|
cls,
|
|
user_id: int,
|
|
guild_id: int,
|
|
allow_creation: bool = True,
|
|
cache: Optional[Cache] = None,
|
|
) -> "PycordUser":
|
|
"""Find user in database and create new record if user does not exist.
|
|
|
|
Args:
|
|
user_id (int): User's Discord ID
|
|
guild_id (int): User's guild Discord ID
|
|
allow_creation (:obj:`bool`, optional): Create new user record if none found in the database
|
|
cache (:obj:`Cache`, optional): Cache engine to get the cache from
|
|
|
|
Returns:
|
|
PycordUser: User object
|
|
|
|
Raises:
|
|
UserNotFoundError: User was not found and creation was not allowed
|
|
"""
|
|
cached_entry: Dict[str, Any] | None = restore_from_cache(
|
|
cls.__short_name__, f"{user_id}_{guild_id}", cache=cache
|
|
)
|
|
|
|
if cached_entry is not None:
|
|
return cls(**cls._entry_from_cache(cached_entry))
|
|
|
|
db_entry: Dict[str, Any] | None = await cls.__collection__.find_one(
|
|
{"id": user_id, "guild_id": guild_id}
|
|
)
|
|
|
|
if db_entry is None:
|
|
if not allow_creation:
|
|
raise UserNotFoundError(user_id, guild_id)
|
|
|
|
db_entry = PycordUser.get_defaults(user_id, guild_id)
|
|
|
|
insert_result: InsertOneResult = await cls.__collection__.insert_one(db_entry)
|
|
|
|
db_entry["_id"] = insert_result.inserted_id
|
|
|
|
if cache is not None:
|
|
cache.set_json(
|
|
f"{cls.__short_name__}_{user_id}_{guild_id}", cls._entry_to_cache(db_entry)
|
|
)
|
|
|
|
return cls(**db_entry)
|
|
|
|
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
|
|
"""Convert PycordUser object to a JSON representation.
|
|
|
|
Args:
|
|
json_compatible (bool): Whether the JSON-incompatible objects like ObjectId need to be converted
|
|
|
|
Returns:
|
|
Dict[str, Any]: JSON representation of PycordUser
|
|
"""
|
|
return {
|
|
"_id": self._id if not json_compatible else str(self._id),
|
|
"id": self.id,
|
|
"guild_id": self.guild_id,
|
|
}
|
|
|
|
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
|
|
await super()._set(cache, **kwargs)
|
|
|
|
async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None:
|
|
await super()._remove(*args, cache=cache)
|
|
|
|
def _get_cache_key(self) -> str:
|
|
return f"{self.__short_name__}_{self.id}_{self.guild_id}"
|
|
|
|
def _update_cache(self, cache: Optional[Cache] = None) -> None:
|
|
super()._update_cache(cache)
|
|
|
|
def _delete_cache(self, cache: Optional[Cache] = None) -> None:
|
|
super()._delete_cache(cache)
|
|
|
|
@staticmethod
|
|
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
|
|
cache_entry: Dict[str, Any] = db_entry.copy()
|
|
|
|
cache_entry["_id"] = str(cache_entry["_id"])
|
|
|
|
return cache_entry
|
|
|
|
@staticmethod
|
|
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
|
|
db_entry: Dict[str, Any] = cache_entry.copy()
|
|
|
|
db_entry["_id"] = ObjectId(db_entry["_id"])
|
|
|
|
return db_entry
|
|
|
|
# TODO Add documentation
|
|
@staticmethod
|
|
def get_defaults(
|
|
user_id: Optional[int] = None, guild_id: Optional[int] = None
|
|
) -> Dict[str, Any]:
|
|
return {
|
|
"id": user_id,
|
|
"guild_id": guild_id,
|
|
}
|
|
|
|
# TODO Add documentation
|
|
@staticmethod
|
|
def get_default_value(key: str) -> Any:
|
|
if key not in PycordUser.get_defaults():
|
|
raise KeyError(f"There's no default value for key '{key}' in PycordUser")
|
|
|
|
return PycordUser.get_defaults()[key]
|
|
|
|
async def update(
|
|
self,
|
|
cache: Optional[Cache] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
await super().update(cache=cache, **kwargs)
|
|
|
|
async def reset(
|
|
self,
|
|
*args: str,
|
|
cache: Optional[Cache] = None,
|
|
) -> None:
|
|
await super().reset(*args, cache=cache)
|
|
|
|
async def purge(self, cache: Optional[Cache] = None) -> None:
|
|
await super().purge(cache)
|
|
|
|
async def get_wallet(self, guild_id: int) -> Wallet:
|
|
"""Get wallet of the user.
|
|
|
|
Args:
|
|
guild_id (int): Guild ID of the wallet
|
|
|
|
Returns:
|
|
Wallet: Wallet object of the user
|
|
"""
|
|
return await Wallet.from_id(self.id, guild_id)
|
|
|
|
# TODO Add documentation
|
|
async def has_active_consent(self, scope: ConsentScope) -> bool:
|
|
# TODO Test this query
|
|
consent: Dict[str, Any] | None = await Consent.__collection__.find_one(
|
|
{
|
|
"user_id": self.id,
|
|
"guild_id": self.guild_id,
|
|
"scope": scope.value,
|
|
"withdrawal_date": None,
|
|
"$gt": {
|
|
"expiration_date": datetime.now(tz=ZoneInfo("UTC")).replace(tzinfo=None)
|
|
},
|
|
}
|
|
)
|
|
|
|
return consent is not None
|
|
|
|
# TODO Add documentation
|
|
async def give_consent(
|
|
self, scope: ConsentScope, expiration_date: Optional[datetime] = None
|
|
) -> None:
|
|
await Consent.give(self.id, self.guild_id, scope, expiration_date)
|
|
|
|
# TODO Add documentation
|
|
async def withdraw_consent(
|
|
self,
|
|
scope: ConsentScope,
|
|
cache: Optional[Cache] = None,
|
|
) -> None:
|
|
# TODO Test this query
|
|
async for consent_entry in Consent.__collection__.find(
|
|
{
|
|
"user_id": self.id,
|
|
"guild_id": self.guild_id,
|
|
"scope": scope.value,
|
|
"withdrawal_date": None,
|
|
"$gt": {
|
|
"expiration_date": datetime.now(tz=ZoneInfo("UTC")).replace(tzinfo=None)
|
|
},
|
|
}
|
|
):
|
|
await Consent.from_entry(consent_entry).withdraw(cache)
|