#!python
# -*- coding: utf-8 -*-

import sys
import numpy as np
import re
import SimpleITK as sitk
import matplotlib.pyplot as plt
import click

def dump_img(img, data, filename):
    final_img = sitk.GetImageFromArray(data)
    final_img.CopyInformation(img)
    final_img = sitk.Cast(final_img, sitk.sitkFloat32)
    sitk.WriteImage(final_img, filename)

# -----------------------------------------------------------------------------
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('image_mhd')

@click.option('--ref', required=True, help='Stat file of the reference simulation')
@click.option('--jobs', '-j', default=1.0, help='Consider the data have been produced by j jobs and will be divided by j')
@click.option('--noise/--not-noise', '-n', default=False, help='Add Poisson noise')

def garf_scale(image_mhd, ref, jobs, noise):

    # load image and squared image
    img = sitk.ReadImage(image_mhd)
    #sq_filename = image_mhd.replace(".mhd", "-Squared.mhd")
    #sq_img = sitk.ReadImage(sq_filename)

    # load stat file of the reference simuation
    fp = open(ref, 'r')
    line = fp.readline()
    while line:
        m = re.match(r'.*NumberOfEvents = (.*)', line)
        if m:
            events = float(m.group(1))
        line = fp.readline()
    fp.close()
    print('Read number of event', events)

    # get data in np
    data = sitk.GetArrayFromImage(img)
    #sq_data = sitk.GetArrayFromImage(sq_img)
    data = data/jobs*events
    #sq_data = sq_data/jobs*events
    print('(before) mean', np.mean(data))
    #print('(before) std', np.std(data))

    # add noise
    # https://stackoverflow.com/questions/19397719/could-numpy-random-poisson-be-used-to-add-poisson-noise-to-images
    if noise:
        init_data = data
        data = np.random.poisson(data, data.shape)
        data = data.astype(float)
        print('(noise) mean', np.mean(data))
        #print('(noise) std', np.std(data))
        output_mhd = image_mhd.replace(".mhd", "-Poisson-noise.mhd")
        print("Write noise to ", output_mhd)
        dump_img(img, init_data-data, output_mhd)

    # write final image
    output_mhd = image_mhd.replace(".mhd", "-s.mhd")
    print("Write image to ", output_mhd)
    dump_img(img, data, output_mhd)

    # write final squared image
    #output_mhd = output_mhd.replace(".mhd", "-Squared.mhd")
    #print("Write sq image to ", output_mhd)
    #dump_img(img, sq_data, output_mhd)


# -----------------------------------------------------------------------------
if __name__ == '__main__':
    garf_scale()
