diff --git a/classes/__init__.py b/classes/__init__.py index 08813ce..ac0044a 100644 --- a/classes/__init__.py +++ b/classes/__init__.py @@ -1,3 +1,4 @@ from .pycord_guild import PycordGuild from .pycord_guild_colors import PycordGuildColors from .pycord_user import PycordUser +from .wallet import Wallet diff --git a/classes/errors/wallet.py b/classes/errors/wallet.py new file mode 100644 index 0000000..3e94916 --- /dev/null +++ b/classes/errors/wallet.py @@ -0,0 +1,10 @@ +class WalletNotFoundError(Exception): + """Wallet could not find user with such an ID from a guild in the database""" + + def __init__(self, owner_id: int, guild_id: int) -> None: + self.owner_id = owner_id + self.guild_id = guild_id + + super().__init__( + f"Wallet of a user with id {self.owner_id} was not found for the guild with id {self.guild_id}" + ) diff --git a/classes/pycord_user.py b/classes/pycord_user.py index 34e5fb2..ed0e833 100644 --- a/classes/pycord_user.py +++ b/classes/pycord_user.py @@ -1,6 +1,7 @@ import logging from dataclasses import dataclass -from typing import Dict, Any, Optional +from logging import Logger +from typing import Any, Dict, Optional from bson import ObjectId from libbot.cache.classes import Cache @@ -8,8 +9,9 @@ from pymongo.results import InsertOneResult from classes.errors.pycord_user import UserNotFoundError from modules.database import col_users +from classes import Wallet -logger = logging.getLogger(__name__) +logger: Logger = logging.getLogger(__name__) @dataclass @@ -131,7 +133,7 @@ class PycordUser: cache.delete(self._get_cache_key()) @staticmethod - def get_defaults(user_id: int | None = None) -> Dict[str, Any]: + def get_defaults(user_id: Optional[int] = None) -> Dict[str, Any]: return { "id": user_id, } @@ -151,3 +153,14 @@ class PycordUser: """ await col_users.delete_one({"_id": self._id}) self._delete_cache(cache) + + async def get_wallet(self, guild_id: int) -> Wallet: + """Get wallet of the user. + + Args: + guild_id (int): Guild ID of the wallet + + Returns: + Wallet: Wallet object of the user + """ + return await Wallet.from_id(self.id, guild_id) diff --git a/classes/wallet.py b/classes/wallet.py index f6772f4..19d6bb7 100644 --- a/classes/wallet.py +++ b/classes/wallet.py @@ -1,15 +1,104 @@ +import logging from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timezone +from logging import Logger +from typing import Any, Dict, Optional from bson import ObjectId +from pymongo.results import InsertOneResult + +from classes.errors import WalletNotFoundError +from modules.database import col_wallets + +logger: Logger = logging.getLogger(__name__) @dataclass class Wallet: _id: ObjectId - id: int - balance: float owner_id: int guild_id: int + balance: float is_frozen: bool created: datetime + + # TODO Write a docstring + @classmethod + async def from_id( + cls, owner_id: int, guild_id: int, allow_creation: bool = True + ) -> "Wallet": + db_entry = await col_wallets.find_one( + {"owner_id": owner_id, "guild_id": guild_id} + ) + + if db_entry is None: + if not allow_creation: + raise WalletNotFoundError(owner_id, guild_id) + + db_entry = Wallet.get_defaults(owner_id, guild_id) + + insert_result: InsertOneResult = await col_wallets.insert_one(db_entry) + + db_entry["_id"] = insert_result.inserted_id + + return cls(**db_entry) + + def _to_dict(self) -> Dict[str, Any]: + return { + "_id": self._id, + "owner_id": self.owner_id, + "guild_id": self.guild_id, + "balance": self.balance, + "is_frozen": self.is_frozen, + "created": self.created, + } + + async def _set(self, key: str, value: Any) -> None: + if not hasattr(self, key): + raise AttributeError() + + setattr(self, key, value) + + await col_wallets.update_one( + {"_id": self._id}, {"$set": {key: value}}, upsert=True + ) + + logger.info( + "Set attribute '%s' of the wallet %s to '%s'", key, str(self._id), value + ) + + @staticmethod + def get_defaults( + owner_id: Optional[int] = None, guild_id: Optional[int] = None + ) -> Dict[str, Any]: + return { + "owner_id": owner_id, + "guild_id": guild_id, + "balance": 0.0, + "is_frozen": False, + "created": datetime.now(tz=timezone.utc), + } + + @staticmethod + def get_default_value(key: str) -> Any: + if key not in Wallet.get_defaults(): + raise KeyError(f"There's no default value for key '{key}' in Wallet") + + return Wallet.get_defaults()[key] + + # TODO Write a docstring + async def freeze(self) -> None: + await self._set("is_frozen", True) + + # TODO Write a docstring + async def unfreeze(self) -> None: + await self._set("is_frozen", False) + + # TODO Write a dosctring + async def deposit(self, amount: float) -> None: + await self._set("balance", round(self.balance + amount, 2)) + + # TODO Add a check to prevent negative balances + # TODO Write a dosctring + async def withdraw(self, amount: float) -> None: + await self._set("balance", round(self.balance - amount, 2)) diff --git a/modules/database.py b/modules/database.py index 4b02ac2..08f161f 100644 --- a/modules/database.py +++ b/modules/database.py @@ -25,6 +25,7 @@ db_client = AsyncClient(con_string) db: AsyncDatabase = db_client.get_database(name=db_config["name"]) col_users: AsyncCollection = db.get_collection("users") +col_wallets: AsyncCollection = db.get_collection("wallets") # col_messages: AsyncCollection = db.get_collection("messages") # col_warnings: AsyncCollection = db.get_collection("warnings") # col_checkouts: AsyncCollection = db.get_collection("checkouts") @@ -33,4 +34,7 @@ col_users: AsyncCollection = db.get_collection("users") # col_transactions: AsyncCollection = db.get_collection("transactions") # Update indexes -db.dispatch.get_collection("users").create_index("id", unique=True) \ No newline at end of file +db.dispatch.get_collection("users").create_index("id", unique=True) +db.dispatch.get_collection("wallets").create_index( + ["owner_id", "guild_id"], unique=False +)