#!/usr/bin/env python

"""

Make a FITS image mask (with PLIO_1 compression) from a SAOImage DS9 region file.

"""

import os
import sys
import numpy as np
import astropy.table as atpy
import astropy.io.fits as pyfits
from astLib import *
import nemo
from nemo import maps, catalogs
import argparse
import astropy.io.fits as pyfits
import time

#------------------------------------------------------------------------------------------------------------
def makeParser():
    
    parser=argparse.ArgumentParser("nemoMask")
    parser.add_argument("regionFileName", help = "DS9 region file used to define the mask. This must contain\
                        polygon regions, with coordinates given in decimal degrees.")
    parser.add_argument("templateMapFileName", help = "A FITS image file, from which the pixelization and\
                        coordinate system of the output mask will be set. Usually this would be the map that\
                        you want to apply a mask to.")
    parser.add_argument("-o", "--output", dest = "outFileName", default = None,
                        help = "The name of the file for the output mask (FITS format, PLIO_1 compression).\
                        If not given, regionFileName will be used (replacing the extension .reg with .fits)")
    parser.add_argument("-v", "--version", action = 'version', version = '%(prog)s' + ' %s' % (nemo.__version__))

    return parser

#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser=makeParser()
    args=parser.parse_args()

    if args.outFileName is None:
        outFileName=args.regionFileName.replace(".reg", ".fits")
    else:
        outFileName=args.outFileName
    
    with pyfits.open(args.templateMapFileName) as img:
        for ext in img:
            if ext.data is not None:
                wcs=astWCS.WCS(ext.header, mode = 'pyfits')
                shape=ext.data.shape[-2:]

    surveyMask=maps.makeMaskFromDS9PolyRegionFile(args.regionFileName, shape, wcs)
    maps.saveFITS(outFileName, np.array(surveyMask, dtype = np.uint8), wcs,
                  compressionType = 'PLIO_1')
