#!python

"""

Fit VIPRS models to GWAS summary statistics
----------------------------

This is a commandline script that enables users to perform
posterior inference for polygenic risk score models using
variational inference techniques.

The script is designed to be used with the VIPRS package and requires the following inputs:

- Linkage disequilibrium (LD) matrices (in zarr format).
- GWAS summary statistics files
    - The summary statistics files can be in any of the following formats:
        - plink1.9, plink2, cojo, magenpy, fastgwa, ssf, gwas-ssf, gwascatalog, saige, or a `custom` format.
- The output directory where to store the results of the inference.

Usage:

    # Fit a VIPRS model to GWAS summary statistics:
    python -m viprs_fit -l /path/to/ld_matrices/ -s /path/to/sumstats/ -o /path/to/output/

    # Fit a VIPRS model using custom GWAS summary statistics format:
    python -m viprs_fit -l /path/to/ld_matrices/ -s /path/to/sumstats/ -o /path/to/output/ --sumstats-format custom
                       --custom-sumstats-mapper rsid=SNP,eff_allele=A1,beta=BETA

"""

import os.path as osp

def check_args(args):
    """
    Check the validity, consistency, and completeness of the commandline arguments.
    """

    from magenpy.utils.system_utils import get_filenames, is_path_writable

    # ----------------------------------------------------------
    # Check that the inputs data files/directories are valid:
    ld_store_files = get_filenames(args.ld_dir, extension='.zgroup')
    if len(ld_store_files) < 1:
        raise FileNotFoundError(f"No valid LD matrix files were found at: {args.ld_dir}")

    sumstats_files = get_filenames(args.sumstats_path)
    if len(sumstats_files) < 1:
        raise FileNotFoundError(f"No valid summary statistics files were found at: {args.sumstats_path}")

    # Check that the output directory is valid / writable:
    if not is_path_writable(args.output_dir):
        raise PermissionError(f"Output directory ({args.output_dir}) is not writable.")

    # Check that the temporary directory is valid / writable:
    if not is_path_writable(args.temp_dir):
        raise PermissionError(f"Temporary directory ({args.temp_dir}) is not writable.")

    # Check the model parameters:
    if args.model == 'VIPRSMix':
        if args.n_components < 1:
            raise ValueError("The number of components for the VIPRSMix model must be at least 1.")

    # Check the arguments for validation/pseudo validation when performing
    # grid search or Bayesian optimization for hyperparameter tuning:

    if args.hyp_search in ('BO', 'GS'):

        # If the grid metric is validation/pseudo-validation, and genotype/phenotype files
        # are provided, then ensure that the files exist:
        if (args.grid_metric in ('validation', 'pseudo_validation') and
                args.validation_bed is not None
                and args.validation_pheno is not None):

            valid_bed_files = get_filenames(args.validation_bed, extension='.bed')

            if len(valid_bed_files) < 1:
                raise FileNotFoundError(f"No BED files were identified at the "
                                        f"specified location: {args.validation_bed}")

            if not osp.isfile(args.validation_pheno):
                raise FileNotFoundError(f"No phenotype file found at {args.validation_pheno}")

        # If we're performing validation with individual-level data:
        if args.grid_metric == 'validation':

            # Check that genotype and phenotype files are provided:
            if args.validation_bed is None or args.validation_pheno is None:
                raise ValueError("To perform cross-validation, you need to provide BED files and a phenotype file "
                                 "for the validation set (use --validation-bed and --validation-pheno).")

        elif args.grid_metric == 'pseudo_validation':

            # If we're performing pseudo cross-validation, then ensure that the user
            # provides either individual-level data or summary statistics for the validation set:
            if args.validation_bed is not None and args.validation_pheno is not None:
                pass
            elif args.validation_ld_panel is not None and args.validation_sumstats is not None:
                ld_store_files = get_filenames(args.validation_ld_panel, extension='.zgroup')
                if len(ld_store_files) < 1:
                    raise FileNotFoundError(f"No valid LD matrix files for the "
                                            f"validation set were found at: {args.ld_dir}")
                sumstats_files = get_filenames(args.validation_sumstats)
                if len(sumstats_files) < 1:
                    raise FileNotFoundError(f"No valid summary statistics files for the validation set "
                                            f"were found at: {args.sumstats_path}")
            else:
                raise ValueError("To perform pseudo-validation, you need to provide either individual-level data "
                                 "or summary statistics for the validation set.")

    if args.hyp_search in ('GS', 'BMA'):
        # If the hyperparameter strategy is grid search or Bayesian model averaging,
        # then ensure that we can generate a grid, given the user's selection of inputs:

        if all([args.pi_grid is None, args.pi_steps is None,
                args.sigma_epsilon_grid is None, args.sigma_epsilon_steps is None]):
            raise ValueError("For grid search (GS) or Bayesian model averaging (BMA), you need to "
                             "provide either a grid or the number of steps for the hyperparameter "
                             "search over the hyperparameters of interest.")


def init_data(args, verbose=True):
    """
    Initialize the data loaders and parsers for the GWAS summary statistics.
    This function takes as input an argparse.Namespace object (the output of
    `argparse.ArgumentParser.parse_args`) and returns a list of dictionaries
    containing the GWADataLoaders for the training (and validation) datasets.

    :param args: An argparse.Namespace object containing the parsed commandline arguments.
    :param verbose: Whether to print status messages to the console.
    :return: A list of dictionaries containing the GWADataLoaders for the training (and validation) datasets.
    Each dictionary contains the following keys:
    - 'train': The GWADataLoader for the training dataset
    - 'valid': The GWADataLoader for the validation dataset (if provided)
    """

    import numpy as np
    from magenpy.GWADataLoader import GWADataLoader
    from magenpy.parsers.sumstats_parsers import SumstatsParser

    if verbose:
        print('\n{:-^62}\n'.format('  Reading input data  '))

    # Prepare the summary statistics parsers:
    if args.sumstats_format == 'custom':

        ss_parser = SumstatsParser(col_name_converter=args.custom_sumstats_mapper,
                                   sep=args.custom_sumstats_sep)
        ss_format = None
    else:
        ss_format = args.sumstats_format
        ss_parser = None

    if verbose:
        print("> Reading the training dataset...")

    # Construct a GWADataLoader object using LD + summary statistics:
    gdl = GWADataLoader(ld_store_files=args.ld_dir,
                        temp_dir=args.temp_dir,
                        verbose=verbose,
                        threads=args.threads)
    # Read the summary statistics file(s):
    gdl.read_summary_statistics(args.sumstats_path,
                                sumstats_format=ss_format,
                                parser=ss_parser)
    # Harmonize the data:
    gdl.harmonize_data()

    # If overall GWAS sample size is provided, set it here:
    if args.gwas_sample_size is not None:
        for ss in gdl.sumstats_table.values():
            ss.set_sample_size(args.gwas_sample_size)

    if args.genomewide:
        data_loaders = {'All': {'train': gdl}}
    else:
        # If we are not performing inference genome-wide,
        # then split the GWADataLoader object into multiple loaders,
        # one per chromosome.
        data_loaders = {c: {'train': g} for c, g in gdl.split_by_chromosome().items()}

    # ----------------------------------------------------------

    if args.hyp_search in ('BO', 'GS'):

        if verbose:
            print("> Reading the validation dataset...")

        extract_snps = np.concatenate(list(gdl.snps.values()))

        if args.validation_bed is not None:

            validation_gdl = GWADataLoader(
                bed_files=args.validation_bed,
                keep_file=args.validation_keep,
                phenotype_file=args.validation_pheno,
                extract_snps=extract_snps,
                backend=args.backend,
                temp_dir=args.temp_dir,
                threads=args.threads
            )

            if args.grid_metric == 'pseudo_validation':
                validation_gdl.perform_gwas()
        else:

            # Construct the validation GWADataLoader object using LD + summary statistics:
            validation_gdl = GWADataLoader(ld_store_files=args.validation_ld_panel_ld_dir,
                                           temp_dir=args.temp_dir,
                                           verbose=verbose,
                                           threads=args.threads)

            # Prepare the validation summary statistics parsers:
            if args.validation_sumstats_format == 'custom':
                ss_parser = SumstatsParser(col_name_converter=args.validation_custom_sumstats_mapper,
                                           sep=args.validation_custom_sumstats_sep)
                ss_format = None
            else:
                ss_format = args.sumstats_format
                ss_parser = None

            # Read the summary statistics file(s):
            validation_gdl.read_summary_statistics(args.validation_sumstats_path,
                                                   sumstats_format=ss_format,
                                                   parser=ss_parser)
            # Harmonize the data:
            validation_gdl.harmonize_data()

            # Filter SNPs:
            validation_gdl.filter_snps(extract_snps)

            # If overall GWAS sample size is provided, set it here:
            if args.validation_gwas_sample_size is not None:
                for ss in validation_gdl.sumstats_table.values():
                    ss.set_sample_size(args.validation_gwas_sample_size)

            if args.genomewide:
                data_loaders['All']['valid'] = validation_gdl
            else:
                for c, g in validation_gdl.split_by_chromosome().items():
                    data_loaders[c]['valid'] = g

    return data_loaders.values()


def prepare_model(args, verbose=True):
    """
    Prepare the model object for fitting to the GWAS summary statistics.

    :param args: An argparse.Namespace object containing the parsed commandline arguments.
    :param verbose: Whether to print status messages to the console.
    :return: A partial function object for the model to fit to the GWAS data.

    """

    if verbose:
        print('\n{:-^62}\n'.format('  Model details  '))

        print("- Model:", args.model)
        hyp_map = {'GS': 'Grid search', 'BO': 'Bayesian optimization',
                   'BMA': 'Bayesian model averaging', 'EM': 'Expectation maximization'}
        print("- Hyperparameter tuning strategy:", hyp_map[args.hyp_search])
        if args.hyp_search in ('BO', 'GS'):
            print("- Model selection criterion:", args.grid_metric)

    from functools import partial
    from viprs.model.gridsearch.HyperparameterGrid import HyperparameterGrid

    from viprs.model.VIPRS import VIPRS
    from viprs.model.VIPRSMix import VIPRSMix
    from viprs.model.gridsearch.VIPRSGridSearch import VIPRSGridSearch
    from viprs.model.gridsearch.VIPRSBMA import VIPRSBMA
    from viprs.model.gridsearch.HyperparameterSearch import BayesOpt

    if args.hyp_search == 'EM':
        if args.model == 'VIPRS':
             p_model = partial(VIPRS,
                               float_precision=args.float_precision,
                               low_memory=not args.use_symmetric_ld,
                               threads=args.threads)
        elif args.model == 'VIPRSMix':
            p_model = partial(VIPRSMix,
                              K=args.n_components,
                              float_precision=args.float_precision,
                              low_memory=not args.use_symmetric_ld,
                              threads=args.threads)

    elif args.hyp_search in ('BMA', 'GS'):

        grid = HyperparameterGrid(sigma_epsilon_grid=args.sigma_epsilon_grid,
                                  sigma_epsilon_steps=args.sigma_epsilon_steps,
                                  pi_grid=args.pi_grid,
                                  pi_steps=args.pi_steps,
                                  h2_est=args.h2_est,
                                  h2_se=args.h2_se)
        model_class = {
            'GS': VIPRSGridSearch,
            'BMA': VIPRSBMA
        }

        p_model = partial(model_class[args.hyp_search],
                          grid=grid,
                          float_precision=args.float_precision,
                          low_memory=not args.use_symmetric_ld,
                          threads=args.threads)

    else:

        base_model = partial(VIPRS,
                             float_precision=args.float_precision,
                             low_memory=not args.use_symmetric_ld,
                             threads=args.threads)

        p_model = partial(BayesOpt,
                          opt_params=['sigma_epsilon', 'pi'],
                          model=base_model,
                          criterion=args.grid_metric
                          )

    return p_model


def fit_model(model, data_dict, args):
    """
    Given a model class (not initialized) and a dictionary containing
    the training and validation data loaders, fit the model to the data.

    :param model: A partial function object for the model to fit to the GWAS data.
    :param data_dict: A dictionary containing the training (and validation) data loaders.
    :param args: An argparse.Namespace object containing the parsed commandline arguments.
    """

    import time
    import numpy as np
    from magenpy.utils.system_utils import get_memory_usage

    # Set the random seed:
    np.random.seed(args.seed)

    chromosome = data_dict['train'].chromosomes
    if len(chromosome) == 1:
        chromosome = chromosome[0]

    result_dict = {
        'ProfilerMetrics': {},
        'Chromosome': chromosome
    }

    for d in data_dict.values():
        d.verbose = False

    # ----------------------------------------------------------

    # Get the memory usage before loading the model:
    start_mem = get_memory_usage()

    # Initialize the model:
    load_start_time = time.time()
    m = model(data_dict['train'])
    load_end_time = time.time()

    result_dict['ProfilerMetrics']['Load_time'] = round(load_end_time - load_start_time, 2)

    # ----------------------------------------------------------

    # Fit the model to data:
    fit_start_time = time.time()
    m = m.fit()
    fit_end_time = time.time()

    # ----------------------------------------------------------

    # Get memory usage after fitting the model:
    end_mem = get_memory_usage()

    # Record the profiler metrics:
    result_dict['ProfilerMetrics']['Fit_time'] = round(fit_end_time - fit_start_time, 2)
    result_dict['ProfilerMetrics']['Memory_usage_MB'] = round(end_mem - start_mem, 2)
    result_dict['ProfilerMetrics']['Total_Iterations'] = m.optim_result.iterations

    # ----------------------------------------------------------
    # Perform validation / model averaging:

    if args.hyp_search in ('GS', 'BMA', 'BO'):

        valid_start_time = time.time()

        if args.hyp_search == 'BMA':
            # If the model is a Bayesian model averaging model, then
            # we need to extract the final model from the BMA object:
            m = m.average_models()
        elif args.hyp_search == 'GS':
            m = m.select_best_model(data_dict['valid'], criterion=args.grid_metric)

        valid_end_time = time.time()

        result_dict['ProfilerMetrics']['Validation time'] = valid_end_time - valid_start_time

        result_dict['Validation'] = m.to_validation_table()

    result_dict['Model parameters'] = m.to_table()
    result_dict['Hyperparameters'] = m.to_theta_table()

    # ----------------------------------------------------------
    # Clean up:

    data_dict['train'].cleanup()
    del data_dict['train']
    if 'valid' in data_dict:
        data_dict['valid'].cleanup()
        del data_dict['valid']

    # ----------------------------------------------------------

    return result_dict


def main():

    import argparse
    import viprs as vp

    print(fr"""
        **********************************************
                    _____
            ___   _____(_)________ ________________
            __ | / /__  / ___  __ \__  ___/__  ___/
            __ |/ / _  /  __  /_/ /_  /    _(__  )
            _____/  /_/   _  .___/ /_/     /____/
                          /_/
        Variational Inference of Polygenic Risk Scores
        Version: {vp.__version__} | Release date: {vp.__release_date__}
        Author: Shadi Zabad, McGill University
        **********************************************
        < Fit VIPRS models to GWAS summary statistics >
    """)

    parser = argparse.ArgumentParser(description="""
        Commandline arguments for fitting VIPRS models to GWAS summary statistics
    """)

    # Required input/output parameters:
    parser.add_argument('-l', '--ld-panel', dest='ld_dir', type=str, required=True,
                        help='The path to the directory where the LD matrices are stored. '
                             'Can be a wildcard of the form ld/chr_*')
    parser.add_argument('-s', '--sumstats', dest='sumstats_path', type=str, required=True,
                        help='The summary statistics directory or file. Can be a '
                             'wildcard of the form sumstats/chr_*')

    # Output files / directories:
    parser.add_argument('--output-dir', dest='output_dir', type=str, required=True,
                        help='The output directory where to store the inference results.')
    parser.add_argument('--output-file-prefix', dest='output_prefix', type=str, default='',
                        help='A prefix to append to the names of the output files (optional).')
    parser.add_argument('--temp-dir', dest='temp_dir', type=str, default='temp',
                        help='The temporary directory where to store intermediate files.')

    # Optional input arguments:
    parser.add_argument('--sumstats-format', dest='sumstats_format',
                        type=str.lower, default='plink',
                        choices={'plink1.9', 'plink2', 'plink', 'cojo', 'magenpy',
                                 'fastgwa', 'ssf', 'gwas-ssf', 'gwascatalog', 'saige', 'custom'},
                        help='The format for the summary statistics file(s).')
    parser.add_argument('--custom-sumstats-mapper', dest='custom_sumstats_mapper', type=str,
                        help='A comma-separated string with column name mappings between the custom '
                             'summary statistics format and the standard format expected by magenpy/VIPRS. '
                             'Provide only mappings for column names that are different, in the form of:'
                             '--custom-sumstats-mapper rsid=SNP,eff_allele=A1,beta=BETA')
    parser.add_argument('--custom-sumstats-sep', dest='custom_sumstats_sep', type=str, default='\t',
                        help='The delimiter for the summary statistics file with custom format.')
    parser.add_argument('--gwas-sample-size', dest='gwas_sample_size', type=float,
                        help='The overall sample size for the GWAS study. This must be provided if the '
                             'sample size per-SNP is not in the summary statistics file.')

    # Arguments for the validation set (individual-level data):
    parser.add_argument('--validation-bfile', dest='validation_bed', type=str,
                        help='The BED files containing the genotype data for the validation set. '
                             'You may use a wildcard here (e.g. "data/chr_*.bed")')
    parser.add_argument('--validation-pheno', dest='validation_pheno', type=str,
                        help='A tab-separated file containing the phenotype for the validation set. '
                             'The expected format is: FID IID phenotype (no header)')
    parser.add_argument('--validation-keep', dest='validation_keep', type=str,
                        help='A plink-style keep file to select a subset of individuals for the validation set.')

    # Arguments for the validation set (summary statistics):
    parser.add_argument('--validation-ld-panel', dest='validation_ld_panel', type=str,
                        help='The path to the directory where the LD matrices for the validation set are stored. '
                             'Can be a wildcard of the form ld/chr_*')
    parser.add_argument('--validation-sumstats', dest='validation_sumstats_path', type=str,
                        help='The summary statistics directory or file for the validation set. Can be a '
                             'wildcard of the form sumstats/chr_*')
    parser.add_argument('--validation-sumstats-format', dest='validation_sumstats_format',
                        type=str.lower,
                        choices={'plink1.9', 'plink2', 'plink', 'cojo', 'magenpy',
                                 'fastgwa', 'ssf', 'gwas-ssf', 'gwascatalog', 'saige', 'custom'},
                        help='The format for the summary statistics file(s) for the validation set.')
    # For now, suppress the help message for these arguments:
    parser.add_argument('--validation-custom-sumstats-mapper',
                        dest='validation_custom_sumstats_mapper', type=str,
                        help=argparse.SUPPRESS)
    parser.add_argument('--validation-custom-sumstats-sep',
                        dest='validation_custom_sumstats_sep',
                        type=str, default='\t',
                        help=argparse.SUPPRESS)
    parser.add_argument('--validation-gwas-sample-size',
                        dest='validation_gwas_sample_size',
                        type=float,
                        help=argparse.SUPPRESS)

    # Arguments for the model type / parameters:
    parser.add_argument('-m', '--model', dest='model', type=str, default='VIPRS',
                        help='The type of PRS model to fit to the GWAS data',
                        choices={'VIPRS', 'VIPRSMix'})
    parser.add_argument('--float-precision', dest='float_precision', type=str, default='float32',
                        help='The float precision to use when fitting the model.',
                        choices={'float32', 'float64'})
    parser.add_argument('--use-symmetric-ld', dest='use_symmetric_ld', action='store_true',
                        default=False,
                        help='Use the symmetric form of the LD matrix when fitting the model.')
    parser.add_argument('--n-components', dest='n_components', type=int, default=3,
                        help='The number of non-null Gaussian mixture components to use with the VIPRSMix model '
                             '(i.e. excluding the spike component).')

    # Arguments for Hyperparameter tuning / model initialization:
    parser.add_argument('--h2-est', dest='h2_est', type=float,
                        help='The estimated heritability of the trait. If available, this value can be '
                             'used for parameter initialization or hyperparameter grid search.')
    parser.add_argument('--h2-se', dest='h2_se', type=float,
                        help='The standard error for the heritability estimate for the trait. '
                             'If available, this value can be used for parameter initialization '
                             'or hyperparameter grid search.')

    parser.add_argument('--hyp-search', dest='hyp_search', type=str, default='EM',
                        choices={'EM', 'GS', 'BO', 'BMA'},
                        help='The strategy for tuning the hyperparameters of the model. '
                             'Options are EM (Expectation-Maximization), GS (Grid search), '
                             'BO (Bayesian Optimization), and BMA (Bayesian Model Averaging).')
    parser.add_argument('--grid-metric', dest='grid_metric', type=str, default='validation',
                        help='The metric for selecting best performing model in grid search.',
                        choices={'ELBO', 'validation', 'pseudo_validation'})

    # Grid-related parameters:
    parser.add_argument('--pi-grid', dest='pi_grid', type=str,
                        help='A comma-separated grid values for the hyperparameter pi (see also --pi-steps).')
    parser.add_argument('--pi-steps', dest='pi_steps', type=int,
                        help='The number of steps for the (default) pi grid. This will create an equidistant '
                             'grid between 1/M and (M-1)/M on a log10 scale, where M is the number of SNPs.')

    parser.add_argument('--sigma-epsilon-grid', dest='sigma_epsilon_grid', type=str,
                        help='A comma-separated grid values for the hyperparameter sigma_epsilon '
                             '(see also --sigma-epsilon-steps).')
    parser.add_argument('--sigma-epsilon-steps', dest='sigma_epsilon_steps', type=int,
                        help='The number of steps (unique values) for the sigma_epsilon grid.')

    # Miscellaneous / general parameters:

    parser.add_argument('--genomewide', dest='genomewide', action='store_true', default=False,
                        help='Fit all chromosomes jointly')
    parser.add_argument('--backend', dest='backend', type=str.lower, default='xarray',
                        choices={'xarray', 'plink'},
                        help='The backend software used for computations on the genotype matrix.')
    parser.add_argument('--n-jobs', dest='n_jobs', type=int, default=1,
                        help='The number of processes to launch for the hyperparameter search (default is '
                             '1, but we recommend increasing this depending on system capacity).')
    parser.add_argument('--threads', dest='threads', type=int, default=1,
                        help='The number of threads to use in the E-Step of VIPRS.')
    parser.add_argument('--output-profiler-metrics', dest='output_profiler_metrics', action='store_true',
                        default=False, help='Output the profiler metrics that measure runtime, memory usage, etc.')

    parser.add_argument('--seed', dest='seed', type=int, default=7209,
                        help='The random seed to use for the random number generator.')

    args = parser.parse_args()

    # ----------------------------------------------------------
    # Import required modules:

    import time
    from datetime import timedelta
    import pandas as pd
    import os.path as osp
    from magenpy.utils.system_utils import makedir
    from joblib import Parallel, delayed

    # ----------------------------------------------------------
    print('{:-^62}\n'.format('  Parsed arguments  '))

    for key, val in vars(args).items():
        if val is not None and val != parser.get_default(key):
            print("--", key, ":", val)

    # ----------------------------------------------------------

    # (1) Check the validity, consistency, and completeness of the commandline arguments:
    check_args(args)

    # Record start time:
    total_start_time = time.time()

    # (2) Read the data:
    data_loaders = init_data(args)

    # (3) Prepare the model:
    model = prepare_model(args)

    # (4) Fit to data:
    print('\n{:-^62}\n'.format('  Model fitting  '))

    fit_results = Parallel(n_jobs=args.n_jobs)(
        delayed(fit_model)(model, dl, args)
        for idx, dl in enumerate(data_loaders)
    )

    # Record end time:
    total_end_time = time.time()

    # (5) Write the results to file:

    if len(fit_results) > 0:
        print('\n{:-^62}\n'.format(''))
        print("\n>>> Writing the inference results to:\n", args.output_dir)

        makedir(osp.abspath(args.output_dir))

        eff_tables = pd.concat([r['Model parameters'] for r in fit_results])
        hyp_tables = pd.concat([r['Hyperparameters'].assign(Chromosome=r['Chromosome']) for r in fit_results])

        try:
            valid_tables = pd.concat([r['Validation'].assign(Chromosome=r['Chromosome']) for r in fit_results])
        except KeyError:
            valid_tables = None
            pass

        profm_table = pd.concat([pd.DataFrame(r['ProfilerMetrics'], index=[0]).assign(Chromosome=r['Chromosome'])
                                 for r in fit_results])
        profm_table['Total_WallClockTime'] = round(total_end_time - total_start_time, 2)

        output_prefix = osp.join(args.output_dir, args.output_prefix + args.model + '_' + args.hyp_search)

        if len(eff_tables) > 0:
            eff_tables.to_csv(output_prefix + '.fit.gz', sep="\t", index=False)

        if len(hyp_tables) > 0:
            hyp_tables.to_csv(output_prefix + '.hyp', sep="\t", index=False)

        if valid_tables is not None:
            valid_tables.to_csv(output_prefix + '.validation', sep="\t", index=False)

        if args.output_profiler_metrics:
            profm_table.to_csv(output_prefix + '.time', sep="\t", index=False)

    print('>>> Total Runtime:\n', timedelta(seconds=total_end_time - total_start_time))


if __name__ == '__main__':

    main()
