#!python
# -*- coding: utf-8 -*-

import sys
import garf
import ntpath
import time
import json
import socket
import torch
import click
import os
import re
import SimpleITK as sitk

# -----------------------------------------------------------------------------
CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.argument("data_root")
@click.argument("N")
@click.argument("nn_pth")
@click.argument("output_mhd")
@click.option(
    "--gpu_batch_size",
    "-b",
    default="1e6",
    help="Large value is faster, but may require too much GPU memory.",
)
@click.option("--size", "-s", default=int(128), help="Image size in pixel (def = 128).")
@click.option(
    "--spacing", "-p", default=4.41806, help="Image spacing in mm (def = 4.41806)."
)

# 58+(9.525/2.0) # collimator + half crystal length in mm
# backside_position = -36.4+5-13.3375
# backside_thickness = 80
# coll_l = spect_head_size/2-backside_position-backside_thickness/2


@click.option(
    "--length",
    "-l",
    default=99,
    help="Collimator length + half crystal length in mm (def = 99 mm).",
)
@click.option(
    "--events",
    "-e",
    default=float(1),
    help="Scale the image for N primary events (def = 1).",
)
@click.option(
    "--folder/--no-folder",
    default=False,
    help="If true, data_root is a folder, look for all arf_*.root and stats_*.txt files (rotation), N is ignored ",
)

# -----------------------------------------------------------------------------
def garf_build_arf_image_with_nn(
    data_root,
    n,
    nn_pth,
    output_mhd,
    gpu_batch_size,
    size,
    spacing,
    length,
    events,
    folder,
):
    """
    \b
    Build spect image from ARF data.root with the given NN.

    <DATA_ROOT> : ARF root data from the simulation\n
    <N>         : number of particles used to create the arf dataset\n
    <NN_PTH>    : filename of the neural network (.pth)

    """
    data_filename = data_root
    model_filename = nn_pth
    output_sq_mhd = output_mhd.replace(".mhd", "-Squared.mhd")

    nn, model = garf.load_nn(model_filename)

    param = {}
    param["gpu_batch_size"] = int(float(gpu_batch_size))
    param["size"] = size
    param["spacing"] = spacing
    param["length"] = length
    param["N_scale"] = events
    param["N_dataset"] = n

    if not folder:
        print("\nLoading dataset...")
        data, x, y, theta, phi, E = garf.load_test_dataset(data_filename)
        # garf.build_arf_image_with_nn(nn, model, data, output_mhd, param, verbose=True)
        img, sq_img = garf.build_arf_image_with_nn(nn, model, data, param)
        sitk.WriteImage(img, output_mhd)
        sitk.WriteImage(sq_img, output_sq_mhd)
    else:
        print("folder", data_root)
        # for root, dirs, files in os.walk(data_root):
        #     for name in dirs:
        #         if name.startswith(('output.')):
        #             f = os.path.join(data_root, name)
        #             print('found ', f)
        stat_files = []
        arf_files = []
        for root, dirs, files in os.walk(data_root):
            files.sort()
            for name in files:
                if name.startswith(("arf")):
                    arf_files.append(name)
                if name.startswith(("stats")):
                    stat_files.append(name)
        pattern = "# NumberOfEvents = "
        l = len(pattern)
        i = 0
        for stat, arf in zip(stat_files, arf_files):
            # Open and get Nb of particles used to create arf
            s = open(os.path.join(data_root, stat)).read()
            r = re.search(pattern + "\w+", s)
            n = int(r.group(0)[l:])
            param["N_dataset"] = n
            # read arf file
            data_filename = os.path.join(data_root, arf)
            data, x, y, theta, phi, E = garf.load_test_dataset(data_filename)
            # build image
            temp = os.path.join(data_root, output_mhd + "_temp_" + str(i) + ".mhd")
            print(arf, n, data_filename, temp)
            # garf.build_arf_image_with_nn(nn, model, data, temp, param, verbose=False)
            img, sq_img = garf.build_arf_image_with_nn(nn, model, data, param)
            temp_sq = temp.replace(".mhd", "-Squared.mhd")
            sitk.WriteImage(img, temp)
            sitk.WriteImage(sq_img, temp_sq)
            i = i + 1


# -----------------------------------------------------------------------------
if __name__ == "__main__":
    garf_build_arf_image_with_nn()
