125 lines
3.8 KiB
Python
125 lines
3.8 KiB
Python
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Union
|
|
|
|
from bson import ObjectId
|
|
from libbot import sync
|
|
from pyrogram.types import User
|
|
|
|
from classes.pyroclient import PyroClient
|
|
from modules.database import col_groups
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class PyroGroup:
|
|
"""Dataclass of DB entry of a group"""
|
|
|
|
__slots__ = (
|
|
"_id",
|
|
"id",
|
|
"locale",
|
|
"locale_auto",
|
|
"ban_failed",
|
|
"timeout_join",
|
|
"timeout_verify",
|
|
)
|
|
|
|
_id: ObjectId
|
|
id: int
|
|
locale: Union[str, None]
|
|
locale_auto: bool
|
|
ban_failed: bool
|
|
timeout_join: int
|
|
timeout_verify: int
|
|
|
|
@classmethod
|
|
async def create_if_not_exists(
|
|
cls,
|
|
id: int,
|
|
locale: Union[str, None] = sync.config_get("locale", "defaults", "group"),
|
|
locale_auto: bool = sync.config_get("locale_auto", "defaults", "group"),
|
|
ban_failed: bool = sync.config_get("ban_failed", "defaults", "group"),
|
|
timeout_join: int = sync.config_get("timeout_join", "defaults", "group"),
|
|
timeout_verify: int = sync.config_get("timeout_verify", "defaults", "group"),
|
|
):
|
|
db_entry = await col_groups.find_one(
|
|
{
|
|
"id": id,
|
|
}
|
|
)
|
|
|
|
if db_entry is None:
|
|
inserted = await col_groups.insert_one(
|
|
{
|
|
"id": id,
|
|
"locale": locale,
|
|
"locale_auto": locale_auto,
|
|
"ban_failed": ban_failed,
|
|
"timeout_join": timeout_join,
|
|
"timeout_verify": timeout_verify,
|
|
}
|
|
)
|
|
db_entry = await col_groups.find_one({"_id": inserted.inserted_id})
|
|
|
|
if db_entry is None:
|
|
raise RuntimeError("Could not find inserted group entry.")
|
|
|
|
return cls(**db_entry)
|
|
|
|
async def set_locale(self, locale: Union[str, None]) -> None:
|
|
logger.debug("Locale of group %s has been set to %s", self.id, locale)
|
|
await col_groups.update_one({"_id": self._id}, {"$set": {"locale": locale}})
|
|
|
|
async def set_locale_auto(self, enabled: bool) -> None:
|
|
logger.debug(
|
|
"Automatic locale selection of group %s has been set to %s",
|
|
self.id,
|
|
enabled,
|
|
)
|
|
await col_groups.update_one(
|
|
{"_id": self._id}, {"$set": {"locale_auto": enabled}}
|
|
)
|
|
|
|
async def set_ban_failed(self, enabled: bool) -> None:
|
|
logger.debug(
|
|
"Banning users that failed the captcha in group %s has been set to %s",
|
|
self.id,
|
|
enabled,
|
|
)
|
|
await col_groups.update_one(
|
|
{"_id": self._id}, {"$set": {"ban_failed": enabled}}
|
|
)
|
|
|
|
async def set_timeout_join(self, timeout: int) -> None:
|
|
logger.debug(
|
|
"Join timeout in group %s has been set to %s",
|
|
self.id,
|
|
timeout,
|
|
)
|
|
await col_groups.update_one(
|
|
{"_id": self._id}, {"$set": {"timeout_join": timeout}}
|
|
)
|
|
|
|
async def set_timeout_verify(self, timeout: int) -> None:
|
|
logger.debug(
|
|
"Verification timeout in group %s has been set to %s",
|
|
self.id,
|
|
timeout,
|
|
)
|
|
await col_groups.update_one(
|
|
{"_id": self._id}, {"$set": {"timeout_verify": timeout}}
|
|
)
|
|
|
|
# Group settings
|
|
# User locale
|
|
def select_locale(
|
|
self, app: PyroClient, user: User, ignore_auto: bool = False
|
|
) -> str:
|
|
if not ignore_auto and self.locale_auto is True:
|
|
if user.language_code is not None:
|
|
return user.language_code
|
|
return self.locale if self.locale is not None else app.default_locale
|
|
return self.locale if self.locale is not None else app.default_locale
|