#!python

"""
Compute Polygenic Scores for Test Samples
----------------------------

This is a commandline script that computes polygenic scores for test samples
given effect size estimates from VIPRS. The script can work with effect sizes from
other software, as long as they're formatted in the same way as VIPRS `.fit` files.

Usage:

    python -m viprs_score -f <fit_files> --bed-files <bed_files> --output-file <output_file>

    - `fit_files` is the path to the file(s) with the output parameter estimates from VIPRS.
    - `bed_files` is the BED files containing the genotype data.
    - `output_file` is the output file where to store the polygenic scores (with no extension).

"""

import os.path as osp
import argparse
import viprs as vp
from magenpy.utils.system_utils import makedir, get_filenames
from magenpy.GWADataLoader import GWADataLoader
from viprs.model.BayesPRSModel import BayesPRSModel

print(fr"""
        **********************************************
                    _____
            ___   _____(_)________ ________________
            __ | / /__  / ___  __ \__  ___/__  ___/
            __ |/ / _  /  __  /_/ /_  /    _(__  )
            _____/  /_/   _  .___/ /_/     /____/
                          /_/
        Variational Inference of Polygenic Risk Scores
        Version: {vp.__version__} | Release date: {vp.__release_date__}
        Author: Shadi Zabad, McGill University
        **********************************************
        < Compute Polygenic Scores for Test Samples >
""")

parser = argparse.ArgumentParser(description="""
    Commandline arguments for computing polygenic scores
""")

parser.add_argument('-f', '--fit-files', dest='fit_files', type=str, required=True,
                    help='The path to the file(s) with the output parameter estimates from VIPRS. '
                         'You may use a wildcard here if fit files are stored per-chromosome (e.g. "prs/chr_*.fit")')
parser.add_argument('--bfile', dest='bed_files', type=str, required=True,
                    help='The BED files containing the genotype data. '
                         'You may use a wildcard here (e.g. "data/chr_*.bed")')
parser.add_argument('--output-file', dest='output_file', type=str, required=True,
                    help='The output file where to store the polygenic scores (with no extension).')

parser.add_argument('--temp-dir', dest='temp_dir', type=str, default='temp',
                    help='The temporary directory where to store intermediate files.')

parser.add_argument('--keep', dest='keep', type=str,
                    help='A plink-style keep file to select a subset of individuals for the test set.')
parser.add_argument('--extract', dest='extract', type=str,
                    help='A plink-style extract file to select a subset of SNPs for scoring.')
parser.add_argument('--backend', dest='backend', type=str, default='xarray',
                    choices={'xarray', 'plink'},
                    help='The backend software used for computations with the genotype matrix.')
parser.add_argument('--threads', dest='threads', type=int, default=1,
                    help='The number of threads to use for computations.')
parser.add_argument('--compress', dest='compress', action='store_true', default=False,
                    help='Compress the output file')

args = parser.parse_args()

# ----------------------------------------------------------
print('{:-^62}\n'.format('  Parsed arguments  '))

for key, val in vars(args).items():
    if val is not None and val != parser.get_default(key):
        print("--", key, ":", val)

# ----------------------------------------------------------
print('\n{:-^62}\n'.format('  Reading input data  '))

test_data = GWADataLoader(args.bed_files,
                          keep_file=args.keep,
                          extract_file=args.extract,
                          min_mac=None,
                          min_maf=None,
                          backend=args.backend,
                          temp_dir=args.temp_dir,
                          threads=args.threads)
prs_m = BayesPRSModel(test_data)

fit_files = get_filenames(args.fit_files, extension='.fit')

if len(fit_files) < 1:
    raise FileNotFoundError("Did not find any parameter fit files at:", args.fit_files)

prs_m.read_inferred_parameters(fit_files)

# ----------------------------------------------------------
print('\n{:-^62}\n'.format('-'))
# Predict on the test set:
print("> Generating polygenic scores...")
prs = test_data.score(prs_m.get_posterior_mean_beta())

# Save the PRS as a table:

ind_table = test_data.to_individual_table().copy()
ind_table['PRS'] = prs

# Clean up all the intermediate files/directories
test_data.cleanup()

print("\n>>> Writing the polygenic scores to:\n", osp.dirname(args.output_file))

# If the user wants the files to be compressed, append `.gz` to the name:
c_ext = ['', '.gz'][args.compress]

# Output the scores:
makedir(osp.dirname(args.output_file))
ind_table.to_csv(args.output_file + '.prs' + c_ext, index=False, sep="\t")
