#!/usr/bin/env python
# -*- coding: latin-1 -*-
#
#   Copyright 2016-2021 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#
import argparse
import os
import subprocess
import sys
from glob import glob

import numpy as np

import rapidtide.externaltools as tide_exttools
import rapidtide.io as tide_io
import rapidtide.workflows.parser_funcs as pf


def _get_parser():
    """
    Argument parser for atlastool
    """
    parser = argparse.ArgumentParser(
        prog="atlastool",
        description=("A tool to manipulate nifti atlas files"),
        allow_abbrev=False,
    )

    # Required arguments
    pf.addreqinputniftifile(
        parser,
        "inputtemplatename",
        addedtext="Must be either a 3D file with different regions specified by integer values, "
        "or a 4D file with nonzero values indicating that a voxel is in the region.",
    )
    pf.addreqoutputniftifile(parser, "outputtemplatename")

    # add optional arguments
    parser.add_argument(
        "--3d",
        dest="volumeperregion",
        action="store_false",
        help=("Return a 3d file with regions encoded as integers"),
        default=False,
    )
    parser.add_argument(
        "--4d",
        dest="volumeperregion",
        action="store_true",
        help=("Return a 4d file with one region per volume"),
    )
    parser.add_argument(
        "--split",
        dest="dosplit",
        action="store_true",
        help=("Split regions along the midline into left and right subregions"),
        default=False,
    )
    parser.add_argument(
        "--maskthresh",
        dest="maskthresh",
        action="store",
        type=lambda x: pf.is_float(parser, x),
        metavar="FILE",
        help=("Threshhold for autogenerated mask (default is 0.25)."),
        default=0.25,
    )
    parser.add_argument(
        "--labelfile",
        dest="labelfile",
        action="store",
        type=lambda x: pf.is_valid_file(parser, x),
        metavar="FILE",
        help=("Labels for the source atlas"),
        default=None,
    )
    parser.add_argument(
        "--xfm",
        dest="xfm",
        action="store",
        type=lambda x: pf.is_valid_file(parser, x),
        metavar="FILE",
        help=("Transform file to go to the reference."),
        default=None,
    )
    parser.add_argument(
        "--targetfile",
        dest="targetfile",
        action="store",
        type=lambda x: pf.is_valid_file(parser, x),
        metavar="TARGETFILE",
        help=("Match the resolution of TARGET"),
        default=None,
    )
    parser.add_argument(
        "--maskfile",
        dest="maskfile",
        action="store",
        type=lambda x: pf.is_valid_file(parser, x),
        metavar="MASK",
        help=(
            "Mask the final atlas with the 3D mask specified (if using a target file, mask must match it)."
        ),
        default=None,
    )
    parser.add_argument(
        "--removeemptyregions",
        dest="removeemptyregions",
        action="store_true",
        help=(
            "Remove regions with no voxels, so that label values are consecutive.  "
            "Adjust the label file as necessary."
        ),
        default=False,
    )
    parser.add_argument(
        "--RtoL",
        dest="LtoR",
        action="store_false",
        help=(
            "Reverse left/right assignment of labels (default is LtoR).  Change if the default output is wrong."
        ),
        default=True,
    )
    debugging = parser.add_argument_group("Debugging options")
    debugging.add_argument(
        "--debug",
        dest="debug",
        action="store_true",
        help=("Output debugging information"),
        default=False,
    )
    debugging.add_argument(
        "--maxval",
        dest="maxval",
        action="store",
        type=lambda x: pf.is_int(parser, x),
        metavar="MAXVAL",
        help=("Only process atlas values up to MAXVAL"),
        default=None,
    )

    return parser


def main():
    # get the command line parameters
    try:
        args = _get_parser().parse_args()
    except SystemExit:
        _get_parser().print_help()
        raise

    if args.debug:
        print(args)

    print("loading template data")
    (
        template_img,
        template_data,
        template_hdr,
        templatedims,
        templatesizes,
    ) = tide_io.readfromnifti(args.inputtemplatename)

    print("reshaping")
    xsize = templatedims[1]
    ysize = templatedims[2]
    numslices = templatedims[3]
    numtimepoints = templatedims[4]
    numvoxels = int(xsize) * int(ysize) * int(numslices)

    # make a 4d array
    if numtimepoints > 1:
        # array is already 4d, just reshape it
        if args.maxval is not None:
            numtimepoints = args.maxval
            templatevoxels = np.reshape(
                template_data[:, :, :, :numtimepoints], (numvoxels, numtimepoints)
            )
        else:
            templatevoxels = np.reshape(template_data, (numvoxels, numtimepoints))
        numregions = numtimepoints
    else:
        # file is 3d, construct a 4d array
        slop = 0.1
        print("going from 3d to 4d")
        if args.maxval is not None:
            print(f"using a maxval of {args.maxval}")
            template_data[np.where(template_data > args.maxval)] = 0
        numregions = np.floor(np.max(template_data)).astype(np.uint16)
        if args.debug:
            print(f"numregions is {numregions}")
        rs_template = np.reshape(template_data, numvoxels)
        templatevoxels = np.zeros((numvoxels, numregions))
        validvoxels = np.where(rs_template > 0)[0]
        print(f"{len(validvoxels)} valid voxels")
        for thevoxel in validvoxels:
            templatevoxels[thevoxel, int(round(rs_template[thevoxel], 0)) - 1] = 1
    maskvoxels = np.max(templatevoxels, axis=1).astype(np.uint16)
    maskvoxels[np.where(maskvoxels > args.maskthresh)] = 1

    # read in the label file, if there is one
    if args.labelfile is not None:
        with open(args.labelfile) as f:
            thelabels = f.read().splitlines()
        if args.debug:
            for label in thelabels:
                print(label)
        if len(thelabels) != numregions:
            raise ("label file does not match atlas")

    # now we have a 4d file, regardless of what we started with
    if args.dosplit:
        print("splitting left right")
        tempvoxels = np.reshape(templatevoxels, (xsize, ysize, numslices, numregions)).astype(
            np.uint16
        )
        tempvoxels[int(xsize // 2) + 1 :, :, :, :] *= 2
        templatevoxels = np.reshape(tempvoxels, (numvoxels, numregions))
        newtemplatevoxels = np.zeros((numvoxels, numregions * 2), dtype=np.uint16)
        if args.LtoR:
            prefixes = ["L_", "R_"]
        else:
            prefixes = ["R_", "L_"]
        for theregion in range(1, numregions + 1):
            newtemplatevoxels[:, theregion - 1] = np.where(
                templatevoxels[:, theregion - 1] == 1, 1, 0
            )
            newtemplatevoxels[:, theregion + numregions - 1] = np.where(
                templatevoxels[:, theregion - 1] == 2, 1, 0
            )
        if args.labelfile is not None:
            thenewlabels = []
            for theregion in range(1, numregions + 1):
                thenewlabels.append(prefixes[0] + thelabels[theregion - 1])
            for theregion in range(1, numregions + 1):
                thenewlabels.append(prefixes[1] + thelabels[theregion - 1])
        numregions *= 2
        templatevoxels = np.reshape(newtemplatevoxels, (numvoxels, numregions))
    else:
        if args.labelfile is not None:
            thenewlabels = thelabels

    # eliminate any missing values
    numnonzero = 0
    threshold = 0.5
    outputlabels = []
    if args.removeemptyregions:
        for i in range(numregions):
            thevolume = 1.0 * templatevoxels[:, i]
            if np.max(thevolume) > threshold:
                templatevoxels[:, numnonzero] = thevolume
                if args.labelfile is not None:
                    outputlabels.append(thenewlabels[i])
                numnonzero += 1
            else:
                print(f"no voxels with value {i + 1} - removing.")
        numregions = numnonzero

    if args.targetfile is not None:
        # do the resampling here
        print("resampling to new resolution")
        fsldir = os.environ.get("FSLDIR")
        outputdir, outputfilename = os.path.split(args.outputtemplatename)
        if args.xfm is None:
            thexfm = os.path.join(fsldir, "data", "atlases", "bin", "eye.mat")
            alignttype = "flirt"
        else:
            thexfm = args.xfm
            dummy, theext = os.path.splitext(thexfm)
            if theext == "mat":
                aligntype = "flirt"
            else:
                aligntype = "ants"
        preroot = os.path.join(outputdir, "temppre")
        postroot = os.path.join(outputdir, "temppost")
        if fsldir is not None:
            for thisregion in range(numregions):
                # first write out a temp file with the data
                outputvoxels = templatevoxels[:, thisregion].reshape((xsize, ysize, numslices))
                template_hdr["dim"][4] = 1

                if args.debug:
                    print(f"writing out temp file {thisregion}")
                tide_io.savetonifti(
                    outputvoxels,
                    template_hdr,
                    f"{preroot}_{str(thisregion).zfill(4)}",
                )
            if aligntype == "flirt":
                mergecmd = os.path.join(fsldir, "bin", "fslmerge")
                filelist = glob(f"{preroot}_*.nii.gz")
                if args.debug:
                    print(filelist)
                sortedlist = sorted(filelist)
                if args.debug:
                    print(sortedlist)
                thecommand = [mergecmd, "-t", preroot] + sortedlist
                if args.debug:
                    print(f"executing: {' '.join(thecommand)}")
                subprocess.call(thecommand)

                flirtcmd = os.path.join(fsldir, "bin", "flirt")
                thecommand = []
                thecommand.append(flirtcmd)
                thecommand.append("-in")
                thecommand.append(preroot)
                thecommand.append("-ref")
                thecommand.append(args.targetfile)
                thecommand.append("-applyxfm")
                thecommand.append("-init")
                thecommand.append(thexfm)
                thecommand.append("-out")
                thecommand.append(postroot)
                if args.debug:
                    thecommand.append("-v")
                    print(f"executing: {' '.join(thecommand)}")
                subprocess.call(thecommand)

            elif aligntype == "ants":
                for thisregion in range(numregions):
                    print(f"aligning {thisregion}")
                    infile = f"{preroot}_{str(thisregion).zfill(4)}.nii.gz"
                    outfile = f"{postroot}_{str(thisregion).zfill(4)}.nii.gz"
                    tide_exttools.antsapply(
                        infile, args.targetfile, outfile, [thexfm], debug=args.debug
                    )
                mergecmd = os.path.join(fsldir, "bin", "fslmerge")
                filelist = glob(f"{postroot}_*.nii.gz")
                if args.debug:
                    print(filelist)
                sortedlist = sorted(filelist)
                if args.debug:
                    print(sortedlist)
                thecommand = [mergecmd, "-t", f"{postroot}"] + sortedlist
                if args.debug:
                    print(f"executing: {' '.join(thecommand)}")
                subprocess.call(thecommand)
            else:
                raise ("Illegal alignment type")
            if args.debug:
                print(f"reading back aligned file")
            (
                template_img,
                template_data,
                template_hdr,
                templatedims,
                templatesizes,
            ) = tide_io.readfromnifti(postroot)
            xsize = templatedims[1]
            ysize = templatedims[2]
            numslices = templatedims[3]
            numregions = templatedims[4]
            numvoxels = int(xsize) * int(ysize) * int(numslices)
            templatevoxels = np.around(np.reshape(template_data, (numvoxels, numregions))).astype(
                np.uint16
            )
        else:
            print("FSL directory not found - aborting")
            sys.exit()

    # mask data
    if args.maskfile is not None:
        # Load the mask
        print("loading mask file")
        (
            mask_img,
            mask_data,
            mask_hdr,
            maskdims,
            masksizes,
        ) = tide_io.readfromnifti(args.maskfile)
        if maskdims[4] > 1:
            raise ValueError(f"{args.maskfile} is not 3D - exiting")
        maskvoxels = mask_data.reshape((numvoxels)).astype(np.uint16)
        maskvoxels[np.where(maskvoxels < 0.5)] = 0
        if not tide_io.checkspacematch(template_hdr, mask_hdr):
            raise ValueError(
                f"Dimensions of {args.maskfile} do not match the target dimensions - exiting"
            )
    else:
        maskvoxels = np.max(templatevoxels, axis=1).astype(np.uint16)
        maskvoxels[np.where(maskvoxels > args.maskthresh)] = 1

    # eliminate any newly missing values
    numnonzero = 0
    threshold = 0.5
    finallabels = []
    if args.removeemptyregions:
        for i in range(numregions):
            thevolume = 1.0 * templatevoxels[:, i]
            if np.max(thevolume) > threshold:
                templatevoxels[:, numnonzero] = thevolume
                if args.labelfile is not None:
                    finallabels.append(outputlabels[i])
                numnonzero += 1
            else:
                print(f"no voxels with value {i + 1} - removing.")
        numregions = numnonzero

    for theregion in range(numregions):
        templatevoxels[:, theregion] *= maskvoxels

    if args.volumeperregion:
        outputvoxels = templatevoxels[:, : numnonzero + 1]
        template_hdr["dim"][4] = numregions
        tide_io.savetonifti(
            outputvoxels.reshape((xsize, ysize, numslices, numregions)),
            template_hdr,
            args.outputtemplatename,
        )
    else:
        outputvoxels = np.argmax(templatevoxels, axis=1) + 1
        outputvoxels *= maskvoxels.astype(np.uint16)
        template_hdr["dim"][4] = 1
        tide_io.savetonifti(
            outputvoxels.reshape((xsize, ysize, numslices)),
            template_hdr,
            args.outputtemplatename,
        )
    if args.labelfile is not None:
        with open(tide_io.niftisplitext(args.outputtemplatename)[0] + "_labels.txt", "w") as f:
            f.writelines("\n".join(finallabels))


if __name__ == "__main__":
    main()
