#! /usr/bin/python3.6
# -*- coding: utf-8 -*-

"""
"""
import argparse
import numpy as np

from scipy.interpolate import interp2d
# from matplotlib import pyplot as plt
from astropy.time import Time
from astropy import units as u
from astropy import coordinates as coord
from astropy import wcs
from astropy.io import fits

from nenupy.beam import SSTbeam, BSTbeam
from nenupy.beam.antenna import miniarrays
from nenupy.skymodel import SkyModel
import nenupy.astro as astro



__author__ = 'Alan Loh'
__copyright__ = 'Copyright 2018, nenupy'
__credits__ = ['Alan Loh']
__license__ = 'MIT'
__version__ = '0.0.1'
__maintainer__ = 'Alan Loh'
__email__ = 'alan.loh@obspm.fr'
__status__ = 'WIP'

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-t', '--time', type=str, help="UTC Time (e.g. 2016-09-30 14:00:00)", required=True)
    parser.add_argument('-s', '--store', type=str, help="Filename (needs to end by .fits)", required=True)
    parser.add_argument('-f', '--freq', type=float, default=50, help="Frequency in MHz", required=False)
    parser.add_argument('-a', '--azimuth', type=float, default=180, help="Azimuth in degrees", required=False)
    parser.add_argument('-e', '--elevation', type=float, default=90, help="Elevation in degrees", required=False)
    parser.add_argument('-p', '--polar', type=str, default='NW', help="Polarization (NW or NE)", required=False)
    parser.add_argument('-m', '--ma', type=int, default=None, help="Mini-Array name index (assume SST if not None)", required=False)
    args = parser.parse_args()



# Skymodel
sm = SkyModel()
sm.gsm2008(freq=args.freq)
skymodel = sm.skymodel
ramodel  = np.linspace(180., -180., skymodel.shape[1]) 
decmodel = np.linspace(-90., 90., skymodel.shape[0]) 
azimuth  = np.linspace(0., 360., skymodel.shape[1])
zenitha  = 90. - np.linspace(0., 90., skymodel.shape[0]/2)
agrid, zgrid = np.meshgrid( azimuth, zenitha )

# beam
if args.ma is not None:
    # Assume SST
    beam = SSTbeam(ma=args.ma,
        freq=args.freq,
        azana=args.azimuth,
        elana=args.elevation,
        polar=args.polar)
else:
    # Assume BST
    beam = BSTbeam(freq=args.freq,
        azana=args.azimuth,
        azdig=args.azimuth,
        elana=args.elevation,
        eldig=args.elevation,
        polar=args.polar) 
beam.getBeam()
beam = np.flipud(beam.beam)
bazi = np.linspace(0., 360., beam.shape[1])
bele = np.linspace(0., 90.,  beam.shape[0])
fb   = interp2d(bazi, bele, beam, kind='linear' )
beam = fb( np.linspace(0, 360, skymodel.shape[1]), np.linspace(0, 90, skymodel.shape[0]/2) )

# Sky rotation
frame  = coord.AltAz(obstime=Time(args.time), location=miniarrays.nenufarloc)
altaz  = coord.SkyCoord(agrid*u.deg, (90-zgrid)*u.deg, frame=frame)
radec  = altaz.transform_to(coord.FK5(equinox='J2000'))
ragrid, decgrid = radec.ra.deg, radec.dec.deg
ragrid[ragrid > 180] -= 360 # RA is within -180 and 180
xraind  = (ramodel[0] - ragrid)   / (ramodel[0]-ramodel[1])   #+ 0.5
xdecind = (decgrid - decmodel[0]) / (decmodel[1]-decmodel[0]) #+ 0.5
skysel  = skymodel[xdecind.astype(int), xraind.astype(int)]

# Sky x beam
skybeam = (beam/beam.max()) * skysel
skybeam = skybeam.compressed().reshape(skybeam.shape) # getting rid off mask array dunno why fits doesnt like

if args.store.lower() == 'none':
    import matplotlib as mpl
    import pylab as plt
    import matplotlib.ticker as mtick
    # Plot
    theta = np.linspace(0., 90., skybeam.shape[0])
    phi   = np.radians( np.linspace(0., 360., skybeam.shape[1]) )
    # ------ Plot ------ #
    # fig = plt.figure()
    fig = plt.figure(figsize=(18/2.54, 18/2.54))
    ax  = fig.add_subplot(111, projection='polar')
    normcb = mpl.colors.LogNorm(vmin=skybeam.max() * 1.e-5, vmax=skybeam.max())
    p = ax.pcolormesh(phi, theta, np.flipud(skybeam), norm=normcb, rasterized=True, cmap='bone')
    ax.grid(linestyle='-', linewidth=0.5, color='white', alpha=0.4)
    plt.setp(ax.get_yticklabels(), rotation='horizontal', color='white')
    
    g = lambda x,y: r'%d'%(90-x)
    ax.yaxis.set_major_formatter(mtick.FuncFormatter( g ))

    # Sources
    ateam = ['Vir A', 'Cyg A', 'Cas A', 'Tau A', 'Her A', 'Hyd A', 'Sun', 'Moon', 'Jupiter', 'Saturn']
    for at in ateam:
        src = astro.Source( source=at, time=Time(args.time), location=miniarrays.nenufarloc )
        altaz = src.source.transform_to(frame)
        az, el = altaz.az.deg, altaz.alt.deg
        if el > 0:
            # plt.scatter(np.radians(az), el, c='red', s=2, alpha=0.5)
            # c = ax.scatter(theta, r, c=colors, s=area, cmap='hsv', alpha=0.75)
            ax.scatter(np.radians(az), 90-el, s=150, facecolor='#d62728', edgecolor='#d62728', alpha=0.3)
            ax.text(np.radians(az), 90-el, '   '+at, color='#d62728')

    plt.title('pol={}, freq={:.2f}MHz, az={}, el={}'.format(args.polar, args.freq, args.azimuth, args.elevation))

    plt.show()

elif args.store.lower().endswith('png'):
    import matplotlib as mpl
    import pylab as plt
    import matplotlib.ticker as mtick
    # Plot
    theta = np.linspace(0., 90., skybeam.shape[0])
    phi   = np.radians( np.linspace(0., 360., skybeam.shape[1]) )
    # ------ Plot ------ #
    # fig = plt.figure()
    fig = plt.figure(figsize=(18/2.54, 18/2.54))
    ax  = fig.add_subplot(111, projection='polar')
    normcb = mpl.colors.LogNorm(vmin=skybeam.max() * 1.e-5, vmax=skybeam.max())
    p = ax.pcolormesh(phi, theta, np.flipud(skybeam), norm=normcb, rasterized=True, cmap='bone')
    ax.grid(linestyle='-', linewidth=0.5, color='white', alpha=0.4)
    plt.setp(ax.get_yticklabels(), rotation='horizontal', color='white')
    
    g = lambda x,y: r'%d'%(90-x)
    ax.yaxis.set_major_formatter(mtick.FuncFormatter( g ))

    # Sources
    ateam = ['Vir A', 'Cyg A', 'Cas A', 'Tau A', 'Her A', 'Hyd A', 'Sun', 'Moon', 'Jupiter', 'Saturn']
    for at in ateam:
        src = astro.Source( source=at, time=Time(args.time), location=miniarrays.nenufarloc )
        altaz = src.source.transform_to(frame)
        az, el = altaz.az.deg, altaz.alt.deg
        if el > 0:
            ax.scatter(np.radians(az), 90-el, s=150, facecolor='#d62728', edgecolor='#d62728', alpha=0.3)
            ax.text(np.radians(az), 90-el, '   '+at, color='#d62728')

    plt.savefig(args.store, dpi=200, facecolor='none',
        edgecolor='none', transparent=True, bbox_inches='tight')

else:
    # Save in FITS
    w = wcs.WCS(naxis=2)
    w.wcs.crpix = [skybeam.shape[1]/2., 0]
    w.wcs.cdelt = np.array([ 360./skybeam.shape[1] , 90. / skybeam.shape[0] ])
    w.wcs.crval = [180., 0.]
    w.wcs.cunit = ['deg', 'deg']
    w.wcs.ctype = ['RA---CAR', 'DEC--CAR']
    w.wcs.cname = ['Azimuth', 'Elevation']
    header = w.to_header()
    hdu1 = fits.PrimaryHDU(header=header)
    hdu1.data = skybeam
    hdus = fits.HDUList([hdu1])#, hdu2, hdu3])
    hdus.writeto(args.store, overwrite=True)

    # Region file
    ateam = ['Vir A', 'Cyg A', 'Cas A', 'Tau A', 'Her A', 'Hyd A', 'Sun', 'Moon', 'Jupiter', 'Saturn']
    with open(args.store[:-4]+'reg', 'w') as rf:
        rf.write("global color=green\n")
        for at in ateam:
            src = astro.Source( source=at, time=Time(args.time), location=miniarrays.nenufarloc )
            altaz = src.source.transform_to(frame)
            az, el = altaz.az.deg, altaz.alt.deg
            px, py = w.wcs_world2pix(az, el, 1)
            rf.write("point("+str(px)+", "+str(py)+")# point=cross text={"+at+"}\n" )

