#!python
# -*- coding: utf-8 -*-

import sys
import garf
import ntpath
import time
import json
import socket
import torch
import click
import os
import re
import SimpleITK as sitk

# -----------------------------------------------------------------------------
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument('data_root')
@click.argument('N')
@click.argument('nn_pth')
@click.argument('output_mhd')

@click.option('--gpu_batch_size', '-b',
              default='1e6',
              help='Large value is faster, but may require too much GPU memory.')

@click.option('--size', '-s',
              default=int(128),
              help='Image size in pixel (def = 128).')

@click.option('--spacing', '-p',
              default=4.41806,
              help='Image spacing in mm (def = 4.41806).')

# 58+(9.525/2.0) # collimator + half crystal length in mm
#backside_position = -36.4+5-13.3375
#backside_thickness = 80
#coll_l = spect_head_size/2-backside_position-backside_thickness/2

@click.option('--length', '-l',
              default=99,
              help='Collimator length + half crystal length in mm (def = 99 mm).')

@click.option('--events', '-e',
              default=float(1),
              help='Scale the image for N primary events (def = 1).')

@click.option('--folder/--no-folder', default=False,
              help='If true, data_root is a folder, look for all arf_*.root and stats_*.txt files (rotation), N is ignored ')

# -----------------------------------------------------------------------------
def garf_build_arf_image_with_nn(data_root,
                                 n,
                                 nn_pth,
                                 output_mhd,
                                 gpu_batch_size,
                                 size, spacing, length, events,
                                 folder):
    '''
    \b
    Build spect image from ARF data.root with the given NN.

    <DATA_ROOT> : ARF root data from the simulation\n
    <N>         : number of particles used to create the arf dataset\n
    <NN_PTH>    : filename of the neural network (.pth)

    '''
    data_filename = data_root
    model_filename = nn_pth
    output_filename = output_mhd

    nn, model = garf.load_nn(model_filename)

    param = {}
    param['gpu_batch_size'] = int(float(gpu_batch_size))
    param['size'] = size
    param['spacing'] = spacing
    param['length'] = length
    param['N_scale'] = events
    param['N_dataset'] = n

    if not folder:
        print("\nLoading dataset...")
        data, x, y, theta, phi, E = garf.load_test_dataset(data_filename)
        garf.build_arf_image_with_nn(nn, model, data, output_mhd, param, verbose=True)
    else:
        print('folder', data_root)
        # for root, dirs, files in os.walk(data_root):
        #     for name in dirs:
        #         if name.startswith(('output.')):
        #             f = os.path.join(data_root, name)
        #             print('found ', f)
        stat_files = []
        arf_files = []
        for root, dirs, files in os.walk(data_root):
            files.sort()
            for name in files:
                if name.startswith(('arf')):
                    arf_files.append(name)
                if name.startswith(('stats')):
                    stat_files.append(name)
        pattern = '# NumberOfEvents = '
        l = len(pattern)
        i=0
        for stat,arf in zip(stat_files, arf_files):
            # Open and get Nb of particles used to create arf
            s = open(os.path.join(data_root, stat)).read()
            r = re.search(pattern+'\w+', s)
            n = int(r.group(0)[l:])
            param['N_dataset'] = n
            # read arf file
            data_filename = os.path.join(data_root, arf)
            data, x, y, theta, phi, E = garf.load_test_dataset(data_filename)
            # build image
            temp = os.path.join(data_root, output_mhd+'_temp_'+str(i)+'.mhd')
            print(arf, n, data_filename, temp)
            garf.build_arf_image_with_nn(nn, model, data, temp, param, verbose=False)
            i = i + 1



# -----------------------------------------------------------------------------
if __name__ == '__main__':
    garf_build_arf_image_with_nn()
