#!/usr/bin/env python3

import argparse
import logging
import os
import time

import cv2
import numpy as np

from recolonyzer import analysis, image_processing, preprocessing, utils

logger = None


# Create main function
def main(args):
    # Print information of inputs to users
    utils.summarise(args)

    # Get working directory
    fdir = utils.get_directory(args)
    nrow, ncol = utils.get_grid_format(args.gridformat)

    # Create needed directories to save outputs
    docheck = preprocessing.arrange_directories(args.remove, fdir)
    # Obtain list of images to analyse
    imanalyse = preprocessing.get_images(fdir, docheck, args.endpoint)
    imanalyse.sort()

    # Set timer
    start_time = time.time()

    # Obtain first and last image
    latest_image = imanalyse[-1]
    earliest_image = imanalyse[0]

    # Print information to users
    logger.debug("")
    logger.debug("Starting analysis:")
    logger.debug("Earliest image: %s", earliest_image)
    logger.debug("Latest image: %s", latest_image)
    logger.debug("")
    logger.debug("Computing position of the grid")
    logger.debug("This may take a few seconds...")
    logger.debug("")

    # Get latest image to detect culture locations
    im_n = cv2.imread(latest_image, cv2.IMREAD_GRAYSCALE)
    _, min_loc, pat_h, pat_w = image_processing.get_position_grid(
        im_n, nrow, ncol, args.fraction)

    # Cut the original image with the size of the best pattern match.
    w_right = int(min_loc[0] + pat_w)
    h_bottom = int(min_loc[1] + pat_h)
    im_ = im_n[min_loc[1]:h_bottom, min_loc[0]:w_right]

    # Find spots and agar based on an automatic threshold
    _, mask = cv2.threshold(
        np.array(im_, dtype=np.uint8), 0, 255,
        cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

    # Locate the position of the spots and the agar into different masks
    grd = np.ones(mask.shape, dtype=bool)
    spots = np.logical_and(grd, ~mask)  # grd & ~mask
    agar = np.logical_and(grd, mask)  # grd & mask

    # Reset mask to avoid problems in future iterations
    mask = None

    logger.debug("Analysing each of the images:")
    for file_name in imanalyse:

        img = cv2.imread(file_name, cv2.IMREAD_GRAYSCALE)
        arr = img[min_loc[1]:h_bottom, min_loc[0]:w_right]

        # Set threshold depending on last image
        thresh, _ = cv2.threshold(
            np.array(arr, dtype=np.uint8), 0, 255,
            cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
        # If the threshold is lower than the values of the agar, set it higher
        thresh = max(thresh, np.mean(arr[agar]) + 1)

        # Create mask to detect spots in each image
        # This will be used only to compute the area of the spots
        mask = np.ones(arr.shape, dtype=np.bool)
        mask[arr < thresh] = False

        # Measure culture phenotypes
        outputs_df = analysis.measure_outputs(arr, mask, pat_h, pat_w, nrow,
                                              ncol, file_name, spots,
                                              args.lightcorrection)
        # Save outputs
        outputs_df.to_csv(
            os.path.join(fdir, "Output_Data",
                         file_name.split(".")[0].split("/")[-1] + ".out"),
            sep="\t",
            index=False)

        # Save mask for a visual check
        img_outputs_dir = os.path.join(fdir, "Output_Images",
                                       file_name.split("/")[-1])
        cv2.imwrite(img_outputs_dir, mask.astype(np.uint8) * 255)

        # Reset mask to avoid problems in the next iteration
        mask = None
        logger.debug("Analysis complete for %s", file_name.split("/")[-1])
    logger.debug("All analyses finished in {:.2f} seconds".format(time.time() -
                                                                  start_time))
    logger.info("No more images to analyse. I'm done")


# Execute main function
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description=
        '''Analyse timeseries of QFA images: locate cultures on plate,
        segment image into agar and cells, apply lighting correction,
        write report including cell density estimates for each location
        in each image.''')
    parser.add_argument(
        "-c",
        "--lightcorrection",
        help='''Flag used to enable lighting correction between images.
        Default: light correction is disabled.''',
        action="store_true")
    parser.add_argument(
        "-q",
        "--quiet",
        help='''Flag used to suppress messages printed during the analysis.
        Default: show messages.''',
        action="store_true")
    parser.add_argument(
        "-r",
        "--remove",
        help='''Flag used to remove any output files from the directory
        before starting the analysis. It is useful to re-analyse a set of
        images that have been ananlysed in advance.
        Default: keep previous output files.''',
        action="store_true")
    parser.add_argument(
        "-e",
        "--endpoint",
        help='''Flag used to analyse only the final image in the series.
        It is useful to test single images.
        Default: analyse all images in the directory.''',
        action="store_true")
    parser.add_argument(
        "-d",
        "--directory",
        type=str,
        help='''Directory in which to search for image files that
        have not been analysed.
        Default = current directory.''',
        default=".")
    parser.add_argument(
        "-o",
        "--gridformat",
        type=str,
        nargs='+',
        help='''Specify rectangular grid format. Important: specify number of
        rows and number of columns, in this order (e.g. -o 8x12 or -o 8 12).
        Default = 8x12.''',
        default=['8x12'])
    parser.add_argument(
        "-f",
        "--fraction",
        type=utils.range_float,
        help='''Minimum fraction of the image that corresponds to the grid.
        Adjust if grid occupies a small part of the total image.
        Default = 0.8.''',
        default=0.8)

    args = parser.parse_args()

    # Setup logger
    if args.quiet:
        logging.basicConfig(format="%(message)s", level=logging.INFO)
    else:
        logging.basicConfig(format="%(message)s", level=logging.DEBUG)
    logger = logging.getLogger(__name__)

    main(args)
