#! /usr/bin/env python3

"""

Command line tool for applying primary beam correction to WSClean images

Gets the frequency to use from the image header

"""

import os
import sys
from katbeam import JimBeam
import astropy.io.fits as pyfits
import astropy.stats as apyStats
from astLib import *
import numpy as np
import nemoCython
import argparse

#WCSKeys=['NAXIS', 'CTYPE', 'CRPIX', 'CRVAL', 'CDELT', 'CUNIT']

if __name__ == '__main__':

    parser = argparse.ArgumentParser("mkat_primary_beam_correct")
    parser.add_argument("imageFileName", help="""A .fits image to correct. Assumes image has NAXIS=4,
                        with frequency in NAXIS3 (e.g., WSClean continuum images).""")
    parser.add_argument("-t", "--threshold", dest="thresh", help="""Threshold below which image pixels
                        will be set to blank values (nan). Use to remove areas where the primary beam
                        correction is large.""", default = 0.3, type = float)
    parser.add_argument("-T", "--trim", dest="trim", help="""Trim image outside valid region (set by 
                        --threshold) to reduce size.""", default = False, action = 'store_true')
    args=parser.parse_args()

    inFileName=args.imageFileName
    thresh=args.thresh
    trim=args.trim
    
    prefix="pbcorr"
    if trim == True:
        prefix=prefix+"_trim"
    outDir=os.path.split(os.path.abspath(inFileName))[0]
    outFileName=outDir+os.path.sep+prefix+"_"+os.path.split(inFileName)[-1]
    
    with pyfits.open(inFileName) as img:
        # Order depends on if e.g. CASA or WSClean image
        polAxis=None
        freqAxis=None
        for i in range(1, 5):
            if img[0].header['CTYPE%d' % (i)] == 'STOKES':
                polAxis=i
            elif img[0].header['CTYPE%d' % (i)] == 'FREQ':
                freqAxis=i
        assert(polAxis is not None and freqAxis is not None)
        shape=img[0].data[0, 0].shape
        wcs=astWCS.WCS(img[0].header, mode = 'pyfits').copy()
        freqMHz=img[0].header['CRVAL%d' % (freqAxis)]/1e6
        wcs.updateFromHeader()
        print("Frequency = %.3f MHz" % (freqMHz))
        print("Assumed L-band") # We can change this based on freq or command line option in future
        beam=JimBeam('MKAT-AA-L-JIM-2020')
        RADeg, decDeg=wcs.getCentreWCSCoords()
        maxRDeg=4.0
        xDegMap, yDegMap=nemoCython.makeXYDegreesDistanceMaps(np.ones(shape, dtype = np.float64), wcs, RADeg, decDeg, maxRDeg)
        I=beam.I(xDegMap, yDegMap, freqMHz)
        img[0].data[0, 0]=img[0].data[0, 0]/I
        img[0].data[0, 0][I < thresh]=np.nan
        if trim == True:
            # Assumes N is at the top, E at the left
            y, x=np.where(I >= thresh)
            yMin, yMax=y.min(), y.max()
            xMin, xMax=x.min(), x.max()
            blah, decMin=wcs.pix2wcs(xMax, yMin)
            blah, decMax=wcs.pix2wcs(xMin, yMax)
            decMinMax=np.array([decMin, decMax])
            yDecAbsMax=np.array([yMin, yMax])[np.argmax(abs(decMinMax))]
            RAMin, blah=wcs.pix2wcs(xMax, yDecAbsMax)
            RAMax, blah=wcs.pix2wcs(xMin, yDecAbsMax)
            clipDict=astImages.clipUsingRADecCoords(img[0].data[0, 0], wcs, RAMin, RAMax, decMin, decMax)
            img[0].data=np.zeros([1, 1, clipDict['data'].shape[0], clipDict['data'].shape[1]])
            img[0].data[0, 0]=clipDict['data']
            img[0].header=clipDict['wcs'].header
            # for i in range(1, 3):
            #     for k in WCSKeys:
            #         img[0].header['%s%d' % (k, i)]=clipDict['wcs'].header['%s%d' % (k, i)]
        # We may as well report RMS while we're at it
        d=img[0].data[np.isnan(img[0].data) == False]
        sigma=1e6
        for i in range(10):
            mask=np.logical_and(np.greater(d, d.mean()-3*sigma), np.less(d, d.mean()+3*sigma))
            sigma=np.std(d[mask])
        print("3-sigma clipped stdev image RMS estimate = %.1f uJy" % (sigma*1e6))
        sbi=apyStats.biweight_scale(d, c = 9.0, modify_sample_size = True)
        print("Biweight scale image RMS estimate = %.1f uJy" % (sbi*1e6))
        img.writeto(outFileName, overwrite = True)
        
