#!/usr/bin/env python

from pyPheWAS.pyPhewasCorev2 import *
import os
import sys
import argparse
import time
import math
import matplotlib.pyplot as plt

def parse_args():
    parser = argparse.ArgumentParser(description="pyPheWAS Plotting Tool")

    parser.add_argument('--statfile', required=True, type=str, help='Name of the statistics/regressions file')
    parser.add_argument('--thresh_type', required=True, type=str, help=' the type of threshold to be used in the plot')
    parser.add_argument('--custom_thresh', required=False, default=None, type=float, help='Custom threshold value (float between 0 and 1)')
    parser.add_argument('--imbalance', required=False, default=True, help = 'Whether or not to show the direction of imbalance in the plot (default=True)')
    parser.add_argument('--phewas_label', required=False, default="plot", type=str, help='Where to put PheCode labels - plot (default) or axis')
    parser.add_argument('--path', required=False, default='.', type=str, help='Path to all input files and destination of output files')
    parser.add_argument('--outfile', required=False, default=None, type=str, help='Name of the output file for the plot')

    args = parser.parse_args()
    return args

"""
Retrieve and validate all arguments.
"""
start = time.time()
args = parse_args()
kwargs = {'path': os.path.join(os.path.abspath(args.path),''),
          'statfile': args.statfile,
          'show_imbalance': eval(args.imbalance),
          'custom_thresh':args.custom_thresh,
          'outfile':args.outfile,
          'phewas_label': args.phewas_label
}
str_thresh_type = args.thresh_type

# Assert that a valid threshold type was used
assert str_thresh_type in threshold_map.keys(), "%s is not a valid regression type" % (kwargs['thresh_type'])
kwargs['thresh_type'] = threshold_map[args.thresh_type]
if kwargs['thresh_type'] == 2:
    assert (kwargs['custom_thresh'] < 1.0) & (kwargs['custom_thresh'] > 0.0), "%s is not a valid threshold (should be between 0.0 and 1.0)" % (kwargs['custom_thresh'])

# Assert that valid files were given
assert kwargs['statfile'].endswith('.csv'), "%s is not a valid phenotype file (must be a .csv file)" % (kwargs['feature_matrix'])

assert kwargs['phewas_label'] in ["plot","axis"], "%s is not a valid PheCode label location" % (kwargs['phewas_label'])

# Specify no output if output file was not given
# if kwargs['outfile'] is None:
#     kwargs['outfile'] = ''
# else:
    # assert kwargs['outfile'].endswith('.pdf') or kwargs['outfile'].endswith('.eps'), "%s is not a valid plot file, must be a .pdf" % (kwargs['outfile'])

ff = open(kwargs['path'] + kwargs['statfile'])
header = ff.readline().strip().split(',')
for i in range(0,len(header),2):
    kwargs[header[i]] = header[i+1]

# Print Arguments
display_kwargs(kwargs)

"""
Create plots
"""

# Make all arguments local variables
locals().update(kwargs)

# Read in the remaining data (the pandas DataFrame)
regressions = pd.read_csv(ff,dtype={'PheWAS Code':str})

# y = regressions['"-log(p)"']

try:
    regressions[['lowlim', 'uplim']] = regressions['Conf-interval beta'].str.split(',', expand=True)
    regressions['uplim'] = regressions.uplim.str.replace(']', '')
    regressions['lowlim'] = regressions.lowlim.str.replace('[', '')
    regressions = regressions.astype(dtype={'uplim':float,'lowlim':float})
    # yb = regressions[['beta', 'lowlim', 'uplim']].values
    # yb = yb.astype(float)
except Exception as e:
    print('Error reading regression file:')
    print(e)
    sys.exit()

# Check if an imbalance will be used
# if show_imbalance:
#     imbalances = get_imbalances(regressions)
# else:
#     imbalances = np.array([])

# Get the regular p-values using a numpy vectorized function
# regpvalues = np.vectorize(lambda x: 10**(-x))(y)

pvalues = regressions['p-val'].values

# Get the threshold type
if thresh_type == 0:
    thresh = get_bon_thresh(pvalues,0.05)
elif thresh_type == 1:
    thresh = get_fdr_thresh(pvalues,0.05)
elif thresh_type == 2:
    thresh = kwargs['custom_thresh']
print('%s threshold: %0.5f'%(str_thresh_type,thresh))

if outfile is not None:
    save = path + outfile
    file_name, file_format = os.path.splitext(save)
    saveb = file_name + '_beta' + file_format
    file_format = file_format[1:] # remove '.' from from first index
    print("Saving plot to %s" % (save))
else:
    save = ''
    saveb=''
    file_format=''

plot_manhattan(regressions, -math.log10(thresh), show_imbalance, save, file_format)
plot_odds_ratio(regressions, -math.log10(thresh), show_imbalance, saveb, file_format, phewas_label)

if not save:
    plt.show()

interval = time.time() - start
hour = math.floor(interval/3600.0)
minute = math.floor((interval - hour*3600)/60)
second = math.floor(interval - hour*3600 - minute*60)

if hour > 0:
    time_str = '%dh:%dm:%ds' %(hour,minute,second)
elif minute > 0:
    time_str = '%dm:%ds' % (minute, second)
else:
    time_str = '%ds' % second

print('pyPhewasPlot Complete\nRuntime: %s' %time_str)