#! /usr/bin/python3.6
# -*- coding: utf-8 -*-

"""
"""
import argparse
import numpy as np

from scipy.interpolate import interp2d
# from matplotlib import pyplot as plt
from astropy.time import Time
from astropy import units as u
from astropy import coordinates as coord

from nenupy.beam import SSTbeam, BSTbeam
from nenupy.beam.antenna import miniarrays
from nenupy.skymodel import SkyModel
import nenupy.astro as astro



__author__ = 'Alan Loh'
__copyright__ = 'Copyright 2018, nenupy'
__credits__ = ['Alan Loh']
__license__ = 'MIT'
__version__ = '0.0.1'
__maintainer__ = 'Alan Loh'
__email__ = 'alan.loh@obspm.fr'
__status__ = 'WIP'

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-t', '--time', type=str, help="UTC Time (e.g. 2016-09-30 14:00:00)", required=True)
    parser.add_argument('-s', '--store', type=str, help="Filename (needs to end by .fits)", required=True)
    parser.add_argument('-f', '--freq', type=float, default=50, help="Frequency in MHz", required=False)
    parser.add_argument('-a', '--azimuth', type=float, default=180, help="Azimuth in degrees", required=False)
    parser.add_argument('-e', '--elevation', type=float, default=90, help="Elevation in degrees", required=False)
    parser.add_argument('-p', '--polar', type=str, default='NW', help="Polarization (NW or NE)", required=False)
    parser.add_argument('-m', '--ma', type=int, default=None, help="Mini-Array name index (assume SST if not None)", required=False)
    args = parser.parse_args()



# Skymodel
sm = SkyModel()
sm.gsm2008(freq=args.freq)
skymodel = sm.skymodel
ramodel  = np.linspace(180., -180., skymodel.shape[1]) 
decmodel = np.linspace(-90., 90., skymodel.shape[0]) 
azimuth  = np.linspace(0., 360., skymodel.shape[1])
zenitha  = 90. - np.linspace(0., 90., skymodel.shape[0]/2)
agrid, zgrid = np.meshgrid( azimuth, zenitha )

# beam
if args.ma is not None:
    # Assume SST
    beam = SSTbeam(ma=args.ma,
        freq=args.freq,
        azana=args.azimuth,
        elana=args.elevation,
        polar=args.polar)
else:
    # Assume BST
    beam = BSTbeam(freq=args.freq,
        azana=args.azimuth,
        azdig=args.azimuth,
        elana=args.elevation,
        eldig=args.elevation,
        polar=args.polar) 
beam.getBeam()
beam = np.flipud(beam.beam)
bazi = np.linspace(0., 360., beam.shape[1])
bele = np.linspace(0., 90.,  beam.shape[0])
fb   = interp2d(bazi, bele, beam, kind='linear' )
beam = fb( np.linspace(0, 360, skymodel.shape[1]), np.linspace(0, 90, skymodel.shape[0]/2) )

# Sky rotation
frame  = coord.AltAz(obstime=Time(args.time), location=miniarrays.nenufarloc)
altaz  = coord.SkyCoord(agrid*u.deg, (90-zgrid)*u.deg, frame=frame)
radec  = altaz.transform_to(coord.FK5(equinox='J2000'))
ragrid, decgrid = radec.ra.deg, radec.dec.deg
ragrid[ragrid > 180] -= 360 # RA is within -180 and 180
xraind  = (ramodel[0] - ragrid)   / (ramodel[0]-ramodel[1])   #+ 0.5
xdecind = (decgrid - decmodel[0]) / (decmodel[1]-decmodel[0]) #+ 0.5
skysel  = skymodel[xdecind.astype(int), xraind.astype(int)]

# Sky x beam
skybeam = (beam/beam.max()) * skysel
skybeam = skybeam.compressed().reshape(skybeam.shape) # getting rid off mask array dunno why fits doesnt like

if args.store.lower() == 'none':
    import matplotlib as mpl
    import pylab as plt
    import matplotlib.ticker as mtick
    # Plot
    theta = np.linspace(0., 90., skybeam.shape[0])
    phi   = np.radians( np.linspace(0., 360., skybeam.shape[1]) )
    # ------ Plot ------ #
    # fig = plt.figure()
    fig = plt.figure(figsize=(18/2.54, 18/2.54))
    ax  = fig.add_subplot(111, projection='polar')
    normcb = mpl.colors.LogNorm(vmin=skybeam.max() * 1.e-5, vmax=skybeam.max())
    p = ax.pcolormesh(phi, theta, np.flipud(skybeam), norm=normcb, rasterized=True, cmap='bone')
    ax.grid(linestyle='-', linewidth=0.5, color='white', alpha=0.4)
    plt.setp(ax.get_yticklabels(), rotation='horizontal', color='white')
    
    g = lambda x,y: r'%d'%(90-x)
    ax.yaxis.set_major_formatter(mtick.FuncFormatter( g ))

    # Sources
    ateam = ['Vir A', 'Cyg A', 'Cas A', 'Tau A', 'Her A', 'Hyd A', 'Sun', 'Moon', 'Jupiter', 'Saturn']
    for at in ateam:
        src = astro.Source( source=at, time=Time(args.time), location=miniarrays.nenufarloc )
        altaz = src.source.transform_to(frame)
        az, el = altaz.az.deg, altaz.alt.deg
        if el > 0:
            # plt.scatter(np.radians(az), el, c='red', s=2, alpha=0.5)
            # c = ax.scatter(theta, r, c=colors, s=area, cmap='hsv', alpha=0.75)
            ax.scatter(np.radians(az), 90-el, s=150, facecolor='#d62728', edgecolor='#d62728', alpha=0.3)
            ax.text(np.radians(az), 90-el, '   '+at, color='#d62728')

    plt.title('pol={}, freq={:.2f}MHz, az={}, el={}'.format(args.polar, args.freq, args.azimuth, args.elevation))
    ax.set_ylim(0, 90)
    plt.show()
    plt.close('all')

elif args.store.lower().endswith('png'):
    import matplotlib as mpl
    import pylab as plt
    import matplotlib.ticker as mtick
    # Plot
    theta = np.linspace(0., 90., skybeam.shape[0])
    phi   = np.radians( np.linspace(0., 360., skybeam.shape[1]) )
    # ------ Plot ------ #
    # fig = plt.figure()
    fig = plt.figure(figsize=(18/2.54, 18/2.54))
    ax  = fig.add_subplot(111, projection='polar')
    normcb = mpl.colors.LogNorm(vmin=skybeam.max() * 1.e-5, vmax=skybeam.max())
    p = ax.pcolormesh(phi, theta, np.flipud(skybeam), norm=normcb, rasterized=True, cmap='bone')
    ax.grid(linestyle='-', linewidth=0.5, color='white', alpha=0.4)
    plt.setp(ax.get_yticklabels(), rotation='horizontal', color='white')
    
    g = lambda x,y: r'%d'%(90-x)
    ax.yaxis.set_major_formatter(mtick.FuncFormatter( g ))

    # Sources
    ateam = ['Vir A', 'Cyg A', 'Cas A', 'Tau A', 'Her A', 'Hyd A', 'Sun', 'Moon', 'Jupiter', 'Saturn']
    for at in ateam:
        src = astro.Source( source=at, time=Time(args.time), location=miniarrays.nenufarloc )
        altaz = src.source.transform_to(frame)
        az, el = altaz.az.deg, altaz.alt.deg
        if el > 0:
            ax.scatter(np.radians(az), 90-el, s=150, facecolor='#d62728', edgecolor='#d62728', alpha=0.3)
            ax.text(np.radians(az), 90-el, '   '+at, color='#d62728')
    ax.set_ylim(0, 90)
    plt.savefig(args.store, dpi=200, facecolor='none',
        edgecolor='none', transparent=True, bbox_inches='tight')
    plt.close('all')

else:
    from astropy import wcs
    from astropy.io import fits

    # Save in FITS
    w = wcs.WCS(naxis=2)
    w.wcs.crpix = [skybeam.shape[1]/2., 0]
    w.wcs.cdelt = np.array([ 360./skybeam.shape[1] , 90. / skybeam.shape[0] ])
    w.wcs.crval = [180., 0.]
    w.wcs.cunit = ['deg', 'deg']
    w.wcs.ctype = ['RA---CAR', 'DEC--CAR']
    w.wcs.cname = ['Azimuth', 'Elevation']
    header = w.to_header()
    hdu1 = fits.PrimaryHDU(header=header)
    hdu1.data = skybeam
    hdus = fits.HDUList([hdu1])#, hdu2, hdu3])
    try:
        hdus.writeto(args.store, overwrite=True)
    except:
        hdus.writeto(args.store, clobber=True)
    # Region file
    ateam = ['Vir A', 'Cyg A', 'Cas A', 'Tau A', 'Her A', 'Hyd A', 'Sun', 'Moon', 'Jupiter', 'Saturn']
    with open(args.store[:-4]+'reg', 'w') as rf:
        rf.write("global color=green\n")
        for at in ateam:
            src = astro.Source( source=at, time=Time(args.time), location=miniarrays.nenufarloc )
            altaz = src.source.transform_to(frame)
            az, el = altaz.az.deg, altaz.alt.deg
            px, py = w.wcs_world2pix(az, el, 1)
            rf.write("point("+str(px)+", "+str(py)+")# point=cross text={"+at+"}\n" )

