#!python
# -*- coding: utf-8 -*-

import sys
from dd import DD
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
import click

# -----------------------------------------------------------------------------
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('image_mhd')
@click.argument('events')
@click.argument('output_mhd')
def garf_scale_and_Poisson_noise(image_mhd, events, output_mhd):

    # load image and squared image
    img = sitk.ReadImage(image_mhd)
    sq_filename = image_mhd.replace(".mhd", "-Squared.mhd")
    sq_img = sitk.ReadImage(sq_filename)

    # ensure that events is in float
    events = float(events)
    print('events', events);

    # get data in np
    data = sitk.GetArrayFromImage(img)
    sq_data = sitk.GetArrayFromImage(sq_img)
    print('(before) mean', np.mean(data)*events)
    print('(before) std', np.std(data)*events)

    # compute SD
    # https://stackoverflow.com/questions/19397719/could-numpy-random-poisson-be-used-to-add-poisson-noise-to-images
    std_data = np.sqrt(sq_data - np.power(data,2))
    np.nan_to_num(std_data, 0)

    # compute final data
    data = (data-std_data*events)*events  #FIXME
    #data = data*events

    # remove negative values
    data[data<0] = 0

    # add noise
    noise = np.random.poisson(data, data.shape)
    noise = noise.astype(float)
    print('(after) mean', np.mean(noise))
    print('(after) std', np.std(noise))

    # write final image
    print("Write image to ", output_mhd)
    noise_img = sitk.GetImageFromArray(noise)
    noise_img.CopyInformation(img)
    noise_img = sitk.Cast(noise_img, sitk.sitkFloat32)
    sitk.WriteImage(noise_img, output_mhd)

    # write final squared image
    output_mhd = output_mhd.replace(".mhd", "-Squared.mhd")
    print("Write sq image to ", output_mhd)
    N = events
    s = sq_data*N*N
    noise_img = sitk.GetImageFromArray(s)
    noise_img.CopyInformation(img)
    noise_img = sitk.Cast(noise_img, sitk.sitkFloat32)
    sitk.WriteImage(noise_img, output_mhd)


# -----------------------------------------------------------------------------
if __name__ == '__main__':
    garf_scale_and_Poisson_noise()
