#!/usr/bin/env python

"""
phot2lc takes output from aperture photometry
pipelines and aids in the extraction of optimal 
light curves. Current features include:
    > Comparison star selection
    > Aperture size selection
    > Polynomial detrending
    > Sigma clipping
    > Manual data point removal
    > Barycentric time corrections

Author: 
    Zach Vanderbosch

For a description of updates, see the 
version_history.txt file.

Usage:
    phot2lc
    phot2lc -h|--help
    phot2lc -t|--telescope
    phot2lc -s|--source
    phot2lc -i|--image
    phot2lc -o|--object

Options:
    -h --help      Show command line options
    -c --codes     Print a list of available telescope codes
    -t --telescope Code name for telescope used
    -s --source    Code name for photometry program used
    -i --image     Name of specific image instead of list
    -o --object    Name of object matching stars.dat entry

""" 

# Check that Python is 3.6 or later
import sys
if sys.version_info[0] < 3:
    print('ERROR: phot2lc incompatible with Python 2.X')
    print('Program exited.')
    sys.exit(1)
if (sys.version_info[0] == 3) & (sys.version_info[1] < 6):
    print('ERROR: phot2lc incompatible with Python 3.5 or earlier')
    print('Program exited.')
    sys.exit(1)

# Set the backend environment base on
# the system platform to get best 
# figure sizing results
import matplotlib as mpl
if sys.platform == 'darwin':   # Mac OSX
    mpl.use('QT5agg')
elif sys.platform == 'linux':  # Linux
    mpl.use('TKagg')
elif sys.platform == 'linux2': # Linux2
    mpl.use('TKagg')
elif sys.platform == 'win32':  # Windows
    mpl.use('QT5agg')
else:                          # Other
    mpl.use('TKagg')


# Import Standard Packages
import os
import json
import argparse
import warnings
import itertools
import numpy as np
import pandas as pd
from glob import glob
import PyQt5.QtCore as qtc
from itertools import combinations

# Import Astropy modules
import astropy.units as u
from astropy.io import fits
from astropy.utils import iers
from astropy.time import Time, TimeDelta
from astropy.coordinates import SkyCoord
from astropy.visualization import ZScaleInterval
from astropy.utils.exceptions import AstropyWarning

# Import matplotlib modules
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import MultipleLocator
from matplotlib.widgets import Cursor, MultiCursor
from matplotlib.widgets import RectangleSelector

# Import Custom Functions
from phot2lc.teledat import get_telinfo
from phot2lc.photfunc import progress_bar, print_commands
from phot2lc.photfunc import get_time, pp_scat, roll_std
from phot2lc.photfunc import div_lc, gen_compstr, calc_lsp
from phot2lc.photfunc import get_loc, window_std
from phot2lc.ucm_utils import read_ucm


#############################################################
## Load in the config.dat file and use it to set the 
## defaults for user-based parameters.

import phot2lc.teledat
config_path = os.path.dirname(os.path.realpath(phot2lc.teledat.__file__))
config_dat = []
with open(config_path + "/config.dat") as f:
    for l in f.readlines():
        config_dat.append(l.strip("\n").split("=")[-1].strip())


# Name for whoever is performing these reductions
author = config_dat[0]
# File containing the list of images being analyzed
list_name = config_dat[1]
# Optional File containing initial guesses at target + comp pixel
# locations. This is the same file used for IRAF ccd_hsp aperture
# photometry and is not needed if you don't have it.
tloc_name = config_dat[2]
# Path to file containing object names + coordinates
star_dat_filename = config_dat[3] 


# Defualt arguments for argparse
default_telescope = config_dat[4]
default_source = config_dat[5]
default_image = None if config_dat[6] == 'None' else config_dat[6]
default_object = None if config_dat[7] == 'None' else config_dat[7]


#############################################################
## Generate arguments for command line parsing

parser = argparse.ArgumentParser()
parser.add_argument('-c', '--codes',action='store_true',
                    help="If invoked, print list of available telescope codes.")
parser.add_argument('-t', '--telescope',type=str,default=default_telescope,
                    help="Code name for telescope used.")
parser.add_argument('-s', '--source',type=str,default=default_source,
                    help="Source of photometry (ccd_hsp or maestro)")
parser.add_argument('-i', '--image',type=str,default=default_image,
                    help="Whether to use a list of images, or just one.")
parser.add_argument('-o', '--object',type=str,default=default_object,
                    help="Name of object. If None, get object from FITS header.")
parser.add_argument('-b', '--barycorr',action='store_false',
                    help="If invoked, do NOT perform barycentric corrections.")
args = parser.parse_args()


#############################################################
# Perform checks on command line inputs

# Load list containing telescope information
telinfo = get_telinfo()

# Currently supported codes, telescopes, instruments
valid_codes = [x['code'] for x in telinfo]
valid_telescopes = [x['telename'] for x in telinfo]
valid_instruments = [x['instname'] for x in telinfo]

# Print out available codes if requested
if args.codes:
    print('\nSupported telescope codes:')
    for code,tele,inst in zip(valid_codes,valid_telescopes,valid_instruments):
        print("   {} = {} {}".format(code,tele,inst))
    print("")
    sys.exit(1)

# Check that telescope code is valid
telcode = args.telescope
if telcode in valid_codes:
    teldict = next(item for item in telinfo if item["code"] == telcode)
else:
    print('\nERROR! Supplied telescope code ' + \
          '"{}" is not currently supported.'.format(telcode))
    print('Supported telescope codes:')
    for code,tele,inst in zip(valid_codes,valid_telescopes,valid_instruments):
        print("   {} = {} {}".format(code,tele,inst))
    print("")
    sys.exit(1)


# Currently supported photometry sources
valid_sources = ['hsp','mae','hcm','ucm']
source_dict = {
    'hsp':'ccd_hsp',
    'mae':'maestro',
    'hcm':'hipercam',  # Currently supports HiPERCAM pipeline output v1.2.0
    'ucm':'ultracam'   # Currently supports ULTRACAM pipeline output v19/12/2005
}

# Check that photometry source is valid
psource = args.source
if psource in valid_sources:
    phot_source = source_dict[psource]
else:
    print('\nERROR! Supplied photometry source ' + \
          '"{}" is not currently supported.'.format(psource))
    print('Supported source codes:')
    for code in valid_sources:
        print("   {} = {}".format(code,source_dict[code]))
    print("")
    sys.exit(1)

# Set the barycorr parameter
barycorr = args.barycorr

# Get Camera Noise Characteristics
dark = teldict['dark'] # Dark Current in ADU/s/pixel
read = teldict['read'] # Read Noise in e-/s/pixel
gain = teldict['gain'] # Gain in e-/ADU


#############################################################
## Additional Setup

# Supress annoying Astropy Warning Messages
warnings.simplefilter('ignore', category=AstropyWarning)

# Temporarily disable some default matplotlib keymaps
plt.rcParams['keymap.home'] = ''
plt.rcParams['keymap.quit'] = ''
plt.rcParams['keymap.back'] = ''
plt.rcParams['keymap.forward'] = ''
plt.rcParams['keymap.grid'] = ''
plt.rcParams['keymap.save'] = ''
plt.rcParams['keymap.fullscreen'] = ''
plt.rcParams['keymap.grid'] = ''
plt.rcParams['keymap.grid_minor'] = ''

# Some additional setup steps for plotting
qtc.pyqtRemoveInputHook() # Supresses QCORE message during input
plt.style.use('dark_background') # Sets default plot style


#############################################################
## Start loading in data

# Grab all of the runbase Photometry files
if psource == 'hsp': # hsp_nd output
    phot_names_raw = glob('runbase*')
elif psource == 'mae': # maestro output
    phot_names_raw = glob('counts*')
elif psource == 'hcm': # hipercam output
    phot_names_raw = glob('output*log')
    if len(phot_names_raw) == 0:
        print('No HiPERCAM output *.log) file found with name ')
elif psource == 'ucm': # ultracam output
    phot_names_raw = glob('output*log')
    if len(phot_names_raw) == 0:
        print('No HiPERCAM output *.log) file found with name ')
Nf = len(phot_names_raw)


# Grab all or just one of the FITS files
if args.image is None:
    fits_list = np.loadtxt(list_name, dtype=str)
    num_fits = len(fits_list)
else:
    fits_list = [args.image]
    num_fits = 1


# Get apeture sizes from filenames
if psource in ['hsp','mae']:
    ap_sizes_raw = [float(f.split("/")[-1][7:]) for f in phot_names_raw]
elif psource in ['hcm','ucm']:
    ap_sizes_raw = [-1]

# Use ap_sizes to sort the phot_names list
ap_sizes = sorted(ap_sizes_raw)
phot_names = [f for _,f in sorted(zip(ap_sizes_raw,phot_names_raw))]


# Load in the Photometry for each aperture size
if psource in ['hsp','mae']:
    phot_data = [pd.read_csv(f,header=None,delim_whitespace=True).iloc[:,1:].astype('float64') 
                 for f in phot_names]
elif psource == 'hcm':

    # Get column names and rows to skip
    skip_rows = []
    with open(phot_names[0]) as f:
        start_row = -10
        for i,l in enumerate(f.readlines()):
            if 'Start of column name definitions' in l:
                start_row = i
            if (i == start_row+2):
                colnames = l.split("=")[1].strip().split(" ")
            if l[0] == "#":
                skip_rows.append(i)

    df = pd.read_csv(phot_names[0],header=None,delim_whitespace=True,
        names=colnames,skiprows=skip_rows)
    phot_data = [df]

elif psource == 'ucm':

    # Load in the data
    skip_rows = []
    with open(phot_names[0]) as f:
        for i,l in enumerate(f.readlines()):
            if l[0] == "#":
                skip_rows.append(i)
            else:
                break
    df = pd.read_csv(phot_names[0],
        header=None, delim_whitespace=True, skiprows=skip_rows)
    Nphotobj = int((len(df.columns)-7)/14)

    # Generate column names that are consistent with HiPERCAM colnames
    baseColnames = ['name','MJD','MJDok','Exptim','CCD','mfwhm','mbeta']
    objColnames = ['naper','x','y','xm','ym','exm','eym','counts','countse',
    'sky','nsky','nrej','worst','flag']

    allobjColnames = [['{}_{}'.format(cname,x+1) for cname in objColnames] for x in range(Nphotobj)]
    colnames = baseColnames + list(itertools.chain.from_iterable(allobjColnames))
    df.columns = colnames

    phot_data = [df]



          
if psource == 'hsp':
    Nobj = len(phot_data[0].columns) - 1  # Number of stars (last column is sky)
elif psource == 'mae':
    Nobj = int(len(phot_data[0].columns)/2)
elif psource == 'hcm':
    Nobj = int((len(phot_data[0].columns)-7)/16)
elif psource == 'ucm':
    Nobj = int((len(phot_data[0].columns)-7)/14)
Ndat = len(phot_data[0])  # Number of data points


# If Phot2lc has been run previously, load in the Log File
log_exists = False
if os.path.exists('phot2lc_log.txt'):
    log_header = pd.read_csv('phot2lc_log.txt',header=None,
                              nrows=5,delimiter='?').values
    log_data = np.loadtxt('phot2lc_log.txt')
    log_exists = True


#############################################################
## Get some basic info about the object

# Pulls some basic info out of the header
telescope = teldict['telename']
instrument = teldict['instname']

if psource in ['hsp','mae','hcm']:
    hdr = fits.getheader(fits_list[0])
    observer = hdr[teldict['observer']]
    filter_name = hdr[teldict['filter']]
    if args.object is None:
        obj_name = hdr[teldict['objname']].replace(" ","")
        if obj_name is None:
            print("Error! Object Name in teledat.py cannot be None.")
            sys.exit(1)
    else:
        obj_name = args.object
elif psource == 'ucm':
    _,hdr = read_ucm(fits_list[0])
    observer = hdr['Observer'][0]
    filter_name = hdr['Filter'][0]
    if args.object is None:
        obj_name = hdr['Object'][0].replace(" ","")
        if obj_name is None:
            print("Error! Object Name in teledat.py cannot be None.")
            sys.exit(1)
    else:
        obj_name = args.object

if len(obj_name) == 0:
    print('\nError! Object name is an empty string. Please provide valid object name.\n')
    sys.exit(1)


# Get object locations from "tloc" file
if psource in ['hsp','mae']:
    try:
        tloc = np.loadtxt(tloc_name,usecols=(0,1))
    except:
        tloc = np.array([])

elif (psource == 'hcm'):
    try:
        with open(tloc_name) as fp:
            aper_data = json.load(fp)
    except:
        tloc = np.array([])
        pass

    # Currently this only handles single-CCD Hipercam pipeline output
    hcm_first_file = fits_list[0]
    if hcm_first_file.split(".")[-1] != "hcm":
        print(f'Error! Input image {hcm_first_file} is not a .hcm file.')
        sys.exit()

    hcm_hdr = fits.getheader(hcm_first_file,1)
    aper_data = aper_data[1][1][1:]
    Naper = len(aper_data)
    tloc = np.zeros((Naper,2))
    for i,ad in enumerate(aper_data):
        tloc[i,0] = ad[1]['x']/float(hcm_hdr['XBIN'])
        tloc[i,1] = ad[1]['y']/float(hcm_hdr['YBIN'])

elif psource == 'ucm':
    try:
        aper_data = []
        with open(tloc_name) as fp:
            for fl in fp.readlines():
                if fl[0] == 'x':
                    aper_data.append(fl.strip("\n"))
    except:
        tloc = np.array([])
        pass

    # Currently this only handles single-CCD Ultracam pipeline output
    ucm_first_file = fits_list[0]
    if ucm_first_file.split(".")[-1] != "ucm":
        print(f'Error! Input image {ucm_first_file} is not a .ucm file.')
        sys.exit()

    _,ucm_hdr = read_ucm(ucm_first_file)
    Naper = len(aper_data)
    tloc = np.zeros((Naper,2))
    for i,ad in enumerate(aper_data):
        xy = ad.split(";")[0].split("=")[-1].strip().split(",")
        xyoff = ad.split(";")[1].split("=")[-1].strip().split(",")
        tloc[i,0] = float(xy[0].strip()) + float(xyoff[0].strip())
        tloc[i,1] = float(xy[1].strip()) + float(xyoff[1].strip())


# Check to make sure this object has an entry in stars.dat file
star_dat = pd.read_csv(star_dat_filename,header=None,delim_whitespace=True,dtype=str)
obj_idx = star_dat.index[star_dat.iloc[:,0] == obj_name]
if len(obj_idx) == 0:
    print("\nWARNING: No objects by name of '{}' found in '{}'\n".format(obj_name,star_dat_filename))
    sys.exit(1)
elif len(obj_idx) > 1:
    print("\nWARNING: Multiple entries for '{}' found in '{}'\n".format(obj_name,star_dat_filename))
    sys.exit(1)


#############################################################
# Show the first image along with marked targets/comps
# Get Data & Header Info for first image
if psource == 'ucm':
    image0,_ = read_ucm(fits_list[0])
else:
    if telcode == 'mcd2':
        if psource == 'hcm':
            image0 = fits.getdata(fits_list[0])
        else:
            image0 = fits.getdata(fits_list[0])[0]
    else:
        image0 = fits.getdata(fits_list[0])
imrows,imcols = np.shape(image0)

# Get Z-Scale Normalization vmin and vmax
ZS = ZScaleInterval(nsamples=10000, contrast=0.15, max_reject=0.5, 
                    min_npixels=5, krej=2.5, max_iterations=5)
vmin0,vmax0 = ZS.get_limits(image0)

# Plot images with sources marked
figsize = (10,7)
imfig = plt.figure(10,figsize=figsize)
gsim = GridSpec(1,1)
im = imfig.add_subplot(gsim[0])

# Set Window Name
imfig.canvas.manager.set_window_title('First Image')

# Font Choices (better fonts for macOSX)
if sys.platform == 'darwin':
    fonta = {'fontname': 'AppleGothic',
              'weight': 'heavy'}
    fontb = {'fontname': 'AppleGothic',
             'weight': 'normal'}
else:
    fonta = {'fontname': 'DejaVu Sans Mono',
              'weight': 'heavy'}
    fontb = {'fontname': 'DejaVu Sans Mono',
             'weight': 'normal'}

# Show the first image
im.imshow(image0, cmap='gray',vmin=vmin0, vmax=vmax0)

# Plot marker for each object
for i in range(len(tloc)):
    if i == 0:
        im.plot(tloc[i,0],tloc[i,1],
                ls='None',marker='o',ms=13,mew=2.0,mec='c',mfc='None')
        im.text(tloc[i,0],tloc[i,1]+0.055*imrows,'Target',
                ha='center',fontsize=12,**fonta)
    else:
        im.plot(tloc[i,0],tloc[i,1],
                ls='None',marker='o',ms=13,mew=2.0,mec='m',mfc='None')
        im.text(tloc[i,0],tloc[i,1]+0.055*imrows,'Comp {}'.format(i),
                ha='center',fontsize=12,**fonta)

# Make Axes invisible
plt.axis("off")

# Add title
titleim = '{}'.format(fits_list[0].split("/")[-1])
im.set_title(titleim,loc='center',fontsize=14,**fontb)

# Save the figure
im_name = '{}_firstframe.png'.format(obj_name)
plt.savefig(im_name,format='png',dpi=100,bbox_inches='tight',pad_inches=0.15)

# Add some basic interactive capabilities
def on_event_im(event):
    global grid_search

    # Option to close the figure and Save the LightCurve
    if event.key == 'W':
        grid_search = False
        plt.close("all")
    # Option to close the figure and Save the LightCurve
    if event.key == 'G':
        grid_search = True
        plt.close("all")
    # Option to close the figure with Saving the LightCurve
    if (event.key == 'Q') | (event.key == 'q'):
        plt.close("all")
        print('\nProgram exited. Lightcurve was not saved.\n')
        sys.exit(1)
    # Option to reprint the command list
    if event.key == '?':
        print_commands()

# Connect incoming events to event handlers
cidim = imfig.canvas.mpl_connect('key_press_event', on_event_im)
#############################################################


#############################################################
# First need to generate the time stamps for data points

# If no previous log exists, generate time stamps from FITS images
if not log_exists:

    if psource in ['hsp','mae']:
        # Time stamp for first exposure
        dt_zero_start,texp = get_time(fits_list[0],teldict)
        # Convert dates & times to ISOT format for Astropy Time
        dt_zero_mid = dt_zero_start + TimeDelta(texp/2.0,format='sec')
    elif psource in ['hcm','ucm']:
        texp = phot_data[0].Exptim.median()
        dt_zero_mid = Time(phot_data[0].MJD.min(),format='mjd',scale='utc')
        dt_zero_start = dt_zero_mid - TimeDelta(texp/2.0,format='sec')

    if psource == 'mae':
        time_stamps = np.loadtxt(phot_names[0])[:,0]*86400.
    elif psource in ['hcm','ucm']:
        time_stamps = (phot_data[0].MJD.values - phot_data[0].MJD.min()) * 86400.
    elif psource == 'hsp':
        if num_fits == 1:
            time_stamps = np.arange(Ndat,dtype=float)*texp
        else:
            ## Create list to store the relative time differences
            time_stamps = np.zeros(num_fits)
            print('')
            for i in range(num_fits):
                    
                ## Print Progress Bar
                count1  = i+1
                action1 = 'Grabbing times from FITS...'
                progress_bar(count1, num_fits, action1)
                    
                ## Get the date/time/texp for current frame
                dt_start,texp = get_time(fits_list[i],teldict)

                # Convert dates & times to ISOT format for Astropy Time
                dt_mid = dt_start + TimeDelta(texp/2.0,format='sec')
                    
                ## Calculate the time difference in seconds
                tdelta = dt_mid - dt_zero_mid
                tdelta_sec = tdelta.to_value(u.s)
                    
                ## Append time difference to time_stamps list
                time_stamps[i] = tdelta_sec
    print('')

# If a log does exist, use it to generate the time stamps
elif log_exists:
    if psource in ['hsp','mae']:
        # First check whether the first image still exists
        # If so, get the first timestamp from there
        if os.path.exists(fits_list[0]):
            dt_zero_start,texp = get_time(fits_list[0],teldict)
            dt_zero_mid = dt_zero_start + TimeDelta(texp/2.0,format='sec')
            time_stamps = log_data[:,0]
        else:
            dt_zero_mid_str = log_header[2][0].split("=")[-1].strip()
            texp = float(log_header[4][0].split("=")[-1].strip())
            dt_zero_mid = Time(dt_zero_mid_str, scale='utc',format='isot')
            dt_zero_start = dt_zero_mid - TimeDelta(texp/2.0,format='sec')
            time_stamps = log_data[:,0]
    elif psource in ['hcm','ucm']:
        dt_zero_mid_str = log_header[2][0].split("=")[-1].strip()
        texp = float(log_header[4][0].split("=")[-1].strip())
        dt_zero_mid = Time(dt_zero_mid_str, scale='utc',format='isot')
        dt_zero_start = dt_zero_mid - TimeDelta(texp/2.0,format='sec')
        time_stamps = log_data[:,0]

print('')


#############################################################
# Make a WQED-like plot showing all the raw photometry
figa = plt.figure(0,figsize=figsize)
gsa = GridSpec(1,1)
ax_LC = figa.add_subplot(gsa[0])

# Set Window Name
figa.canvas.manager.set_window_title('Raw Photometry')

# Dark Background Color List
dbk_clist = plt.rcParams['axes.prop_cycle'].by_key()['color']
dbk_clist += dbk_clist

# Plot sky counts, normalized to range from 0->1
try:
    ap4_loc = ap_sizes.index(4.0)
except:
    ap4_loc = 0
if psource == 'hsp':
    sky_flux = phot_data[ap4_loc][Nobj+1]
elif psource == 'mae':
    sky_flux = phot_data[ap4_loc][2]
elif psource in ['hcm','ucm']:
    sky_flux = phot_data[0].sky_1 * (np.pi*(phot_data[0].mfwhm**2))
sky_norm = sky_flux / sky_flux.mean()
ax_LC.plot(time_stamps,(sky_norm - min(sky_norm))/max(sky_norm - min(sky_norm)),
           ls='None',marker='.',ms=3,mfc='#3C87FA',mec='#3C87FA')

# Pandas column-IDs for all comps
# If phot2lc Log exists, initialize comp_select with previously chosen comps
if psource == 'hsp':
    comp_ids = list(range(2,Nobj+1))
elif psource == 'mae':
    comp_ids = list(range(3,2*Nobj,2))
elif psource in ['hcm','ucm']:
    comp_ids = [f'counts_{x}' for x in range(2,Nobj+1)]
    count_ids = ['counts_1'] + comp_ids


if not log_exists:
    comp_select = comp_ids.copy()
elif log_exists:
    if psource == 'hsp':
        comp_select = [int(x)+1 for x in log_header[3][0].split("=")[-1].split("+")]
    elif psource == 'mae':
        comp_select = [2*int(x)+1 for x in log_header[3][0].split("=")[-1].split("+")]
    elif psource in ['hcm','ucm']:
        comp_select = [f'counts_{int(x)+1}' for x in log_header[3][0].split("=")[-1].split("+")]

# Plot Stellar Flux
for i in range(Nobj):
    if psource == 'hsp':
        norm = phot_data[ap4_loc][i+1] / phot_data[ap4_loc][i+1].mean()
    elif psource == 'mae':
        norm = phot_data[ap4_loc][2*i+1] / phot_data[ap4_loc][2*i+1].mean()
    elif psource in ['hcm','ucm']:
        norm = phot_data[0][f'counts_{i+1}'] / phot_data[0][f'counts_{i+1}'].mean()
    norm2 = (norm - min(norm))/max(norm - min(norm))
    ax_LC.plot(time_stamps,norm2 + Nobj-i+0.5,ls='None',marker='.',ms=3,
               mfc=dbk_clist[i],mec=dbk_clist[i])

# Set y limit
ax_LC.set_ylim(0,Nobj+2)

# Set tick parameters
ax_LC.tick_params(which='both',axis='both',labelleft=False,left=True,right=True)
ax_LC.tick_params(which='both',axis='y',direction='in')

# Add Labels for each light curve shown
xdiff = ax_LC.get_xlim()[1] - ax_LC.get_xlim()[0]
xloc = ax_LC.get_xlim()[1] + 0.01*xdiff
for i in range(Nobj+1):
    if i == 0:
        lc_labela = 'Target'
        if psource == 'hsp':
            lc_labelb = '{:.0f}'.format(phot_data[ap4_loc][i+1].mean())
        elif psource == 'mae':
            lc_labelb = '{:.0f}'.format(phot_data[ap4_loc][2*i+1].mean())
        elif psource in ['hcm','ucm']:
            lc_labelb = '{:.0f}'.format(phot_data[0]['counts_1'].mean())
        lc_color = dbk_clist[i]
        ax_LC.text(xloc,float(Nobj+1-i)-0.00,lc_labela,
                   fontsize=13,color=lc_color,**fonta)
        ax_LC.text(xloc,float(Nobj+1-i)-0.30,lc_labelb,
                   fontsize=10,color=lc_color,**fonta)
    elif (i > 0) & (i < Nobj):
        lc_labela = 'Comp {}'.format(i)
        if psource == 'hsp':
            lc_labelb = '{:.0f}'.format(phot_data[ap4_loc][i+1].mean())
        elif psource == 'mae':
            lc_labelb = '{:.0f}'.format(phot_data[ap4_loc][2*i+1].mean())
        elif psource in ['hcm','ucm']:
            lc_labelb = '{:.0f}'.format(phot_data[0][f'counts_{i+1}'].mean())
        lc_color = dbk_clist[i]
        ax_LC.text(xloc,float(Nobj+1-i)-0.00,lc_labela,
                   fontsize=13,color=lc_color,**fonta)
        ax_LC.text(xloc,float(Nobj+1-i)-0.30,lc_labelb,
                   fontsize=10,color=lc_color,**fonta)
    else:
        lc_labela = 'Sky'
        lc_labelb = '{:.0f}'.format(sky_flux.mean())
        lc_color = '#3C87FA'
        ax_LC.text(xloc,float(Nobj+1-i)-0.50,lc_labela,
                   fontsize=13,color=lc_color,**fonta)
        ax_LC.text(xloc,float(Nobj+1-i)-0.80,lc_labelb,
                   fontsize=10,color=lc_color,**fonta)


# Add XY labels
ax_LC.set_xlabel('Time (s)',fontsize=12,**fontb)
ax_LC.set_ylabel('Photometric Counts',fontsize=12,**fontb)

# Add title
print_date = dt_zero_start.to_value('iso')[0:10]
print_time = dt_zero_start.to_value('iso')[11:19]
titlea = 'Raw Photometry for {}   UT-Date: {}'.format(obj_name,print_date)
ax_LC.set_title(titlea,loc='center',fontsize=14,**fontb)

# Connect incoming events to event handlers
cida = figa.canvas.mpl_connect('key_press_event', on_event_im)
###################################################


###################################################
# Now to Ceate the Interactive Figure where outlier
# and bad/weather points can be removed, a polynomial 
# fit can be chosen, and comp stars can be chosen

figb = plt.figure(1,figsize=figsize)
gsb = GridSpec(3,3,height_ratios=[1,1.2,3],hspace=0.0,wspace=0.0)
bx_LC = figb.add_subplot(gsb[2,:])
bx_mod = figb.add_subplot(gsb[0,0])
bx_comp = figb.add_subplot(gsb[1,:])
bx_text = figb.add_subplot(gsb[0,1:])

# Make axes invisible for text axis
#bx_text.axis('off')
bx_text.tick_params(which='both',labelbottom=False, labelleft=False, 
                    left=False, bottom=False)
bx_text.set_facecolor('xkcd:charcoal')
bx_text.spines['bottom'].set_color('w')
bx_text.spines['top'].set_color('w') 
bx_text.spines['right'].set_color('w')
bx_text.spines['left'].set_color('w')

# Set Window Name
figb.canvas.manager.set_window_title('Divided Light Curve')

# Tick Parameters
bx_LC.tick_params(which='both',axis='both',right=True,top=True,direction='in')
bx_mod.tick_params(which='both',axis='both',labelbottom=False,
                   labelleft=False,left=False,bottom=False)
bx_comp.tick_params(which='both',axis='both',labelbottom=False,direction='in',
                   labelleft=False,left=False,bottom=True)

# Initialize the plot
apnum = 0   # First Aperture
if psource in ['hsp','mae']:
    phot_target = phot_data[apnum][1].replace(0.0,1.0).values
    phot_comps = np.sum(phot_data[apnum][comp_select].replace(0.0,1.0).values,axis=1)
elif psource in ['hcm','ucm']:
    phot_target = phot_data[0]['counts_1'].replace(0.0,1.0).values
    phot_comps = np.sum(phot_data[0][comp_select].replace(0.0,1.0).values,axis=1)

# If Phot2lc has not been previously run, use default initialization
if not log_exists:
    polypack = (0,0,3.0,3.0)      # No polynomial fit to start
    dele_ind = []                 # No points deleted
    keep_ind = list(range(Ndat))  # Keeping all points
# If phot2lc has been run, use the log to initialize
# the polynomial fit and the lists of deleted/kept points
elif log_exists:
    if len(log_header[1][0].split("=")[-1].strip()) < 5:
        polypack = (int(log_header[1][0].split("=")[-1].strip()),0,3.0,3.0)
    else:
        polypack = (int(log_header[1][0].split("=")[-1].strip().split(",")[0]),
                    int(log_header[1][0].split("=")[-1].strip().split(",")[1]),
                    float(log_header[1][0].split("=")[-1].strip().split(",")[2]),
                    float(log_header[1][0].split("=")[-1].strip().split(",")[3]))
    dele_ind = [i for i,x in enumerate(log_data[:,1]) if x == 0]
    keep_ind = [i for i,x in enumerate(log_data[:,1]) if x == 1]

# Create one version of the divided light curve without any deleted points
dlc_init_all,_,_,_ = div_lc(time_stamps,phot_target,phot_comps,
                            polypack,list(range(Ndat)),[])

# Create version with points deleted if specified
dlc_init,dellc_init,draw_init,mod_init = div_lc(time_stamps,phot_target,phot_comps,
                                                polypack,keep_ind,dele_ind)


# Plot Light Curve
p0, = bx_LC.plot(time_stamps[keep_ind],dlc_init,ls='-',c='w',lw=0.5,
                 marker='o',ms=3,mew=0.75,mec='w',mfc='w')
# Show deleted points
p1, = bx_LC.plot(time_stamps[dele_ind],dellc_init,ls='None',
                 marker='o',ms=3,mew=0.75,mec='r',mfc='r')
# Show Polynomial Fit
p2, = bx_mod.plot(time_stamps[keep_ind],mod_init,
                  ls='-',c='r',lw=1.5,zorder=2)
p3, = bx_mod.plot(time_stamps[keep_ind],draw_init,
                  ls='None',marker='.',ms=2,mew=0,mfc='w',zorder=1)
# Show Summed Comp Star Light Curve
p4, = bx_comp.plot(time_stamps[keep_ind],phot_comps[keep_ind],
                   ls='None',marker='.',ms=4,mew=0,mfc='C0')
# Show Summed Comp Star Light Curve Deleted Points
p5, = bx_comp.plot(time_stamps[dele_ind],phot_comps[dele_ind],
                   ls='None',marker='.',ms=4,mew=0,mfc='r')

# Add a cursor to the divided light curve and comparison star plots
multicursor = MultiCursor(figb.canvas, (bx_LC, bx_comp), 
                          useblit=True, horizOn=True, 
                          c='r', lw=0.5, ls='--')

# Set initial XY limits
xlow_LC = min(time_stamps)
xupp_LC = max(time_stamps)
ylow_LC = np.nanmin(dlc_init_all[~np.isinf(dlc_init_all)])
yupp_LC = np.nanmax(dlc_init_all[~np.isinf(dlc_init_all)])
ylow_mod = np.nanmin(draw_init)
yupp_mod = np.nanmax(draw_init)
ylow_comp = np.nanmin(phot_comps)
yupp_comp = np.nanmax(phot_comps)

xdiff_LC = xupp_LC - xlow_LC
ydiff_LC = yupp_LC - ylow_LC
ydiff_mod = yupp_mod - ylow_mod
ydiff_comp = yupp_comp - ylow_comp

bx_LC.set_xlim(xlow_LC-0.05*xdiff_LC, xupp_LC+0.05*xdiff_LC)
bx_LC.set_ylim(ylow_LC-0.30*ydiff_LC, yupp_LC+0.30*ydiff_LC)
bx_mod.set_ylim(ylow_mod-0.30*ydiff_mod, yupp_mod+0.30*ydiff_mod)
bx_comp.set_xlim(bx_LC.get_xlim())
bx_comp.set_ylim(ylow_comp-0.30*ydiff_comp, yupp_comp+0.30*ydiff_comp)


# Add XY labels
bx_LC.set_xlabel('Time (s)',fontsize=12,**fontb)
bx_LC.set_ylabel('Normalized Flux',fontsize=12,**fontb)

# Add the initial title
p2p_scat = pp_scat(dlc_init)
# titleb = 'Object: {}   UT-Date: {}   UT-Time: {} \n\n'.format(
#           obj_name,print_date,print_time)
#bx_LC.set_title(titleb,loc='center',y=1.23,fontsize=13,**fontb)

# Add text to text axis
top_text = f"Object: {obj_name}"
bottom_text = f"UT-Date: {print_date}   UT-Time: {print_time}"
bx_text.text(0.5,0.70,top_text,ha='center',va='center',
             fontsize=14.5,transform=bx_text.transAxes,**fontb)
bx_text.text(0.5,0.30,bottom_text,ha='center',va='center',
             fontsize=14.5,transform=bx_text.transAxes,**fontb)

# Add title to polynomial plot showing the poly order
title_poly = 'Poly Order = {}'.format(polypack[0])
bx_mod.set_title(title_poly,loc='left',fontsize=11,x=0.02,y=0.74,**fontb)

# Add title to comparison star plot showing which comps are used
comp_str = gen_compstr(comp_select,psource)
title_comp = 'Comps {}'.format(comp_str)
bx_comp.set_title(title_comp,loc='left',fontsize=11,x=0.01,y=0.75,**fontb)

# Add text showing the aperture selection and P2P scatter
apscat_str =  'Aperture Radius: {:.2f} pix {:5}'.format(ap_sizes[apnum],'') + \
              'P2P Scatter: {:.2f} %'.format(p2p_scat*1e2)
tx0 = bx_LC.text(0.015,0.92,apscat_str,
                 transform=bx_LC.transAxes,fontsize=11,**fontb)

# Set some more initial parameters
num_del = len(dele_ind)                 # Tracks number of deleted points
keep_x = np.copy(time_stamps[keep_ind]) # x-values kept
keep_y = np.copy(dlc_init)              # y-values kept
dele_x = np.copy(time_stamps[dele_ind]) # x-values deleted
dele_y = np.copy(dellc_init)            # y-values deleted
show_delps = True             # Whether to plot deleted point or not
zoomed = False                # Whether the plot is zoomed in or not

###################################################
# Functions for plotting

# Function to update the plots
def update_plot(kx,ky,dx,dy,raw,mod,kcomp,dcomp):
    global bx_LC,bx_mod,bx_comp,p0,p1,p2,p3,p4
    global comp_select,show_delps,apnum,zoomed

    # Update data
    p0.set_data(kx,ky)   # Update Kept points in final DLC
    p2.set_data(kx,mod)  # Update Polynomial model
    p3.set_data(kx,raw)  # Update raw DLC used for polyfit
    p4.set_data(kx,kcomp) # Update the the comparison star data

    # Update XY limits and display of deleted points for divided light curve
    if show_delps:
        p1.set_data(dx,dy)
        p5.set_data(dx,dcomp)
        if len(dx) > 0:
            minx,maxx = min(min(kx),min(dx)),max(max(kx),max(dx))
            miny = min(np.nanmin(ky[~np.isinf(ky)]),
                       np.nanmin(dy[~np.isinf(dy)]))
            maxy = max(np.nanmax(ky[~np.isinf(ky)]),
                       np.nanmax(dy[~np.isinf(dy)]))
            miny_comp = min(np.nanmin(kcomp[~np.isinf(kcomp)]),
                            np.nanmin(dcomp[~np.isinf(dcomp)]))
            maxy_comp = max(np.nanmax(kcomp[~np.isinf(kcomp)]),
                            np.nanmax(dcomp[~np.isinf(dcomp)]))
        else:
            minx,maxx = min(kx),max(kx)
            miny = np.nanmin(ky[~np.isinf(ky)])
            maxy = np.nanmax(ky[~np.isinf(ky)])
            miny_comp = np.nanmin(kcomp[~np.isinf(kcomp)])
            maxy_comp = np.nanmax(kcomp[~np.isinf(kcomp)])
    else:
        p1.set_data([],[])
        p5.set_data([],[])
        minx,maxx = min(kx),max(kx)
        miny = np.nanmin(ky[~np.isinf(ky)])
        maxy = np.nanmax(ky[~np.isinf(ky)])
        miny_comp = np.nanmin(kcomp[~np.isinf(kcomp)])
        maxy_comp = np.nanmax(kcomp[~np.isinf(kcomp)])

    xdiff = maxx-minx
    ydiff = maxy-miny
    ydiff_comp = maxy_comp-miny_comp

    if not zoomed:  # Adjust xy limits if not already zoomed in
        bx_LC.set_xlim(minx-0.05*xdiff,maxx+0.05*xdiff)
        bx_LC.set_ylim(miny-0.30*ydiff,maxy+0.30*ydiff)
        bx_comp.set_xlim(bx_LC.get_xlim())
    bx_comp.set_ylim(miny_comp-0.3*ydiff_comp,maxy_comp+0.3*ydiff_comp)

    # Update XY Limits of bx_mod axis
    miny_mod = min(raw)
    maxy_mod = max(raw)
    xdiff_mod = np.nanmax(kx) - np.nanmin(kx)
    ydiff_mod = maxy_mod - miny_mod
    bx_mod.set_xlim(np.nanmin(kx)-0.05*xdiff_mod,np.nanmax(kx)+0.05*xdiff_mod)
    bx_mod.set_ylim(miny_mod-0.3*ydiff_mod,maxy_mod+0.3*ydiff_mod)

    # Update the polynomial order title
    title_poly = 'Poly Order = {}'.format(polypack[0])
    bx_mod.set_title(title_poly,loc='left',fontsize=11,x=0.02,y=0.74,**fontb)

    # Update the comparison star title
    comp_str = gen_compstr(comp_select,psource)
    title_comp = 'Comps {}'.format(comp_str)
    bx_comp.set_title(title_comp,loc='left',fontsize=11,x=0.01,y=0.75,**fontb)

    # Add text showing the aperture selection and P2P scatter
    p2p_scat = pp_scat(ky)
    apscat_str =  'Aperture Radius: {:.2f} pix {:5}'.format(ap_sizes[apnum],'') + \
                  'P2P Scatter: {:.2f} %'.format(p2p_scat*1e2)
    tx0.set_text(apscat_str)

    plt.draw()
    return

# Function to find index of nearest data point to cursor
def get_index(ax_obj,p,x,y,xe,ye):
    trans = ax_obj.transData.transform(list(zip(x,y)))
    dist = ((trans[:,0]-xe)**2 + (trans[:,1]-ye)**2)**0.2
    index = dist.argmin()
    return index

# Function to update plot title for messages to User
def message(ax_obj,s):
    ax_obj.set_title(s)
    plt.draw()

# Function to activate and deactivate the Rectangle garbage selector
def toggle_garbage(show=True):
    if toggle_garbage.RS.active:
        if show:
            print(' Garbage Selector deactivated.')
        toggle_garbage.RS.set_active(False)
    elif not toggle_garbage.RS.active:
        if show:
            print(' Garbage Selector activated.')
        toggle_garbage.RS.set_active(True)

# Function to activate and deactivate the Rectangle garbage selector
def toggle_reverse_garbage(show=True):
    if toggle_reverse_garbage.RS.active:
        if show:
            print(' Reverse Garbage Selector deactivated.')
        toggle_reverse_garbage.RS.set_active(False)
    elif not toggle_reverse_garbage.RS.active:
        if show:
            print(' Reverse Garbage Selector activated.')
        toggle_reverse_garbage.RS.set_active(True)

# # Function to activate and deactivate Rectangle zoom selector
def toggle_zoom(show=True):
    if toggle_zoom.RS.active:
        if show:
            print(' Zoom Selector deactivated.')
        toggle_zoom.RS.set_active(False)
    elif not toggle_zoom.RS.active:
        if show:
            print(' Zoom Selector activated.')
        toggle_zoom.RS.set_active(True)

# Function to update the divided light curves
def dlc_update(kx,ky,dx,dy,polyinfo,aper):
    global bx_LC,p0,p1,p2,p3,p4
    global dele_ind,keep_ind,comp_select,show_delps

    # Get indices of kept/deleted points
    dele_ind = [np.where(time_stamps == d)[0][0] for d in dx]
    keep_ind = [np.where(time_stamps == k)[0][0] for k in kx]

    # Calculate new divided light curve
    if (psource == 'hsp') | (psource == 'mae'):
        phot_target = phot_data[aper][1].replace(0.0,1.0).values
        phot_comps = np.sum(phot_data[aper][comp_select].replace(0.0,1.0).values,axis=1)
    elif psource == 'hcm':
        phot_target = phot_data[0]['counts_1'].replace(0.0,1.0).values
        phot_comps = np.sum(phot_data[0][comp_select].replace(0.0,1.0).values,axis=1)
    klc,dlc,kraw,kmod = div_lc(time_stamps,phot_target,phot_comps,
                               polyinfo,keep_ind,dele_ind)

    # Update the figure
    update_plot(kx,klc,dx,dlc,kraw,kmod,phot_comps[keep_ind],phot_comps[dele_ind])

    return klc, dlc, keep_ind, dele_ind


# Function for Garbage Selector Rectangle to Call
def garbage_select(eclick, erelease):
    global apnum,bx_LC,p0,p1,p2,p3,p4,num_del,polypack
    global keep_x,keep_y,keep_ind,zoomed
    global dele_x,dele_y,dele_ind,show_delps

    # Find indices of data points within drawn box
    xlow,xupp = sorted([eclick.xdata,erelease.xdata])
    ylow,yupp = sorted([eclick.ydata,erelease.ydata])

    ind_box = np.where((keep_x > xlow) & (keep_x < xupp) &
                       (keep_y > ylow) & (keep_y < yupp))

    if len(ind_box[0]) == 0:
        return  # No data outside the box
    else:
        ind_box = ind_box[0]

    # If box will deletes too many points, don't allaow it
    if len(keep_x) - len(ind_box) <= 1:
        print('WARNING: Cannot delete any more points!')
        return
    else:
        # Delete data points from lightcurve
        dele_x = np.append(dele_x,keep_x[ind_box])
        dele_y = np.append(dele_y,keep_y[ind_box])
        keep_x = np.delete(keep_x,ind_box)
        keep_y = np.delete(keep_y,ind_box)
        num_del += len(ind_box)

        # Resort the keep array
        keep_x = keep_x[keep_x.argsort()]
        keep_y = keep_y[keep_x.argsort()]

        # Update divided light curve and plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,
                                                     polypack,apnum)

        # Deactivate the garbage selector
        toggle_garbage()

        return


# Function for Garbage Selector Rectangle to Call
def reverse_garbage_select(eclick, erelease):
    global apnum,bx_LC,p0,p1,p2,p3,p4,num_del,polypack
    global keep_x,keep_y,keep_ind,zoomed
    global dele_x,dele_y,dele_ind,show_delps

    # Find indices of data points outside the drawn box
    xlow,xupp = sorted([eclick.xdata,erelease.xdata])
    ylow,yupp = sorted([eclick.ydata,erelease.ydata])

    ind_box = np.where((keep_x < xlow) | (keep_x > xupp) |
                       (keep_y < ylow) | (keep_y > yupp))

    if len(ind_box[0]) == 0:
        return  # No data outside the box
    else:
        ind_box = ind_box[0]

    # If box will delete too many points, don't allaow it
    if len(keep_x) - len(ind_box) <= 1:
        print('WARNING: Cannot delete any more points!')
        return
    else:
        # Delete data points from lightcurve
        dele_x = np.append(dele_x,keep_x[ind_box])
        dele_y = np.append(dele_y,keep_y[ind_box])
        keep_x = np.delete(keep_x,ind_box)
        keep_y = np.delete(keep_y,ind_box)
        num_del += len(ind_box)

        # Resort the keep array
        keep_x = keep_x[keep_x.argsort()]
        keep_y = keep_y[keep_x.argsort()]

        # Update divided light curve and plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,
                                                     polypack,apnum)

        # Deactivate the garbage selector
        toggle_reverse_garbage()

        return

# Function for Zoom Selector Rectangle to Call
def zoom_select(eclick, erelease):
    global bx_LC,zoomed

    # Find indices of data points within drawn box
    xlow,xupp = sorted([eclick.xdata,erelease.xdata])
    ylow,yupp = sorted([eclick.ydata,erelease.ydata])

    # Update XY limits
    bx_LC.set_xlim(xlow,xupp)
    bx_LC.set_ylim(ylow,yupp)
    bx_comp.set_xlim(xlow,xupp)
    zoomed=True
    plt.draw()

    # Deactivate the Zoom Selector
    toggle_zoom()

    return


# Function for sigma clipping data
def sigclip_data(kx,ky,dx,dy,siglow,sigupp,dw):
    global apnum,bx_LC,p0,p1,p2,p3,p4,num_del,polypack
    global keep_x,keep_y,keep_ind,zoomed
    global dele_x,dele_y,dele_ind,show_delps
    
    # First get stats within each window
    windows = np.arange(0,len(kx),dw,dtype=int)
    polymods,win_std = window_std(kx,ky,windows,dw)

    # Iterate through windows and sigma-slip data
    good_idx = []
    bad_idx = []
    for i,w in enumerate(windows):
        xdata = kx[w:w+dw]
        ydata = ky[w:w+dw]

        for j,y in enumerate(ydata):
            idx = i*dw + j
            lowlim = polymods[i][j] - siglow*win_std[i]
            upplim = polymods[i][j] + sigupp*win_std[i]
            if (y > lowlim) & (y < upplim):
                good_idx.append(idx)
            else:
                bad_idx.append(idx)

    # Update keep_ind and dele_ind
    keep_x = kx[good_idx]
    keep_y = ky[good_idx]
    dele_x = np.append(dx,kx[bad_idx])
    dele_y = np.append(dy,ky[bad_idx])
    num_del += len(bad_idx)

    # Re-sort the deleted points
    dele_x = dele_x[dele_x.argsort()]
    dele_y = dele_y[dele_x.argsort()]

    # Update divided light curve and plot
    keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                 dele_x,dele_y,
                                                 polypack,apnum)

    return


#####################################################

# Now for the main function which defines all of 
# the interactive options

def on_event_b(event):
    global apnum,bx_LC,p0,p1,p2,p3,p4,num_del,polypack
    global keep_x,keep_y,keep_ind,comp_select,grid_search
    global dele_x,dele_y,del_ind,show_delps,zoomed

    # Add a deleted point back to the plot
    if event.key == 'a':
        if num_del > 0:
            xplot,yplot = event.x,event.y
            idx = get_index(bx_LC, p1, dele_x, dele_y, xplot, yplot)
            keep_x = np.append(keep_x,dele_x[idx])
            keep_y = np.append(keep_y,dele_y[idx])
            dele_x = np.delete(dele_x,idx)
            dele_y = np.delete(dele_y,idx)
            num_del -= 1

            # Re-sort the keep array
            keep_y = keep_y[keep_x.argsort()]
            keep_x = keep_x[keep_x.argsort()]

            # Update divided light curve and plot
            keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                         dele_x,dele_y,
                                                         polypack,apnum)

    # Add ALL deleted points back to the plot
    if event.key == 'A':
        if num_del > 0:
            print('Added all deleted points back to Light Curve.')
            keep_x = np.append(keep_x,dele_x[:])
            keep_y = np.append(keep_y,dele_y[:])
            dele_x = np.asarray([])
            dele_y = np.asarray([])
            num_del = 0

            # Re-sort the keep array
            keep_y = keep_y[keep_x.argsort()]
            keep_x = keep_x[keep_x.argsort()]

            # Update divided light curve and plot
            keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                         dele_x,dele_y,
                                                         polypack,apnum)

    # Delete a single point from the light curve
    if event.key == 'd':
        if len(keep_x) <= 2:
            print('WARNING! Cannot delete any more points!')
        else:
            xplot,yplot = event.x,event.y
            idx = get_index(bx_LC, p0, keep_x, keep_y, xplot, yplot)
            dele_x = np.append(dele_x,keep_x[idx])
            dele_y = np.append(dele_y,keep_y[idx])
            keep_x = np.delete(keep_x,idx)
            keep_y = np.delete(keep_y,idx)
            num_del += 1

            # Update divided light curve and plot
            keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                         dele_x,dele_y,
                                                         polypack,apnum)

    # Toggle the display of deleted points
    if event.key == 's':
        if show_delps == False:
            show_delps=True
        elif show_delps == True:
            show_delps=False
        # Update divided light curve and plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,
                                                     polypack,apnum)

    # Fit a Polynomial to the light curve WITHOUT
    # performing any sigma-rejection iterations (excludes deleted points)
    if event.key == 'f':
        # First Deactivate Garbage/Zoom Boxes if Active
        if toggle_garbage.RS.active:
            toggle_garbage()
        if toggle_reverse_garbage.RS.active:
            toggle_reverse_garbage()
        if toggle_zoom.RS.active:
            toggle_zoom()

        # Ask for new degree value on command line
        print("\nFitting Polynomial WITHOUT Sigma Rejections")
        deg_str = input('Polynomial Degree: ') 
        deg = int(deg_str)

        # Create the polynomial pack with zero rejection iterations
        # polypack = (polynomial degree, 
        #             rejection iterations, 
        #             lower sigma, 
        #             upper sigma)
        polypack = (deg,0,3.,3.)

        # Update divided light curve & plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,polypack,apnum)

    # Fit a Polynomial to the light curve WITH
    # performing any sigma-rejection iterations (excludes deleted points)
    if event.key == 'F':
        # First Deactivate Garbage/Zoom Boxes if Active
        if toggle_garbage.RS.active:
            toggle_garbage()
        if toggle_reverse_garbage.RS.active:
            toggle_reverse_garbage()
        if toggle_zoom.RS.active:
            toggle_zoom()

        # Ask for new degree value on command line
        print("\nFitting Polynomial WITH Sigma Rejections")
        deg_str = input('Polynomial Degree    : ') 
        nrej_str = input('Rejection Iterations : ')
        siglow_str = input('Lower Sigma Threshold: ')
        sigupp_str = input('Upper Sigma Threshold: ')
        deg = int(deg_str)
        nrej = int(nrej_str)
        lower_sig = float(siglow_str)
        upper_sig = float(sigupp_str)

        # Create the polynomial pack with zero rejection iterations
        # polypack = (polynomial degree, 
        #             rejection iterations, 
        #             lower sigma, 
        #             upper sigma)
        polypack = (deg,nrej,lower_sig,upper_sig)

        # Update divided light curve & plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,polypack,apnum)

    # Choose which comparison stars to use
    if event.key == 'c':
        # First Deactivate Garbage/Zoom Boxes if Active
        if toggle_garbage.RS.active:
            toggle_garbage()
        if toggle_reverse_garbage.RS.active:
            toggle_reverse_garbage()
        if toggle_zoom.RS.active:
            toggle_zoom()

        # Ask for comp stars
        new_comps = input('Enter Comp Numbers (comma separated): ')
        new_comp_ind = [int(c)-1 for c in new_comps.split(",")]

        # Make sure none of the given numbers exceed the number of comps
        # and that the number of comps selected also does not exceed this #
        if any([i > Nobj-1 for i in new_comp_ind]):
            print('WARNING: Comps ID(s) exceeds number of comps.')
        elif len(new_comp_ind) > len(comp_ids):
            print('WARNING: More Comps ID(s) Supplied than number of comps.')
        elif len(new_comp_ind) > len(set(new_comp_ind)):
            print('WARNING: Duplicate IDs supplied.')
        else:
            # Get new selection of comps
            comp_select = [comp_ids[i] for i in new_comp_ind]

            # Update divided light curve & plot
            keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                         dele_x,dele_y,
                                                         polypack,apnum)

    # Move to previous aperture size
    if event.key == 'v':
        # Iterate the aperture number
        if apnum > 0: # If greater than the minimum aperture size
            apnum -= 1

        # Update divided light curve & plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,
                                                     polypack,apnum)

    # Move to next aperture size
    if event.key == 'w':
        # Iterate the aperture number
        if apnum < Nf-1: # If less than the maximum aperture size
            apnum += 1
        # Update divided light curve & plot
        keep_y,dele_y,keep_ind,dele_ind = dlc_update(keep_x,keep_y,
                                                     dele_x,dele_y,
                                                     polypack,apnum)

    # Draw box to delete points inside
    if event.key == 'g':
        if (toggle_zoom.RS.active) | (toggle_reverse_garbage.RS.active):
            print(" Cannot activate Garbage Selector while other selectors are Active.")
        else:
            toggle_garbage() 

    # Draw box to delete points outside
    if event.key == 'r':
        if (toggle_zoom.RS.active) | (toggle_garbage.RS.active):
            print(" Cannot activate Reverse Garbage Selector while other selectors are Active.")
        else:
            toggle_reverse_garbage() 

    # Draw box to zoom in
    if event.key == 'z':
        if (toggle_garbage.RS.active) | (toggle_reverse_garbage.RS.active):
            print(" Cannot activate Zoom Selector while other selectors are Active.")
        else:
            toggle_zoom()  

    # Restore zoom to original
    if event.key == 'Z':
        # Update XY limits and display of deleted points 
        if show_delps:
            if len(dele_x) > 0:
                minx,maxx = (min(min(keep_x),min(dele_x)),
                             max(max(keep_x),max(dele_x)))
                miny,maxy = (min(min(keep_y),min(dele_y)),
                             max(max(keep_y),max(dele_y)))
            else:
                minx,maxx = min(keep_x),max(keep_x)
                miny,maxy = min(keep_y),max(keep_y)
        else:
            minx,maxx = min(keep_x),max(keep_x)
            miny,maxy = min(keep_y),max(keep_y)
        xdiff = maxx-minx
        ydiff = maxy-miny
        bx_LC.set_xlim(minx-0.05*xdiff,maxx+0.05*xdiff)
        bx_LC.set_ylim(miny-0.30*ydiff,maxy+0.30*ydiff)
        bx_comp.set_xlim(bx_LC.get_xlim())
        zoomed=False
        plt.draw()

    # Perform sigma-clipping
    if event.key == 'x':

        print("\nPerforming Sigma-Clipping:")

        # Get lower sigma input
        success = False
        while not success:
            try:
                sigma_low = input('Lower Sigma: ')
                sigma_low = float(sigma_low)
                success = True
            except Exception as e:
                print(e)

        # Get upper sigma input
        success = False
        while not success:
            try:
                sigma_upp = input('Upper Sigma: ')
                sigma_upp = float(sigma_upp)
                success = True
            except Exception as e:
                print(e)

        # Get Window Size (Optional since 25 is the default)
        try:
            window_size = input('Window Size [25]: ')
            window_size = int(window_size)
        except:
            window_size = 25

        # Call the sigma-clip function
        sigclip_data(keep_x,keep_y,dele_x,dele_y,
                     sigma_low,sigma_upp,dw=window_size)


    # Option to close the plots and continue to aperture selection
    # without performing a full comp star + ap. size grid search
    if event.key == 'W':
        grid_search = False
        plt.close("all")

    # Option to close the plots and continue t aperture selection
    # while also performing a full grid search of possible comp
    # star + aperture size combinations for best combo.
    if event.key == 'G':
        grid_search = True
        plt.close('all')

    # Option to close the plots and exit program
    if (event.key == 'Q') | (event.key == 'q'):
        plt.close("all")
        print("\nProgram Exited. Lightcurve was not saved. \n")
        sys.exit(1)

    # Option to reprint the command list
    if event.key == '?':
        print_commands()

# Setup the rectangle selectors, start them off as deactivated
rect_garbage = dict(facecolor='#FFFF00',edgecolor = '#FFFF00',
                    alpha=0.4, fill=True)
rect_zoom = dict(facecolor='#6DA1F0',edgecolor = '#6DA1F0',
                 alpha=0.4, fill=True)
toggle_garbage.RS = RectangleSelector(bx_LC, garbage_select, 
                                       props=rect_garbage,
                                       useblit=True,
                                       button=[1],  # Left click only
                                       minspanx=5, minspany=5,
                                       spancoords='pixels',
                                       interactive=False)

toggle_reverse_garbage.RS = RectangleSelector(bx_LC, reverse_garbage_select, 
                                       props=rect_garbage,
                                       useblit=True,
                                       button=[1],  # Left click only
                                       minspanx=5, minspany=5,
                                       spancoords='pixels',
                                       interactive=False)

toggle_zoom.RS = RectangleSelector(bx_LC, zoom_select, 
                                       props=rect_zoom,
                                       useblit=True,
                                       button=[1],  # Left click only
                                       minspanx=5, minspany=5,
                                       spancoords='pixels',
                                       interactive=False)
toggle_garbage.RS.set_active(False)
toggle_reverse_garbage.RS.set_active(False)
toggle_zoom.RS.set_active(False)

# Connect incoming events to event handlers
cidb1 = figb.canvas.mpl_connect('key_press_event', on_event_b)
plt.show()

# Disconnect Figures from interactive event handling
figa.canvas.mpl_disconnect(cida)
figb.canvas.mpl_disconnect(cidb1)
imfig.canvas.mpl_disconnect(cidim)


############################################
# After choosing a polynomial fit and removing outliers,
# display all light curves for each aperture simultaneously
# and highlight the optimal aperture choice.

# First Calculate final divided light curves and 
# calculate p2p-scatter for each aperture
lc_all = []
pp_all = []
for i in range(Nf):
    # Calculate divided LC and point-2-point scatter
    if psource in ['hsp','mae']:
        phot_t = phot_data[i][1].replace(0.0,1.0).values
    elif psource in ['hcm','ucm']:
        phot_t = phot_data[i]['counts_1'].replace(0.0,1.0).values
    phot_c = np.sum(phot_data[i][comp_select].replace(0.0,1.0).values,axis=1)
    lc,_,_,_ = div_lc(time_stamps,phot_t,phot_c,polypack,keep_ind,dele_ind)
    pp = pp_scat(lc)
    lc_all.append(lc)
    pp_all.append(pp*1e2)

# Get index of minimum aperture
ind_ppmin = pp_all.index(min(pp_all))

# Generate all possible combinations of comparison stars
all_combs = []
for i in range(len(comp_ids)):
    new_comb = list(combinations(comp_ids,i+1))
    # Have to convert the itertools tuples into pure lists for Pandas
    for nc in new_comb:
        nc_new = list(nc)
        all_combs.append(nc_new)
Ncomb = len(all_combs)

# Go through all possible aper/comp-star 
# combinations to find the optimal one
if grid_search:
    lc_comb = []   # Stores lowest P2P-light curve for each Comp-Combo
    pp_comb = []   # Stores lowest P2P-value for each Comp-Combo
    ind_aper = []  # Stores index for aperture with lowest P2P-value
    for i in range(Ncomb):
        lc_aper = []
        pp_aper = []
        for j in range(Nf):
            if psource in ['hsp','mae']:
                phot_t = phot_data[j][1].replace(0.0,1.0).values
            elif psource in ['hcm','ucm']:
                phot_t = phot_data[j]['counts_1'].replace(0.0,1.0).values
            phot_c = np.sum(phot_data[j][all_combs[i]].replace(0.0,1.0).values,axis=1)
            lc,_,_,_ = div_lc(time_stamps,phot_t,phot_c,polypack,keep_ind,dele_ind)
            pp = pp_scat(lc)
            lc_aper.append(lc)
            pp_aper.append(pp*1e2)

        # Save info for lowest P2P light curve
        ind_aper.append(pp_aper.index(min(pp_aper)))
        lc_comb.append(lc_aper[ind_aper[i]])
        pp_comb.append(pp_aper[ind_aper[i]])

    # Get index of minimum Combination
    ind_comb_min = pp_comb.index(min(pp_comb))

    # Calculate the P2P scatter for light curves using
    # the optimal selection of combination stars
    pp_all_grid = []
    for i in range(Nf):
        if psource in ['hsp','mae']:
            phot_t = phot_data[i][1].replace(0.0,1.0).values
        elif psource in ['hcm','ucm']:
            phot_t = phot_data[i]['counts_1'].replace(0.0,1.0).values
        phot_c = np.sum(phot_data[i][all_combs[ind_comb_min]].replace(0.0,1.0).values,axis=1)
        lc,_,_,_ = div_lc(time_stamps,phot_t,phot_c,polypack,keep_ind,dele_ind)
        pp = pp_scat(lc)
        pp_all_grid.append(pp*1e2)

    # Get index of minimum aperture
    ind_ppmin_grid = pp_all_grid.index(min(pp_all_grid))


# Print out a message showing P2P for User Selected Comps + Optimal Aperture
apmin_user = ap_sizes[ind_ppmin]
comp_str1 = gen_compstr(comp_select,psource)
print('\nUser Selected Comparison Stars + Optimal Aperture:')
print('    Aper  = {} pix'.format(apmin_user) + \
    '\n    Comps = {}'.format(comp_str1) + \
    '\n    P2P   = {:.2f} %'.format(pp_all[ind_ppmin]))

# Print out a message showing P2P for Optimized Comp + Aperture Combination
if grid_search:
    apmin_auto = ap_sizes[ind_aper[ind_comb_min]]
    comp_str2 = gen_compstr(all_combs[ind_comb_min],psource)
    print('\nFull Grid Search for Optimal Comp + Aperture Combination:')
    print('    Aper  = {} pix'.format(apmin_auto) + \
        '\n    Comps = {}'.format(comp_str2) + \
        '\n    P2P   = {:.2f} %'.format(pp_comb[ind_comb_min]))

    # Print out some warning messages if the User-Selected vs.
    # Automated Comp or Aperture Selections do not agree
    if (comp_str1 != comp_str2) & (apmin_user == apmin_auto):
        print('\nWARNING: Comparison Star Selections Do Not Agree.')
    elif (comp_str1 == comp_str2) & (apmin_user != apmin_auto):
        print('\nWARNING: Aperture Size Selections Do Not Agree.')
    elif (comp_str1 != comp_str2) & (apmin_user != apmin_auto):
        print('\nWARNING: Aperture Size and Comparison Star Selections Do Not Agree.')
    else:
        print('\nAperture Size and Comparison Star Selections are in Agreement.')
else:
    print('\nFull Grid Search not performed.')

# Later the User will be able to decide whether to save the
# light curve reflecting their own comparison star choices
# or save the light curve with optimized selections
# This flag will decide with light curve to put into
# the save lightcurve file.
save_user = True

############################################
# Plot light curves for all apertures
figc = plt.figure(3,figsize=figsize)
gsc = GridSpec(3,1,hspace=0.3)
cx_lc = figc.add_subplot(gsc[0])
cx_ft = figc.add_subplot(gsc[1])
cx_pp = figc.add_subplot(gsc[2])

# Set Window Name
figc.canvas.manager.set_window_title('Aperture Selection')

# Set xy-limits of the PP plot
if grid_search:
    pplow = min(pp_all+pp_all_grid)
    ppupp = max(pp_all+pp_all_grid)
else:
    pplow = min(pp_all)
    ppupp = max(pp_all)
ppdiff = ppupp - pplow
cx_pp.set_xlim(min(ap_sizes)-0.4 ,max(ap_sizes)+0.4)
if ppdiff > 0:
    cx_pp.set_ylim(pplow-0.2*ppdiff, ppupp+0.2*ppdiff)
else:
    cx_pp.set_ylim(pplow-0.2, pplow+0.2)


# Plot Light Curve from Optimal Apertured
lcplot1, = cx_lc.plot(time_stamps[keep_ind],lc_all[ind_ppmin],ls='-',lw=0.75,c='#FFFFFF88',
                      marker='o',ms=3.2,mew=0,mfc='w',label='User')
if grid_search:
    lcplot2, = cx_lc.plot(time_stamps[keep_ind],lc_comb[ind_comb_min],
                          ls='None',marker='.',ms=3.0,mec='r',mfc='r',
                          mew=0.8,label='Grid')
cx_lc.legend(ncol=2,fontsize=8,loc='upper right')

# Plot Periodogram of Optimal Light Curve
freqarr,lsp = calc_lsp(time_stamps[keep_ind],lc_all[ind_ppmin])
siglevel = 4.0*np.nanmean(lsp)*1e2
ftplot1, = cx_ft.plot(freqarr*1e6,lsp*1e2,c='w',label='User')
if grid_search:
    freqarr2,lsp2 = calc_lsp(time_stamps[keep_ind],lc_comb[ind_comb_min])
    siglevel2 = 4.0*np.nanmean(lsp2)*1e2
    ftplot2, = cx_ft.plot(freqarr2*1e6,lsp2*1e2,ls='-',lw=0.8,c='r',label='Grid')

if grid_search:
    cx_ft.axhline(siglevel,ls='-.',lw=1.0,c='w')
    cx_ft.axhline(siglevel2,ls='-.',lw=1.0,c='r')
    cx_ft.set_xlim(0,max(freqarr)*1e6)
    lsp_levels = [max(lsp)*1e2,max(lsp2)*1e2]
    sig_levels = [siglevel,siglevel2]
    if max(lsp_levels) > 1.58*max(sig_levels):
        max_height = 1.2*max(lsp_levels)
    else:
        max_height = 1.9*max(sig_levels)
    cx_ft.set_ylim(0,max_height)
    sig_text  = r'4$\langle$A$\rangle$ = {:6.3f}%,  Max S/N = {:6.2f}'.format(
                siglevel,4.0*max(lsp)*1e2/siglevel)
    sig_text2 = r'4$\langle$A$\rangle$ = {:6.3f}%,  Max S/N = {:6.2f}'.format(
                siglevel2,4.0*max(lsp2)*1e2/siglevel2)
    cx_ft.text(0.63*cx_ft.get_xlim()[1],siglevel+0.23*cx_ft.get_ylim()[1],
               sig_text, fontsize=9, **fontb)
    cx_ft.text(0.63*cx_ft.get_xlim()[1],siglevel+0.10*cx_ft.get_ylim()[1],
               sig_text2, color='r', fontsize=9, **fontb)
    cx_ft.legend(ncol=2,fontsize=8,loc='upper right')
else:
    cx_ft.axhline(siglevel,ls='--',lw=1.0,c='C0')
    cx_ft.set_xlim(0,max(freqarr)*1e6)
    if max(lsp)*1e2 > 1.58*siglevel:
        max_height = 1.2*max(lsp)*1e2
    else:
        max_height = 1.9*siglevel
    cx_ft.set_ylim(0,max_height)
    sig_text = r'4$\langle$A$\rangle$ = {:6.3f}%,  Max S/N = {:6.2f}'.format(
               siglevel,4.0*max(lsp)*1e2/siglevel)
    cx_ft.text(0.63*cx_ft.get_xlim()[1],siglevel+0.10*cx_ft.get_ylim()[1],
               sig_text, fontsize=9, **fontb)
    cx_ft.legend(ncol=2,fontsize=8,loc='upper right')
    
# Plot PP-Scat versus Aperture Size
pp_majloc = MultipleLocator(0.5)
pp_label = "User, Comps {}, Min. P2P = {:.3f} %".format(comp_str1,min(pp_all))
pplot, = cx_pp.plot(ap_sizes,pp_all,ls='-',lw=1.0,c='#FFFFFFAA',
                    marker='o',ms=4,mec='w',mfc='w',label=pp_label)
if grid_search:
    ppgrid_label = "Grid, Comps {}, Min. P2P = {:.3f} %".format(comp_str2,min(pp_all_grid))
    pplot_grid, = cx_pp.plot(ap_sizes,pp_all_grid,ls='--',lw=0.75,c='r',
                             marker='o',ms=5,mec='r',mfc='None',label=ppgrid_label)
cx_pp.xaxis.set_major_locator(pp_majloc)

# Print a special marker for the minimum scatter aperture
cx_pp.plot(ap_sizes[ind_ppmin],pp_all[ind_ppmin],ls='None',
           marker='s',ms=9,mec='c',mfc='None',mew=1.5)
if grid_search:
    cx_pp.plot(ap_sizes[ind_ppmin_grid],pp_all_grid[ind_ppmin_grid],ls='None',
               marker='o',ms=9,mec='C1',mfc='None',mew=1.5)
cx_pp.legend(ncol=1,fontsize=8,loc='upper right')


# Add text and labels
cx_lc.set_xlabel('Time (s)',fontsize=9,**fontb)
cx_lc.set_ylabel('Rel. Flux (%)',fontsize=9,**fontb)
cx_ft.set_xlabel('Frequency ($\mu$Hz)',fontsize=9,**fontb)
cx_ft.set_ylabel('Amplitude (%)',fontsize=9,**fontb)
cx_pp.set_xlabel('Aperture Radius (pix)',fontsize=9,**fontb)
cx_pp.set_ylabel('Avg. P2P-Scatter (%)',fontsize=9,**fontb)
cx_lc.tick_params(direction='in',labelsize=8)
cx_ft.tick_params(direction='in',labelsize=8)
cx_pp.tick_params(direction='in',labelsize=8)
cx_pp.tick_params(which='major',axis='x',labelrotation=70)


# Add title
titlec = '{} Optimal Aperture Selection   UT-Date: {}'.format(obj_name,print_date)
cx_lc.set_title(titlec,loc='center',fontsize=14,**fontb)

# Set some more initial parameters
num_ppdel = 0
keep_ppx = np.copy(np.asarray(ap_sizes))
keep_ppy = np.copy(np.asarray(pp_all))
if grid_search: 
    keep_ppy_grid = np.copy(np.asarray(pp_all_grid)) 
dele_ppx = np.asarray([]) 
dele_ppy = np.asarray([])
if grid_search:
    dele_ppy_grid = np.asarray([])

# Add some basic interactive capabilities
def on_event_c(event):
    global cx_pp,pplot,save_user,grid_search
    global keep_ppx,keep_ppy,keep_ppy_grid
    global dele_ppx,dele_ppy,dele_ppy_grid,num_ppdel

    # Add ALL deleted points back to the plot
    if event.key == 'A':
        if num_ppdel > 0:
            print('Added all deleted points back to Light Curve.')
            keep_ppx = np.append(keep_ppx,dele_ppx[:])
            keep_ppy = np.append(keep_ppy,dele_ppy[:])
            dele_ppx = np.asarray([])
            dele_ppy = np.asarray([])
            if grid_search:
                keep_ppy_grid = np.append(keep_ppy_grid,dele_ppy_grid[:])
                dele_ppy_grid = np.asarray([])
            num_ppdel = 0

            # Re-sort the keep array
            keep_ppy = keep_ppy[keep_ppx.argsort()]
            if grid_search:
                keep_ppy_grid = keep_ppy_grid[keep_ppx.argsort()]
            keep_ppx = keep_ppx[keep_ppx.argsort()]

            # Update the plot
            pplot.set_data(keep_ppx,keep_ppy)
            if grid_search:
                pplot_grid.set_data(keep_ppx,keep_ppy_grid)

            # Update Y limits
            if grid_search:
                pplow = min(np.concatenate((keep_ppy,keep_ppy_grid)))
                ppupp = max(np.concatenate((keep_ppy,keep_ppy_grid)))
            else:
                pplow = min(keep_ppy)
                ppupp = max(keep_ppy)
            ppdiff = ppupp - pplow
            cx_pp.set_ylim(pplow-0.2*ppdiff, ppupp+0.2*ppdiff)
            plt.draw()

    # Delete a single point from the light curve
    if event.key == 'd':
        if len(keep_ppx) <= 2:
            print('WARNING! Cannot delete any more points!')
        else:
            xplot,yplot = event.x,event.y
            idx = get_index(cx_pp, pplot, keep_ppx, keep_ppy, xplot, yplot)
            dele_ppx = np.append(dele_ppx,keep_ppx[idx])
            dele_ppy = np.append(dele_ppy,keep_ppy[idx])
            keep_ppx = np.delete(keep_ppx,idx)
            keep_ppy = np.delete(keep_ppy,idx)
            if grid_search:
                dele_ppy_grid = np.append(dele_ppy_grid,keep_ppy_grid[idx])
                keep_ppy_grid = np.delete(keep_ppy_grid,idx)
            num_ppdel += 1

            # Update the plot
            pplot.set_data(keep_ppx,keep_ppy)
            if grid_search:
                pplot_grid.set_data(keep_ppx,keep_ppy_grid)

            # Update Y limits
            if grid_search:
                pplow = min(np.concatenate((keep_ppy,keep_ppy_grid)))
                ppupp = max(np.concatenate((keep_ppy,keep_ppy_grid)))
            else:
                pplow = min(keep_ppy)
                ppupp = max(keep_ppy)
            ppdiff = ppupp - pplow
            cx_pp.set_ylim(pplow-0.2*ppdiff, ppupp+0.2*ppdiff)
            plt.draw()

    # Option to close the figure and Save the *USER* LightCurve
    if event.key == 'W':
        plt.close("all")

    # Option to close the figure and Save the *AUTO* LightCurve
    if event.key == 'G':
        if grid_search:
            save_user=False  # Save flag changed
        else:
            print('Grid Search was not performed, saving USER lightcurve.')
        plt.close("all")

    # Option to close the figure with Saving the LightCurve
    if (event.key == 'Q') or (event.key == 'q'):
        plt.close("all")
        print('\nProgram exited. Lightcurve was not saved.\n')
        sys.exit(1)

    # Option to reprint the command list
    if event.key == '?':
        print_commands()

# Connect incoming events to event handlers
cidc = figc.canvas.mpl_connect('key_press_event', on_event_c)
plt.show()

# Disconnect Figure from interactive event handling
figc.canvas.mpl_disconnect(cidc)

############################################
# Before saving the final lightcurve, times need to 
# be converted into barycentric format

# Choose appropriate telescope location
loc = get_loc(hdr,telcode)


# Define Astropy coordinate object for the target
# Get the RA & Dec
obj_dat = star_dat.loc[obj_idx].values[0]
obj_ra = (float(obj_dat[1]) + 
          float(obj_dat[2])/60.0 + 
          float(obj_dat[3])/3600.0) * (360.0/24.0)
if float(obj_dat[4]) > 0:
    obj_dec = (float(obj_dat[4]) + 
               float(obj_dat[5])/60.0 + 
               float(obj_dat[6])/3600.0)
else:
    obj_dec = (float(obj_dat[4]) - 
               float(obj_dat[5])/60.0 - 
               float(obj_dat[6])/3600.0)
tcoord = SkyCoord(obj_ra,obj_dec,unit="deg",frame="icrs")
ra_string = tcoord.to_string('hmsdms',sep=" ")[0:11]
de_string = tcoord.to_string('hmsdms',sep=" ")[12:]

# Calculate BJD times using Astropy Time & TimeDelta objects
tkeep = time_stamps[keep_ind]              # All kept timestamps
st = dt_zero_mid.to_value('isot')          # UTC date-time at mid-exposure T0
st_mjd = dt_zero_mid.to_value('mjd')       # UTC MJD at mid-exposure T0
t0 = Time(st,scale='utc',format='isot',    # Astropy mid-exposure Start time
          location=loc)                    
td = TimeDelta(tkeep,format='sec')         # Time deltas for kept points
t = t0 + td                                # Full list of Astropy mid-exposure times

if barycorr:
    print('\nCalculating barycentric corrections...')
    ltt_bary_t0 = t0.light_travel_time(tcoord) # Light travel time to barycenter for T0
    ltt_bary = t.light_travel_time(tcoord)     # Light travel time to barycenter for all kept times
    tbjd = t.tdb.jd + ltt_bary.jd              # Barycentric times (rescaled UTC + LTT)
    bjdref = t0.tdb.jd + ltt_bary_t0           # BJD at mid-exposure T0
    tbjd_sec = (tbjd - bjdref.value) * 86400.0 # BJD converted to seconds since reftime
else:
    print('\nBarycentric corrections skipped...')

# Function to Calculate Error Bars and S/N
def calc_err(phot,aper):
    """
    phot = raw aperture phtometry for all objects + sky
    aper = Aperture radius in pixels
    """

    # Aperture area
    aper_area = np.pi*aper**2
    
    # Add object & sky counts
    object_phot = phot[:,0]
    comp_phot = np.sum(phot[:,1:-1],axis=1)
    sky_phot = phot[:,-1]
    obj_sky = object_phot+sky_phot
    mean_ratio = np.nanmean(object_phot / comp_phot)

    
    # CCD noise characteristics, if available
    if np.isnan(dark):
        dark_noise = 0.0
    else:
        dark_noise = dark*texp*aper_area

    if np.isnan(read):
        read_noise = 0.0
    else:
        read_noise = (read/gain)**2*aper_area

    # Combine all noise sources
    noise = np.sqrt(obj_sky + dark_noise + read_noise)
    snr = object_phot/noise

    # Normalize the noise
    norm_noise = noise / comp_phot / mean_ratio

    return snr, norm_noise

def calc_err_hcm(phot,photerr):
    """
    phot = raw aperture phtometry for all objects
    photerr = Raw flux error of target
    """

    object_phot = phot[:,0]
    comp_phot = np.sum(phot[:,1:-1],axis=1)
    mean_ratio = np.nanmean(object_phot / comp_phot)

    norm_noise = photerr / comp_phot / mean_ratio

    return norm_noise



#################################################################
# Lastly, save the raw photometry and the normalized lightcurve
# into files.  Trying to create outputs similar to the WQED 
# .wq and .lc1 files

# Create header info for file (use save_user flag to choose appropriate info)
if save_user:
    opt_ind = ind_ppmin
    opt_ap = apmin_user
    comp_number = len(comp_select)
    comp_ids = comp_str1
    opt_pp = pp_all[ind_ppmin]
    if psource in ['hsp','mae']:
        phot_index = [1] + comp_select + [Nobj+1]
        snr,err = calc_err(phot_data[ind_ppmin][phot_index].iloc[keep_ind,:].values,opt_ap)
        lc_std = roll_std(lc_all[ind_ppmin],20)
    elif psource in ['hcm','ucm']:
        phot_index = ['counts_1'] + comp_select
        err = calc_err_hcm(phot_data[0][phot_index].iloc[keep_ind,:].values,
            phot_data[0]['countse_1'].iloc[keep_ind].values)

else:
    opt_ind = ind_aper[ind_comb_min]
    opt_ap = apmin_auto
    comp_number = len(all_combs[ind_comb_min])
    comp_ids = comp_str2
    opt_pp = pp_comb[ind_comb_min]
    if psource in ['hsp','mae']:
        phot_index = [1] + all_combs[ind_comb_min] + [Nobj+1]
        snr,err = calc_err(phot_data[ind_aper[ind_comb_min]].iloc[keep_ind,:].values,opt_ap)
        lc_std = roll_std(lc_comb[ind_comb_min],20)
    elif psource in ['hcm','ucm']:
        phot_index = ['counts_1'] + all_combs[ind_comb_min]
        err = calc_err_hcm(phot_data[0][phot_index].iloc[keep_ind,:].values,
            phot_data[0]['countse_1'].iloc[keep_ind].values)

# To try and generate flux uncertainties that are not over
# or under estimated, a correction factor is applied here so
# that the average uncertainty value matches the standard
# deviation (SD) of the light curve. The SD is calculated
# averaging all of the SD values calculated within a sliding 
# window of width 20 data points. Thishelps to reduce the 
# impact of photometric variability on the SD measurement.
if psource in ['hsp','mae']:
    avg_err = np.nanmean(err)
    err_corrfac = lc_std / avg_err
    err*=err_corrfac


# Define the BJED header entry beforehand:
if barycorr:
    bjed_entry = '\nBJED       = {:<30.9f}# Mid Exposure BJD-TDB at T0'.format(bjdref.value)
else:
    bjed_entry = '\nBJED       = {:30s}# Mid Exposure BJD-TDB at T0'.format('None')

# Header for the raw photometry file
rawphot_header = 'Object     = {:30s}# Name of Object'.format(obj_name) + \
    '\nRA         = {:30s}# Object Right Ascension'.format(ra_string) + \
    '\nDec        = {:30s}# Object Declination'.format(de_string) + \
    '\nTelescope  = {:30s}# Name of Telescope'.format(telescope) + \
    '\nInstrument = {:30s}# Name of Instrument'.format(instrument) + \
    '\nTeleCode   = {:30s}# Teledat Code Name'.format(telcode) + \
    '\nObserver   = {:30s}# Name of Observer'.format(observer) + \
    '\nDate       = {:30s}# Mid-Exposure UTC Date at T0'.format(st.split("T")[0]) + \
    '\nTime       = {:30s}# Mid-Exposure UTC Time at T0'.format(st.split("T")[1]) + \
    '\nMJD        = {:<30.9f}# Mid-Exposure UTC MJD at T0'.format(st_mjd) + \
    '\nExptime    = {:<30.6f}# Exposure Time (s)'.format(texp) + \
    '\nFilter     = {:30s}# Filter Name'.format(filter_name) + \
    '{}'.format(bjed_entry) + \
    '\nBarycorr   = {:30s}# Barycentric Corrections Applied?'.format(str(barycorr)) + \
    '\nApPhot     = {:30s}# Photometry Program'.format(source_dict[psource]) + \
    '\nOrigFile   = {:30s}# Source Photometry Filename'.format(phot_names[opt_ind]) + \
    '\nApRadius   = {:<30.2f}# Aperture Radius (pixels)'.format(opt_ap) + \
    '\nNkeep      = {:<30d}# Number of points in light curve '.format(len(keep_ind)) + \
    '\nNdelete    = {:<30d}# Number of points removed'.format(len(dele_ind)) + \
    '\nColumns: Time since T0 (s), Target, Comparisons, Sky' 

# Header for the final lightcurve
lcfin_header = 'Object     = {:30s}# Name of Object'.format(obj_name) + \
    '\nRA         = {:30s}# Object Right Ascension'.format(ra_string) + \
    '\nDec        = {:30s}# Object Declination'.format(de_string) + \
    '\nTelescope  = {:30s}# Name of Telescope'.format(telescope) + \
    '\nInstrument = {:30s}# Name of Instrument'.format(instrument) + \
    '\nTeleCode   = {:30s}# Teledat Code Name'.format(telcode) + \
    '\nObserver   = {:30s}# Name of Observer'.format(observer) + \
    '\nDate       = {:30s}# Mid-Exposure UTC Date at T0'.format(st.split("T")[0]) + \
    '\nTime       = {:30s}# Mid-Exposure UTC Time at T0'.format(st.split("T")[1]) + \
    '\nMJD        = {:<30.9f}# Mid-Exposure UTC MJD at T0'.format(st_mjd) + \
    '\nExptime    = {:<30.6f}# Exposure Time (s)'.format(texp) + \
    '\nFilter     = {:30s}# Filter Name'.format(filter_name) + \
    '{}'.format(bjed_entry) + \
    '\nBarycorr   = {:30s}# Barycentric Corrections Applied?'.format(str(barycorr)) + \
    '\nApPhot     = {:30s}# Photometry Program'.format(source_dict[psource]) + \
    '\nOrigFile   = {:30s}# Source Photometry Filename'.format(phot_names[opt_ind]) + \
    '\nApRadius   = {:<30.2f}# Aperture Radius (pixels)'.format(opt_ap) + \
    '\nAvgScatter = {:<30.2f}# Avg. Point-to-Point Scatter (%)'.format(opt_pp) + \
    '\nNComps     = {:<30d}# Number of Comparison stars used'.format(comp_number) + \
    '\nCompIDs    = {:30s}# IDs of Comparison stars used'.format(comp_ids.replace("+",",")) + \
    '\nPolyOrder  = {:<30d}# Degree of Polynomial Division'.format(polypack[0]) + \
    '\nNkeep      = {:<30d}# Number of points in light curve '.format(len(keep_ind)) + \
    '\nNdelete    = {:<30d}# Number of points removed'.format(len(dele_ind)) + \
    '\nAuthor     = {:30s}# Author of this light curve'.format(author) + \
    '\nCreatedOn  = {:30s}# Date created'.format(Time.now().to_value("iso")) + \
    '\nColumns: Time since T0 (s), Relative Flux, Relative Flux Error' 

# Header for the raw photometry file
log_header = '    OBJECT = {}'.format(obj_name) + \
             '\nPOLYNOMIAL = {},{},{:.2f},{:.2f}'.format(polypack[0],polypack[1],
                                                         polypack[2],polypack[3]) + \
             '\n     DTMID = {}'.format(st) + \
             '\n     COMPS = {}'.format(comp_ids) + \
             '\n      TEXP = {:.6f}'.format(texp)

# Generate output arrays (use save_user flag to choose appropriate lightcurve)
Ndat_all = len(time_stamps)
Ndat_keep = len(keep_ind)
if barycorr:
    bjdtimes_saved = np.reshape(tbjd_sec,(Ndat_keep,1))
else:
    rawtimes_saved = np.reshape(tkeep,(Ndat_keep,1))
orig_times_logged = np.reshape(time_stamps,(Ndat_all,1))
del_log = np.reshape(np.zeros(Ndat_all),(Ndat_all,1))
del_log[keep_ind] += 1 # Set to 1 for points which are kept (0 means deleted)
lcerr_saved = np.reshape(err,(Ndat_keep,1))
if save_user:
    if psource == 'hsp':
        phot_saved = phot_data[opt_ind].iloc[keep_ind,:].values
        if barycorr:
            output_phot = np.concatenate((bjdtimes_saved,phot_saved),axis=1)
        else:
            output_phot = np.concatenate((rawtimes_saved,phot_saved),axis=1)
    elif psource == 'mae':
        phot_saved = phot_data[opt_ind].iloc[keep_ind,0:2*Nobj-1:2].values
        sky_saved = np.reshape(phot_data[opt_ind].iloc[keep_ind,1].values,(Ndat_keep,1))
        if barycorr:
            output_phot = np.concatenate((bjdtimes_saved,phot_saved,sky_saved),axis=1)
        else:
            output_phot = np.concatenate((rawtimes_saved,phot_saved,sky_saved),axis=1)
    elif psource in ['hcm','ucm']:
        phot_saved = phot_data[0][count_ids].iloc[keep_ind,:].values
        sky_saved = np.reshape(sky_flux[keep_ind].values,(Ndat_keep,1))
        if barycorr:
            output_phot = np.concatenate((bjdtimes_saved,phot_saved,sky_saved),axis=1)
        else:
            output_phot = np.concatenate((rawtimes_saved,phot_saved,sky_saved),axis=1)
    lc_saved = np.reshape(lc_all[ind_ppmin],(Ndat_keep,1))
    if barycorr:
        output_lc = np.concatenate((bjdtimes_saved,lc_saved,lcerr_saved),axis=1)
    else:
        output_lc = np.concatenate((rawtimes_saved,lc_saved,lcerr_saved),axis=1)
    output_log = np.concatenate((orig_times_logged,del_log),axis=1)
else:
    if psource == 'hsp':
        phot_saved = phot_data[opt_ind].iloc[keep_ind,:].values
        if barycorr:
            output_phot = np.concatenate((bjdtimes_saved,phot_saved),axis=1)
        else:
            output_phot = np.concatenate((rawtimes_saved,phot_saved),axis=1)
    elif psource == 'mae':
        phot_saved = phot_data[opt_ind].iloc[keep_ind,0:2*Nobj-1:2].values
        sky_saved = np.reshape(phot_data[opt_ind].iloc[keep_ind,1].values,(Ndat_keep,1))
        if barycorr:
            output_phot = np.concatenate((bjdtimes_saved,phot_saved,sky_saved),axis=1)
        else:
            output_phot = np.concatenate((rawtimes_saved,phot_saved,sky_saved),axis=1)
    elif psource in ['hcm','ucm']:
        phot_saved = phot_data[0][count_ids].iloc[keep_ind,:].values
        sky_saved = np.reshape(sky_flux[keep_ind].values,(Ndat_keep,1))
        if barycorr:
            output_phot = np.concatenate((bjdtimes_saved,phot_saved,sky_saved),axis=1)
        else:
            output_phot = np.concatenate((rawtimes_saved,phot_saved,sky_saved),axis=1)
    lc_saved = np.reshape(lc_comb[ind_comb_min],(Ndat_keep,1))
    if barycorr:
        output_lc = np.concatenate((bjdtimes_saved,lc_saved,lcerr_saved),axis=1)
    else:
        output_lc = np.concatenate((rawtimes_saved,lc_saved,lcerr_saved),axis=1)
    output_log = np.concatenate((orig_times_logged,del_log),axis=1)


# Create the format strings for raw photometry and final lightcurve  
rawphot_format = '%10.3f  '
for i in range(Nobj+1):
    rawphot_format += '  %7.0f'
lcfin_format = '%10.3f  %9.6f  %9.6f'
log_format = '%10.3f  %i'

# Save the lightcurve + header info to file
lcphot_fname = '{}_{}.phot'.format(obj_name,print_date.replace("-","")).lower()
lcfin_fname = '{}_{}.lc'.format(obj_name,print_date.replace("-","")).lower()
lclog_fname = 'phot2lc_log.txt'
np.savetxt(lcphot_fname, output_phot, fmt=rawphot_format, header=rawphot_header)
np.savetxt(lcfin_fname, output_lc, fmt=lcfin_format, header=lcfin_header)
np.savetxt(lclog_fname, output_log, fmt=log_format, header=log_header)

print("\nLightcurve saved to file {}".format(lcfin_fname))
print("\nFinished! \n")


