#!/usr/bin/env python

import sys
import argparse
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from glob import glob
from matplotlib.gridspec import GridSpec
from astropy.timeseries import LombScargle as ls
from scipy.signal import find_peaks
from phot2lc.photfunc import calc_lsp, prewhiten

# Set the backend environment based on the
# system platform to get best figure sizing
import matplotlib as mpl
if sys.platform == 'darwin':   # Mac OSX
    mpl.use('QT5agg')
elif sys.platform == 'linux':  # Linux
    mpl.use('TKagg')
elif sys.platform == 'linux2': # Linux2
    mpl.use('TKagg')
elif sys.platform == 'win32':  # Windows
    mpl.use('QT5agg')
else:                          # Other
    mpl.use('TKagg')

"""
A tool which generates a plot of the lightcurve
from an available .lc file generated by phot2lc,
or a .lc1 file generated by WQED.

In addition to the lightcurve, a periodogram
is also generated and a pre-whitening sequence
is automatically run to identify significant
peaks.


Author:
    Zach Vanderbosch

For a description of updates, see the 
version_history.txt file.
"""

# Generate arguments for command line parsing
parser = argparse.ArgumentParser(description='Provide Input & Output Filenames.')
parser.add_argument('-f', '--files',type=str,nargs='*',
                    help="Input file(s) to perform quicklook on.")
parser.add_argument('-s', '--save',action='store_true',
                    help="Whether to save the quicklook plot.")
parser.add_argument('-p', '--prewhiten',action='store_true',
                    help="Whether to perform a pre-whitening sequence.")
parser.add_argument('-l', '--lower',type=float,default=500.,
                    help="Lower frequency limit for pre-whitening search (micro-Hertz).")
parser.add_argument('-u', '--upper',type=float,default=100000.,
                    help="Upper frequency limit for pre-whitening search (micro-Hertz).")
parser.add_argument('-n', '--num',type=int,default=10,
                    help="Maximum number of pre-whitening iterations.")
parser.add_argument('-w', '--wqedlc',action='store_true',
                    help="Whether the input file(s) are from WQED.")
args = parser.parse_args()

# First check whether the file(s) come from phot2lc or wqed
if args.wqedlc:
    wqed = True
else:
    wqed = False


# Check for filenames
filenames = sorted(args.files)
Nf = len(filenames)
if Nf == 0:
    print('\n0 Input Files provided, quicklook cannot proceed.')
    print('Program exited.\n')
    sys.exit(1)

# Check that the wqed parameter is correct by looking at the
# number of columns. 2 columns = WQED, 4 columns = phot2lc
test_data = np.loadtxt(filenames[0])
columns = np.shape(test_data)[1]
if (wqed) & (columns > 2):
    print('\nLightcurve is not of type WQED. Expected 2 columns, saw {}'.format(columns))
    print('Program exited.\n')
    sys.exit(1)
if (~wqed) & (columns < 3):
    print('\nLightcurve is not of type PHOT2LC. Expected 3 columns, saw {}'.format(columns))
    print('Program exited.\n')
    sys.exit(1)



# Iterate through each file and generate a quicklook plot
for file in filenames:

    # Load data
    data = np.loadtxt(file)

    # Parse data
    time = data[:,0]
    flux = data[:,1]
    fluxerr = data[:,2]


    # Get object name
    with open(file) as f:
        for l in f.readlines():
            if l[0] == "#":
                if wqed:
                    if "Object" in l.split("=")[0]:
                        object_name = l.split("=")[1].split("#")[0].strip()
                    if "Date" in l.split("=")[0]:
                        date_obs = l.split("=")[1].split("#")[0].strip()
                else:
                    if "Object" in l.split("=")[0]:
                        object_name = l.split("=")[1].split("#")[0].strip()
                    if "# Date" in l.split("=")[0]:
                        date_obs = l.split("=")[1].split("#")[0].strip()
            else:
                break

    # Calculate the Lomb-Scargle Periodogram
    farr,lsp = calc_lsp(time,flux)
    sig = 4.0*np.nanmean(lsp[(farr>0.0005) & (farr<0.012)])

    # Extract frequency and amplitude values from LMFIT result
    if args.prewhiten:
        result, lsp_pw = prewhiten(time,flux,Npw=args.num,fmin=args.lower,fmax=args.upper)
        if result is not None:
            # Save Fit Results
            freq_vals = []
            freq_errs = []
            amp_vals = []
            amp_errs = []
            for name,param in result.params.items():
                if 'freq' in name:
                    freq_vals.append(param.value)
                    freq_errs.append(param.stderr)
                if 'amp' in name:
                    amp_vals.append(abs(param.value))
                    amp_errs.append(param.stderr)
            Np = len(freq_vals)
            sig = 4.0*np.nanmean(lsp_pw[(farr>0.0005) & (farr<0.012)])
        else: 
            Np = 0


    ###########################################################
    # Plotting Code

    fig = plt.figure('Quicklook',figsize=(8,6))
    gs = GridSpec(2,1)
    ax0 = fig.add_subplot(gs[0])
    ax1 = fig.add_subplot(gs[1])

    ax0.errorbar(time/3600,flux*1e2,yerr=fluxerr*1e2,
        ls='None',marker='.',ms=5,mfc='k',mec='k',
        ecolor='silver',elinewidth=1)
    ax1.plot(farr*1e6,lsp*1e2,c='k',lw=1)
    ax1.axhline(sig*1e2,ls=':',c='C2')

    # Set XY Labels
    ax0.set_xlabel('Time (hr)',fontsize=10)
    ax1.set_xlabel(r'Frequency ($\mu$Hz)',fontsize=10)
    ax0.set_ylabel('Relative Flux (%)',fontsize=10)
    ax1.set_ylabel('Amplitude (%)',fontsize=10)

    # Set XY Limits
    amax = max(lsp[farr<0.01]) * 1e2
    if amax < sig*1e2:
        yupp = 1.4 * sig*1e2
    else:
        yupp = 1.2 * amax
    ax1.set_xlim(0,max(farr)*1e6)
    ax1.set_ylim(0,yupp)

    # Add text for 4<A> sig-threshold
    threshold_text = r'$4\langle$A$\rangle = {:.3f} \%$'.format(sig*1e2)
    axis_to_data = ax1.transAxes + ax1.transData.inverted()
    data_to_axis = axis_to_data.inverted()
    points_axis = data_to_axis.transform([0,sig*1e2+0.04*yupp])
    ax1.text(0.80,points_axis[1],threshold_text, 
             color='C2',fontsize=10,transform=ax1.transAxes)

    # Add Markers for Each Frequency
    if args.prewhiten:
        if Np > 0:
            ax1.fill_between(farr*1e6,lsp_pw*1e2,y2=0,color='C3',alpha=0.4,lw=0)
            for i,f in enumerate(freq_vals):
                ax1.plot(f*1e6,1.09*amax,ls='None',marker='|',mfc='b',mec='b',mew=0.5)
                ax1.text(f*1e6,1.13*amax,'{}'.format(i+1),ha='center',fontsize=7,color='b')

        # Print out Each Peaks Frequency and Amplitude
        available_fonts = []
        for f in mpl.font_manager.fontManager.ttflist:
            available_fonts.append(f.name)
        font_choices = ['Bitstream Vera Sans Mono',
                        'DejaVu Sans Mono']
        if font_choices[0] in available_fonts:
            font = {'fontname':font_choices[0]}
        else:
            font = {'fontname':font_choices[1]}
        if Np > 0:
            ax1.text(0.55,0.92,'Significant Peaks',fontsize=8,transform=ax1.transAxes,**font)
            ax1.text(0.55,0.91,'_____________________________',transform=ax1.transAxes,**font)
            for i,(f,a) in enumerate(zip(freq_vals,amp_vals)):
                p = 1./f
                yloc = 0.83 - (float(i)*0.055)
                if (p >= 600.) & (p < 3600.):
                    p = p/60. # Convert to minutes if >= 10 minutes
                    ax1.text(0.55,yloc,'{}: {:7.1f} $\mu$Hz, {:5.2f} min, {:5.2f} %'.format(i+1,f*1e6,p,a*1e2),
                             fontsize=8,transform=ax1.transAxes,**font)

                elif (p >= 3600.) & (p < 86400.):
                    p = p/3600. # Convert to hours if >= 1 hour
                    ax1.text(0.55,yloc,'{}: {:7.1f} $\mu$Hz, {:5.2f} hr, {:5.2f} %'.format(i+1,f*1e6,p,a*1e2),
                             fontsize=8,transform=ax1.transAxes,**font)

                elif (p >= 86400.):
                    p = p/86400. # Convert to days if >= 1 day
                    ax1.text(0.55,yloc,'{}: {:7.1f} $\mu$Hz, {:5.2f} d, {:5.2f} %'.format(i+1,f*1e6,p,a*1e2),
                             fontsize=8,transform=ax1.transAxes,**font)

                else:
                    ax1.text(0.55,yloc,'{}: {:7.1f} $\mu$Hz, {:5.1f} s, {:5.2f} %'.format(i+1,f*1e6,1./f,a*1e2),
                             fontsize=8,transform=ax1.transAxes,**font)


    # Add a title
    title = "Quicklook Plot for {} on {}".format(object_name,date_obs)
    ax0.set_title(title,fontsize=13)

    # Set Tick Parameters
    ax0.minorticks_on()
    ax1.minorticks_on()
    ax0.tick_params(which='both',axis='both',top=True,right=True,direction='in')
    ax1.tick_params(which='both',axis='both',right=True)
    ax1.tick_params(which='both',axis='y',direction='in')

    # Optionally Save the figure
    if args.save:
        figname = 'quicklook_{}.png'.format(".".join(file.split(".")[:-1]))
        plt.savefig(figname,dpi=300,bbox_inches='tight')


    # Show figure, then close
    plt.show()
    plt.close()

