#!/usr/bin/env python
import argparse
import json
import numpy as np
from PBCT import PBCT, PBCTClassifier, load_model
from PBCT.classes import DEFAULTS


def parse_args():
    arg_parser = argparse.ArgumentParser(
        description=(
            "Fit a PBCT model to data or use a trained model to predict new"
            " results. Input files and options may be provided either with comm"
            "and-line options or by a json config file (see --config option)."
        ),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # NOTE: Nested groups would be awesome here, but they don't work correctly.

    g = arg_parser.add_mutually_exclusive_group()
    g.add_argument(
        '--fit', action='store_true',
        help='Use input data to train a PBCT.')
    g.add_argument(
        '--predict', action='store_true',
        help='Predict interaction between input instances.')

    arg_parser.add_argument(
        '--config',
        help=('Load options from json file. File example: {'
              '\n\t"path_model": "/path/to/save/model.json",'
              '\n\t"fit": "true",'
              '\n\t"XX": ["/path/to/X1.csv", "/path/to/X2.csv"],'
              '\n\t"Y": "/path/to/Y.csv",'
              #'\n\t"XX_names": ["/path/to/X1_names.txt",  "/path/to/X2_names.txt"],'
              #'\n\t"XX_col_names": ["/path/to/X1_col_names.txt", "/path/to/X2_col_names.txt"],'
              '\n}.'
              ' Multiple dicts in a list will caus'
              'e this script to run multiple times.'))

    arg_parser.add_argument(
        '--XX', nargs='+',
        help=('Paths to .csv files containing rows of numerical attributes'
              ' for each axis\' instance.'))
    arg_parser.add_argument(
        '--XX_names', nargs='+',
        help=('Paths to files containing string identifiers for each instance'
              ' for each axis, being one file for each axis.'))
    arg_parser.add_argument(
        '--XX_col_names', nargs='+',
        help=('Paths to files containing string identifiers for each attribute'
              'column, being one file for each axis.'))
    arg_parser.add_argument(
        '--Y',
        help=('If fitting the model to data, it represents the pat'
              'h to a .csv file containing the known interaction matrix be'
              'tween rows and columns data.'
              'If predicting, Y is the destination path for the pr'
              'edicted values, formatted as an interaction matrix in the s'
              'ame way described.'))
    arg_parser.add_argument(
        '--path_model', default=DEFAULTS['path_model'],
        help=('When fitting: path to the location where the model will be '
              'saved. When predicting: the saved model to be used.'))
    arg_parser.add_argument(
        '--max_depth', default=DEFAULTS['max_depth'],
        help=('Maximum PBCT depth allowed. -1 will disable this stopping cr'
              'iterion.'))
    arg_parser.add_argument(
        '--min_samples_leaf', default=DEFAULTS['min_samples_leaf'],
        help=('Minimum number of sample pairs in the training set required'
              ' to arrive at each leaf.'))
    # arg_parser.add_argument(
    #     '--simple_mean', action='store_true',
    #     help=('If provided, the prototype function will always return the '
    #           'mean value over the entire sub interaction matrix of the leaf'
    #           ', not considering possible known instances.'))
    arg_parser.add_argument(
        '--verbose', '-v', action='store_true',
        help='Show more detailed output')

    # TODO:
    # arg_parser.add_argument(
    #     '--visualize', default=DEFAULTS['path_rendering'],
    #     help=('If provided, path to where a visualization of the trained t'
    #           'ree will be saved.'))

    # TODO:
    #arg_parser.add_argument(
    #    '--fromdir',
    #    help=('Read data from directory instead. In such case, filenames must be:'
    #          '\tX1, X2, Y, X1_names, X2_names, X1_col_names and X2_col_names,\n'
    #          'with any file extension. *_names files are optional.'))

    return arg_parser.parse_args()


def main(**kwargs):
    """CLI for using PBCTs.

    Command-line utility for training a PBCT or using a trained model to predic
    t new interactions. See `parse_args()` or use --help for parameters' descri
    ption.
    """

    if kwargs['config'] is not None:
        # If config file is a single dict, load its options and proceed.
        # If it's a list of dicts, run this function for each of them.
        with open(kwargs['config']) as config_file:
            config = json.load(config_file)
        if isinstance(config, dict):
            kwargs.update(config)
        elif isinstance(config, list):
            # Ensure we are not caugth in the top conditional again.
            kwargs['config'] = None
            for jsonkwargs in config:
                main(**(kwargs | jsonkwargs))
            return

    print('Loading data...')
    XX = [np.loadtxt(p, delimiter=',') for p in kwargs['XX']]
    XX_names, XX_col_names = None, None
    # if kwargs['XX_names']:
    #     XX_names = [np.loadtxt(p) for p in kwargs['XX_names']]
    if kwargs['XX_col_names']:
        XX_col_names = [np.loadtxt(p) for p in kwargs['XX_col_names']]

    if kwargs['fit']:
        Y = np.loadtxt(kwargs['Y'], delimiter=',')
        Tree = PBCT(
            min_samples_leaf=kwargs['min_samples_leaf'],
            max_depth=kwargs['max_depth'],
            verbose=kwargs['verbose'],
        )

        print('Training PBCT...')
        Tree.fit(XX, Y, X_names=XX_col_names)  # FIXME: Confusing variable names.
        print('Saving model...')
        Tree.save(kwargs['path_model'])
        print('Done.')

    elif kwargs['predict']:
        print('Loading model...')
        Tree = load_model(kwargs['path_model'])
        print('Predicting values...')
        predictions = Tree.predict(
            XX,
            # simple_mean=kwargs['simple_mean'],
            verbose=kwargs['verbose'],
        )
        print('Saving...')
        np.savetxt(kwargs['Y'], predictions, delimiter=',', fmt='%d')
        print('Done.')


if __name__ == '__main__':
    args = parse_args()
    main(**vars(args))
