from pathlib import Path
from typing import Any, List, Mapping, Union

import cv2
import numpy as np
from numpy.typing import NDArray
from scipy import spatial

from modules.database import col_photos


def hash_array_to_hash_hex(hash_array) -> str:
    # convert hash array of 0 or 1 to hash string in hex
    hash_array = np.array(hash_array, dtype=np.uint8)
    hash_str = "".join(str(i) for i in 1 * hash_array.flatten())
    return hex(int(hash_str, 2))


def hash_hex_to_hash_array(hash_hex) -> NDArray:
    # convert hash string in hex to hash values of 0 or 1
    hash_str = int(hash_hex, 16)
    array_str = bin(hash_str)[2:]
    return np.array(list(array_str), dtype=np.float32)


async def get_duplicates_cache(album: str) -> Mapping[str, Any]:
    return {
        photo["filename"]: [photo["_id"].__str__(), photo["hash"]]
        async for photo in col_photos.find({"album": album})
    }


async def get_phash(filepath: Union[str, Path]) -> str:
    img = cv2.imread(str(filepath))
    # resize image and convert to gray scale
    img = cv2.resize(img, (64, 64))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = np.array(img, dtype=np.float32)
    # calculate dct of image
    dct = cv2.dct(img)
    # to reduce hash length take only 8*8 top-left block
    # as this block has more information than the rest
    dct_block = dct[:8, :8]
    # caclulate mean of dct block excluding first term i.e, dct(0, 0)
    dct_average = (dct_block.mean() * dct_block.size - dct_block[0, 0]) / (
        dct_block.size - 1
    )
    # convert dct block to binary values based on dct_average
    dct_block[dct_block < dct_average] = 0.0
    dct_block[dct_block != 0] = 1.0
    # store hash value
    return hash_array_to_hash_hex(dct_block.flatten())


async def get_duplicates(hash_string: str, album: str) -> List[Mapping[str, Any]]:
    duplicates = []
    cache = await get_duplicates_cache(album)
    for image_name, image_object in cache.items():
        try:
            distance = spatial.distance.hamming(
                hash_hex_to_hash_array(cache[image_name][1]),
                hash_hex_to_hash_array(hash_string),
            )
        except ValueError:
            continue
        # print("{0:<30} {1}".format(image_name, distance), flush=True)
        if distance <= 0.1:
            duplicates.append(
                {
                    "id": cache[image_name][0],
                    "filename": image_name,
                    "difference": distance,
                }
            )
    return duplicates