#!/usr/bin/env python

import os
import sys
import numpy as np
import pandas as pd
import argparse
from astropy.time import Time

"""
A python version of the WQED Weld function for use
with the output from phot2lc to stitch multiple
light curves together into a single file.

Author:
    Zach Vanderbosch

For a description of updates, see the 
version_history.txt file.
"""

# Generate arguments for command line parsing
parser = argparse.ArgumentParser(description='Provide Input & Output Filenames.')
parser.add_argument('-f', '--infiles',type=str,nargs='*',
                    help="Input files.")
parser.add_argument('-b', '--bypass',action='store_true',
                    help="Bypass the barycentric correction check and weld LC's anyways.")
parser.add_argument('-o', '--outfile',type=str,
                    help="Output filename.")
args = parser.parse_args()


# Get list of filenames
filenames = sorted(args.infiles)
Nf = len(filenames)
if Nf == 0:
    print('\n0 Input Files provided, weldlc cannot proceed.')
    print('Program exited.\n')
    sys.exit(1)
elif Nf == 1:
    print('\nOnly 1 Input File provided, weldlc cannot proceed.')
    print('Program exited.\n')
    sys.exit(1)
else:
   print('\nWelding {} Files:'.format(Nf))


# Grab the headers
headers = []
for f in filenames:
    print("  {}".format(f))
    header = pd.read_csv(f,header=None,nrows=24,dtype=str,sep='*').values
    headers.append(header)


# First check whether all input light curves
# have been barycentric corrected or not
corrs = []
for h in headers:
    bjdcorr = h[12][0].split("=")[1].split("#")[0].strip()
    corrs.append(bjdcorr == "True")

# Set the "passed" keyword based on whether the input
# light curves are barycentric corrected or not.
if args.bypass:
    print("Barycentric correction check BYPASSED")
    print("Assuming LCs are not all barycentric corrected.")
    corrected = False
else:
    if all(corrs):
        print("Barycentric Corrections PASSED")
        corrected = True
    else:
        print("Barycentric Corrections FAILED")
        print("Input light curves have not ALL been barycentric corrected.")
        print("If you wish to WELD anyways, use the -b (--bypass) option.")
        print("\nProgram Exited\n")
        sys.exit(1)


# Get the BJDs or MJDs of T0 from header
t0 = []
if corrected:
    for h in headers:
        # Grab Info From the Header
        bjd = float(h[11][0].split("=")[1].split("#")[0].strip())
        t0.append(bjd)
else:
    mjds = []
    for h in headers:
        # Grab Info From the Header
        mjd = float(h[8][0].split("=")[1].split("#")[0].strip())
        t0.append(mjd)
        

# Now get all of the actual data
tdata = np.asarray([])
fdata = np.asarray([])
edata = np.asarray([])
for i,f in enumerate(filenames):
    data = np.loadtxt(f)
    tdata = np.concatenate((tdata,(data[:,0]/86400.0) + t0[i]))
    fdata = np.concatenate((fdata,data[:,1]))
    edata = np.concatenate((edata,data[:,2]))

# Convert times in seconds to reference the minimum BJD
ref_time = min(t0)
tref = (tdata - ref_time) * 86400.0

# Get date of bjdref
if corrected:
    tstart = Time(ref_time,format='jd',scale='tdb')
else:
    tstart = Time(ref_time,format='mjd',scale='utc')
date_start = tstart.to_value('isot').split("T")[0]
time_start = tstart.to_value('isot').split("T")[-1]

# Resort and Reshape Arrays
Nv = len(tdata)
ind_sort = np.argsort(tref)
tsorted = np.reshape(tref[ind_sort],(Nv,1))
fsorted = np.reshape(fdata[ind_sort],(Nv,1))
esorted = np.reshape(edata[ind_sort],(Nv,1))

# Calculate the time sapn in days
tspan = (tsorted[-1,0] - tsorted[0,0])/86400.

# Recombine into an output file
output = np.concatenate((tsorted,fsorted,esorted),axis=1)

# Create new header for output
if corrected:
    lc_header = '{}'.format(headers[0][0][0].strip("# ")) + \
                '\n{}'.format(headers[0][1][0].strip("# ")) + \
                '\n{}'.format(headers[0][2][0].strip("# ")) + \
                '\nDate       = {:30s}# Mid-Exposure TDB Date at T0'.format(date_start) + \
                '\nTime       = {:30s}# Mid-Exposure TDB Time at T0'.format(time_start) + \
                '\nBJED       = {:<30.9f}# Mid-Exposure TDB JD at T0'.format(ref_time) + \
                '\nBaryCorr   = {:30s}# Whether the times are barycentric corrected'.format('True') + \
                '\nWeldNum    = {:<30d}# Number of files welded'.format(Nf) + \
                '\nNpoints    = {:<30d}# Number of data points'.format(len(tsorted)) + \
                '\nTspan      = {:<30.8f}# Time spanned by data (days)'.format(tspan) + \
                '\nWeldDate   = {:30s}# File creation date'.format(Time.now().to_value("iso")) + \
                '\nColumns    = T-mid (s), Rel. Flux, Rel. Flux Error' 
else:
    lc_header = '{}'.format(headers[0][0][0].strip("# ")) + \
                '\n{}'.format(headers[0][1][0].strip("# ")) + \
                '\n{}'.format(headers[0][2][0].strip("# ")) + \
                '\nDate       = {:30s}# Mid-Exposure UTC Date at T0'.format(date_start) + \
                '\nTime       = {:30s}# Mid-Exposure UTC Time at T0'.format(time_start) + \
                '\nMJD        = {:<30.9f}# Mid-Exposure UTC MJD at T0'.format(ref_time) + \
                '\nBaryCorr   = {:30s}# Whether the times are barycentric corrected'.format('False') + \
                '\nWeldNum    = {:<30d}# Number of files welded'.format(Nf) + \
                '\nNpoints    = {:<30d}# Number of data points'.format(len(tsorted)) + \
                '\nTspan      = {:<30.8f}# Time spanned by data (days)'.format(tspan) + \
                '\nWeldDate   = {:30s}# File creation date'.format(Time.now().to_value("iso")) + \
                '\nColumns    = T-mid (s), Rel. Flux, Rel. Flux Error' 


# Save the lightcurve to file
lc_format = '%15.3f  %9.6f  %9.6f'
lc_fname = args.outfile
np.savetxt(lc_fname, output, fmt=lc_format, header=lc_header)

print('Finished. Output saved as {}\n'.format(lc_fname))











