#!/usr/bin/env python

from pyPheWAS.pyPhewasCorev2 import *
import os
import time
import math
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="pyPheWAS ICD-Phecode Lookup Tool")

    parser.add_argument('--phenotype', required=True, type=str, help='Name of the phenotype file (e.g. icd9_data.csv)')
    parser.add_argument('--group', required=True, type=str, help ='Name of the group file (e.g. groups.csv)')
    parser.add_argument('--reg_type', required=True, type=str, help='Type of regression that you would like to use (log, lin, or dur)')
    parser.add_argument('--path', required=False, default='.', type=str, help='Path to all input files and destination of output files')
    parser.add_argument('--postfix', required=False, default=None, type=str, help='Descriptive postfix for output files (e.g. poster or ages50-60)')
    parser.add_argument('--phewas_cov', required=False, default=None, type=float, help='PheCodes to use as covariates in pyPhewasModel regression')
    parser.add_argument('--covariates', required=False, default='', type=str, help='Variables to be used as covariates')
    parser.add_argument('--response', required=False, default='', type=str, help='Variable to predict instead of genotype')
    parser.add_argument('--imbalance', required=False, default="True", type=str, help='Whether or not to show the direction of imbalance in the plot')
    parser.add_argument('--thresh_type', required=False, default=None, type=str, help='Type of threshold to be used in the plot (fdr, bon, or custom)')
    parser.add_argument('--custom_thresh', required=False, default=None, type=float, help='Custom threshold value (float between 0 and 1)')

    args = parser.parse_args()
    return args

"""
Retrieve and validate all arguments.
"""
start = time.time()

args = parse_args()
kwargs = {'path': os.path.join(os.path.abspath(args.path),''),
          'phenotypefile': args.phenotype,
          'groupfile': args.group,
          'phewas_cov':args.phewas_cov,
          'postfix':args.postfix,
          'show_imbalance': eval(args.imbalance),
          'custom_thresh': args.custom_thresh,
          'covariates': args.covariates,
          'response': args.response
}
str_reg_type = args.reg_type
str_thresh_type = args.thresh_type

# Assert that a valid regression type was used
assert str_reg_type in regression_map.keys(), "%s is not a valid regression type" % str_reg_type
kwargs['reg_type'] = regression_map[str_reg_type]


# Assert that valid files were given
assert kwargs['phenotypefile'].endswith('.csv'), "%s is not a valid phenotype file, must be a .csv file" % (kwargs['phenotypefile'])
assert kwargs['groupfile'].endswith('.csv'), "%s is not a valid group file, must be a .csv file" % (kwargs['groupfile'])

# Assign the output file if none was assigned
if kwargs['postfix'] is None:
    if kwargs['covariates'] is not '':
        kwargs['postfix'] = kwargs['covariates'] + '_' + os.path.splitext(kwargs['groupfile'])[0]
    else:
        kwargs['postfix'] = os.path.splitext(kwargs['groupfile'])[0]
else:
    if kwargs['covariates'] is not '':
        kwargs['postfix'] = kwargs['covariates'] + '_' + os.path.splitext(kwargs['postfix'])[0]
    else:
        kwargs['postfix'] = os.path.splitext(kwargs['groupfile'])[0]

# Check phewas_cov
if kwargs['phewas_cov']:
    kwargs['phewas_cov'] = float(kwargs['phewas_cov'])

if kwargs['response'] is None:
    kwargs['response'] = ""


# Assert that a valid threshold type was used
if args.thresh_type is None:
    kwargs['thresh_type'] = ['fdr','bon']
else:
    assert str_thresh_type in threshold_map.keys(), "%s is not a valid threshold type" % (str_thresh_type)
    kwargs['thresh_type'] = [str_thresh_type]

# Print Arguments
display_kwargs(kwargs)

# Make all arguments local variables
locals().update(kwargs)


""" 
pyPhewasLookup 
"""

print("Retrieving phenotype data...")
phenotypes = get_icd_codes(path, phenotypefile, reg_type)

print("Retrieving group data...")
genotypes = get_group_file(path, groupfile)

print("Generating feature matrix...")
fm,columns = generate_feature_matrix(genotypes,phenotypes,reg_type,phewas_cov)

print("Saving feature matrices to %s" % (path + outfile))
h = ','.join(columns)

np.savetxt(path + 'agg_measures_' + outfile, fm[0],delimiter=',',header=h)
print("...")
np.savetxt(path + 'icd_age_' + outfile, fm[1],delimiter=',',header=h)
print("...")
np.savetxt(path + 'phewas_cov_' + outfile, fm[2],delimiter=',',header=h)


""" 
pyPhewasModel 
"""

print("Running PheWAS regressions...")
regressions = run_phewas(fm, genotypes, covariates, reg_type, response, phewas_cov)

reg_outfile = "regressions_" + postfix + ".csv"
print("Saving regression data to %s" % (path + reg_outfile))
header = ','.join(['str_reg_type', str_reg_type, 'group', groupfile]) + '\n'
f = open(os.sep.join([path, reg_outfile]), 'w')
f.write(header)
regressions.to_csv(f)
f.close()


""" 
pyPhewasPlot 
"""


y = regressions['"-log(p)"']

try:
    regressions[['lowlim', 'uplim']] = regressions['Conf-interval beta'].str.split(',', expand=True)
    regressions.uplim = regressions.uplim.str.replace(']', '')
    regressions.lowlim = regressions.lowlim.str.replace('[', '')
    yb = regressions[['beta', 'lowlim', 'uplim']].values
    yb = yb.astype(float)
except:
    print('No correlation')

# Check if an imbalance will be used
if show_imbalance:
    imbalances = get_imbalances(regressions)
else:
    imbalances = np.array([])

# Get the regular p-values using a numpy vectorized function
regpvalues = np.vectorize(lambda x: 10**(-x))(y)

for t in thresh_type:
    t_num = threshold_map[t]

    # Get the threshold type
    if t_num == 0:
        thresh = get_bon_thresh(y,0.05)
    elif t_num == 1:
        thresh = get_fdr_thresh(regpvalues,0.05)
    elif t_num == 2:
        thresh = kwargs['custom_thresh']
    print('%s threshold: %0.5f' % (t, thresh))

    save = path + t + '_'  + postfix + '.pdf'
    saveb = path + t + '_' + postfix + '_beta.pdf'
    print('Saving plots to %s' %save)

    plot_data_points(y, -math.log10(thresh), save, imbalances)
    plot_odds_ratio(yb, y, -math.log10(thresh), saveb, imbalances)

interval = time.time() - start
hour = math.floor(interval/3600.0)
minute = math.floor((interval - hour*3600)/60)
second = math.floor(interval - hour*3600 - minute*60)

if hour > 0:
    time_str = '%dh:%dm:%ds' %(hour,minute,second)
elif minute > 0:
    time_str = '%dm:%ds' % (minute, second)
else:
    time_str = '%ds' % second

print('pyPhewasPipeline Complete\nRuntime: %s' %time_str)