WIP: Migration to async_pymongo
This commit is contained in:
		| @@ -1,4 +1,5 @@ | ||||
| from pymongo import GEOSPHERE, MongoClient | ||||
| from async_pymongo import AsyncClient | ||||
| from pymongo import GEOSPHERE | ||||
|  | ||||
| from modules.utils import configGet | ||||
|  | ||||
| @@ -17,16 +18,10 @@ else: | ||||
|         db_config["host"], db_config["port"], db_config["name"] | ||||
|     ) | ||||
|  | ||||
| db_client = MongoClient(con_string) | ||||
| db_client = AsyncClient(con_string) | ||||
|  | ||||
| db = db_client.get_database(name=db_config["name"]) | ||||
|  | ||||
| collections = db.list_collection_names() | ||||
|  | ||||
| for collection in ["users", "albums", "photos", "videos", "tokens", "emails"]: | ||||
|     if collection not in collections: | ||||
|         db.create_collection(collection) | ||||
|  | ||||
| col_users = db.get_collection("users") | ||||
| col_albums = db.get_collection("albums") | ||||
| col_photos = db.get_collection("photos") | ||||
|   | ||||
| @@ -1,4 +1,6 @@ | ||||
| import contextlib | ||||
| from pathlib import Path | ||||
| from typing import Mapping, Union | ||||
|  | ||||
| from exif import Image | ||||
|  | ||||
| @@ -21,7 +23,7 @@ def decimal_coords(coords: float, ref: str) -> float: | ||||
|     return round(decimal_degrees, 5) | ||||
|  | ||||
|  | ||||
| def extract_location(filepath: str) -> dict: | ||||
| def extract_location(filepath: Union[str, Path]) -> Mapping[str, float]: | ||||
|     """Get location data from image | ||||
|  | ||||
|     ### Args: | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| from importlib.util import module_from_spec, spec_from_file_location | ||||
| from os import getcwd, path, walk | ||||
| from pathlib import Path | ||||
| from typing import Union | ||||
|  | ||||
| # ================================================================================= | ||||
|  | ||||
| @@ -17,11 +18,15 @@ def get_py_files(src): | ||||
|     return py_files | ||||
|  | ||||
|  | ||||
| def dynamic_import(module_name, py_path): | ||||
| def dynamic_import(module_name: str, py_path: str): | ||||
|     try: | ||||
|         module_spec = spec_from_file_location(module_name, py_path) | ||||
|         module = module_from_spec(module_spec)  # type: ignore | ||||
|         module_spec.loader.exec_module(module)  # type: ignore | ||||
|         if module_spec is None: | ||||
|             raise RuntimeError( | ||||
|                 f"Module spec from module name {module_name} and path {py_path} is None" | ||||
|             ) | ||||
|         module = module_from_spec(module_spec) | ||||
|         module_spec.loader.exec_module(module) | ||||
|         return module | ||||
|     except SyntaxError: | ||||
|         print( | ||||
| @@ -29,12 +34,12 @@ def dynamic_import(module_name, py_path): | ||||
|             flush=True, | ||||
|         ) | ||||
|         return | ||||
|     except Exception as exp: | ||||
|         print(f"Could not load extension {module_name} due to {exp}", flush=True) | ||||
|     except Exception as exc: | ||||
|         print(f"Could not load extension {module_name} due to {exc}", flush=True) | ||||
|         return | ||||
|  | ||||
|  | ||||
| def dynamic_import_from_src(src, star_import=False): | ||||
| def dynamic_import_from_src(src: Union[str, Path], star_import=False): | ||||
|     my_py_files = get_py_files(src) | ||||
|     for py_file in my_py_files: | ||||
|         module_name = Path(py_file).stem | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| from pathlib import Path | ||||
| from typing import Union | ||||
| from typing import Any, List, Mapping, Union | ||||
|  | ||||
| import cv2 | ||||
| import numpy as np | ||||
| @@ -9,7 +9,7 @@ from scipy import spatial | ||||
| from modules.database import col_photos | ||||
|  | ||||
|  | ||||
| def hash_array_to_hash_hex(hash_array): | ||||
| def hash_array_to_hash_hex(hash_array) -> str: | ||||
|     # convert hash array of 0 or 1 to hash string in hex | ||||
|     hash_array = np.array(hash_array, dtype=np.uint8) | ||||
|     hash_str = "".join(str(i) for i in 1 * hash_array.flatten()) | ||||
| @@ -23,10 +23,10 @@ def hash_hex_to_hash_array(hash_hex) -> NDArray: | ||||
|     return np.array(list(array_str), dtype=np.float32) | ||||
|  | ||||
|  | ||||
| def get_duplicates_cache(album: str) -> dict: | ||||
| async def get_duplicates_cache(album: str) -> Mapping[str, Any]: | ||||
|     return { | ||||
|         photo["filename"]: [photo["_id"].__str__(), photo["hash"]] | ||||
|         for photo in col_photos.find({"album": album}) | ||||
|         async for photo in col_photos.find({"album": album}) | ||||
|     } | ||||
|  | ||||
|  | ||||
| @@ -52,9 +52,9 @@ async def get_phash(filepath: Union[str, Path]) -> str: | ||||
|     return hash_array_to_hash_hex(dct_block.flatten()) | ||||
|  | ||||
|  | ||||
| async def get_duplicates(hash_string: str, album: str) -> list: | ||||
| async def get_duplicates(hash_string: str, album: str) -> List[Mapping[str, Any]]: | ||||
|     duplicates = [] | ||||
|     cache = get_duplicates_cache(album) | ||||
|     cache = await get_duplicates_cache(album) | ||||
|     for image_name, image_object in cache.items(): | ||||
|         try: | ||||
|             distance = spatial.distance.hamming( | ||||
|   | ||||
| @@ -28,8 +28,8 @@ try: | ||||
|         ) | ||||
|         mail_sender.ehlo() | ||||
|         logger.info("Initialized SMTP connection") | ||||
| except Exception as exp: | ||||
|     logger.error("Could not initialize SMTP connection to: %s", exp) | ||||
| except Exception as exc: | ||||
|     logger.error("Could not initialize SMTP connection to: %s", exc) | ||||
|     print_exc() | ||||
|  | ||||
| try: | ||||
| @@ -37,5 +37,5 @@ try: | ||||
|         configGet("login", "mailer", "smtp"), configGet("password", "mailer", "smtp") | ||||
|     ) | ||||
|     logger.info("Successfully initialized mailer") | ||||
| except Exception as exp: | ||||
|     logger.error("Could not login into provided SMTP account due to: %s", exp) | ||||
| except Exception as exc: | ||||
|     logger.error("Could not login into provided SMTP account due to: %s", exc) | ||||
|   | ||||
| @@ -54,16 +54,20 @@ oauth2_scheme = OAuth2PasswordBearer( | ||||
| ) | ||||
|  | ||||
|  | ||||
| def verify_password(plain_password, hashed_password): | ||||
| def verify_password(plain_password, hashed_password) -> bool: | ||||
|     return pwd_context.verify(plain_password, hashed_password) | ||||
|  | ||||
|  | ||||
| def get_password_hash(password): | ||||
| def get_password_hash(password) -> str: | ||||
|     return pwd_context.hash(password) | ||||
|  | ||||
|  | ||||
| def get_user(user: str): | ||||
|     found_user = col_users.find_one({"user": user}) | ||||
| async def get_user(user: str) -> UserInDB: | ||||
|     found_user = await col_users.find_one({"user": user}) | ||||
|  | ||||
|     if found_user is None: | ||||
|         raise RuntimeError(f"User {user} does not exist") | ||||
|  | ||||
|     return UserInDB( | ||||
|         user=found_user["user"], | ||||
|         email=found_user["email"], | ||||
| @@ -72,14 +76,16 @@ def get_user(user: str): | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def authenticate_user(user_name: str, password: str): | ||||
|     if user := get_user(user_name): | ||||
| async def authenticate_user(user_name: str, password: str) -> Union[UserInDB, bool]: | ||||
|     if user := await get_user(user_name): | ||||
|         return user if verify_password(password, user.hash) else False | ||||
|     else: | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None): | ||||
| def create_access_token( | ||||
|     data: dict, expires_delta: Union[timedelta, None] = None | ||||
| ) -> str: | ||||
|     to_encode = data.copy() | ||||
|     if expires_delta: | ||||
|         expire = datetime.now(tz=timezone.utc) + expires_delta | ||||
| @@ -93,7 +99,7 @@ def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None | ||||
|  | ||||
| async def get_current_user( | ||||
|     security_scopes: SecurityScopes, token: str = Depends(oauth2_scheme) | ||||
| ): | ||||
| ) -> UserInDB: | ||||
|     if security_scopes.scopes: | ||||
|         authenticate_value = f'Bearer scope="{security_scopes.scope_str}"' | ||||
|     else: | ||||
| @@ -112,12 +118,12 @@ async def get_current_user( | ||||
|             raise credentials_exception | ||||
|         token_scopes = payload.get("scopes", []) | ||||
|         token_data = TokenData(scopes=token_scopes, user=user) | ||||
|     except (JWTError, ValidationError): | ||||
|         raise credentials_exception | ||||
|     except (JWTError, ValidationError) as exc: | ||||
|         raise credentials_exception from exc | ||||
|  | ||||
|     user = get_user(user=token_data.user) | ||||
|     user_record = await get_user(user=token_data.user) | ||||
|  | ||||
|     if user is None: | ||||
|     if user_record is None: | ||||
|         raise credentials_exception | ||||
|  | ||||
|     for scope in security_scopes.scopes: | ||||
| @@ -127,7 +133,7 @@ async def get_current_user( | ||||
|                 detail="Not enough permissions", | ||||
|                 headers={"WWW-Authenticate": authenticate_value}, | ||||
|             ) | ||||
|     return user | ||||
|     return user_record | ||||
|  | ||||
|  | ||||
| async def get_current_active_user( | ||||
|   | ||||
| @@ -49,8 +49,8 @@ def jsonSave(contents: Union[list, dict], filepath: Union[str, Path]) -> None: | ||||
|         with open(filepath, "w", encoding="utf8") as file: | ||||
|             file.write(dumps(contents, ensure_ascii=False, indent=4)) | ||||
|             file.close() | ||||
|     except Exception as exp: | ||||
|         logger.error("Could not save json file %s: %s\n%s", filepath, exp, format_exc()) | ||||
|     except Exception as exc: | ||||
|         logger.error("Could not save json file %s: %s\n%s", filepath, exc, format_exc()) | ||||
|     return | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user