#!/usr/bin/env python
# coding: utf-8

from __future__ import (absolute_import, division, print_function)

from logging import getLogger, StreamHandler, DEBUG
logger = getLogger(__name__)
handler = StreamHandler()
# handler.setLevel(DEBUG)
# logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False


from pathlib import Path
from common.imagededuper import ImageDeduper
from termcolor import colored, cprint

import sys


def is_image(path):
    img_file_ext = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.gif', '.GIF']
    return path.suffix in img_file_ext


def is_image_with_str(filepath):
    img_file_ext = ['png', 'PNG', 'jpg', 'JPG', 'jpeg', 'JPEG', 'gif', 'GIF']
    suffix = filepath.split('.')[-1]
    return suffix in img_file_ext


def package_check(args):
    if args.ngt:
        try:
            import ngtpy as _ngtpy
        except:
            logger.error(colored("Error: Unable to load NGT. Please install NGT and python binding first.", 'red'))
            sys.exit(1)
    elif args.hnsw:
        try:
            import hnswlib as _hnswlib
        except:
            logger.error(colored("Error: Unable to load hnsw. Please install hnsw python binding first.", 'red'))
            sys.exit(1)
    elif args.faiss_flat:
        try:
            import faiss as _faiss
        except:
            logger.error(colored("Error: Unable to load faiss. Please install faiss python binding first.", 'red'))
            sys.exit(1)
    else:
        pass


def gen_image_filenames(target_dir, recursive, sort_type):
    image_filenames = []
    if recursive:
        for path in Path(target_dir).glob('**/*'):
            if is_image(path):
                image_filenames.append(str(path))
    else:
        for path in Path(target_dir).glob('*'):
            if is_image(path):
                image_filenames.append(str(path))
    if len(image_filenames) == 0:
        logger.error("Image not found. To search the directory recursively, add the --recursive option.")
        sys.exit(0)
    if sort_type != 'none':
        image_filenames.sort()
    return image_filenames


def gen_image_filenames_from_list(target_files, sort_type):
    image_filenames = []
    with open(target_files, 'r') as f:
        image_filenames = [line.rstrip() for line in f.readlines() if is_image_with_str(line.rstrip())]
    if len(image_filenames) == 0:
        logger.error("Image not found. To search the directory recursively, add the --recursive option.")
        sys.exit(0)
    if sort_type != 'none':
        image_filenames.sort()
    return image_filenames


def dedupe_images(args):
    try:
        package_check(args)
        if args.files_from:
            image_filenames = gen_image_filenames_from_list(args.files_from, args.sort)
        else:
            image_filenames = gen_image_filenames(args.target_dir, args.recursive, args.sort)
        deduper = ImageDeduper(args, image_filenames)
        deduper.dedupe(args)

        if args.delete:
            deduper.preserve(args)
        else:
            deduper.print_duplicates(args)

        if args.summarize:
            deduper.summarize(args)

    except KeyboardInterrupt:
        sys.exit(1)


def main(argv=sys.argv[1:]):
    import argparse
    parser = argparse.ArgumentParser(
        description="finding and deleting duplicate image files based on perceptual hash")
    parser.add_argument("target_dir", type=str, nargs='?')
    parser.add_argument("hash_method", type=str,
        choices=['ahash', 'phash', 'dhash', 'whash', 'phash_org'],
        help="""method of perceptual hashing.
            ahash(Average hash) phash(Perceptual hash) dhash(Difference hash)
            whash(Haar wavelet hash) phash_org(Perceptual hash faithful implementation)""")
    parser.add_argument("hamming_distance", type=int,
        help="allowable Hamming distances.")
    # parser.add_argument("--preserve_largest", type=str,
    #     choice=['filesize', 'pixelsize'],
    #     help="preserve the larget filesize or pixelsize image from duplicate images")
    parser.add_argument('--files-from', '-T', type=str, default=None,
        help="use list of target image filename for input")
    parser.add_argument("-r", "--recursive", action="store_true", default=False,
        help="search images recursively from the target directory")
    parser.add_argument("-d", "--delete", action="store_true",
        help="prompt user for files to preserve and delete")
    parser.add_argument("-c", "--imgcat", action="store_true",
        help="display duplicate images for iTerm2")
    parser.add_argument("-m", "--summarize", action="store_true",
        help="summarize dupe information")
    parser.add_argument("-N", "--noprompt", action="store_true",
        help="together with --delete, preserve the first file in each set of duplicates and delete the rest without prompting the user")
    parser.add_argument("--query", type=str, default=None,
        help="find image files that are duplicated or similar to the specified image file from the target directory")
    parser.add_argument("--hash-bits", type=int, default=64,
        help="""bits of perceptual hash.
            The number of bits specifies the value that is the square of n.
            For example, you can specify 64(8^2), 144(12^2), 256(16^2)
            """)
    parser.add_argument("--sort", type=str,
        choices=['filesize', 'filepath', 'imagesize', 'width', 'height', 'none'],
        help="how to sort duplicate image files (default=filesize)")
    parser.add_argument("--reverse", action="store_true",
        help="reverse order while sorting")
    parser.add_argument("--num-proc", type=int, default=None,
        help="number of hash calculation and ngt processes (default=cpu_count-1)")
    parser.add_argument("--log", action="store_true",
        help="output logs of duplicate and delete files")
    parser.add_argument("--no-cache", dest="cache", action="store_false",
        help="not create or use image hash cache")
    parser.add_argument("--no-subdir-warning", dest="print_warning", action="store_false",
        help="stop warnings that appear when similar images are in different subdirectories")
    parser.add_argument("--sameline", action="store_true",
        help="list each set of matches on a single line")
    parser.add_argument("--dry-run", dest="run", action="store_false",
        help="dry run (do not delete any files)")
    parser.add_argument("--faiss-flat", action="store_true", default=False,
        help="use faiss exact search (IndexFlatL2) for calculating Hamming distance between hash of images")
    parser.add_argument("--faiss-flat-k", type=int, default=20,
        help="number of searched objects when using faiss-flat (default=20)")

    # imgcat options
    parser.add_argument("--size", type=str, default="256x256",
        help="resize image (default=256x256)")
    parser.add_argument("--space", type=int, default=0,
        help="space between images (default=0)")
    parser.add_argument("--space-color", type=str, default='black',
        help="space color between images (default=black)")
    parser.add_argument("--tile-num", type=int, default=4,
        help="horizontal tile number (default=4)")
    parser.add_argument("--interpolation", type=str, default="INTER_LINEAR",
        help="interpolation methods")
    parser.add_argument("--no-keep-aspect", dest="keep_aspect", action="store_false",
        help="do not keep aspect when displaying imagest")

    # ngt options
    parser.add_argument("--ngt", action="store_true", default=True,
        help="use NGT for calculating Hamming distance between hash of images")
    parser.add_argument("--ngt-k", type=int, default=20,
        help="""number of searched objects when using NGT.
            Increasing this value, improves accuracy and increases computation time.
            (default=20)""")
    parser.add_argument("--ngt-epsilon", type=float, default=0.1,
        help="""search range when using NGT.
            Increasing this value, improves accuracy and increases computation time.
            (default=0.1)""")
    parser.add_argument("--ngt-edges", type=int, default=10,
        help="number of initial edges of each node at graph generation time (default=10)")
    parser.add_argument("--ngt-edges-for-search", type=int, default=40,
        help="number of edges at search time (default=40)")

    # hnsw options
    parser.add_argument("--hnsw", action="store_true", default=False,
        help="use hnsw for calculating Hamming distance between hash of images")
    parser.add_argument("--hnsw-k", type=int, default=20,
        help="""number of searched objects when using hnsw.
            Increasing this value, improves accuracy and increases computation time.
            (default=20)""")
    parser.add_argument("--hnsw-ef-construction", type=int, default=100,
        help="controls index search speed/build speed tradeoff (default=100)")
    parser.add_argument("--hnsw-m", type=int, default=16,
        help="""m is tightly connected with internal dimensionality of the data
            stronlgy affects the memory consumption (default=16)""")
    parser.add_argument("--hnsw-ef", type=int, default=50,
        help="controls recall. higher ef leads to better accuracy, but slower search (default=50)")

    # faiss options
    parser.add_argument("--faiss-cuda", action="store_true", default=False,
        help="attempt to use CUDA enabled GPU for searches")

    # CUDA options
    parser.add_argument("--cuda-device", type=int, default=-1,
        help="uses the specific CUDA device passed (default=device with lowest load)")

    args = parser.parse_args()

    if (args.target_dir is None) and (args.files_from is None):
        print("Positional argument 'target_dir' is required when not specified --files-from option.")
        sys.exit(1)
    elif args.target_dir and args.files_from:
        logger.warning(colored("Both 'target_dir' and `--files-from` option are specified. `target_dir` is ignored.", 'red'))

    # check options
    if args.delete and args.summarize:
        print("options --summarize and --delete are not compatible")
        sys.exit(1)

    # check hnsw or faiss-flat
    if args.hnsw or args.faiss_flat:
        args.ngt = False

    dedupe_images(args)


if __name__ == '__main__':
    main()
