#!/usr/bin/env python3

"""
Author: Shadi Zabad
Date: May 2022

This is a commandline script that enables users to generate
LD matrices in Zarr format from plink's `.bed` files.
"""

import os.path as osp
import argparse
from magenpy.GWASDataLoader import GWASDataLoader

print("""
**********************************************                            
 _ __ ___   __ _  __ _  ___ _ __  _ __  _   _ 
| '_ ` _ \ / _` |/ _` |/ _ \ '_ \| '_ \| | | |
| | | | | | (_| | (_| |  __/ | | | |_) | |_| |
|_| |_| |_|\__,_|\__, |\___|_| |_| .__/ \__, |
                 |___/           |_|    |___/
Modeling and Analysis of Genetics data in python
Version: 0.0.1 | Release date: May 2022
Author: Shadi Zabad, McGill University
**********************************************
< Compute LD matrix and output in Zarr format >
""")

parser = argparse.ArgumentParser(description="""
Commandline arguments for LD matrix computation
""")

parser.add_argument('--estimator', dest='estimator', type=str, default='windowed',
                    choices={'windowed', 'shrinkage', 'block', 'sample'},
                    help='The LD estimator (windowed, shrinkage, block, sample)')
parser.add_argument('--bfile', dest='bed_file', type=str, required=True,
                    help='The path to the BED file')
parser.add_argument('--keep', dest='keep_file', type=str,
                    help='A plink-style keep file to select a subset of individuals to compute the LD matrices.')
parser.add_argument('--extract', dest='extract_file', type=str,
                    help='A plink-style extract file to select a subset of SNPs to compute the LD for.')
parser.add_argument('--backend', dest='backend', type=str, default='dask',
                    choices={'dask', 'plink'},
                    help='The backend software used to compute the LD between variants.')
parser.add_argument('--temp-dir', dest='temp_dir', type=str, default='temp',
                    help='The temporary directory where we store intermediate files.')
parser.add_argument('--output-dir', dest='output_dir', type=str, required=True,
                    help='The output directory where the Zarr formatted LD matrices will be stored.')
# Options for the various LD estimators:
parser.add_argument('--cm-dist', dest='cm_dist', type=float, default=3.,
                    help='Maximum distance between a pair of SNPs in centi Morgan.')
parser.add_argument('--ld-blocks', dest='ld_blocks', type=str,
                    help='Path to the file with the LD block boundaries, '
                         'in LDetect format (e.g. chr start stop, tab-separated)')
parser.add_argument('--genmap-Ne', dest='genmap_ne', type=int,
                    help="The effective population size for the population from which the genetic map was derived.")
parser.add_argument('--genmap-sample-size', dest='genmap_ss', type=int,
                    help="The sample size for the dataset used to infer the genetic map.")
parser.add_argument('--shrinkage-cutoff', dest='shrink_cutoff', type=float, default=1e-5,
                    help="The cutoff value below which we assume that the correlation between variants is zero.")

args = parser.parse_args()

# ------------------------------------------------------
# Sanity checks on the parsed arguments:
if args.estimator == 'block':
    if args.ld_blocks is None:
        raise Exception("If you select the [block] LD estimator, make sure that "
                        "you also provide the ld blocks file via the --ld-blocks flag!")
    elif not osp.isfile(args.ld_blocks):
        raise FileNotFoundError("The LD blocks file does not exist!")
elif args.estimator == 'shrinkage':
    if args.genmap_ne is None:
        raise Exception("If you select the [shrinkage] estimator, you need to specify the "
                        "effective population size via the --genmap-Ne flag!")
    elif args.genmap_ss is None:
        raise Exception("If you select the [shrinkage] estimator, you need to specify the "
                        "sample size for the genetic map via the --genmap-sample-size flag!")
# ------------------------------------------------------

# Print out the parsed input commands:
print("> LD estimator:", args.estimator)

if args.estimator == 'windowed':
    print(f">>> Maximum distance: {args.cm_dist} cM")
elif args.estimator == 'block':
    print(">>> LD blocks file:", args.ld_blocks)
elif args.estimator == 'shrinkage':
    print(">>> Genetic map sample size:", args.genmap_ss)
    print(">>> Genetic map effective population size:", args.genmap_ne)
    print(">>> Shrinkage cutoff:", args.shrink_cutoff)

print("\n\n> Source data:")
print(">>> BED file:", args.bed_file)
if args.keep_file is not None:
    print(">>> Keep samples:", args.keep_file)
if args.extract_file is not None:
    print(">>> Keep variants:", args.extract_file)

print("\n\n> Output:")
print(">>> Temporary directory:", args.temp_dir)
print(">>> Output directory:", args.output_dir)

# ------------------------------------------------------

gdl = GWASDataLoader(
    args.bed_file,
    keep_individuals=args.keep_file,
    keep_snps=args.extract_file,
    ld_estimator=args.estimator,
    window_unit="cM",
    cm_window_cutoff=args.cm_dist,
    compute_ld=True,
    use_plink=args.backend == 'plink',
    output_dir=args.output_dir,
    temp_dir=args.temp_dir,
    ld_block_files=args.ld_blocks,
    genmap_Ne=args.genmap_ne,
    genmap_sample_size=args.genmap_ss,
    shrinkage_cutoff=args.shrink_cutoff
)

# Clean up all intermediate files and directories:
gdl.cleanup()
print("Done!")
