WIP: Wallets

This commit is contained in:
Profitroll 2025-02-18 08:04:02 +01:00
parent 654034491a
commit 8883c8eda8
5 changed files with 124 additions and 7 deletions

View File

@ -1,3 +1,4 @@
from .pycord_guild import PycordGuild
from .pycord_guild_colors import PycordGuildColors
from .pycord_user import PycordUser
from .wallet import Wallet

10
classes/errors/wallet.py Normal file
View File

@ -0,0 +1,10 @@
class WalletNotFoundError(Exception):
"""Wallet could not find user with such an ID from a guild in the database"""
def __init__(self, owner_id: int, guild_id: int) -> None:
self.owner_id = owner_id
self.guild_id = guild_id
super().__init__(
f"Wallet of a user with id {self.owner_id} was not found for the guild with id {self.guild_id}"
)

View File

@ -1,6 +1,7 @@
import logging
from dataclasses import dataclass
from typing import Dict, Any, Optional
from logging import Logger
from typing import Any, Dict, Optional
from bson import ObjectId
from libbot.cache.classes import Cache
@ -8,8 +9,9 @@ from pymongo.results import InsertOneResult
from classes.errors.pycord_user import UserNotFoundError
from modules.database import col_users
from classes import Wallet
logger = logging.getLogger(__name__)
logger: Logger = logging.getLogger(__name__)
@dataclass
@ -131,7 +133,7 @@ class PycordUser:
cache.delete(self._get_cache_key())
@staticmethod
def get_defaults(user_id: int | None = None) -> Dict[str, Any]:
def get_defaults(user_id: Optional[int] = None) -> Dict[str, Any]:
return {
"id": user_id,
}
@ -151,3 +153,14 @@ class PycordUser:
"""
await col_users.delete_one({"_id": self._id})
self._delete_cache(cache)
async def get_wallet(self, guild_id: int) -> Wallet:
"""Get wallet of the user.
Args:
guild_id (int): Guild ID of the wallet
Returns:
Wallet: Wallet object of the user
"""
return await Wallet.from_id(self.id, guild_id)

View File

@ -1,15 +1,104 @@
import logging
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timezone
from logging import Logger
from typing import Any, Dict, Optional
from bson import ObjectId
from pymongo.results import InsertOneResult
from classes.errors import WalletNotFoundError
from modules.database import col_wallets
logger: Logger = logging.getLogger(__name__)
@dataclass
class Wallet:
_id: ObjectId
id: int
balance: float
owner_id: int
guild_id: int
balance: float
is_frozen: bool
created: datetime
# TODO Write a docstring
@classmethod
async def from_id(
cls, owner_id: int, guild_id: int, allow_creation: bool = True
) -> "Wallet":
db_entry = await col_wallets.find_one(
{"owner_id": owner_id, "guild_id": guild_id}
)
if db_entry is None:
if not allow_creation:
raise WalletNotFoundError(owner_id, guild_id)
db_entry = Wallet.get_defaults(owner_id, guild_id)
insert_result: InsertOneResult = await col_wallets.insert_one(db_entry)
db_entry["_id"] = insert_result.inserted_id
return cls(**db_entry)
def _to_dict(self) -> Dict[str, Any]:
return {
"_id": self._id,
"owner_id": self.owner_id,
"guild_id": self.guild_id,
"balance": self.balance,
"is_frozen": self.is_frozen,
"created": self.created,
}
async def _set(self, key: str, value: Any) -> None:
if not hasattr(self, key):
raise AttributeError()
setattr(self, key, value)
await col_wallets.update_one(
{"_id": self._id}, {"$set": {key: value}}, upsert=True
)
logger.info(
"Set attribute '%s' of the wallet %s to '%s'", key, str(self._id), value
)
@staticmethod
def get_defaults(
owner_id: Optional[int] = None, guild_id: Optional[int] = None
) -> Dict[str, Any]:
return {
"owner_id": owner_id,
"guild_id": guild_id,
"balance": 0.0,
"is_frozen": False,
"created": datetime.now(tz=timezone.utc),
}
@staticmethod
def get_default_value(key: str) -> Any:
if key not in Wallet.get_defaults():
raise KeyError(f"There's no default value for key '{key}' in Wallet")
return Wallet.get_defaults()[key]
# TODO Write a docstring
async def freeze(self) -> None:
await self._set("is_frozen", True)
# TODO Write a docstring
async def unfreeze(self) -> None:
await self._set("is_frozen", False)
# TODO Write a dosctring
async def deposit(self, amount: float) -> None:
await self._set("balance", round(self.balance + amount, 2))
# TODO Add a check to prevent negative balances
# TODO Write a dosctring
async def withdraw(self, amount: float) -> None:
await self._set("balance", round(self.balance - amount, 2))

View File

@ -25,6 +25,7 @@ db_client = AsyncClient(con_string)
db: AsyncDatabase = db_client.get_database(name=db_config["name"])
col_users: AsyncCollection = db.get_collection("users")
col_wallets: AsyncCollection = db.get_collection("wallets")
# col_messages: AsyncCollection = db.get_collection("messages")
# col_warnings: AsyncCollection = db.get_collection("warnings")
# col_checkouts: AsyncCollection = db.get_collection("checkouts")
@ -33,4 +34,7 @@ col_users: AsyncCollection = db.get_collection("users")
# col_transactions: AsyncCollection = db.get_collection("transactions")
# Update indexes
db.dispatch.get_collection("users").create_index("id", unique=True)
db.dispatch.get_collection("users").create_index("id", unique=True)
db.dispatch.get_collection("wallets").create_index(
["owner_id", "guild_id"], unique=False
)