#!/usr/bin/env python

"""

Make a map containing model objects using the info in a nemo output catalog

"""

import os
import sys
import numpy as np
import astropy.table as atpy
import astropy.io.fits as pyfits
from astLib import *
from nemo import startUp, maps, catalogs, signals
import argparse
import astropy.io.fits as pyfits
on_rtd=os.environ.get('READTHEDOCS', None)
if on_rtd is None:
    import pyccl as ccl

#------------------------------------------------------------------------------------------------------------
def makeParser():
    
    parser=argparse.ArgumentParser("nemoModel")
    parser.add_argument("catalog", help = "Either the path to a Nemo FITS-table format catalog,\
                        or 'pointsources-N' (to generate a test catalog of N point sources that will be\
                        inserted into the map, e.g., pointsources-1000 will insert 1000 sources). If\
                        the latter, the catalog will be written to outputFileName_inputCatalog.fits\
                        (with the .fits extension stripped from outputFileName). If the former,\
                        cosmological parameters may be specified in the FITS header using the\
                        OM0, OB0, H0, SIGMA8, NS keywords (used only for cluster models at the\
                        moment).")
    parser.add_argument("maskFileName", help = "A FITS image file, containing a mask of the desired sky\
                        area. Non-zero values in the mask are used to define tiles (typically 10 x 5 deg),\
                        which are processed in parallel if MPI is enabled (see -M switch). The output sky\
                        model image will have the same pixelization as the mask image.")
    parser.add_argument("beamFileName", help = "A file containing the beam profile, in the standard format\
                        used by ACT.")
    parser.add_argument("outputFileName", help = "The name of the output file that will contain the sky\
                        model image.")
    parser.add_argument("-f", "--frequency-GHz", dest = "obsFreqGHz", type = float, default = 150.0,
                        help = "If the nemo catalog contains SZ-selected clusters, the SZ signal will be\
                        evaluted at the given frequency, ignoring relativistic effects (default: 150.0).")
    parser.add_argument("-s", "--scale-signals", dest = "scale", type = float, default = 1.0,
                        help = "Scale the input y_c values of clusters in the catalog by this factor.")
    parser.add_argument("-p", "--profile", dest = "profile", type = str, default = "A10",
                        help = "For cluster models only - select the profile to use: 'A10'\
                        (Arnaud et al. 2010 UPP models) or 'B12' (Battaglia et al. 2012 models). The mass and\
                        redshift used to set the shape and scale of the cluster model uses the 'template'\
                        column in the Nemo catalog, if the 'true_M500c', 'true_fixed_y_c', and 'true_Q' columns\
                        are not present. This requires the filter labels that appear in the 'template'\
                        column to be in the format 'Profile_M{$MASS}_z{$REDSHIFT}', with the decimal point\
                        replaced by the letter 'p' (e.g., 'Arnaud_M2e14_z0p4', as in the example config\
                        files for Nemo).")
    parser.add_argument("-C", "--add-cmb", dest = "addCMB", action = "store_true", default = False,
                        help = "Add a realization of the cosmic microwave background to the map. Use\
                        --cmb-seed to set the random seed to the same value if creating sets of maps at\
                        different frequencies. Note that this is very memory intensive for large maps,\
                        and is a serial operation.")
    parser.add_argument("-S", "--seed", dest = "seed", type = int, default = None,
                        help = "Random seed used only for generating (optional) cosmic microwave\
                        background or source catalog realizations (i.e., the seed is not used for\
                        noise realizations).")
    parser.add_argument("-N", "--add-noise", dest = "addNoise", default = 0.0,
                        help = "If a random cosmic microwave background realization had been added, add\
                        white noise. If a number is given, uniform per-pixel white noise at the specified\
                        level is added. If a number followed by 'sb' is given (e.g., 40sb), then constant\
                        surface brightness noise (per square arcmin) is added, adjusting for varying pixel\
                        scale across the map (if present). Otherwise, a path to an inverse variance map\
                        can be given (must have the same pixelization as the supplied mask).")
    parser.add_argument("-k", "--lknee", dest = "lKnee", default = None, type = float,
                        help = "Must be used with -N (--add-noise) and -C (--add-cmb). If given,\
                        1/f noise with power spectrum N_l = (1 + l/lknee)^-3) will be generated and added\
                        to the map (see Appendix A of MacCrann et al. 2023). Reasonable values to use\
                        are 2000 for ACT f090, and 3000 for ACT f150.")
    parser.add_argument("-A", "--add-map", dest = "addMap", default = None,
                        help = "Path to a FITS map (same pixelization as the input mask) that will be\
                        added to the output sim map. Useful if you want to add extra pre-computed signals\
                        to the map (e.g., Galactic dust, large scale noise components etc.). Amplitudes\
                        can be modified with the --add-map-scaling parameter.")
    parser.add_argument("--add-map-scaling", dest = "addMapScaling", default = 1.0,
                        help = "If given, multiply the map pointed to by --add-map by this factor.")
    parser.add_argument("--split-noise-test", dest = "splitNoiseTest", action = "store_true",
                        help = "If set, and -N and -C switches are used, split the map into two\
                        sections, with the white noise level doubled in one half of the map.")
    parser.add_argument("-T", "--break-map-into-tiles", dest = "breakIntoTiles", action = "store_true",
                        help = "Break large maps into tiles using the autotiler function in Nemo.\
                        This will be turned on automatically if MPI is enabled (using -M).", 
                        default = False)
    parser.add_argument("-a", "--tcmb-alpha", dest = "TCMBAlpha", type = float, default = 0.0,
                        help = "Only applies to cluster models. Set this to a non-zero value to generate\
                        a model where the CMB temperature varies as T(z) = T0 * (1+z)^{1-TCMBAlpha}.\
                        Requires a 'redshift' column to be present in the input catalog.")
    parser.add_argument("-M", "--mpi", dest="MPIEnabled", action="store_true", help="Enable MPI. If used,\
                        the image will be broken into a number of tiles, with one tile per process. If you\
                        want to use this, run with e.g., mpiexec -np 4 nemoModel args -M", 
                        default = False)
    parser.add_argument("-n", "--no-strict-errors", dest="noStrictMPIExceptions", action="store_true",
                        help="Disable strict exception handling (applies under MPI only, i.e., must be\
                        used with the -M switch). If you use this option, you will get the full traceback\
                        when a Python Exception is triggered, but the code may not terminate. This is due\
                        to the Exception handling in mpi4py.", default = False)
    
    return parser

#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser=makeParser()
    args=parser.parse_args()
    
    if args.noStrictMPIExceptions == True:
        strictMPIExceptions=False
    else:
        strictMPIExceptions=True

    if args.addMap is not None:
        assert(os.path.exists(args.addMap) == True)

    # Create a stub config (then we can use existing start-up stuff to dish out the tiles)
    print(">>> Setting up ...")
    parDict={}
    mapDict={}
    mapDict['mapFileName']=args.maskFileName
    mapDict['obsFreqGHz']=args.obsFreqGHz
    mapDict['beamFileName']=args.beamFileName
    parDict['unfilteredMaps']=[mapDict]
    parDict['mapFilters']=[]
    parDict['useTiling']=False
    parDict['reprojectToTan']=False
    if args.MPIEnabled == True or args.breakIntoTiles == True:
        parDict['useTiling']=True
        parDict['tileOverlapDeg']=1.0
        parDict['tileDefinitions']={'mask': args.maskFileName,
                                    'targetTileWidthDeg': 10.0, 'targetTileHeightDeg': 5.0}

    config=startUp.NemoConfig(parDict, MPIEnabled = args.MPIEnabled, divideTilesByProcesses = True,
                              makeOutputDirs = False, setUpMaps = True, writeTileInfo = False,
                              verbose = False, strictMPIExceptions = strictMPIExceptions)
    
    baseDir, fileName=os.path.split(args.outputFileName)
    if baseDir != '':
        os.makedirs(baseDir, exist_ok = True)

    # Noise:
    # - if number, uniform white noise level per pixel
    # - if number with sb suffix, uniform white noise level per square arcmin
    # - if string, a path to ivar map
    if config.rank == 0:
        if type(args.addNoise) == str and args.addNoise[-2:] == 'sb':
            addNoise=float(args.addNoise.split('sb')[0])
            noiseMode='perSquareArcmin'
        else:
            noiseMode='perPixel'
            try:
                addNoise=float(args.addNoise)
            except:
                # Assumed inv var map path
                with pyfits.open(args.addNoise) as img:
                    if img[0].data.ndim == 2:
                        sigma=img[0].data
                    else:
                        sigma=img[0].data[0]
                    addNoiseWCS=astWCS.WCS(img[0].header, mode = 'pyfits')
                valid=sigma > 1e-7
                sigma[valid]=np.sqrt(1/sigma[valid])
                addNoise=sigma
    else:
        addNoise=None
        noiseMode=None

    if args.catalog[:12] == 'pointsources':
        try:
            numSources=int(args.catalog.split("-")[-1])
        except:
            raise Exception("Use format pointsources-N, e.g., pointsources-100 will generate a test catalog of 100 sources.")
        with pyfits.open(args.maskFileName) as img:
            for extName in img:
                data=img[extName].data
                if data is not None:
                    break
            wcs=astWCS.WCS(img[extName].header, mode = 'pyfits')
            if numSources > 0:
                tab=catalogs.generateRandomSourcesCatalog(data, wcs, numSources, seed = args.seed)
                tab.write(args.outputFileName.replace(".fits", "_inputCatalog.fits"), overwrite = True)
            else:
                tab=atpy.Table(names = ('RADeg', 'decDeg'))
    else:
        tab=atpy.Table().read(args.catalog)

    # Optionally override the fiducial cosmology (only useful for cluster sims)
    keywords=['OM0', 'OB0', 'H0', 'SIGMA8', 'NS']
    keyCount=0
    for key in keywords:
        if key in tab.meta.keys():
            keyCount=keyCount+1
    if keyCount == len(keywords):
        print(">>> Using cosmology specified in header for catalog %s [only affects painted cluster sizes]" % (args.catalog))
        cosmoModel=ccl.Cosmology(Omega_c = tab.meta['OM0']-tab.meta['OB0'], Omega_b = tab.meta['OB0'],
                                 h = 0.01*tab.meta['H0'], sigma8 = tab.meta['SIGMA8'], n_s = tab.meta['NS'],
                                 transfer_function = signals.transferFunction)
    else:
        print(">>> Using fiducial cosmology")
        cosmoModel=signals.fiducialCosmoModel

    # Optional signal scaling (useful for diff alpha sims)
    if 'y_c' in tab.keys():
        tab['y_c']=tab['y_c']*args.scale
    
    # Build a dictionary containing the model images which we'll stich together at the end
    print(">>> Building models in tiles ...")
    modelsDict={}
    for tileName in config.tileNames:
        print("... %s ..." % (tileName))
        shape=(config.tileCoordsDict[tileName]['clippedSection'][3]-config.tileCoordsDict[tileName]['clippedSection'][2], 
               config.tileCoordsDict[tileName]['clippedSection'][1]-config.tileCoordsDict[tileName]['clippedSection'][0])
        wcs=astWCS.WCS(config.tileCoordsDict[tileName]['header'], mode = 'pyfits')
        try:
            modelsDict[tileName]=maps.makeModelImage(shape, wcs, tab, args.beamFileName,
                                                     obsFreqGHz = args.obsFreqGHz,
                                                     validAreaSection = config.tileCoordsDict[tileName]['areaMaskInClipSection'],
                                                     TCMBAlpha = args.TCMBAlpha,
                                                     profile = args.profile, cosmoModel = cosmoModel)
        except:
            raise Exception("makeModelImage failed on tile '%s'" % (tileName))

    # Gathering
    #if config.MPIEnabled == True:
        #config.comm.barrier()
        #gathered_modelsDicts=config.comm.gather(modelsDict, root = 0)
        #if config.rank == 0:
            #print("... gathered sky model tiles ...")
            #for tileModelDict in gathered_modelsDicts:
                #for tileName in tileModelDict.keys():
                    #modelsDict[tileName]=tileModelDict[tileName]
    
    # We can't just gather as that triggers a 32-bit overflow (?)
    # So, sending one object at a time
    if config.MPIEnabled == True:
        config.comm.barrier()
        if config.rank > 0:
            print("... rank = %d sending sky model tiles ..." % (config.rank))
            config.comm.send(modelsDict, dest = 0)
        elif config.rank == 0:
            print("... gathering sky model tiles ...")
            gathered_modelsDicts=[]
            gathered_modelsDicts.append(modelsDict)
            for source in range(1, config.size):
                gathered_modelsDicts.append(config.comm.recv(source = source))
            for tileModelDict in gathered_modelsDicts:
                for tileName in tileModelDict.keys():
                    modelsDict[tileName]=tileModelDict[tileName]

    # Stitching
    print(">>> Stitching tiles ...")
    if config.rank == 0:
        d=np.zeros([config.origWCS.header['NAXIS2'], config.origWCS.header['NAXIS1']])
        wcs=config.origWCS
        for tileName in modelsDict.keys():
            print("... %s ..." % (tileName))
            minX, maxX, minY, maxY=config.tileCoordsDict[tileName]['clippedSection']
            if modelsDict[tileName] is not None:
                d[minY:maxY, minX:maxX]=d[minY:maxY, minX:maxX]+np.array(modelsDict[tileName])
        # Optionally add CMB
        if args.addCMB == True:
            # Save tSZ-only map for debugging
            astImages.saveFITS(args.outputFileName.replace(".fits", "_signalOnly.fits"), d, wcs)
            shape=(d.shape)
            # Split adding CMB and noise to enable more tests
            d=d+maps.simCMBMap(shape, wcs, noiseLevel = None, beam = args.beamFileName,
                               seed = args.seed)
            astImages.saveFITS(args.outputFileName.replace(".fits", "_signalAndCMB.fits"), d, wcs)
            # Support trimming of inv var map, if we have one that doesn't match area mask
            # NOTE: As it's noise, we don't actually care about matching RA, dec coords strictly
            # So we don't worry about +/-1 pixel offsets here
            # But we do have to deal with RA wrapping, and whether RAMax needs to increase/decrease
            if type(addNoise) == np.ndarray and shape != addNoise.shape:
                RAMin, RAMax, decMin, decMax=wcs.getImageMinMaxWCSCoords()
                clipOk=False
                count=0
                lastWidth=None
                direction=1
                while clipOk is False:
                    count=count+1
                    clip=astImages.clipUsingRADecCoords(addNoise, addNoiseWCS, RAMin, RAMax, decMin, decMax)
                    if lastWidth is not None:
                        if lastWidth > clip['data'].shape[1]:
                            direction=direction*-1
                    lastWidth=clip['data'].shape[1]
                    if clip['data'].shape[0] < d.shape[0]:
                        decMax=decMax+0.01
                    if clip['data'].shape[1] < d.shape[1]:
                        RAMax=RAMax+0.01*direction
                    if clip['data'].shape[0] >= d.shape[0] and clip['data'].shape[1] >= d.shape[1]:
                        clipOk=True
                    if count > 20:
                        raise Exception("Clipping ivar map to suitable size failed to converge quickly enough.")
                addNoise=clip['data'][:d.shape[0], :d.shape[1]]
                # raise Exception("Given inv var map does not have the same dimensions as the output sim map - should match the given mask")
            d=d+maps.simNoiseMap(shape, addNoise, wcs = wcs, lKnee = args.lKnee, noiseMode = noiseMode)
            # For testing abrupt noise changes
            if args.splitNoiseTest == True:
                d[:int(d.shape[0]/2)]=d[:int(d.shape[0]/2)]+np.random.normal(0, 2*args.addNoise, [int(d.shape[0]/2), d.shape[1]])
                # Doesn't really matter about absolute scaling of this
                weights=np.ones(d.shape)*args.addNoise
                weights[:int(d.shape[0]/2)]=weights[:int(d.shape[0]/2)]*2
                weights=np.power(weights, -2)
                astImages.saveFITS(args.outputFileName.replace(".fits", ".ivar.fits"), weights, wcs)
        # Optionally add another component (e.g., large scale noise, to start with)
        if args.addMap is not None:
            with pyfits.open(args.addMap) as img:
                d=d+float(args.addMapScaling)*img[0].data
        astImages.saveFITS(args.outputFileName, d, wcs)
