#!/usr/bin/env python2
# -*- coding: utf-8 -*-

"""Run an fMRI analysis using the JDE-VEM framework


All arguments (even positional ones) are optional and default values are defined
in the beginning of the script"""

import os
import argparse
import sys
import json
import time
import datetime

import numpy as np

import pyhrf

from pyhrf.core import FmriData
from pyhrf.ui.vb_jde_analyser import JDEVEMAnalyser
from pyhrf.ui.treatment import FMRITreatment
from pyhrf.tools import format_duration


def load_contrasts_definitions(contrasts_file):
    """Loads contrasts from a json file defining the contrasts with linear combinations of conditions.

    Parameters
    ----------
    contrasts_file : str
        the path to the json file
        
    Returns
    -------
    compute_contrasts : bool
        if everything gone well, tell the JDEVEMAnalyser to compute the contrasts
    contrasts_def : dict
        each key is a contrast and each corresponding value is the contrast definition

    """

    try:
        with open(contrasts_file) as contrasts_file:
            contrasts_def = json.load(contrasts_file)
        compute_contrasts = bool(contrasts_def)
    except (IOError, TypeError):
        compute_contrasts = False
        contrasts_def = None

    return compute_contrasts, contrasts_def


def main(args):
    """Run when calling the script"""

    if not args.dt:
        args.dt = float(args.tr)/np.trunc(args.tr)

    try:
        pyhrf.logger.setLevel(args.log_level)
    except ValueError:
        try:
            pyhrf.logger.setLevel(int(args.log_level))
        except ValueError:
            print("Can't set log level to {}".format(args.log_level))
            sys.exit(1)

    start_time = time.time()

    if not os.path.isdir(args.output_dir):
        try:
            os.makedirs(args.output_dir)
        except OSError as e:
            print("Output directory could not be created.\nError was: {}".format(e.strerror))
            sys.exit(1)

    bold_data = FmriData.from_vol_files(
        mask_file=args.parcels_file, paradigm_csv_file=args.onsets_file,
        bold_files=args.bold_data_file, tr=args.tr
    )

    compute_contrasts, contrasts_def = load_contrasts_definitions(args.def_contrasts_file)

    jde_vem_analyser = JDEVEMAnalyser(
        hrfDuration=args.hrf_duration, sigmaH=args.sigma_h, fast=True,
        computeContrast=compute_contrasts, nbClasses=2, PLOT=False,
        nItMax=args.nb_iter_max, nItMin=args.nb_iter_min, scale=False,
        beta=args.beta, estimateSigmaH=True, estimateHRF=args.estimate_hrf,
        TrueHrfFlag=False, HrfFilename='hrf.nii', estimateDrifts=True,
        hyper_prior_sigma_H=args.hrf_hyperprior, dt=args.dt, estimateBeta=True,
        contrasts=contrasts_def, simulation=False, estimateLabels=True,
        LabelsFilename=None, MFapprox=False, estimateMixtParam=True,
        constrained=False, InitVar=0.5, InitMean=2.0, MiniVemFlag=False,
        NbItMiniVem=5, zero_constraint=args.zero_constraint, drifts_type=args.drifts_type
    )

    processing_jde_vem = FMRITreatment(
        fmri_data=bold_data, analyser=jde_vem_analyser,
        output_dir=args.output_dir, make_outputs=True
    )

    if not args.parallel:
        processing_jde_vem.run()
    else:
        processing_jde_vem.run(parallel="local")

    if args.save_processing_config:
        # Let's canonicalize all paths
        config_save = vars(args)

        for file_nb, bold_file in enumerate(config_save["bold_data_file"]):
            config_save["bold_data_file"][file_nb] = os.path.abspath(bold_file)

        config_save["parcels_file"] = os.path.abspath(config_save["parcels_file"])
        config_save["onsets_file"] = os.path.abspath(config_save["onsets_file"])

        if config_save["def_contrasts_file"]:
            config_save["def_contrasts_file"] = os.path.abspath(config_save["def_contrasts_file"])

        config_save["output_dir"] = os.path.abspath(config_save["output_dir"])
        config_save_filename = "{}_processing.json".format(datetime.datetime.today()).replace(" ", "_")
        config_save_path = os.path.join(args.output_dir, config_save_filename)

        with open(config_save_path, 'w') as json_file:
            json.dump(config_save, json_file, sort_keys=True, indent=4)

    print("\nTotal computation took: {} seconds".format(format_duration(time.time() - start_time)))

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("tr", type=float, help="Repetition time of the fMRI data")
    parser.add_argument("parcels_file", help="Nifti label file which defines the parcels")
    parser.add_argument("onsets_file", help="CSV onsets file")
    parser.add_argument("bold_data_file", nargs="*",
                        help="list of 4D-Nifti file(s) of BOLD data (one per run)")
    parser.add_argument("-c", "--def_contrasts_file",
                        help="JSON file defining the contrasts")
    parser.add_argument("-d", "--dt", type=float,
                        help=("time resolution of the HRF (if not defined here or"
                              " in the script, it is automatically computed)"))
    parser.add_argument("-l", "--hrf-duration", type=float, default=25.,
                        help="time lenght of the HRF (in seconds, default: %(default)s)")
    parser.add_argument("-o", "--output", dest="output_dir", default=os.getcwd(),
                        help=("output directory (created if needed,"
                              " default: current folder: %(default)s)"))
    parser.add_argument("--nb-iter-min", type=int, default=5,
                        help=("minimum number of iterations of the VEM algorithm"
                              " (default: %(default)s)"))
    parser.add_argument("--nb-iter-max", type=int, default=100,
                        help=("maximum number of iterations of the VEM algorithm"
                              " (default: %(default)s)"))
    parser.add_argument("--beta", type=float, default=1.0,
                        help=("intrinsic correlation parameter on latent variables"
                              " (default: %(default)s)"))
    parser.add_argument("--hrf-hyperprior", type=int, default=1000,
                        help="(default: %(default)s)")
    parser.add_argument("--sigma-h", type=float, default=0.1,
                        help="Prior variance on the HRF (default: %(default)s)")
    parser.add_argument("--estimate-hrf", action="store_true", default=True,
                        help="(default: %(default)s)")
    parser.add_argument("--no-estimate-hrf", action="store_false", dest="estimate_hrf",
                        help="explicitly disable HRFs estimation")
    parser.add_argument("--zero-constraint", action="store_true", default=True,
                        help="zero constraint on the boundaries of the HRF (default: %(default)s)")
    parser.add_argument("--no-zero-constraint", action="store_false", dest="zero_constraint",
                        help="explicitly disable zero constraint (default enabled)")
    parser.add_argument("--drifts-type", type=str, default="poly",
                        help="set the drift type, can be 'poly' or 'cos' (default: %(default)s)")
    parser.add_argument("-v", "--log-level", default="WARNING",
                        choices=("DEBUG", "10", "INFO", "20", "WARNING", "30",
                                 "ERROR", "40", "CRITICAL", "50"),
                        help="(default: %(default)s)")
    parser.add_argument("-p", "--parallel", action="store_true", default=True,
                        help="(default: %(default)s)")
    parser.add_argument("--no-parallel", action="store_false", dest="parallel",
                        help="explicitly disable parallel computation")
    parser.add_argument("--save-processing", action="store_true", default=True,
                        help="(default: %(default)s)")
    parser.add_argument("--no-save-processing", action="store_false", dest="save_processing_config")
    parser_args = parser.parse_args()
    main(parser_args)
