#! /usr/bin/env python
##########################################################################
# NSAp - Copyright (C) CEA, 2013 - 2018
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

# System import
from __future__ import print_function
import argparse
import os
import shutil
import json
import subprocess
from datetime import datetime
from pprint import pprint
import textwrap
from argparse import RawTextHelpFormatter

# Bredala import
try:
    import bredala
    bredala.USE_PROFILER = False
    bredala.register("pyconnectome.utils.regtools",
                     names=["flirt"])
    bredala.register("pydcmio.dcmconverter.converter",
                     names=["nii2dcm"])
except:
    pass

# Third party import
from pyconnectome import DEFAULT_FSL_PATH
from pyconnectome.utils.regtools import flirt
from pydcmio.dcmconverter.converter import nii2dcm
import nibabel
import numpy


# Parameters to keep trace
__hopla__ = ["runtime", "inputs", "outputs"]


# Script documentation
DOC = """
Resample image using FSL
~~~~~~~~~~~~~~~~~~~~~~~~

python pydcmio_deface_post \
    -i /neurospin/radiomics_pub/workspace/HERBY_iso2_deface/sub-220901/ses-10205/anat/sub-220901_ses-10205_acq-axial_run-401_T1w/sub-220901_ses-10205_acq-iso2axial_run-401_T1w_deface/sub-220901_ses-10205_acq-iso2axial_run-401_T1w_full_normfilter.nii.gz \
    -d /neurospin/radiomics_pub/workspace/HERBY_iso2_deface/sub-220901/ses-10205/anat/sub-220901_ses-10205_acq-axial_run-401_T1w/sub-220901_ses-10205_acq-iso2axial_run-401_T1w_deface/workspace/sub-220901_ses-10205_acq-iso2axial_run-401_T1w.hdr \
    -r /neurospin/radiomics_pub/HERBY/sourcedata/sub-220901/ses-10205/anat/sub-220901_ses-10205_acq-axial_run-401_T1w/sub-220901_ses-10205_acq-axial_run-401_T1w.nii.gz \
    -o /neurospin/radiomics_pub/workspace/HERBY_deface/test \
    -v 2
"""


def is_file(filearg):
    """ Type for argparse - checks that file exists but does not open.
    """
    if not os.path.isfile(filearg):
        raise argparse.ArgumentError(
            "The file '{0}' does not exist!".format(filearg))
    return filearg


def is_directory(dirarg):
    """ Type for argparse - checks that directory exists.
    """
    if not os.path.isdir(dirarg):
        raise argparse.ArgumentError(
            "The directory '{0}' does not exist!".format(dirarg))
    return dirarg


def get_cmd_line_args():
    """
    Create a command line argument parser and return a dict mapping
    <argument name> -> <argument value>.
    """
    parser = argparse.ArgumentParser(
        prog="python pydcmio_deface_post",
        description=textwrap.dedent(DOC),
        formatter_class=RawTextHelpFormatter)

    # Required arguments
    required = parser.add_argument_group("required arguments")
    required.add_argument(
        "-i", "--in-file",
        required=True, metavar="<path>", type=is_file,
        help="Path of defaced nifti image to be resampled.")
    required.add_argument(
        "-d", "--down-ref-file",
        required=True, metavar="<path>", type=is_file,
        help="Path of reference nifti image in the defaced input image space.")
    required.add_argument(
        "-r", "--ref-file",
        required=True, metavar="<path>", type=is_file,
        help="Path of reference nifti image to be defaced.")
    required.add_argument(
        "-o", "--outdir",
        required=True, metavar="<path>", type=is_directory,
        help="The destination folder.")

    # Optional arguments
    parser.add_argument(
        "-S", "--sid",
        help="The subject identifier.")
    parser.add_argument(
        "-N", "--study-name",
        help="The study name.")
    parser.add_argument(
        "-F", "--fsl-config", metavar="<path>", type=is_file,
        help="Path to fsl sh config file.")
    parser.add_argument(
        "-v", "--verbose",
        type=int, choices=[0, 1, 2], default=0,
        help="increase the verbosity level: 0 silent, [1, 2] verbose.")


    # Create a dict of arguments to pass to the 'main' function
    args = parser.parse_args()
    kwargs = vars(args)
    if kwargs["fsl_config"] is None:
        kwargs["fsl_config"] = DEFAULT_FSL_PATH
    verbose = kwargs.pop("verbose")

    return kwargs, verbose


"""
First check if the output directory exists on the file system, and
clean it if requested. Transcode also the subject identifier if requested.
"""
inputs, verbose = get_cmd_line_args()
runtime = {
    "tool": "pydcmio_deface_post",
    "timestamp": datetime.now().isoformat()
}
outputs = None
if verbose > 0:
    print("[info] Starting deface pre ...")
    print("[info] Runtime:")
    pprint(runtime)
    print("[info] Inputs:")
    pprint(inputs)


"""
Get the defacing mask
"""
if verbose > 1:
    print("[info] Get defacing mask by substracting '{0}' and '{1}'.".format(
        inputs["in_file"], inputs["down_ref_file"]))
im_deface = nibabel.load(inputs["in_file"])
im_ref = nibabel.load(inputs["down_ref_file"])
array_deface = im_deface.get_data().squeeze()
array_ref = im_ref.get_data().squeeze()
array_mask = numpy.abs(array_deface - array_ref)
ind = numpy.where(array_mask > 0.6)
array_mask[...] = 0
array_mask[ind] = array_deface[ind]
im_mask = nibabel.Nifti1Image(array_mask, im_deface.affine)
mask_file = os.path.join(inputs["outdir"], "mask.nii.gz")
nibabel.save(im_mask, mask_file)


"""
Resample the defacing mask
"""
identity_trf = os.path.join(inputs["outdir"], "identity.trf")
numpy.savetxt(identity_trf, numpy.eye(4))
resample_mask_file = os.path.join(inputs["outdir"], "resample_mask.nii.gz")
flirt(
    in_file=mask_file,
    ref_file=inputs["ref_file"],
    init=identity_trf,
    out=resample_mask_file,
    applyxfm=True,
    interp="nearestneighbour",
    shfile=inputs["fsl_config"])
os.remove(identity_trf)


"""
Apply defacing to the original image
"""
im = nibabel.load(inputs["ref_file"])
im_mask = nibabel.load(resample_mask_file)
array = im.get_data().squeeze()
array_mask = im_mask.get_data().squeeze()
ind = numpy.where(array_mask > 1e-1)
array[ind] = array_mask[ind]
im_deface = nibabel.Nifti1Image(array, im.affine)
basename = os.path.basename(inputs["ref_file"])
if basename.startswith("sub-"):
    if "acq-" in basename:
        pos = basename.find("acq-") + 4
        basename = basename[:pos] + "maskface" + basename[pos:]
    else:
        _basename = basename.split("_")
        _basename.insert(-2, "acq-maskface")
        basename = "_".join(_basename)
else:
    basename = "maskface_" + basename
deface_file = os.path.join(inputs["outdir"], basename)
nibabel.save(im_deface, deface_file)


"""
Generate a DICOM tarball
"""
name = basename.split(".")[0]
dicom_dir = os.path.join(inputs["outdir"], name)
if not os.path.isdir(dicom_dir):
    os.mkdir(dicom_dir)
series_fnames = nii2dcm(
    nii_file=deface_file,
    outdir=dicom_dir,
    sid=inputs["sid"],
    study_id=inputs["study_name"],
    debug=False)
dicom_tarball = dicom_dir + ".dicom.tar.gz"
cmd = ["tar", "zcvf", dicom_tarball, "-C", inputs["outdir"], name]
subprocess.check_call(cmd)
shutil.rmtree(dicom_dir)


"""
Update the outputs and save them and the inputs in a 'logs' directory.
"""
logdir = os.path.join(inputs["outdir"], "logs")
if not os.path.isdir(logdir):
    os.mkdir(logdir)
outputs = {
    "mask_file": mask_file,
    "resample_mask_file": resample_mask_file,
    "deface_file": deface_file
}
for name, final_struct in [("inputs_post", inputs), ("outputs_post", outputs),
                           ("runtime_post", runtime)]:
    log_file = os.path.join(logdir, "{0}.json".format(name))
    with open(log_file, "wt") as open_file:
        json.dump(final_struct, open_file, sort_keys=True, check_circular=True,
                  indent=4)
if verbose > 1:
    print("[info] Outputs:")
    pprint(outputs)


