#!/usr/bin/env python

"""

Make a map containing model objects using the info in a nemo output catalog

"""

import os
import sys
import numpy as np
import astropy.table as atpy
import astropy.io.fits as pyfits
from astLib import *
from nemo import startUp
from nemo import maps
from nemo import catalogs
import argparse
import astropy.io.fits as pyfits

#------------------------------------------------------------------------------------------------------------
def makeParser():
    
    parser=argparse.ArgumentParser("nemoModel")
    parser.add_argument("catalog", help = "Either the path to a Nemo FITS-table format catalog,\
                        or 'pointsources-N' (to generate a test catalog of N point sources that will be\
                        inserted into the map, e.g., pointsources-1000 will insert 1000 sources). If\
                        the latter, the catalog will be written to outputFileName_inputCatalog.fits\
                        (with the .fits extension stripped from outputFileName).")
    parser.add_argument("maskFileName", help = "A FITS image file, containing a mask of the desired sky\
                        area. Non-zero values in the mask are used to define tiles (typically 10 x 5 deg),\
                        which are processed in parallel if MPI is enabled (see -M switch). The output sky\
                        model image will have the same pixelization as the mask image.")
    parser.add_argument("beamFileName", help = "A file containing the beam profile, in the standard format\
                        used by ACT.")
    parser.add_argument("outputFileName", help = "The name of the output file that will contain the sky\
                        model image.")
    parser.add_argument("-f", "--frequency-GHz", dest = "obsFreqGHz", type = float, default = 150.0,
                        help = "If the nemo catalog contains SZ-selected clusters, the SZ signal will be\
                        evaluted at the given frequency, ignoring relativistic effects (default: 150.0).")
    parser.add_argument("-s", "--scale-signals", dest = "scale", type = float, default = 1.0,
                        help = "Scale the input y_c values of clusters in the catalog by this factor.")
    parser.add_argument("-p", "--profile", dest = "profile", type = str, default = "A10",
                        help = "For cluster models only - select the profile to use: 'A10'\
                        (Arnaud et al. 2010 UPP models) or 'B12' (Battaglia et al. 2012 models). The mass and\
                        redshift used to set the shape and scale of the cluster model uses the 'template'\
                        column in the Nemo catalog, if the 'true_M500c', 'true_fixed_y_c', and 'true_Q' columns\
                        are not present. This requires the filter labels that appear in the 'template'\
                        column to be in the format 'Profile_M{$MASS}_z{$REDSHIFT}', with the decimal point\
                        replaced by the letter 'p' (e.g., 'Arnaud_M2e14_z0p4', as in the example config\
                        files for Nemo).")
    parser.add_argument("-C", "--add-cmb", dest = "addCMB", action = "store_true", default = False,
                        help = "Add a realization of the cosmic microwave background to the map. Use\
                        --cmb-seed to set the random seed to the same value if creating sets of maps at\
                        different frequencies. Note that this is very memory intensive for large maps,\
                        and is a serial operation.")
    parser.add_argument("-S", "--seed", dest = "seed", type = int, default = None,
                        help = "Random seed used only for generating (optional) cosmic microwave\
                        background or source catalog realizations (i.e., the seed is not used for\
                        noise realizations).")
    parser.add_argument("-N", "--add-white-noise", dest = "addWhiteNoise", type = float, default = 0.0,
                        help = "If a random cosmic microwave background realization had been added, add\
                        uniform per-pixel white noise at the specified level.")
    parser.add_argument("--split-noise-test", dest = "splitNoiseTest", action = "store_true",
                        help = "If set, and -N and -C switches are used, split the map into two\
                        sections, with the white noise level doubled in one half of the map.")
    parser.add_argument("-T", "--break-map-into-tiles", dest = "breakIntoTiles", action = "store_true",
                        help = "Break large maps into tiles using the autotiler function in Nemo.\
                        This will be turned on automatically if MPI is enabled (using -M).", 
                        default = False)
    parser.add_argument("-a", "--tcmb-alpha", dest = "TCMBAlpha", type = float, default = 0.0,
                        help = "Only applies to cluster models. Set this to a non-zero value to generate\
                        a model where the CMB temperature varies as T(z) = T0 * (1+z)^{1-TCMBAlpha}.\
                        Requires a 'redshift' column to be present in the input catalog.")
    parser.add_argument("-M", "--mpi", dest="MPIEnabled", action="store_true", help="Enable MPI. If used,\
                        the image will be broken into a number of tiles, with one tile per process. If you\
                        want to use this, run with e.g., mpiexec -np 4 nemoModel args -M", 
                        default = False)
    parser.add_argument("-n", "--no-strict-errors", dest="noStrictMPIExceptions", action="store_true",
                        help="Disable strict exception handling (applies under MPI only, i.e., must be\
                        used with the -M switch). If you use this option, you will get the full traceback\
                        when a Python Exception is triggered, but the code may not terminate. This is due\
                        to the Exception handling in mpi4py.", default = False)
    
    return parser

#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser = makeParser()
    args = parser.parse_args()
    
    if args.noStrictMPIExceptions == True:
        strictMPIExceptions=False
    else:
        strictMPIExceptions=True

    # Create a stub config (then we can use existing start-up stuff to dish out the tiles)
    print(">>> Setting up ...")
    parDict={}
    mapDict={}
    mapDict['mapFileName']=args.maskFileName
    mapDict['obsFreqGHz']=args.obsFreqGHz
    mapDict['beamFileName']=args.beamFileName
    parDict['unfilteredMaps']=[mapDict]
    parDict['mapFilters']=[]
    if args.MPIEnabled == True or args.breakIntoTiles == True:
        parDict['makeTileDir']=True
        parDict['tileOverlapDeg']=1.0
        parDict['tileDefLabel']='auto'
        parDict['tileDefinitions']={'mask': args.maskFileName,
                                    'targetTileWidthDeg': 10.0, 'targetTileHeightDeg': 5.0}

    config=startUp.NemoConfig(parDict, MPIEnabled = args.MPIEnabled, divideTilesByProcesses = True,
                              makeOutputDirs = False, setUpMaps = True, writeTileDir = False, 
                              verbose = False, strictMPIExceptions = strictMPIExceptions)
    
    baseDir, fileName=os.path.split(args.outputFileName)
    if baseDir != '':
        os.makedirs(baseDir, exist_ok = True)
    
    if args.catalog[:12] == 'pointsources':
        try:
            numSources=int(args.catalog.split("-")[-1])
        except:
            raise Exception("Use format pointsources-N, e.g., pointsources-100 will generate a test catalog of 100 sources.")
        with pyfits.open(args.maskFileName) as img:
            for extName in img:
                data=img[extName].data
                if data is not None:
                    break
            wcs=astWCS.WCS(img[extName].header, mode = 'pyfits')
            if numSources > 0:
                tab=catalogs.generateRandomSourcesCatalog(data, wcs, numSources, seed = args.seed)
                tab.write(args.outputFileName.replace(".fits", "_inputCatalog.fits"), overwrite = True)
            else:
                tab=atpy.Table(names = ('RADeg', 'decDeg'))
    else:
        tab=atpy.Table().read(args.catalog)
    
    # Optional signal scaling (useful for diff alpha sims)
    if 'y_c' in tab.keys():
        tab['y_c']=tab['y_c']*args.scale
    
    # Build a dictionary containing the model images which we'll stich together at the end
    print(">>> Building models in tiles ...")
    modelsDict={}
    for tileName in config.tileNames:
        print("... %s ..." % (tileName))
        shape=(config.tileCoordsDict[tileName]['clippedSection'][3]-config.tileCoordsDict[tileName]['clippedSection'][2], 
               config.tileCoordsDict[tileName]['clippedSection'][1]-config.tileCoordsDict[tileName]['clippedSection'][0])
        wcs=astWCS.WCS(config.tileCoordsDict[tileName]['header'], mode = 'pyfits')
        try:
            modelsDict[tileName]=maps.makeModelImage(shape, wcs, tab, args.beamFileName, 
                                                     obsFreqGHz = args.obsFreqGHz,
                                                     validAreaSection = config.tileCoordsDict[tileName]['areaMaskInClipSection'],
                                                     TCMBAlpha = args.TCMBAlpha,
                                                     profile = args.profile)
        except:
            raise Exception("makeModelImage failed on tile '%s'" % (tileName))

    # Gathering
    #if config.MPIEnabled == True:
        #config.comm.barrier()
        #gathered_modelsDicts=config.comm.gather(modelsDict, root = 0)
        #if config.rank == 0:
            #print("... gathered sky model tiles ...")
            #for tileModelDict in gathered_modelsDicts:
                #for tileName in tileModelDict.keys():
                    #modelsDict[tileName]=tileModelDict[tileName]
    
    # We can't just gather as that triggers a 32-bit overflow (?)
    # So, sending one object at a time
    if config.MPIEnabled == True:
        config.comm.barrier()
        if config.rank > 0:
            print("... rank = %d sending sky model tiles ..." % (config.rank))
            config.comm.send(modelsDict, dest = 0)
        elif config.rank == 0:
            print("... gathering sky model tiles ...")
            gathered_modelsDicts=[]
            gathered_modelsDicts.append(modelsDict)
            for source in range(1, config.size):
                gathered_modelsDicts.append(config.comm.recv(source = source))
            for tileModelDict in gathered_modelsDicts:
                for tileName in tileModelDict.keys():
                    modelsDict[tileName]=tileModelDict[tileName]

    # Stitching
    print(">>> Stitching tiles ...")
    if config.rank == 0:
        d=np.zeros([config.origWCS.header['NAXIS2'], config.origWCS.header['NAXIS1']])
        wcs=config.origWCS
        for tileName in modelsDict.keys():
            print("... %s ..." % (tileName))
            minX, maxX, minY, maxY=config.tileCoordsDict[tileName]['clippedSection']
            if modelsDict[tileName] is not None:
                d[minY:maxY, minX:maxX]=d[minY:maxY, minX:maxX]+modelsDict[tileName]
        # Optionally add CMB
        if args.addCMB == True:
            shape=(d.shape)
            d=d+maps.simCMBMap(shape, wcs, noiseLevel = args.addWhiteNoise, beamFileName = args.beamFileName, 
                               seed = args.seed)
            # For testing abrupt noise changes
            if args.splitNoiseTest == True:
                d[:int(d.shape[0]/2)]=d[:int(d.shape[0]/2)]+np.random.normal(0, 2*args.addWhiteNoise, [int(d.shape[0]/2), d.shape[1]])
                # Doesn't really matter about absolute scaling of this
                weights=np.ones(d.shape)*args.addWhiteNoise
                weights[:int(d.shape[0]/2)]=weights[:int(d.shape[0]/2)]*2
                weights=np.power(weights, -2)
                astImages.saveFITS(args.outputFileName.replace(".fits", ".ivar.fits"), weights, wcs)
        astImages.saveFITS(args.outputFileName, d, wcs)
