#!/usr/bin/env python

from pyPheWAS.pyPhewasCorev2 import *
import os
import argparse
import time
import math
import matplotlib.pyplot as plt

def parse_args():
    parser = argparse.ArgumentParser(description="pyPheWAS Plotting Tool")

    parser.add_argument('--statfile', required=True, type=str, help='Name of the statistics/regressions file')
    parser.add_argument('--imbalance', required=True, type=str, help ='Whether or not to show the direction of imbalance in the plot')
    parser.add_argument('--thresh_type', required=True, type=str, help=' the type of threshold to be used in the plot')
    parser.add_argument('--custom_thresh', required=False, default=None, type=float, help='Custom threshold value (float between 0 and 1)')
    parser.add_argument('--path', required=False, default='.', type=str, help='Path to all input files and destination of output files')
    parser.add_argument('--outfile', required=False, default=None, type=str, help='Name of the output file for the plot')

    args = parser.parse_args()
    return args

"""
Retrieve and validate all arguments.
"""
start = time.time()
args = parse_args()
kwargs = {'path': os.path.join(os.path.abspath(args.path),''),
          'statfile': args.statfile,
          'show_imbalance': eval(args.imbalance),
          'custom_thresh':args.custom_thresh,
          'outfile':args.outfile
}
str_thresh_type = args.thresh_type

# Assert that a valid threshold type was used
assert str_thresh_type in threshold_map.keys(), "%s is not a valid regression type" % (kwargs['thresh_type'])
kwargs['thresh_type'] = threshold_map[args.thresh_type]
if kwargs['thresh_type'] == 2:
    assert (kwargs['custom_thresh'] < 1.0) & (kwargs['custom_thresh'] > 0.0), "%s is not a valid threshold (should be between 0.0 and 1.0)" % (kwargs['custom_thresh'])

# Assert that valid files were given
assert kwargs['statfile'].endswith('.csv'), "%s is not a valid phenotype file (must be a .csv file)" % (kwargs['feature_matrix'])

# Specify no output if output file was not given
if kwargs['outfile'] is None:
    kwargs['outfile'] = ''
else:
    assert kwargs['outfile'].endswith('.pdf'), "%s is not a valid plot file, must be a .pdf" % (kwargs['outfile'])

ff = open(kwargs['path'] + kwargs['statfile'])
header = ff.readline().strip().split(',')
for i in range(0,len(header),2):
    kwargs[header[i]] = header[i+1]

# Print Arguments
display_kwargs(kwargs)

"""
Create plots
"""

# Make all arguments local variables
locals().update(kwargs)

# Read in the remaining data (the pandas DataFrame)
regressions = pd.read_csv(ff)

y = regressions['"-log(p)"']

try:
    regressions[['lowlim', 'uplim']] = regressions['Conf-interval beta'].str.split(',', expand=True)
    regressions.uplim = regressions.uplim.str.replace(']', '')
    regressions.lowlim = regressions.lowlim.str.replace('[', '')
    yb = regressions[['beta', 'lowlim', 'uplim']].values
    yb = yb.astype(float)
except:
    print('No correlation')

# Check if an imbalance will be used
if show_imbalance:
    imbalances = get_imbalances(regressions)
else:
    imbalances = np.array([])

# Get the regular p-values using a numpy vectorized function
regpvalues = np.vectorize(lambda x: 10**(-x))(y)

# Get the threshold type
if thresh_type == 0:
    thresh = get_bon_thresh(y,0.05)
elif thresh_type == 1:
    thresh = get_fdr_thresh(regpvalues,0.05)
elif thresh_type == 2:
    thresh = kwargs['custom_thresh']
print('%s threshold: %0.5f'%(str_thresh_type,thresh))

if outfile != '':
    save = path + outfile
    saveb = path + outfile.replace('.pdf','_beta.pdf')
    print("Saving plot to %s" % (save))
else:
    save = ''
    saveb=''
    print("Displaying plot.")

plot_data_points(y, -math.log10(thresh), save, imbalances)
plot_odds_ratio(yb, y, -math.log10(thresh), saveb, imbalances)

if not save:
    plt.show()

interval = time.time() - start
hour = math.floor(interval/3600.0)
minute = math.floor((interval - hour*3600)/60)
second = math.floor(interval - hour*3600 - minute*60)

if hour > 0:
    time_str = '%dh:%dm:%ds' %(hour,minute,second)
elif minute > 0:
    time_str = '%dm:%ds' % (minute, second)
else:
    time_str = '%ds' % second

print('pyPhewasPlot Complete\nRuntime: %s' %time_str)