#!/usr/bin/env python
from __future__ import print_function, division
import argparse, sys
import sldp.config as config
import ypy.pretty as pretty


def main(args):
    print('initializing...')
    import os, time
    import numpy as np
    import pandas as pd
    import scipy.stats as st
    import gprim.annotation as ga
    import gprim.dataset as gd
    import ypy.memo as memo
    import sldp.weights as weights
    import sldp.chunkstats as cs
    import sldp.storyteller as storyteller

    # basic initialization
    mhc_bp = [25684587, 35455756]
    refpanel = gd.Dataset(args.bfile_reg_chr)
    pheno_name = args.pss_chr.split('/')[-2].replace('.KG3.95','')
    if args.seed is not None:
        np.random.seed(args.seed)
        print('random seed:', args.seed)

    # read in names of background annotations and marginal annotations
    annots = [ga.Annotation(annot) for annot in args.sannot_chr]
    marginal_names = sum([a.names(22, RV=True) for a in annots], [])
    marginal_names = [n for n in marginal_names if '.R' in n]
    marginal_infos = pd.concat([a.info_df(args.chroms) for a in annots], axis=0)
    backgroundannots = [ga.Annotation(annot) for annot in args.background_sannot_chr]
    background_names = sum([a.names(22, True) for a in backgroundannots], [])
    background_names = [n for n in background_names if '.R' in n]
    print('background annotations:', background_names)
    print('marginal annotations:', marginal_names)
    if len(set(background_names) & set(marginal_names)) > 0:
        raise ValueError('the background annotation names and the marginal annotation '+\
                'names must be disjoint sets')

    # read in ldblocks and remove ones that overlap mhc
    ldblocks = pd.read_csv(args.ld_blocks, delim_whitespace=True, header=None,
            names=['chr','start', 'end'])
    mhcblocks = (ldblocks.chr == 'chr6') & \
            (ldblocks.end > mhc_bp[0]) & \
            (ldblocks.start < mhc_bp[1])
    ldblocks = ldblocks[~mhcblocks]

    # read information about sumstats
    sumstats_info = pd.read_csv(args.pss_chr+'info', sep='\t')
    sigma2g = sumstats_info.loc[0].sigma2g
    h2g = sumstats_info.loc[0].h2g

    # read in sumstats and annots, and compute numerator and denominator of regression for
    # each ldblock. These will later be processed and aggregated
    numerators = dict(); denominators = dict()
    t0 = time.time()
    for c in args.chroms:
        print(time.time()-t0, ': loading chr', c, 'of', args.chroms)

        # get refpanel snp metadata 
        snps = refpanel.bim_df(c)
        print(len(snps), 'snps in refpanel', len(snps.columns), 'columns, including metadata')

        # read sumstats
        print('reading sumstats')
        ss = pd.read_csv(args.pss_chr+str(c)+'.pss.gz', sep='\t')
        print(np.isnan(ss[args.weights]).sum(), 'sumstats nans out of', len(ss))
        snps['Winv_ahat'] = ss[args.weights]
        snps['N'] = ss.N
        snps['typed'] = snps.Winv_ahat.notnull()
        if args.chi2_thresh > 0:
            print('applying chi2 threshold of', args.chi2_thresh)
            snps.typed &= (ss.Winv_ahat_I**2 * ss.N > args.chi2_thresh)
            print(snps.typed.sum(), 'typed snps left')

        # read annotations
        print('reading annotations')
        for annot in backgroundannots+annots:
            mynames = [n for n in annot.names(22, RV=True) if '.R' in n] #names of annotations
            print(time.time()-t0, ': reading annot', annot.filestem())
            print('adding', mynames)
            snps = pd.concat([snps, annot.RV_df(c)[mynames]], axis=1)
            if (~np.isfinite(snps[mynames].values)).sum() > 0:
                raise ValueError('There should be no nans in the postprocessed annotation. '+\
                        'But there are '+str((~np.isfinite(snps[mynames].values)).sum()))

        # make sure things are in the order we think they are
        if (np.array(background_names+marginal_names) !=
                snps.columns.values[-len(background_names+marginal_names):]).any():
            raise ValueError('Merged annotations are not in the right order')

        # perform computations
        for ldblock, _, meta, ind in refpanel.block_data(
                ldblocks, c, meta=snps, genos=False, verbose=0):
            if meta.typed.sum() == 0 or \
                    not os.path.exists(args.svd_stem+str(ldblock.name)+'.R.npz'):
                # no typed snps/hm3 snps in this block. set num snps to 0
                ldblocks.loc[ldblock.name, 'M_H'] = 0
                continue
            if (meta[background_names+marginal_names] == 0).values.all():
                # annotations are all 0 in this block
                ldblocks.loc[ldblock.name, 'M_H'] = 0
                continue
            # record the number of typed snps in this block, and start- and end- snp indices
            ldblocks.loc[ldblock.name, 'M_H'] = meta.typed.sum()
            ldblocks.loc[ldblock.name, 'snpind_begin'] = min(meta.index) # for verbose
            ldblocks.loc[ldblock.name, 'snpind_end'] = max(meta.index)+1 # for verbose

            # load regression weights and prepare for regression computation
            meta_t = meta[meta.typed.values]
            N = meta_t.N.mean()
            if args.weights == 'Winv_ahat_h' or args.weights == 'Winv_ahat_hlN':
                R = np.load(args.svd_stem+str(ldblock.name)+'.R.npz')
                R2 = None
                if len(R['U']) != len(meta):
                    raise ValueError('regression wgts dimension must match regression snps')
            elif args.weights == 'Winv_ahat_h2' or args.weights == 'Winv_ahat':
                R = np.load(args.svd_stem+str(ldblock.name)+'.R.npz')
                R2 = np.load(args.svd_stem+str(ldblock.name)+'.R2.npz')
                if len(R['U']) != len(meta) or len(R2['U']) != len(meta):
                    raise ValueError('regression wgts dimension must match regression snps')
            else:
                R = None; R2 = None

            # multiply ahat by the weights
            Winv_RV = weights.invert_weights(
                    R, R2, sigma2g, N, meta[background_names+marginal_names].values,
                    typed=meta.typed.values, mode=args.weights)

            numerators[ldblock.name] = \
                (meta_t[background_names+marginal_names].T.dot(
                    meta_t.Winv_ahat)).values/1e6
            denominators[ldblock.name] = meta_t[background_names+marginal_names].T.dot(
                    Winv_RV[meta.typed.values]).values/1e6
        memo.reset()

    # get data for jackknifing
    print('jackknifing')
    chunk_nums, chunk_denoms, loo_nums, loo_denoms, chunkinfo = cs.collapse_to_chunks(
            ldblocks,
            numerators,
            denominators,
            args.jk_blocks)

    # compute final results
    global q, results
    results = pd.DataFrame()
    for i, name in enumerate(marginal_names):
        print(i, name)
        # metadata about v
        sqnorm = marginal_infos.loc[name[:-2], 'sqnorm']
        supp = marginal_infos.loc[name[:-2], 'supp']

        # estimate mu and initialize results row
        k = marginal_names.index(name)
        mu = cs.get_est(sum(chunk_nums), sum(chunk_denoms), k, len(background_names))
        results = results.append({
            'pheno':pheno_name,
            'annot':name},
            ignore_index=True)

        # compute q
        q, r, mux, muy = cs.residualize(chunk_nums, chunk_denoms, len(background_names), k)


        # p-values
        if args.bothp or not args.fastp:
            p_emp, z_emp = cs.signflip(q, args.T, printmem=True, mode=args.stat)
            results.loc[i,'z'] = z_emp
            results.loc[i,'p'] = p_emp
        if args.bothp or args.fastp:
            z_fast = np.sum(q)/np.linalg.norm(q)
            p_fast = st.chi2.sf(z_fast**2, 1)
            results.loc[i, 'z_fast'] = z_fast
            results.loc[i, 'p_fast'] = p_fast
        if args.fastp and not args.bothp: # if only fastp, rename the output columns
            results.loc[i,'p'] = results.loc[i,'p_fast']
            results.loc[i,'z'] = results.loc[i,'z_fast']
            results.drop(['pfast','zfast'], axis=1, inplace=True)

        # print verbose information if required
        if results.loc[i].p < args.verbose_thresh:
            fname = args.outfile_stem+'.'+pheno_name+'.'+name
            print('writing verbose results to', fname)
            chunkinfo['q'] = q
            chunkinfo['r'] = r
            chunkinfo.to_csv(fname+'.chunks', sep='\t', index=False)
            coeffs = pd.DataFrame()
            coeffs['annot'] = background_names
            coeffs['mux'] = mux
            coeffs['muy'] = muy
            coeffs.to_csv(fname+'.coeffs', sep='\t', index=False)

        # nominate interesting loci if desired
        if results.loc[i].p < args.tell_me_stories:
            storyteller.write(args.outfile_stem+'.'+name+'.loci',
                    args, name, background_names, mux, muy, results.loc[i].z,
                    corr_thresh=args.story_corr_thresh)

        # add more output if desired
        if args.more_stats:
            results.loc[i,'qkurtosis'] = st.kurtosis(q)
            results.loc[i,'qstd'] = np.std(q)
            results.loc[i,'p_jk'] = st.chi2.sf((mu/se)**2,1)
            results.loc[i,'sqnorm'] = sqnorm

        # jackknife
        se = cs.jackknife_se(mu, loo_nums, loo_denoms, k, len(background_names))
        results.loc[i,'mu'] = mu
        results.loc[i,'se(mu)'] = se

        # add estimates of rf, h2v, and other associated quantities
        M = marginal_infos.loc[name[:-2],'M']
        results.loc[i,'h2g'] = h2g
        results.loc[i,'rf'] = mu * np.sqrt(
                sqnorm / h2g)
        results.loc[i,'h2v/h2g'] = results.loc[i].rf**2 - \
                results.loc[i,'se(mu)']**2 * sqnorm / (M*sigma2g)
        results.loc[i,'h2v'] = results.loc[i,'h2v/h2g'] * h2g
        results.loc[i,'supp(v)/M'] = supp/M
        results.to_csv(args.outfile_stem + '.gwresults', sep='\t', index=False, na_rep='nan')

    print(results)
    print('writing results to', args.outfile_stem + '.gwresults')
    results.to_csv(args.outfile_stem + '.gwresults', sep='\t', index=False, na_rep='nan')
    print('done')

# preprocess any sumstats that need preprocessing
def preprocess_sumstats(args):
    import os
    if args.pss_chr is None:
        unprocessed_chroms = [
                c for c in args.chroms
                if not os.path.exists(
                    args.sumstats_stem + '.' + args.refpanel_name + '/'+str(c)+'.pss.gz')
                ]
        if len(unprocessed_chroms) > 0:
            print('Preprocessing', args.sumstats_stem+'.sumstats.gz... at', unprocessed_chroms)
            if args.config is None:
                raise ValueError('automatic pre-processing of a sumstats file requires '+\
                        'specification of a config file; otherwise I dont know what '+\
                        'parameters to use. If you want, you can preprocess the sumstats '+\
                        'without a config file by running preprocesspheno manually')
            print('Using config file', args.config, 'and default options')

            # run the command
            import sldp.preprocesspheno, copy
            args_ = copy.copy(args)
            args_.no_M_5_50 = False
            args_.set_h2g = None
            args_.chroms = unprocessed_chroms
            sldp.preprocesspheno.main(args_)

        # modify args to reflect existing of pss-chr files
        args.pss_chr = args.sumstats_stem + '.' + args.refpanel_name + '/'
        args.sumstats_stem = None
        print('== finished preprocessing sumstats ==')

# preprocess any annotations that need preprocessing
def preprocess_sannots(args):
    import os

    for sannot in args.sannot_chr:
        unprocessed_chroms = [
                c for c in args.chroms
                if not (os.path.exists(sannot + str(c) + '.RV.gz') and
                    os.path.exists(sannot + str(c) + '.info'))
                ]
        if len(unprocessed_chroms) > 0:
            print('Preprocessing', sannot, 'at chromosomes', unprocessed_chroms)
            if args.config is None:
                raise ValueError('automatic pre-processing of an annotation '+\
                        'requires specification of a config file; otherwise I dont know what '+\
                        'parameters to use. If you want, you can preprocess the annotation '+\
                        'without a config file by running preprocessannot manually')
            print('Using config file', args.config, 'and default options')

            # run preprocessing command
            import sldp.preprocessannot, copy
            args_ = copy.copy(args)
            args_.alpha = -1
            args_.chroms = unprocessed_chroms
            sldp.preprocessannot.main(args_)

            print('== finished preprocessing annotation', sanno)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # required arguments
    parser.add_argument('--outfile-stem', required=True,
            help='Path to an output file stem.')
    pheno = parser.add_mutually_exclusive_group(required=True)
    pheno.add_argument('--pss-chr', default=None,
            help='Path to .pss.gz file, without chromosome number or .pss.gz extension. '+\
                    'This is the phenotype that SLDP will analyze.')
    pheno.add_argument('--sumstats-stem', default=None,
            help='Path to a .sumstats.gz file, not including ".sumstats.gz" extension. '+\
                    'SLDP will process this into a set of .pss.gz files before running.')
    parser.add_argument('--sannot-chr', nargs='+', required=True,
            help='One or more (space-delimited) paths to gzipped annot files, without '+\
                    'chromosome number or .sannot.gz/.RV.gz extension. These are the '+\
                    'annotations that SLDP will analyze against the phenotype.')

    # optional arguments
    parser.add_argument('--verbose-thresh', default=0., type=float,
            help='Print additional information about each association studied with a '+\
                    'p-value below this number. (Default is 0.) This includes: '+\
                    'the covariance in each independent block of genome (.chunks files), '+\
                    'and the coefficients required to residualize any background '+\
                        'annotations out of the other annotations being analyzed.')
    parser.add_argument('-fastp', default=False, action='store_true',
            help='Estimate p-values fast (without permutation)')
    parser.add_argument('-bothp', default=False, action='store_true',
            help='Print both fastp p-values (as p_fast) and normal p-values. '+\
                    'Takes precedence over fastp')
    parser.add_argument('--tell-me-stories', default=0.,
            help='!!Experimental!! For associations with a p-value less than this number, '+\
                    'print information about loci that may be promising to study. '+\
                    'This will produce plots of (potentially overlapping) loci where the '+\
                    'signed LD profile is highly correlated with the GWAS signal in a '+\
                    'direction consistent with the global effect. '+\
                    'Default value is 0.')
    parser.add_argument('--story-corr-thresh', default=0.8, type=float,
            help='The threshold to use for correlation between Rv and alphahat in order '+\
                    'for a locus to be considered worthy of a story')
    parser.add_argument('-more-stats', default=False, action='store_true',
            help='Print additional statistis about q in results file')
    parser.add_argument('--T', type=int, default=1000000,
            help='number of times to sign flip for empirical p-values. Default is 10^6.')
    parser.add_argument('--jk-blocks', type=int, default=300,
            help='Number of jackknife blocks to use. Default is 300.')
    parser.add_argument('--weights', default='Winv_ahat_h',
            help='which set of regression weights to use. Default is Winv_ahat_h, '+\
                    'corresponding to weights described in Reshef et al. 2017.')
    parser.add_argument('--seed', default=None, type=int,
            help='Seed random number generator to a certain value. Off by default.')
    parser.add_argument('--stat', default='sum',
            help='*experimental* Which statistic to use for hypothesis testing. Options '+\
                    'are: sum, medrank, or thresh.')
    parser.add_argument('--chi2-thresh', default=0, type=float,
            help='only use SNPs with a chi2 above this number for the regression')
    parser.add_argument('--chroms', nargs='+', default=range(1,23), type=int,
            help='Space-delimited list of chromosomes to analyze. Default is 1..22')

    # configurable arguments
    parser.add_argument('--config', default=None,
            help='Path to a json file with values for other parameters. ' +\
                    'Values in this file will be overridden by any values passed ' +\
                    'explicitly via the command line.')
    parser.add_argument('--background-sannot-chr', nargs='+', default=[],
            help='One or more (space-delimited) paths to gzipped annot files, without '+\
                    'chromosome number or .sannot.gz extension. These are the annotations '+\
                    'that SLDP will control for.')
    parser.add_argument('--svd-stem', default=None,
            help='Path to directory in which output files will be stored. ' +\
                    'If not supplied, will be read from config file.')
    parser.add_argument('--bfile-reg-chr', default=None,
            help='Path to plink bfile of reference panel to use, not including ' +\
                    'chromosome number. This bfile should contain only regression SNPs '+\
                    '(as opposed to, e.g., all potentially causal SNPs). '+\
                    'If not supplied, will be read from config file.')
    parser.add_argument('--ld-blocks', default=None,
            help='Path to UCSC bed file containing one bed interval per LD block. If '+\
                    'not supplied, will be read from config file.')


    print('=====')
    print(' '.join(sys.argv))
    print('=====')
    args = parser.parse_args()
    config.add_default_params(args)
    pretty.print_namespace(args)
    print('=====')

    preprocess_sumstats(args)
    preprocess_sannots(args)

    main(args)
