diff --git a/classes/holo_user.py b/classes/holo_user.py index 6403786..a633c7f 100644 --- a/classes/holo_user.py +++ b/classes/holo_user.py @@ -5,7 +5,7 @@ import discord import discord.member from libbot import config_get -from modules.database import col_users, col_warnings +from modules.database import col_warnings, sync_col_users, sync_col_warnings, col_users logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ class HoloUser: else: self.id = user - jav_user = col_users.find_one({"user": self.id}) + jav_user = sync_col_users.find_one({"user": self.id}) if jav_user is None: raise UserNotFoundError(user=user, user_id=self.id) @@ -57,29 +57,29 @@ class HoloUser: ### Returns: * `int`: Number of warnings """ - warns = col_warnings.find_one({"user": self.id}) + warns = sync_col_warnings.find_one({"user": self.id}) return 0 if warns is None else warns["warns"] - def warn(self, count=1, reason: str = "Not provided") -> None: + async def warn(self, count=1, reason: str = "Not provided") -> None: """Warn and add count to warns number ### Args: * `count` (int, optional): Count of warnings to be added. Defaults to 1. """ - warns = col_warnings.find_one({"user": self.id}) + warns = await col_warnings.find_one({"user": self.id}) if warns is not None: - col_warnings.update_one( - filter={"_id": self.db_id}, - update={"$set": {"warns": warns["warns"] + count}}, + await col_warnings.update_one( + {"_id": self.db_id}, + {"$set": {"warns": warns["warns"] + count}}, ) else: - col_warnings.insert_one(document={"user": self.id, "warns": count}) + await col_warnings.insert_one(document={"user": self.id, "warns": count}) logger.info(f"User {self.id} was warned {count} times due to: {reason}") - def set(self, key: str, value: Any) -> None: + async def set(self, key: str, value: Any) -> None: """Set attribute data and save it into database ### Args: @@ -90,14 +90,16 @@ class HoloUser: raise AttributeError() setattr(self, key, value) - col_users.update_one( - filter={"_id": self.db_id}, update={"$set": {key: value}}, upsert=True + + await col_users.update_one( + {"_id": self.db_id}, {"$set": {key: value}}, upsert=True ) logger.info(f"Set attribute {key} of user {self.id} to {value}") + @staticmethod async def is_moderator( - self, member: Union[discord.User, discord.Member, discord.member.Member] + member: Union[discord.User, discord.Member, discord.member.Member] ) -> bool: """Check if user is moderator or council member @@ -119,8 +121,9 @@ class HoloUser: return False + @staticmethod async def is_council( - self, member: Union[discord.User, discord.Member, discord.member.Member] + member: Union[discord.User, discord.Member, discord.member.Member] ) -> bool: """Check if user is a member of council diff --git a/cogs/analytics.py b/cogs/analytics.py index ebb6a13..9f884a3 100644 --- a/cogs/analytics.py +++ b/cogs/analytics.py @@ -46,7 +46,7 @@ class Analytics(commands.Cog): } ) - col_analytics.insert_one( + await col_analytics.insert_one( { "user": message.author.id, "channel": message.channel.id, diff --git a/cogs/custom_channels.py b/cogs/custom_channels.py index 2e20e2f..0824d4c 100644 --- a/cogs/custom_channels.py +++ b/cogs/custom_channels.py @@ -19,7 +19,7 @@ class CustomChannels(commands.Cog): @commands.Cog.listener() async def on_guild_channel_delete(self, channel: GuildChannel): - col_users.find_one_and_update( + await col_users.find_one_and_update( {"customchannel": channel.id}, {"$set": {"customchannel": None}} ) @@ -77,7 +77,7 @@ class CustomChannels(commands.Cog): manage_channels=True, ) - holo_user_ctx.set("customchannel", created_channel.id) + await holo_user_ctx.set("customchannel", created_channel.id) await ctx.respond( embed=Embed( @@ -178,7 +178,7 @@ class CustomChannels(commands.Cog): color=Color.fail, ) ) - holo_user_ctx.set("customchannel", None) + await holo_user_ctx.set("customchannel", None) return # Return if the confirmation is missing @@ -194,7 +194,7 @@ class CustomChannels(commands.Cog): await custom_channel.delete(reason="Власник запросив видалення") - holo_user_ctx.set("customchannel", None) + await holo_user_ctx.set("customchannel", None) await ctx.respond( embed=Embed( diff --git a/cogs/data.py b/cogs/data.py index d45f38b..ca9b902 100644 --- a/cogs/data.py +++ b/cogs/data.py @@ -40,7 +40,7 @@ class Data(commands.Cog): # Return if the user is not an owner and not in the council if (ctx.user.id not in self.client.owner_ids) and not ( - await holo_user.is_council(ctx.author) + await HoloUser.is_council(ctx.author) ): logging.info( "User %s tried to use /export but permission denied", @@ -108,11 +108,9 @@ class Data(commands.Cog): async def data_migrate_cmd(self, ctx: ApplicationContext, kind: str): await ctx.defer() - holo_user = HoloUser(ctx.author) - # Return if the user is not an owner and not in the council if (ctx.user.id not in self.client.owner_ids) and not ( - await holo_user.is_council(ctx.author) + await HoloUser.is_council(ctx.author) ): logging.info( "User %s tried to use /migrate but permission denied", @@ -156,7 +154,7 @@ class Data(commands.Cog): if member.bot: continue - if col_users.find_one({"user": member.id}) is None: + if (await col_users.find_one({"user": member.id})) is None: user = {} defaults = await config_get("user", "defaults") @@ -165,7 +163,7 @@ class Data(commands.Cog): for key in defaults: user[key] = defaults[key] - col_users.insert_one(document=user) + await col_users.insert_one(document=user) logging.info( "Added DB record for user %s during migration", member.id diff --git a/cogs/logger.py b/cogs/logger.py index c45c131..f404d0d 100644 --- a/cogs/logger.py +++ b/cogs/logger.py @@ -18,7 +18,7 @@ class Logger(commands.Cog): and (message.author.bot is False) and (message.author.system is False) ): - if col_users.find_one({"user": message.author.id}) is None: + if (await col_users.find_one({"user": message.author.id})) is None: user = {} defaults = await config_get("user", "defaults") @@ -27,7 +27,7 @@ class Logger(commands.Cog): for key in defaults: user[key] = defaults[key] - col_users.insert_one(document=user) + await col_users.insert_one(document=user) @commands.Cog.listener() async def on_member_join(self, member: Member): @@ -51,7 +51,7 @@ class Logger(commands.Cog): ) ) - if col_users.find_one({"user": member.id}) is None: + if (await col_users.find_one({"user": member.id})) is None: user = {} defaults = await config_get("user", "defaults") @@ -60,7 +60,7 @@ class Logger(commands.Cog): for key in defaults: user[key] = defaults[key] - col_users.insert_one(document=user) + await col_users.insert_one(document=user) def setup(client: PycordBot): diff --git a/modules/database.py b/modules/database.py index bb7eaaf..8d2d0ae 100644 --- a/modules/database.py +++ b/modules/database.py @@ -1,12 +1,19 @@ +from typing import Dict, Any + +from async_pymongo import AsyncClient, AsyncCollection, AsyncDatabase +from libbot.sync import config_get as sync_config_get from pymongo import MongoClient -from ujson import loads +from pymongo.synchronous.collection import Collection +from pymongo.synchronous.database import Database -with open("config.json", "r", encoding="utf-8") as f: - db_config = loads(f.read())["database"] - f.close() +db_config: Dict[str, Any] = sync_config_get("database") -db_client = MongoClient( - "mongodb://{0}:{1}@{2}:{3}/{4}".format( +con_string: str = ( + "mongodb://{0}:{1}/{2}".format( + db_config["host"], db_config["port"], db_config["name"] + ) + if db_config["user"] is None or db_config["password"] is None + else "mongodb://{0}:{1}@{2}:{3}/{4}".format( db_config["user"], db_config["password"], db_config["host"], @@ -14,14 +21,20 @@ db_client = MongoClient( db_config["name"], ) ) -db = db_client.get_database(name=db_config["name"]) -collections = db.list_collection_names() +db_client: AsyncClient = AsyncClient(con_string) +db_client_sync: MongoClient = MongoClient(con_string) -for collection in ["users", "warnings", "scheduler", "analytics"]: - if not collection in collections: - db.create_collection(collection) +# Async declarations per default +db: AsyncDatabase = db_client.get_database(name=db_config["name"]) -col_users = db.get_collection("users") -col_warnings = db.get_collection("warnings") -col_analytics = db.get_collection("analytics") +col_users: AsyncCollection = db.get_collection("users") +col_warnings: AsyncCollection = db.get_collection("warnings") +col_analytics: AsyncCollection = db.get_collection("analytics") + +# Sync declarations as a fallback +sync_db: Database = db_client_sync.get_database(name=db_config["name"]) + +sync_col_users: Collection = sync_db.get_collection("users") +sync_col_warnings: Collection = sync_db.get_collection("warnings") +sync_col_analytics: Collection = sync_db.get_collection("analytics") diff --git a/requirements.txt b/requirements.txt index 3eac856..e3d96e8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,6 @@ -aiofiles==24.1.0 -apscheduler==3.11.0 -pymongo~=4.10.0 -requests~=2.32.3 +aiofiles~=24.1.0 +apscheduler~=3.11.0 +async_pymongo==0.1.11 +libbot[speed,pycord]==3.2.3 ujson~=5.10.0 -WaifuPicsPython==0.2.0 ---extra-index-url https://git.end-play.xyz/api/packages/profitroll/pypi/simple -libbot[speed,pycord]==3.2.3 \ No newline at end of file +WaifuPicsPython==0.2.0 \ No newline at end of file