#!/usr/bin/python
# -*- coding: utf-8 -*-

import argparse
import logging
import os

# Set  OMP_NUM_THREADS to 1 before importing numpy
os.environ["OMP_NUM_THREADS"] = "1"

import nibabel as nib
from math import pi
from dipy.core.gradients import gradient_table
from dipy.io import read_bvals_bvecs

import bonndit as bd
from bonndit.shoredeconv import fa_guided_mask
from bonndit.io import fsl_vectors_to_worldspace, fsl_gtab_to_worldspace
from bonndit.michi import fields, dwmri


def main():
    parser = argparse.ArgumentParser(
        description='This script computes fiber orientation distribution functions (fODFs) \
        as described in "Versatile, Robust and Efficient Tractography With Constrained Higher \
        Order Tensor fODFs" by Ankele et al. (2017). It is assumed that the input data is saved \
        using FSL.', add_help=False)

    parser.add_argument('indir',
                        help='Folder containing all required input files')

    parser.add_argument('-o', '--outdir',
                        help='Folder in which the output will be saved (default: same as indir)')

    inputfiles = parser.add_argument_group('Custom input filenames', 'It is not recommended to specify \
    a Specify custom names for input files.')
    inputfiles.add_argument('-d', '--data', default='data.nii.gz',
                            help='Diffusion weighted data (default: data.nii.gz)')
    inputfiles.add_argument('-e', '--dtivecs', default='dti_V1.nii.gz',
                            help='First eigenvectors of a DTI model (default: dti_V1.nii.gz)')
    inputfiles.add_argument('-a', '--dtifa', default='dti_FA.nii.gz',
                            help='Fractional anisotropy values from a DTI model (default: dti_FA.nii.gz)')
    inputfiles.add_argument('-m', '--brainmask', default='mask.nii.gz',
                            help='Brain mask (default: mask.nii.gz)')
    inputfiles.add_argument('-W', '--wmmask', default='fast_pve_2.nii.gz',
                            help='White matter mask (default: fast_pve_2.nii.gz)')
    inputfiles.add_argument('-G', '--gmmask', default='fast_pve_1.nii.gz',
                            help='Gray matter mask (default: fast_pve_1.nii.gz)')
    inputfiles.add_argument('-F', '--csfmask', default='fast_pve_0.nii.gz',
                            help='Cerebrospinal fluid mask (default: fast_pve_0.nii.gz)')

    flags = parser.add_argument_group('flags (optional)', '')
    flags.add_argument("-h", "--help", action="help", help="Show this help message and exit")
    flags.add_argument('-v', '--verbose', action='store_true',
                       help='Activate progress bars and console logging')
    flags.add_argument('-R', '--responseonly', action='store_true',
                       help='Calculate and save only the response functions')
    flags.add_argument('-M', '--tissuemasks', action='store_true',
                       help='Output the DTI improved tissue masks (csf/gm/wm)')

    shoreopts = parser.add_argument_group('shore options (optional)', 'Optional arguments for the computation of \
    the shore response functions')
    shoreopts.add_argument('-k', '--kernel', choices=["rank1", "delta"],
                           default="rank1", type=str,
                           help='Kernel type (default: rank1)')
    shoreopts.add_argument('-r', '--order', default=4, type=int,
                           help='Order of the shore basis (default: 4)')
    shoreopts.add_argument('-z', '--zeta', default=700, type=float,
                           help='Radial scaling factor (default: 700)')
    shoreopts.add_argument('-t', '--tau', default=1 / (4 * pi ** 2),
                           type=float,
                           help='q-scaling (default: 1 / (4 * math.pi ** 2)')
    shoreopts.add_argument('-f', '--fawm', default=0.7, type=float,
                           help='White matter fractional anisotropy threshold (default: 0.7)')

    deconvopts = parser.add_argument_group('deconvolution options (optional)', '')
    deconvopts.add_argument('-C', '--constraint', choices=['hpsd', 'nonneg', 'none'], default='hpsd',
                            help='Constraint for the fODFs (default: hpsd)')

    multiprocessing = parser.add_argument_group('multiprocessing (optional)', 'Configure the multiprocessing behaviour \
    (only supported for Python 3)')
    multiprocessing.add_argument('-w', '--workers', default=None, type=int,
                                 help='Number of cpus (default: all available cpus)')

    log = parser.add_argument_group('logging (optional)', 'Configure the logging behaviour')
    log.add_argument('-L', '--loglevel', choices=['INFO', 'WARNING', 'ERROR'],
                     default='INFO',
                     help='Specify the logging level for the console')

    args = parser.parse_args()

    # Create outdir if it does not exists
    indir = args.indir
    if not args.outdir:
        outdir = indir
    else:
        outdir = args.outdir

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    levels = {'INFO': logging.INFO,
              'WARNING': logging.WARNING,
              'ERROR': logging.ERROR}

    # Logging setup for file
    logging.basicConfig(filename=os.path.join(args.outdir, 'mtdeconv.log'),
                        format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
                        datefmt='%y-%m-%d %H:%M',
                        level=levels[args.loglevel],
                        filemode='w')

    # Console logging if verbose flag is set
    if args.verbose:
        # define a Handler which writes INFO messages or higher to the sys.stderr
        console = logging.StreamHandler()
        console.setLevel(levels[args.loglevel])
        # set a format which is simpler for console use
        formatter = logging.Formatter(
            '%(name)-12s: %(levelname)-8s %(message)s')
        # tell the handler to use this format
        console.setFormatter(formatter)
        # add the handler to the root logger
        logging.getLogger('').addHandler(console)

    logging.info('mtdeconv has been called with:')
    param_string = 'Order: {} Zeta: {}, Tau: {}, FAWM: {}, Constraint: {}, ' \
                   'Kernel {}'
    logging.info(param_string.format(args.order, args.zeta, args.tau,
                                     args.fawm, args.constraint, args.kernel))

    # Check whether all specified files exist
    for f in [args.brainmask, args.data, args.dtifa, args.csfmask, args.gmmask,
              args.wmmask, args.dtivecs, 'bvals', 'bvecs', ]:
        filepath = os.path.join(indir, f)
        if not os.path.isfile(filepath):
            msg = 'No such file or directory: "{}"'.format(filepath)
            logging.error(msg)
            raise FileNotFoundError(msg)

    # Load fractional anisotropy
    dti_fa = bd.load(os.path.join(indir, args.dtifa))

    # Load DTI mask
    dti_mask = bd.load(os.path.join(indir, args.brainmask))

    # Load tissue segmentation masks
    csf_mask = bd.load(os.path.join(indir, args.csfmask))
    gm_mask = bd.load(os.path.join(indir, args.gmmask))
    wm_mask = bd.load(os.path.join(indir, args.wmmask))

    # Load first eigenvectors of a precalculated diffusion tensor
    dti_vecs = bd.load(os.path.join(indir, args.dtivecs))

    # Load diffusion weighted data
    data = bd.load(os.path.join(indir, args.data))

    # Load bvals and bvecs and initialize a GradientTable object
    bvals, bvecs = read_bvals_bvecs(os.path.join(indir, "bvals"),
                                    os.path.join(indir, "bvecs"))
    gtab = gradient_table(bvals, bvecs)

    logging.info('Input loaded.')

    # Create tissue masks based on fractional anisotropy values
    wm_mask = fa_guided_mask(wm_mask, dti_fa, dti_mask, tissue_threshold=0.95,
                             fa_lower_thresh=0.7)
    gm_mask = fa_guided_mask(gm_mask, dti_fa, dti_mask, tissue_threshold=0.95,
                             fa_upper_thresh=0.2)
    csf_mask = fa_guided_mask(csf_mask, dti_fa, dti_mask,
                              tissue_threshold=0.95,
                              fa_upper_thresh=0.2)

    logging.info('Fractional anisotropy based tissue masks created.')

    # Flip sign of x-coordinate if affine determinant is positive and rotate to worldspace
    gtab = fsl_gtab_to_worldspace(gtab, data.affine)
    dti_vecs = fsl_vectors_to_worldspace(dti_vecs)
    logging.info('Rotation to worldspace finished')

    # We need this Meta object for saving later
    base_filename = os.path.join(indir, args.data).rstrip(".gz").rstrip(".nii")
    try:
        _, _, meta = dwmri.load(base_filename + '.nii.gz')
    except FileNotFoundError:
        try:
            _, _, meta = dwmri.load(base_filename + '.nii')
        except FileNotFoundError as e:
            raise FileNotFoundError(e)

    if args.tissuemasks:
        nib.save(wm_mask, os.path.join(args.outdir, 'wm_mask.nii.gz'))
        nib.save(gm_mask, os.path.join(args.outdir, 'gm_mask.nii.gz'))
        nib.save(csf_mask, os.path.join(args.outdir, 'csf_mask.nii.gz'))

    # Check if response is already in the output folder
    if not args.responseonly:
        if os.path.exists(os.path.join(outdir, "response.npz")):
            fit = bd.ShoreMultiTissueResponse.load(
                os.path.join(outdir, "response.npz"))
            logging.info('Existing response functions loaded.')

        else:
            model = bd.ShoreMultiTissueResponseEstimator(gtab, args.order,
                                                         args.zeta, args.tau)
            fit = model.fit(data, dti_vecs, wm_mask, gm_mask, csf_mask,
                            verbose=args.verbose, cpus=args.workers)
            fit.save(os.path.join(outdir, "response.npz"))
            logging.info('Response functions estimated and saved.')

    # Force recalculate the response if response only is specified
    else:
        model = bd.ShoreMultiTissueResponseEstimator(gtab, args.order,
                                                     args.zeta, args.tau)
        fit = model.fit(data, dti_vecs, wm_mask, gm_mask, csf_mask,
                        verbose=args.verbose,
                                         cpus=args.workers)
        fit.save(os.path.join(outdir, "response.npz"))
        logging.info('Response functions recalculated and saved.')

    # Deconvolution if 'responseonly' is not set
    if not args.responseonly:
        out, wmout, gmout, csfout = fit.fodf(data, pos=args.constraint,
                                             mask=dti_mask, kernel=args.kernel,
                                             verbose=args.verbose,
                                             cpus=args.workers)
        logging.info('Signal deconvolved with multiple tissue response functions.')

        fields.save_tensor(os.path.join(args.outdir, "fodf.nrrd"), out,
                           mask=dti_mask.get_data(), meta=meta)
        logging.info('fODFs saved.')

        # Save volumes
        fields.save_scalar(os.path.join(args.outdir, "wmvolume.nrrd"),
                           wmout, meta)
        fields.save_scalar(os.path.join(args.outdir, "gmvolume.nrrd"),
                           gmout, meta)
        fields.save_scalar(os.path.join(args.outdir, "csfvolume.nrrd"),
                           csfout, meta)
        logging.info('Volume fractions saved.')

    logging.info('Success!')


if __name__ == "__main__":
    main()
