#!/usr/bin/env python

"""
phot2lc takes output from aperture photometry
pipelines and aids in the extraction of optimal 
light curves. Current features include:
    > Comparison star selection
    > Aperture size selection
    > Polynomial detrending
    > Sigma clipping
    > Manual data point removal
    > Barycentric time corrections

Author: 
    Zach Vanderbosch

For a description of updates, see the 
version_history.txt file.

Usage:
    phot2lc
    phot2lc -h|--help
    phot2lc -t|--telescope
    phot2lc -s|--source
    phot2lc -i|--image
    phot2lc -o|--object

Options:
    -h --help      Show command line options
    -t --telescope Code name for telescope used
    -s --source    Code name for photometry program used
    -i --image     Name of specific image instead of list
    -o --object    Name of object matching stars.dat entry

""" 

# Check that Python in 3.6 or later
import sys
if sys.version_info[0] < 3:
    print('ERROR: phot2lc incompatible with Python 2.X')
    print('Program exited.')
    sys.exit(1)
if (sys.version_info[0] == 3) & (sys.version_info[1] < 6):
    print('ERROR: phot2lc incompatible with Python 3.5 or earlier')
    print('Program exited.')
    sys.exit(1)

# Set the backend environment
import matplotlib as mpl
mpl.use('QT5agg')

# Import Standard Packages
import os
import argparse
import warnings
import numpy as np
import pandas as pd
from glob import glob
import PyQt5.QtCore as qtc
from itertools import combinations

# Import Astropy modules
import astropy.units as u
from astropy.io import fits
from astropy.utils import iers
from astropy.time import Time, TimeDelta
from astropy.coordinates import SkyCoord
from astropy.visualization import ZScaleInterval
from astropy.utils.exceptions import AstropyWarning

# Import matplotlib modules
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import MultipleLocator
from matplotlib.widgets import Cursor, RectangleSelector

# Import Custom Functions
from phot2lc.teledat import get_telinfo
from phot2lc.photfunc import progress_bar, print_commands
from phot2lc.photfunc import get_time, pp_scat, poly_fit
from phot2lc.photfunc import div_lc, gen_compstr, calc_lsp
from phot2lc.photfunc import get_loc, window_std


#############################################################
## Load in the config.dat file and use it to set the 
## defaults for user-based parameters.

import phot2lc.teledat
config_path = os.path.dirname(os.path.realpath(phot2lc.teledat.__file__))
config_dat = []
with open(config_path + "/config.dat") as f:
    for l in f.readlines():
        config_dat.append(l.strip("\n").split("=")[-1].strip())


# Name for whoever is performing these reductions
author = config_dat[0]
# File containing the list of images being analyzed
list_name = config_dat[1]
# Optional File containing initial guesses at target + comp pixel
# locations. This is the same file used for IRAF ccd_hsp aperture
# photometry and is not needed if you don't have it.
tloc_name = config_dat[2]
# Path to file containing object names + coordinates
star_dat_filename = config_dat[3] 


# Defualt arguments for argparse
default_telescope = config_dat[4]
default_source = config_dat[5]
default_image = None if config_dat[6] == 'None' else config_dat[6]
default_object = None if config_dat[7] == 'None' else config_dat[7]


#############################################################
## Generate arguments for command line parsing

parser = argparse.ArgumentParser()
parser.add_argument('-t', '--telescope',type=str,default=default_telescope,
                    help="Code name for telescope used.")
parser.add_argument('-s', '--source',type=str,default=default_source,
                    help="Source of photometry (ccd_hsp or maestro)")
parser.add_argument('-i', '--image',type=str,default=default_image,
                    help="Whether to use a list of images, or just one.")
parser.add_argument('-o', '--object',type=str,default=default_object,
                    help="Name of object. If None, get object from FITS header.")
args = parser.parse_args()


#############################################################
# Perform checks on command line inputs

# Load list containing telescope information
telinfo = get_telinfo()

# Currently supported codes, telescopes, instruments
valid_codes = [x['code'] for x in telinfo]
valid_telescopes = [x['telename'] for x in telinfo]
valid_instruments = [x['instname'] for x in telinfo]

# Check that telescope code is valid
telcode = args.telescope
if telcode in valid_codes:
    teldict = next(item for item in telinfo if item["code"] == telcode)
else:
    print('\nERROR! Supplied telescope code ' + \
          '"{}" is not currently supported.'.format(telcode))
    print('Supported telescope codes:')
    for code,tele,inst in zip(valid_codes,valid_telescopes,valid_instruments):
        print("   {} = {} {}".format(code,tele,inst))
    print("")
    sys.exit(1)


# Currently supported photometry sources
valid_sources = ['hsp','mae']
source_dict = {'hsp':'ccd_hsp',
               'mae':'maestro'}

# Check that photometry source is valid
psource = args.source
if psource in valid_sources:
    phot_source = source_dict[psource]
else:
    print('\nERROR! Supplied photometry source ' + \
          '"{}" is not currently supported.'.format(psource))
    print('Supported source codes:')
    for code in valid_sources:
        print("   {} = {}".format(code,source_dict[code]))
    print("")
    sys.exit(1)


# Get Camera Noise Characteristics
dark = teldict['dark'] # Dark Current in ADU/s/pixel
read = teldict['read'] # Read Noise in e-/s/pixel
gain = teldict['gain'] # Gain in e-/ADU


#############################################################
## Additional Setup

# Supress annoying Astropy Warning Messages
warnings.simplefilter('ignore', category=AstropyWarning)

# Temporarily disable some default matplotlib keymaps
plt.rcParams['keymap.home'] = ''
plt.rcParams['keymap.quit'] = ''
plt.rcParams['keymap.back'] = ''
plt.rcParams['keymap.forward'] = ''
plt.rcParams['keymap.grid'] = ''
plt.rcParams['keymap.save'] = ''
plt.rcParams['keymap.fullscreen'] = ''
plt.rcParams['keymap.grid'] = ''
plt.rcParams['keymap.grid_minor'] = ''

# Some additional setup steps for plotting
qtc.pyqtRemoveInputHook() # Supresses QCORE message during input
plt.style.use('dark_background') # Sets default plot style


#############################################################
## Start loading in data

# Grab all of the runbase Photometry files
if psource == 'hsp': # hsp_nd output
    phot_names_raw = glob('runbase*')
elif psource == 'mae': # maestro output
    phot_names_raw = glob('counts*')
Nf = len(phot_names_raw)

# Grab all or just one of the FITS files
if args.image is None:
    fits_list = np.loadtxt(list_name, dtype=str)
    num_fits = len(fits_list)
else:
    fits_list = [args.image]
    num_fits = 1

# Get apeture sizes from filenames
ap_sizes_raw = [float(f.split("/")[-1][7:]) for f in phot_names_raw]

# Use ap_sizes to sort the phot_names list
ap_sizes = sorted(ap_sizes_raw)
phot_names = [f for _,f in sorted(zip(ap_sizes_raw,phot_names_raw))]

# Get object locations from 'phot_coords.orig' file
try:
    tloc = np.loadtxt(tloc_name,usecols=(0,1))
except:
    tloc = np.array([])

# Load in the Photometry for each aperture size
phot_data = [pd.read_csv(f,header=None,delim_whitespace=True).iloc[:,1:].astype('float64') 
             for f in phot_names]
             
if psource == 'hsp':
    Nobj = len(phot_data[0].columns) - 1  # Number of stars (last column is sky)
elif psource == 'mae':
    Nobj = int(len(phot_data[0].columns)/2)
Ndat = len(phot_data[0])  # Number of data points


# If Phot2lc has been run previously, load in the Log File
log_exists = False
if os.path.exists('phot2lc_log.txt'):
    log_header = pd.read_csv('phot2lc_log.txt',header=None,nrows=5).values
    log_data = np.loadtxt('phot2lc_log.txt')
    log_exists = True


#############################################################
## Get some basic info about the object

# Pulls some basic info out of the header
hdr = fits.getheader(fits_list[0])
telescope = teldict['telename']
instrument = teldict['instname']
observer = hdr[teldict['observer']]
filter_name = hdr[teldict['filter']]
if args.object is None:
    obj_name = hdr[teldict['objname']].replace(" ","")
    if obj_name is None:
        print("Error! Object Name in teledat.py cannot be None.")
        sys.exit(1)
else:
    obj_name = args.object


# Check to make sure this object has an entry in stars.dat file
star_dat = pd.read_csv(star_dat_filename,header=None,delim_whitespace=True,dtype=str)
obj_idx = star_dat.index[star_dat.iloc[:,0] == obj_name]
if len(obj_idx) == 0:
    print("\nWARNING: No objects by name of {} found in 'stars.dat'\n".format(obj_name))
    sys.exit(1)
elif len(obj_idx) > 1:
    print("\nWARNING: Multiple entries for {} found in 'stars.dat'\n".format(obj_name))
    sys.exit(1)


#############################################################
# Show the first image along with marked targets/comps
# Get Data & Header Info for first image
if telcode == 'mcd2':
    image0 = fits.getdata(fits_list[0])[0]
else:
    image0 = fits.getdata(fits_list[0])

# Get Z-Scale Normalization vmin and vmax
ZS = ZScaleInterval(nsamples=10000, contrast=0.15, max_reject=0.5, 
                    min_npixels=5, krej=2.5, max_iterations=5)
vmin0,vmax0 = ZS.get_limits(image0)

# Plot images with sources marked
figsize = (9,6.2)
imfig = plt.figure(10,figsize=figsize)
gsim = GridSpec(1,1)
im = imfig.add_subplot(gsim[0])

# Set Window Name
imfig.canvas.set_window_title('First Image')

# Font Choices
fonta = {'fontname': 'AppleGothic',
          'weight': 'heavy'}
fontb = {'fontname': 'AppleGothic',
         'weight': 'normal'}

# Show the first image
im.imshow(image0, cmap='gray',vmin=vmin0, vmax=vmax0)

# Plot marker for each object
for i in range(len(tloc)):
    if i == 0:
        im.plot(tloc[i,0],tloc[i,1],
                ls='None',marker='o',ms=15,mew=2.0,mec='c',mfc='None')
        im.text(tloc[i,0],tloc[i,1]+15,'Target',ha='center',fontsize=12,**fonta)
    else:
        im.plot(tloc[i,0],tloc[i,1],
                ls='None',marker='o',ms=15,mew=2.0,mec='m',mfc='None')
        im.text(tloc[i,0],tloc[i,1]+15,'Comp {}'.format(i),ha='center',fontsize=12,**fonta)

# Make Axes invisible
plt.axis("off")

# Add title
titleim = '{}'.format(fits_list[0].split("/")[-1])
im.set_title(titleim,loc='center',fontsize=14,**fontb)

# Save the figure
im_name = '{}_firstframe.png'.format(obj_name)
plt.savefig(im_name,format='png',dpi=100,bbox_inches='tight',pad_inches=0.15)

# Add some basic interactive capabilities
def on_event_im(event):
    global grid_search

    # Option to close the figure and Save the LightCurve
    if event.key == 'W':
        grid_search = False
        plt.close("all")
    # Option to close the figure and Save the LightCurve
    if event.key == 'G':
        grid_search = True
        plt.close("all")
    # Option to close the figure with Saving the LightCurve
    if (event.key == 'Q') | (event.key == 'q'):
        plt.close("all")
        print('\nProgram exited. Lightcurve was not saved.\n')
        sys.exit(1)
    # Option to reprint the command list
    if event.key == '?':
        print_commands()

# Connect incoming events to event handlers
cidim = imfig.canvas.mpl_connect('key_press_event', on_event_im)
#############################################################


#############################################################
# First need to generate the time stamps for data points

# If no previous log exists, generate time stamps from FITS images
if not log_exists:
    # Time stamp for first exposure
    dt_zero_start,texp = get_time(fits_list[0],teldict)

    # Convert dates & times to ISOT format for Astropy Time
    dt_zero_mid = dt_zero_start + TimeDelta(texp/2.0,format='sec')

    if psource == 'mae':
        time_stamps = np.loadtxt(phot_names[0])[:,0]*86400.
    elif psource == 'hsp':
        if num_fits == 1:
            time_stamps = np.arange(len(Ndat),dtype=float)*texp
        else:
            ## Create list to store the relative time differences
            time_stamps = np.zeros(num_fits)
            print('')
            for i in range(num_fits):
                    
                ## Print Progress Bar
                count1  = i+1
                action1 = 'Grabbing times from FITS...'
                progress_bar(count1, num_fits, action1)
                    
                ## Get the date/time/texp for current frame
                dt_start,texp = get_time(fits_list[i],teldict)

                # Convert dates & times to ISOT format for Astropy Time
                dt_mid = dt_start + TimeDelta(texp/2.0,format='sec')
                    
                ## Calculate the time difference in seconds
                tdelta = dt_mid - dt_zero_mid
                tdelta_sec = tdelta.to_value(u.s)
                    
                ## Append time difference to time_stamps list
                time_stamps[i] = tdelta_sec
    print('')

# If a log does exist, use it to generate the time stamps
elif log_exists:
    # First check whether the first image still exists
    # If so, get the first timestamp from there
    if os.path.exists(fits_list[0]):
        dt_zero_start,texp = get_time(fits_list[0],teldict)
        dt_zero_mid = dt_zero_start + TimeDelta(texp/2.0,format='sec')
        time_stamps = log_data[:,0]
    else:
        dt_zero_mid_str = log_header[2][0].split("=")[-1].strip()
        texp = float(log_header[4][0].split("=")[-1].strip())
        dt_zero_mid = Time(dt_zero_mid_str, scale='utc',format='isot')
        time_stamps = log_data[:,0]
    print('')


#############################################################
# Make a WQED-like plot showing all the raw photometry
figa = plt.figure(0,figsize=figsize)
gsa = GridSpec(1,1)
ax_LC = figa.add_subplot(gsa[0])

# Set Window Name
figa.canvas.set_window_title('Raw Photometry')

# Dark Background Color List
dbk_clist = plt.rcParams['axes.prop_cycle'].by_key()['color']
dbk_clist += dbk_clist

# Plot sky counts, normalized to range from 0->1
try:
    ap4_loc = ap_sizes.index(4.0)
except:
    ap4_loc = 0
if psource == 'hsp':
    sky_flux = phot_data[ap4_loc][Nobj+1]
elif psource == 'mae':
    sky_flux = phot_data[ap4_loc][2]
sky_norm = sky_flux / sky_flux.mean()
ax_LC.plot(time_stamps,(sky_norm - min(sky_norm))/max(sky_norm - min(sky_norm)),
           ls='None',marker='.',ms=3,mfc='#3C87FA',mec='#3C87FA')

# Pandas column-IDs for all comps
# If phot2lc Log exists, initialize comp_select with previously chosen comps
if psource == 'hsp':
    comp_ids = list(range(2,Nobj+1))
elif psource == 'mae':
    comp_ids = list(range(3,2*Nobj,2))

if not log_exists:
    comp_select = comp_ids.copy()
elif log_exists:
    if psource == 'hsp':
        comp_select = [int(x)+1 for x in log_header[3][0].split("=")[-1].split("+")]
    elif psource == 'mae':
        comp_select = [2*int(x)+1 for x in log_header[3][0].split("=")[-1].split("+")]

# Plot Stellar Flux
for i in range(Nobj):
    if psource == 'hsp':
        norm = phot_data[ap4_loc][i+1] / phot_data[ap4_loc][i+1].mean()
    elif psource == 'mae':
        norm = phot_data[ap4_loc][2*i+1] / phot_data[ap4_loc][2*i+1].mean()
    norm2 = (norm - min(norm))/max(norm - min(norm))
    ax_LC.plot(time_stamps,norm2 + Nobj-i+0.5,ls='None',marker='.',ms=3,
               mfc=dbk_clist[i],mec=dbk_clist[i])

# Set y limit
ax_LC.set_ylim(0,Nobj+2)

# Set tick parameters
ax_LC.tick_params(which='both',axis='both',labelleft=False,left=True,right=True)
ax_LC.tick_params(which='both',axis='y',direction='in')

# Add Labels for each light curve shown
xdiff = ax_LC.get_xlim()[1] - ax_LC.get_xlim()[0]
xloc = ax_LC.get_xlim()[1] + 0.01*xdiff
for i in range(Nobj+1):
    if i == 0:
        lc_labela = 'Target'
        if psource == 'hsp':
            lc_labelb = '{:.0f}'.format(phot_data[ap4_loc][i+1].mean())
        elif psource == 'mae':
            lc_labelb = '{:.0f}'.format(phot_data[ap4_loc][2*i+1].mean())
        lc_color = dbk_clist[i]
        ax_LC.text(xloc,float(Nobj+1-i)-0.00,lc_labela,
                   fontsize=13,color=lc_color,**fonta)
        ax_LC.text(xloc,float(Nobj+1-i)-0.30,lc_labelb,
                   fontsize=10,color=lc_color,**fonta)
    elif (i > 0) & (i < Nobj):
        lc_labela = 'Comp {}'.format(i)
        if psource == 'hsp':
            lc_labelb = '{:.0f}'.format(phot_data[ap4_loc][i+1].mean())
        elif psource == 'mae':
            lc_labelb = '{:.0f}'.format(phot_data[ap4_loc][2*i+1].mean())
        lc_color = dbk_clist[i]
        ax_LC.text(xloc,float(Nobj+1-i)-0.00,lc_labela,
                   fontsize=13,color=lc_color,**fonta)
        ax_LC.text(xloc,float(Nobj+1-i)-0.30,lc_labelb,
                   fontsize=10,color=lc_color,**fonta)
    else:
        lc_labela = 'Sky'
        lc_labelb = '{:.0f}'.format(sky_flux.mean())
        lc_color = '#3C87FA'
        ax_LC.text(xloc,float(Nobj+1-i)-0.50,lc_labela,
                   fontsize=13,color=lc_color,**fonta)
        ax_LC.text(xloc,float(Nobj+1-i)-0.80,lc_labelb,
                   fontsize=10,color=lc_color,**fonta)


# Add XY labels
ax_LC.set_xlabel('Time (s)',fontsize=12,**fontb)
ax_LC.set_ylabel('Photometric Counts',fontsize=12,**fontb)

# Add title
print_date = dt_zero_start.to_value('iso')[0:10]
print_time = dt_zero_start.to_value('iso')[11:19]
titlea = 'Raw Photometry for {}   UT-Date: {}'.format(obj_name,print_date)
ax_LC.set_title(titlea,loc='center',fontsize=14,**fontb)

# Connect incoming events to event handlers
cida = figa.canvas.mpl_connect('key_press_event', on_event_im)
###################################################


###################################################
# Now to Ceate the Interactive Figure where outlier
# and bad/weather points can be removed, a polynomial 
# fit can be chosen, and comp stars can be chosen

figb = plt.figure(1,figsize=figsize)
gsb = GridSpec(2,3,height_ratios=[1,3],hspace=0.0,wspace=0.0)
bx_LC = figb.add_subplot(gsb[1,:])
bx_mod = figb.add_subplot(gsb[0,0])
bx_comp = figb.add_subplot(gsb[0,1:])

# Set Window Name
figb.canvas.set_window_title('Divided Light Curve')

# Tick Parameters
bx_LC.tick_params(which='both',axis='both',right=True,top=True,direction='in')
bx_mod.tick_params(which='both',axis='both',labelbottom=False,
                   labelleft=False,left=False,bottom=False)
bx_comp.tick_params(which='both',axis='both',labelbottom=False,
                   labelleft=False,left=False,bottom=False)

# Initialize the plot
apnum = 0   # First Aperture
# If Phot2lc has not been previously run, use default initialization
if not log_exists:
    deg = 0                       # No polynomial fit to start
    dele_ind = []                 # No points deleted
    keep_ind = list(range(Ndat))  # Keeping all points
# If phot2lc has been run, use the log to initialize
# the polynomial fit and the lists of deleted/kept points
elif log_exists:
    deg = int(log_header[1][0].split("=")[-1].strip())
    dele_ind = [i for i,x in enumerate(log_data[:,1]) if x == 0]
    keep_ind = [i for i,x in enumerate(log_data[:,1]) if x == 1]
phot_target = phot_data[apnum][1].replace(0.0,1.0).values
phot_comps = np.sum(phot_data[apnum][comp_select].replace(0.0,1.0).values,axis=1)
dlc_init,dellc_init,draw_init,mod_init = div_lc(time_stamps,phot_target,phot_comps,
                                                deg,keep_ind,dele_ind)

# Plot Light Curve
p0, = bx_LC.plot(time_stamps[keep_ind],dlc_init,ls='-',c='w',lw=0.5,
                 marker='o',ms=3,mew=0.75,mec='w',mfc='w')
# Show deleted points
p1, = bx_LC.plot(time_stamps[dele_ind],dellc_init,ls='None',
                 marker='o',ms=3,mew=0.75,mec='r',mfc='r')
# Show Polynomial Fit
p2, = bx_mod.plot(time_stamps[keep_ind],mod_init,
                  ls='-',c='r',lw=1.5,zorder=2)
p3, = bx_mod.plot(time_stamps[keep_ind],draw_init,
                  ls='None',marker='.',ms=2,mew=0,mfc='w',zorder=1)
# Show Summed Comp Star Light Curve
p4, = bx_comp.plot(time_stamps[keep_ind],phot_comps[keep_ind],
                   ls='None',marker='.',ms=3,mew=0,mfc='C0')

# Add a cursor to the divided light curve plot
cursor = Cursor(bx_LC, useblit=True, c='r', lw=0.5, ls='--')

# Set initial XY limits
xlow_LC = min(time_stamps[keep_ind])
xupp_LC = max(time_stamps[keep_ind])
ylow_LC = np.nanmin(dlc_init[~np.isinf(dlc_init)])
yupp_LC = np.nanmax(dlc_init[~np.isinf(dlc_init)])
ylow_mod = np.nanmin(draw_init)
yupp_mod = np.nanmax(draw_init)
ylow_comp = np.nanmin(phot_comps[keep_ind])
yupp_comp = np.nanmax(phot_comps[keep_ind])

xdiff_LC = xupp_LC - xlow_LC
ydiff_LC = yupp_LC - ylow_LC
ydiff_mod = yupp_mod - ylow_mod
ydiff_comp = yupp_comp - ylow_comp

bx_LC.set_xlim(xlow_LC-0.05*xdiff_LC, xupp_LC+0.05*xdiff_LC)
bx_LC.set_ylim(ylow_LC-0.30*ydiff_LC, yupp_LC+0.30*ydiff_LC)
bx_mod.set_ylim(ylow_mod-0.30*ydiff_mod, yupp_mod+0.30*ydiff_mod)
bx_comp.set_ylim(ylow_comp-0.30*ydiff_mod, yupp_comp+0.30*ydiff_mod)


# Add XY labels
bx_LC.set_xlabel('Time (s)',fontsize=12,**fontb)
bx_LC.set_ylabel('Normalized Flux',fontsize=12,**fontb)

# Add the initial title
p2p_scat = pp_scat(dlc_init)
titleb = 'Object: {}   UT-Date: {}   UT-Time: {} \n\n'.format(
          obj_name,print_date,print_time)
bx_LC.set_title(titleb,loc='center',y=1.23,fontsize=13,**fontb)

# Add title to polynomial plot showing the poly order
title_poly = 'Poly Order = {}'.format(deg)
bx_mod.set_title(title_poly,loc='left',fontsize=10,x=0.02,y=0.80,**fontb)

# Add title to comparison star plot showing which comps are used
comp_str = gen_compstr(comp_select,psource)
title_comp = 'Comps {}'.format(comp_str)
bx_comp.set_title(title_comp,loc='left',fontsize=10,x=0.01,y=0.80,**fontb)

# Add text showing the aperture selection and P2P scatter
apscat_str =  'Aperture Radius: {:.2f} pix {:5}'.format(ap_sizes[apnum],'') + \
              'P2P Scatter: {:.2f} %'.format(p2p_scat*1e2)
tx0 = bx_LC.text(0.015,0.94,apscat_str,
                 transform=bx_LC.transAxes,fontsize=10,**fontb)

# Set some more initial parameters
num_del = len(dele_ind)                 # Tracks number of deleted points
keep_x = np.copy(time_stamps[keep_ind]) # x-values kept
keep_y = np.copy(dlc_init)              # y-values kept
dele_x = np.copy(time_stamps[dele_ind]) # x-values deleted
dele_y = np.copy(dellc_init)            # y-values deleted
show_delps = True             # Whether to plot deleted point or not
zoomed = False                # Whether the plot is zoomed in or not

###################################################
# Functions for plotting

# Function to update the plots
def update_plot(kx,ky,dx,dy,raw,mod,comp):
    global bx_LC,bx_mod,bx_comp,p0,p1,p2,p3,p4
    global comp_select,show_delps,apnum,zoomed

    # Update data
    p0.set_data(kx,ky)   # Update Kept points in final DLC
    p2.set_data(kx,mod)  # Update Polynomial model
    p3.set_data(kx,raw)  # Update raw DLC used for polyfit
    p4.set_data(kx,comp) # Update the the comparison star data

    # Update XY limits and display of deleted points for divided light curve
    if show_delps:
        p1.set_data(dx,dy)
        if len(dx) > 0:
            minx,maxx = min(min(kx),min(dx)),max(max(kx),max(dx))
            miny = min(np.nanmin(ky[~np.isinf(ky)]),np.nanmin(dy[~np.isinf(dy)]))
            maxy = max(np.nanmax(ky[~np.isinf(ky)]),np.nanmax(dy[~np.isinf(dy)]))
        else:
            minx,maxx = min(kx),max(kx)
            miny,maxy = np.nanmin(ky[~np.isinf(ky)]),np.nanmax(ky[~np.isinf(ky)])
    else:
        p1.set_data([],[])
        minx,maxx = min(kx),max(kx)
        miny,maxy = np.nanmin(ky[~np.isinf(ky)]),np.nanmax(ky[~np.isinf(ky)])
    xdiff = maxx-minx
    ydiff = maxy-miny
    if not zoomed:  # Adjust xy limits if not already zoomed in
        bx_LC.set_xlim(minx-0.05*xdiff,maxx+0.05*xdiff)
        bx_LC.set_ylim(miny-0.30*ydiff,maxy+0.30*ydiff)

    # Update XY Limits of bx_mod axis
    miny_mod = min(raw)
    maxy_mod = max(raw)
    ydiff_mod = maxy_mod - miny_mod
    bx_mod.set_ylim(miny_mod-0.3*ydiff_mod,maxy_mod+0.3*ydiff_mod)

    # Update XY Limits of bx_comp axis
    miny_comp = min(comp)
    maxy_comp = max(comp)
    ydiff_comp = maxy_comp - miny_comp
    bx_comp.set_ylim(miny_comp-0.3*ydiff_comp,maxy_comp+0.3*ydiff_comp)

    # Update the Main Plot Title
    titleb = 'Object: {}   UT-Date: {}\n\n'.format(obj_name,print_date)
    bx_LC.set_title(titleb,loc='right',fontsize=13,**fontb)

    # Update the polynomial order title
    title_poly = 'Poly Order = {}'.format(deg)
    bx_mod.set_title(title_poly,loc='left',fontsize=10,x=0.02,y=0.80,**fontb)

    # Update the comparison star title
    comp_str = gen_compstr(comp_select,psource)
    title_comp = 'Comps {}'.format(comp_str)
    bx_comp.set_title(title_comp,loc='left',fontsize=10,x=0.01,y=0.80,**fontb)

    # Add text showing the aperture selection and P2P scatter
    p2p_scat = pp_scat(ky)
    apscat_str =  'Aperture Radius: {:.2f} pix {:5}'.format(ap_sizes[apnum],'') + \
                  'P2P Scatter: {:.2f} %'.format(p2p_scat*1e2)
    tx0.set_text(apscat_str)

    plt.draw()
    return

# Function to find index of nearest data point to cursor
def get_index(ax_obj,p,x,y,xe,ye):
    trans = ax_obj.transData.transform(list(zip(x,y)))
    dist = ((trans[:,0]-xe)**2 + (trans[:,1]-ye)**2)**0.2
    index = dist.argmin()
    return index

# Function to update plot title for messages to User
def message(ax_obj,s):
    ax_obj.set_title(s)
    plt.draw()

# Function to activate and deactivate the Rectangle garbage selector
def toggle_garbage(show=True):
    if toggle_garbage.RS.active:
        if show:
            print(' Garbage Selector deactivated.')
        toggle_garbage.RS.set_active(False)
    elif not toggle_garbage.RS.active:
        if show:
            print(' Garbage Selector activated.')
        toggle_garbage.RS.set_active(True)

# Function to activate and deactivate the Rectangle garbage selector
def toggle_reverse_garbage(show=True):
    if toggle_reverse_garbage.RS.active:
        if show:
            print(' Reverse Garbage Selector deactivated.')
        toggle_reverse_garbage.RS.set_active(False)
    elif not toggle_reverse_garbage.RS.active:
        if show:
            print(' Reverse Garbage Selector activated.')
        toggle_reverse_garbage.RS.set_active(True)

# # Function to activate and deactivate Rectangle zoom selector
def toggle_zoom(show=True):
    if toggle_zoom.RS.active:
        if show:
            print(' Zoom Selector deactivated.')
        toggle_zoom.RS.set_active(False)
    elif not toggle_zoom.RS.active:
        if show:
            print(' Zoom Selector activated.')
        toggle_zoom.RS.set_active(True)

# Function to update the divided light curves
def dlc_update(kx,ky,dx,dy,degree,aper):
    global bx_LC,p0,p1,p2,p3,p4
    global dele_ind,keep_ind,comp_select,show_delps

    # Get indices of kept/deleted points
    dele_ind = [np.where(time_stamps == d)[0][0] for d in dx]
    keep_ind = [np.where(time_stamps == k)[0][0] for k in kx]

    # Calculate new divided light curve
    phot_target = phot_data[aper][1].replace(0.0,1.0).values
    phot_comps = np.sum(phot_data[aper][comp_select].replace(0.0,1.0).values,axis=1)
    klc,dlc,kraw,kmod = div_lc(time_stamps,phot_target,phot_comps,
                               degree,keep_ind,dele_ind)

    # Update the figure
    update_plot(kx,klc,dx,dlc,kraw,kmod,phot_comps[keep_ind])

    return klc, dlc, keep_ind, dele_ind


# Function for Garbage Selector Rectangle to Call
def garbage_select(eclick, erelease):
    global apnum,bx_LC,p0,p1,p2,p3,p4,num_del,deg
    global keep_x,keep_y,keep_ind,zoomed
    global dele_x,dele_y,dele_ind,show_delps

    # Find indices of data points within drawn box
    xlow,xupp = sorted([eclick.xdata,erelease.xdata])
    ylow,yupp = sorted([eclick.ydata,erelease.ydata])

    ind_box = np.where((keep_x > xlow) & (keep_x < xupp) &
                       (keep_y > ylow) & (keep_y < yupp))

    if len(ind_box[0]) == 0:
        return  # No data outside the box
    else:
        ind_box = ind_box[0]

    # If box will deletes too many points, don't allaow it
    if len(keep_x) - len(ind_box) <= 1:
        print('WARNING: Cannot delete any more points!')
        return
    else:
        # Delete data points from lightcurve
        dele_x = np.append(dele_x,keep_x[ind_box])
        dele_y = np.append(dele_y,keep_y[ind_box])
        keep_x = np.delete(keep_x,ind_box)
        keep_y = np.delete(keep_y,ind_box)
        num_del += len(ind_box)

        # Resort the keep array
        keep_x = keep_x[keep_x.argsort()]
        keep_y = keep_y[keep_x.argsort()]

        # Update divided light curve and plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,deg,apnum)

        # Deactivate the garbage selector
        toggle_garbage()

        return


# Function for Garbage Selector Rectangle to Call
def reverse_garbage_select(eclick, erelease):
    global apnum,bx_LC,p0,p1,p2,p3,p4,num_del,deg
    global keep_x,keep_y,keep_ind,zoomed
    global dele_x,dele_y,dele_ind,show_delps

    # Find indices of data points outside the drawn box
    xlow,xupp = sorted([eclick.xdata,erelease.xdata])
    ylow,yupp = sorted([eclick.ydata,erelease.ydata])

    ind_box = np.where((keep_x < xlow) | (keep_x > xupp) |
                       (keep_y < ylow) | (keep_y > yupp))

    if len(ind_box[0]) == 0:
        return  # No data outside the box
    else:
        ind_box = ind_box[0]

    # If box will delete too many points, don't allaow it
    if len(keep_x) - len(ind_box) <= 1:
        print('WARNING: Cannot delete any more points!')
        return
    else:
        # Delete data points from lightcurve
        dele_x = np.append(dele_x,keep_x[ind_box])
        dele_y = np.append(dele_y,keep_y[ind_box])
        keep_x = np.delete(keep_x,ind_box)
        keep_y = np.delete(keep_y,ind_box)
        num_del += len(ind_box)

        # Resort the keep array
        keep_x = keep_x[keep_x.argsort()]
        keep_y = keep_y[keep_x.argsort()]

        # Update divided light curve and plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,deg,apnum)

        # Deactivate the garbage selector
        toggle_reverse_garbage()

        return

# Function for Zoom Selector Rectangle to Call
def zoom_select(eclick, erelease):
    global bx_LC,zoomed

    # Find indices of data points within drawn box
    xlow,xupp = sorted([eclick.xdata,erelease.xdata])
    ylow,yupp = sorted([eclick.ydata,erelease.ydata])

    # Update XY limits
    bx_LC.set_xlim(xlow,xupp)
    bx_LC.set_ylim(ylow,yupp)
    zoomed=True
    plt.draw()

    # Deactivate the Zoom Selector
    toggle_zoom()

    return


# Function for sigma clipping data
def sigclip_data(kx,ky,dx,dy,siglow,sigupp,dw):
    global apnum,bx_LC,p0,p1,p2,p3,p4,num_del,deg
    global keep_x,keep_y,keep_ind,zoomed
    global dele_x,dele_y,dele_ind,show_delps
    
    # First get stats within each window
    windows = np.arange(0,len(kx),dw,dtype=int)
    polymods,win_std = window_std(kx,ky,windows,dw)

    # Iterate through windows and sigma-slip data
    good_idx = []
    bad_idx = []
    for i,w in enumerate(windows):
        xdata = kx[w:w+dw]
        ydata = ky[w:w+dw]

        for j,y in enumerate(ydata):
            idx = i*dw + j
            lowlim = polymods[i][j] - siglow*win_std[i]
            upplim = polymods[i][j] + sigupp*win_std[i]
            if (y > lowlim) & (y < upplim):
                good_idx.append(idx)
            else:
                bad_idx.append(idx)

    # Update keep_ind and dele_ind
    keep_x = kx[good_idx]
    keep_y = ky[good_idx]
    dele_x = np.append(dx,kx[bad_idx])
    dele_y = np.append(dy,ky[bad_idx])
    num_del += len(bad_idx)

    # Re-sort the deleted points
    dele_x = dele_x[dele_x.argsort()]
    dele_y = dele_y[dele_x.argsort()]

    # Update divided light curve and plot
    keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                 dele_x,dele_y,
                                                 deg,apnum)

    return


#####################################################

# Now for the main function which defines all of 
# the interactive options

def on_event_b(event):
    global apnum,bx_LC,p0,p1,p2,p3,p4,num_del,deg
    global keep_x,keep_y,keep_ind,comp_select,grid_search
    global dele_x,dele_y,del_ind,show_delps,zoomed

    # Add a deleted point back to the plot
    if event.key == 'a':
        if num_del > 0:
            xplot,yplot = event.x,event.y
            idx = get_index(bx_LC, p1, dele_x, dele_y, xplot, yplot)
            keep_x = np.append(keep_x,dele_x[idx])
            keep_y = np.append(keep_y,dele_y[idx])
            dele_x = np.delete(dele_x,idx)
            dele_y = np.delete(dele_y,idx)
            num_del -= 1

            # Re-sort the keep array
            keep_y = keep_y[keep_x.argsort()]
            keep_x = keep_x[keep_x.argsort()]

            # Update divided light curve and plot
            keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                         dele_x,dele_y,deg,apnum)

    # Add ALL deleted points back to the plot
    if event.key == 'A':
        if num_del > 0:
            print('Added all deleted points back to Light Curve.')
            keep_x = np.append(keep_x,dele_x[:])
            keep_y = np.append(keep_y,dele_y[:])
            dele_x = np.asarray([])
            dele_y = np.asarray([])
            num_del = 0

            # Re-sort the keep array
            keep_y = keep_y[keep_x.argsort()]
            keep_x = keep_x[keep_x.argsort()]

            # Update divided light curve and plot
            keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                         dele_x,dele_y,deg,apnum)

    # Delete a single point from the light curve
    if event.key == 'd':
        if len(keep_x) <= 2:
            print('WARNING! Cannot delete any more points!')
        else:
            xplot,yplot = event.x,event.y
            idx = get_index(bx_LC, p0, keep_x, keep_y, xplot, yplot)
            dele_x = np.append(dele_x,keep_x[idx])
            dele_y = np.append(dele_y,keep_y[idx])
            keep_x = np.delete(keep_x,idx)
            keep_y = np.delete(keep_y,idx)
            num_del += 1

            # Update divided light curve and plot
            keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                         dele_x,dele_y,deg,apnum)

    # Toggle the display of deleted points
    if event.key == 's':
        if show_delps == False:
            show_delps=True
        elif show_delps == True:
            show_delps=False
        # Update divided light curve and plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,deg,apnum)

    # Fit a Polynomial to the light curve (excludes deleted points)
    if event.key == 'f':
        # First Deactivate Garbage/Zoom Boxes if Active
        if toggle_garbage.RS.active:
            toggle_garbage()
        if toggle_reverse_garbage.RS.active:
            toggle_reverse_garbage()
        if toggle_zoom.RS.active:
            toggle_zoom()

        # Ask for new degree value on command line
        deg_str = input('New Polynomial Degree: ') 
        deg = int(deg_str)

        # Update divided light curve & plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,deg,apnum)

    # Choose which comparison stars to use
    if event.key == 'c':
        # First Deactivate Garbage/Zoom Boxes if Active
        if toggle_garbage.RS.active:
            toggle_garbage()
        if toggle_reverse_garbage.RS.active:
            toggle_reverse_garbage()
        if toggle_zoom.RS.active:
            toggle_zoom()

        # Ask for comp stars
        new_comps = input('Enter Comp Numbers (comma separated): ')
        new_comp_ind = [int(c)-1 for c in new_comps.split(",")]

        # Make sure none of the given numbers exceed the number of comps
        # and that the number of comps selected also does not exceed this #
        if any([i >= max(comp_ids)-1 for i in new_comp_ind]):
            print('WARNING: Comps ID(s) exceeds number of comps.')
        elif len(new_comp_ind) > len(comp_ids):
            print('WARNING: More Comps ID(s) Supplied than number of comps.')
        elif len(new_comp_ind) > len(set(new_comp_ind)):
            print('WARNING: Duplicate IDs supplied.')
        else:
            # Get new selection of comps
            comp_select = [comp_ids[i] for i in new_comp_ind]

            # Update divided light curve & plot
            keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                         dele_x,dele_y,deg,apnum)

    # Move to previous aperture size
    if event.key == 'v':
        # Iterate the aperture number
        if apnum > 0: # If greater than the minimum aperture size
            apnum -= 1

        # Update divided light curve & plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,deg,apnum)

    # Move to next aperture size
    if event.key == 'w':
        # Iterate the aperture number
        if apnum < Nf-1: # If less than the maximum aperture size
            apnum += 1
        # Update divided light curve & plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,deg,apnum)

    # Draw box to delete points inside
    if event.key == 'g':
        if (toggle_zoom.RS.active) | (toggle_reverse_garbage.RS.active):
            print(" Cannot activate Garbage Selector while other selectors are Active.")
        else:
            toggle_garbage() 

    # Draw box to delete points outside
    if event.key == 'r':
        if (toggle_zoom.RS.active) | (toggle_garbage.RS.active):
            print(" Cannot activate Reverse Garbage Selector while other selectors are Active.")
        else:
            toggle_reverse_garbage() 

    # Draw box to zoom in
    if event.key == 'z':
        if (toggle_garbage.RS.active) | (toggle_reverse_garbage.RS.active):
            print(" Cannot activate Zoom Selector while other selectors are Active.")
        else:
            toggle_zoom()  

    # Restore zoom to original
    if event.key == 'Z':
        # Update XY limits and display of deleted points 
        if show_delps:
            if len(dele_x) > 0:
                minx,maxx = (min(min(keep_x),min(dele_x)),
                             max(max(keep_x),max(dele_x)))
                miny,maxy = (min(min(keep_y),min(dele_y)),
                             max(max(keep_y),max(dele_y)))
            else:
                minx,maxx = min(keep_x),max(keep_x)
                miny,maxy = min(keep_y),max(keep_y)
        else:
            minx,maxx = min(keep_x),max(keep_x)
            miny,maxy = min(keep_y),max(keep_y)
        xdiff = maxx-minx
        ydiff = maxy-miny
        bx_LC.set_xlim(minx-0.05*xdiff,maxx+0.05*xdiff)
        bx_LC.set_ylim(miny-0.30*ydiff,maxy+0.30*ydiff)
        zoomed=False
        plt.draw()

    # Perform sigma-clipping
    if event.key == 'x':

        print("\nPerforming Sigma-Clipping:")

        # Get lower sigma input
        success = False
        while not success:
            try:
                sigma_low = input('Lower Sigma: ')
                sigma_low = float(sigma_low)
                success = True
            except Exception as e:
                print(e)

        # Get upper sigma input
        success = False
        while not success:
            try:
                sigma_upp = input('Upper Sigma: ')
                sigma_upp = float(sigma_upp)
                success = True
            except Exception as e:
                print(e)

        # Get Window Size (Optional since 25 is the default)
        try:
            window_size = input('Window Size [25]: ')
            window_size = int(window_size)
        except:
            window_size = 25

        # Call the sigma-clip function
        sigclip_data(keep_x,keep_y,dele_x,dele_y,
                     sigma_low,sigma_upp,dw=window_size)


    # Option to close the plots and continue to aperture selection
    # without performing a full comp star + ap. size grid search
    if event.key == 'W':
        grid_search = False
        plt.close("all")

    # Option to close the plots and continue t aperture selection
    # while also performing a full grid search of possible comp
    # star + aperture size combinations for best combo.
    if event.key == 'G':
        grid_search = True
        plt.close('all')

    # Option to close the plots and exit program
    if (event.key == 'Q') | (event.key == 'q'):
        plt.close("all")
        print("\nProgram Exited. Lightcurve was not saved. \n")
        sys.exit(1)

    # Option to reprint the command list
    if event.key == '?':
        print_commands()

# Setup the rectangle selectors, start them off as deactivated
rect_garbage = dict(facecolor='#FFFF00',edgecolor = '#FFFF00',
                    alpha=0.4, fill=True)
rect_zoom = dict(facecolor='#6DA1F0',edgecolor = '#6DA1F0',
                 alpha=0.4, fill=True)
toggle_garbage.RS = RectangleSelector(bx_LC, garbage_select, 
                                       rectprops=rect_garbage,
                                       drawtype='box', 
                                       useblit=True,
                                       button=[1],  # Left click only
                                       minspanx=5, minspany=5,
                                       spancoords='pixels',
                                       interactive=False)

toggle_reverse_garbage.RS = RectangleSelector(bx_LC, reverse_garbage_select, 
                                       rectprops=rect_garbage,
                                       drawtype='box', 
                                       useblit=True,
                                       button=[1],  # Left click only
                                       minspanx=5, minspany=5,
                                       spancoords='pixels',
                                       interactive=False)

toggle_zoom.RS = RectangleSelector(bx_LC, zoom_select, 
                                       rectprops=rect_zoom,
                                       drawtype='box', 
                                       useblit=True,
                                       button=[1],  # Left click only
                                       minspanx=5, minspany=5,
                                       spancoords='pixels',
                                       interactive=False)
toggle_garbage.RS.set_active(False)
toggle_reverse_garbage.RS.set_active(False)
toggle_zoom.RS.set_active(False)

# Connect incoming events to event handlers
cidb1 = figb.canvas.mpl_connect('key_press_event', on_event_b)
plt.show()

# Disconnect Figures from interactive event handling
figa.canvas.mpl_disconnect(cida)
figb.canvas.mpl_disconnect(cidb1)
imfig.canvas.mpl_disconnect(cidim)


############################################
# After choosing a polynomial fit and removing outliers,
# display all light curves for each aperture simultaneously
# and highlight the optimal aperture choice.

# First Calculate final divided light curves and 
# calculate p2p-scatter for each aperture
lc_all = []
pp_all = []
for i in range(Nf):
    # Calculate divided LC and point-2-point scatter
    phot_t = phot_data[i][1].replace(0.0,1.0).values
    phot_c = np.sum(phot_data[i][comp_select].replace(0.0,1.0).values,axis=1)
    lc,_,_,_ = div_lc(time_stamps,phot_t,phot_c,deg,keep_ind,dele_ind)
    pp = pp_scat(lc)
    lc_all.append(lc)
    pp_all.append(pp*1e2)

# Get index of minimum aperture
ind_ppmin = pp_all.index(min(pp_all))

# Generate all possible combinations of comparison stars
all_combs = []
for i in range(len(comp_ids)):
    new_comb = list(combinations(comp_ids,i+1))
    # Have to convert the itertools tuples into pure lists for Pandas
    for nc in new_comb:
        nc_new = list(nc)
        all_combs.append(nc_new)
Ncomb = len(all_combs)

# Go through all possible aper/comp-star 
# combinations to find the optimal one
if grid_search:
    lc_comb = []   # Stores lowest P2P-light curve for each Comp-Combo
    pp_comb = []   # Stores lowest P2P-value for each Comp-Combo
    ind_aper = []  # Stores index for aperture with lowest P2P-value
    for i in range(Ncomb):
        lc_aper = []
        pp_aper = []
        for j in range(Nf):
            phot_t = phot_data[j][1].replace(0.0,1.0).values
            phot_c = np.sum(phot_data[j][all_combs[i]].replace(0.0,1.0).values,axis=1)
            lc,_,_,_ = div_lc(time_stamps,phot_t,phot_c,deg,keep_ind,dele_ind)
            pp = pp_scat(lc)
            lc_aper.append(lc)
            pp_aper.append(pp*1e2)

        # Save info for lowest P2P light curve
        ind_aper.append(pp_aper.index(min(pp_aper)))
        lc_comb.append(lc_aper[ind_aper[i]])
        pp_comb.append(pp_aper[ind_aper[i]])

    # Get index of minimum Combination
    ind_comb_min = pp_comb.index(min(pp_comb))


# Print out a message showing P2P for User Selected Comps + Optimal Aperture
apmin_user = ap_sizes[ind_ppmin]
comp_str1 = gen_compstr(comp_select,psource)
print('\nUser Selected Comparison Stars + Optimal Aperture:')
print('    Aper  = {} pix'.format(apmin_user) + \
    '\n    Comps = {}'.format(comp_str1) + \
    '\n    P2P   = {:.2f} %'.format(pp_all[ind_ppmin]))

# Print out a message showing P2P for Optimized Comp + Aperture Combination
if grid_search:
    apmin_auto = ap_sizes[ind_aper[ind_comb_min]]
    comp_str2 = gen_compstr(all_combs[ind_comb_min],psource)
    print('\nFull Grid Search for Optimal Comp + Aperture Combination:')
    print('    Aper  = {} pix'.format(apmin_auto) + \
        '\n    Comps = {}'.format(comp_str2) + \
        '\n    P2P   = {:.2f} %'.format(pp_comb[ind_comb_min]))

    # Print out some warning messages if the User-Selected vs.
    # Automated Comp or Aperture Selections do not agree
    if (comp_str1 != comp_str2) & (apmin_user == apmin_auto):
        print('\nWARNING: Comparison Star Selections Do Not Agree.')
    elif (comp_str1 == comp_str2) & (apmin_user != apmin_auto):
        print('\nWARNING: Aperture Size Selections Do Not Agree.')
    elif (comp_str1 != comp_str2) & (apmin_user != apmin_auto):
        print('\nWARNING: Aperture Size and Comparison Star Selections Do Not Agree.')
    else:
        print('\nAperture Size and Comparison Star Selections are in Agreement.')
else:
    print('\nFull Grid Search not performed.')

# Later the User will be able to decide whether to save the
# light curve reflecting their own comparison star choices
# or save the light curve with optimized selections
# This flag will decide with light curve to put into
# the save lightcurve file.
save_user = True

############################################
# Plot light curves for all apertures
figc = plt.figure(3,figsize=figsize)
gsc = GridSpec(3,1,hspace=0.3)
cx_lc = figc.add_subplot(gsc[0])
cx_ft = figc.add_subplot(gsc[1])
cx_pp = figc.add_subplot(gsc[2])

# Set Window Name
figc.canvas.set_window_title('Aperture Selection')

# Set xy-limits of the PP plot
ppdiff = max(pp_all) - min(pp_all)
cx_pp.set_xlim(min(ap_sizes)-0.4 ,max(ap_sizes)+0.4)
cx_pp.set_ylim(min(pp_all)-0.2*ppdiff, max(pp_all)+0.2*ppdiff)

# Plot Light Curve from Optimal Apertured
lcplot1, = cx_lc.plot(time_stamps[keep_ind],lc_all[ind_ppmin],ls='-',lw=0.75,c='#FFFFFF88',
                      marker='o',ms=3.2,mew=0,mfc='w',label='User')
if grid_search:
    lcplot2, = cx_lc.plot(time_stamps[keep_ind],lc_comb[ind_comb_min],
                          ls='None',marker='o',ms=3.2,mec='r',mfc='None',
                          mew=0.8,label='Grid')
cx_lc.legend(ncol=2,fontsize=8)

# Plot Periodogram of Optimal Light Curve
freqarr,lsp = calc_lsp(time_stamps[keep_ind],lc_all[ind_ppmin])
siglevel = 4.0*np.nanmean(lsp)*1e2
ftplot1, = cx_ft.plot(freqarr*1e6,lsp*1e2,c='w',label='User')
if grid_search:
    freqarr2,lsp2 = calc_lsp(time_stamps[keep_ind],lc_comb[ind_comb_min])
    ftplot2, = cx_ft.plot(freqarr2*1e6,lsp2*1e2,ls='--',lw=0.9,c='r',label='Grid')

cx_ft.axhline(siglevel,ls='--',lw=1.0,c='C0')
cx_ft.set_xlim(0,max(freqarr)*1e6)
cx_ft.set_ylim(0,1.2*max(lsp)*1e2)
cx_ft.text(0.70*cx_ft.get_xlim()[1],siglevel+0.1*cx_ft.get_ylim()[1],
           r'4$\langle$A$\rangle$ Sig. Level = {:.3}%'.format(siglevel), **fontb)
cx_ft.legend(ncol=2,fontsize=8)
    
# Plot PP-Scat versus Aperture Size
pp_majloc = MultipleLocator(0.5)
pplot, = cx_pp.plot(ap_sizes,pp_all,ls='-',lw=1,c='w',
                    marker='o',ms=4,mec='w',mfc='w')
cx_pp.xaxis.set_major_locator(pp_majloc)

# Print a special marker for the minimum scatter aperture
cx_pp.plot(ap_sizes[ind_ppmin],pp_all[ind_ppmin],ls='None',
           marker='s',ms=9,mec='c',mfc='None',mew=1.5)

# Add text and labels
cx_lc.set_xlabel('Time (s)',fontsize=9,**fontb)
cx_lc.set_ylabel('Rel. Flux (%)',fontsize=9,**fontb)
cx_ft.set_xlabel('Frequency ($\mu$Hz)',fontsize=9,**fontb)
cx_ft.set_ylabel('Amplitude (%)',fontsize=9,**fontb)
cx_pp.set_xlabel('Aperture Radius (pix)',fontsize=9,**fontb)
cx_pp.set_ylabel('P2P-Scatter (%)',fontsize=9,**fontb)
cx_lc.tick_params(direction='in',labelsize=8)
cx_ft.tick_params(direction='in',labelsize=8)
cx_pp.tick_params(direction='in',labelsize=8)
cx_pp.tick_params(which='major',axis='x',labelrotation=70)


# Add title
titlec = 'Optimal Aperture Selection'.format(obj_name,print_date)
cx_lc.set_title(titlec,loc='center',fontsize=14,**fontb)

# Set some more initial parameters
num_ppdel = 0
keep_ppx = np.copy(np.asarray(ap_sizes))
keep_ppy = np.copy(np.asarray(pp_all)) 
dele_ppx = np.asarray([]) 
dele_ppy = np.asarray([])
# Add some basic interactive capabilities
def on_event_c(event):
    global cx_pp,pplot,save_user,grid_search
    global keep_ppx,keep_ppy,dele_ppx,dele_ppy,num_ppdel

    # Add ALL deleted points back to the plot
    if event.key == 'A':
        if num_ppdel > 0:
            print('Added all deleted points back to Light Curve.')
            keep_ppx = np.append(keep_ppx,dele_ppx[:])
            keep_ppy = np.append(keep_ppy,dele_ppy[:])
            dele_ppx = np.asarray([])
            dele_ppy = np.asarray([])
            num_ppdel = 0

            # Re-sort the keep array
            keep_ppy = keep_ppy[keep_ppx.argsort()]
            keep_ppx = keep_ppx[keep_ppx.argsort()]

            # Update the plot
            pplot.set_data(keep_ppx,keep_ppy)

            # Update Y limits
            ppdiff = max(keep_ppy) - min(keep_ppy)
            cx_pp.set_ylim(min(keep_ppy)-0.2*ppdiff, max(keep_ppy)+0.2*ppdiff)
            plt.draw()

    # Delete a single point from the light curve
    if event.key == 'd':
        if len(keep_ppx) <= 2:
            print('WARNING! Cannot delete any more points!')
        else:
            xplot,yplot = event.x,event.y
            idx = get_index(cx_pp, pplot, keep_ppx, keep_ppy, xplot, yplot)
            dele_ppx = np.append(dele_ppx,keep_ppx[idx])
            dele_ppy = np.append(dele_ppy,keep_ppy[idx])
            keep_ppx = np.delete(keep_ppx,idx)
            keep_ppy = np.delete(keep_ppy,idx)
            num_ppdel += 1

            # Update the plot
            pplot.set_data(keep_ppx,keep_ppy)

            # Update Y limits
            ppdiff = max(keep_ppy) - min(keep_ppy)
            cx_pp.set_ylim(min(keep_ppy)-0.2*ppdiff, max(keep_ppy)+0.2*ppdiff)
            plt.draw()

    # Option to close the figure and Save the *USER* LightCurve
    if event.key == 'W':
        plt.close("all")

    # Option to close the figure and Save the *AUTO* LightCurve
    if event.key == 'G':
        if grid_search:
            save_user=False  # Save flag changed
        else:
            print('Grid Search was not performed, saving USER lightcurve.')
        plt.close("all")

    # Option to close the figure with Saving the LightCurve
    if event.key == 'Q':
        plt.close("all")
        print('\nProgram exited. Lightcurve was not saved.\n')
        sys.exit(1)

    # Option to reprint the command list
    if event.key == '?':
        print_commands()

# Connect incoming events to event handlers
cidc = figc.canvas.mpl_connect('key_press_event', on_event_c)
plt.show()

# Disconnect Figure from interactive event handling
figc.canvas.mpl_disconnect(cidc)

############################################
# Before saving the final lightcurve, times need to 
# be converted into barycentric format

# Choose appropriate telescope location
loc = get_loc(hdr,telcode)


# Define Astropy coordinate object for the target
# Get the RA & Dec
obj_dat = star_dat.loc[obj_idx].values[0]
obj_ra = (float(obj_dat[1]) + 
          float(obj_dat[2])/60.0 + 
          float(obj_dat[3])/3600.0) * (360.0/24.0)
obj_dec = (float(obj_dat[4]) + 
           float(obj_dat[5])/60.0 + 
           float(obj_dat[6])/3600.0)
tcoord = SkyCoord(obj_ra,obj_dec,unit="deg",frame="icrs")
ra_string = tcoord.to_string('hmsdms',sep=" ")[0:11]
de_string = tcoord.to_string('hmsdms',sep=" ")[12:]

# Calculate BJD times using Astropy Time & TimeDelta objects
print('\nCalculating barycentric corrections...')
tkeep = time_stamps[keep_ind]
st = dt_zero_mid.to_value('isot')        # Mid Exposure start date/time
st_mjd = dt_zero_mid.to_value('mjd')     # Mid Exposure start MJD
print('BJD Mid-Exposure Start Time = {}'.format(st))
t0 = Time(st,scale='utc',format='isot',location=loc)  # Astropy mid-exposure Start time
td = TimeDelta(tkeep,format='sec')       # Time deltas for kept points
t = t0 + td                              # Full list of Astropy mid-exposure times
ltt_bary = t.light_travel_time(tcoord)   # Light travel time to barycenter
tbjd = t.tdb.jd + ltt_bary.jd            # Barycentric times (rescaled UTC + LTT)
tbjd_sec = (tbjd - tbjd[0]) * 86400.0    # BJD converted to seconds since reftime
bjdref = tbjd[0]                         # BJD of T0 for all timestamps


# Function to Calculate Error Bars and S/N
def calc_err(phot,aper):
    """
    phot = raw aperture phtometry for all objects
    aper = Aperture radius in pixels
    """

    # CCD Noise Characteristics
    aper_area = np.pi*aper**2  # Number of pixels within aperture
    
    # Add object & sky counts
    object_phot = phot[:,0]
    sky_phot = phot[:,-1]
    obj_sky = object_phot+sky_phot
    
    # Combine all sources of noise
    if np.isnan(dark):
        dark_noise = 0.0
    else:
        dark_noise = dark*texp*aper_area

    if np.isnan(read):
        read_noise = 0.0
    else:
        read_noise = (read/gain)**2*aper_area
    noise = np.sqrt(obj_sky + dark_noise + read_noise)
    snr = object_phot/noise
    return snr,noise


#################################################################
# Lastly, save the raw photometry and the normalized lightcurve
# into files.  Trying to create outputs similar to the WQED 
# .wq and .lc1 files

# Create header info for file (use save_user flag to choose appropriate info)
if save_user:
    opt_ind = ind_ppmin
    opt_ap = apmin_user
    comp_number = len(comp_select)
    comp_ids = comp_str1
    opt_pp = pp_all[ind_ppmin]
    snr,_ = calc_err(phot_data[ind_ppmin].iloc[keep_ind,:].values,opt_ap)
    err = 1./snr
else:
    opt_ind = ind_comb_min
    opt_ap = apmin_auto
    comp_number = len(all_combs[ind_comb_min])
    comp_ids = comp_str2
    opt_pp = pp_comb[ind_comb_min]
    snr,_ = calc_err(phot_data[ind_comb_min].iloc[keep_ind,:].values,opt_ap)
    err = 1./snr

# Header for the raw photometry file
rawphot_header = 'Object     : {:30s}# Name of Object'.format(obj_name) + \
    '\nRA         : {:30s}# Object Right Ascension'.format(ra_string) + \
    '\nDec        : {:30s}# Object Declination'.format(de_string) + \
    '\nTelescope  : {:30s}# Name of Telescope'.format(telescope) + \
    '\nInstrument : {:30s}# Name of Instrument'.format(instrument) + \
    '\nTeleCode   : {:30s}# Teledat Code Name'.format(telcode) + \
    '\nDate       : {:30s}# Mid-Exp. UTC Start Date'.format(st.split("T")[0]) + \
    '\nUTC        : {:30s}# Mid-Exp. UTC Start Time'.format(st.split("T")[1]) + \
    '\nMJD        : {:<30.9f}# Mid Exposure MJD Start'.format(st_mjd) + \
    '\nExptime    : {:<30.6f}# Exposure Time (s)'.format(texp) + \
    '\nFilter     : {:30s}# Filter Name'.format(filter_name) + \
    '\nBJED       : {:<30.9f}# Mid Exp. Barycentric Julian Date'.format(bjdref) + \
    '\nApPhot     : {:30s}# Photometry Program'.format(source_dict[psource]) + \
    '\nOrigFile   : {:30s}# Source Photometry Filename'.format(phot_names[opt_ind]) + \
    '\nApRadius   : {:<30.2f}# Aperture Radius (pixels)'.format(opt_ap) + \
    '\nNkeep      : {:<30d}# Number of points in light curve '.format(len(keep_ind)) + \
    '\nNdelete    : {:<30d}# Number of points removed'.format(len(dele_ind)) + \
    '\nColumns: Raw T-mid (s), BaryCorr T-mid (s), Target, Comp(s), Sky' 

# Header for the final lightcurve
lcfin_header = 'Object     : {:30s}# Name of Object'.format(obj_name) + \
    '\nRA         : {:30s}# Object Right Ascension'.format(ra_string) + \
    '\nDec        : {:30s}# Object Declination'.format(de_string) + \
    '\nTelescope  : {:30s}# Name of Telescope'.format(telescope) + \
    '\nInstrument : {:30s}# Name of Instrument'.format(instrument) + \
    '\nTeleCode   : {:30s}# Teledat Code Name'.format(telcode) + \
    '\nDate       : {:30s}# Mid-Exposure UTC Date at T0'.format(st.split("T")[0]) + \
    '\nUTC        : {:30s}# Mid-Exposure UTC Time at T0'.format(st.split("T")[1]) + \
    '\nMJD        : {:<30.9f}# Mid-Exposure UTC MJD at T0'.format(st_mjd) + \
    '\nExptime    : {:<30.6f}# Exposure Time (s)'.format(texp) + \
    '\nFilter     : {:30s}# Filter Name'.format(filter_name) + \
    '\nBJED       : {:<30.9f}# Mid-Exposure TDB JD at T0'.format(bjdref) + \
    '\nApPhot     : {:30s}# Photometry Program'.format(source_dict[psource]) + \
    '\nOrigFile   : {:30s}# Source Photometry Filename'.format(phot_names[opt_ind]) + \
    '\nApRadius   : {:<30.2f}# Aperture Radius (pixels)'.format(opt_ap) + \
    '\nAvgScatter : {:<30.2f}# Avg. Point-to-Point Scatter (%)'.format(opt_pp) + \
    '\nComps      : {:<30d}# Comparison stars used'.format(comp_number) + \
    '\nPolyOrder  : {:<30d}# Degree of Polynomial Division'.format(deg) + \
    '\nNkeep      : {:<30d}# Number of points in light curve '.format(len(keep_ind)) + \
    '\nNdelete    : {:<30d}# Number of points removed'.format(len(dele_ind)) + \
    '\nAuthor     : {:30s}# Author of this light curve'.format(author) + \
    '\nCreatedOn  : {:30s}# Date created'.format(Time.now().to_value("iso")) + \
    '\nColumns: Raw T-mid (s), BaryCorr T-mid (s), Rel. Flux, Rel. Flux Error' 

# Header for the raw photometry file
log_header = '    OBJECT = {}'.format(obj_name) + \
             '\nPOLYNOMIAL = {}'.format(deg) + \
             '\n     DTMID = {}'.format(st) + \
             '\n     COMPS = {}'.format(comp_ids) + \
             '\n      TEXP = {:.6f}'.format(texp)

# Generate output arrays (use save_user flag to choose appropriate lightcurve)
Ndat_all = len(time_stamps)
Ndat_keep = len(keep_ind)
rawtimes_saved = np.reshape(tkeep,(Ndat_keep,1))
bjdtimes_saved = np.reshape(tbjd_sec,(Ndat_keep,1))
orig_times_logged = np.reshape(time_stamps,(Ndat_all,1))
del_log = np.reshape(np.zeros(Ndat_all),(Ndat_all,1))
del_log[keep_ind] += 1 # Set to 1 for points which are kept (0 means deleted)
lcerr_saved = np.reshape(err,(Ndat_keep,1))
if save_user:
    if psource == 'hsp':
        phot_saved = phot_data[ind_ppmin].iloc[keep_ind,:].values
        output_phot = np.concatenate((rawtimes_saved,bjdtimes_saved,phot_saved),axis=1)
    elif psource == 'mae':
        phot_saved = phot_data[ind_ppmin].iloc[keep_ind,0:2*Nobj-1:2].values
        sky_saved = np.reshape(phot_data[ind_ppmin].iloc[keep_ind,1].values,(Ndat_keep,1))
        output_phot = np.concatenate((rawtimes_saved,bjdtimes_saved,phot_saved,sky_saved),axis=1)
    lc_saved = np.reshape(lc_all[ind_ppmin],(Ndat_keep,1))
    output_lc = np.concatenate((rawtimes_saved,bjdtimes_saved,lc_saved,lcerr_saved),axis=1)
    output_log = np.concatenate((orig_times_logged,del_log),axis=1)
else:
    if psource == 'hsp':
        phot_saved = phot_data[ind_comb_min].iloc[keep_ind,:].values
        output_phot = np.concatenate((rawtimes_saved,bjdtimes_saved,phot_saved),axis=1)
    elif psource == 'mae':
        phot_saved = phot_data[ind_comb_min].iloc[keep_ind,0:2*Nobj-1:2].values
        sky_saved = np.reshape(phot_data[ind_comb_min].iloc[keep_ind,1].values,(Ndat_keep,1))
        output_phot = np.concatenate((rawtimes_saved,bjdtimes_saved,phot_saved,sky_saved),axis=1)
    lc_saved = np.reshape(lc_comb[ind_comb_min],(Ndat_keep,1))
    output_lc = np.concatenate((rawtimes_saved,bjdtimes_saved,lc_saved,lcerr_saved),axis=1)
    output_log = np.concatenate((orig_times_logged,del_log),axis=1)


# Create the format strings for raw photometry and final lightcurve  
rawphot_format = '%10.3f  %10.3f  '
for i in range(Nobj+1):
    rawphot_format += '  %7.0f'
lcfin_format = '%10.3f  %10.3f  %9.6f  %9.6f'
log_format = '%10.3f  %i'

# Save the lightcurve + header info to file
lcphot_fname = '{}_{}.phot'.format(obj_name,print_date.replace("-","")).lower()
lcfin_fname = '{}_{}.lc'.format(obj_name,print_date.replace("-","")).lower()
lclog_fname = 'phot2lc_log.txt'
np.savetxt(lcphot_fname, output_phot, fmt=rawphot_format, header=rawphot_header)
np.savetxt(lcfin_fname, output_lc, fmt=lcfin_format, header=lcfin_header)
np.savetxt(lclog_fname, output_log, fmt=log_format, header=log_header)

print("\nLightcurve saved to file {}".format(lcfin_fname))
print("\nFinished! \n")


