#!/usr/bin/env python

"""

Make a map containing model objects using the info in a nemo output catalog

"""

import os
import sys
import numpy as np
import astropy.table as atpy
import astropy.io.fits as pyfits
from astLib import *
import nemo
from nemo import startUp, maps, catalogs, signals
import argparse
import astropy.io.fits as pyfits
import time
on_rtd=os.environ.get('READTHEDOCS', None)
if on_rtd is None:
    import pyccl as ccl

#------------------------------------------------------------------------------------------------------------
def makeParser():
    
    parser=argparse.ArgumentParser("nemoModel")
    parser.add_argument("catalog", help = "Either the path to a Nemo FITS-table format catalog,\
                        or 'pointsources-N' (to generate a test catalog of N point sources that will be\
                        inserted into the map, e.g., pointsources-1000 will insert 1000 sources). If\
                        the latter, the catalog will be written to outputFileName_inputCatalog.fits\
                        (with the .fits extension stripped from outputFileName). If the former,\
                        cosmological parameters may be specified in the FITS header using the\
                        OM0, OB0, H0, SIGMA8, NS keywords (used only for cluster models at the\
                        moment).")
    parser.add_argument("maskFileName", help = "A FITS image file, containing a mask of the desired sky\
                        area. Non-zero values in the mask are used to define tiles (typically 10 x 5 deg),\
                        which are processed in parallel if MPI is enabled (see -M switch). The output sky\
                        model image will have the same pixelization as the mask image.")
    parser.add_argument("beamFileName", help = "A file containing the beam profile, in the standard format\
                        used by ACT.")
    parser.add_argument("outputFileName", help = "The name of the output file that will contain the sky\
                        model image.")
    parser.add_argument("-f", "--frequency-GHz", dest = "obsFreqGHz", type = float, default = 150.0,
                        help = "If the nemo catalog contains SZ-selected clusters, the SZ signal will be\
                        evaluted at the given frequency, ignoring relativistic effects (default: 150.0).")
    parser.add_argument("-s", "--scale-signals", dest = "scale", type = float, default = 1.0,
                        help = "Scale the input y_c values of clusters in the catalog by this factor.")
    parser.add_argument("-p", "--profile", dest = "profile", type = str, default = "A10",
                        help = "For cluster models only - select the profile to use: 'A10'\
                        (Arnaud et al. 2010 UPP models) or 'B12' (Battaglia et al. 2012 models). The mass and\
                        redshift used to set the shape and scale of the cluster model uses the 'template'\
                        column in the Nemo catalog, if the 'true_M500c', 'true_fixed_y_c', and 'true_Q' columns\
                        are not present. This requires the filter labels that appear in the 'template'\
                        column to be in the format 'Profile_M{$MASS}_z{$REDSHIFT}', with the decimal point\
                        replaced by the letter 'p' (e.g., 'Arnaud_M2e14_z0p4', as in the example config\
                        files for Nemo).")
    parser.add_argument("-C", "--add-cmb", dest = "addCMB", action = "store_true", default = False,
                        help = "Add a realization of the cosmic microwave background to the map. Use\
                        --cmb-seed to set the random seed to the same value if creating sets of maps at\
                        different frequencies. Note that this is very memory intensive for large maps,\
                        and is a serial operation.")
    parser.add_argument("-S", "--seed", dest = "seed", type = int, default = None,
                        help = "Random seed used only for generating (optional) cosmic microwave\
                        background or source catalog realizations (i.e., the seed is not used for\
                        noise realizations).")
    parser.add_argument("-N", "--add-noise", dest = "addNoise", default = 0.0,
                        help = "If a random cosmic microwave background realization had been added, add\
                        white noise. If a number is given, uniform per-pixel white noise at the specified\
                        level is added. If a number followed by 'sb' is given (e.g., 40sb), then constant\
                        surface brightness noise (per square arcmin) is added, adjusting for varying pixel\
                        scale across the map (if present). Otherwise, a path to an inverse variance map\
                        can be given (must have the same pixelization as the supplied mask).")
    parser.add_argument("-k", "--lknee", dest = "lKnee", default = None, type = float,
                        help = "Must be used with -N (--add-noise) and -C (--add-cmb). If given,\
                        1/f noise with power spectrum N_l = (1 + l/lknee)^-3) will be generated and added\
                        to the map (see Appendix A of MacCrann et al. 2023). Reasonable values to use\
                        are 2000 for ACT f090, and 3000 for ACT f150.")
    parser.add_argument("-A", "--add-map", dest = "addMap", default = None,
                        help = "Path to a FITS map (same pixelization as the input mask) that will be\
                        added to the output sim map. Useful if you want to add extra pre-computed signals\
                        to the map (e.g., Galactic dust, large scale noise components etc.). Amplitudes\
                        can be modified with the --add-map-scaling parameter.")
    parser.add_argument("--add-map-scaling", dest = "addMapScaling", default = 1.0,
                        help = "If given, multiply the map pointed to by --add-map by this factor.")
    parser.add_argument("-a", "--tcmb-alpha", dest = "TCMBAlpha", type = float, default = 0.0,
                        help = "Only applies to cluster models. Set this to a non-zero value to generate\
                        a model where the CMB temperature varies as T(z) = T0 * (1+z)^{1-TCMBAlpha}.\
                        Requires a 'redshift' column to be present in the input catalog.")
    parser.add_argument("-M", "--mpi", dest="MPIEnabled", action="store_true", help="Enable MPI. If used,\
                        the catalog to be painted will be divided amongst processes. If you\
                        want to use this, run with e.g., mpiexec -np 4 nemoModel args -M", 
                        default = False)
    parser.add_argument("-n", "--no-strict-errors", dest="noStrictMPIExceptions", action="store_true",
                        help="Disable strict exception handling (applies under MPI only, i.e., must be\
                        used with the -M switch). If you use this option, you will get the full traceback\
                        when a Python Exception is triggered, but the code may not terminate. This is due\
                        to the Exception handling in mpi4py.", default = False)
    parser.add_argument("-v", "--version", action = 'version', version = '%(prog)s' + ' %s' % (nemo.__version__))

    return parser

#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser=makeParser()
    args=parser.parse_args()
    
    if args.noStrictMPIExceptions == True:
        strictMPIExceptions=False
    else:
        strictMPIExceptions=True

    if args.addMap is not None:
        assert(os.path.exists(args.addMap) == True)

    areaMask, wcs=maps.chunkLoadMask(args.maskFileName)
    shape=areaMask.shape
    del areaMask

    baseDir, fileName=os.path.split(args.outputFileName)
    if baseDir != '':
        os.makedirs(baseDir, exist_ok = True)

    if args.MPIEnabled == True:
        from mpi4py import MPI
        # This is needed to get MPI to abort if one process crashes (due to mpi4py error handling)
        # If this part is disabled, we get nice exceptions, but the program will never finish if a process dies
        # Here we get the error message at least but not the traceback before MPI Aborts
        if strictMPIExceptions == True:
            sys_excepthook=sys.excepthook
            def mpi_excepthook(v, t, tb):
                sys_excepthook(v, t, tb)
                print("Exception: %s" % (t.args[0]))
                MPI.COMM_WORLD.Abort(1)
            sys.excepthook=mpi_excepthook
        comm=MPI.COMM_WORLD
        size=comm.Get_size()
        rank=comm.Get_rank()
        if size == 1:
            raise Exception("If you want to use MPI, run with e.g., mpiexec -np 4 nemoModel ... -M")
    else:
        rank=0
        comm=None
        size=1

    # Noise:
    # - if number, uniform white noise level per pixel
    # - if number with sb suffix, uniform white noise level per square arcmin
    # - if string, a path to ivar map
    if rank == 0:
        if type(args.addNoise) == str and args.addNoise[-2:] == 'sb':
            addNoise=float(args.addNoise.split('sb')[0])
            noiseMode='perSquareArcmin'
        else:
            noiseMode='perPixel'
            try:
                addNoise=float(args.addNoise)
            except:
                # Assumed inv var map path
                with pyfits.open(args.addNoise) as img:
                    if img[0].data.ndim == 2:
                        sigma=img[0].data
                    else:
                        sigma=img[0].data[0]
                    addNoiseWCS=astWCS.WCS(img[0].header, mode = 'pyfits')
                valid=sigma > 1e-7
                sigma[valid]=np.sqrt(1/sigma[valid])
                addNoise=sigma
    else:
        addNoise=None
        noiseMode=None

    if args.catalog[:12] == 'pointsources':
        try:
            numSources=int(args.catalog.split("-")[-1])
        except:
            raise Exception("Use format pointsources-N, e.g., pointsources-100 will generate a test catalog of 100 sources.")
        with pyfits.open(args.maskFileName) as img:
            for extName in img:
                data=img[extName].data
                if data is not None:
                    break
            wcs=astWCS.WCS(img[extName].header, mode = 'pyfits')
            if numSources > 0:
                tab=catalogs.generateRandomSourcesCatalog(data, wcs, numSources, seed = args.seed)
                tab.write(args.outputFileName.replace(".fits", "_inputCatalog.fits"), overwrite = True)
            else:
                tab=atpy.Table(names = ('RADeg', 'decDeg'))
    else:
        t0=time.time()
        tab=atpy.Table().read(args.catalog)
        t1=time.time()
        # print("... loaded input catalog (took %.3f sec ; rank = %d)" % (t1-t0, config.rank))

    # Optionally override the fiducial cosmology (only useful for cluster sims)
    keywords=['OM0', 'OB0', 'H0', 'SIGMA8', 'NS']
    keyCount=0
    for key in keywords:
        if key in tab.meta.keys():
            keyCount=keyCount+1
    if keyCount == len(keywords):
        if rank == 0: print(">>> Using cosmology specified in header for catalog %s [only affects painted cluster sizes]" % (args.catalog))
        cosmoModel=ccl.Cosmology(Omega_c = tab.meta['OM0']-tab.meta['OB0'], Omega_b = tab.meta['OB0'],
                                 h = 0.01*tab.meta['H0'], sigma8 = tab.meta['SIGMA8'], n_s = tab.meta['NS'],
                                 transfer_function = signals.transferFunction)
    else:
        if rank == 0: print(">>> Using fiducial cosmology")
        cosmoModel=signals.fiducialCosmoModel

    # Optional signal scaling (useful for diff alpha sims)
    if 'y_c' in tab.keys():
        tab['y_c']=tab['y_c']*args.scale

    # Divide models to be painted among processes
    tab=catalogs.getCatalogWithinImage(tab, shape, wcs)
    if args.MPIEnabled == True:
        numRowsPerProcess=int(np.ceil(len(tab)/size))
        startIndex=rank*numRowsPerProcess
        endIndex=startIndex+numRowsPerProcess
        if rank == size-1:
            endIndex=len(tab)
        tab=tab[startIndex:endIndex]

    # Build a dictionary containing the model images which we'll stich together at the end
    if rank == 0: print(">>> Building models ...")
    t0=time.time()
    modelImage=maps.makeModelImage(shape, wcs, tab, args.beamFileName,
                                   obsFreqGHz = args.obsFreqGHz,
                                   TCMBAlpha = args.TCMBAlpha,
                                   profile = args.profile, cosmoModel = cosmoModel,
                                   reportTimingInfo = False)
    t1=time.time()
    print("... rank %d image complete (took %.3f sec)" % (rank, t1-t0))

    if args.MPIEnabled == True:
        # comm.barrier()
        if rank > 0:
            print("... rank = %d sending sky model image" % (rank))
            comm.send(modelImage, dest = 0)
            del modelImage
        elif rank == 0:
            print("... gathering sky model images")
            for source in range(1, size):
                recModelImage=comm.recv(source = source)
                if recModelImage is not None:
                    modelImage=modelImage+recModelImage
            del recModelImage

    # Assembly with optional CMB and noise
    if rank == 0:
        print(">>> Assembling final model image ...")
        if modelImage is None:  # e.g., if we run with pointsources-0 to just get CMB and noise
            d=np.zeros(shape)
        else:
            d=modelImage
        # Optionally add CMB
        if args.addCMB == True:
            # Save tSZ-only map for debugging
            astImages.saveFITS(args.outputFileName.replace(".fits", "_signalOnly.fits"), d, wcs)
            shape=(d.shape)
            # Split adding CMB and noise to enable more tests
            d=d+maps.simCMBMap(shape, wcs, noiseLevel = None, beam = args.beamFileName,
                               seed = args.seed)
            astImages.saveFITS(args.outputFileName.replace(".fits", "_signalAndCMB.fits"), d, wcs)
            # Support trimming of inv var map, if we have one that doesn't match area mask
            # NOTE: As it's noise, we don't actually care about matching RA, dec coords strictly
            # So we don't worry about +/-1 pixel offsets here
            # But we do have to deal with RA wrapping, and whether RAMax needs to increase/decrease
            if type(addNoise) == np.ndarray and shape != addNoise.shape:
                RAMin, RAMax, decMin, decMax=wcs.getImageMinMaxWCSCoords()
                clipOk=False
                count=0
                lastWidth=None
                direction=1
                while clipOk is False:
                    count=count+1
                    clip=astImages.clipUsingRADecCoords(addNoise, addNoiseWCS, RAMin, RAMax, decMin, decMax)
                    if lastWidth is not None:
                        if lastWidth > clip['data'].shape[1]:
                            direction=direction*-1
                    lastWidth=clip['data'].shape[1]
                    if clip['data'].shape[0] < d.shape[0]:
                        decMax=decMax+0.01
                    if clip['data'].shape[1] < d.shape[1]:
                        RAMax=RAMax+0.01*direction
                    if clip['data'].shape[0] >= d.shape[0] and clip['data'].shape[1] >= d.shape[1]:
                        clipOk=True
                    if count > 20:
                        raise Exception("Clipping ivar map to suitable size failed to converge quickly enough.")
                addNoise=clip['data'][:d.shape[0], :d.shape[1]]
                # raise Exception("Given inv var map does not have the same dimensions as the output sim map - should match the given mask")
            d=d+maps.simNoiseMap(shape, addNoise, wcs = wcs, lKnee = args.lKnee, noiseMode = noiseMode)
        # Optionally add another component (e.g., large scale noise, to start with)
        if args.addMap is not None:
            with pyfits.open(args.addMap) as img:
                d=d+float(args.addMapScaling)*img[0].data
        astImages.saveFITS(args.outputFileName, d, wcs)
        print("... finished")
