Worked on #13 and #4. There are some caching issues left, though. Introduced abstract class Cacheable. Replaced async_pymongo with pymongo

This commit is contained in:
2025-05-06 02:54:30 +02:00
parent 9d562e2e9d
commit 86c75d06fa
22 changed files with 412 additions and 137 deletions

View File

@@ -1,7 +1,7 @@
"""Module with class PycordEvent."""
from dataclasses import dataclass
from datetime import datetime, timezone
from datetime import datetime, tzinfo
from logging import Logger
from typing import Any, Dict, List, Optional
from zoneinfo import ZoneInfo
@@ -11,6 +11,7 @@ from libbot.cache.classes import Cache
from pymongo import DESCENDING
from pymongo.results import InsertOneResult
from classes.abstract import Cacheable
from classes.errors import EventNotFoundError
from modules.database import col_events
from modules.utils import get_logger, restore_from_cache
@@ -19,7 +20,7 @@ logger: Logger = get_logger(__name__)
@dataclass
class PycordEvent:
class PycordEvent(Cacheable):
"""Object representation of an event in the database.
Attributes:
@@ -82,7 +83,7 @@ class PycordEvent:
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache)
if cached_entry is not None:
return cls(**cached_entry)
return cls(**cls._entry_from_cache(cached_entry))
db_entry = await cls.__collection__.find_one(
{"_id": event_id if isinstance(event_id, ObjectId) else ObjectId(event_id)}
@@ -92,7 +93,7 @@ class PycordEvent:
raise EventNotFoundError(event_id=event_id)
if cache is not None:
cache.set_json(f"{cls.__short_name__}_{event_id}", db_entry)
cache.set_json(f"{cls.__short_name__}_{event_id}", cls._entry_to_cache(dict(db_entry)))
return cls(**db_entry)
@@ -123,7 +124,7 @@ class PycordEvent:
raise EventNotFoundError(event_name=event_name, guild_id=guild_id)
if cache is not None:
cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", db_entry)
cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", cls._entry_to_cache(db_entry))
return cls(**db_entry)
@@ -172,7 +173,7 @@ class PycordEvent:
db_entry["_id"] = insert_result.inserted_id
if cache is not None:
cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry)
cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry))
return cls(**db_entry)
@@ -215,10 +216,10 @@ class PycordEvent:
if cache is None:
return
user_dict: Dict[str, Any] = self.to_dict()
object_dict: Dict[str, Any] = self.to_dict(json_compatible=True)
if user_dict is not None:
cache.set_json(self._get_cache_key(), user_dict)
if object_dict is not None:
cache.set_json(self._get_cache_key(), object_dict)
else:
self._delete_cache(cache)
@@ -253,6 +254,32 @@ class PycordEvent:
if stage_index != old_stage_index:
await (await bot.find_event_stage(event_stage_id)).update(cache, sequence=stage_index)
@staticmethod
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
cache_entry: Dict[str, Any] = db_entry.copy()
cache_entry["_id"] = str(cache_entry["_id"])
cache_entry["created"] = cache_entry["created"].isoformat()
cache_entry["ended"] = None if cache_entry["ended"] is None else cache_entry["ended"].isoformat()
cache_entry["starts"] = cache_entry["starts"].isoformat()
cache_entry["ends"] = cache_entry["ends"].isoformat()
cache_entry["stage_ids"] = [str(stage_id) for stage_id in cache_entry["stage_ids"]]
return cache_entry
@staticmethod
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
db_entry: Dict[str, Any] = cache_entry.copy()
db_entry["_id"] = ObjectId(db_entry["_id"])
db_entry["created"] = datetime.fromisoformat(db_entry["created"])
db_entry["ended"] = None if db_entry["ended"] is None else datetime.fromisoformat(db_entry["ended"])
db_entry["starts"] = datetime.fromisoformat(db_entry["starts"])
db_entry["ends"] = datetime.fromisoformat(db_entry["ends"])
db_entry["stage_ids"] = [ObjectId(stage_id) for stage_id in db_entry["stage_ids"]]
return db_entry
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
"""Convert the object to a JSON representation.
@@ -266,14 +293,20 @@ class PycordEvent:
"_id": self._id if not json_compatible else str(self._id),
"name": self.name,
"guild_id": self.guild_id,
"created": self.created,
"ended": self.ended,
"created": self.created if not json_compatible else self.created.isoformat(),
"ended": (
self.ended
if not json_compatible
else (None if self.ended is None else self.ended.isoformat())
),
"is_cancelled": self.is_cancelled,
"creator_id": self.creator_id,
"starts": self.starts,
"ends": self.ends,
"starts": self.starts if not json_compatible else self.starts.isoformat(),
"ends": self.ends if not json_compatible else self.ends.isoformat(),
"thumbnail": self.thumbnail,
"stage_ids": self.stage_ids,
"stage_ids": (
self.stage_ids if not json_compatible else [str(stage_id) for stage_id in self.stage_ids]
),
}
@staticmethod
@@ -456,7 +489,7 @@ class PycordEvent:
return self.ends.replace(tzinfo=ZoneInfo("UTC"))
def get_start_date_localized(self, tz: str | timezone | ZoneInfo) -> datetime:
def get_start_date_localized(self, tz: tzinfo) -> datetime:
"""Get the event start date in the provided timezone.
Returns:
@@ -470,7 +503,7 @@ class PycordEvent:
return self.starts.replace(tzinfo=tz)
def get_end_date_localized(self, tz: str | timezone | ZoneInfo) -> datetime:
def get_end_date_localized(self, tz: tzinfo) -> datetime:
"""Get the event end date in the provided timezone.
Returns: