#!/usr/bin/python

# Copyright (C) 2017 Michael Coughlin
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

""".
Gravitational-wave Electromagnetic Optimization

This script generates an optimized list of pointings and content for
reviewing gravitational-wave skymap likelihoods.

Comments should be e-mailed to michael.coughlin@ligo.org.

"""


import os, sys, glob, optparse, shutil, warnings
import numpy as np
np.random.seed(0)

import healpy as hp

import matplotlib
#matplotlib.rc('text', usetex=True)
matplotlib.use('Agg')
matplotlib.rcParams.update({'font.size': 16})
matplotlib.rcParams['contour.negative_linestyle'] = 'solid'
import matplotlib.pyplot as plt

import gwemopt.utils, gwemopt.plotting

__author__ = "Michael Coughlin <michael.coughlin@ligo.org>"
__version__ = 1.0
__date__    = "6/17/2017"

# =============================================================================
#
#                               DEFINITIONS
#
# =============================================================================

def parse_commandline():
    """@Parse the options given on the command-line.
    """
    parser = optparse.OptionParser(usage=__doc__,version=__version__)

    parser.add_option("-s", "--skymap", help="GW skymap.", default='../output/skymaps/GRB181222_annulus.fit')
    parser.add_option("-p", "--skymapplot", help="GW skymap plot.", default='../output/skymaps/GRB181222_annulus.pdf')

    parser.add_option("--nside",default=256,type=int)

    parser.add_option("--doCircle",  action="store_true", default=False)
    parser.add_option("--doAnnulus",  action="store_true", default=False)

    parser.add_option("--ra",default=294.324)
    parser.add_option("--dec",default=-23.694)
    parser.add_option("--radius",default=51.089,type=float)
    parser.add_option("--radius_sigma",default=0.097,type=float)

    parser.add_option("--doSquare",  action="store_true", default=False)
    #parser.add_option("--ras",default="290.010,290.186,265.423,264.736")
    #parser.add_option("--decs",default="48.528,48.912,43.214,42.455")

    parser.add_option("--doTESS",  action="store_true", default=False)
    #parser.add_option("--ras",default="324.5670,338.5766,19.4927,90.0042")
    #parser.add_option("--decs",default="-33.1730,-55.0789,-71.9781,-66.5647")
    parser.add_option("--lon",default="10.94,10.94,10.94,10.94")
    parser.add_option("--lat",default="-18.0,-42.0,-66.0,-89.0")

    parser.add_option("--doIntersect",  action="store_true", default=False)
    parser.add_option("-o", "--original_skymap", help="GW skymap.", default='../data/GRB181222/glg_healpix_all_bn181222841.fit')

    parser.add_option("-v", "--verbose", action="store_true", default=False,
                      help="Run verbosely. (Default: False)")

    opts, args = parser.parse_args()

    # show parameters
    if opts.verbose:
        print >> sys.stderr, ""
        print >> sys.stderr, "running gwemopt_run..."
        print >> sys.stderr, "version: %s"%__version__
        print >> sys.stderr, ""
        print >> sys.stderr, "***************** PARAMETERS ********************"
        for o in opts.__dict__.items():
          print >> sys.stderr, o[0]+":"
          print >> sys.stderr, o[1]
        print >> sys.stderr, ""

    return opts

def get_intersect(healpix,original_skymap,nside):

    original_data = hp.read_map(original_skymap)
    nside_orig = hp.pixelfunc.get_nside(original_data)
    original_data = hp.ud_grade(original_data,nside,power=-2)
    
    healpix = healpix * original_data
    healpix = healpix / np.sum(healpix)

    return healpix    

def gen_annulus(ra,dec,radius,error,nside):

    ra, dec = float(ra), float(dec)

    theta = 0.5 * np.pi - np.deg2rad(dec)
    phi = np.deg2rad(ra)

    xyz = hp.ang2vec(theta, phi)
    ipix1 = hp.query_disc(nside, xyz, np.deg2rad(radius-error)) 
    ipix2 = hp.query_disc(nside, xyz, np.deg2rad(radius+error))    
    ipix = np.setdiff1d(ipix2,ipix1)

    npix = hp.nside2npix(nside)
    n = np.zeros(npix)
    n[ipix] = 1.
    healpix = n / np.sum(n)

    return healpix    

def gen_circle(ra,dec,error,nside):

    if type(ra) == str and ',' in ra:
        list_params = np.array([list(map(float,ra.split(","))),list(map(float,dec.split(",")))])
    else:
        list_params = np.array([[float(ra)],[float(dec)]])
    if len(list_params[0]) != len(list_params[1]):
        print("Need equal number of RAs and decs")
        return None
    else:
        params = np.transpose(list_params)
        params = params.tolist()

    npix = hp.nside2npix(nside)
    n = np.zeros(npix)

    for pair in params:
        p = hp.ang2pix(nside, pair[0], pair[1], lonlat=True)
        n[p] = 1.

    healpix = hp.smoothing(n, sigma=np.deg2rad(error), verbose=False)
    healpix = healpix / np.sum(healpix)

    return healpix

def gen_square(ras,decs,nside):

    if ',' in ras:
        print("Use --doCircle for multiple locations")
        return None

    radecs = []
    for r,d in zip(ras,decs):
        radecs.append([r,d])
    radecs = np.array(radecs)

    xyz = []
    for r, d in radecs:
        xyz.append(hp.ang2vec(r, d, lonlat=True))

    npts, junk = radecs.shape
    if npts == 4:
        ipix = hp.query_polygon(nside, np.array(xyz))
    else:
        ipix = hp.query_polygon(nside, np.array(xyz))

    npix = hp.nside2npix(nside)
    n = np.zeros(npix)
    n[ipix] = 1.
    healpix = n / np.sum(n)

    return healpix

#def gen_TESS(ras,decs,nside):
def gen_TESS(lats,lons,nside):

    tess_fov = 24 
    ipixs = np.array([])
    #for r,d in zip(ras,decs):
    for t,n in zip(lats,lons):
        print(n,t)
        ipix, radecs, patch, area = gwemopt.utils.getSquarePixels(n,t,tess_fov,nside)
        #ipix, radecs, patch, area = gwemopt.utils.getSquarePixels(r,d,tess_fov,nside)
        ipixs = np.hstack((ipixs,ipix))
    ipixs = np.unique(ipixs).astype(int)
    
    npix = hp.nside2npix(nside)
    n = np.zeros(npix)
    n[ipixs] = 1.
    healpix = n / np.sum(n)

    r = hp.rotator.Rotator(coord=['E','C'])
    healpix = r.rotate_map(healpix)

    return healpix   

# =============================================================================
#
#                                    MAIN
#
# =============================================================================

warnings.filterwarnings("ignore")

# Parse command line
opts = parse_commandline()

if opts.doCircle:
    healpix = gen_circle(opts.ra,opts.dec,opts.radius,opts.nside)
elif opts.doSquare:
    ras = [float(x) for x in opts.ras.split(",")]
    decs = [float(x) for x in opts.decs.split(",")]
    healpix = gen_square(ras,decs,opts.nside)
elif opts.doAnnulus:
    healpix = gen_annulus(opts.ra,opts.dec,opts.radius,opts.radius_sigma,opts.nside)
elif opts.doTESS:
    #ras = [float(x) for x in opts.ras.split(",")]
    #decs = [float(x) for x in opts.decs.split(",")]
    lats = [float(x) for x in opts.lat.split(",")]
    lons = [float(x) for x in opts.lon.split(",")]
    healpix = gen_TESS(lats,lons,opts.nside)
else:
    print("Please specify --doCircle, --doSquare, --doTESS or --doAnnulus")
    exit(0)

if opts.doIntersect:
    healpix = get_intersect(healpix,opts.original_skymap,opts.nside)

hp.fitsfunc.write_map(opts.skymap,healpix,overwrite=True)

unit='Gravitational-wave probability'
cbar=False

plt.figure()
hp.mollview(healpix,title='',unit=unit,cbar=False,min=0,max=np.max(healpix))
gwemopt.plotting.add_edges()
plt.show()
plt.savefig(opts.skymapplot,dpi=200)
plt.close('all')

