#!/usr/bin/env python3

"""

Fit VIPRS models to GWAS summary statistics
----------------------------

This is a commandline script that enables users to perform
posterior inference for polygenic risk score models using
variational inference techniques.

The script is designed to be used with the VIPRS package and requires the following inputs:

- Linkage disequilibrium (LD) matrices (in zarr format).
- GWAS summary statistics files
    - The summary statistics files can be in any of the following formats:
        - plink1.9, plink2, cojo, magenpy, fastgwa, ssf, gwas-ssf, gwascatalog, saige, or a `custom` format.
- The output directory where to store the results of the inference.

Usage:

    # Fit a VIPRS model to GWAS summary statistics:
    python -m viprs_fit -l /path/to/ld_matrices/ -s /path/to/sumstats/ -o /path/to/output/

    # Fit a VIPRS model using custom GWAS summary statistics format:
    python -m viprs_fit -l /path/to/ld_matrices/ -s /path/to/sumstats/ -o /path/to/output/ --sumstats-format custom
                       --custom-sumstats-mapper rsid=SNP,eff_allele=A1,beta=BETA

"""

import os.path as osp
import time

# Setup the logger:
import logging
logger = logging.getLogger(__name__)


def check_args(args):
    """
    Check the validity, consistency, and completeness of the commandline arguments.
    """

    from magenpy.utils.system_utils import get_filenames, is_path_writable

    # ----------------------------------------------------------
    # Check that the inputs data files/directories are valid:
    # TODO: Update this check once LD matrices are on AWS s3
    ld_store_files = get_filenames(args.ld_dir, extension='.zgroup')
    if len(ld_store_files) < 1:
        raise FileNotFoundError(f"No valid LD matrix files were found at: {args.ld_dir}")

    sumstats_files = get_filenames(args.sumstats_path)
    if len(sumstats_files) < 1:
        raise FileNotFoundError(f"No valid summary statistics files were found at: {args.sumstats_path}")

    # Check that the output directory is valid / writable:
    if not is_path_writable(args.output_dir):
        raise PermissionError(f"Output directory ({args.output_dir}) is not writable.")

    # Check that the temporary directory is valid / writable:
    if not is_path_writable(args.temp_dir):
        raise PermissionError(f"Temporary directory ({args.temp_dir}) is not writable.")

    # Check the model parameters:
    if args.model == 'VIPRSMix':
        if args.n_components < 1:
            raise ValueError("The number of components for the VIPRSMix model must be at least 1.")

    # Check the residual variance parameter (if provided):
    if args.fix_sigma_epsilon is not None:
        assert 0 < args.fix_sigma_epsilon < 1., "The residual variance must be between 0. and 1. (exclusive)."

    # Check that lambda_min is non-negative (if provided):
    if args.lambda_min is not None:
        if args.lambda_min != 'infer':
            try:
                lm_min = float(args.lambda_min)
                if lm_min < 0.:
                    raise ValueError("The lambda_min parameter must be a non-negative number.")
            except ValueError:
                "The lambda_min parameter must be set to 'infer' or non-negative number."

    # Check the arguments for validation/pseudo validation when performing
    # grid search or Bayesian optimization for hyperparameter tuning:

    if args.hyp_search in ('GS',):

        # If the grid metric is validation/pseudo-validation, and genotype/phenotype files
        # are provided, then ensure that the files exist:
        if (args.grid_metric in ('validation', 'pseudo_validation') and
                args.validation_bed is not None
                and args.validation_pheno is not None):

            valid_bed_files = get_filenames(args.validation_bed, extension='.bed')

            if len(valid_bed_files) < 1:
                raise FileNotFoundError(f"No BED files were identified at the "
                                        f"specified location: {args.validation_bed}")

            if not osp.isfile(args.validation_pheno):
                raise FileNotFoundError(f"No phenotype file found at {args.validation_pheno}")

        # If we're performing validation with individual-level data:
        if args.grid_metric == 'validation':

            # Check that genotype and phenotype files are provided:
            if args.validation_bed is None or args.validation_pheno is None:
                raise ValueError("To perform cross-validation, you need to provide BED files and a phenotype file "
                                 "for the validation set (use --validation-bed and --validation-pheno).")

        elif args.grid_metric == 'pseudo_validation':

            # If we're performing pseudo cross-validation, then ensure that the user
            # provides either individual-level data or summary statistics for the validation set:
            if args.validation_bed is not None and args.validation_pheno is not None:
                pass
            elif args.validation_ld_panel is not None and args.validation_sumstats_path is not None:
                ld_store_files = get_filenames(args.validation_ld_panel, extension='.zgroup')
                if len(ld_store_files) < 1:
                    raise FileNotFoundError(f"No valid LD matrix files for the "
                                            f"validation set were found at: {args.ld_dir}")
                sumstats_files = get_filenames(args.validation_sumstats_path)
                if len(sumstats_files) < 1:
                    raise FileNotFoundError(f"No valid summary statistics files for the validation set "
                                            f"were found at: {args.validation_sumstats_path}")
            else:
                logger.info("\n> Performing pseudo-validation without external GWAS summary statistics."
                            "\nWill automatically split the training data into "
                            f"training and validation sets (training proportion: {args.prop_train}).")

    if args.hyp_search in ('GS', 'BMA'):
        # If the hyperparameter strategy is grid search or Bayesian model averaging,
        # then ensure that we can generate a grid, given the user's selection of inputs:

        if all([args.pi_grid is None, args.pi_steps is None,
                args.sigma_epsilon_grid is None, args.sigma_epsilon_steps is None]):
            raise ValueError("For grid search (GS) or Bayesian model averaging (BMA), you need to "
                             "provide either a grid or the number of steps for the hyperparameter "
                             "search over the hyperparameters of interest. See --pi-grid, --pi-steps, "
                             "--sigma-epsilon-grid, and --sigma-epsilon-steps.")


def init_data(args):
    """
    Initialize the data loaders and parsers for the GWAS summary statistics.
    This function takes as input an argparse.Namespace object (the output of
    `argparse.ArgumentParser.parse_args`) and returns a list of dictionaries
    containing the GWADataLoaders for the training (and validation) datasets.

    :param args: An argparse.Namespace object containing the parsed commandline arguments.
    :return: A list of dictionaries containing the GWADataLoaders for the training (and validation) datasets.
    Each dictionary contains the following keys:
    - 'train': The GWADataLoader for the training dataset
    - 'valid': The GWADataLoader for the validation dataset (if provided)
    """

    import numpy as np
    from magenpy.GWADataLoader import GWADataLoader
    from magenpy.parsers.sumstats_parsers import SumstatsParser

    logger.info('\n{:-^100}\n'.format('  Reading & harmonizing input data  '))

    # Prepare the summary statistics parsers:
    if args.sumstats_format == 'custom':

        ss_parser = SumstatsParser(col_name_converter=args.custom_sumstats_mapper,
                                   sep=args.custom_sumstats_sep)
        ss_format = None
    else:
        ss_format = args.sumstats_format
        ss_parser = None

    logger.info(">>> Reading the training dataset...")

    # Construct a GWADataLoader object using LD + summary statistics:
    gdl = GWADataLoader(ld_store_files=args.ld_dir,
                        temp_dir=args.temp_dir,
                        threads=args.threads)

    # Unless the user explicitly decides to keep them, filter long-range LD regions:
    if args.exclude_lrld:
        logger.info("> Filtering long-range LD regions...")
        for ld in gdl.ld.values():
            ld.filter_long_range_ld_regions()

    # Read the summary statistics file(s):
    gdl.read_summary_statistics(args.sumstats_path,
                                sumstats_format=ss_format,
                                parser=ss_parser)
    # Harmonize the data:
    gdl.harmonize_data()

    # If overall GWAS sample size is provided, set it here:
    if args.gwas_sample_size is not None:
        for ss in gdl.sumstats_table.values():
            ss.set_sample_size(args.gwas_sample_size)

    if args.genomewide:
        data_loaders = {'All': {'train': gdl}}
    else:
        # If we are not performing inference genome-wide,
        # then split the GWADataLoader object into multiple loaders,
        # one per chromosome.
        data_loaders = {c: {'train': g} for c, g in gdl.split_by_chromosome().items()}

    # ----------------------------------------------------------
    # If the user provides validation data, extract this data and harmonize it with the
    # training data:

    if args.hyp_search in ('GS', ) and (args.validation_bed is not None or
                                        args.validation_ld_panel is not None):

        logger.info(">>> Reading the validation dataset...")

        extract_snps = np.concatenate(list(gdl.snps.values()))

        if args.validation_bed is not None:

            validation_gdl = GWADataLoader(
                bed_files=args.validation_bed,
                keep_file=args.validation_keep,
                phenotype_file=args.validation_pheno,
                extract_snps=extract_snps,
                backend=args.backend,
                temp_dir=args.temp_dir,
                threads=args.threads
            )

            if args.grid_metric == 'pseudo_validation':
                validation_gdl.perform_gwas()
        else:

            # Construct the validation GWADataLoader object using LD + summary statistics:
            validation_gdl = GWADataLoader(ld_store_files=args.validation_ld_panel,
                                           temp_dir=args.temp_dir,
                                           threads=args.threads)

            # Prepare the validation summary statistics parsers:
            if args.validation_sumstats_format == 'custom':
                ss_parser = SumstatsParser(col_name_converter=args.validation_custom_sumstats_mapper,
                                           sep=args.validation_custom_sumstats_sep)
                ss_format = None
            else:
                ss_format = args.sumstats_format
                ss_parser = None

            # Read the summary statistics file(s):
            validation_gdl.read_summary_statistics(args.validation_sumstats_path,
                                                   sumstats_format=ss_format,
                                                   parser=ss_parser)

            # If overall GWAS sample size is provided, set it here:
            if args.validation_gwas_sample_size is not None:
                for ss in validation_gdl.sumstats_table.values():
                    ss.set_sample_size(args.validation_gwas_sample_size)

        # Filter SNPs:
        validation_gdl.filter_snps(extract_snps)
        # Harmonize the data:
        validation_gdl.harmonize_data()

        if args.genomewide:
            data_loaders['All']['valid'] = validation_gdl
        else:
            for c, g in validation_gdl.split_by_chromosome().items():
                data_loaders[c]['valid'] = g

    return data_loaders.values()


def prepare_model(args):
    """
    Prepare the model object for fitting to the GWAS summary statistics.

    :param args: An argparse.Namespace object containing the parsed commandline arguments.
    :return: A partial function object for the model to fit to the GWAS data.

    """

    # Print information about the model details:
    logger.info('\n{:-^100}\n'.format('  Model details  '))

    logger.info(f"- Model:{args.model}")
    hyp_map = {'GS': 'Grid search', 'BO': 'Bayesian optimization',
               'BMA': 'Bayesian model averaging', 'EM': 'Expectation maximization'}
    logger.info(f"- Hyperparameter tuning strategy: {hyp_map[args.hyp_search]}")
    if args.hyp_search in ('GS', ):
        logger.info(f"- Model selection criterion: {args.grid_metric}")

    from functools import partial
    from viprs.model import VIPRS, VIPRSMix

    if args.lambda_min is None:
        lambda_min = 0.
    elif args.lambda_min == 'infer':
        lambda_min = 'infer'
    else:
        lambda_min = float(args.lambda_min)

    if args.hyp_search == 'EM':

        # If the user requests fixing the residual variance to a particular value,
        # then set it here and pass to VIPRS:
        if args.fix_sigma_epsilon is not None:
            fix_params = {'sigma_epsilon': args.fix_sigma_epsilon}
        else:
            fix_params = None

        if args.model == 'VIPRS':
             p_model = partial(VIPRS,
                               float_precision=args.float_precision,
                               low_memory=not args.use_symmetric_ld,
                               lambda_min=lambda_min,
                               dequantize_on_the_fly=args.dequantize_on_the_fly,
                               threads=args.threads,
                               fix_params=fix_params)
        elif args.model == 'VIPRSMix':
            p_model = partial(VIPRSMix,
                              K=args.n_components,
                              float_precision=args.float_precision,
                              low_memory=not args.use_symmetric_ld,
                              lambda_min=lambda_min,
                              dequantize_on_the_fly=args.dequantize_on_the_fly,
                              threads=args.threads,
                              fix_params=fix_params)

    elif args.hyp_search in ('BMA', 'GS'):

        from viprs.model.gridsearch import HyperparameterGrid
        from viprs.model import VIPRSGrid

        grid = HyperparameterGrid(sigma_epsilon_grid=args.sigma_epsilon_grid,
                                  sigma_epsilon_steps=args.sigma_epsilon_steps,
                                  pi_grid=args.pi_grid,
                                  pi_steps=args.pi_steps,
                                  h2_est=args.h2_est,
                                  h2_se=args.h2_se)

        logger.info("- Hyperparameter grid:")
        logger.info(grid.to_table())

        p_model = partial(VIPRSGrid,
                          grid=grid,
                          float_precision=args.float_precision,
                          low_memory=not args.use_symmetric_ld,
                          lambda_min=lambda_min,
                          dequantize_on_the_fly=args.dequantize_on_the_fly,
                          threads=args.threads)

    """
    else:

        from viprs.model.gridsearch.HyperparameterSearch import BayesOpt

        base_model = partial(VIPRS,
                             float_precision=args.float_precision,
                             low_memory=not args.use_symmetric_ld,
                             lambda_min=lambda_min,
                             dequantize_on_the_fly=args.dequantize_on_the_fly,
                             threads=args.threads)

        p_model = partial(BayesOpt,
                          opt_params=['sigma_epsilon', 'pi'],
                          model=base_model,
                          criterion=args.grid_metric
                          )
    """

    return p_model


def fit_model(model, data_dict, args):
    """
    Given a model class (not initialized) and a dictionary containing
    the training and validation data loaders, fit the model to the data.

    :param model: A partial function object for the model to fit to the GWAS data.
    :param data_dict: A dictionary containing the training (and validation) data loaders.
    :param args: An argparse.Namespace object containing the parsed commandline arguments.
    """

    import numpy as np

    # Set the random seed:
    np.random.seed(args.seed)

    chromosome = data_dict['train'].chromosomes
    if len(chromosome) == 1:
        chromosome = chromosome[0]

    result_dict = {
        'ProfilerMetrics': {},
        'Chromosome': chromosome
    }

    for d in data_dict.values():
        d.verbose = False

    # ----------------------------------------------------------
    # If the model uses a grid (e.g. VIPRSGridSearch, VIPRSBMA), then make sure
    # to update the parameters of the grid to match the characteristics of the data
    # used in model fit:

    if args.hyp_search in ('GS', 'BMA'):

        grid = model.keywords['grid']

        if args.pi_steps is not None:
            grid.n_snps = data_dict['train'].n_snps
            grid.generate_pi_grid(steps=args.pi_steps)

        if args.lambda_min_steps is not None:

            ld_mat = list(data_dict['train'].ld.values())[0]
            lambda_min = ld_mat.get_lambda_min(aggregate='min')

            grid.generate_lambda_min_grid(steps=args.lambda_min_steps, emp_lambda_min=lambda_min)

        from functools import partial
        model = partial(model.func, **{**model.keywords, 'grid': grid})

    # ----------------------------------------------------------

    # Initialize the model:
    load_start_time = time.time()
    m = model(data_dict['train'])
    load_end_time = time.time()

    result_dict['ProfilerMetrics']['Load_time'] = round(load_end_time - load_start_time, 2)

    # ----------------------------------------------------------
    # If the user selected grid search (GS) but did not provide validation data,
    # then split the training data into training and validation sets:

    if args.hyp_search == 'GS' and 'valid' not in data_dict:

        logger.debug("> Splitting the training data into training and validation sets (PUMAS)...")

        split_start_time = time.time()
        m.split_gwas_sumstats(args.prop_train,
                              seed=args.seed,
                              threads=args.threads)
        split_end_time = time.time()

        result_dict['ProfilerMetrics']['PUMAS_split_time'] = round(split_end_time - split_start_time, 2)

    # ----------------------------------------------------------

    from functools import partial

    if args.hyp_search in ('GS', 'BMA') and args.grid_search_mode != 'pathwise':
        fit_func = partial(m.fit, max_iter=args.max_iter, pathwise=False)
    else:
        fit_func = partial(m.fit, max_iter=args.max_iter)

    # Fit the model to data:
    fit_start_time = time.time()

    m = fit_func()
    if m.optim_result.error_on_termination:

        if m._sigma_g < 0. and np.all(m.lambda_min == 0.):
            logger.info("> Optimization diverged. "
                        "Re-trying with setting regularization parameter lambda_min...")
            for c in m.shapes:
                m.lambda_min = m.gdl.ld[c].get_lambda_min(min_max_ratio=1e-3)
            m = fit_func()

        else:
            raise Exception(m.optim_result.message)

    fit_end_time = time.time()

    # ----------------------------------------------------------

    # Record the profiler metrics:
    result_dict['ProfilerMetrics']['Fit_time'] = round(fit_end_time - fit_start_time, 2)
    result_dict['ProfilerMetrics']['Total_Iterations'] = m.optim_result.iterations

    # ----------------------------------------------------------
    # Perform validation / model averaging:

    if args.hyp_search in ('GS', 'BMA'):

        from viprs.model.gridsearch.grid_utils import (
            bayesian_model_average, select_best_model
        )

        valid_start_time = time.time()

        if args.hyp_search == 'BMA':
            # If the mode is a Bayesian model averaging model, then
            # we need to extract the final model from the BMA object:

            m = bayesian_model_average(m)
        elif args.hyp_search == 'GS':

            if 'valid' in data_dict:
                m = select_best_model(m, data_dict['valid'], criterion=args.grid_metric)
            else:
                m = select_best_model(m, None, criterion='pseudo_validation')

                # If the sumstats were split, then reset the original data and
                # then refit the selected model on the full sumstats:
                m.initialize_input_data_arrays()

                # If the number of threads was reduced previously,
                # reset it:
                if m.threads < args.threads:
                    m.threads = args.threads

                fit_start_time = time.time()
                m = fit_func()
                fit_end_time = time.time()

                result_dict['ProfilerMetrics']['Fit_time'] += round(fit_end_time - fit_start_time, 2)

        valid_end_time = time.time()

        result_dict['ProfilerMetrics']['Validation_time'] = round(valid_end_time - valid_start_time, 2)

        result_dict['Validation'] = m.to_validation_table()

    result_dict['Model parameters'] = m.to_table()
    result_dict['Hyperparameters'] = m.to_theta_table()

    # ----------------------------------------------------------
    # Clean up:

    data_dict['train'].cleanup()
    del data_dict['train']
    if 'valid' in data_dict:
        data_dict['valid'].cleanup()
        del data_dict['valid']

    # ----------------------------------------------------------

    return result_dict


def main():

    import argparse
    import viprs as vp

    print("\n" + vp.make_ascii_logo(
        desc='< Fit VIPRS to GWAS summary statistics >',
        left_padding=10
    ) + "\n", flush=True)

    parser = argparse.ArgumentParser(description="""
        Commandline arguments for fitting VIPRS to GWAS summary statistics
    """)

    # Required input parameters:
    parser.add_argument('-l', '--ld-panel', dest='ld_dir', type=str, required=True,
                        help='The path to the directory where the LD matrices are stored. '
                             'Can be a wildcard of the form ld/chr_*')
    parser.add_argument('-s', '--sumstats', dest='sumstats_path', type=str, required=True,
                        help='The summary statistics directory or file. Can be a '
                             'wildcard of the form sumstats/chr_*')

    # Output files / directories:
    parser.add_argument('--output-dir', dest='output_dir', type=str, required=True,
                        help='The output directory where to store the inference results.')
    parser.add_argument('--output-file-prefix', dest='output_prefix', type=str, default='',
                        help='A prefix to append to the names of the output files (optional).')
    parser.add_argument('--temp-dir', dest='temp_dir', type=str, default='temp',
                        help='The temporary directory where to store intermediate files.')

    # Optional input arguments:
    parser.add_argument('--sumstats-format', dest='sumstats_format',
                        type=str.lower, default='plink',
                        choices={'plink1.9', 'plink2', 'plink', 'cojo', 'magenpy',
                                 'fastgwa', 'ssf', 'gwas-ssf', 'gwascatalog', 'saige', 'custom'},
                        help='The format for the summary statistics file(s).')
    parser.add_argument('--custom-sumstats-mapper', dest='custom_sumstats_mapper', type=str,
                        help='A comma-separated string with column name mappings between the custom '
                             'summary statistics format and the standard format expected by magenpy/VIPRS. '
                             'Provide only mappings for column names that are different, in the form of:'
                             '--custom-sumstats-mapper rsid=SNP,eff_allele=A1,beta=BETA')
    parser.add_argument('--custom-sumstats-sep', dest='custom_sumstats_sep', type=str, default='\t',
                        help='The delimiter for the summary statistics file with custom format.')
    parser.add_argument('--gwas-sample-size', dest='gwas_sample_size', type=float,
                        help='The overall sample size for the GWAS study. This must be provided if the '
                             'sample size per-SNP is not in the summary statistics file.')

    # Arguments for the validation set (individual-level data):
    parser.add_argument('--validation-bfile', dest='validation_bed', type=str,
                        help='The BED files containing the genotype data for the validation set. '
                             'You may use a wildcard here (e.g. "data/chr_*.bed")')
    parser.add_argument('--validation-pheno', dest='validation_pheno', type=str,
                        help='A tab-separated file containing the phenotype for the validation set. '
                             'The expected format is: FID IID phenotype (no header)')
    parser.add_argument('--validation-keep', dest='validation_keep', type=str,
                        help='A plink-style keep file to select a subset of individuals for the validation set.')

    # Arguments for the validation set (summary statistics):
    parser.add_argument('--validation-ld-panel', dest='validation_ld_panel', type=str,
                        help='The path to the directory where the LD matrices for the validation set are stored. '
                             'Can be a wildcard of the form ld/chr_*')
    parser.add_argument('--validation-sumstats', dest='validation_sumstats_path', type=str,
                        help='The summary statistics directory or file for the validation set. Can be a '
                             'wildcard of the form sumstats/chr_*')
    parser.add_argument('--validation-sumstats-format', dest='validation_sumstats_format',
                        type=str.lower,
                        choices={'plink1.9', 'plink2', 'plink', 'cojo', 'magenpy',
                                 'fastgwa', 'ssf', 'gwas-ssf', 'gwascatalog', 'saige', 'custom'},
                        help='The format for the summary statistics file(s) for the validation set.')
    # For now, suppress the help message for these arguments:
    parser.add_argument('--validation-custom-sumstats-mapper',
                        dest='validation_custom_sumstats_mapper', type=str,
                        help=argparse.SUPPRESS)
    parser.add_argument('--validation-custom-sumstats-sep',
                        dest='validation_custom_sumstats_sep',
                        type=str, default='\t',
                        help=argparse.SUPPRESS)
    parser.add_argument('--validation-gwas-sample-size',
                        dest='validation_gwas_sample_size',
                        type=float,
                        help=argparse.SUPPRESS)

    # Arguments for the model type / parameters:
    parser.add_argument('-m', '--model', dest='model', type=str, default='VIPRS',
                        help='The type of PRS model to fit to the GWAS data',
                        choices={'VIPRS', 'VIPRSMix'})
    parser.add_argument('--float-precision', dest='float_precision', type=str, default='float32',
                        help='The float precision to use when fitting the model.',
                        choices={'float32', 'float64'})
    parser.add_argument('--use-symmetric-ld', dest='use_symmetric_ld', action='store_true',
                        default=False,
                        help='Use the symmetric form of the LD matrix when fitting the model.')
    parser.add_argument('--dequantize-on-the-fly', dest='dequantize_on_the_fly', action='store_true',
                        default=False,
                        help='Dequantize the entries of the LD matrix on-the-fly during inference.')
    parser.add_argument('--fix-sigma-epsilon', dest='fix_sigma_epsilon', type=float,
                        help='Set the value of the residual variance hyperparameter, sigma_epsilon, '
                             'to the provided value.')
    parser.add_argument('--lambda-min', dest='lambda_min', type=str,
                        help='Set the value of the lambda_min parameter, which acts as a regularizer for '
                             'the effect sizes and compensates for noise in the LD matrix. Set to "infer" to '
                             'derive this parameter from the properties of the LD matrix itself.')
    parser.add_argument('--n-components', dest='n_components', type=int, default=3,
                        help='The number of non-null Gaussian mixture components to use with the VIPRSMix model '
                             '(i.e. excluding the spike component).')
    parser.add_argument('--max-iter', dest='max_iter', type=int, default=1000,
                        help='The maximum number of iterations to run the coordinate ascent algorithm.')

    # Arguments for Hyperparameter tuning / model initialization:
    parser.add_argument('--h2-est', dest='h2_est', type=float,
                        help='The estimated heritability of the trait. If available, this value can be '
                             'used for parameter initialization or hyperparameter grid search.')
    parser.add_argument('--h2-se', dest='h2_se', type=float,
                        help='The standard error for the heritability estimate for the trait. '
                             'If available, this value can be used for parameter initialization '
                             'or hyperparameter grid search.')

    parser.add_argument('--hyp-search', dest='hyp_search', type=str, default='EM',
                        choices={'EM', 'GS', 'BMA'},
                        help='The strategy for tuning the hyperparameters of the model. '
                             'Options are EM (Expectation-Maximization), GS (Grid search), '
                             'and BMA (Bayesian Model Averaging).')
    parser.add_argument('--grid-metric', dest='grid_metric', type=str, default='pseudo_validation',
                        help='The metric for selecting best performing model in grid search.',
                        choices={'ELBO', 'validation', 'pseudo_validation'})

    # Grid-related parameters:
    parser.add_argument('--grid-search-mode', dest='grid_search_mode',
                        type=str.lower, default='pathwise',
                        choices={'pathwise', 'independent'},
                        help='The mode for grid search. Pathwise mode updates the hyperparameters '
                             'sequentially and in a warm-start fashion, while independent mode updates each model '
                             'separately starting from same initialization.')
    parser.add_argument('--prop-train', dest='prop_train', type=float, default=0.8,
                        help='The proportion of the samples to use for training when performing cross '
                             'validation using the PUMAS procedure.')
    parser.add_argument('--pi-grid', dest='pi_grid', type=str,
                        help='A comma-separated grid values for the hyperparameter pi (see also --pi-steps).')
    parser.add_argument('--pi-steps', dest='pi_steps', type=int,
                        help='The number of steps for the (default) pi grid. This will create an equidistant '
                             'grid between 10/M and 0.2 on a log10 scale, where M is the number of variants.')

    parser.add_argument('--sigma-epsilon-grid', dest='sigma_epsilon_grid', type=str,
                        help='A comma-separated grid values for the hyperparameter sigma_epsilon '
                             '(see also --sigma-epsilon-steps).')
    parser.add_argument('--sigma-epsilon-steps', dest='sigma_epsilon_steps', type=int,
                        help='The number of steps (unique values) for the sigma_epsilon grid.')

    parser.add_argument('--lambda-min-steps', dest='lambda_min_steps', type=int,
                        help='The number of grid steps for the lambda_min grid. Lambda_min is used to compensate '
                             'for noise in the LD matrix and acts as an extra regularizer for the effect sizes.')

    # Miscellaneous / general parameters:

    parser.add_argument('--genomewide', dest='genomewide', action='store_true', default=False,
                        help='Fit all chromosomes jointly')
    parser.add_argument('--exclude-lrld', dest='exclude_lrld', action='store_true', default=False,
                        help='Exclude Long Range LD (LRLD) regions during inference. These regions can cause '
                             'numerical instabilities in some cases.')
    parser.add_argument('--backend', dest='backend', type=str.lower, default='xarray',
                        choices={'xarray', 'plink'},
                        help='The backend software used for computations on the genotype matrix.')
    parser.add_argument('--n-jobs', dest='n_jobs', type=int, default=1,
                        help='The number of processes to launch for the hyperparameter search (default is '
                             '1, but we recommend increasing this depending on system capacity).')
    parser.add_argument('--threads', dest='threads', type=int, default=1,
                        help='The number of threads to use in the E-Step of VIPRS.')
    parser.add_argument('--output-profiler-metrics', dest='output_profiler_metrics', action='store_true',
                        default=False, help='Output the profiler metrics that measure runtime, memory usage, etc.')

    # Logging / reproducibility parameters:
    parser.add_argument('--log-level', dest='log_level', type=str, default='WARNING',
                        choices={'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'},
                        help='The logging level for the console output.')
    parser.add_argument('--seed', dest='seed', type=int, default=7209,
                        help='The random seed to use for the random number generator.')

    args = parser.parse_args()

    # ----------------------------------------------------------
    # Import required modules:

    from datetime import timedelta
    import pandas as pd
    from magenpy.utils.system_utils import makedir, PeakMemoryProfiler, setup_logger
    from joblib import Parallel, delayed

    # ----------------------------------------------------------
    # Setup the logger:

    # Create the output directory:
    makedir(osp.abspath(args.output_dir))
    # Set the output prefix:
    output_prefix = osp.join(args.output_dir, args.output_prefix + args.model + '_' + args.hyp_search)

    # Clear the log file:
    open(f"{output_prefix}.log", 'w').close()

    # Set up the module loggers:
    setup_logger(modules=['viprs', 'magenpy'],
                 log_file=f"{output_prefix}.log",
                 log_level=args.log_level)

    # Set up the logger for the main module:
    setup_logger(loggers=[logger],
                 log_file=f"{output_prefix}.log",
                 log_format='%(message)s',
                 log_level=['INFO', args.log_level][logging.getLevelName(args.log_level) < logging.INFO])

    # ----------------------------------------------------------
    # Print the parsed arguments:

    logger.info('{:-^100}\n'.format('  Parsed arguments  '))

    for key, val in vars(args).items():
        if val is not None and val != parser.get_default(key):
            logger.info(f"-- {key}: {val}")

    # ----------------------------------------------------------

    # (1) Check the validity, consistency, and completeness of the commandline arguments:
    try:
        check_args(args)
    except Exception as e:
        logger.error(f"An error occurred while parsing the commandline arguments:\n{str(e)}")
        raise e

    # Record start time:
    total_start_time = time.time()

    # (2) Read the data:
    data_loaders = init_data(args)
    # Record time for data preparation:
    data_prep_time = time.time()

    # (3) Prepare the model:
    try:
        model = prepare_model(args)
    except Exception as e:
        logger.error(f"An error occurred while preparing the model:\n{str(e)}")
        raise e

    # (4) Fit to data:
    logger.info('\n{:-^100}\n'.format('  Inference  '))

    with PeakMemoryProfiler() as peak_mem:

        parallel = Parallel(n_jobs=args.n_jobs)

        with parallel:
            fit_results = parallel(
                delayed(fit_model)(model, dl, args)
                for idx, dl in enumerate(data_loaders)
            )

    # Record end time:
    total_end_time = time.time()

    # (5) Write the results to file:

    if len(fit_results) > 0:
        logger.info('\n{:-^100}\n'.format(''))
        logger.info(f"\n>>> Writing the inference results to:\n {args.output_dir}")

        makedir(osp.abspath(args.output_dir))

        eff_tables = pd.concat([r['Model parameters'] for r in fit_results])
        hyp_tables = pd.concat([r['Hyperparameters'].assign(Chromosome=r['Chromosome']) for r in fit_results])

        try:
            valid_tables = pd.concat([r['Validation'].assign(Chromosome=r['Chromosome']) for r in fit_results])
        except KeyError:
            valid_tables = None
            pass

        profm_table = pd.concat([
            pd.DataFrame(r['ProfilerMetrics'], index=[0]).assign(Chromosome=r['Chromosome'])
            for r in fit_results
        ])
        profm_table['Total_WallClockTime'] = round(total_end_time - total_start_time, 2)
        profm_table['DataPrep_Time'] = round(data_prep_time - total_start_time, 2)
        profm_table['Peak_Memory_MB'] = round(peak_mem.get_peak_memory(), 2)

        if len(eff_tables) > 0:
            eff_tables.to_csv(f"{output_prefix}.fit.gz", sep="\t", index=False)

        if len(hyp_tables) > 0:
            hyp_tables.to_csv(f"{output_prefix}.hyp", sep="\t", index=False)

        if valid_tables is not None:
            valid_tables.to_csv(f"{output_prefix}.validation", sep="\t", index=False)

        if args.output_profiler_metrics:
            profm_table.to_csv(f"{output_prefix}.prof", sep="\t", index=False)

    logger.info(f">>> Total Runtime:\n {timedelta(seconds=total_end_time - total_start_time)}")


if __name__ == '__main__':

    main()
