import os, sys, glob, optparse, shutil, warnings
import numpy as np
np.random.seed(0)
import subprocess

import healpy as hp
#import ephem
import astropy, astroplan
from astropy import table

import gwemopt.utils, gwemopt.gracedb
import gwemopt.rankedTilesGenerator, gwemopt.waw
import gwemopt.lightcurve, gwemopt.coverage
import gwemopt.efficiency, gwemopt.plotting
import gwemopt.moc, gwemopt.catalog
import gwemopt.tiles, gwemopt.segments
#import gwemopt.tiles, gwemopt.segments_astroplan

import gwemopt.footprint, gwemopt.transients
import gwemopt.quadrants, gwemopt.mapsplit

import matplotlib.pyplot as plt
import matplotlib.patches
if not os.getenv("DISPLAY", None):
    import matplotlib
    matplotlib.use("agg", warn=False)

__author__ = "Michael Coughlin <michael.coughlin@ligo.org>"
__version__ = 1.0
__date__    = "6/17/2017"

# =============================================================================
#
#                               DEFINITIONS
#
# =============================================================================

def parse_commandline():
    """@Parse the options given on the command-line.
        """
    parser = optparse.OptionParser(usage=__doc__,version=__version__)
    
    parser.add_option("-c", "--configDirectory", help="GW-EM config file directory.", default ="../config/")
    parser.add_option("-s", "--skymap", help="GW skymap.", default='../output/skymaps/G268556.fits')
    parser.add_option("-g", "--gpstime", help="GPS time.", default=1167559936.0, type=float)
    
    parser.add_option("-o", "--outputDir", help="output directory",default="../output")
    parser.add_option("--tilingDir", help="tiling directory",default="../tiling")
    parser.add_option("--scheduleType", help="schedule type",default="greedy")
    parser.add_option("--timeallocationType", help="time allocation type",default="powerlaw")
    parser.add_option("--nside",default=256,type=int)
    parser.add_option("--tilesType", help="tiling type",default="moc")
    parser.add_option("--doSchedule",  action="store_true", default=False)
    parser.add_option("--doTiles",  action="store_true", default=False)
    parser.add_option("--iterativeOverlap",default=0.0,type=float)
    parser.add_option("--doCalcTiles", action="store_true",default = False)
    parser.add_option("--doAlternatingFilters",  action="store_true", default=False)
    parser.add_option("--doBalanceExposure",  action="store_true", default=False)
    parser.add_option("--filters",default="r,g,r")
    parser.add_option("--doMovie_supersched",  action="store_true", default=False)
    parser.add_option("--powerlaw_cl",default=0.9,type=float)
    parser.add_option("--powerlaw_n",default=1.0,type=float)
    parser.add_option("--powerlaw_dist_exp",default=0,type=float)
    parser.add_option("--Tobs",default="0.0,1.0")
    parser.add_option("--Tobs_split",default=2,type=int)
    parser.add_option("--doSkymap",  action="store_true", default=False)
    parser.add_option("--doEfficiency",  action="store_true", default=False)
    parser.add_option("--exposuretimes",default="30.0,30.0,30.0")
    
    parser.add_option("--doMinimalTiling",  action="store_true", default=False)
    parser.add_option("--doReferences",  action="store_true", default=False)
    parser.add_option("--doPlots",  action="store_true", default=False)
    parser.add_option("-t", "--alltelescopes", help="Telescope names for each scheduling round.",
    default = 'ZTF+DECam')
    parser.add_option("--doChipGaps",  action="store_true", default=False)
    parser.add_option("--doParallel",action="store_true",default=False)
    parser.add_option("--doSingleExposure",  action="store_true", default=False)
                      
    parser.add_option("-v", "--verbose", action="store_true", default=False,
    help="Run verbosely. (Default: False)")

    opts, args = parser.parse_args()
                      
    # show parameters
    if opts.verbose:
        print >> sys.stderr, ""
        print >> sys.stderr, "running gwemopt_run..."
        print >> sys.stderr, "version: %s"%__version__
        print >> sys.stderr, ""
        print >> sys.stderr, "***************** PARAMETERS ********************"
        for o in opts.__dict__.items():
            print >> sys.stderr, o[0]+":"
            print >> sys.stderr, o[1]
        print >> sys.stderr, ""

    return opts

def params_struct(opts,schedround):
    """@Creates gwemopt params structure
        @param opts
        gwemopt command line options
        """
    telescopes = schedround.split(",")
    
    params["config"] = {}
    configFiles = glob.glob("%s/*.config"%opts.configDirectory)
    for configFile in configFiles:
        telescope = configFile.split("/")[-1].replace(".config","")
        if not telescope in telescopes: continue
        params["config"][telescope] = gwemopt.utils.readParamsFromFile(configFile)
        params["config"][telescope]["telescope"] = telescope
        if opts.doSingleExposure:
            exposuretime = np.array(opts.exposuretimes.split(","),dtype=np.float)[0]
            
            nmag = np.log(exposuretime/params["config"][telescope]["exposuretime"]) / np.log(2.5)
            params["config"][telescope]["magnitude"] = params["config"][telescope]["magnitude"] + nmag
            params["config"][telescope]["exposuretime"] = exposuretime
        if "tesselationFile" in params["config"][telescope]:
            if not os.path.isfile(params["config"][telescope]["tesselationFile"]):
                if params["config"][telescope]["FOV_type"] == "circle":
                    gwemopt.tiles.tesselation_spiral(params["config"][telescope])
                elif params["config"][telescope]["FOV_type"] == "square":
                    gwemopt.tiles.tesselation_packing(params["config"][telescope])
            params["config"][telescope]["tesselation"] = np.loadtxt(params["config"][telescope]["tesselationFile"],usecols=(0,1,2),comments='%')
    
        if "referenceFile" in params["config"][telescope]:
            refs = table.unique(table.Table.read(
                params["config"][telescope]["referenceFile"],
                format='ascii', data_start=2, data_end=-1)['field', 'fid'])
            reference_images =\
                {group[0]['field']: group['fid'].astype(int).tolist()
                for group in refs.group_by('field').groups}
            reference_images_map = {1: 'g', 2: 'r', 3: 'i'}
            for key in reference_images:
                reference_images[key] = [reference_images_map.get(n, n)
                                         for n in reference_images[key]]
            params["config"][telescope]["reference_images"] = reference_images
        
        location = astropy.coordinates.EarthLocation(params["config"][telescope]["longitude"],params["config"][telescope]["latitude"],params["config"][telescope]["elevation"])
        observer = astroplan.Observer(location=location)
        params["config"][telescope]["observer"] = observer
            
        # observer = ephem.Observer()
        # observer.lat = str(params["config"][telescope]["latitude"])
        # observer.lon = str(params["config"][telescope]["longitude"])
        # observer.horizon = str(-12.0)
        # observer.elevation = params["config"][telescope]["elevation"]
        # params["config"][telescope]["observer"] = observer

    params["alltelescopes"] = opts.alltelescopes.split("+")
    params["doSuperSched"] = True
    params["doMovie_supersched"] = opts.doMovie_supersched
    params["Tobs_split"] = opts.Tobs_split
    params["skymap"] = opts.skymap
    params["tilesType"] = opts.tilesType
    params["gpstime"] = opts.gpstime
    params["scheduleType"] = opts.scheduleType
    params["timeallocationType"]: opts.timeallocationType
    params["telescopes"] = telescopes
    params["outputDir"] = opts.outputDir
    params["tilingDir"] = opts.tilingDir
    params["iterativeOverlap"] = opts.iterativeOverlap
    params["doReferences"] = opts.doReferences
    params["doMinimalTiling"] = opts.doMinimalTiling
    params["nside"] = opts.nside
    params["powerlaw_cl"] = opts.powerlaw_cl
    params["doPlots"] = opts.doPlots
    params["powerlaw_n"] = opts.powerlaw_n
    params["powerlaw_dist_exp"] = opts.powerlaw_dist_exp
    params["doChipGaps"] = opts.doChipGaps
    params["Tobs_all"] = np.array(opts.Tobs.split(","),dtype=np.float)
    params["doParallel"] = opts.doParallel
    params["doCalcTiles"] = opts.doCalcTiles
    params["doSingleExposure"] = opts.doSingleExposure
    params["exposuretimes"] = np.array(opts.exposuretimes.split(","),dtype=np.float)
    params["filters"] = opts.filters.split(",")
    params["doAlternatingFilters"] = opts.doAlternatingFilters
    params["doBalanceExposure"] = opts.doBalanceExposure

    return params

# =============================================================================
#
#                                    MAIN
#
# =============================================================================

warnings.filterwarnings("ignore")

# Parse command line
opts = parse_commandline()
alltelescopes = opts.alltelescopes.split("+")
if not os.path.isdir(opts.outputDir): os.makedirs(opts.outputDir)
Tobs_split = opts.Tobs_split
Tobs = np.array(opts.Tobs.split(","),dtype=np.float)
Tobs = np.linspace(Tobs[0],Tobs[1],Tobs_split+1)
failed_rounds = []
if len(alltelescopes) != Tobs_split:
    raise ValueError(f'Number of rounds in alltelescopes ({len(alltelescopes)}) does not match number of rounds of Tobs ({Tobs_split}). Change --Tobs_split or the number of rounds in --alltelescopes.')

params = {}
params["coverage_structs"] = {} #for erase_observed_tiles in coverage.py
params["covered_field_ids"] ={}
params["tile_structs"] = {} #for doMovie_supersched
for ii,schedround in enumerate(alltelescopes):
    covered_field_ids_hold = params["covered_field_ids"] #saves field ids,tile_structs, coverage_structs before call to params_struct()
    coverage_structs_hold = params["coverage_structs"]
    tile_structs_hold = params["tile_structs"]
    params = params_struct(opts,schedround)
    params["coverage_structs"] = coverage_structs_hold
    params["covered_field_ids"] = covered_field_ids_hold
    params["tile_structs"] = tile_structs_hold
    gwemopt.utils.params_checker(params)
    params["Tobs"] = np.array([Tobs[ii],Tobs[ii+1]],dtype=np.float)
    params["coverage_structs"][f'coverage_struct_{ii}'] = {}
    params["tile_structs"][f'tile_structs_{ii}'] = {} # to find out what round we are in during later stages
    if opts.doSkymap:
        params["skymap"]
        params["gpstime"]
    try:
        params = gwemopt.segments.get_telescope_segments(params)
    except ValueError as e:
        print(f'Could not find segments for round {ii}. ',e)
        failed_rounds.append(ii)
        continue

    parentdir = params["outputDir"]
    params["outputDir"] = os.path.join(parentdir,f'{Tobs[ii]:.2}_to_{Tobs[ii+1]:.2}_Tobs')
    if not os.path.isdir(params["outputDir"]): os.mkdir(params["outputDir"])
    map_struct = gwemopt.utils.read_skymap(params, is3D=False)

    if opts.doPlots:
        print("Plotting skymap...")
        gwemopt.plotting.skymap(params,map_struct)

    if opts.doTiles:
        if params["tilesType"] == "moc":
            print("Generating MOC struct...")
            moc_structs = gwemopt.moc.create_moc(params, map_struct=map_struct)
            tile_structs = gwemopt.tiles.moc(params, map_struct, moc_structs)
        
        elif params["tilesType"] == "ranked":
            print("Generating ranked struct...")
                #tile_structs = gwemopt.tiles.rankedTiles(params, map_struct)
            moc_structs = gwemopt.rankedTilesGenerator.create_ranked(params, map_struct)
            tile_structs = gwemopt.tiles.moc(params, map_struct, moc_structs)
        elif params["tilesType"] == "hierarchical":
            print("Generating hierarchical struct...")
            tile_structs = gwemopt.tiles.hierarchical(params, map_struct)
            params["Ntiles"] = []
            for telescope in params["telescopes"]:
                params["config"][telescope]["tesselation"] = np.empty((0,3))
                tiles_struct = tile_structs[telescope]
                for index in tiles_struct.keys():
                    ra, dec = tiles_struct[index]["ra"], tiles_struct[index]["dec"]
                    params["config"][telescope]["tesselation"] = np.append(params["config"][telescope]["tesselation"],[[index,ra,dec]],axis=0)
                params["Ntiles"].append(len(tiles_struct.keys()))

        elif params["tilesType"] == "greedy":
            print("Generating greedy struct...")
            tile_structs = gwemopt.tiles.greedy(params, map_struct)
            for telescope in params["telescopes"]:
                params["config"][telescope]["tesselation"] = np.empty((0,3))
                tiles_struct = tile_structs[telescope]
                for index in tiles_struct.keys():
                    ra, dec = tiles_struct[index]["ra"], tiles_struct[index]["dec"]
                    params["config"][telescope]["tesselation"] = np.append(params["config"][telescope]["tesselation"],[[index,ra,dec]],axis=0)

        elif params["tilesType"] == "galaxy":
            print("Generating galaxy struct...")
            tile_structs = gwemopt.tiles.galaxy(params, map_struct, catalog_struct)
            for telescope in params["telescopes"]:
                params["config"][telescope]["tesselation"] = np.empty((0,3))
                tiles_struct = tile_structs[telescope]
                for index in tiles_struct.keys():
                    ra, dec = tiles_struct[index]["ra"], tiles_struct[index]["dec"]
                    params["config"][telescope]["tesselation"] = np.append(params["config"][telescope]["tesselation"],[[index,ra,dec]],axis=0)
        else:
            print("Need tilesType to be moc, greedy, hierarchical, ranked or galaxy")
            exit(0)
    
            if opts.doPlots:
                print("Plotting tiles struct...")
                gwemopt.plotting.tiles(params, map_struct, tile_structs)
    if ii>0 and ii-1 not in failed_rounds: #reads in field ids of tiles that were covered in previous round and appends according to telescope
        
        prevtelescopes = alltelescopes[ii-1].split(",")
        prevfile = os.path.join(parentdir,f'{Tobs[ii-1]:.2}_to_{Tobs[ii]:.2}_Tobs')
        if not os.path.isdir(prevfile): os.mkdir(prevfile)
        for prevtelescope in prevtelescopes:
            if prevtelescope not in params["covered_field_ids"]:
                params["covered_field_ids"][prevtelescope]={}
            params["covered_field_ids"][prevtelescope][ii-1] = []
            schedfile = f'schedule_{prevtelescope}.dat'
            data_file = os.path.join(prevfile,schedfile)
            try:
                with open(data_file, "r") as f: #read in field_ids and whether observed
                    for line in f:
                        data = list(line.split(' '))
                        field_id = int(data[0])
                        if int(data[8]) == 1:
                            params["covered_field_ids"][prevtelescope][ii-1].append(field_id) #get list of observed tiles

            except IOError:
                print("sched file could not be opened")

    if opts.doSchedule:
        if opts.doTiles:
            print("Generating coverage...")
            tile_structs,coverage_struct = gwemopt.coverage.timeallocation(params, map_struct, tile_structs)
        else:
            print("Need to enable --doTiles to use --doSchedule")
            exit(0)

    if opts.doSchedule or opts.doCoverage:
        print("Summary of coverage...")
        gwemopt.scheduler.summary(params,map_struct,coverage_struct)

        if opts.doPlots:
            print("Plotting coverage...")
            gwemopt.plotting.coverage(params, map_struct, coverage_struct)

    params["tile_structs"][f'tile_structs_{ii}'] = tile_structs
    params["coverage_structs"][f'coverage_struct_{ii}'] = coverage_struct

#gets covered field ids for last round

if ii not in failed_rounds:
    prevtelescopes = alltelescopes[-1].split(",")
    prevfile = os.path.join(parentdir,f'{Tobs[-2]:.2}_to_{Tobs[-1]:.2}_Tobs')
    if not os.path.isdir(prevfile): os.mkdir(prevfile)
    for prevtelescope in prevtelescopes:
        if prevtelescope not in params["covered_field_ids"]:
            params["covered_field_ids"][prevtelescope]={}
        schedfile = f'schedule_{prevtelescope}.dat'
        data_file = os.path.join(prevfile,schedfile)
        params["covered_field_ids"][prevtelescope][ii] = []
        try:
            with open(data_file, "r") as f: #read in field_ids and whether observed
                for line in f:
                    data = list(line.split(' '))
                    field_id = int(data[0])
                    if int(data[8]) == 1:
                        params["covered_field_ids"][prevtelescope][ii].append(field_id) #get list of observed tiles
        
        except IOError:
            print("sched file could not be opened")



if params["doMovie_supersched"]: #creates mpeg and gif animations of total coverage

    moviedir = os.path.join(parentdir,"movie")
    moviefiles = os.path.join(moviedir,"coverage-%04d.png")
    filename = os.path.join(parentdir,"coverage.mpg")
    ffmpeg_command = 'ffmpeg -an -y -r 20 -i %s -b:v %s %s'%(moviefiles,'5000k',filename)
    subprocess.call(ffmpeg_command,shell=True)
    filename = os.path.join(parentdir,"coverage.gif")
    ffmpeg_command = 'ffmpeg -an -y -r 20 -i %s -b:v %s %s'%(moviefiles,'5000k',filename)
    subprocess.call(ffmpeg_command,shell=True)
    rm_command = "rm %s/*.png"%(moviedir)
    subprocess.call(rm_command,shell=True)


'''
    PLOT COVERAGE FOR ALL TOBS
'''

for w in range(len(params["coverage_structs"])): #sets baseline reference to map other field IDs to
    coverage_struct_0 = params["coverage_structs"][f'coverage_struct_{w}']
    try:
        mjd_min=coverage_struct_0["data"][2,2]
    except:
        continue
    break

mjd_max=0
og_uncovered=[]
plotName = os.path.join(parentdir,'fullcoverage.pdf')
fig = plt.figure(figsize=(10,8))
ax=plt.gca()
alltelescopes = params["alltelescopes"]
for ii,telescopes in enumerate(alltelescopes):
    alltelescopes[ii] = telescopes.split(",")

for ii in range(len(alltelescopes)):
    coverage_struct = params["coverage_structs"][f'coverage_struct_{ii}']
    if not coverage_struct: continue
    idx = np.isfinite(coverage_struct["data"][:,2])
    try:
        min = np.min(coverage_struct["data"][idx,2])
    except:
        continue
    max = np.max(coverage_struct["data"][idx,2]*1.000005)
    if min<mjd_min: mjd_min=min
    if max>mjd_max: mjd_max=max
min_field_id = coverage_struct["data"][2][5]
max_field_id=0

for ii in range(len(coverage_struct_0["ipix"])):
    if not coverage_struct_0["telescope"][ii] == alltelescopes[w][0]: #only does for reference telescope
        continue
    data = coverage_struct_0["data"][ii,:]
    x,y=data[2],data[5]
    width=(mjd_max-mjd_min)/14
    height=20
    patch = matplotlib.patches.Rectangle(xy=(x,y),width=width,height=height)
    if data[5] in params["covered_field_ids"][alltelescopes[w][0]][w]: #makes sure field id was covered by telescope in that round then sets to green
        patch.set_fc('green')
    else:
        patch.set_fc('red')
        og_uncovered.append(data[5])
    
    if data[5]<min_field_id: min_field_id=data[5]
    if data[5]>max_field_id: max_field_id=data[5]
    patch.set_ec('black')
    ax.add_patch(patch)

#map field ids of other telescopes to alltelescopes[w][0] (reference)

for jj,telescopes in enumerate(alltelescopes):
    tile_structs = params["tile_structs"][f'tile_structs_{jj}']
    coverage_struct = params["coverage_structs"][f'coverage_struct_{jj}']
    if not coverage_struct: continue #skips if coverage struct is empty
    for telescope in telescopes:
        if telescope == alltelescopes[w][0] and jj==w: continue
        elif telescope == alltelescopes[w][0]: #assigns mapped field id to original field id if same telescope
            for ii in range(len(coverage_struct["ipix"])):
                if coverage_struct["telescope"][ii] != telescope: continue
                coverage_struct[f'mapped_field_id_{ii}'] = coverage_struct["data"][ii,5]
            continue
        tile_struct = tile_structs[telescope]
        tile_struct = gwemopt.utils.check_overlapping_tiles(params,tile_struct,coverage_struct_0)#appends field id corresponding to first coverage_struct to the tile_struct
        for key in tile_struct.keys():
            if not 'epochs' in tile_struct[key]: continue
            epochs = tile_struct[key]["epochs"]
            for i in range(len(epochs)):
                field_id = epochs[i,5]
                epochs_telescope = tile_struct[key]["epochs_telescope"][i]
                if epochs_telescope != alltelescopes[w][0]: break #only continues if field id is from first telescope
                for ii in range(len(coverage_struct["ipix"])): #looks through coverage_struct to find corresponding index
                    if coverage_struct["telescope"][ii] != telescope:
                        continue
                    data = coverage_struct["data"][ii,:]
                    if data[5] == int(key) and f'mapped_field_id_{ii}' not in coverage_struct: #finds key with the corresponding field id
                        coverage_struct[f'mapped_field_id_{ii}'] = field_id

#plots rest of the points for remaining rounds
for jj in range(len(alltelescopes)):
    coverage_struct = params["coverage_structs"][f'coverage_struct_{jj}']
    if not coverage_struct: continue #skips if coverage struct is empty

    for ii in range(len(coverage_struct["ipix"])):
        if coverage_struct["telescope"][ii] == alltelescopes[w][0] and jj==w: continue
        data = coverage_struct["data"][ii,:]
        
        if f'mapped_field_id_{ii}' in coverage_struct:
            mapped_field_id = coverage_struct[f'mapped_field_id_{ii}']
        else: continue

        x,y=data[2],mapped_field_id
        width=(mjd_max-mjd_min)/14
        height=20
        
        patch = matplotlib.patches.Rectangle(xy=(x,y),width=width,height=height)
        round = jj
        if mapped_field_id in og_uncovered and data[5] in params["covered_field_ids"][coverage_struct["telescope"][ii]][jj]: #loops through past rounds to see if tile was previously unobserved and now reobserved
            patch.set_fc('blue')
        elif data[5] in params["covered_field_ids"][coverage_struct["telescope"][ii]][jj]: #sets color to green if observed
            patch.set_fc('green')
        else:
            patch.set_fc('red')
            og_uncovered.append(mapped_field_id)

        if mapped_field_id<min_field_id: min_field_id=mapped_field_id
        if mapped_field_id>max_field_id: max_field_id=mapped_field_id
        patch.set_ec('black')
        ax.add_patch(patch)


plt.xlabel("Time [MJD]")
plt.ylabel("Tile Number")
plt.xlim([mjd_min,mjd_max])
plt.ylim([min_field_id-50,max_field_id+50])
plt.show()
plt.savefig(plotName,dpi=200)
plt.close('all')
