WIP: Migration to async_pymongo

This commit is contained in:
2023-08-14 13:44:07 +02:00
parent 80ec8eb4f3
commit a1acaed6dd
13 changed files with 196 additions and 175 deletions

View File

@@ -1,4 +1,5 @@
from pymongo import GEOSPHERE, MongoClient
from async_pymongo import AsyncClient
from pymongo import GEOSPHERE
from modules.utils import configGet
@@ -17,16 +18,10 @@ else:
db_config["host"], db_config["port"], db_config["name"]
)
db_client = MongoClient(con_string)
db_client = AsyncClient(con_string)
db = db_client.get_database(name=db_config["name"])
collections = db.list_collection_names()
for collection in ["users", "albums", "photos", "videos", "tokens", "emails"]:
if collection not in collections:
db.create_collection(collection)
col_users = db.get_collection("users")
col_albums = db.get_collection("albums")
col_photos = db.get_collection("photos")

View File

@@ -1,4 +1,6 @@
import contextlib
from pathlib import Path
from typing import Mapping, Union
from exif import Image
@@ -21,7 +23,7 @@ def decimal_coords(coords: float, ref: str) -> float:
return round(decimal_degrees, 5)
def extract_location(filepath: str) -> dict:
def extract_location(filepath: Union[str, Path]) -> Mapping[str, float]:
"""Get location data from image
### Args:

View File

@@ -1,6 +1,7 @@
from importlib.util import module_from_spec, spec_from_file_location
from os import getcwd, path, walk
from pathlib import Path
from typing import Union
# =================================================================================
@@ -17,11 +18,15 @@ def get_py_files(src):
return py_files
def dynamic_import(module_name, py_path):
def dynamic_import(module_name: str, py_path: str):
try:
module_spec = spec_from_file_location(module_name, py_path)
module = module_from_spec(module_spec) # type: ignore
module_spec.loader.exec_module(module) # type: ignore
if module_spec is None:
raise RuntimeError(
f"Module spec from module name {module_name} and path {py_path} is None"
)
module = module_from_spec(module_spec)
module_spec.loader.exec_module(module)
return module
except SyntaxError:
print(
@@ -29,12 +34,12 @@ def dynamic_import(module_name, py_path):
flush=True,
)
return
except Exception as exp:
print(f"Could not load extension {module_name} due to {exp}", flush=True)
except Exception as exc:
print(f"Could not load extension {module_name} due to {exc}", flush=True)
return
def dynamic_import_from_src(src, star_import=False):
def dynamic_import_from_src(src: Union[str, Path], star_import=False):
my_py_files = get_py_files(src)
for py_file in my_py_files:
module_name = Path(py_file).stem

View File

@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Union
from typing import Any, List, Mapping, Union
import cv2
import numpy as np
@@ -9,7 +9,7 @@ from scipy import spatial
from modules.database import col_photos
def hash_array_to_hash_hex(hash_array):
def hash_array_to_hash_hex(hash_array) -> str:
# convert hash array of 0 or 1 to hash string in hex
hash_array = np.array(hash_array, dtype=np.uint8)
hash_str = "".join(str(i) for i in 1 * hash_array.flatten())
@@ -23,10 +23,10 @@ def hash_hex_to_hash_array(hash_hex) -> NDArray:
return np.array(list(array_str), dtype=np.float32)
def get_duplicates_cache(album: str) -> dict:
async def get_duplicates_cache(album: str) -> Mapping[str, Any]:
return {
photo["filename"]: [photo["_id"].__str__(), photo["hash"]]
for photo in col_photos.find({"album": album})
async for photo in col_photos.find({"album": album})
}
@@ -52,9 +52,9 @@ async def get_phash(filepath: Union[str, Path]) -> str:
return hash_array_to_hash_hex(dct_block.flatten())
async def get_duplicates(hash_string: str, album: str) -> list:
async def get_duplicates(hash_string: str, album: str) -> List[Mapping[str, Any]]:
duplicates = []
cache = get_duplicates_cache(album)
cache = await get_duplicates_cache(album)
for image_name, image_object in cache.items():
try:
distance = spatial.distance.hamming(

View File

@@ -28,8 +28,8 @@ try:
)
mail_sender.ehlo()
logger.info("Initialized SMTP connection")
except Exception as exp:
logger.error("Could not initialize SMTP connection to: %s", exp)
except Exception as exc:
logger.error("Could not initialize SMTP connection to: %s", exc)
print_exc()
try:
@@ -37,5 +37,5 @@ try:
configGet("login", "mailer", "smtp"), configGet("password", "mailer", "smtp")
)
logger.info("Successfully initialized mailer")
except Exception as exp:
logger.error("Could not login into provided SMTP account due to: %s", exp)
except Exception as exc:
logger.error("Could not login into provided SMTP account due to: %s", exc)

View File

@@ -54,16 +54,20 @@ oauth2_scheme = OAuth2PasswordBearer(
)
def verify_password(plain_password, hashed_password):
def verify_password(plain_password, hashed_password) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
def get_password_hash(password) -> str:
return pwd_context.hash(password)
def get_user(user: str):
found_user = col_users.find_one({"user": user})
async def get_user(user: str) -> UserInDB:
found_user = await col_users.find_one({"user": user})
if found_user is None:
raise RuntimeError(f"User {user} does not exist")
return UserInDB(
user=found_user["user"],
email=found_user["email"],
@@ -72,14 +76,16 @@ def get_user(user: str):
)
def authenticate_user(user_name: str, password: str):
if user := get_user(user_name):
async def authenticate_user(user_name: str, password: str) -> Union[UserInDB, bool]:
if user := await get_user(user_name):
return user if verify_password(password, user.hash) else False
else:
return False
def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None):
def create_access_token(
data: dict, expires_delta: Union[timedelta, None] = None
) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.now(tz=timezone.utc) + expires_delta
@@ -93,7 +99,7 @@ def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None
async def get_current_user(
security_scopes: SecurityScopes, token: str = Depends(oauth2_scheme)
):
) -> UserInDB:
if security_scopes.scopes:
authenticate_value = f'Bearer scope="{security_scopes.scope_str}"'
else:
@@ -112,12 +118,12 @@ async def get_current_user(
raise credentials_exception
token_scopes = payload.get("scopes", [])
token_data = TokenData(scopes=token_scopes, user=user)
except (JWTError, ValidationError):
raise credentials_exception
except (JWTError, ValidationError) as exc:
raise credentials_exception from exc
user = get_user(user=token_data.user)
user_record = await get_user(user=token_data.user)
if user is None:
if user_record is None:
raise credentials_exception
for scope in security_scopes.scopes:
@@ -127,7 +133,7 @@ async def get_current_user(
detail="Not enough permissions",
headers={"WWW-Authenticate": authenticate_value},
)
return user
return user_record
async def get_current_active_user(

View File

@@ -49,8 +49,8 @@ def jsonSave(contents: Union[list, dict], filepath: Union[str, Path]) -> None:
with open(filepath, "w", encoding="utf8") as file:
file.write(dumps(contents, ensure_ascii=False, indent=4))
file.close()
except Exception as exp:
logger.error("Could not save json file %s: %s\n%s", filepath, exp, format_exc())
except Exception as exc:
logger.error("Could not save json file %s: %s\n%s", filepath, exc, format_exc())
return