#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import logging
import os
import time
import SimpleITK as sitk
import click
import gatetools as gt
import itk
from tqdm import tqdm
import gaga
import garf

logger = logging.getLogger(__name__)

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('gan_pth_filename')
@click.argument('garf_pth_filename')
@click.option('--n', '-n', default=1e5, help='Number of samples to generate')
@click.option('--output', '-o', required=True, help='Output filename.mhd ')
@click.option('--radius', '-r', default=float(180), help='Radius in mm')
@click.option('--start_angle', default=float(0), help='Starting angle (in deg)')
@click.option('--stop_angle', default=float(1), help='Stop angle (in deg)')
@click.option('--step_angle', default=float(1), help='Step angle (in deg)')
@click.option('--scale', '-s', default=float(1), help='Scale the final image by s. ')
@click.option('--debug_projection', default=False, is_flag=True,
              help='Debug: dump npy file with U,V,theta,phi,E on the plane')
@click.option('--sigma', default=False, is_flag=True,
              help='Compute and dump sigma (uncertainty) image')
@gt.add_options(gt.common_options)
def gaga_garf_generate_img(gan_pth_filename, garf_pth_filename,
                           radius, n, scale,
                           start_angle, stop_angle, step_angle,
                           output, debug_projection, sigma,
                           **kwargs):
    """
    \b
    Simulation of a SPECT image:
    - input particles are generated from a GAN (gaga-phsp)
    - detector plane use ARF (garf) to create the image

    \b
    <GAN_PTH_FILENAME>    : input GAN-PHSP PTH file (.pth)
    <GARF_PTH_FILENAME>   : input GARF PTH file (.pth)
    """

    # logger
    gt.logging_conf(**kwargs)

    # input number of events
    n = int(n)

    # load gan pth
    logger.info(f'Reading GAN-PHSP from {gan_pth_filename}')
    gan_params, G, D, optim, dtypef = gaga.load(gan_pth_filename)

    # load garf pth
    logger.info(f'Reading GARF from {gan_pth_filename}')
    garf_nn, garf_model = garf.load_nn(garf_pth_filename, verbose=False)

    # initialisation
    batch_size = 4e6
    t0 = time.time()

    # initialisation plane
    print('FIXME : image_plane_size_mm')
    image_plane_size_mm = [576, 446]

    # initialisation garf
    garf_param = {}
    garf_param['gpu_batch_size'] = int(4e5)  # float(gpu_batch_size))
    garf_param['size'] = 128
    garf_param['spacing'] = 4.41806
    garf_param['length'] = 99
    garf_param['N_scale'] = scale
    garf_param['N_dataset'] = n
    size_mm = garf_param['size'] * garf_param['spacing'] / 2.0

    # initialisation
    gan_batch_size = 2e5

    # loop over angles
    angle = start_angle
    bar_n = (stop_angle - start_angle) / step_angle * n
    pbar = tqdm(total=bar_n)
    a = 0
    p = {"gan_params": gan_params,
         "G": G,
         "batch_size": batch_size,
         "gan_batch_size": gan_batch_size,
         # "plane": plane,
         "image_plane_size_mm": image_plane_size_mm,
         "debug": False,  #### FIXME
         "garf_nn": garf_nn,
         "garf_model": garf_model,
         "pbar": pbar,
         "n": n,
         "garf_param": garf_param}

    while angle < stop_angle:
        tqdm.write(f'Angle {angle} deg')
        ev = 0

        plane = gaga.init_plane(batch_size, angle=angle, radius=radius)
        p["plane"] = plane
        image, sq_image = gaga.gaga_garf_generate_image(p)

        # save image
        b, extension = os.path.splitext(output)
        out = f'{b}_{str(angle).zfill(5)}{extension}'
        sitk.WriteImage(image, out)

        # save image
        sout = f'{b}_{str(angle).zfill(5)}-Squared{extension}'
        sitk.WriteImage(sq_image, sout)

        # compute sigma if needed
        if sigma:
            filenames = []
            sfilenames = []
            filenames.append(out)
            sfilenames.append(sout)
            sigout = f'{b}_{str(angle).zfill(5)}_sigma{extension}'
            nevents = n
            sigma_flag = True
            threshold = 0
            print(filenames, sfilenames, nevents)
            uncertainty, m, nb = gt.image_uncertainty_by_slice(filenames, sfilenames, nevents, sigma_flag, threshold)
            itk.imwrite(uncertainty, sigout)

        ev += batch_size
        # pbar.update(batch_size)
        angle += step_angle
        a = a + 1

    pbar.close()


# --------------------------------------------------------------------------
if __name__ == '__main__':
    gaga_garf_generate_img()
