#!python

"""
Evaluate the predictive performance of PRS models
----------------------------

This is a commandline script that can compute various metrics to evaluate the predictive performance of
polygenic risk score (PRS) models. The script can compute metrics for both continuous and binary phenotypes.

The script requires two input files:

    - `--prs-file`: The path to the PRS file. The file should have the following
        format: FID IID PRS, where FID and IID are the family and individual IDs, and PRS is the polygenic risk score.
    - `--phenotype-file`: The path to the phenotype file. The file should have the following format: FID IID phenotype,
        where FID and IID are the family and individual IDs, and phenotype is the phenotype value.

Usage:

    python -m viprs_evaluate  --prs-file /path/to/prs_file
                              --phenotype-file /path/to/phenotype_file
                              --output-file /path/to/output_file

"""

import os.path as osp
import pandas as pd
import argparse

import viprs as vp
from magenpy.utils.system_utils import makedir
from magenpy import SampleTable
from viprs.eval import eval_metric_names, eval_incremental_metrics
from viprs.eval.continuous_metrics import r2_stats


print(fr"""
     **************************************************
                    _____
            ___   _____(_)________ ________________
            __ | / /__  / ___  __ \__  ___/__  ___/
            __ |/ / _  /  __  /_/ /_  /    _(__  )
            _____/  /_/   _  .___/ /_/     /____/
                          /_/
        Variational Inference of Polygenic Risk Scores
        Version: {vp.__version__} | Release date: {vp.__release_date__}
        Author: Shadi Zabad, McGill University
     **************************************************
        < Evaluate the Prediction Accuracy of PRS >
""")


parser = argparse.ArgumentParser(description="""
    Commandline arguments for evaluating polygenic score estimates
""")

parser.add_argument('--prs-file', dest='prs_file', type=str, required=True,
                    help='The path to the PRS file (expected format: FID IID PRS, tab-separated)')
parser.add_argument('--phenotype-file', dest='pheno_file', type=str, required=True,
                    help='The path to the phenotype file. '
                         'The expected format is: FID IID phenotype (no header), tab-separated.')
parser.add_argument('--phenotype-col', dest='pheno_col', type=int, default=2,
                    help='The column index for the phenotype in the phenotype file (0-based index).')
parser.add_argument('--phenotype-likelihood', dest='pheno_lik', type=str, default='infer',
                    choices={'gaussian', 'binomial', 'infer'},
                    help='The phenotype likelihood ("gaussian" for continuous, "binomial" for case-control). '
                         'If not set, will be inferred automatically based on the phenotype file.')
parser.add_argument('--output-file', dest='output_file', type=str, required=True,
                    help='The output file where to store the evaluation metrics (with no extension).')
parser.add_argument('--metrics', dest='metrics', type=str, nargs='+',
                    help='The evaluation metrics to compute (default: all available metrics that are '
                         'relevant for the phenotype). For a full list of supported metrics, '
                         'check the documentation.')
parser.add_argument('--covariates-file', dest='covariates_file', type=str,
                    help='A file with covariates for the samples included in the analysis. This tab-separated '
                         'file should not have a header and the first two columns should be '
                         'the FID and IID of the samples.')

args = parser.parse_args()

# ----------------------------------------------------------
print('{:-^62}\n'.format('  Parsed arguments  '))

for key, val in vars(args).items():
    if val is not None and val != parser.get_default(key):
        print("--", key, ":", val)

# ----------------------------------------------------------
print('\n{:-^62}\n'.format('-'))

print('> Reading input data...')

sample_table = SampleTable(phenotype_likelihood=args.pheno_lik)

# Read the phenotype file:
sample_table.read_phenotype_file(args.pheno_file, usecols=[0, 1, args.pheno_col])

assert sample_table.n > 0, "No samples found in the phenotype file."

# Read the covariates file:
if args.covariates_file is not None:
    sample_table.read_covariates_file(args.covariates_file)

# Make sure that samples remain after reading both:
assert sample_table.n > 0, "No samples found after merging the covariates and phenotype files."

prs_df = pd.read_csv(args.prs_file, sep='\t')

# Merge the PRS data with the phenotype data:
prs_df = prs_df.merge(sample_table.get_individual_table(), on=['FID', 'IID'])

assert len(prs_df) > 0, "No common samples found in the PRS and phenotype files."

sample_table.filter_samples(keep_samples=prs_df.IID.values)

# ----------------------------------------------------------

print("> Evaluating PRS model performance...")

if sample_table.phenotype_likelihood == 'binomial':
    metrics = args.metrics or ['AUROC', 'AUPRC', 'Nagelkerke_R2', 'Liability_R2']
else:
    metrics = args.metrics or ['Pearson_R', 'R2', 'Incremental_R2', 'Partial_Correlation']

if isinstance(metrics, str):
    metrics = [metrics]

# Loop over the requested metrics and evaluate them, and store result in a dictionary:

info_dict = {'Sample size': sample_table.n}

for metric in metrics:

    # If covariates are provided and the metric can be computed
    # while adjusting for covariates, then do so:
    if metric in eval_incremental_metrics:
        if args.covariates_file is not None:
            info_dict[metric] = eval_metric_names[metric](sample_table.phenotype,
                                                          prs_df['PRS'].values,
                                                          sample_table.get_covariates())
        else:
            print(f"> Warning: Skipping {metric} because covariates are required for this metric.")
            continue

    else:
        info_dict[metric] = eval_metric_names[metric](sample_table.phenotype,
                                                      prs_df['PRS'].values)

    # Compute the standard errors for R-squared metrics:
    if 'R2' in metric:
        info_dict[f'{metric}_err'] = r2_stats(info_dict[metric], sample_table.n)['SE']

# ----------------------------------------------------------

print("\n>>> Writing the evaluation results to:\n", osp.dirname(args.output_file))

makedir(osp.dirname(args.output_file))
pd.DataFrame([info_dict]).to_csv(args.output_file + ".eval", sep="\t", index=False)
