#!python

##############################################
#                                            #
# Executable file that counts and segments   #
# circles from blood cell images             #
#                                            #
# Author: Amine Neggazi                      #
# Email: neggazimedlamine@gmail/com          #
# Nick: nemo256                              #
#                                            #
# Please read cbc/LICENSE                    #
#                                            #
##############################################

# imports
import glob
import os
import cv2
import json
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from scipy import ndimage


# global variables
input_shape  = (188, 188, 3)
output_shape = (100, 100, 1)
padding      = [200, 100]

def conv_bn(filters,
            model,
            kernel=(3, 3),
            activation='relu', 
            strides=(1, 1),
            padding='valid',
            type='normal'):
    '''
    This is a custom convolution function:
    :param filters --> number of filters for each convolution
    :param kernel --> the kernel size
    :param activation --> the general activation function (relu)
    :param strides --> number of strides
    :param padding --> model padding (can be valid or same)
    :param type --> to indicate if it is a transpose or normal convolution

    :return --> returns the output after the convolutions.
    '''
    if type == 'transpose':
        kernel = (2, 2)
        strides = 2
        conv = tf.keras.layers.Conv2DTranspose(filters, kernel, strides, padding)(model)
    else:
        conv = tf.keras.layers.Conv2D(filters, kernel, strides, padding)(model)

    conv = tf.keras.layers.BatchNormalization()(conv)
    conv = tf.keras.layers.Activation(activation)(conv)

    return conv


def max_pool(input):
    '''
    This is a general max pool function with custom parameters.
    '''
    return tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(input)


def concatenate(input1, input2, crop):
    '''
    This is a general concatenation function with custom parameters.
    '''
    return tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(crop)(input1), input2])


def do_unet():
    '''
    This is the dual output U-Net model.
    It is a custom U-Net with optimized number of layers.
    Please read model.summary()
    '''
    inputs = tf.keras.layers.Input((188, 188, 3))

    # encoder
    filters = 32
    encoder1 = conv_bn(3*filters, inputs)
    encoder1 = conv_bn(filters, encoder1, kernel=(1, 1))
    encoder1 = conv_bn(filters, encoder1)
    pool1 = max_pool(encoder1)

    filters *= 2
    encoder2 = conv_bn(filters, pool1)
    encoder2 = conv_bn(filters, encoder2)
    pool2 = max_pool(encoder2)

    filters *= 2
    encoder3 = conv_bn(filters, pool2)
    encoder3 = conv_bn(filters, encoder3)
    pool3 = max_pool(encoder3)

    filters *= 2
    encoder4 = conv_bn(filters, pool3)
    encoder4 = conv_bn(filters, encoder4)

    # decoder
    filters /= 2
    decoder1 = conv_bn(filters, encoder4, type='transpose')
    decoder1 = concatenate(encoder3, decoder1, 4)
    decoder1 = conv_bn(filters, decoder1)
    decoder1 = conv_bn(filters, decoder1)

    filters /= 2
    decoder2 = conv_bn(filters, decoder1, type='transpose')
    decoder2 = concatenate(encoder2, decoder2, 16)
    decoder2 = conv_bn(filters, decoder2)
    decoder2 = conv_bn(filters, decoder2)

    filters /= 2
    decoder3 = conv_bn(filters, decoder2, type='transpose')
    decoder3 = concatenate(encoder1, decoder3, 40)
    decoder3 = conv_bn(filters, decoder3)
    decoder3 = conv_bn(filters, decoder3)

    out_mask = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='mask')(decoder3)

    if cell_type == 'rbc':
        out_edge = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='edge')(decoder3)
        model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask, out_edge))
    elif cell_type == 'wbc' or cell_type == 'plt':
        model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask))

    opt = tf.optimizers.Adam(learning_rate=0.0001)

    if cell_type == 'rbc':
        model.compile(loss='mse',
                      loss_weights=[0.1, 0.9],
                      optimizer=opt,
                      metrics='accuracy')
    elif cell_type == 'wbc' or cell_type == 'plt':
        model.compile(loss='mse',
                      optimizer=opt,
                      metrics='accuracy')
    return model


def load_image_list(img_files):
    '''
    This is the load image list function, which loads an enumerate
    of images (param: img_files)
    :param img_files --> the input image files which we want to read

    :return imgs --> the images that we read
    '''
    imgs = []
    for image_file in img_files:
        imgs += [cv2.imread(image_file)]
    return imgs


def clahe_images(img_list):
    '''
    This is the clahe images function, which applies a clahe threshold
    the input image list.
    :param img_files --> the input image files which we want to read

    :return img_list --> the output images
    '''
    for i, img in enumerate(img_list):
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))

        lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
        lab[..., 0] = clahe.apply(lab[..., 0])
        img_list[i] = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
    return img_list


def preprocess_data(imgs, padding=padding[1]):
    '''
    This is the preprocess data function, which adds a padding to 
    the input images, masks and edges if there are any.
    :param imgs --> the input list of images.
    :param padding --> the input padding which is going to be applied.

    :return imgs --> output images with added padding.
    '''
    imgs = [np.pad(img, ((padding, padding),
                         (padding, padding), (0, 0)), mode='constant') for img in imgs]
    return imgs


def load_data(img_list, padding=padding[1]):
    '''
    This is the load data function, which will handle image loading and preprocessing.
    :param img_list --> list of input images
    :param padding --> padding to be applied on preprocessing

    :return imgs --> the output preprocessed imgs.
    '''
    imgs = load_image_list(img_list)
    imgs = clahe_images(imgs)
    return preprocess_data(imgs, padding=padding)


def slice(imgs,
          padding=padding[1],
          input_size=input_shape[0],
          output_size=output_shape[0]):
    '''
    This is the slice function, which slices each image into image chips.
    :param imgs --> the input images
    :param padding --> the padding which will be applied to each image
    :param input_size --> the input shape
    :param output_size --> the output shape

    :return list tuple (list, list, list) --> the tuple list of output (image, mask and edge chips)
    '''
    img_chips = []

    center_offset = padding + (output_size / 2)
    for i, _ in enumerate(imgs):
        for x in np.arange(center_offset, imgs[i].shape[0] - input_size / 2, output_size):
            for y in np.arange(center_offset, imgs[i].shape[1] - input_size / 2, output_size):
                chip_x_l = int(x - (input_size / 2))
                chip_x_r = int(x + (input_size / 2))
                chip_y_l = int(y - (input_size / 2))
                chip_y_r = int(y + (input_size / 2))

                temp_chip = imgs[i][chip_x_l:chip_x_r, chip_y_l:chip_y_r]

                temp_chip = temp_chip.astype(np.float32) * 2
                temp_chip /= 255
                temp_chip -= 1

                img_chips += [temp_chip]
    return np.array(img_chips)


def normalize(img):
    '''
    Normalizes an image
    :param img --> an input image that we want normalized

    :return np.array --> an output image normalized (as a numpy array)
    '''
    return np.array((img - np.min(img)) / (np.max(img) - np.min(img)))


def get_sizes(img,
              padding=padding[1],
              input=input_shape[0],
              output=output_shape[0]):
    '''
    Get full image sizes (x, y) to rebuilt the full image output
    :param img --> an input image we want to get its dimensions
    :param padding --> the default padding used on the test dataset
    :param input --> the input shape of the image (param: img)
    :param output --> the output shape of the image (param: img)

    :return couple --> a couple which contains the image dimensions as in (x, y)
    '''
    offset = padding + (output / 2)
    return [(len(np.arange(offset, img[0].shape[0] - input / 2, output)), len(np.arange(offset, img[0].shape[1] - input / 2, output)))]


def reshape(img,
            size_x,
            size_y):
    '''
    Reshape the full image output using the original sizes (x, y)
    :param img --> an input image we want to reshape
    :param size_x --> the x axis (length) of the input image (param: img)
    :param size_y --> the y axis (length) of the input image (param: img)

    :return img (numpy array) --> the output image reshaped according to the provided dimensions (size_x, size_y)
    '''
    return img.reshape(size_x, size_y, output_shape[0], output_shape[0], 1)


def concat(imgs):
    '''
    Concatenate all the output image chips to rebuild the full image
    :param imgs --> the images that we want to concatenate

    :return full_image --> the concatenation of all the provided images (param: imgs)
    '''
    return cv2.vconcat([cv2.hconcat(im_list) for im_list in imgs[:,:,:,:]])


def denoise(img):
    '''
    Remove noise from an image
    :param img --> the input image that we want to denoise (remove the noise)

    :return image --> the denoised output image
    '''
    # read the image
    img = cv2.imread(img)
    # return the output denoised image
    return cv2.fastNlMeansDenoising(img, 23, 23, 7, 21)


def predict(img='Im037_0'):
    '''
    Predict (segment) blood cell images using the trained model (do_unet)
    :param img --> the image we want to predict (from the test/ directory)

    :return --> saves the predicted (segmented blood cell image) under the folder output/
    '''
    # Check for existing predictions
    if os.path.exists(f'{output_directory}/mask.png'):
        print('Prediction already exists!')
        return

    test_img = sorted(glob.glob(f'data/{cell_type}/test/image/{img}.jpg'))
    test_mask = sorted(glob.glob(f'data/{cell_type}/test/mask/{img}.jpg'))
    if cell_type == 'rbc':
        test_edge = sorted(glob.glob(f'data/{cell_type}/test/edge/{img}.jpg'))
    elif cell_type == 'wbc' or cell_type == 'plt':
        test_edge = None
    else:
        print('Invalid blood cell type!\n')
        return

    # initialize do_unet
    model = do_unet()

    # Check for existing weights
    if os.path.exists(f'models/{model_name}.h5'):
        print('Test succeeded\n')
        model.load_weights(f'models/{model_name}.h5')

    # load test data
    if cell_type == 'rbc':
        img, mask, edge = load_data(test_img, test_mask, test_edge, padding=padding[0])

        img_chips, mask_chips, edge_chips = slice(
            img,
            mask,
            edge,
            padding=padding[1],
            input_size=input_shape[0],
            output_size=output_shape[0],
        )
    else:
        img, mask = load_data(test_img, test_mask, padding=padding[0])

        img_chips, mask_chips = slice(
            img,
            mask,
            padding=padding[1],
            input_size=input_shape[0],
            output_size=output_shape[0],
        )

    # segment all image chips
    output = model.predict(img_chips)
    if cell_type == 'rbc':
        new_mask_chips = np.array(output[0])
        new_edge_chips = np.array(output[1])
    elif cell_type == 'wbc' or cell_type == 'plt':
        new_mask_chips = np.array(output)

    # get image dimensions
    dimensions = [get_sizes(img)[0][0], get_sizes(img)[0][1]]

    # reshape chips arrays to be concatenated
    new_mask_chips = reshape(new_mask_chips, dimensions[0], dimensions[1])
    if cell_type == 'rbc':
        new_edge_chips = reshape(new_edge_chips, dimensions[0], dimensions[1])

    # get rid of none necessary dimension
    new_mask_chips = np.squeeze(new_mask_chips)
    if cell_type == 'rbc':
        new_edge_chips = np.squeeze(new_edge_chips)

    # concatenate chips into a single image (mask and edge)
    new_mask = concat(new_mask_chips)
    if cell_type == 'rbc':
        new_edge = concat(new_edge_chips)
    
    # create output directories if it does not exist
    if not os.path.exists('output/'):
        os.makedirs('output/')

    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    # save predicted mask and edge
    plt.imsave(f'{output_directory}/mask.png', new_mask)
    if cell_type == 'rbc':
        plt.imsave(f'{output_directory}/edge.png', new_edge)
        plt.imsave(f'{output_directory}/edge_mask.png', new_mask - new_edge)

    # organize results into one figure
    if cell_type == 'rbc':
        fig = plt.figure(figsize=(25, 12), dpi=80)
        fig.subplots_adjust(hspace=0.1, wspace=0.1)
        ax  = fig.add_subplot(2, 3, 1)
        ax.set_title('Test image')
        ax.imshow(np.array(img)[0,:,:,:])
        ax  = fig.add_subplot(2, 3, 2)
        ax.set_title('Test mask')
        ax.imshow(np.array(mask)[0,:,:])
        ax  = fig.add_subplot(2, 3, 3)
        ax.set_title('Test edge')
        ax.imshow(np.array(edge)[0,:,:])
        ax  = fig.add_subplot(2, 3, 5)
        ax.set_title('Predicted mask')
        ax.imshow(new_mask)
        ax  = fig.add_subplot(2, 3, 6)
        ax.set_title('Predicted edge')
        ax.imshow(new_edge)
    elif cell_type == 'wbc' or cell_type == 'plt':
        fig = plt.figure(figsize=(25, 12), dpi=80)
        fig.subplots_adjust(hspace=0.1, wspace=0.1)
        ax  = fig.add_subplot(2, 2, 1)
        ax.set_title('Test image')
        ax.imshow(np.array(img)[0,:,:,:])
        ax  = fig.add_subplot(2, 2, 2)
        ax.set_title('Test mask')
        ax.imshow(np.array(mask)[0,:,:])
        ax  = fig.add_subplot(2, 2, 4)
        ax.set_title('Predicted mask')
        ax.imshow(new_mask)

    # save the figure as a sample output
    plt.savefig('sample.png')


def evaluate(model_name='mse'):
    '''
    Evaluate an already trained model
    :param model_name --> the model weights that we want to evaluate

    :return --> output the evaluated model weights directly to the screen.
    '''
    test_img_list = sorted(glob.glob(f'data/{cell_type}/test/image/*.jpg'))
    test_mask_list = sorted(glob.glob(f'data/{cell_type}/test/mask/*.jpg'))
    if cell_type == 'rbc':
        test_edge_list = sorted(glob.glob(f'data/{cell_type}/test/edge/*.jpg'))
    elif cell_type == 'wbc' or cell_type == 'plt':
        test_edge_list = None
    else:
        print('Invalid blood cell type!\n')
        return

    # initialize do_unet
    model = do_unet()

    # load weights
    if os.path.exists(f'models/{model_name}.h5'):
        model.load_weights(f'models/{model_name}.h5')
    else:
        train(model_name)

    # load test data
    if cell_type == 'rbc':
        img, mask, edge = load_data(test_img, test_mask, test_edge, padding=padding[0])

        img_chips, mask_chips, edge_chips = slice(
            img,
            mask,
            edge,
            padding=padding[1],
            input_size=input_shape[0],
            output_size=output_shape[0]
        )
    elif cell_type == 'wbc' or cell_type == 'plt':
        img, mask = load_data(test_img, test_mask, padding=padding[0])

        img_chips, mask_chips = test_slice(
            img,
            mask,
            padding=padding[1],
            input_size=input_shape[0],
            output_size=output_shape[0]
        )

    # print the evaluated accuracies
    if cell_type == 'rbc':
        print(model.evaluate(img_chips, (mask_chips, edge_chips)))
    else:
        print(model.evaluate(img_chips, (mask_chips)))


def threshold(img='edge.png'):
    '''
    This is the threshold function, which applied an otsu threshold
    to the input image (param: img)
    :param img --> the image we want to threshold

    :return --> saves the output thresholded image under the folder output/<cell_type>/threshold_<img>.png
    '''
    if not os.path.exists(output_directory + '/' + img):
        print('Image does not exist!')
        return

    # substract if img is edge_mask
    if img == 'edge_mask.png':
        mask = cv2.imread(f'{output_directory}/threshold_mask.png')
        edge = cv2.imread(f'{output_directory}/threshold_edge.png')

        # substract mask - edge
        image = mask - edge
    else:
        # getting the input image
        image = cv2.imread(f'{output_directory}/{img}')

        # convert to grayscale and apply otsu's thresholding
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        otsu_threshold, image = cv2.threshold(image, 0, 255, cv2.THRESH_OTSU,)

    # save the resulting thresholded image
    plt.imsave(f'{output_directory}/threshold_{img}', image, cmap='gray')
    

def hough_transform(img='edge.png'):
    '''
    This is the Circle Hough Transform function (CHT), which counts the
    circles from an input image.
    :param img --> the input image that we want to count circles from.

    :return --> saves the output image under the folder output/<cell_type>/hough_transform.png
    '''
    if not os.path.exists(output_directory + '/' + img):
        print('Image does not exist!')
        return

    # getting the input image
    image = cv2.imread(f'{output_directory}/{img}')
    # convert to grayscale
    img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    # apply hough circles
    if cell_type == 'rbc':
        circles = cv2.HoughCircles(img, cv2.HOUGH_GRADIENT, 1, minDist=33, maxRadius=55, minRadius=28, param1=30, param2=20)
    elif cell_type == 'wbc' or cell_type == 'plt':
        circles = cv2.HoughCircles(img, cv2.HOUGH_GRADIENT, 1, minDist=51, maxRadius=120, minRadius=48, param1=70, param2=20)
    output = img.copy()

    # ensure at least some circles were found
    if circles is not None:
        # convert the (x, y) coordinates and radius of the circles to integers
        circles = np.round(circles[0, :]).astype("int")
        # loop over the (x, y) coordinates and radius of the circles
        for (x, y, r) in circles:
            # draw the circle in the output image, then draw a rectangle
            # corresponding to the center of the circle
            cv2.circle(output, (x, y), r, (0, 0, 255), 2)
            cv2.rectangle(output, (x - 5, y - 5), (x + 5, y + 5), (0, 0, 255), -1)
        # save the output image
        plt.imsave(f'{output_directory}/hough_transform.png',
                   np.hstack([img, output]))

    # show the hough_transform results
    print(f'Hough transform: {len(circles)}')


def component_labeling(img='edge.png'):
    '''
    This is the Connected Component Labeling (CCL), which labels all the connected objects from an input image
    :param img --> the input image that we want to apply CCL to.

    :return --> saves the output image under the folder output/<cell_type>/component_labeling.png
    '''
    if not os.path.exists(output_directory + '/' + img):
        print('Image does not exist!')
        return

    # getting the input image
    image = cv2.imread(f'{output_directory}/{img}')
    # convert to grayscale
    img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    # converting those pixels with values 1-127 to 0 and others to 1
    img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)[1]
    # applying cv2.connectedComponents() 
    num_labels, labels = cv2.connectedComponents(img)
    
    # map component labels to hue val, 0-179 is the hue range in OpenCV
    label_hue = np.uint8(179*labels/np.max(labels))
    blank_ch = 255*np.ones_like(label_hue)
    output = cv2.merge([label_hue, blank_ch, blank_ch])

    # converting cvt to BGR
    output = cv2.cvtColor(output, cv2.COLOR_HSV2BGR)

    # set bg label to black
    output[label_hue==0] = 0
    
    # saving image after Component Labeling
    plt.imsave(f'{output_directory}/component_labeling.png',
               np.hstack([image, output]))

    # show number of labels detected
    print(f'Connected component labeling: {num_labels}')


def distance_transform(img='threshold_edge_mask.png'):
    '''
    This is the Euclidean Distance Transform function (EDT), which applied the distance transform algorithm to an input image>
    :param img --> the input image that we want to apply EDT to.

    :return --> saves the output image under the folder output/<cell_type>/distance_transform.png
    '''
    if not os.path.exists(output_directory + '/' + img):
        print('Image does not exist!')
        return

    # getting the input image
    image = cv2.imread(f'{output_directory}/{img}')
    # convert to numpy array
    img = np.asarray(image)
    # convert to grayscale
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    img = ndimage.distance_transform_edt(img)
    img = ndimage.binary_dilation(img)

    # saving image after Component Labeling
    plt.imsave(f'{output_directory}/distance_transform.png', img, cmap='gray')


if __name__ == '__main__':
    '''
    The main function, which handles all the function call
    (later on, this will dynamically call functions according user input)

    Cell type can be wither rbc, wbc or plt
    which stands for:
        rbc --> Red blood cells
        wbc --> White blood cells
        plt --> Platelets
    '''
    cell_type        = 'rbc'             # rbc, wbc or plt
    model_name       = cell_type
    if cell_type == 'rbc':
        output_directory = 'output/rbc'
    elif cell_type == 'wbc':
        output_directory = 'output/wbc'
    elif cell_type == 'plt':
        output_directory = 'output/plt'
    else:
        print('Invalid blood cell type!\n')

    # train('rbc')
    # evaluate(model_name='rbc')
    predict()
    threshold('mask.png')

    if cell_type == 'rbc':
        threshold('edge.png')
        threshold('edge_mask.png')
        distance_transform('threshold_edge_mask.png')
        hough_transform('edge.png')
    else:
        distance_transform('threshold_mask.png')
        hough_transform('mask.png')

    component_labeling('distance_transform.png')

