#! /usr/bin/python
# -*- coding: utf-8 -*-

""" Command line Python3 script to plot the UV coevrage of a Measurement Set
"""

import argparse
import os
import matplotlib.pyplot as plt

try:
    from pyrap.tables import table
except:
    print("\n\t=== WARNING: Pyrap module not found ===")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--msname', type=str, help="Measurement Set", required=True)
    parser.add_argument('-f', '--savefile', type=str, help="Save file", default='UV_Coverage.pdf', required=False)
    parser.add_argument('-s', '--select', type=str, help="Any pyrap.table.query", default=None, required=False)
    parser.add_argument('-w', '--wavelength', type=int, help="Use wavelength units", default=0, required=False)
    args = parser.parse_args()

    if not os.path.isdir(args.msname):
        raise IOError("\t=== MS '{}' not found ===".format(args.msname))
    return

    # ------ Read UV ------ #
    mstable = table(args.msname, ack=False, readonly=True)
    if args.select:
        mstable = mstable.query(args.select)
    uvw = mstable.getcol('UVW')
    if args.wavelength:
        # ------ Convert from m to wavelength ------ #
        spwdataid  = mstable.getcol('DATA_DESC_ID')
        spwtable   = table(os.path.join(args.msname, 'SPECTRAL_WINDOW'), ack=False, readonly=True)
        chanfreq   = spwtable.getcol('CHAN_FREQ')
        desctable  = table( os.path.join(sargs.msname, 'DATA_DESCRIPTION'), ack=False, readonly=True)
        spwid      = desctable.getcol('SPECTRAL_WINDOW_ID')
        frequency  = np.take(chanfreq, spwid[spwdataid], axis=0)
        wavelength = const.c.value / frequency
        uvw *= 1./wavelength
    else:
        pass
    mstable.close()

    u   = np.append( -uvw[:, 0], uvw[:, 0])
    v   = np.append( -uvw[:, 1], uvw[:, 1])
    mask = (u != 0.) & (v != 0) # no auto-correlations
    xmin, xmax = u.min(), u.max()
    ymin, ymax = v.min(), v.max() 

    # ------ Make a density plot ------ #
    fig, ax = plt.subplots(figsize=(10, 10))
    hb = ax.hexbin(u[mask], v[mask], gridsize=400, mincnt=1, bins='log', cmap='Blues', edgecolor='none')
    ax.margins(0) # must have!
    ax.axis([xmin, xmax, ymin, ymax])
    ax.set_title('UV Coverage')
    if args.wavelength:
        ax.set_xlabel('u ($\\lambda$)')
        ax.set_ylabel('v ($\\lambda$)')
    else:
        ax.set_xlabel('u (m)')
        ax.set_ylabel('v (m)')
    cb = fig.colorbar(hb, ax=ax)
    cb.set_label('log$_{10}$(density)')
    plt.savefig(args.savefile)

    plt.clf()
    plt.cla()
    plt.close("all")