#!/usr/bin/env python3

"""
Evaluate the predictive performance of PRS models
----------------------------

This is a commandline script that can compute various metrics to evaluate the predictive performance of
polygenic risk score (PRS) models. The script can compute metrics for both continuous and binary phenotypes.

The script requires two input files:

    - `--prs-file`: The path to the PRS file. The file should have the following
        format: FID IID PRS, where FID and IID are the family and individual IDs, and PRS is the polygenic risk score.
    - `--phenotype-file`: The path to the phenotype file. The file should have the following format: FID IID phenotype,
        where FID and IID are the family and individual IDs, and phenotype is the phenotype value.

Usage:

    python -m viprs_evaluate  --prs-file /path/to/prs_file
                              --phenotype-file /path/to/phenotype_file
                              --output-file /path/to/output_file

"""

# Setup the logger:
import logging
logger = logging.getLogger(__name__)


def main():

    import argparse
    import viprs as vp

    print("\n" + vp.make_ascii_logo(
        desc='< Evaluate Prediction Accuracy of PRS Models >',
        left_padding=10
    ) + "\n", flush=True)

    parser = argparse.ArgumentParser(description="""
        Commandline arguments for evaluating polygenic scores
    """)

    parser.add_argument('--prs-file', dest='prs_file', type=str, required=True,
                        help='The path to the PRS file (expected format: FID IID PRS, tab-separated)')
    parser.add_argument('--phenotype-file', dest='pheno_file', type=str, required=True,
                        help='The path to the phenotype file. '
                             'The expected format is: FID IID phenotype (no header), tab-separated.')
    parser.add_argument('--phenotype-col', dest='pheno_col', type=int, default=2,
                        help='The column index for the phenotype in the phenotype file (0-based index).')
    parser.add_argument('--phenotype-likelihood', dest='pheno_lik', type=str, default='infer',
                        choices={'gaussian', 'binomial', 'infer'},
                        help='The phenotype likelihood ("gaussian" for continuous, "binomial" for case-control). '
                             'If not set, will be inferred automatically based on the phenotype file.')
    parser.add_argument('--keep', dest='keep', type=str,
                        help='A plink-style keep file to select a subset of individuals for the evaluation.')
    parser.add_argument('--output-file', dest='output_file', type=str, required=True,
                        help='The output file where to store the evaluation metrics (with no extension).')
    parser.add_argument('--metrics', dest='metrics', type=str, nargs='+',
                        help='The evaluation metrics to compute (default: all available metrics that are '
                             'relevant for the phenotype). For a full list of supported metrics, '
                             'check the documentation.')
    parser.add_argument('--covariates-file', dest='covariates_file', type=str,
                        help='A file with covariates for the samples included in the analysis. This tab-separated '
                             'file should not have a header and the first two columns should be '
                             'the FID and IID of the samples.')
    parser.add_argument('--log-level', dest='log_level', type=str, default='WARNING',
                        choices={'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'},
                        help='The logging level for the console output.')

    args = parser.parse_args()

    # ----------------------------------------------------------
    import os.path as osp
    import pandas as pd
    from magenpy.utils.system_utils import makedir, setup_logger
    from magenpy import SampleTable
    from viprs.eval import eval_metric_names, eval_incremental_metrics
    from viprs.eval.eval_utils import r2_stats

    # ----------------------------------------------------------
    # Setup the logger:

    # Create the output directory:
    makedir(osp.dirname(args.output_file))

    # Clear the log file:
    log_file = f"{args.output_file}.log"
    open(log_file, 'w').close()

    # Set up the module loggers:
    setup_logger(modules=['viprs', 'magenpy'],
                 log_file=log_file,
                 log_level=args.log_level)

    # Set up the logger for the main module:
    setup_logger(loggers=[logger],
                 log_file=log_file,
                 log_format='%(message)s',
                 log_level=['INFO', args.log_level][logging.getLevelName(args.log_level) < logging.INFO])

    # ----------------------------------------------------------
    logger.info('{:-^100}\n'.format('  Parsed arguments  '))

    for key, val in vars(args).items():
        if val is not None and val != parser.get_default(key):
            logger.info(f"-- {key}: {val}")

    # ----------------------------------------------------------
    logger.info('\n{:-^100}\n'.format('  Reading input data  '))

    sample_table = SampleTable(phenotype_likelihood=args.pheno_lik)

    # Read the phenotype file:
    sample_table.read_phenotype_file(args.pheno_file, usecols=[0, 1, args.pheno_col])

    assert sample_table.n > 0, logger.error("No samples found in the phenotype file.")

    # Read the covariates file:
    if args.covariates_file is not None:
        sample_table.read_covariates_file(args.covariates_file)

    if args.keep is not None:
        sample_table.filter_samples(keep_file=args.keep)

    # Make sure that samples remain after reading both:
    assert sample_table.n > 0, logger.error("No samples found after merging the covariates and phenotype files.")

    prs_df = pd.read_csv(args.prs_file, sep=r'\s+')

    # Merge the PRS data with the phenotype data:
    prs_df = sample_table.get_individual_table().merge(prs_df, on=['FID', 'IID'])

    assert len(prs_df) > 0, logger.error("No common samples found in the PRS and phenotype files.")

    sample_table.filter_samples(keep_samples=prs_df.IID.values)

    # ----------------------------------------------------------
    logger.info('\n{:-^100}\n'.format('  Evaluating PRS performance  '))

    if sample_table.phenotype_likelihood == 'binomial':
        metrics = args.metrics or ['AUROC', 'AUPRC', 'Nagelkerke_R2', 'Liability_R2']
    else:
        metrics = args.metrics or ['Pearson_R', 'R2',
                                   'Incremental_R2', 'Partial_Correlation']

    if isinstance(metrics, str):
        metrics = [metrics]

    # Loop over the requested metrics and evaluate them, and store result in a dictionary:

    info_dict = {'Sample size': sample_table.n}

    if args.covariates_file is not None:
        covariates = sample_table.get_covariates_table().drop(columns=['FID', 'IID'])
    else:
        covariates = None

    for metric in metrics:

        # If covariates are provided and the metric can be computed
        # while adjusting for covariates, then do so:
        if metric in eval_incremental_metrics and covariates is not None:
            info_dict[metric] = eval_metric_names[metric](sample_table.phenotype,
                                                          prs_df['PRS'].values,
                                                          covariates)

        else:
            info_dict[metric] = eval_metric_names[metric](sample_table.phenotype,
                                                          prs_df['PRS'].values)

        # Compute the standard errors for R-squared metrics:
        if 'R2' in metric:
            info_dict[f'{metric}_err'] = r2_stats(info_dict[metric], sample_table.n)['SE']

    # ----------------------------------------------------------

    logger.info(f"\n>>> Writing the evaluation metrics to:\n {osp.dirname(args.output_file)}")

    makedir(osp.dirname(args.output_file))
    pd.DataFrame([info_dict]).to_csv(args.output_file + ".eval", sep="\t", index=False)


if __name__ == '__main__':
    main()
