#!python

##############################################
#                                            #
# Executable file that counts and segments   #
# circles from blood cell images             #
#                                            #
# Author: Amine Neggazi                      #
# Email: neggazimedlamine@gmail/com          #
# Nick: nemo256                              #
#                                            #
# Please read cbc/LICENSE                    #
#                                            #
##############################################

# imports
import glob
import os
import cv2
import json
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from scipy import ndimage


# global variables
input_shape  = (188, 188, 3)
output_shape = (100, 100, 1)
padding      = [200, 100]

def conv_bn(filters,
            model,
            kernel=(3, 3),
            activation='relu', 
            strides=(1, 1),
            padding='valid',
            type='normal'):
    '''
    This is a custom convolution function:
    :param filters --> number of filters for each convolution
    :param kernel --> the kernel size
    :param activation --> the general activation function (relu)
    :param strides --> number of strides
    :param padding --> model padding (can be valid or same)
    :param type --> to indicate if it is a transpose or normal convolution

    :return --> returns the output after the convolutions.
    '''
    if type == 'transpose':
        kernel = (2, 2)
        strides = 2
        conv = tf.keras.layers.Conv2DTranspose(filters, kernel, strides, padding)(model)
    else:
        conv = tf.keras.layers.Conv2D(filters, kernel, strides, padding)(model)

    conv = tf.keras.layers.BatchNormalization()(conv)
    conv = tf.keras.layers.Activation(activation)(conv)

    return conv


def max_pool(input):
    '''
    This is a general max pool function with custom parameters.
    '''
    return tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(input)


def concatenate(input1, input2, crop):
    '''
    This is a general concatenation function with custom parameters.
    '''
    return tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(crop)(input1), input2])


def get_callbacks(name):
    '''
    This is a custom function to save only the best checkpoint.
    :param name --> the input model name
    '''
    return [
        tf.keras.callbacks.ModelCheckpoint(f'models/{name}.h5',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    ]


def do_unet():
    '''
    This is the dual output U-Net model.
    It is a custom U-Net with optimized number of layers.
    Please read model.summary()
    '''
    inputs = tf.keras.layers.Input((188, 188, 3))

    # encoder
    filters = 32
    encoder1 = conv_bn(3*filters, inputs)
    encoder1 = conv_bn(filters, encoder1, kernel=(1, 1))
    encoder1 = conv_bn(filters, encoder1)
    pool1 = max_pool(encoder1)

    filters *= 2
    encoder2 = conv_bn(filters, pool1)
    encoder2 = conv_bn(filters, encoder2)
    pool2 = max_pool(encoder2)

    filters *= 2
    encoder3 = conv_bn(filters, pool2)
    encoder3 = conv_bn(filters, encoder3)
    pool3 = max_pool(encoder3)

    filters *= 2
    encoder4 = conv_bn(filters, pool3)
    encoder4 = conv_bn(filters, encoder4)

    # decoder
    filters /= 2
    decoder1 = conv_bn(filters, encoder4, type='transpose')
    decoder1 = concatenate(encoder3, decoder1, 4)
    decoder1 = conv_bn(filters, decoder1)
    decoder1 = conv_bn(filters, decoder1)

    filters /= 2
    decoder2 = conv_bn(filters, decoder1, type='transpose')
    decoder2 = concatenate(encoder2, decoder2, 16)
    decoder2 = conv_bn(filters, decoder2)
    decoder2 = conv_bn(filters, decoder2)

    filters /= 2
    decoder3 = conv_bn(filters, decoder2, type='transpose')
    decoder3 = concatenate(encoder1, decoder3, 40)
    decoder3 = conv_bn(filters, decoder3)
    decoder3 = conv_bn(filters, decoder3)

    out_mask = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='mask')(decoder3)

    if cell_type == 'rbc':
        out_edge = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid', name='edge')(decoder3)
        model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask, out_edge))
    elif cell_type == 'wbc' or cell_type == 'plt':
        model = tf.keras.models.Model(inputs=inputs, outputs=(out_mask))

    opt = tf.optimizers.Adam(learning_rate=0.0001)

    if cell_type == 'rbc':
        model.compile(loss='mse',
                      loss_weights=[0.1, 0.9],
                      optimizer=opt,
                      metrics='accuracy')
    elif cell_type == 'wbc' or cell_type == 'plt':
        model.compile(loss='mse',
                      optimizer=opt,
                      metrics='accuracy')

    return model


def load_image_list(img_files, gray=False):
    '''
    This is the load image list function, which loads an enumerate
    of images (param: img_files)
    :param img_files --> the input image files which we want to read

    :return imgs --> the images that we read
    '''
    imgs = []
    if gray:
        for image_file in img_files:
            img = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
            img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)[1]
            imgs += [img]

    else:
        for image_file in img_files:
            imgs += [cv2.imread(image_file)]
    return imgs


def clahe_images(img_list):
    '''
    This is the clahe images function, which applies a clahe threshold
    the input image list.
    :param img_files --> the input image files which we want to read

    :return img_list --> the output images
    '''
    for i, img in enumerate(img_list):
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))

        lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
        lab[..., 0] = clahe.apply(lab[..., 0])
        img_list[i] = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
    return img_list


def preprocess_data(imgs, mask, edge=None, padding=padding[1]):
    '''
    This is the preprocess data function, which adds a padding to 
    the input images, masks and edges if there are any.
    :param imgs --> the input list of images.
    :param mask --> the input list of masks.
    :param edge --> the input list of edges.
    :param padding --> the input padding which is going to be applied.

    :return tuple(imgs, mask, edge if exists) --> output images, masks and edges with padding added.
    '''
    imgs = [np.pad(img, ((padding, padding),
                         (padding, padding), (0, 0)), mode='constant') for img in imgs]
    mask = [np.pad(mask, ((padding, padding),
                          (padding, padding)), mode='constant') for mask in mask]
    if edge is not None:
        edge = [np.pad(edge, ((padding, padding),
                              (padding, padding)), mode='constant') for edge in edge]

    if edge is not None:
        return imgs, mask, edge

    return imgs, mask


def load_data(img_list, mask_list, edge_list=None, padding=padding[1]):
    '''
    This is the load data function, which will handle image loading and preprocessing.
    :param img_list --> list of input images
    :param mask_list --> list of input masks
    :param edge_list --> list of input edges
    :param padding --> padding to be applied on preprocessing

    :return tuple(imgs, masks and edges if exists) --> the output preprocessed imgs, masks and edges.
    '''
    imgs = load_image_list(img_list)
    imgs = clahe_images(imgs)

    mask = load_image_list(mask_list, gray=True)
    if edge_list:
        edge = load_image_list(edge_list, gray=True)
    else:
        edge = None

    return preprocess_data(imgs, mask, edge, padding=padding)


def aug_lum(image, factor=None):
    '''
    This is the augment luminosity function, which we apply to
    augment the luminosity of an input image.
    :param image --> the input image we want to augment
    :param factor --> the factor of luminosity augment (default is 0.5 * random number)

    :return image --> the output luminosity augmented image
    '''
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    hsv = hsv.astype(np.float64)

    if factor is None:
        lum_offset = 0.5 + np.random.uniform()
    else:
        lum_offset = factor

    hsv[..., 2] = hsv[..., 2] * lum_offset
    hsv[..., 2][hsv[..., 2] > 255] = 255
    hsv = hsv.astype(np.uint8)

    return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)


def aug_img(image):
    '''
    This is the augment colors function, which we apply to
    augment the colors of an given image.
    :param image --> the input image we want to augment

    :return image --> the output colors augmented image
    '''
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    hsv = hsv.astype(np.float64)

    hue_offset = 0.8 + 0.4*np.random.uniform()
    sat_offset = 0.5 + np.random.uniform()
    lum_offset = 0.5 + np.random.uniform()

    hsv[..., 0] = hsv[..., 0] * hue_offset
    hsv[..., 1] = hsv[..., 1] * sat_offset
    hsv[..., 2] = hsv[..., 2] * lum_offset

    hsv[..., 0][hsv[..., 0] > 255] = 255
    hsv[..., 1][hsv[..., 1] > 255] = 255
    hsv[..., 2][hsv[..., 2] > 255] = 255

    hsv = hsv.astype(np.uint8)

    return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)


def train_generator(imgs, mask, edge=None,
                    scale_range=None,
                    padding=padding[1],
                    input_size=input_shape[0],
                    output_size=output_shape[0],
                    skip_empty=False):
    '''
    This is the train generator function, which generates the train dataset.
    :param imgs --> the input images
    :param mask --> the input masks
    :param edge --> the input edges if there are any (red blood cells only)
    :param scale_range --> the factor (i, j) of rescaling.
    :param padding --> the padding which will be applied to each image
    :param input_size --> the input shape
    :param output_size --> the output shape
    :param skip_empty --> skip empty chips (random if not set)

    :return chips --> yields an image, mask and edge chip each time it gets executed (called)
    '''
    if scale_range is not None:
        scale_range = [1 - scale_range, 1 + scale_range]
    while True:
        # select which type of cell to return
        chip_type = np.random.choice([True, False])

        while True:
            # pick random image
            i = np.random.randint(len(imgs))

            # pick random central location in the image (200 + 196/2)
            center_offset = padding + (output_size / 2)
            x = np.random.randint(center_offset, imgs[i].shape[0] - center_offset)
            y = np.random.randint(center_offset, imgs[i].shape[1] - center_offset)

            # scale the box randomly from x0.8 - 1.2x original size
            scale = 1
            if scale_range is not None:
                scale = scale_range[0] + ((scale_range[0] - scale_range[0]) * np.random.random())

            # find the edges of a box around the image chip and the mask chip
            chip_x_l = int(x - ((input_size / 2) * scale))
            chip_x_r = int(x + ((input_size / 2) * scale))
            chip_y_l = int(y - ((input_size / 2) * scale))
            chip_y_r = int(y + ((input_size / 2) * scale))

            mask_x_l = int(x - ((output_size / 2) * scale))
            mask_x_r = int(x + ((output_size / 2) * scale))
            mask_y_l = int(y - ((output_size / 2) * scale))
            mask_y_r = int(y + ((output_size / 2) * scale))

            # take a slice of the image and mask accordingly
            temp_chip = imgs[i][chip_x_l:chip_x_r, chip_y_l:chip_y_r]
            temp_mask = mask[i][mask_x_l:mask_x_r, mask_y_l:mask_y_r]
            if edge is not None:
                temp_edge = edge[i][mask_x_l:mask_x_r, mask_y_l:mask_y_r]

            if skip_empty:
                if ((temp_mask > 0).sum() > 5) is chip_type:
                    continue

            # resize the image chip back to 380 and the mask chip to 196
            temp_chip = cv2.resize(temp_chip,
                                   (input_size, input_size),
                                   interpolation=cv2.INTER_CUBIC)
            temp_mask = cv2.resize(temp_mask,
                                   (output_size, output_size),
                                   interpolation=cv2.INTER_NEAREST)
            if edge is not None:
                temp_edge = cv2.resize(temp_edge,
                                       (output_size, output_size),
                                       interpolation=cv2.INTER_NEAREST)

            # randomly rotate (like below)
            rot = np.random.randint(4)
            temp_chip = np.rot90(temp_chip, k=rot, axes=(0, 1))
            temp_mask = np.rot90(temp_mask, k=rot, axes=(0, 1))
            if edge is not None:
                temp_edge = np.rot90(temp_edge, k=rot, axes=(0, 1))

            # randomly flip
            if np.random.random() > 0.5:
                temp_chip = np.flip(temp_chip, axis=1)
                temp_mask = np.flip(temp_mask, axis=1)
                if edge is not None:
                    temp_edge = np.flip(temp_edge, axis=1)

            # randomly luminosity augment
            temp_chip = aug_lum(temp_chip)

            # randomly augment chip
            temp_chip = aug_img(temp_chip)

            # rescale the image
            temp_chip = temp_chip.astype(np.float32) * 2
            temp_chip /= 255
            temp_chip -= 1

            # later on ... randomly adjust colours
            if edge is not None:
                yield temp_chip, ((temp_mask > 0).astype(float)[..., np.newaxis], 
                                  (temp_edge > 0).astype(float)[..., np.newaxis])
            else:
                yield temp_chip, ((temp_mask > 0).astype(float)[..., np.newaxis])
            break


def test_chips(imgs, mask,
               edge=None,
               padding=padding[1],
               input_size=input_shape[0],
               output_size=output_shape[0]):
    '''
    This is the test chips function, which generates the test dataset.
    :param imgs --> the input images
    :param mask --> the input masks
    :param edge --> the input edges if there are any (red blood cells only)
    :param padding --> the padding which will be applied to each image
    :param input_size --> the input shape
    :param output_size --> the output shape

    :return chips --> yields an image, mask and edge chip each time it gets executed (called)
    '''
    center_offset = padding + (output_size / 2)
    for i, _ in enumerate(imgs):
        for x in np.arange(center_offset, imgs[i].shape[0] - input_size / 2, output_size):
            for y in np.arange(center_offset, imgs[i].shape[1] - input_size / 2, output_size):
                chip_x_l = int(x - (input_size / 2))
                chip_x_r = int(x + (input_size / 2))
                chip_y_l = int(y - (input_size / 2))
                chip_y_r = int(y + (input_size / 2))

                mask_x_l = int(x - (output_size / 2))
                mask_x_r = int(x + (output_size / 2))
                mask_y_l = int(y - (output_size / 2))
                mask_y_r = int(y + (output_size / 2))

                temp_chip = imgs[i][chip_x_l:chip_x_r, chip_y_l:chip_y_r]
                temp_mask = mask[i][mask_x_l:mask_x_r, mask_y_l:mask_y_r]
                if edge is not None:
                    temp_edge = edge[i][mask_x_l:mask_x_r, mask_y_l:mask_y_r]

                temp_chip = temp_chip.astype(np.float32) * 2
                temp_chip /= 255
                temp_chip -= 1

                if edge is not None:
                    yield temp_chip, ((temp_mask > 0).astype(float)[..., np.newaxis], 
                                      (temp_edge > 0).astype(float)[..., np.newaxis])
                else:
                    yield temp_chip, ((temp_mask > 0).astype(float)[..., np.newaxis])
                break


def slice(imgs, mask,
          edge=None,
          padding=padding[1],
          input_size=input_shape[0],
          output_size=output_shape[0]):
    '''
    This is the slice function, which slices each image into image chips.
    :param imgs --> the input images
    :param mask --> the input masks
    :param edge --> the input edges if there are any (red blood cells only)
    :param padding --> the padding which will be applied to each image
    :param input_size --> the input shape
    :param output_size --> the output shape

    :return list tuple (list, list, list) --> the tuple list of output (image, mask and edge chips)
    '''
    img_chips = []
    mask_chips = []
    if edge is not None:
        edge_chips = []

    center_offset = padding + (output_size / 2)
    for i, _ in enumerate(imgs):
        for x in np.arange(center_offset, imgs[i].shape[0] - input_size / 2, output_size):
            for y in np.arange(center_offset, imgs[i].shape[1] - input_size / 2, output_size):
                chip_x_l = int(x - (input_size / 2))
                chip_x_r = int(x + (input_size / 2))
                chip_y_l = int(y - (input_size / 2))
                chip_y_r = int(y + (input_size / 2))

                mask_x_l = int(x - (output_size / 2))
                mask_x_r = int(x + (output_size / 2))
                mask_y_l = int(y - (output_size / 2))
                mask_y_r = int(y + (output_size / 2))

                temp_chip = imgs[i][chip_x_l:chip_x_r, chip_y_l:chip_y_r]
                temp_mask = mask[i][mask_x_l:mask_x_r, mask_y_l:mask_y_r]
                if edge is not None:
                    temp_edge = edge[i][mask_x_l:mask_x_r, mask_y_l:mask_y_r]

                temp_chip = temp_chip.astype(np.float32) * 2
                temp_chip /= 255
                temp_chip -= 1

                img_chips += [temp_chip]
                mask_chips += [(temp_mask > 0).astype(float)[..., np.newaxis]]
                if edge is not None:
                    edge_chips += [(temp_edge > 0).astype(float)[..., np.newaxis]]

    img_chips = np.array(img_chips)
    mask_chips = np.array(mask_chips)
    if edge is not None:
        edge_chips = np.array(edge_chips)

    if edge is not None:
        return img_chips, mask_chips, edge_chips

    return img_chips, mask_chips


def generator(img_list, mask_list, edge_list=None, type='train'):
    '''
    This is the generator function, which provides the list of image, mask and edge lists to the train generator and test chips functions.
    :param img_list --> the input list of images
    :param mask_list --> the input list of masks
    :param edge_list --> the input list of edges if there are any
    :param type --> can be either train or test, used to determine which generator function is to be called

    :return tensorflow dataset --> the output generated functions fed to tensorflow
    '''
    if cell_type == 'rbc':
        img, mask, edge = load_data(img_list, mask_list, edge_list)
    elif cell_type == 'wbc' or cell_type == 'plt':
        img, mask = load_data(img_list, mask_list)
        edge = None

    def gen():
        if type == 'train':
            return train_generator(img, mask, edge,
                                   padding=padding[0],
                                   input_size=input_shape[0],
                                   output_size=output_shape[0])
        elif type == 'test':
            return test_chips(img, mask, edge,
                              padding=padding[0],
                              input_size=input_shape[0],
                              output_size=output_shape[0])

    # load train dataset to tensorflow for training
    if cell_type == 'rbc':
        return tf.data.Dataset.from_generator(
            gen,
            (tf.float64, ((tf.float64), (tf.float64))),
            (input_shape, (output_shape, output_shape))
        )
    elif cell_type == 'wbc' or cell_type == 'plt':
        return tf.data.Dataset.from_generator(
            gen,
            (tf.float64, (tf.float64)),
            (input_shape, (output_shape))
        )


# functions
def train(model_name='mse', epochs=50):
    '''
    This is the train function, so that we can train multiple models
    according to blood cell types and multiple input shapes aswell.
    :param model_name --> model weights that we want saved
    :param epochs --> how many epochs we want the model to be trained

    :return --> saves the model weights under <model_name>.h5 with
    its respective history file <model_name>_history.npy
    '''

    train_img_list = sorted(glob.glob(f'data/{cell_type}/train/image/*.jpg'))
    test_img_list = sorted(glob.glob(f'data/{cell_type}/test/image/*.jpg'))
    train_mask_list = sorted(glob.glob(f'data/{cell_type}/train/mask/*.jpg'))
    test_mask_list = sorted(glob.glob(f'data/{cell_type}/test/mask/*.jpg'))

    if cell_type == 'rbc':
        train_edge_list = sorted(glob.glob(f'data/{cell_type}/train/edge/*.jpg'))
        test_edge_list = sorted(glob.glob(f'data/{cell_type}/test/edge/*.jpg'))
    elif cell_type == 'wbc' or cell_type == 'plt':
        train_edge_list = None
        test_edge_list = None
    else:
        print('Invalid blood cell type!\n')
        return

    # loading train dataset and test datasets
    train_dataset = generator(
        train_img_list,
        train_mask_list,
        train_edge_list,
        type='train'
    )
    test_dataset = generator(
        test_img_list,
        test_mask_list,
        test_edge_list,
        type='test'
    )

    # initializing the do_unet model
    model = do_unet()

    # create models directory if it does not exist
    if not os.path.exists('models/'):
        os.makedirs('models/')

    # Check for existing weights
    if os.path.exists(f'models/{model_name}.h5'):
        model.load_weights(f'models/{model_name}.h5')

    # fitting the model
    history = model.fit(
        train_dataset.batch(8),
        validation_data=test_dataset.batch(8),
        epochs=epochs,
        steps_per_epoch=125,
        max_queue_size=16,
        use_multiprocessing=True,
        workers=8,
        verbose=1,
        callbacks=get_callbacks(model_name)
    )

    # save the history
    np.save(f'models/{model_name}_history.npy', history.history)


def normalize(img):
    '''
    Normalizes an image
    :param img --> an input image that we want normalized

    :return np.array --> an output image normalized (as a numpy array)
    '''
    return np.array((img - np.min(img)) / (np.max(img) - np.min(img)))


def get_sizes(img,
              padding=padding[1],
              input=input_shape[0],
              output=output_shape[0]):
    '''
    Get full image sizes (x, y) to rebuilt the full image output
    :param img --> an input image we want to get its dimensions
    :param padding --> the default padding used on the test dataset
    :param input --> the input shape of the image (param: img)
    :param output --> the output shape of the image (param: img)

    :return couple --> a couple which contains the image dimensions as in (x, y)
    '''
    offset = padding + (output / 2)
    return [(len(np.arange(offset, img[0].shape[0] - input / 2, output)), len(np.arange(offset, img[0].shape[1] - input / 2, output)))]


def reshape(img,
            size_x,
            size_y):
    '''
    Reshape the full image output using the original sizes (x, y)
    :param img --> an input image we want to reshape
    :param size_x --> the x axis (length) of the input image (param: img)
    :param size_y --> the y axis (length) of the input image (param: img)

    :return img (numpy array) --> the output image reshaped according to the provided dimensions (size_x, size_y)
    '''
    return img.reshape(size_x, size_y, output_shape[0], output_shape[0], 1)


def concat(imgs):
    '''
    Concatenate all the output image chips to rebuild the full image
    :param imgs --> the images that we want to concatenate

    :return full_image --> the concatenation of all the provided images (param: imgs)
    '''
    return cv2.vconcat([cv2.hconcat(im_list) for im_list in imgs[:,:,:,:]])


def denoise(img):
    '''
    Remove noise from an image
    :param img --> the input image that we want to denoise (remove the noise)

    :return image --> the denoised output image
    '''
    # read the image
    img = cv2.imread(img)
    # return the denoised image
    return cv2.fastNlMeansDenoising(img, 23, 23, 7, 21)


def predict(img='Im037_0'):
    '''
    Predict (segment) blood cell images using the trained model (do_unet)
    :param img --> the image we want to predict (from the test/ directory)

    :return --> saves the predicted (segmented blood cell image) under the folder output/
    '''
    # Check for existing predictions
    if os.path.exists(f'{output_directory}/mask.png'):
        print('Prediction already exists!')
        return

    test_img = sorted(glob.glob(f'data/{cell_type}/test/image/{img}.jpg'))
    test_mask = sorted(glob.glob(f'data/{cell_type}/test/mask/{img}.jpg'))
    if cell_type == 'rbc':
        test_edge = sorted(glob.glob(f'data/{cell_type}/test/edge/{img}.jpg'))
    elif cell_type == 'wbc' or cell_type == 'plt':
        test_edge = None
    else:
        print('Invalid blood cell type!\n')
        return

    # initialize do_unet
    model = do_unet()

    # Check for existing weights
    if os.path.exists(f'models/{model_name}.h5'):
        model.load_weights(f'models/{model_name}.h5')

    # load test data
    if cell_type == 'rbc':
        img, mask, edge = load_data(test_img, test_mask, test_edge, padding=padding[0])

        img_chips, mask_chips, edge_chips = slice(
            img,
            mask,
            edge,
            padding=padding[1],
            input_size=input_shape[0],
            output_size=output_shape[0],
        )
    else:
        img, mask = load_data(test_img, test_mask, padding=padding[0])

        img_chips, mask_chips = slice(
            img,
            mask,
            padding=padding[1],
            input_size=input_shape[0],
            output_size=output_shape[0],
        )

    # segment all image chips
    output = model.predict(img_chips)
    if cell_type == 'rbc':
        new_mask_chips = np.array(output[0])
        new_edge_chips = np.array(output[1])
    elif cell_type == 'wbc' or cell_type == 'plt':
        new_mask_chips = np.array(output)

    # get image dimensions
    dimensions = [get_sizes(img)[0][0], get_sizes(img)[0][1]]

    # reshape chips arrays to be concatenated
    new_mask_chips = reshape(new_mask_chips, dimensions[0], dimensions[1])
    if cell_type == 'rbc':
        new_edge_chips = reshape(new_edge_chips, dimensions[0], dimensions[1])

    # get rid of none necessary dimension
    new_mask_chips = np.squeeze(new_mask_chips)
    if cell_type == 'rbc':
        new_edge_chips = np.squeeze(new_edge_chips)

    # concatenate chips into a single image (mask and edge)
    new_mask = concat(new_mask_chips)
    if cell_type == 'rbc':
        new_edge = concat(new_edge_chips)
    
    # create output directories if it does not exist
    if not os.path.exists('output/'):
        os.makedirs('output/')

    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    # save predicted mask and edge
    plt.imsave(f'{output_directory}/mask.png', new_mask)
    if cell_type == 'rbc':
        plt.imsave(f'{output_directory}/edge.png', new_edge)
        plt.imsave(f'{output_directory}/edge_mask.png', new_mask - new_edge)

    # organize results into one figure
    if cell_type == 'rbc':
        fig = plt.figure(figsize=(25, 12), dpi=80)
        fig.subplots_adjust(hspace=0.1, wspace=0.1)
        ax  = fig.add_subplot(2, 3, 1)
        ax.set_title('Test image')
        ax.imshow(np.array(img)[0,:,:,:])
        ax  = fig.add_subplot(2, 3, 2)
        ax.set_title('Test mask')
        ax.imshow(np.array(mask)[0,:,:])
        ax  = fig.add_subplot(2, 3, 3)
        ax.set_title('Test edge')
        ax.imshow(np.array(edge)[0,:,:])
        ax  = fig.add_subplot(2, 3, 5)
        ax.set_title('Predicted mask')
        ax.imshow(new_mask)
        ax  = fig.add_subplot(2, 3, 6)
        ax.set_title('Predicted edge')
        ax.imshow(new_edge)
    elif cell_type == 'wbc' or cell_type == 'plt':
        fig = plt.figure(figsize=(25, 12), dpi=80)
        fig.subplots_adjust(hspace=0.1, wspace=0.1)
        ax  = fig.add_subplot(2, 2, 1)
        ax.set_title('Test image')
        ax.imshow(np.array(img)[0,:,:,:])
        ax  = fig.add_subplot(2, 2, 2)
        ax.set_title('Test mask')
        ax.imshow(np.array(mask)[0,:,:])
        ax  = fig.add_subplot(2, 2, 4)
        ax.set_title('Predicted mask')
        ax.imshow(new_mask)

    # save the figure as a sample output
    plt.savefig('sample.png')


def evaluate(model_name='mse'):
    '''
    Evaluate an already trained model
    :param model_name --> the model weights that we want to evaluate

    :return --> output the evaluated model weights directly to the screen.
    '''
    test_img_list = sorted(glob.glob(f'data/{cell_type}/test/image/*.jpg'))
    test_mask_list = sorted(glob.glob(f'data/{cell_type}/test/mask/*.jpg'))
    if cell_type == 'rbc':
        test_edge_list = sorted(glob.glob(f'data/{cell_type}/test/edge/*.jpg'))
    elif cell_type == 'wbc' or cell_type == 'plt':
        test_edge_list = None
    else:
        print('Invalid blood cell type!\n')
        return

    # initialize do_unet
    model = do_unet()

    # load weights
    if os.path.exists(f'models/{model_name}.h5'):
        model.load_weights(f'models/{model_name}.h5')
    else:
        train(model_name)

    # load test data
    if cell_type == 'rbc':
        img, mask, edge = load_data(test_img, test_mask, test_edge, padding=padding[0])

        img_chips, mask_chips, edge_chips = slice(
            img,
            mask,
            edge,
            padding=padding[1],
            input_size=input_shape[0],
            output_size=output_shape[0]
        )
    elif cell_type == 'wbc' or cell_type == 'plt':
        img, mask = load_data(test_img, test_mask, padding=padding[0])

        img_chips, mask_chips = test_slice(
            img,
            mask,
            padding=padding[1],
            input_size=input_shape[0],
            output_size=output_shape[0]
        )

    # print the evaluated accuracies
    if cell_type == 'rbc':
        print(model.evaluate(img_chips, (mask_chips, edge_chips)))
    else:
        print(model.evaluate(img_chips, (mask_chips)))


def threshold(img='edge.png'):
    '''
    This is the threshold function, which applied an otsu threshold
    to the input image (param: img)
    :param img --> the image we want to threshold

    :return --> saves the output thresholded image under the folder output/<cell_type>/threshold_<img>.png
    '''
    if not os.path.exists(output_directory + '/' + img):
        print('Image does not exist!')
        return

    # substract if img is edge_mask
    if img == 'edge_mask.png':
        mask = cv2.imread(f'{output_directory}/threshold_mask.png')
        edge = cv2.imread(f'{output_directory}/threshold_edge.png')

        # substract mask - edge
        image = mask - edge
    else:
        # getting the input image
        image = cv2.imread(f'{output_directory}/{img}')

        # convert to grayscale and apply otsu's thresholding
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        otsu_threshold, image = cv2.threshold(image, 0, 255, cv2.THRESH_OTSU,)

    # save the resulting thresholded image
    plt.imsave(f'{output_directory}/threshold_{img}', image, cmap='gray')
    

def hough_transform(img='edge.png'):
    '''
    This is the Circle Hough Transform function (CHT), which counts the
    circles from an input image.
    :param img --> the input image that we want to count circles from.

    :return --> saves the output image under the folder output/<cell_type>/hough_transform.png
    '''
    if not os.path.exists(output_directory + '/' + img):
        print('Image does not exist!')
        return

    # getting the input image
    image = cv2.imread(f'{output_directory}/{img}')
    # convert to grayscale
    img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    # apply hough circles
    if cell_type == 'rbc':
        circles = cv2.HoughCircles(img, cv2.HOUGH_GRADIENT, 1, minDist=33, maxRadius=55, minRadius=28, param1=30, param2=20)
    elif cell_type == 'wbc' or cell_type == 'plt':
        circles = cv2.HoughCircles(img, cv2.HOUGH_GRADIENT, 1, minDist=51, maxRadius=120, minRadius=48, param1=70, param2=20)
    output = img.copy()

    # ensure at least some circles were found
    if circles is not None:
        # convert the (x, y) coordinates and radius of the circles to integers
        circles = np.round(circles[0, :]).astype("int")
        # loop over the (x, y) coordinates and radius of the circles
        for (x, y, r) in circles:
            # draw the circle in the output image, then draw a rectangle
            # corresponding to the center of the circle
            cv2.circle(output, (x, y), r, (0, 0, 255), 2)
            cv2.rectangle(output, (x - 5, y - 5), (x + 5, y + 5), (0, 0, 255), -1)
        # save the output image
        plt.imsave(f'{output_directory}/hough_transform.png',
                   np.hstack([img, output]))

    # show the hough_transform results
    print(f'Hough transform: {len(circles)}')


def component_labeling(img='edge.png'):
    '''
    This is the Connected Component Labeling (CCL), which labels all the connected objects from an input image
    :param img --> the input image that we want to apply CCL to.

    :return --> saves the output image under the folder output/<cell_type>/component_labeling.png
    '''
    if not os.path.exists(output_directory + '/' + img):
        print('Image does not exist!')
        return

    # getting the input image
    image = cv2.imread(f'{output_directory}/{img}')
    # convert to grayscale
    img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    # converting those pixels with values 1-127 to 0 and others to 1
    img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)[1]
    # applying cv2.connectedComponents() 
    num_labels, labels = cv2.connectedComponents(img)
    
    # map component labels to hue val, 0-179 is the hue range in OpenCV
    label_hue = np.uint8(179*labels/np.max(labels))
    blank_ch = 255*np.ones_like(label_hue)
    output = cv2.merge([label_hue, blank_ch, blank_ch])

    # converting cvt to BGR
    output = cv2.cvtColor(output, cv2.COLOR_HSV2BGR)

    # set bg label to black
    output[label_hue==0] = 0
    
    # saving image after Component Labeling
    plt.imsave(f'{output_directory}/component_labeling.png',
               np.hstack([image, output]))

    # show number of labels detected
    print(f'Connected component labeling: {num_labels}')


def distance_transform(img='threshold_edge_mask.png'):
    '''
    This is the Euclidean Distance Transform function (EDT), which applied the distance transform algorithm to an input image>
    :param img --> the input image that we want to apply EDT to.

    :return --> saves the output image under the folder output/<cell_type>/distance_transform.png
    '''
    if not os.path.exists(output_directory + '/' + img):
        print('Image does not exist!')
        return

    # getting the input image
    image = cv2.imread(f'{output_directory}/{img}')
    # convert to numpy array
    img = np.asarray(image)
    # convert to grayscale
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    img = ndimage.distance_transform_edt(img)
    img = ndimage.binary_dilation(img)

    # saving image after Component Labeling
    plt.imsave(f'{output_directory}/distance_transform.png', img, cmap='gray')


if __name__ == '__main__':
    '''
    The main function, which handles all the function call
    (later on, this will dynamically call functions according user input)

    Cell type can be wither rbc, wbc or plt
    which stands for:
        rbc --> Red blood cells
        wbc --> White blood cells
        plt --> Platelets
    '''
    cell_type        = 'rbc'             # rbc, wbc or plt
    model_name       = cell_type
    if cell_type == 'rbc':
        output_directory = 'output/rbc'
    elif cell_type == 'wbc':
        output_directory = 'output/wbc'
    elif cell_type == 'plt':
        output_directory = 'output/plt'
    else:
        print('Invalid blood cell type!\n')

    # train('rbc')
    # evaluate(model_name='rbc')
    predict()
    threshold('mask.png')

    if cell_type == 'rbc':
        threshold('edge.png')
        threshold('edge_mask.png')
        distance_transform('threshold_edge_mask.png')
        hough_transform('edge.png')
    else:
        distance_transform('threshold_mask.png')
        hough_transform('mask.png')

    component_labeling('distance_transform.png')

