Worked on #13 and #4. There are some caching issues left, though. Introduced abstract class Cacheable. Replaced async_pymongo with pymongo

This commit is contained in:
2025-05-06 02:54:30 +02:00
parent 9d562e2e9d
commit 86c75d06fa
22 changed files with 412 additions and 137 deletions

View File

@@ -0,0 +1 @@
from .cacheable import Cacheable

View File

@@ -0,0 +1,81 @@
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Optional
from libbot.cache.classes import Cache
from pymongo.asynchronous.collection import AsyncCollection
class Cacheable(ABC):
"""Abstract class for cacheable"""
__short_name__: str
__collection__: ClassVar[AsyncCollection]
@classmethod
@abstractmethod
async def from_id(cls, *args: Any, cache: Optional[Cache] = None, **kwargs: Any) -> Any:
pass
@abstractmethod
async def _set(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
pass
@abstractmethod
async def _remove(self, *args: str, cache: Optional[Cache] = None) -> None:
pass
@abstractmethod
def _get_cache_key(self) -> str:
pass
@abstractmethod
def _update_cache(self, cache: Optional[Cache] = None) -> None:
pass
@abstractmethod
def _delete_cache(self, cache: Optional[Cache] = None) -> None:
pass
@staticmethod
@abstractmethod
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
pass
@staticmethod
@abstractmethod
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
pass
@abstractmethod
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
pass
@staticmethod
@abstractmethod
def get_defaults(**kwargs: Any) -> Dict[str, Any]:
pass
@staticmethod
@abstractmethod
def get_default_value(key: str) -> Any:
pass
@abstractmethod
async def update(
self,
cache: Optional[Cache] = None,
**kwargs: Any,
) -> None:
pass
@abstractmethod
async def reset(
self,
*args: str,
cache: Optional[Cache] = None,
) -> None:
pass
@abstractmethod
async def purge(self, cache: Optional[Cache] = None) -> None:
pass

View File

@@ -14,13 +14,13 @@ from typing_extensions import override
from classes import PycordEvent, PycordEventStage, PycordGuild, PycordUser from classes import PycordEvent, PycordEventStage, PycordGuild, PycordUser
from classes.errors import ( from classes.errors import (
DiscordGuildMemberNotFoundError,
EventNotFoundError, EventNotFoundError,
EventStageMissingSequenceError, EventStageMissingSequenceError,
EventStageNotFoundError, EventStageNotFoundError,
GuildNotFoundError, GuildNotFoundError,
DiscordGuildMemberNotFoundError,
) )
from modules.database import col_events, col_users from modules.database import col_events, col_users, _update_database_indexes
from modules.utils import get_logger from modules.utils import get_logger
logger: Logger = get_logger(__name__) logger: Logger = get_logger(__name__)
@@ -58,6 +58,7 @@ class PycordBot(LibPycordBot):
@override @override
async def start(self, *args: Any, **kwargs: Any) -> None: async def start(self, *args: Any, **kwargs: Any) -> None:
await self._schedule_tasks() await self._schedule_tasks()
await _update_database_indexes()
self.started = datetime.now(tz=ZoneInfo("UTC")) self.started = datetime.now(tz=ZoneInfo("UTC"))

View File

@@ -1,7 +1,7 @@
"""Module with class PycordEvent.""" """Module with class PycordEvent."""
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, tzinfo
from logging import Logger from logging import Logger
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
@@ -11,6 +11,7 @@ from libbot.cache.classes import Cache
from pymongo import DESCENDING from pymongo import DESCENDING
from pymongo.results import InsertOneResult from pymongo.results import InsertOneResult
from classes.abstract import Cacheable
from classes.errors import EventNotFoundError from classes.errors import EventNotFoundError
from modules.database import col_events from modules.database import col_events
from modules.utils import get_logger, restore_from_cache from modules.utils import get_logger, restore_from_cache
@@ -19,7 +20,7 @@ logger: Logger = get_logger(__name__)
@dataclass @dataclass
class PycordEvent: class PycordEvent(Cacheable):
"""Object representation of an event in the database. """Object representation of an event in the database.
Attributes: Attributes:
@@ -82,7 +83,7 @@ class PycordEvent:
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache) cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, event_id, cache=cache)
if cached_entry is not None: if cached_entry is not None:
return cls(**cached_entry) return cls(**cls._entry_from_cache(cached_entry))
db_entry = await cls.__collection__.find_one( db_entry = await cls.__collection__.find_one(
{"_id": event_id if isinstance(event_id, ObjectId) else ObjectId(event_id)} {"_id": event_id if isinstance(event_id, ObjectId) else ObjectId(event_id)}
@@ -92,7 +93,7 @@ class PycordEvent:
raise EventNotFoundError(event_id=event_id) raise EventNotFoundError(event_id=event_id)
if cache is not None: if cache is not None:
cache.set_json(f"{cls.__short_name__}_{event_id}", db_entry) cache.set_json(f"{cls.__short_name__}_{event_id}", cls._entry_to_cache(dict(db_entry)))
return cls(**db_entry) return cls(**db_entry)
@@ -123,7 +124,7 @@ class PycordEvent:
raise EventNotFoundError(event_name=event_name, guild_id=guild_id) raise EventNotFoundError(event_name=event_name, guild_id=guild_id)
if cache is not None: if cache is not None:
cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", db_entry) cache.set_json(f"{cls.__short_name__}_{db_entry['_id']}", cls._entry_to_cache(db_entry))
return cls(**db_entry) return cls(**db_entry)
@@ -172,7 +173,7 @@ class PycordEvent:
db_entry["_id"] = insert_result.inserted_id db_entry["_id"] = insert_result.inserted_id
if cache is not None: if cache is not None:
cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry) cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry))
return cls(**db_entry) return cls(**db_entry)
@@ -215,10 +216,10 @@ class PycordEvent:
if cache is None: if cache is None:
return return
user_dict: Dict[str, Any] = self.to_dict() object_dict: Dict[str, Any] = self.to_dict(json_compatible=True)
if user_dict is not None: if object_dict is not None:
cache.set_json(self._get_cache_key(), user_dict) cache.set_json(self._get_cache_key(), object_dict)
else: else:
self._delete_cache(cache) self._delete_cache(cache)
@@ -253,6 +254,32 @@ class PycordEvent:
if stage_index != old_stage_index: if stage_index != old_stage_index:
await (await bot.find_event_stage(event_stage_id)).update(cache, sequence=stage_index) await (await bot.find_event_stage(event_stage_id)).update(cache, sequence=stage_index)
@staticmethod
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
cache_entry: Dict[str, Any] = db_entry.copy()
cache_entry["_id"] = str(cache_entry["_id"])
cache_entry["created"] = cache_entry["created"].isoformat()
cache_entry["ended"] = None if cache_entry["ended"] is None else cache_entry["ended"].isoformat()
cache_entry["starts"] = cache_entry["starts"].isoformat()
cache_entry["ends"] = cache_entry["ends"].isoformat()
cache_entry["stage_ids"] = [str(stage_id) for stage_id in cache_entry["stage_ids"]]
return cache_entry
@staticmethod
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
db_entry: Dict[str, Any] = cache_entry.copy()
db_entry["_id"] = ObjectId(db_entry["_id"])
db_entry["created"] = datetime.fromisoformat(db_entry["created"])
db_entry["ended"] = None if db_entry["ended"] is None else datetime.fromisoformat(db_entry["ended"])
db_entry["starts"] = datetime.fromisoformat(db_entry["starts"])
db_entry["ends"] = datetime.fromisoformat(db_entry["ends"])
db_entry["stage_ids"] = [ObjectId(stage_id) for stage_id in db_entry["stage_ids"]]
return db_entry
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]: def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
"""Convert the object to a JSON representation. """Convert the object to a JSON representation.
@@ -266,14 +293,20 @@ class PycordEvent:
"_id": self._id if not json_compatible else str(self._id), "_id": self._id if not json_compatible else str(self._id),
"name": self.name, "name": self.name,
"guild_id": self.guild_id, "guild_id": self.guild_id,
"created": self.created, "created": self.created if not json_compatible else self.created.isoformat(),
"ended": self.ended, "ended": (
self.ended
if not json_compatible
else (None if self.ended is None else self.ended.isoformat())
),
"is_cancelled": self.is_cancelled, "is_cancelled": self.is_cancelled,
"creator_id": self.creator_id, "creator_id": self.creator_id,
"starts": self.starts, "starts": self.starts if not json_compatible else self.starts.isoformat(),
"ends": self.ends, "ends": self.ends if not json_compatible else self.ends.isoformat(),
"thumbnail": self.thumbnail, "thumbnail": self.thumbnail,
"stage_ids": self.stage_ids, "stage_ids": (
self.stage_ids if not json_compatible else [str(stage_id) for stage_id in self.stage_ids]
),
} }
@staticmethod @staticmethod
@@ -456,7 +489,7 @@ class PycordEvent:
return self.ends.replace(tzinfo=ZoneInfo("UTC")) return self.ends.replace(tzinfo=ZoneInfo("UTC"))
def get_start_date_localized(self, tz: str | timezone | ZoneInfo) -> datetime: def get_start_date_localized(self, tz: tzinfo) -> datetime:
"""Get the event start date in the provided timezone. """Get the event start date in the provided timezone.
Returns: Returns:
@@ -470,7 +503,7 @@ class PycordEvent:
return self.starts.replace(tzinfo=tz) return self.starts.replace(tzinfo=tz)
def get_end_date_localized(self, tz: str | timezone | ZoneInfo) -> datetime: def get_end_date_localized(self, tz: tzinfo) -> datetime:
"""Get the event end date in the provided timezone. """Get the event end date in the provided timezone.
Returns: Returns:

View File

@@ -10,6 +10,7 @@ from discord import File
from libbot.cache.classes import Cache from libbot.cache.classes import Cache
from pymongo.results import InsertOneResult from pymongo.results import InsertOneResult
from classes.abstract import Cacheable
from classes.errors import EventStageNotFoundError from classes.errors import EventStageNotFoundError
from modules.database import col_stages from modules.database import col_stages
from modules.utils import get_logger, restore_from_cache from modules.utils import get_logger, restore_from_cache
@@ -18,7 +19,7 @@ logger: Logger = get_logger(__name__)
@dataclass @dataclass
class PycordEventStage: class PycordEventStage(Cacheable):
__slots__ = ( __slots__ = (
"_id", "_id",
"event_id", "event_id",
@@ -61,7 +62,7 @@ class PycordEventStage:
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, stage_id, cache=cache) cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, stage_id, cache=cache)
if cached_entry is not None: if cached_entry is not None:
return cls(**cached_entry) return cls(**cls._entry_from_cache(cached_entry))
db_entry = await cls.__collection__.find_one( db_entry = await cls.__collection__.find_one(
{"_id": stage_id if isinstance(stage_id, ObjectId) else ObjectId(stage_id)} {"_id": stage_id if isinstance(stage_id, ObjectId) else ObjectId(stage_id)}
@@ -71,7 +72,7 @@ class PycordEventStage:
raise EventStageNotFoundError(stage_id) raise EventStageNotFoundError(stage_id)
if cache is not None: if cache is not None:
cache.set_json(f"{cls.__short_name__}_{stage_id}", db_entry) cache.set_json(f"{cls.__short_name__}_{stage_id}", cls._entry_to_cache(dict(db_entry)))
return cls(**db_entry) return cls(**db_entry)
@@ -104,7 +105,7 @@ class PycordEventStage:
db_entry["_id"] = insert_result.inserted_id db_entry["_id"] = insert_result.inserted_id
if cache is not None: if cache is not None:
cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry) cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry))
return cls(**db_entry) return cls(**db_entry)
@@ -159,10 +160,10 @@ class PycordEventStage:
if cache is None: if cache is None:
return return
user_dict: Dict[str, Any] = self.to_dict() object_dict: Dict[str, Any] = self.to_dict(json_compatible=True)
if user_dict is not None: if object_dict is not None:
cache.set_json(self._get_cache_key(), user_dict) cache.set_json(self._get_cache_key(), object_dict)
else: else:
self._delete_cache(cache) self._delete_cache(cache)
@@ -172,6 +173,26 @@ class PycordEventStage:
cache.delete(self._get_cache_key()) cache.delete(self._get_cache_key())
@staticmethod
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
cache_entry: Dict[str, Any] = db_entry.copy()
cache_entry["_id"] = str(cache_entry["_id"])
cache_entry["event_id"] = str(cache_entry["event_id"])
cache_entry["created"] = cache_entry["created"].isoformat()
return cache_entry
@staticmethod
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
db_entry: Dict[str, Any] = cache_entry.copy()
db_entry["_id"] = ObjectId(db_entry["_id"])
db_entry["event_id"] = ObjectId(db_entry["event_id"])
db_entry["created"] = datetime.fromisoformat(db_entry["created"])
return db_entry
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]: def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
"""Convert the object to a JSON representation. """Convert the object to a JSON representation.
@@ -186,7 +207,7 @@ class PycordEventStage:
"event_id": self.event_id if not json_compatible else str(self.event_id), "event_id": self.event_id if not json_compatible else str(self.event_id),
"guild_id": self.guild_id, "guild_id": self.guild_id,
"sequence": self.sequence, "sequence": self.sequence,
"created": self.created, "created": self.created if not json_compatible else self.created.isoformat(),
"creator_id": self.creator_id, "creator_id": self.creator_id,
"question": self.question, "question": self.question,
"answer": self.answer, "answer": self.answer,

View File

@@ -6,6 +6,7 @@ from bson import ObjectId
from libbot.cache.classes import Cache from libbot.cache.classes import Cache
from pymongo.results import InsertOneResult from pymongo.results import InsertOneResult
from classes.abstract import Cacheable
from classes.errors import GuildNotFoundError from classes.errors import GuildNotFoundError
from modules.database import col_guilds from modules.database import col_guilds
from modules.utils import get_logger, restore_from_cache from modules.utils import get_logger, restore_from_cache
@@ -14,7 +15,7 @@ logger: Logger = get_logger(__name__)
@dataclass @dataclass
class PycordGuild: class PycordGuild(Cacheable):
"""Dataclass of DB entry of a guild""" """Dataclass of DB entry of a guild"""
__slots__ = ( __slots__ = (
@@ -57,7 +58,7 @@ class PycordGuild:
cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, guild_id, cache=cache) cached_entry: Dict[str, Any] | None = restore_from_cache(cls.__short_name__, guild_id, cache=cache)
if cached_entry is not None: if cached_entry is not None:
return cls(**cached_entry) return cls(**cls._entry_from_cache(cached_entry))
db_entry = await cls.__collection__.find_one({"id": guild_id}) db_entry = await cls.__collection__.find_one({"id": guild_id})
@@ -72,7 +73,7 @@ class PycordGuild:
db_entry["_id"] = insert_result.inserted_id db_entry["_id"] = insert_result.inserted_id
if cache is not None: if cache is not None:
cache.set_json(f"{cls.__short_name__}_{guild_id}", db_entry) cache.set_json(f"{cls.__short_name__}_{guild_id}", cls._entry_to_cache(db_entry))
return cls(**db_entry) return cls(**db_entry)
@@ -115,10 +116,10 @@ class PycordGuild:
if cache is None: if cache is None:
return return
user_dict: Dict[str, Any] = self.to_dict() object_dict: Dict[str, Any] = self.to_dict(json_compatible=True)
if user_dict is not None: if object_dict is not None:
cache.set_json(self._get_cache_key(), user_dict) cache.set_json(self._get_cache_key(), object_dict)
else: else:
self._delete_cache(cache) self._delete_cache(cache)
@@ -128,6 +129,22 @@ class PycordGuild:
cache.delete(self._get_cache_key()) cache.delete(self._get_cache_key())
@staticmethod
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
cache_entry: Dict[str, Any] = db_entry.copy()
cache_entry["_id"] = str(cache_entry["_id"])
return cache_entry
@staticmethod
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
db_entry: Dict[str, Any] = cache_entry.copy()
db_entry["_id"] = ObjectId(db_entry["_id"])
return db_entry
def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]: def to_dict(self, json_compatible: bool = False) -> Dict[str, Any]:
"""Convert the object to a JSON representation. """Convert the object to a JSON representation.

View File

@@ -17,6 +17,7 @@ from discord.abc import GuildChannel
from libbot.cache.classes import Cache from libbot.cache.classes import Cache
from pymongo.results import InsertOneResult from pymongo.results import InsertOneResult
from classes.abstract import Cacheable
from classes.errors import ( from classes.errors import (
DiscordCategoryNotFoundError, DiscordCategoryNotFoundError,
DiscordChannelNotFoundError, DiscordChannelNotFoundError,
@@ -33,9 +34,17 @@ logger: Logger = get_logger(__name__)
@dataclass @dataclass
class PycordUser: class PycordUser(Cacheable):
"""Dataclass of DB entry of a user""" """Dataclass of DB entry of a user"""
# TODO Implement this
async def update(self, cache: Optional[Cache] = None, **kwargs: Any) -> None:
pass
# TODO Implement this
async def reset(self, *args: str, cache: Optional[Cache] = None) -> None:
pass
__slots__ = ( __slots__ = (
"_id", "_id",
"id", "id",
@@ -83,7 +92,7 @@ class PycordUser:
) )
if cached_entry is not None: if cached_entry is not None:
return cls(**cached_entry) return cls(**cls._entry_from_cache(cached_entry))
db_entry = await cls.__collection__.find_one({"id": user_id, "guild_id": guild_id}) db_entry = await cls.__collection__.find_one({"id": user_id, "guild_id": guild_id})
@@ -98,7 +107,7 @@ class PycordUser:
db_entry["_id"] = insert_result.inserted_id db_entry["_id"] = insert_result.inserted_id
if cache is not None: if cache is not None:
cache.set_json(f"{cls.__short_name__}_{user_id}_{guild_id}", db_entry) cache.set_json(f"{cls.__short_name__}_{user_id}_{guild_id}", cls._entry_to_cache(db_entry))
return cls(**db_entry) return cls(**db_entry)
@@ -186,10 +195,10 @@ class PycordUser:
if cache is None: if cache is None:
return return
user_dict: Dict[str, Any] = self.to_dict() object_dict: Dict[str, Any] = self.to_dict(json_compatible=True)
if user_dict is not None: if object_dict is not None:
cache.set_json(self._get_cache_key(), user_dict) cache.set_json(self._get_cache_key(), object_dict)
else: else:
self._delete_cache(cache) self._delete_cache(cache)
@@ -199,6 +208,46 @@ class PycordUser:
cache.delete(self._get_cache_key()) cache.delete(self._get_cache_key())
@staticmethod
def _entry_to_cache(db_entry: Dict[str, Any]) -> Dict[str, Any]:
cache_entry: Dict[str, Any] = db_entry.copy()
cache_entry["_id"] = str(cache_entry["_id"])
cache_entry["current_event_id"] = (
None if cache_entry["current_event_id"] is None else str(cache_entry["current_event_id"])
)
cache_entry["current_stage_id"] = (
None if cache_entry["current_stage_id"] is None else str(cache_entry["current_stage_id"])
)
cache_entry["registered_event_ids"] = [
str(event_id) for event_id in cache_entry["registered_event_ids"]
]
cache_entry["completed_event_ids"] = [
str(event_id) for event_id in cache_entry["completed_event_ids"]
]
return cache_entry
@staticmethod
def _entry_from_cache(cache_entry: Dict[str, Any]) -> Dict[str, Any]:
db_entry: Dict[str, Any] = cache_entry.copy()
db_entry["_id"] = ObjectId(db_entry["_id"])
db_entry["current_event_id"] = (
None if db_entry["current_event_id"] is None else ObjectId(db_entry["current_event_id"])
)
db_entry["current_stage_id"] = (
None if db_entry["current_stage_id"] is None else ObjectId(db_entry["current_stage_id"])
)
db_entry["registered_event_ids"] = [
ObjectId(event_id) for event_id in db_entry["registered_event_ids"]
]
db_entry["completed_event_ids"] = [
ObjectId(event_id) for event_id in db_entry["completed_event_ids"]
]
return db_entry
# TODO Add documentation # TODO Add documentation
@staticmethod @staticmethod
def get_defaults(user_id: Optional[int] = None, guild_id: Optional[int] = None) -> Dict[str, Any]: def get_defaults(user_id: Optional[int] = None, guild_id: Optional[int] = None) -> Dict[str, Any]:

View File

@@ -88,14 +88,15 @@ class CogConfig(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
try: try:
timezone_parsed: ZoneInfo = ZoneInfo(timezone) timezone_parsed: ZoneInfo = ZoneInfo(timezone)
except ZoneInfoNotFoundError: except ZoneInfoNotFoundError:
await ctx.respond( await ctx.respond(
self.bot._("timezone_invalid", "messages", locale=ctx.locale).format(timezone=timezone) self.bot._("timezone_invalid", "messages", locale=ctx.locale).format(timezone=timezone),
ephemeral=True,
) )
return return
@@ -130,7 +131,7 @@ class CogConfig(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
await guild.purge(self.bot.cache) await guild.purge(self.bot.cache)
@@ -146,11 +147,13 @@ class CogConfig(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
if not guild.is_configured(): if not guild.is_configured():
await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True
)
return return
await ctx.respond( await ctx.respond(

View File

@@ -84,11 +84,13 @@ class CogEvent(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
if not guild.is_configured(): if not guild.is_configured():
await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True
)
return return
guild_timezone: ZoneInfo = ZoneInfo(guild.timezone) guild_timezone: ZoneInfo = ZoneInfo(guild.timezone)
@@ -97,7 +99,9 @@ class CogEvent(Cog):
start_date: datetime = datetime.strptime(start, "%d.%m.%Y %H:%M").replace(tzinfo=guild_timezone) start_date: datetime = datetime.strptime(start, "%d.%m.%Y %H:%M").replace(tzinfo=guild_timezone)
end_date: datetime = datetime.strptime(end, "%d.%m.%Y %H:%M").replace(tzinfo=guild_timezone) end_date: datetime = datetime.strptime(end, "%d.%m.%Y %H:%M").replace(tzinfo=guild_timezone)
except ValueError: except ValueError:
await ctx.respond(self.bot._("event_dates_parsing_failed", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("event_dates_parsing_failed", "messages", locale=ctx.locale), ephemeral=True
)
return return
if not await validate_event_validity(ctx, name, start_date, end_date, to_utc=True): if not await validate_event_validity(ctx, name, start_date, end_date, to_utc=True):
@@ -180,7 +184,7 @@ class CogEvent(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
try: try:
@@ -190,7 +194,9 @@ class CogEvent(Cog):
return return
if not guild.is_configured(): if not guild.is_configured():
await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True
)
return return
guild_timezone: ZoneInfo = ZoneInfo(guild.timezone) guild_timezone: ZoneInfo = ZoneInfo(guild.timezone)
@@ -202,7 +208,9 @@ class CogEvent(Cog):
else datetime.strptime(start, "%d.%m.%Y %H:%M").replace(tzinfo=guild_timezone) else datetime.strptime(start, "%d.%m.%Y %H:%M").replace(tzinfo=guild_timezone)
) )
except ValueError: except ValueError:
await ctx.respond(self.bot._("event_start_date_parsing_failed", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("event_start_date_parsing_failed", "messages", locale=ctx.locale), ephemeral=True
)
return return
try: try:
@@ -212,7 +220,9 @@ class CogEvent(Cog):
else datetime.strptime(end, "%d.%m.%Y %H:%M").replace(tzinfo=guild_timezone) else datetime.strptime(end, "%d.%m.%Y %H:%M").replace(tzinfo=guild_timezone)
) )
except ValueError: except ValueError:
await ctx.respond(self.bot._("event_end_date_parsing_failed", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("event_end_date_parsing_failed", "messages", locale=ctx.locale), ephemeral=True
)
return return
if not await validate_event_validity( if not await validate_event_validity(
@@ -280,7 +290,7 @@ class CogEvent(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
try: try:
@@ -290,7 +300,9 @@ class CogEvent(Cog):
return return
if not guild.is_configured(): if not guild.is_configured():
await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True
)
return return
start_date: datetime = pycord_event.starts.replace(tzinfo=ZoneInfo("UTC")) start_date: datetime = pycord_event.starts.replace(tzinfo=ZoneInfo("UTC"))
@@ -305,7 +317,8 @@ class CogEvent(Cog):
await ctx.respond( await ctx.respond(
self.bot._("event_not_editable", "messages", locale=ctx.locale).format( self.bot._("event_not_editable", "messages", locale=ctx.locale).format(
event_name=pycord_event.name event_name=pycord_event.name
) ),
ephemeral=True,
) )
return return

View File

@@ -34,11 +34,13 @@ class CogGuess(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
if not guild.is_configured(): if not guild.is_configured():
await ctx.respond(self.bot._("guild_unconfigured", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("guild_unconfigured", "messages", locale=ctx.locale), ephemeral=True
)
return return
user: PycordUser = await self.bot.find_user(ctx.author, ctx.guild) user: PycordUser = await self.bot.find_user(ctx.author, ctx.guild)

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from bson.errors import InvalidId from bson.errors import InvalidId
from discord import ApplicationContext, Cog, TextChannel, option, slash_command, File from discord import ApplicationContext, Cog, File, TextChannel, option, slash_command
from discord.utils import basic_autocomplete from discord.utils import basic_autocomplete
from libbot.i18n import _, in_every_locale from libbot.i18n import _, in_every_locale
@@ -39,7 +39,7 @@ class CogRegister(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
try: try:
@@ -49,17 +49,21 @@ class CogRegister(Cog):
return return
if not guild.is_configured(): if not guild.is_configured():
await ctx.respond(self.bot._("guild_unconfigured", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("guild_unconfigured", "messages", locale=ctx.locale), ephemeral=True
)
return return
user: PycordUser = await self.bot.find_user(ctx.author, ctx.guild) user: PycordUser = await self.bot.find_user(ctx.author, ctx.guild)
if user.is_jailed: if user.is_jailed:
await ctx.respond(self.bot._("jailed_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("jailed_error", "messages", locale=ctx.locale), ephemeral=True)
return return
if pycord_event._id in user.registered_event_ids: if pycord_event._id in user.registered_event_ids:
await ctx.respond(self.bot._("register_already_registered", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("register_already_registered", "messages", locale=ctx.locale), ephemeral=True
)
return return
await user.event_register(pycord_event._id, cache=self.bot.cache) await user.event_register(pycord_event._id, cache=self.bot.cache)

View File

@@ -84,11 +84,13 @@ class CogStage(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
if not guild.is_configured(): if not guild.is_configured():
await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True
)
return return
try: try:
@@ -196,11 +198,13 @@ class CogStage(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
if not guild.is_configured(): if not guild.is_configured():
await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True
)
return return
try: try:
@@ -215,11 +219,13 @@ class CogStage(Cog):
try: try:
event_stage: PycordEventStage = await self.bot.find_event_stage(stage) event_stage: PycordEventStage = await self.bot.find_event_stage(stage)
except (InvalidId, EventStageNotFoundError): except (InvalidId, EventStageNotFoundError):
await ctx.respond(self.bot._("stage_not_found", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("stage_not_found", "messages", locale=ctx.locale), ephemeral=True)
return return
if order is not None and order > len(pycord_event.stage_ids): if order is not None and order > len(pycord_event.stage_ids):
await ctx.respond(self.bot._("stage_sequence_out_of_range", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("stage_sequence_out_of_range", "messages", locale=ctx.locale), ephemeral=True
)
return return
processed_media: List[Dict[str, Any]] = ( processed_media: List[Dict[str, Any]] = (
@@ -278,11 +284,13 @@ class CogStage(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
if not guild.is_configured(): if not guild.is_configured():
await ctx.respond(self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("guild_unconfigured_admin", "messages", locale=ctx.locale), ephemeral=True
)
return return
try: try:
@@ -297,7 +305,7 @@ class CogStage(Cog):
try: try:
event_stage: PycordEventStage = await self.bot.find_event_stage(stage) event_stage: PycordEventStage = await self.bot.find_event_stage(stage)
except (InvalidId, EventStageNotFoundError): except (InvalidId, EventStageNotFoundError):
await ctx.respond(self.bot._("stage_not_found", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("stage_not_found", "messages", locale=ctx.locale), ephemeral=True)
return return
await pycord_event.remove_stage(self.bot, event_stage._id, cache=self.bot.cache) await pycord_event.remove_stage(self.bot, event_stage._id, cache=self.bot.cache)

View File

@@ -43,7 +43,7 @@ class CogUnregister(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
try: try:
@@ -53,17 +53,22 @@ class CogUnregister(Cog):
return return
if not guild.is_configured(): if not guild.is_configured():
await ctx.respond(self.bot._("guild_unconfigured", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("guild_unconfigured", "messages", locale=ctx.locale), ephemeral=True
)
return return
user: PycordUser = await self.bot.find_user(ctx.author, ctx.guild) user: PycordUser = await self.bot.find_user(ctx.author, ctx.guild)
if user.is_jailed: if user.is_jailed:
await ctx.respond(self.bot._("jailed_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("jailed_error", "messages", locale=ctx.locale), ephemeral=True)
return return
# TODO Fix a bug where registered_event_ids is invalid because of caching
if pycord_event._id not in user.registered_event_ids: if pycord_event._id not in user.registered_event_ids:
await ctx.respond(self.bot._("unregister_not_registered", "messages", locale=ctx.locale)) await ctx.respond(
self.bot._("unregister_not_registered", "messages", locale=ctx.locale), ephemeral=True
)
return return
await user.event_unregister(pycord_event._id, cache=self.bot.cache) await user.event_unregister(pycord_event._id, cache=self.bot.cache)

View File

@@ -1,27 +1,27 @@
from datetime import datetime from datetime import datetime
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any from typing import Any, Dict, List
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from bson import ObjectId from bson import ObjectId
from bson.errors import InvalidId from bson.errors import InvalidId
from discord import ( from discord import (
ApplicationContext, ApplicationContext,
File,
SlashCommandGroup, SlashCommandGroup,
TextChannel,
User, User,
option, option,
File,
TextChannel,
) )
from discord.ext.commands import Cog from discord.ext.commands import Cog
from libbot.i18n import _, in_every_locale from libbot.i18n import _, in_every_locale
from classes import PycordUser, PycordEvent, PycordGuild from classes import PycordEvent, PycordGuild, PycordUser
from classes.errors import GuildNotFoundError from classes.errors import GuildNotFoundError
from classes.pycord_bot import PycordBot from classes.pycord_bot import PycordBot
from modules.database import col_users from modules.database import col_users
from modules.utils import is_operation_confirmed, get_logger from modules.utils import get_logger, is_operation_confirmed, get_utc_now
logger: Logger = get_logger(__name__) logger: Logger = get_logger(__name__)
@@ -54,35 +54,45 @@ class CogUser(Cog):
try: try:
guild: PycordGuild = await self.bot.find_guild(ctx.guild.id) guild: PycordGuild = await self.bot.find_guild(ctx.guild.id)
except (InvalidId, GuildNotFoundError): except (InvalidId, GuildNotFoundError):
await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale)) await ctx.respond(self.bot._("unexpected_error", "messages", locale=ctx.locale), ephemeral=True)
return return
pycord_user: PycordUser = await self.bot.find_user(user.id, ctx.guild.id) pycord_user: PycordUser = await self.bot.find_user(user.id, ctx.guild.id)
events: List[PycordEvent] = [] events: List[PycordEvent] = []
utc_now: datetime = get_utc_now()
pipeline: List[Dict[str, Any]] = [ pipeline: List[Dict[str, Any]] = [
{"$match": {"id": pycord_user.id}}, {"$match": {"id": pycord_user.id}},
{ {
"$lookup": { "$lookup": {
"from": "events", "from": "events",
"localField": "registered_event_ids", "let": {"event_ids": "$registered_event_ids"},
"foreignField": "_id", "pipeline": [
{
"$match": {
"$expr": {
"$and": [
{"$in": ["$_id", "$$event_ids"]},
{"$eq": ["$ended", None]},
{"$gt": ["$ends", utc_now]},
{"$lt": ["$starts", utc_now]},
{"$eq": ["$is_cancelled", False]},
]
}
}
}
],
"as": "registered_events", "as": "registered_events",
} }
}, },
{ {"$match": {"registered_events.0": {"$exists": True}}},
"$match": {
"registered_events.ended": None,
"registered_events.ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"registered_events.starts": {"$lt": datetime.now(tz=ZoneInfo("UTC"))},
"registered_events.is_cancelled": False,
}
},
] ]
async for result in col_users.aggregate(pipeline): async with await col_users.aggregate(pipeline) as cursor:
for registered_event in result["registered_events"]: async for result in cursor:
events.append(PycordEvent(**registered_event)) for registered_event in result["registered_events"]:
events.append(PycordEvent(**registered_event))
for event in events: for event in events:
if pycord_user.current_event_id is not None and pycord_user.current_event_id != event._id: if pycord_user.current_event_id is not None and pycord_user.current_event_id != event._id:

View File

@@ -6,13 +6,13 @@ from zoneinfo import ZoneInfo
from bson import ObjectId from bson import ObjectId
from bson.errors import InvalidId from bson.errors import InvalidId
from discord import Activity, ActivityType, Cog, Member, TextChannel, File from discord import Activity, ActivityType, Cog, File, Member, TextChannel
from classes import PycordEvent, PycordGuild, PycordUser from classes import PycordEvent, PycordGuild, PycordUser
from classes.errors import GuildNotFoundError from classes.errors import GuildNotFoundError
from classes.pycord_bot import PycordBot from classes.pycord_bot import PycordBot
from modules.database import col_users from modules.database import col_users
from modules.utils import get_logger from modules.utils import get_logger, get_utc_now
logger: Logger = get_logger(__name__) logger: Logger = get_logger(__name__)
@@ -78,29 +78,39 @@ class CogUtility(Cog):
user: PycordUser = await self.bot.find_user(member.id, member.guild.id) user: PycordUser = await self.bot.find_user(member.id, member.guild.id)
events: List[PycordEvent] = [] events: List[PycordEvent] = []
utc_now: datetime = get_utc_now()
pipeline: List[Dict[str, Any]] = [ pipeline: List[Dict[str, Any]] = [
{"$match": {"id": user.id}}, {"$match": {"id": user.id}},
{ {
"$lookup": { "$lookup": {
"from": "events", "from": "events",
"localField": "registered_event_ids", "let": {"event_ids": "$registered_event_ids"},
"foreignField": "_id", "pipeline": [
{
"$match": {
"$expr": {
"$and": [
{"$in": ["$_id", "$$event_ids"]},
{"$eq": ["$ended", None]},
{"$gt": ["$ends", utc_now]},
{"$lt": ["$starts", utc_now]},
{"$eq": ["$is_cancelled", False]},
]
}
}
}
],
"as": "registered_events", "as": "registered_events",
} }
}, },
{ {"$match": {"registered_events.0": {"$exists": True}}},
"$match": {
"registered_events.ended": None,
"registered_events.ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"registered_events.starts": {"$lt": datetime.now(tz=ZoneInfo("UTC"))},
"registered_events.is_cancelled": False,
}
},
] ]
async for result in col_users.aggregate(pipeline): async with await col_users.aggregate(pipeline) as cursor:
for registered_event in result["registered_events"]: async for result in cursor:
events.append(PycordEvent(**registered_event)) for registered_event in result["registered_events"]:
events.append(PycordEvent(**registered_event))
for event in events: for event in events:
if user.current_event_id is not None and user.current_event_id != event._id: if user.current_event_id is not None and user.current_event_id != event._id:

View File

@@ -2,8 +2,10 @@
from typing import Any, Mapping from typing import Any, Mapping
from async_pymongo import AsyncClient, AsyncCollection, AsyncDatabase
from libbot.utils import config_get from libbot.utils import config_get
from pymongo import AsyncMongoClient
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.database import AsyncDatabase
db_config: Mapping[str, Any] = config_get("database") db_config: Mapping[str, Any] = config_get("database")
@@ -19,7 +21,7 @@ else:
con_string = "mongodb://{0}:{1}/{2}".format(db_config["host"], db_config["port"], db_config["name"]) con_string = "mongodb://{0}:{1}/{2}".format(db_config["host"], db_config["port"], db_config["name"])
# Async declarations # Async declarations
db_client = AsyncClient(con_string) db_client = AsyncMongoClient(con_string)
db: AsyncDatabase = db_client.get_database(name=db_config["name"]) db: AsyncDatabase = db_client.get_database(name=db_config["name"])
col_users: AsyncCollection = db.get_collection("users") col_users: AsyncCollection = db.get_collection("users")
@@ -27,10 +29,10 @@ col_guilds: AsyncCollection = db.get_collection("guilds")
col_events: AsyncCollection = db.get_collection("events") col_events: AsyncCollection = db.get_collection("events")
col_stages: AsyncCollection = db.get_collection("stages") col_stages: AsyncCollection = db.get_collection("stages")
# Update indexes # Update indexes
db.dispatch.get_collection("users").create_index("id", name="user_id", unique=True) async def _update_database_indexes() -> None:
db.dispatch.get_collection("guilds").create_index("id", name="guild_id", unique=True) await col_users.create_index("id", name="user_id", unique=True)
db.dispatch.get_collection("events").create_index("guild_id", name="guild_id", unique=False) await col_guilds.create_index("id", name="guild_id", unique=True)
db.dispatch.get_collection("stages").create_index( await col_events.create_index("guild_id", name="guild_id", unique=False)
["event_id", "guild_id"], name="event_id-and-guild_id", unique=False await col_stages.create_index(["event_id", "guild_id"], name="event_id-and-guild_id", unique=False)
)

View File

@@ -7,7 +7,7 @@ from .autocomplete_utils import (
autocomplete_user_registered_events, autocomplete_user_registered_events,
) )
from .cache_utils import restore_from_cache from .cache_utils import restore_from_cache
from .datetime_utils import get_unix_timestamp from .datetime_utils import get_unix_timestamp, get_utc_now
from .event_utils import validate_event_validity from .event_utils import validate_event_validity
from .git_utils import get_current_commit from .git_utils import get_current_commit
from .logging_utils import get_logger, get_logging_config from .logging_utils import get_logger, get_logging_config

View File

@@ -49,31 +49,41 @@ async def autocomplete_user_available_events(ctx: AutocompleteContext) -> List[O
async def autocomplete_user_registered_events(ctx: AutocompleteContext) -> List[OptionChoice]: async def autocomplete_user_registered_events(ctx: AutocompleteContext) -> List[OptionChoice]:
"""Return list of active events user is registered in""" """Return list of active events user is registered in"""
utc_now: datetime = datetime.now(tz=ZoneInfo("UTC"))
pipeline: List[Dict[str, Any]] = [ pipeline: List[Dict[str, Any]] = [
{"$match": {"id": ctx.interaction.user.id}}, {"$match": {"id": ctx.interaction.user.id}},
{ {
"$lookup": { "$lookup": {
"from": "events", "from": "events",
"localField": "registered_event_ids", "let": {"event_ids": "$registered_event_ids"},
"foreignField": "_id", "pipeline": [
{
"$match": {
"$expr": {
"$and": [
{"$in": ["$_id", "$$event_ids"]},
{"$eq": ["$ended", None]},
{"$gt": ["$ends", utc_now]},
{"$gt": ["$starts", utc_now]},
{"$eq": ["$is_cancelled", False]},
]
}
}
}
],
"as": "registered_events", "as": "registered_events",
} }
}, },
{ {"$match": {"registered_events.0": {"$exists": True}}},
"$match": {
"registered_events.ended": None,
"registered_events.ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"registered_events.starts": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
"registered_events.is_cancelled": False,
}
},
] ]
event_names: List[OptionChoice] = [] event_names: List[OptionChoice] = []
async for result in col_users.aggregate(pipeline): async with await col_users.aggregate(pipeline) as cursor:
for registered_event in result["registered_events"]: async for result in cursor:
event_names.append(OptionChoice(registered_event["name"], str(registered_event["_id"]))) for registered_event in result["registered_events"]:
event_names.append(OptionChoice(registered_event["name"], str(registered_event["_id"])))
return event_names return event_names

View File

@@ -5,3 +5,8 @@ from zoneinfo import ZoneInfo
# TODO Add documentation # TODO Add documentation
def get_unix_timestamp(date: datetime, to_utc: bool = False) -> int: def get_unix_timestamp(date: datetime, to_utc: bool = False) -> int:
return int((date if not to_utc else date.replace(tzinfo=ZoneInfo("UTC"))).timestamp()) return int((date if not to_utc else date.replace(tzinfo=ZoneInfo("UTC"))).timestamp())
# TODO Add documentation
def get_utc_now() -> datetime:
return datetime.now(tz=ZoneInfo("UTC"))

View File

@@ -23,15 +23,15 @@ async def validate_event_validity(
end_date_internal: datetime = end_date.astimezone(ZoneInfo("UTC")) if to_utc else end_date end_date_internal: datetime = end_date.astimezone(ZoneInfo("UTC")) if to_utc else end_date
if start_date_internal < datetime.now(tz=ZoneInfo("UTC")): if start_date_internal < datetime.now(tz=ZoneInfo("UTC")):
await ctx.respond(_("event_start_past", "messages", locale=ctx.locale)) await ctx.respond(_("event_start_past", "messages", locale=ctx.locale), ephemeral=True)
return False return False
if end_date_internal < datetime.now(tz=ZoneInfo("UTC")): if end_date_internal < datetime.now(tz=ZoneInfo("UTC")):
await ctx.respond(_("event_end_past", "messages", locale=ctx.locale)) await ctx.respond(_("event_end_past", "messages", locale=ctx.locale), ephemeral=True)
return False return False
if start_date_internal >= end_date_internal: if start_date_internal >= end_date_internal:
await ctx.respond(_("event_end_before_start", "messages", locale=ctx.locale)) await ctx.respond(_("event_end_before_start", "messages", locale=ctx.locale), ephemeral=True)
return False return False
# TODO Add validation for concurrent events. # TODO Add validation for concurrent events.
@@ -47,7 +47,7 @@ async def validate_event_validity(
query["_id"] = {"$ne": event_id} query["_id"] = {"$ne": event_id}
if (await col_events.find_one(query)) is not None: if (await col_events.find_one(query)) is not None:
await ctx.respond(_("event_name_duplicate", "messages", locale=ctx.locale)) await ctx.respond(_("event_name_duplicate", "messages", locale=ctx.locale), ephemeral=True)
return False return False
return True return True

View File

@@ -7,7 +7,7 @@ from libbot.i18n import _
async def is_operation_confirmed(ctx: ApplicationContext, confirm: bool) -> bool: async def is_operation_confirmed(ctx: ApplicationContext, confirm: bool) -> bool:
if confirm is None or not confirm: if confirm is None or not confirm:
await ctx.respond(ctx.bot._("operation_unconfirmed", "messages", locale=ctx.locale)) await ctx.respond(ctx.bot._("operation_unconfirmed", "messages", locale=ctx.locale), ephemeral=True)
return False return False
return True return True
@@ -18,7 +18,7 @@ async def is_event_status_valid(
event: "PycordEvent", event: "PycordEvent",
) -> bool: ) -> bool:
if event.is_cancelled: if event.is_cancelled:
await ctx.respond(_("event_is_cancelled", "messages", locale=ctx.locale)) await ctx.respond(_("event_is_cancelled", "messages", locale=ctx.locale), ephemeral=True)
return False return False
if ( if (
@@ -26,7 +26,7 @@ async def is_event_status_valid(
<= datetime.now(tz=ZoneInfo("UTC")) <= datetime.now(tz=ZoneInfo("UTC"))
<= event.ends.replace(tzinfo=ZoneInfo("UTC")) <= event.ends.replace(tzinfo=ZoneInfo("UTC"))
): ):
await ctx.respond(_("event_ongoing_not_editable", "messages", locale=ctx.locale)) await ctx.respond(_("event_ongoing_not_editable", "messages", locale=ctx.locale), ephemeral=True)
return False return False
return True return True

View File

@@ -1,10 +1,10 @@
aiodns~=3.2.0 aiodns~=3.2.0
apscheduler~=3.11.0 apscheduler~=3.11.0
async_pymongo==0.1.11
brotlipy~=0.7.0 brotlipy~=0.7.0
faust-cchardet~=2.1.19 faust-cchardet~=2.1.19
libbot[speed,pycord,cache]==4.1.0 libbot[speed,pycord,cache]==4.1.0
mongodb-migrations==1.3.1 mongodb-migrations==1.3.1
msgspec~=0.19.0 msgspec~=0.19.0
pymongo~=4.12.1,>=4.9
pytz~=2025.1 pytz~=2025.1
typing_extensions>=4.11.0 typing_extensions>=4.11.0