#!/usr/bin/env python

"""

Make a FITS image mask (with PLIO_1 compression) from a SAOImage DS9 region file.

"""

import os
import sys
import numpy as np
import astropy.table as atpy
import astropy.io.fits as pyfits
from astLib import *
import nemo
from nemo import maps, catalogs
import argparse
import astropy.io.fits as pyfits
import time

#------------------------------------------------------------------------------------------------------------
def makeParser():
    
    parser=argparse.ArgumentParser("nemoMask")
    parser.add_argument("inputFileName", help = "Either:\
                        (i) a DS9 region file (with extension ``.reg``) used to define the mask (this must\
                        contain only polygon regions, with coordinates given in decimal degrees);\
                        (ii) a FITS catalog containing the columns ``RADeg``, ``decDeg``, ``rArcmin``\
                        (use this to make a mask of flagged circular regions, for use with, e.g.,\
                        the ``postFlags`` configuration file parameter)")
    parser.add_argument("templateMapFileName", help = "A FITS image file, from which the pixelization and\
                        coordinate system of the output mask will be set. Usually this would be the map that\
                        you want to apply a mask to.")
    parser.add_argument("-o", "--output", dest = "outFileName", default = None,
                        help = "The name of the file for the output mask (FITS format, PLIO_1 compression).\
                        If not given, for inputFileName (i) we replace the extension ``.reg`` with ``.fits``;\
                        (ii) we preprend ``mask_`` to the inputFileName.")
    parser.add_argument("-i", "--invert-mask", dest="invertMask", action="store_true",
                        help="If True, this will make the mask have value 0 inside the regions defined\
                        by inputFileName, and 1 outside. This is the opposite of the default behavior\
                        (value 1 inside the defined regions, 0 outside)", default = False)
    parser.add_argument("-v", "--version", action = 'version', version = '%(prog)s' + ' %s' % (nemo.__version__))

    return parser

#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser=makeParser()
    args=parser.parse_args()

    t0=time.time()

    # Polygon region mask
    if os.path.splitext(args.inputFileName)[-1] == '.reg':
        regionFileName=args.inputFileName

        if args.outFileName is None:
            outFileName=regionFileName.replace(".reg", ".fits")
        else:
            outFileName=args.outFileName

        with pyfits.open(args.templateMapFileName) as img:
            for ext in img:
                if ext.data is not None:
                    wcs=astWCS.WCS(ext.header, mode = 'pyfits')
                    shape=ext.data.shape[-2:]

        mask=maps.makeMaskFromDS9PolyRegionFile(regionFileName, shape, wcs)

    # Circular flagged regions from a catalog
    else:

        if args.outFileName is None:
            outDir, outFileName=os.path.split(os.path.abspath(args.inputFileName))
            outFileName=outDir+os.path.sep+"mask_"+outFileName
            outFileName, ext=os.path.splitext(outFileName)
            outFileName=outFileName+".fits"
        else:
            outFileName=args.outFileName

        areaMask, wcs=maps.chunkLoadMask(args.templateMapFileName)
        shape=areaMask.shape
        shape=shape[-2:]
        del areaMask

        tab=atpy.Table().read(args.inputFileName)
        if 'rArcmin' not in tab.keys():
            raise Exception("Did not find 'rArcmin' column - this is needed for making masks from FITS catalogs.")

        # Paint regions
        tab=catalogs.getCatalogWithinImage(tab, shape, wcs)
        maskImage=np.zeros(shape, dtype = np.uint8)
        rDegMap=np.ones(shape, dtype = np.float32)*1000 # No point going 16 bit here, we get 32 back anyway
        count=0
        for row in tab:
            count=count+1
            maskRadiusArcmin=row['rArcmin']
            rDegMap, xBounds, yBounds=maps.makeDegreesDistanceMap(rDegMap, wcs,
                                                                  row['RADeg'], row['decDeg'],
                                                                  (maskRadiusArcmin*2)/60)
            selection=rDegMap[yBounds[0]:yBounds[1], xBounds[0]:xBounds[1]] < maskRadiusArcmin/60.0
            maskImage[yBounds[0]:yBounds[1], xBounds[0]:xBounds[1]][selection]=1
        mask=maskImage

    if args.invertMask == True:
        mask=1-mask

    # Output
    maps.saveFITS(outFileName, np.array(mask, dtype = np.uint8), wcs,
                  compressionType = 'PLIO_1')

    print("Time taken: %.1f sec" % (time.time()-t0))
