#!/usr/bin/env python

import os
import sys
import numpy as np
import pandas as pd
import argparse
from astropy.time import Time

"""
A python version of the WQED Weld function for use
with the output from phot2lc to stitch multiple
light curves together into a single file.

Author:
    Zach Vanderbosch

For a description of updates, see the 
version_history.txt file.
"""

# Generate arguments for command line parsing
parser = argparse.ArgumentParser(description='Provide Input & Output Filenames.')
parser.add_argument('-f', '--infiles',type=str,nargs='*',
                    help="Input files.")
parser.add_argument('-o', '--outfile',type=str,
                    help="Output filename.")
args = parser.parse_args()


# Get list of filenames
filenames = sorted(args.infiles)
Nf = len(filenames)
if Nf == 0:
    print('\n0 Input Files provided, weldlc cannot proceed.')
    print('Program exited.\n')
    sys.exit(1)
elif Nf == 1:
    print('\nOnly 1 Input File provided, weldlc cannot proceed.')
    print('Program exited.\n')
    sys.exit(1)
else:
  print('\nWelding {} Files:'.format(Nf))


# Get the BJDs of T0 from header
headers = []
bjds = []
for f in filenames:
    print("  {}".format(f))
    header = pd.read_csv(f,header=None,nrows=23,dtype=str,sep='?').values
    headers.append(header)

    # Grab Info From the Header
    bjd = float(header[11][0].split(":")[1].split("#")[0].strip())
    bjds.append(bjd)


# Now get all of the actual data
tdata = np.asarray([])  # Combined time data in BJD days
fdata = np.asarray([])
edata = np.asarray([])
for i,f in enumerate(filenames):
    data = np.loadtxt(f)
    tdata = np.concatenate((tdata,(data[:,1]/86400.0) + bjds[i]))
    fdata = np.concatenate((fdata,data[:,2]))
    edata = np.concatenate((edata,data[:,3]))

# Convert times in seconds to reference the minimum BJD
bjdref = min(bjds)
tref = (tdata - bjdref) * 86400.0

# Get date of bjdref
bjd_start = Time(bjdref,format='jd',scale='tdb')
date_start = bjd_start.to_value('isot').split("T")[0] 

# Resort and Reshape Arrays
Nv = len(tdata)
ind_sort = np.argsort(tref)
tsorted = np.reshape(tref[ind_sort],(Nv,1))
fsorted = np.reshape(fdata[ind_sort],(Nv,1))
esorted = np.reshape(edata[ind_sort],(Nv,1))

# Calculate the time sapn in days
tspan = (tsorted[-1,0] - tsorted[0,0])/86400.

# Recombine into an output file
output = np.concatenate((tsorted,fsorted,esorted),axis=1)

# Create new header for output
lc_header = '{}'.format(headers[0][0][0].strip("# ")) + \
            '\n{}'.format(headers[0][1][0].strip("# ")) + \
            '\n{}'.format(headers[0][2][0].strip("# ")) + \
            '\nDate       : {:30s}# Mid-Exposure TDB Date at T0'.format(date_start) + \
            '\nBJED       : {:<30.9f}# Mid-Exposure TDB JD at T0'.format(bjdref) + \
            '\nWeldNum    : {:<30d}# Number of files welded'.format(Nf) + \
            '\nNpoints    : {:<30d}# Number of data points'.format(len(tsorted)) + \
            '\nTspan      : {:<30.8f}# Time spanned by data (days)'.format(tspan) + \
            '\nWeldDate   : {:30s}# File creation date'.format(Time.now().to_value("iso")) + \
            '\nColumns: BaryCorr T-mid (s), Rel. Flux, Rel. Flux Error' 

# Save the lightcurve to file
lc_format = '%15.3f  %9.6f  %9.6f'
lc_fname = args.outfile
np.savetxt(lc_fname, output, fmt=lc_format, header=lc_header)

print('Finished. Output saved as {}\n'.format(lc_fname))











