#!/usr/bin/env python
"""
Usage:
    train_model (--true-pos STR) (--false-pos STR) (--type STR) [--out STR] [--njobs INT] [--verbose]

Description:
    Train a model to be saved and used to filter VCFs.

Arguments:
    --true-pos STR          Path to true-positive VCF from VCFeval or comma-seperated list of paths
    --false-pos STR         Path to false-positive VCF from VCFeval or comma-seperated list of paths
    --type STR              SNP or INDEL

Options:
    -o, --out <STR>                 Outfile name for writing model [default: (type).filter.pickle.dat]
    -n, --njobs <INT>               Number of threads to run in parallel [default: 2]
    -h, --help                      Show this help message and exit.
    -v, --version                   Show version and exit.
    --verbose                       Log output

Examples:
    train_model --true-pos <path/to/tp/vcf(s)> --false-pos <path/to/fp/vcf(s)> --type [SNP, INDEL] --njobs 20
"""

from docopt import docopt
import extremevariantfilter as evf
from multiprocessing import Pool
from contextlib import closing
import pandas as pd
import numpy as np
import pickle
import warnings


def get_options():
    """Gets command line arguments

    Returns
    -------
    tp_vcf : str
        path to true positive vcf
    fp_vcf : str
        path to false positive vcf
    poly : str
        type of polymorphism. 'SNP' or 'INDEL'
    njobs : int
        number of threads used for exdecution
    outname : str
        name for the output model
    """

    args = docopt(__doc__, version='1.0')
    verbose = args['--verbose']

    # Read training data
    tp_vcf = args['--true-pos']
    fp_vcf = args['--false-pos']
    poly = args['--type'].upper()
    njobs = int(args['--njobs'])
    outname = args['--out']
    if outname == "(type).filter.pickle.dat":
        outname = poly + '.filter.pickle.dat'
    evf.Check_Type(poly)

    return tp_vcf, fp_vcf, poly, njobs, outname


def main():
    """Main function for train_model

    Trains an XGBClassifier based on command line arguments.
    This model is then pickled and stored for future use.
    """

    warnings.filterwarnings('ignore',category=DeprecationWarning)

    tp_vcf, fp_vcf, poly, njobs, outname = get_options()
    all_vcf = evf.Check_VCF_Paths(tp_vcf, fp_vcf)
    with closing(Pool(processes=njobs)) as pool:
        results = pool.map(evf.Get_Training_Tables, all_vcf)
        pool.terminate()

    X, Y = zip(*results)
    X_all = np.concatenate(X)
    Y_all = np.concatenate(Y)

    model = evf.Build_Model(poly, njobs)
    print("Training {} on {}").format(model[0], all_vcf)
    model[1].fit(X_all, Y_all)
    with open(outname, "wb") as out:
        pickle.dump(model[1], out)


if __name__ == "__main__":
    main()
