#!python

"""
Compute Linkage-Disequilibrium (LD) matrices and store in Zarr array format
----------------------------

This is a commandline script that facilitates the computation of LD matrices
from genotype data stored in plink BED format. The script supports various
estimators for computing the LD matrix, including windowed, shrinkage, block,
and sample-based estimators. The script outputs the computed LD matrices in
Zarr array format, which is a compressed, parallelized, and scalable format
for storing large numerical arrays.

Usage:

    python -m magenpy_ld --bfile <bed_file> --estimator <estimator> --output-dir <output_dir>

For larger genotype matrices, we recommend using `plink1.9` as a backend to compute the LD matrices.
You can do that by specifying the `--backend` parameter:

    python -m magenpy_ld --bfile <bed_file> --estimator <estimator> --output-dir <output_dir> --backend plink

"""

import os.path as osp
import argparse
import magenpy as mgp
import time
from datetime import timedelta
from magenpy.utils.system_utils import valid_url
from magenpy.GenotypeMatrix import xarrayGenotypeMatrix, plinkBEDGenotypeMatrix

print(fr"""
**********************************************                            
 _ __ ___   __ _  __ _  ___ _ __  _ __  _   _ 
| '_ ` _ \ / _` |/ _` |/ _ \ '_ \| '_ \| | | |
| | | | | | (_| | (_| |  __/ | | | |_) | |_| |
|_| |_| |_|\__,_|\__, |\___|_| |_| .__/ \__, |
                 |___/           |_|    |___/
Modeling and Analysis of Genetics data in python
Version: {mgp.__version__} | Release date: {mgp.__release_date__}
Author: Shadi Zabad, McGill University
**********************************************
< Compute LD matrix and output in Zarr format >
""")

parser = argparse.ArgumentParser(description="""
    Commandline arguments for LD matrix computation
""")

# General options:
parser.add_argument('--estimator', dest='estimator', type=str, default='windowed',
                    choices={'windowed', 'shrinkage', 'block', 'sample'},
                    help='The LD estimator (windowed, shrinkage, block, sample)')
parser.add_argument('--bfile', dest='bed_file', type=str, required=True,
                    help='The path to a plink BED file.')
parser.add_argument('--keep', dest='keep_file', type=str,
                    help='A plink-style keep file to select a subset of individuals to compute the LD matrices.')
parser.add_argument('--extract', dest='extract_file', type=str,
                    help='A plink-style extract file to select a subset of SNPs to compute the LD matrix for.')
parser.add_argument('--backend', dest='backend', type=str, default='xarray',
                    choices={'xarray', 'plink'},
                    help='The backend software used to compute the Linkage-Disequilibrium between variants.')
parser.add_argument('--temp-dir', dest='temp_dir', type=str, default='temp',
                    help='The temporary directory where we store intermediate files.')
parser.add_argument('--output-dir', dest='output_dir', type=str, required=True,
                    help='The output directory where the Zarr formatted LD matrices will be stored.')
parser.add_argument('--min-maf', dest='min_maf', type=float,
                    help='The minimum minor allele frequency for variants included in the LD matrix.')
parser.add_argument('--min-mac', dest='min_mac', type=float,
                    help='The minimum minor allele count for variants included in the LD matrix.')

# Metadata / reproducibility options:
parser.add_argument('--genome-build', dest='genome_build', type=str,
                    help='The genome build for the genotype data (recommend storing as metadata).')
parser.add_argument('--metadata', dest='metadata', type=str,
                    help='A comma-separated string with metadata keys and values. This is used to store '
                         'information about the genotype data from which the LD matrix was computed, such as '
                         'the biobank/samples, cohort characteristics (e.g. ancestry), etc. Keys and values '
                         'should be separated by =, such that inputs are in the form of:'
                         '--metadata Biobank=UKB,Ancestry=EUR,Date=April2024')

# Argument for the float precision:
parser.add_argument('--storage-dtype', dest='storage_dtype', type=str,
                    default='int16', help='The data type for the entries of the LD matrix.',
                    choices={'float32', 'float64', 'int16', 'int8'})

# Add arguments for the compressor:
parser.add_argument('--compressor', dest='compressor', type=str,
                    default='lz4', help='The compressor name or compression algorithm to use for the LD matrix.',
                    choices={'lz4', 'zstd', 'gzip', 'zlib'})

parser.add_argument('--compression-level', dest='compression_level', type=int,
                    default=5, help='The compression level to use for the entries of the LD matrix (1-9).')

# Options for the various LD estimators:

# For the windowed estimator:
parser.add_argument('--ld-window', dest='ld_window', type=int,
                    help='Maximum number of neighboring SNPs to consider when computing LD.')
parser.add_argument('--ld-window-kb', dest='ld_window_kb', type=float,
                    help='Maximum distance (in kilobases) between pairs of variants when computing LD.')
parser.add_argument('--ld-window-cm', dest='ld_window_cm', type=float,
                    help='Maximum distance (in centi Morgan) between pairs of variants when computing LD.')

# For the block estimator:
parser.add_argument('--ld-blocks', dest='ld_blocks', type=str,
                    help='Path to the file with the LD block boundaries, '
                         'in LDetect format (e.g. chr start stop, tab-separated)')

# For the shrinkage estimator:
parser.add_argument('--genmap-Ne', dest='genmap_ne', type=int,
                    help="The effective population size for the population from which the genetic map was derived.")
parser.add_argument('--genmap-sample-size', dest='genmap_ss', type=int,
                    help="The sample size for the dataset used to infer the genetic map.")
parser.add_argument('--shrinkage-cutoff', dest='shrink_cutoff', type=float,
                    help="The cutoff value below which we assume that the correlation between variants is zero.")

args = parser.parse_args()

# ------------------------------------------------------
# Sanity checks on the parsed arguments:

if args.estimator == 'windowed':
    if args.ld_window is None and args.ld_window_kb is None and args.ld_window_cm is None:
        raise Exception("For the windowed estimator, the user must provide the window size using --ld-window or "
                        "the maximum distance in kilobases (--ld-window-kb) or centi Morgan (--ld-window-cm).")

elif args.estimator == 'block':
    if args.ld_blocks is None:
        raise Exception("If you select the [block] LD estimator, make sure that "
                        "you also provide the ld blocks file via the --ld-blocks flag!")
    elif not osp.isfile(args.ld_blocks) and not valid_url(args.ld_blocks):
        raise FileNotFoundError("The LD blocks file does not exist!")
elif args.estimator == 'shrinkage':
    if args.genmap_ne is None:
        raise Exception("If you select the [shrinkage] estimator, you need to specify the "
                        "effective population size via the --genmap-Ne flag!")
    elif args.genmap_ss is None:
        raise Exception("If you select the [shrinkage] estimator, you need to specify the "
                        "sample size for the genetic map via the --genmap-sample-size flag!")

# ------------------------------------------------------
# Extract the arguments for selected estimator:

ld_kwargs = {}

if args.estimator == 'windowed':
    if args.ld_window is not None:
        ld_kwargs['window_size'] = args.ld_window
    if args.ld_window_kb is not None:
        ld_kwargs['kb_window_size'] = args.ld_window_kb
    if args.ld_window_cm is not None:
        ld_kwargs['cm_window_size'] = args.ld_window_cm
elif args.estimator == 'block':
    ld_kwargs['ld_blocks_file'] = args.ld_blocks
elif args.estimator == 'shrinkage':
    if args.genmap_ne is not None:
        ld_kwargs['genetic_map_ne'] = args.genmap_ne
    if args.genmap_ss is not None:
        ld_kwargs['genetic_map_sample_size'] = args.genmap_ss
    if args.shrink_cutoff is not None:
        ld_kwargs['threshold'] = args.shrink_cutoff

# ------------------------------------------------------

# Print out the parsed input commands:
print("> LD estimator:", args.estimator)

print(">>> Parsed estimator characteristics:\n", ld_kwargs)

print("\n\n> Source data:")
print(">>> BED file:", args.bed_file)

if args.keep_file is not None:
    print(">>> Keep samples:", args.keep_file)
if args.extract_file is not None:
    print(">>> Keep variants:", args.extract_file)
if args.min_maf is not None:
    print(">>> Minimum allele frequency:", args.min_maf)
if args.min_mac is not None:
    print(">>> Minimum allele count:", args.min_mac)

print(">>> Storage data type:", args.storage_dtype)
print(">>> Compressor:", args.compressor)
print(">>> Compression level:", args.compression_level)

print("\n\n> Output:")
print(">>> Temporary directory:", args.temp_dir)
print(">>> Output directory:", args.output_dir)

# ------------------------------------------------------
# Perform the computation:

print("\n\n> Processing the genotype data...")

if args.backend == 'xarray':
    g = xarrayGenotypeMatrix.from_file(args.bed_file,
                                       temp_dir=args.temp_dir,
                                       genome_build=args.genome_build)
else:
    g = plinkBEDGenotypeMatrix.from_file(args.bed_file,
                                         temp_dir=args.temp_dir,
                                         genome_build=args.genome_build)

if args.keep_file is not None:
    print("> Filtering samples...")
    g.filter_samples(keep_file=args.keep_file)

if args.extract_file is not None:
    print("> Filtering variants...")
    g.filter_snps(extract_file=args.extract_file)

if args.min_mac is not None or args.min_maf is not None:
    print("> Filtering variants by allele frequency/count...")
    g.filter_by_allele_frequency(min_maf=args.min_maf, min_mac=args.min_mac)


# Record start time:
start_time = time.time()

# Compute LD matrix:
print("> Computing the LD matrix...")
ld_mat = g.compute_ld(args.estimator,
                      args.output_dir,
                      dtype=args.storage_dtype,
                      compressor_name=args.compressor,
                      compression_level=args.compression_level,
                      **ld_kwargs)

# Store metadata (if provided):
if args.metadata is not None:
    parsed_metadata = {
        k: v for entry in args.metadata.split(',') for k, v in [entry.strip().split('=')]
        if len(entry.strip()) > 0
    }

    if len(parsed_metadata) > 0:
        for k, v in parsed_metadata.items():
            ld_mat.set_store_attr(k, v)


# Clean up all intermediate files and directories:
g.cleanup()

print("Done!")
print("> Output directory:\n\t", args.output_dir)
# Record the end time:
end_time = time.time()
print('Total runtime:', timedelta(seconds=end_time - start_time))
