From ffcfbbfc3b04f9134a691c1712f735468486a46b Mon Sep 17 00:00:00 2001 From: profitroll Date: Sun, 16 Feb 2025 13:11:48 +0100 Subject: [PATCH] Added caching, updated libbot, refactored PycordUser --- README.md | 6 ++ classes/cache/__init__.py | 3 + classes/cache/cache.py | 44 +++++++++ classes/cache/cache_memcached.py | 89 ++++++++++++++++++ classes/cache/cache_redis.py | 89 ++++++++++++++++++ classes/errors/__init__.py | 1 + classes/errors/pycord_user.py | 9 ++ classes/pycord_bot.py | 62 +++++++++++++ classes/pycord_user.py | 153 +++++++++++++++++++++++++++++++ classes/pycordbot.py | 50 ---------- classes/pycorduser.py | 42 --------- config_example.json | 9 ++ main.py | 8 +- modules/cache_manager.py | 29 ++++++ modules/cache_utils.py | 25 +++++ modules/database.py | 18 ++-- modules/migrator.py | 2 +- requirements.txt | 5 +- 18 files changed, 538 insertions(+), 106 deletions(-) create mode 100644 classes/cache/__init__.py create mode 100644 classes/cache/cache.py create mode 100644 classes/cache/cache_memcached.py create mode 100644 classes/cache/cache_redis.py create mode 100644 classes/errors/__init__.py create mode 100644 classes/errors/pycord_user.py create mode 100644 classes/pycord_bot.py create mode 100644 classes/pycord_user.py delete mode 100644 classes/pycordbot.py delete mode 100644 classes/pycorduser.py create mode 100644 modules/cache_manager.py create mode 100644 modules/cache_utils.py diff --git a/README.md b/README.md index 0351282..62892dd 100644 --- a/README.md +++ b/README.md @@ -14,3 +14,9 @@ Discord

+ +## Starting the bot + +```shell +uvicorn main:app +``` \ No newline at end of file diff --git a/classes/cache/__init__.py b/classes/cache/__init__.py new file mode 100644 index 0000000..dce54a8 --- /dev/null +++ b/classes/cache/__init__.py @@ -0,0 +1,3 @@ +from .cache import Cache +from .cache_memcached import CacheMemcached +from .cache_redis import CacheRedis diff --git a/classes/cache/cache.py b/classes/cache/cache.py new file mode 100644 index 0000000..8b0f617 --- /dev/null +++ b/classes/cache/cache.py @@ -0,0 +1,44 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict + +import pymemcache +import redis + + +class Cache(ABC): + client: pymemcache.Client | redis.Redis + + @classmethod + @abstractmethod + def from_config(cls, engine_config: Dict[str, Any]) -> Any: + pass + + @abstractmethod + def get_json(self, key: str) -> Any | None: + # TODO This method must also carry out ObjectId conversion! + pass + + @abstractmethod + def get_string(self, key: str) -> str | None: + pass + + @abstractmethod + def get_object(self, key: str) -> Any | None: + pass + + @abstractmethod + def set_json(self, key: str, value: Any) -> None: + # TODO This method must also carry out ObjectId conversion! + pass + + @abstractmethod + def set_string(self, key: str, value: str) -> None: + pass + + @abstractmethod + def set_object(self, key: str, value: Any) -> None: + pass + + @abstractmethod + def delete(self, key: str) -> None: + pass diff --git a/classes/cache/cache_memcached.py b/classes/cache/cache_memcached.py new file mode 100644 index 0000000..81e91f9 --- /dev/null +++ b/classes/cache/cache_memcached.py @@ -0,0 +1,89 @@ +import logging +from logging import Logger +from typing import Dict, Any + +from pymemcache import Client + +from modules.cache_utils import string_to_json, json_to_string +from . import Cache + +logger: Logger = logging.getLogger(__name__) + + +class CacheMemcached(Cache): + client: Client + + def __init__(self, client: Client): + self.client = client + + logger.info("Initialized Memcached for caching") + + @classmethod + def from_config(cls, engine_config: Dict[str, Any]) -> "CacheMemcached": + if "uri" not in engine_config: + raise KeyError( + "Cache configuration is invalid. Please check if all keys are set (engine: memcached)" + ) + + return cls(Client(engine_config["uri"], default_noreply=True)) + + def get_json(self, key: str) -> Any | None: + try: + result: Any | None = self.client.get(key, None) + + logger.debug( + "Got json cache key '%s'%s", + key, + "" if result is not None else " (not found)", + ) + except Exception as exc: + logger.error("Could not get json cache key '%s' due to: %s", key, exc) + return None + + return None if result is None else string_to_json(result) + + def get_string(self, key: str) -> str | None: + try: + result: str | None = self.client.get(key, None) + + logger.debug( + "Got string cache key '%s'%s", + key, + "" if result is not None else " (not found)", + ) + + return result + except Exception as exc: + logger.error("Could not get string cache key '%s' due to: %s", key, exc) + return None + + # TODO Implement binary deserialization + def get_object(self, key: str) -> Any | None: + raise NotImplementedError() + + def set_json(self, key: str, value: Any) -> None: + try: + self.client.set(key, json_to_string(value)) + logger.debug("Set json cache key '%s'", key) + except Exception as exc: + logger.error("Could not set json cache key '%s' due to: %s", key, exc) + return None + + def set_string(self, key: str, value: str) -> None: + try: + self.client.set(key, value) + logger.debug("Set string cache key '%s'", key) + except Exception as exc: + logger.error("Could not set string cache key '%s' due to: %s", key, exc) + return None + + # TODO Implement binary serialization + def set_object(self, key: str, value: Any) -> None: + raise NotImplementedError() + + def delete(self, key: str) -> None: + try: + self.client.delete(key) + logger.debug("Deleted cache key '%s'", key) + except Exception as exc: + logger.error("Could not delete cache key '%s' due to: %s", key, exc) diff --git a/classes/cache/cache_redis.py b/classes/cache/cache_redis.py new file mode 100644 index 0000000..83fc022 --- /dev/null +++ b/classes/cache/cache_redis.py @@ -0,0 +1,89 @@ +import logging +from logging import Logger +from typing import Dict, Any + +from redis import Redis + +from classes.cache import Cache +from modules.cache_utils import string_to_json, json_to_string + +logger: Logger = logging.getLogger(__name__) + + +class CacheRedis(Cache): + client: Redis + + def __init__(self, client: Redis): + self.client = client + + logger.info("Initialized Redis for caching") + + @classmethod + def from_config(cls, engine_config: Dict[str, Any]) -> Any: + if "uri" not in engine_config: + raise KeyError( + "Cache configuration is invalid. Please check if all keys are set (engine: memcached)" + ) + + return cls(Redis.from_url(engine_config["uri"])) + + def get_json(self, key: str) -> Any | None: + try: + result: Any | None = self.client.get(key) + + logger.debug( + "Got json cache key '%s'%s", + key, + "" if result is not None else " (not found)", + ) + except Exception as exc: + logger.error("Could not get json cache key '%s' due to: %s", key, exc) + return None + + return None if result is None else string_to_json(result) + + def get_string(self, key: str) -> str | None: + try: + result: str | None = self.client.get(key) + + logger.debug( + "Got string cache key '%s'%s", + key, + "" if result is not None else " (not found)", + ) + + return result + except Exception as exc: + logger.error("Could not get string cache key '%s' due to: %s", key, exc) + return None + + # TODO Implement binary deserialization + def get_object(self, key: str) -> Any | None: + raise NotImplementedError() + + def set_json(self, key: str, value: Any) -> None: + try: + self.client.set(key, json_to_string(value)) + logger.debug("Set json cache key '%s'", key) + except Exception as exc: + logger.error("Could not set json cache key '%s' due to: %s", key, exc) + return None + + def set_string(self, key: str, value: str) -> None: + try: + self.client.set(key, value) + logger.debug("Set string cache key '%s'", key) + except Exception as exc: + logger.error("Could not set string cache key '%s' due to: %s", key, exc) + return None + + # TODO Implement binary serialization + def set_object(self, key: str, value: Any) -> None: + raise NotImplementedError() + + def delete(self, key: str) -> None: + try: + self.client.delete(key) + logger.debug("Deleted cache key '%s'", key) + except Exception as exc: + logger.error("Could not delete cache key '%s' due to: %s", key, exc) diff --git a/classes/errors/__init__.py b/classes/errors/__init__.py new file mode 100644 index 0000000..d172d33 --- /dev/null +++ b/classes/errors/__init__.py @@ -0,0 +1 @@ +from .pycord_user import UserNotFoundError diff --git a/classes/errors/pycord_user.py b/classes/errors/pycord_user.py new file mode 100644 index 0000000..5d6f718 --- /dev/null +++ b/classes/errors/pycord_user.py @@ -0,0 +1,9 @@ +class UserNotFoundError(Exception): + """PycordUser could not find user with such an ID in the database""" + + def __init__(self, user_id: int) -> None: + self.user_id = user_id + + super().__init__( + f"User with id {self.user_id} was not found" + ) diff --git a/classes/pycord_bot.py b/classes/pycord_bot.py new file mode 100644 index 0000000..c34c059 --- /dev/null +++ b/classes/pycord_bot.py @@ -0,0 +1,62 @@ +from typing import Any, Union + +from aiohttp import ClientSession +from discord import User +from libbot.pycord.classes import PycordBot as LibPycordBot + +from classes.pycord_user import PycordUser +from modules.cache_manager import create_cache_client + + +# from modules.tracking.dhl import update_tracks_dhl + + +class PycordBot(LibPycordBot): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._set_cache_engine() + + self.client_session = ClientSession() + + if self.scheduler is None: + return + + # Scheduler job for DHL parcel tracking + # self.scheduler.add_job( + # update_tracks_dhl, + # trigger="cron", + # hour=self.config["modules"]["tracking"]["fetch_hours"], + # args=[self, self.client_session], + # ) + + def _set_cache_engine(self) -> None: + if "cache" in self.config and self.config["cache"]["type"] is not None: + self.cache = create_cache_client(self.config, self.config["cache"]["type"]) + + async def find_user(self, user: Union[int, User]) -> PycordUser: + """Find User by its ID or User object. + + ### Args: + * user (`Union[int, User]`): ID or User object to extract ID from. + + ### Returns: + * `PycordUser`: User in database representation. + """ + + return ( + await PycordUser.from_id(user, cache=self.cache) + if isinstance(user, int) + else await PycordUser.from_id(user.id, cache=self.cache) + ) + + async def start(self, *args: Any, **kwargs: Any) -> None: + await super().start(*args, **kwargs) + + async def close(self, **kwargs) -> None: + await self.client_session.close() + + if self.scheduler is not None: + self.scheduler.shutdown() + + await super().close(**kwargs) diff --git a/classes/pycord_user.py b/classes/pycord_user.py new file mode 100644 index 0000000..57a055d --- /dev/null +++ b/classes/pycord_user.py @@ -0,0 +1,153 @@ +import logging +from dataclasses import dataclass +from typing import Dict, Any, Optional + +from bson import ObjectId +from pymongo.results import InsertOneResult + +from classes.cache import Cache +from classes.errors.pycord_user import UserNotFoundError +from modules.database import col_users + +logger = logging.getLogger(__name__) + + +@dataclass +class PycordUser: + """Dataclass of DB entry of a user""" + + __slots__ = ("_id", "id") + + _id: ObjectId + id: int + + @classmethod + async def from_id( + cls, user_id: int, allow_creation: bool = True, cache: Optional[Cache] = None + ) -> "PycordUser": + """Find user in database and create new record if user does not exist. + + Args: + user_id (int): User's Discord ID + allow_creation (:obj:`bool`, optional): Create new user record if none found in the database + cache (:obj:`Cache`, optional): Cache engine to get the cache from + + Returns: + PycordUser: User object + + Raises: + UserNotFoundException: User was not found and creation was not allowed + """ + if cache is not None: + cached_entry: Dict[str, Any] | None = cache.get_json(f"user_{user_id}") + + if cached_entry is not None: + return cls(**cached_entry) + + db_entry = await col_users.find_one({"id": user_id}) + + if db_entry is None: + if not allow_creation: + raise UserNotFoundError(user_id) + + db_entry = PycordUser.get_defaults(user_id) + + insert_result: InsertOneResult = await col_users.insert_one(db_entry) + + db_entry["_id"] = insert_result.inserted_id + + if cache is not None: + cache.set_json(f"user_{user_id}", db_entry) + + return cls(**db_entry) + + def _to_dict(self) -> Dict[str, Any]: + return { + "_id": self._id, + "id": self.id, + } + + async def _set(self, key: str, value: Any, cache: Optional[Cache] = None) -> None: + """Set attribute data and save it into the database. + + Args: + key (str): Attribute to change + value (Any): Value to set + cache (:obj:`Cache`, optional): Cache engine to write the update into + """ + if not hasattr(self, key): + raise AttributeError() + + setattr(self, key, value) + + await col_users.update_one( + {"_id": self._id}, {"$set": {key: value}}, upsert=True + ) + + self._update_cache(cache) + + logger.info("Set attribute '%s' of user %s to '%s'", key, self.id, value) + + async def _remove(self, key: str, cache: Optional[Cache] = None) -> None: + """Remove attribute data and save it into the database. + + Args: + key (str): Attribute to remove + cache (:obj:`Cache`, optional): Cache engine to write the update into + """ + if not hasattr(self, key): + raise AttributeError() + + default_value: Any = PycordUser.get_default_value(key) + + setattr(self, key, default_value) + + await col_users.update_one( + {"_id": self._id}, {"$set": {key: default_value}}, upsert=True + ) + + self._update_cache(cache) + + logger.info("Removed attribute '%s' of user %s", key, self.id) + + def _get_cache_key(self) -> str: + return f"user_{self.id}" + + def _update_cache(self, cache: Optional[Cache] = None) -> None: + if cache is None: + return + + user_dict: Dict[str, Any] = self._to_dict() + + if user_dict is not None: + cache.set_json(self._get_cache_key(), user_dict) + else: + self._delete_cache(cache) + + def _delete_cache(self, cache: Optional[Cache] = None) -> None: + if cache is None: + return + + cache.delete(self._get_cache_key()) + + @staticmethod + def get_defaults(user_id: int | None = None) -> Dict[str, Any]: + return { + "id": user_id, + } + + @staticmethod + def get_default_value(key: str) -> Any: + if key not in PycordUser.get_defaults(): + raise KeyError(f"There's no default value for key '{key}' in PycordUser") + + return PycordUser.get_defaults()[key] + + async def purge(self, cache: Optional[Cache] = None) -> None: + """Completely remove user data from database. Currently only removes the user record from users collection. + + Args: + cache (:obj:`Cache`, optional): Cache engine to write the update into + """ + await col_users.delete_one({"_id": self._id}) + self._delete_cache(cache) diff --git a/classes/pycordbot.py b/classes/pycordbot.py deleted file mode 100644 index 4f770bf..0000000 --- a/classes/pycordbot.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Any, Union - -from aiohttp import ClientSession -from discord import User -from libbot.pycord.classes import PycordBot as LibPycordBot - -from classes.pycorduser import PycordUser -from modules.tracking.dhl import update_tracks_dhl - - -class PycordBot(LibPycordBot): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.client_session = ClientSession() - - if self.scheduler is None: - return - - # Scheduler job for DHL parcel tracking - self.scheduler.add_job( - update_tracks_dhl, - trigger="cron", - hour=self.config["modules"]["tracking"]["fetch_hours"], - args=[self, self.client_session], - ) - - async def find_user(self, user: Union[int, User]) -> PycordUser: - """Find User by it's ID or User object. - - ### Args: - * user (`Union[int, User]`): ID or User object to extract ID from. - - ### Returns: - * `PycordUser`: User in database representation. - """ - - return ( - await PycordUser.find(user) - if isinstance(user, int) - else await PycordUser.find(user.id) - ) - - async def close(self, *args: Any, **kwargs: Any) -> None: - await self.client_session.close() - - if self.scheduler is not None: - self.scheduler.shutdown() - - await super().close(*args, **kwargs) diff --git a/classes/pycorduser.py b/classes/pycorduser.py deleted file mode 100644 index b82d83d..0000000 --- a/classes/pycorduser.py +++ /dev/null @@ -1,42 +0,0 @@ -import logging -from dataclasses import dataclass - -from bson import ObjectId - -from modules.database import col_users - -logger = logging.getLogger(__name__) - - -@dataclass -class PycordUser: - """Dataclass of DB entry of a user""" - - __slots__ = ("_id", "id") - - _id: ObjectId - id: int - - @classmethod - async def find(cls, id: int): - """Find user in database and create new record if user does not exist. - - ### Args: - * id (`int`): User's Discord ID - - ### Raises: - * `RuntimeError`: Raised when user entry after insertion could not be found. - - ### Returns: - * `PycordUser`: User with its database data. - """ - db_entry = await col_users.find_one({"id": id}) - - if db_entry is None: - inserted = await col_users.insert_one({"id": id}) - db_entry = await col_users.find_one({"_id": inserted.inserted_id}) - - if db_entry is None: - raise RuntimeError("Could not find inserted user entry.") - - return cls(**db_entry) diff --git a/config_example.json b/config_example.json index 508704c..ebb51f8 100644 --- a/config_example.json +++ b/config_example.json @@ -23,6 +23,15 @@ "port": 27017, "name": "javelina" }, + "cache": { + "type": null, + "memcached": { + "uri": "127.0.0.1:11211" + }, + "redis": { + "uri": "redis://127.0.0.1:6379/0" + } + }, "privacy": { "api_endpoint": "https://api.javelina.eu/v1" }, diff --git a/main.py b/main.py index 0dafb83..5f400b1 100644 --- a/main.py +++ b/main.py @@ -2,9 +2,9 @@ import asyncio import logging from os import getpid -from libbot import sync +from libbot.utils import config_get -from classes.pycordbot import PycordBot +from classes.pycord_bot import PycordBot from modules.extensions_loader import dynamic_import_from_src from modules.scheduler import scheduler @@ -12,7 +12,7 @@ from modules.scheduler import scheduler from api.app import app logging.basicConfig( - level=logging.DEBUG if sync.config_get("debug") else logging.INFO, + level=logging.DEBUG if config_get("debug") else logging.INFO, format="%(name)s.%(funcName)s | %(levelname)s | %(message)s", datefmt="[%X]", ) @@ -29,7 +29,7 @@ async def main(): dynamic_import_from_src("api.extensions", star_import=True) try: - await bot.start(sync.config_get("bot_token", "bot")) + await bot.start(config_get("bot_token", "bot")) except KeyboardInterrupt: logger.warning("Forcefully shutting down with PID %s...", getpid()) await bot.close() diff --git a/modules/cache_manager.py b/modules/cache_manager.py new file mode 100644 index 0000000..c852f8e --- /dev/null +++ b/modules/cache_manager.py @@ -0,0 +1,29 @@ +from typing import Dict, Any, Literal + +from classes.cache.cache_memcached import CacheMemcached +from classes.cache.cache_redis import CacheRedis + + +def create_cache_client( + config: Dict[str, Any], + engine: Literal["memcached", "redis"] | None = None, +) -> CacheMemcached | CacheRedis: + if engine not in ["memcached", "redis"] or engine is None: + raise KeyError( + f"Incorrect cache engine provided. Expected 'memcached' or 'redis', got '{engine}'" + ) + + if "cache" not in config or engine not in config["cache"]: + raise KeyError( + f"Cache configuration is invalid. Please check if all keys are set (engine: '{engine}')" + ) + + match engine: + case "memcached": + return CacheMemcached.from_config(config["cache"][engine]) + case "redis": + return CacheRedis.from_config(config["cache"][engine]) + case _: + raise KeyError( + f"Cache implementation for the engine '{engine}' is not present." + ) diff --git a/modules/cache_utils.py b/modules/cache_utils.py new file mode 100644 index 0000000..2f95204 --- /dev/null +++ b/modules/cache_utils.py @@ -0,0 +1,25 @@ +from copy import deepcopy +from typing import Any + +from bson import ObjectId +from ujson import dumps, loads + + +def json_to_string(json_object: Any) -> str: + json_object_copy: Any = deepcopy(json_object) + + if isinstance(json_object_copy, dict) and "_id" in json_object_copy: + json_object_copy["_id"] = str(json_object_copy["_id"]) + + return dumps( + json_object_copy, ensure_ascii=False, indent=0, escape_forward_slashes=False + ) + + +def string_to_json(json_string: str) -> Any: + json_object: Any = loads(json_string) + + if "_id" in json_object: + json_object["_id"] = ObjectId(json_object["_id"]) + + return json_object diff --git a/modules/database.py b/modules/database.py index 6581e81..4b02ac2 100644 --- a/modules/database.py +++ b/modules/database.py @@ -3,7 +3,7 @@ from typing import Any, Mapping from async_pymongo import AsyncClient, AsyncCollection, AsyncDatabase -from libbot.sync import config_get +from libbot.utils import config_get db_config: Mapping[str, Any] = config_get("database") @@ -20,13 +20,17 @@ else: db_config["host"], db_config["port"], db_config["name"] ) +# Async declarations db_client = AsyncClient(con_string) db: AsyncDatabase = db_client.get_database(name=db_config["name"]) col_users: AsyncCollection = db.get_collection("users") -col_messages: AsyncCollection = db.get_collection("messages") -col_warnings: AsyncCollection = db.get_collection("warnings") -col_checkouts: AsyncCollection = db.get_collection("checkouts") -col_trackings: AsyncCollection = db.get_collection("trackings") -col_authorized: AsyncCollection = db.get_collection("authorized") -col_transactions: AsyncCollection = db.get_collection("transactions") +# col_messages: AsyncCollection = db.get_collection("messages") +# col_warnings: AsyncCollection = db.get_collection("warnings") +# col_checkouts: AsyncCollection = db.get_collection("checkouts") +# col_trackings: AsyncCollection = db.get_collection("trackings") +# col_authorized: AsyncCollection = db.get_collection("authorized") +# col_transactions: AsyncCollection = db.get_collection("transactions") + +# Update indexes +db.dispatch.get_collection("users").create_index("id", unique=True) \ No newline at end of file diff --git a/modules/migrator.py b/modules/migrator.py index 5ebeb91..ec5f1ad 100644 --- a/modules/migrator.py +++ b/modules/migrator.py @@ -1,6 +1,6 @@ from typing import Any, Mapping -from libbot.sync import config_get +from libbot.utils import config_get from mongodb_migrations.cli import MigrationManager from mongodb_migrations.config import Configuration diff --git a/requirements.txt b/requirements.txt index bcbefed..805cd94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,10 @@ async_pymongo==0.1.11 colorthief==0.2.1 deepl==1.21.0 fastapi[all]~=0.115.0 +libbot[speed,pycord]==4.0.2 mongodb-migrations==1.3.1 +pymemcache~=4.0.0 pynacl~=1.5.0 pyrmv==0.4.0 pytz~=2025.1 ---extra-index-url https://git.end-play.xyz/api/packages/profitroll/pypi/simple -libbot[speed,pycord]~=3.3.0,<4.0.0 +redis~=5.2.1 \ No newline at end of file