#!python

"""
Simulate Polygenic Traits using Complex Genetic Architectures
----------------------------

This is a commandline script that facilitates the simulation of complex traits
using a linear additive model with heterogeneous genetic architectures. The script
supports simulating phenotypes with different heritabilities, levels of polygenicity,
and genetic architectures. The script outputs the simulated phenotypes in a tabular
format that can be used for downstream analyses.

The script can simulate both quantitative and case-control traits. For case-control
traits, the script requires the specification of the prevalence of cases in the population.

The script requires access to genotype data in PLINK BED format.

Usage:

    python -m magenpy_simulate --bfile <bed_files> --h2 <h2> --prop-causal <p> --output-file <output_file>

"""

import os.path as osp
import numpy as np
import magenpy as mgp
import time
from datetime import timedelta
import warnings
from magenpy.simulation.PhenotypeSimulator import PhenotypeSimulator
from magenpy.utils.system_utils import makedir, get_filenames
import argparse


print(fr"""
**********************************************                            
 _ __ ___   __ _  __ _  ___ _ __  _ __  _   _ 
| '_ ` _ \ / _` |/ _` |/ _ \ '_ \| '_ \| | | |
| | | | | | (_| | (_| |  __/ | | | |_) | |_| |
|_| |_| |_|\__,_|\__, |\___|_| |_| .__/ \__, |
                 |___/           |_|    |___/
Modeling and Analysis of Genetics data in python
Version: {mgp.__version__} | Release date: {mgp.__release_date__}
Author: Shadi Zabad, McGill University
**********************************************
< Simulate complex quantitative or case-control traits >
""")

# --------- Options ---------

parser = argparse.ArgumentParser(description="""
    Commandline arguments for the complex trait simulator
""")

parser.add_argument('--bfile', dest='bed_file', type=str, required=True,
                    help='The BED files containing the genotype data. '
                         'You may use a wildcard here (e.g. "data/chr_*.bed")')
parser.add_argument('--keep', dest='keep_file', type=str,
                    help='A plink-style keep file to select a subset of individuals for simulation.')
parser.add_argument('--extract', dest='extract_file', type=str,
                    help='A plink-style extract file to select a subset of SNPs for simulation.')
parser.add_argument('--backend', dest='backend', type=str, default='xarray',
                    choices={'xarray', 'plink'},
                    help='The backend software used for the computation.')
parser.add_argument('--temp-dir', dest='temp_dir', type=str, default='temp',
                    help='The temporary directory where we store intermediate files.')
parser.add_argument('--output-file', dest='output_file', type=str, required=True,
                    help='The path where the simulated phenotype will be stored '
                         '(no extension needed).')
parser.add_argument('--output-simulated-beta', dest='output_sim_beta',
                    action='store_true', default=False,
                    help='Output a table with the true simulated effect size for each variant.')
parser.add_argument('--min-maf', dest='min_maf', type=float,
                    help='The minimum minor allele frequency for variants included in the simulation.')
parser.add_argument('--min-mac', dest='min_mac', type=int,
                    help='The minimum minor allele count for variants included in the simulation.')

# Simulation parameters:
parser.add_argument('--h2', dest='h2', type=float, required=True,
                    help='Trait heritability. Ranges between 0. and 1., inclusive.')
parser.add_argument('--mix-prop', dest='mix_prop', type=str,
                    help='Mixing proportions for the mixture density (comma separated). For example, '
                         'for the spike-and-slab mixture density, with the proportion of causal variants '
                         'set to 0.1, you can specify: "--mix-prop 0.9,0.1 --var-mult 0,1".')
parser.add_argument('--prop-causal', '-p', dest='prop_causal', type=float,
                    help='The proportion of causal variants in the simulation. See --mix-prop for '
                         'more complex architectures specification.')
parser.add_argument('--var-mult', '-d', dest='var_mult', type=str,
                    help='Multipliers on the variance for each mixture component.')
parser.add_argument('--phenotype-likelihood', dest='likelihood', type=str, default='gaussian',
                    choices={'gaussian', 'binomial'},
                    help='The likelihood for the simulated trait: '
                         'gaussian (e.g. quantitative) or binomial (e.g. case-control).')
parser.add_argument('--prevalence', dest='prevalence', type=float,
                    help='The prevalence of cases (or proportion of positives) for binary traits. '
                         'Ranges between 0. and 1.')

parser.add_argument('--seed', dest='seed', type=int,
                    help='The random seed to use for the random number generator.')

args = parser.parse_args()

# ------------------------------------------------------
# Sanity checks on the inputs:

bed_file = get_filenames(args.bed_file, extension=".bed")
if len(bed_file) < 1:
    raise FileNotFoundError(f"No BED files were identified at the specified location: {args.bed_file}")


if args.prop_causal is not None:
    pi = [1. - args.prop_causal, args.prop_causal]
    d = [0., 1.]
elif args.mix_prop is not None:
    pi = list(map(float, args.mix_prop.split(",")))
    if args.var_mult:
        d = list(map(float, args.var_mult.split(",")))
    else:
        raise ValueError("Specifying mixing proportions without variance multipliers is not permitted.")
else:
    warnings.warn("Mixing proportions not specified. Assuming an infinitesimal architecture "
                  "where all variants are causal!")
    pi = [0., 1.]
    d = [0., 1.]

if len(pi) != len(d):
    raise ValueError("The multipliers and mixing proportions must be of the same length!")

# ------------------------------------------------------
# Print out the parsed input commands:
print(f"> Simulating complex trait with {args.likelihood} likelihood...")
print(f">>> Heritability:", args.h2)
print(f">>> Mixing proportions:", pi)
print(f">>> Variance multipliers:", d)
if args.likelihood == 'binomial' and args.prevalence is not None:
    print(f">>> Prevalence:", args.prevalence)

print("\n\n> Source data:")
print(">>> BED files:", args.bed_file)

if args.keep_file is not None:
    print(">>> Keep samples:", args.keep_file)
if args.extract_file is not None:
    print(">>> Keep variants:", args.extract_file)
if args.min_maf is not None:
    print(">>> Minimum allele frequency:", args.min_maf)
if args.min_mac is not None:
    print(">>> Minimum allele count:", args.min_mac)

print("\n\n> Output:")
print(">>> Temporary directory:", args.temp_dir)
print(">>> Output file:", args.output_file)

# ------------------------------------------------------

# Record start time:
start_time = time.time()

# Set the random seed:
if args.seed is not None:
    np.random.seed(args.seed)

# Construct the PhenotypeSimulator object:
gs = PhenotypeSimulator(bed_file,
                        keep_file=args.keep_file,
                        extract_file=args.extract_file,
                        phenotype_likelihood=args.likelihood,
                        h2=args.h2,
                        pi=pi,
                        d=d,
                        min_maf=args.min_maf,
                        min_mac=args.min_mac,
                        backend=args.backend,
                        temp_dir=args.temp_dir)

print("> Simulating phenotype...")
gs.simulate(reset_beta=True, perform_gwas=False)
pheno_table = gs.to_phenotype_table()

# Write the simulated phenotypes to file:
makedir(osp.dirname(args.output_file))
pheno_table.to_csv(args.output_file + '.SimPheno', sep="\t",
                   index=False, header=False)

if args.output_sim_beta:
    # Output the simulated effect sizes:
    sim_effects = gs.to_true_beta_table()
    sim_effects.to_csv(args.output_file + ".SimEffect", sep="\t",
                       index=False)

gs.cleanup()

print("Done!")
print("> Output file(s):\n\t", args.output_file + '.SimPheno')
if args.output_sim_beta:
    print("\t", args.output_file + ".SimEffect")
# Record the end time:
end_time = time.time()
print('Total runtime:', timedelta(seconds=end_time - start_time))
