#!/usr/bin/env python3

"""
Compute Polygenic Scores for Test Samples
----------------------------

This is a commandline script that computes polygenic scores for test samples
given effect size estimates from VIPRS. The script can work with effect sizes from
other software, as long as they're formatted in the same way as VIPRS `.fit` files.

Usage:

    python -m viprs_score -f <fit_files> --bed-files <bed_files> --output-file <output_file>

    - `fit_files` is the path to the file(s) with the output parameter estimates from VIPRS.
    - `bed_files` is the BED files containing the genotype data.
    - `output_file` is the output file where to store the polygenic scores (with no extension).

"""

# Setup the logger:
import logging
logger = logging.getLogger(__name__)


def main():

    import argparse
    import viprs as vp

    print("\n" + vp.make_ascii_logo(
        desc='< Compute Polygenic Scores for Test Samples >',
        left_padding=10
    ) + "\n", flush=True)

    parser = argparse.ArgumentParser(description="""
        Commandline arguments for computing polygenic scores
    """)

    parser.add_argument('-f', '--fit-files', dest='fit_files', type=str, required=True,
                        help='The path to the file(s) with the output parameter estimates from VIPRS. '
                             'You may use a wildcard here if fit files are stored '
                             'per-chromosome (e.g. "prs/chr_*.fit")')
    parser.add_argument('--bfile', dest='bed_files', type=str, required=True,
                        help='The BED files containing the genotype data. '
                             'You may use a wildcard here (e.g. "data/chr_*.bed")')
    parser.add_argument('--output-file', dest='output_file', type=str, required=True,
                        help='The output file where to store the polygenic scores (with no extension).')

    parser.add_argument('--temp-dir', dest='temp_dir', type=str, default='temp',
                        help='The temporary directory where to store intermediate files.')

    parser.add_argument('--keep', dest='keep', type=str,
                        help='A plink-style keep file to select a subset of individuals for the test set.')
    parser.add_argument('--extract', dest='extract', type=str,
                        help='A plink-style extract file to select a subset of SNPs for scoring.')
    parser.add_argument('--backend', dest='backend', type=str, default='xarray',
                        choices={'xarray', 'plink'},
                        help='The backend software used for computations with the genotype matrix.')
    parser.add_argument('--threads', dest='threads', type=int, default=1,
                        help='The number of threads to use for computations.')
    parser.add_argument('--compress', dest='compress', action='store_true', default=False,
                        help='Compress the output file')

    parser.add_argument('--log-level', dest='log_level', type=str, default='WARNING',
                        choices={'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'},
                        help='The logging level for the console output.')

    args = parser.parse_args()

    # ----------------------------------------------------------
    import os.path as osp
    from magenpy.utils.system_utils import makedir, get_filenames, setup_logger
    from magenpy.GWADataLoader import GWADataLoader
    from viprs.model.BayesPRSModel import BayesPRSModel

    # ----------------------------------------------------------
    # Setup the logger:

    # Create the output directory:
    makedir(osp.dirname(args.output_file))

    # Clear the log file:
    log_file = f"{args.output_file}.log"
    open(log_file, 'w').close()

    # Set up the module loggers:
    setup_logger(modules=['viprs', 'magenpy'],
                 log_file=log_file,
                 log_level=args.log_level)

    # Set up the logger for the main module:
    setup_logger(loggers=[logger],
                 log_file=log_file,
                 log_format='%(message)s',
                 log_level=['INFO', args.log_level][logging.getLevelName(args.log_level) < logging.INFO])

    # ----------------------------------------------------------

    logger.info('{:-^100}\n'.format('  Parsed arguments  '))

    for key, val in vars(args).items():
        if val is not None and val != parser.get_default(key):
            logger.info(f"-- {key}: {val}")

    # ----------------------------------------------------------
    logger.info('\n{:-^100}\n'.format('  Reading input data  '))

    test_data = GWADataLoader(args.bed_files,
                              keep_file=args.keep,
                              extract_file=args.extract,
                              min_mac=None,
                              min_maf=None,
                              backend=args.backend,
                              temp_dir=args.temp_dir,
                              threads=args.threads)
    prs_m = BayesPRSModel(test_data)

    fit_files = get_filenames(args.fit_files, extension='.fit')

    if len(fit_files) < 1:
        err_msg = "Did not find PRS coefficient files at:\n" + args.fit_files
        logger.error(err_msg)
        raise FileNotFoundError(err_msg)

    prs_m.read_inferred_parameters(fit_files)

    # ----------------------------------------------------------
    logger.info('\n{:-^100}\n'.format('  Scoring  '))

    # Predict on the test set:
    prs = test_data.score(prs_m.get_posterior_mean_beta())

    # Save the PRS as a table:

    ind_table = test_data.to_individual_table().copy()
    ind_table['PRS'] = prs

    # Clean up all the intermediate files/directories
    test_data.cleanup()

    logger.info(f"\n>>> Writing the polygenic scores to:\n {osp.dirname(args.output_file)}")

    # If the user wants the files to be compressed, append `.gz` to the name:
    c_ext = ['', '.gz'][args.compress]

    # Output the scores:
    makedir(osp.dirname(args.output_file))
    ind_table.to_csv(args.output_file + '.prs' + c_ext, index=False, sep="\t")


if __name__ == '__main__':
    main()
