3 Commits

15 changed files with 386 additions and 109 deletions

View File

@@ -0,0 +1 @@
from .cacheable import Cacheable

View File

@@ -0,0 +1,81 @@
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Optional
from libbot.cache.classes import Cache
from pymongo.asynchronous.collection import AsyncCollection
class Cacheable(ABC):
"""Abstract class for cacheable"""
__short_name__: str
__collection__: ClassVar[AsyncCollection]
@classmethod
@abstractmethod
async def from_id(cls, *args: Any, cache: Optional[Cache] = None, **kwargs: Any) -> Any:
pass
@abstractmethod
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
pass
@abstractmethod
async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None:
pass
@abstractmethod
def _get_cache_key(self) -> str:
pass
@abstractmethod
def _update_cache(self, cache: Optional[Cache] = None) -> None:
pass
@abstractmethod
def _delete_cache(self, cache: Optional[Cache] = None) -> None:
pass
@staticmethod
@abstractmethod
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
pass
@staticmethod
@abstractmethod
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
pass
@abstractmethod
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
pass
@staticmethod
@abstractmethod
def get_defaults(**kwargs: Any) -> Dict[str, Any]:
pass
@staticmethod
@abstractmethod
def get_default_value(key: str) -> Any:
pass
@abstractmethod
async def update(
self,
cache: Optional[Cache] = None,
**kwargs: Any,
) -> None:
pass
@abstractmethod
async def reset(
self,
*args: str,
cache: Optional[Cache] = None,
) -> None:
pass
@abstractmethod
async def purge(self, cache: Optional[Cache] = None) -> None:
pass

1
classes/base/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .base_cacheable import BaseCacheable

View File

@@ -0,0 +1,110 @@
from abc import ABC
from logging import Logger
from typing import Any, Dict, Optional
from bson import ObjectId
from libbot.cache.classes import Cache
from classes.abstract import Cacheable
from modules.utils import get_logger
logger: Logger = get_logger(__name__)
class BaseCacheable(Cacheable, ABC):
"""Base implementation of Cacheable used by all cachable classes."""
_id: ObjectId
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
for key, value in kwargs.items():
if not hasattr(self, key):
raise AttributeError()
setattr(self, key, value)
await self.__collection__.update_one({"_id": self._id}, {"$set": kwargs})
self._update_cache(cache)
logger.info("Set attributes of %s to %s", self._id, kwargs)
async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None:
attributes: Dict[str, Any] = {}
for key in args:
if not hasattr(self, key):
raise AttributeError()
default_value: Any = self.get_default_value(key)
setattr(self, key, default_value)
attributes[key] = default_value
await self.__collection__.update_one({"_id": self._id}, {"$set": attributes})
self._update_cache(cache)
logger.info("Reset attributes %s of %s to default values", args, self._id)
def _update_cache(self, cache: Optional[Cache] = None) -> None:
if cache is None:
return
object_dict: Dict[str, Any] = self.to_dict(json_compatible=True)
if object_dict is not None:
cache.set_json(self._get_cache_key(), object_dict)
else:
self._delete_cache(cache)
def _delete_cache(self, cache: Optional[Cache] = None) -> None:
if cache is None:
return
cache.delete(self._get_cache_key())
async def update(
self,
cache: Optional[Cache] = None,
**kwargs: Any,
) -> None:
"""Update attribute(s) on the object and save the updated entry into the database.
Args:
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
**kwargs (Any): Mapping of attributes in format `attribute_name=attribute_value` to update.
Raises:
AttributeError: Provided attribute does not exist in the class.
"""
await self._set(cache=cache, **kwargs)
async def reset(
self,
*args: str,
cache: Optional[Cache] = None,
) -> None:
"""Remove attribute(s) on the object, replace them with a default value and save the updated entry into the database.
Args:
*args (str): List of attributes to remove.
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
Raises:
AttributeError: Provided attribute does not exist in the class.
"""
await self._remove(*args, cache=cache)
async def purge(self, cache: Optional[Cache] = None) -> None:
"""Completely remove object data from database. Currently only removes the record from a respective collection.
Args:
cache (:obj:`Cache`, optional): Cache engine that will be used to update the cache.
"""
await self.__collection__.delete_one({"_id": self._id})
self._delete_cache(cache)
logger.info("Purged %s from the database", self._id)

View File

@@ -1,7 +1,8 @@
class UserNotFoundError(Exception):
"""PycordUser could not find user with such an ID in the database"""
def __init__(self, user_id: int) -> None:
self.user_id = user_id
def __init__(self, user_id: int, guild_id: int) -> None:
self.user_id: int = user_id
self.guild_id: int = guild_id
super().__init__(f"User with id {self.user_id} was not found")
super().__init__(f"User with id {self.user_id} was not found in guild {self.guild_id}")

View File

@@ -1,13 +1,17 @@
import logging
from datetime import datetime
from logging import Logger
from typing import Any
from typing import Any, Literal
from zoneinfo import ZoneInfo
from aiohttp import ClientSession
from discord import User
from libbot.cache.classes import CacheMemcached, CacheRedis
from libbot.cache.manager import create_cache_client
from libbot.pycord.classes import PycordBot as LibPycordBot
from classes import PycordUser
from modules.database import _update_database_indexes
logger: Logger = logging.getLogger(__name__)
@@ -16,6 +20,11 @@ logger: Logger = logging.getLogger(__name__)
class PycordBot(LibPycordBot):
__version__ = "0.0.1"
started: datetime
cache: CacheMemcached | CacheRedis | None = None
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@@ -30,13 +39,13 @@ class PycordBot(LibPycordBot):
# i18n formats than provided by libbot
self._ = self._modified_string_getter
# Scheduler job for DHL parcel tracking
# self.scheduler.add_job(
# update_tracks_dhl,
# trigger="cron",
# hour=self.config["modules"]["tracking"]["fetch_hours"],
# args=[self, self.client_session],
# )
def _set_cache_engine(self) -> None:
cache_type: Literal["redis", "memcached"] | None = self.config["cache"]["type"]
if "cache" in self.config and cache_type is not None:
self.cache = create_cache_client(
self.config, cache_type, prefix=self.config["cache"][cache_type]["prefix"]
)
def _modified_string_getter(self, key: str, *args: str, locale: str | None = None) -> Any:
"""This method exists because of the different i18n formats than provided by libbot.
@@ -47,10 +56,6 @@ class PycordBot(LibPycordBot):
key, *args, locale=None if locale is None else locale.split("-")[0]
)
def _set_cache_engine(self) -> None:
if "cache" in self.config and self.config["cache"]["type"] is not None:
self.cache = create_cache_client(self.config, self.config["cache"]["type"])
async def find_user(self, user: int | User) -> PycordUser:
"""Find User by its ID or User object.
@@ -70,6 +75,11 @@ class PycordBot(LibPycordBot):
)
async def start(self, *args: Any, **kwargs: Any) -> None:
await self._schedule_tasks()
await _update_database_indexes()
self.started = datetime.now(tz=ZoneInfo("UTC"))
await super().start(*args, **kwargs)
async def close(self, **kwargs) -> None:
@@ -79,3 +89,13 @@ class PycordBot(LibPycordBot):
self.scheduler.shutdown()
await super().close(**kwargs)
async def _schedule_tasks(self) -> None:
# Scheduler job for DHL parcel tracking
# self.scheduler.add_job(
# update_tracks_dhl,
# trigger="cron",
# hour=self.config["modules"]["tracking"]["fetch_hours"],
# args=[self, self.client_session],
# )
pass

View File

@@ -7,30 +7,37 @@ from bson import ObjectId
from libbot.cache.classes import Cache
from pymongo.results import InsertOneResult
from classes.base import BaseCacheable
from classes.errors.pycord_user import UserNotFoundError
from classes.wallet import Wallet
from modules.database import col_users
from modules.utils import restore_from_cache
logger: Logger = logging.getLogger(__name__)
@dataclass
class PycordUser:
class PycordUser(BaseCacheable):
"""Dataclass of DB entry of a user"""
__slots__ = ("_id", "id")
__slots__ = ("_id", "id", "guild_id")
_id: ObjectId
id: int
guild_id: int
@classmethod
async def from_id(
cls, user_id: int, allow_creation: bool = True, cache: Optional[Cache] = None
cls,
user_id: int,
guild_id: int,
allow_creation: bool = True,
cache: Optional[Cache] = None,
) -> "PycordUser":
"""Find user in database and create new record if user does not exist.
Args:
user_id (int): User's Discord ID
guild_id (int): User's guild Discord ID
allow_creation (:obj:`bool`, optional): Create new user record if none found in the database
cache (:obj:`Cache`, optional): Cache engine to get the cache from
@@ -40,26 +47,31 @@ class PycordUser:
Raises:
UserNotFoundError: User was not found and creation was not allowed
"""
if cache is not None:
cached_entry: Dict[str, Any] | None = cache.get_json(f"user_{user_id}")
cached_entry: Dict[str, Any] | None = restore_from_cache(
cls.__short_name__, f"{user_id}_{guild_id}", cache=cache
)
if cached_entry is not None:
return cls(**cached_entry)
if cached_entry is not None:
return cls(**cls._entry_from_cache(cached_entry))
db_entry = await col_users.find_one({"id": user_id})
db_entry: Dict[str, Any] | None = await cls.__collection__.find_one(
{"id": user_id, "guild_id": guild_id}
)
if db_entry is None:
if not allow_creation:
raise UserNotFoundError(user_id)
raise UserNotFoundError(user_id, guild_id)
db_entry = PycordUser.get_defaults(user_id)
db_entry = PycordUser.get_defaults(user_id, guild_id)
insert_result: InsertOneResult = await col_users.insert_one(db_entry)
insert_result: InsertOneResult = await cls.__collection__.insert_one(db_entry)
db_entry["_id"] = insert_result.inserted_id
if cache is not None:
cache.set_json(f"user_{user_id}", db_entry)
cache.set_json(
f"{cls.__short_name__}_{user_id}_{guild_id}", cls._entry_to_cache(db_entry)
)
return cls(**db_entry)
@@ -75,75 +87,51 @@ class PycordUser:
return {
"_id": self._id if not json_compatible else str(self._id),
"id": self.id,
"guild_id": self.guild_id,
}
async def _set(self, key: str, value: Any, cache: Optional[Cache] = None) -> None:
"""Set attribute data and save it into the database.
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
await super()._set(cache, **kwargs)
Args:
key (str): Attribute to change
value (Any): Value to set
cache (:obj:`Cache`, optional): Cache engine to write the update into
"""
if not hasattr(self, key):
raise AttributeError()
setattr(self, key, value)
await col_users.update_one({"_id": self._id}, {"$set": {key: value}}, upsert=True)
self._update_cache(cache)
logger.info("Set attribute '%s' of user %s to '%s'", key, self.id, value)
async def _remove(self, key: str, cache: Optional[Cache] = None) -> None:
"""Remove attribute data and save it into the database.
Args:
key (str): Attribute to remove
cache (:obj:`Cache`, optional): Cache engine to write the update into
"""
if not hasattr(self, key):
raise AttributeError()
default_value: Any = PycordUser.get_default_value(key)
setattr(self, key, default_value)
await col_users.update_one(
{"_id": self._id}, {"$set": {key: default_value}}, upsert=True
)
self._update_cache(cache)
logger.info("Removed attribute '%s' of user %s", key, self.id)
async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None:
await super()._remove(*args, cache=cache)
def _get_cache_key(self) -> str:
return f"user_{self.id}"
return f"{self.__short_name__}_{self.id}_{self.guild_id}"
def _update_cache(self, cache: Optional[Cache] = None) -> None:
if cache is None:
return
user_dict: Dict[str, Any] = self.to_dict()
if user_dict is not None:
cache.set_json(self._get_cache_key(), user_dict)
else:
self._delete_cache(cache)
super()._update_cache(cache)
def _delete_cache(self, cache: Optional[Cache] = None) -> None:
if cache is None:
return
cache.delete(self._get_cache_key())
super()._delete_cache(cache)
@staticmethod
def get_defaults(user_id: Optional[int] = None) -> Dict[str, Any]:
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
cache_entry: Dict[str, Any] = db_entry.copy()
cache_entry["_id"] = str(cache_entry["_id"])
return cache_entry
@staticmethod
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
db_entry: Dict[str, Any] = cache_entry.copy()
db_entry["_id"] = ObjectId(db_entry["_id"])
return db_entry
# TODO Add documentation
@staticmethod
def get_defaults(
user_id: Optional[int] = None, guild_id: Optional[int] = None
) -> Dict[str, Any]:
return {
"id": user_id,
"guild_id": guild_id,
}
# TODO Add documentation
@staticmethod
def get_default_value(key: str) -> Any:
if key not in PycordUser.get_defaults():
@@ -151,14 +139,22 @@ class PycordUser:
return PycordUser.get_defaults()[key]
async def purge(self, cache: Optional[Cache] = None) -> None:
"""Completely remove user data from database. Currently only removes the user record from users collection.
async def update(
self,
cache: Optional[Cache] = None,
**kwargs: Any,
) -> None:
await super().update(cache=cache, **kwargs)
Args:
cache (:obj:`Cache`, optional): Cache engine to write the update into
"""
await col_users.delete_one({"_id": self._id})
self._delete_cache(cache)
async def reset(
self,
*args: str,
cache: Optional[Cache] = None,
) -> None:
await super().reset(*args, cache=cache)
async def purge(self, cache: Optional[Cache] = None) -> None:
await super().purge(cache)
async def get_wallet(self, guild_id: int) -> Wallet:
"""Get wallet of the user.

View File

@@ -145,8 +145,10 @@ class Wallet:
if amount < 0:
raise ValueError()
# allow_creation might need to be set to False in the future
# if users will be able to opt out from having a wallet
wallet: Wallet = await self.from_id(
wallet_owner_id, wallet_guild_id, allow_creation=False
wallet_owner_id, wallet_guild_id, allow_creation=True
)
if balance_limit is not None and amount + wallet.balance > balance_limit:

View File

@@ -26,10 +26,12 @@
"cache": {
"type": null,
"memcached": {
"uri": "127.0.0.1:11211"
"uri": "127.0.0.1:11211",
"prefix": null
},
"redis": {
"uri": "redis://127.0.0.1:6379/0"
"uri": "redis://127.0.0.1:6379/0",
"prefix": null
}
},
"privacy": {

22
main.py
View File

@@ -1,9 +1,11 @@
import asyncio
import contextlib
import logging
import logging.config
from logging import Logger
from os import getpid
from os import getpid, makedirs
from pathlib import Path
from discord import LoginFailure
from libbot.utils import config_get
# Import required for uvicorn
@@ -11,14 +13,13 @@ from api.app import app # noqa
from classes.pycord_bot import PycordBot
from modules.extensions_loader import dynamic_import_from_src
from modules.scheduler import scheduler
from modules.utils import get_logging_config, get_logger
logging.basicConfig(
level=logging.DEBUG if config_get("debug") else logging.INFO,
format="%(name)s.%(funcName)s | %(levelname)s | %(message)s",
datefmt="[%X]",
)
makedirs(Path("logs/"), exist_ok=True)
logger: Logger = logging.getLogger(__name__)
logging.config.dictConfig(get_logging_config())
logger: Logger = get_logger(__name__)
# Try to import the module that improves performance
# and ignore errors when module is not installed
@@ -38,9 +39,14 @@ async def main():
try:
await bot.start(config_get("bot_token", "bot"))
except LoginFailure as exc:
logger.error("Provided bot token is invalid: %s", exc)
except KeyboardInterrupt:
logger.warning("Forcefully shutting down with PID %s...", getpid())
await bot.close()
except Exception as exc:
logger.error("An unexpected error has occurred: %s", exc, exc_info=exc)
exit(1)
asyncio.create_task(main())

View File

@@ -2,8 +2,10 @@
from typing import Any, Mapping
from async_pymongo import AsyncClient, AsyncCollection, AsyncDatabase
from libbot.utils import config_get
from pymongo import AsyncMongoClient
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.database import AsyncDatabase
db_config: Mapping[str, Any] = config_get("database")
@@ -21,10 +23,11 @@ else:
)
# Async declarations
db_client = AsyncClient(con_string)
db_client = AsyncMongoClient(con_string)
db: AsyncDatabase = db_client.get_database(name=db_config["name"])
col_users: AsyncCollection = db.get_collection("users")
col_guilds: AsyncCollection = db.get_collection("guilds")
col_wallets: AsyncCollection = db.get_collection("wallets")
# col_messages: AsyncCollection = db.get_collection("messages")
# col_warnings: AsyncCollection = db.get_collection("warnings")
@@ -33,6 +36,11 @@ col_wallets: AsyncCollection = db.get_collection("wallets")
# col_authorized: AsyncCollection = db.get_collection("authorized")
# col_transactions: AsyncCollection = db.get_collection("transactions")
# Update indexes
db.dispatch.get_collection("users").create_index("id", unique=True)
db.dispatch.get_collection("wallets").create_index(["owner_id", "guild_id"], unique=False)
async def _update_database_indexes() -> None:
await col_users.create_index(["id", "guild_id"], name="user_id-guild_id", unique=True)
await col_guilds.create_index("guild_id", name="guild_id", unique=True)
await col_wallets.create_index(
["owner_id", "guild_id"], name="owner_id-guild_id", unique=True
)

View File

@@ -0,0 +1,2 @@
from .cache_utils import restore_from_cache
from .logging_utils import get_logger, get_logging_config

View File

@@ -0,0 +1,10 @@
from typing import Any, Dict, Optional
from bson import ObjectId
from libbot.cache.classes import Cache
def restore_from_cache(
cache_prefix: str, cache_key: str | int | ObjectId, cache: Optional[Cache] = None
) -> Dict[str, Any] | None:
return None if cache is None else cache.get_json(f"{cache_prefix}_{cache_key}")

View File

@@ -0,0 +1,35 @@
import logging
from logging import Logger
from pathlib import Path
from typing import Any, Dict
from libbot.utils import config_get
def get_logging_config() -> Dict[str, Any]:
return {
"version": 1,
"disable_existing_loggers": False,
"handlers": {
"file": {
"class": "logging.handlers.RotatingFileHandler",
"filename": str(Path("logs/latest.log")),
"maxBytes": 500000,
"backupCount": 10,
"formatter": "simple",
},
"console": {"class": "logging.StreamHandler", "formatter": "systemd"},
},
"formatters": {
"simple": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"},
"systemd": {"format": "%(name)s - %(levelname)s - %(message)s"},
},
"root": {
"level": "DEBUG" if config_get("debug") else "INFO",
"handlers": ["file", "console"],
},
}
def get_logger(name: str) -> Logger:
return logging.getLogger(name)

View File

@@ -1,11 +1,13 @@
aiohttp>=3.6.0
apscheduler~=3.11.0
async_pymongo==0.1.11
colorthief==0.2.1
deepl==1.22.0
fastapi[all]~=0.115.0
libbot[speed,pycord,cache]==4.2.0
mongodb-migrations==1.3.1
pynacl~=1.5.0
pyrmv==0.5.0
pytz~=2025.1
pytz~=2025.1
# Temporarily disabled because
# these are still unused for now
# colorthief==0.2.1
# deepl==1.22.0
# pyrmv==0.5.0