Worked on #13 and #4. There are some caching issues left, though. Introduced abstract class Cacheable. Replaced async_pymongo with pymongo
This commit is contained in:
@@ -2,8 +2,10 @@
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
from async_pymongo import AsyncClient, AsyncCollection, AsyncDatabase
|
||||
from libbot.utils import config_get
|
||||
from pymongo import AsyncMongoClient
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
|
||||
db_config: Mapping[str, Any] = config_get("database")
|
||||
|
||||
@@ -19,7 +21,7 @@ else:
|
||||
con_string = "mongodb://{0}:{1}/{2}".format(db_config["host"], db_config["port"], db_config["name"])
|
||||
|
||||
# Async declarations
|
||||
db_client = AsyncClient(con_string)
|
||||
db_client = AsyncMongoClient(con_string)
|
||||
db: AsyncDatabase = db_client.get_database(name=db_config["name"])
|
||||
|
||||
col_users: AsyncCollection = db.get_collection("users")
|
||||
@@ -27,10 +29,10 @@ col_guilds: AsyncCollection = db.get_collection("guilds")
|
||||
col_events: AsyncCollection = db.get_collection("events")
|
||||
col_stages: AsyncCollection = db.get_collection("stages")
|
||||
|
||||
|
||||
# Update indexes
|
||||
db.dispatch.get_collection("users").create_index("id", name="user_id", unique=True)
|
||||
db.dispatch.get_collection("guilds").create_index("id", name="guild_id", unique=True)
|
||||
db.dispatch.get_collection("events").create_index("guild_id", name="guild_id", unique=False)
|
||||
db.dispatch.get_collection("stages").create_index(
|
||||
["event_id", "guild_id"], name="event_id-and-guild_id", unique=False
|
||||
)
|
||||
async def _update_database_indexes() -> None:
|
||||
await col_users.create_index("id", name="user_id", unique=True)
|
||||
await col_guilds.create_index("id", name="guild_id", unique=True)
|
||||
await col_events.create_index("guild_id", name="guild_id", unique=False)
|
||||
await col_stages.create_index(["event_id", "guild_id"], name="event_id-and-guild_id", unique=False)
|
||||
|
@@ -7,7 +7,7 @@ from .autocomplete_utils import (
|
||||
autocomplete_user_registered_events,
|
||||
)
|
||||
from .cache_utils import restore_from_cache
|
||||
from .datetime_utils import get_unix_timestamp
|
||||
from .datetime_utils import get_unix_timestamp, get_utc_now
|
||||
from .event_utils import validate_event_validity
|
||||
from .git_utils import get_current_commit
|
||||
from .logging_utils import get_logger, get_logging_config
|
||||
|
@@ -49,31 +49,41 @@ async def autocomplete_user_available_events(ctx: AutocompleteContext) -> List[O
|
||||
async def autocomplete_user_registered_events(ctx: AutocompleteContext) -> List[OptionChoice]:
|
||||
"""Return list of active events user is registered in"""
|
||||
|
||||
utc_now: datetime = datetime.now(tz=ZoneInfo("UTC"))
|
||||
|
||||
pipeline: List[Dict[str, Any]] = [
|
||||
{"$match": {"id": ctx.interaction.user.id}},
|
||||
{
|
||||
"$lookup": {
|
||||
"from": "events",
|
||||
"localField": "registered_event_ids",
|
||||
"foreignField": "_id",
|
||||
"let": {"event_ids": "$registered_event_ids"},
|
||||
"pipeline": [
|
||||
{
|
||||
"$match": {
|
||||
"$expr": {
|
||||
"$and": [
|
||||
{"$in": ["$_id", "$$event_ids"]},
|
||||
{"$eq": ["$ended", None]},
|
||||
{"$gt": ["$ends", utc_now]},
|
||||
{"$gt": ["$starts", utc_now]},
|
||||
{"$eq": ["$is_cancelled", False]},
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"as": "registered_events",
|
||||
}
|
||||
},
|
||||
{
|
||||
"$match": {
|
||||
"registered_events.ended": None,
|
||||
"registered_events.ends": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
|
||||
"registered_events.starts": {"$gt": datetime.now(tz=ZoneInfo("UTC"))},
|
||||
"registered_events.is_cancelled": False,
|
||||
}
|
||||
},
|
||||
{"$match": {"registered_events.0": {"$exists": True}}},
|
||||
]
|
||||
|
||||
event_names: List[OptionChoice] = []
|
||||
|
||||
async for result in col_users.aggregate(pipeline):
|
||||
for registered_event in result["registered_events"]:
|
||||
event_names.append(OptionChoice(registered_event["name"], str(registered_event["_id"])))
|
||||
async with await col_users.aggregate(pipeline) as cursor:
|
||||
async for result in cursor:
|
||||
for registered_event in result["registered_events"]:
|
||||
event_names.append(OptionChoice(registered_event["name"], str(registered_event["_id"])))
|
||||
|
||||
return event_names
|
||||
|
||||
|
@@ -5,3 +5,8 @@ from zoneinfo import ZoneInfo
|
||||
# TODO Add documentation
|
||||
def get_unix_timestamp(date: datetime, to_utc: bool = False) -> int:
|
||||
return int((date if not to_utc else date.replace(tzinfo=ZoneInfo("UTC"))).timestamp())
|
||||
|
||||
|
||||
# TODO Add documentation
|
||||
def get_utc_now() -> datetime:
|
||||
return datetime.now(tz=ZoneInfo("UTC"))
|
||||
|
@@ -23,15 +23,15 @@ async def validate_event_validity(
|
||||
end_date_internal: datetime = end_date.astimezone(ZoneInfo("UTC")) if to_utc else end_date
|
||||
|
||||
if start_date_internal < datetime.now(tz=ZoneInfo("UTC")):
|
||||
await ctx.respond(_("event_start_past", "messages", locale=ctx.locale))
|
||||
await ctx.respond(_("event_start_past", "messages", locale=ctx.locale), ephemeral=True)
|
||||
return False
|
||||
|
||||
if end_date_internal < datetime.now(tz=ZoneInfo("UTC")):
|
||||
await ctx.respond(_("event_end_past", "messages", locale=ctx.locale))
|
||||
await ctx.respond(_("event_end_past", "messages", locale=ctx.locale), ephemeral=True)
|
||||
return False
|
||||
|
||||
if start_date_internal >= end_date_internal:
|
||||
await ctx.respond(_("event_end_before_start", "messages", locale=ctx.locale))
|
||||
await ctx.respond(_("event_end_before_start", "messages", locale=ctx.locale), ephemeral=True)
|
||||
return False
|
||||
|
||||
# TODO Add validation for concurrent events.
|
||||
@@ -47,7 +47,7 @@ async def validate_event_validity(
|
||||
query["_id"] = {"$ne": event_id}
|
||||
|
||||
if (await col_events.find_one(query)) is not None:
|
||||
await ctx.respond(_("event_name_duplicate", "messages", locale=ctx.locale))
|
||||
await ctx.respond(_("event_name_duplicate", "messages", locale=ctx.locale), ephemeral=True)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
@@ -7,7 +7,7 @@ from libbot.i18n import _
|
||||
|
||||
async def is_operation_confirmed(ctx: ApplicationContext, confirm: bool) -> bool:
|
||||
if confirm is None or not confirm:
|
||||
await ctx.respond(ctx.bot._("operation_unconfirmed", "messages", locale=ctx.locale))
|
||||
await ctx.respond(ctx.bot._("operation_unconfirmed", "messages", locale=ctx.locale), ephemeral=True)
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -18,7 +18,7 @@ async def is_event_status_valid(
|
||||
event: "PycordEvent",
|
||||
) -> bool:
|
||||
if event.is_cancelled:
|
||||
await ctx.respond(_("event_is_cancelled", "messages", locale=ctx.locale))
|
||||
await ctx.respond(_("event_is_cancelled", "messages", locale=ctx.locale), ephemeral=True)
|
||||
return False
|
||||
|
||||
if (
|
||||
@@ -26,7 +26,7 @@ async def is_event_status_valid(
|
||||
<= datetime.now(tz=ZoneInfo("UTC"))
|
||||
<= event.ends.replace(tzinfo=ZoneInfo("UTC"))
|
||||
):
|
||||
await ctx.respond(_("event_ongoing_not_editable", "messages", locale=ctx.locale))
|
||||
await ctx.respond(_("event_ongoing_not_editable", "messages", locale=ctx.locale), ephemeral=True)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
Reference in New Issue
Block a user