Worked on #13 and #4. There are some caching issues left, though. Introduced abstract class Cacheable. Replaced async_pymongo with pymongo
This commit is contained in:
1
classes/abstract/__init__.py
Normal file
1
classes/abstract/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .cacheable import Cacheable
|
81
classes/abstract/cacheable.py
Normal file
81
classes/abstract/cacheable.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, ClassVar, Dict, Optional
|
||||
|
||||
from libbot.cache.classes import Cache
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
|
||||
|
||||
class Cacheable(ABC):
|
||||
"""Abstract class for cacheable"""
|
||||
|
||||
__short_name__: str
|
||||
__collection__: ClassVar[AsyncCollection]
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def from_id(cls, *args: Any, cache: Optional[Cache] = None, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_cache_key(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _update_cache(self, cache: Optional[Cache] = None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _delete_cache(self, cache: Optional[Cache] = None) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_defaults(**kwargs: Any) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_default_value(key: str) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def update(
|
||||
self,
|
||||
cache: Optional[Cache] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reset(
|
||||
self,
|
||||
*args: str,
|
||||
cache: Optional[Cache] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def purge(self, cache: Optional[Cache] = None) -> None:
|
||||
pass
|
@@ -14,13 +14,13 @@ from typing_extensions import override
|
||||
|
||||
from classes import PycordEvent, PycordEventStage, PycordGuild, PycordUser
|
||||
from classes.errors import (
|
||||
DiscordGuildMemberNotFoundError,
|
||||
EventNotFoundError,
|
||||
EventStageMissingSequenceError,
|
||||
EventStageNotFoundError,
|
||||
GuildNotFoundError,
|
||||
DiscordGuildMemberNotFoundError,
|
||||
)
|
||||
from modules.database import col_events, col_users
|
||||
from modules.database import col_events, col_users, _update_database_indexes
|
||||
from modules.utils import get_logger
|
||||
|
||||
logger: Logger = get_logger(__name__)
|
||||
@@ -58,6 +58,7 @@ class PycordBot(LibPycordBot):
|
||||
@override
|
||||
async def start(self, *args: Any, **kwargs: Any) -> None:
|
||||
await self._schedule_tasks()
|
||||
await _update_database_indexes()
|
||||
|
||||
self.started = datetime.now(tz=ZoneInfo("UTC"))
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""Module with class PycordEvent."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, tzinfo
|
||||
from logging import Logger
|
||||
from typing import Any, Dict, List, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
@@ -11,6 +11,7 @@ from libbot.cache.classes import Cache
|
||||
from pymongo import DESCENDING
|
||||
from pymongo.results import InsertOneResult
|
||||
|
||||
from classes.abstract import Cacheable
|
||||
from classes.errors import EventNotFoundError
|
||||
from modules.database import col_events
|
||||
from modules.utils import get_logger, restore_from_cache
|
||||
@@ -19,7 +20,7 @@ logger: Logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PycordEvent:
|
||||
class PycordEvent(Cacheable):
|
||||
"""Object representation of an event in the database.
|
||||
|
||||
Attributes:
|
||||
@@ -82,7 +83,7 @@ class PycordEvent:
|
||||
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache)
|
||||
|
||||
if cached_entry is not None:
|
||||
return cls(**cached_entry)
|
||||
return cls(**cls._entry_from_cache(cached_entry))
|
||||
|
||||
db_entry = await cls.__collection__.find_one(
|
||||
{"_id": event_id if isinstance(event_id, ObjectId) else ObjectId(event_id)}
|
||||
@@ -92,7 +93,7 @@ class PycordEvent:
|
||||
raise EventNotFoundError(event_id=event_id)
|
||||
|
||||
if cache is not None:
|
||||
cache.set_json(f"{cls.__short_name__}_{event_id}", db_entry)
|
||||
cache.set_json(f"{cls.__short_name__}_{event_id}", cls._entry_to_cache(dict(db_entry)))
|
||||
|
||||
return cls(**db_entry)
|
||||
|
||||
@@ -123,7 +124,7 @@ class PycordEvent:
|
||||
raise EventNotFoundError(event_name=event_name, guild_id=guild_id)
|
||||
|
||||
if cache is not None:
|
||||
cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", db_entry)
|
||||
cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", cls._entry_to_cache(db_entry))
|
||||
|
||||
return cls(**db_entry)
|
||||
|
||||
@@ -172,7 +173,7 @@ class PycordEvent:
|
||||
db_entry["_id"] = insert_result.inserted_id
|
||||
|
||||
if cache is not None:
|
||||
cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry)
|
||||
cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry))
|
||||
|
||||
return cls(**db_entry)
|
||||
|
||||
@@ -215,10 +216,10 @@ class PycordEvent:
|
||||
if cache is None:
|
||||
return
|
||||
|
||||
user_dict: Dict[str, Any] = self.to_dict()
|
||||
object_dict: Dict[str, Any] = self.to_dict(json_compatible=True)
|
||||
|
||||
if user_dict is not None:
|
||||
cache.set_json(self._get_cache_key(), user_dict)
|
||||
if object_dict is not None:
|
||||
cache.set_json(self._get_cache_key(), object_dict)
|
||||
else:
|
||||
self._delete_cache(cache)
|
||||
|
||||
@@ -253,6 +254,32 @@ class PycordEvent:
|
||||
if stage_index != old_stage_index:
|
||||
await (await bot.find_event_stage(event_stage_id)).update(cache, sequence=stage_index)
|
||||
|
||||
@staticmethod
|
||||
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
cache_entry: Dict[str, Any] = db_entry.copy()
|
||||
|
||||
cache_entry["_id"] = str(cache_entry["_id"])
|
||||
cache_entry["created"] = cache_entry["created"].isoformat()
|
||||
cache_entry["ended"] = None if cache_entry["ended"] is None else cache_entry["ended"].isoformat()
|
||||
cache_entry["starts"] = cache_entry["starts"].isoformat()
|
||||
cache_entry["ends"] = cache_entry["ends"].isoformat()
|
||||
cache_entry["stage_ids"] = [str(stage_id) for stage_id in cache_entry["stage_ids"]]
|
||||
|
||||
return cache_entry
|
||||
|
||||
@staticmethod
|
||||
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
db_entry: Dict[str, Any] = cache_entry.copy()
|
||||
|
||||
db_entry["_id"] = ObjectId(db_entry["_id"])
|
||||
db_entry["created"] = datetime.fromisoformat(db_entry["created"])
|
||||
db_entry["ended"] = None if db_entry["ended"] is None else datetime.fromisoformat(db_entry["ended"])
|
||||
db_entry["starts"] = datetime.fromisoformat(db_entry["starts"])
|
||||
db_entry["ends"] = datetime.fromisoformat(db_entry["ends"])
|
||||
db_entry["stage_ids"] = [ObjectId(stage_id) for stage_id in db_entry["stage_ids"]]
|
||||
|
||||
return db_entry
|
||||
|
||||
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
|
||||
"""Convert the object to a JSON representation.
|
||||
|
||||
@@ -266,14 +293,20 @@ class PycordEvent:
|
||||
"_id": self._id if not json_compatible else str(self._id),
|
||||
"name": self.name,
|
||||
"guild_id": self.guild_id,
|
||||
"created": self.created,
|
||||
"ended": self.ended,
|
||||
"created": self.created if not json_compatible else self.created.isoformat(),
|
||||
"ended": (
|
||||
self.ended
|
||||
if not json_compatible
|
||||
else (None if self.ended is None else self.ended.isoformat())
|
||||
),
|
||||
"is_cancelled": self.is_cancelled,
|
||||
"creator_id": self.creator_id,
|
||||
"starts": self.starts,
|
||||
"ends": self.ends,
|
||||
"starts": self.starts if not json_compatible else self.starts.isoformat(),
|
||||
"ends": self.ends if not json_compatible else self.ends.isoformat(),
|
||||
"thumbnail": self.thumbnail,
|
||||
"stage_ids": self.stage_ids,
|
||||
"stage_ids": (
|
||||
self.stage_ids if not json_compatible else [str(stage_id) for stage_id in self.stage_ids]
|
||||
),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -456,7 +489,7 @@ class PycordEvent:
|
||||
|
||||
return self.ends.replace(tzinfo=ZoneInfo("UTC"))
|
||||
|
||||
def get_start_date_localized(self, tz: str | timezone | ZoneInfo) -> datetime:
|
||||
def get_start_date_localized(self, tz: tzinfo) -> datetime:
|
||||
"""Get the event start date in the provided timezone.
|
||||
|
||||
Returns:
|
||||
@@ -470,7 +503,7 @@ class PycordEvent:
|
||||
|
||||
return self.starts.replace(tzinfo=tz)
|
||||
|
||||
def get_end_date_localized(self, tz: str | timezone | ZoneInfo) -> datetime:
|
||||
def get_end_date_localized(self, tz: tzinfo) -> datetime:
|
||||
"""Get the event end date in the provided timezone.
|
||||
|
||||
Returns:
|
||||
|
@@ -10,6 +10,7 @@ from discord import File
|
||||
from libbot.cache.classes import Cache
|
||||
from pymongo.results import InsertOneResult
|
||||
|
||||
from classes.abstract import Cacheable
|
||||
from classes.errors import EventStageNotFoundError
|
||||
from modules.database import col_stages
|
||||
from modules.utils import get_logger, restore_from_cache
|
||||
@@ -18,7 +19,7 @@ logger: Logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PycordEventStage:
|
||||
class PycordEventStage(Cacheable):
|
||||
__slots__ = (
|
||||
"_id",
|
||||
"event_id",
|
||||
@@ -61,7 +62,7 @@ class PycordEventStage:
|
||||
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, stage_id, cache=cache)
|
||||
|
||||
if cached_entry is not None:
|
||||
return cls(**cached_entry)
|
||||
return cls(**cls._entry_from_cache(cached_entry))
|
||||
|
||||
db_entry = await cls.__collection__.find_one(
|
||||
{"_id": stage_id if isinstance(stage_id, ObjectId) else ObjectId(stage_id)}
|
||||
@@ -71,7 +72,7 @@ class PycordEventStage:
|
||||
raise EventStageNotFoundError(stage_id)
|
||||
|
||||
if cache is not None:
|
||||
cache.set_json(f"{cls.__short_name__}_{stage_id}", db_entry)
|
||||
cache.set_json(f"{cls.__short_name__}_{stage_id}", cls._entry_to_cache(dict(db_entry)))
|
||||
|
||||
return cls(**db_entry)
|
||||
|
||||
@@ -104,7 +105,7 @@ class PycordEventStage:
|
||||
db_entry["_id"] = insert_result.inserted_id
|
||||
|
||||
if cache is not None:
|
||||
cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry)
|
||||
cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry))
|
||||
|
||||
return cls(**db_entry)
|
||||
|
||||
@@ -159,10 +160,10 @@ class PycordEventStage:
|
||||
if cache is None:
|
||||
return
|
||||
|
||||
user_dict: Dict[str, Any] = self.to_dict()
|
||||
object_dict: Dict[str, Any] = self.to_dict(json_compatible=True)
|
||||
|
||||
if user_dict is not None:
|
||||
cache.set_json(self._get_cache_key(), user_dict)
|
||||
if object_dict is not None:
|
||||
cache.set_json(self._get_cache_key(), object_dict)
|
||||
else:
|
||||
self._delete_cache(cache)
|
||||
|
||||
@@ -172,6 +173,26 @@ class PycordEventStage:
|
||||
|
||||
cache.delete(self._get_cache_key())
|
||||
|
||||
@staticmethod
|
||||
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
cache_entry: Dict[str, Any] = db_entry.copy()
|
||||
|
||||
cache_entry["_id"] = str(cache_entry["_id"])
|
||||
cache_entry["event_id"] = str(cache_entry["event_id"])
|
||||
cache_entry["created"] = cache_entry["created"].isoformat()
|
||||
|
||||
return cache_entry
|
||||
|
||||
@staticmethod
|
||||
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
db_entry: Dict[str, Any] = cache_entry.copy()
|
||||
|
||||
db_entry["_id"] = ObjectId(db_entry["_id"])
|
||||
db_entry["event_id"] = ObjectId(db_entry["event_id"])
|
||||
db_entry["created"] = datetime.fromisoformat(db_entry["created"])
|
||||
|
||||
return db_entry
|
||||
|
||||
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
|
||||
"""Convert the object to a JSON representation.
|
||||
|
||||
@@ -186,7 +207,7 @@ class PycordEventStage:
|
||||
"event_id": self.event_id if not json_compatible else str(self.event_id),
|
||||
"guild_id": self.guild_id,
|
||||
"sequence": self.sequence,
|
||||
"created": self.created,
|
||||
"created": self.created if not json_compatible else self.created.isoformat(),
|
||||
"creator_id": self.creator_id,
|
||||
"question": self.question,
|
||||
"answer": self.answer,
|
||||
|
@@ -6,6 +6,7 @@ from bson import ObjectId
|
||||
from libbot.cache.classes import Cache
|
||||
from pymongo.results import InsertOneResult
|
||||
|
||||
from classes.abstract import Cacheable
|
||||
from classes.errors import GuildNotFoundError
|
||||
from modules.database import col_guilds
|
||||
from modules.utils import get_logger, restore_from_cache
|
||||
@@ -14,7 +15,7 @@ logger: Logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PycordGuild:
|
||||
class PycordGuild(Cacheable):
|
||||
"""Dataclass of DB entry of a guild"""
|
||||
|
||||
__slots__ = (
|
||||
@@ -57,7 +58,7 @@ class PycordGuild:
|
||||
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, guild_id, cache=cache)
|
||||
|
||||
if cached_entry is not None:
|
||||
return cls(**cached_entry)
|
||||
return cls(**cls._entry_from_cache(cached_entry))
|
||||
|
||||
db_entry = await cls.__collection__.find_one({"id": guild_id})
|
||||
|
||||
@@ -72,7 +73,7 @@ class PycordGuild:
|
||||
db_entry["_id"] = insert_result.inserted_id
|
||||
|
||||
if cache is not None:
|
||||
cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry)
|
||||
cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry))
|
||||
|
||||
return cls(**db_entry)
|
||||
|
||||
@@ -115,10 +116,10 @@ class PycordGuild:
|
||||
if cache is None:
|
||||
return
|
||||
|
||||
user_dict: Dict[str, Any] = self.to_dict()
|
||||
object_dict: Dict[str, Any] = self.to_dict(json_compatible=True)
|
||||
|
||||
if user_dict is not None:
|
||||
cache.set_json(self._get_cache_key(), user_dict)
|
||||
if object_dict is not None:
|
||||
cache.set_json(self._get_cache_key(), object_dict)
|
||||
else:
|
||||
self._delete_cache(cache)
|
||||
|
||||
@@ -128,6 +129,22 @@ class PycordGuild:
|
||||
|
||||
cache.delete(self._get_cache_key())
|
||||
|
||||
@staticmethod
|
||||
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
cache_entry: Dict[str, Any] = db_entry.copy()
|
||||
|
||||
cache_entry["_id"] = str(cache_entry["_id"])
|
||||
|
||||
return cache_entry
|
||||
|
||||
@staticmethod
|
||||
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
db_entry: Dict[str, Any] = cache_entry.copy()
|
||||
|
||||
db_entry["_id"] = ObjectId(db_entry["_id"])
|
||||
|
||||
return db_entry
|
||||
|
||||
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
|
||||
"""Convert the object to a JSON representation.
|
||||
|
||||
|
@@ -17,6 +17,7 @@ from discord.abc import GuildChannel
|
||||
from libbot.cache.classes import Cache
|
||||
from pymongo.results import InsertOneResult
|
||||
|
||||
from classes.abstract import Cacheable
|
||||
from classes.errors import (
|
||||
DiscordCategoryNotFoundError,
|
||||
DiscordChannelNotFoundError,
|
||||
@@ -33,9 +34,17 @@ logger: Logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PycordUser:
|
||||
class PycordUser(Cacheable):
|
||||
"""Dataclass of DB entry of a user"""
|
||||
|
||||
# TODO Implement this
|
||||
async def update(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
# TODO Implement this
|
||||
async def reset(self, *args: str, cache: Optional[Cache] = None) -> None:
|
||||
pass
|
||||
|
||||
__slots__ = (
|
||||
"_id",
|
||||
"id",
|
||||
@@ -83,7 +92,7 @@ class PycordUser:
|
||||
)
|
||||
|
||||
if cached_entry is not None:
|
||||
return cls(**cached_entry)
|
||||
return cls(**cls._entry_from_cache(cached_entry))
|
||||
|
||||
db_entry = await cls.__collection__.find_one({"id": user_id, "guild_id": guild_id})
|
||||
|
||||
@@ -98,7 +107,7 @@ class PycordUser:
|
||||
db_entry["_id"] = insert_result.inserted_id
|
||||
|
||||
if cache is not None:
|
||||
cache.set_json(f"{cls.__short_name__}_{user_id}_{guild_id}", db_entry)
|
||||
cache.set_json(f"{cls.__short_name__}_{user_id}_{guild_id}", cls._entry_to_cache(db_entry))
|
||||
|
||||
return cls(**db_entry)
|
||||
|
||||
@@ -186,10 +195,10 @@ class PycordUser:
|
||||
if cache is None:
|
||||
return
|
||||
|
||||
user_dict: Dict[str, Any] = self.to_dict()
|
||||
object_dict: Dict[str, Any] = self.to_dict(json_compatible=True)
|
||||
|
||||
if user_dict is not None:
|
||||
cache.set_json(self._get_cache_key(), user_dict)
|
||||
if object_dict is not None:
|
||||
cache.set_json(self._get_cache_key(), object_dict)
|
||||
else:
|
||||
self._delete_cache(cache)
|
||||
|
||||
@@ -199,6 +208,46 @@ class PycordUser:
|
||||
|
||||
cache.delete(self._get_cache_key())
|
||||
|
||||
@staticmethod
|
||||
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
cache_entry: Dict[str, Any] = db_entry.copy()
|
||||
|
||||
cache_entry["_id"] = str(cache_entry["_id"])
|
||||
cache_entry["current_event_id"] = (
|
||||
None if cache_entry["current_event_id"] is None else str(cache_entry["current_event_id"])
|
||||
)
|
||||
cache_entry["current_stage_id"] = (
|
||||
None if cache_entry["current_stage_id"] is None else str(cache_entry["current_stage_id"])
|
||||
)
|
||||
cache_entry["registered_event_ids"] = [
|
||||
str(event_id) for event_id in cache_entry["registered_event_ids"]
|
||||
]
|
||||
cache_entry["completed_event_ids"] = [
|
||||
str(event_id) for event_id in cache_entry["completed_event_ids"]
|
||||
]
|
||||
|
||||
return cache_entry
|
||||
|
||||
@staticmethod
|
||||
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
|
||||
db_entry: Dict[str, Any] = cache_entry.copy()
|
||||
|
||||
db_entry["_id"] = ObjectId(db_entry["_id"])
|
||||
db_entry["current_event_id"] = (
|
||||
None if db_entry["current_event_id"] is None else ObjectId(db_entry["current_event_id"])
|
||||
)
|
||||
db_entry["current_stage_id"] = (
|
||||
None if db_entry["current_stage_id"] is None else ObjectId(db_entry["current_stage_id"])
|
||||
)
|
||||
db_entry["registered_event_ids"] = [
|
||||
ObjectId(event_id) for event_id in db_entry["registered_event_ids"]
|
||||
]
|
||||
db_entry["completed_event_ids"] = [
|
||||
ObjectId(event_id) for event_id in db_entry["completed_event_ids"]
|
||||
]
|
||||
|
||||
return db_entry
|
||||
|
||||
# TODO Add documentation
|
||||
@staticmethod
|
||||
def get_defaults(user_id: Optional[int] = None, guild_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
|
Reference in New Issue
Block a user