#!/usr/bin/env python

"""

Calculate masses of clusters detected by nemo

Requires 'massOptions' in nemo config file

Run this after filtering / detecting objects in a map with nemo itself

Can be used to obtain 'forced photometry', i.e., mass estimates for objects in redshiftCatalog
(for, e.g., optical stacking)

"""

import os
import sys
import numpy as np
import pylab as plt
import astropy.table as atpy
from astLib import *
from scipy import stats
from scipy import interpolate
import nemo
from nemo import catalogs
from nemo import signals
from nemo import maps
from nemo import filters
from nemo import MockSurvey
from nemo import photometry
from nemo import startUp
from nemo import completeness
from nemo import pipelines
on_rtd=os.environ.get('READTHEDOCS', None)
if on_rtd is None:
    import pyccl as ccl
import argparse
import astropy.io.fits as pyfits
import time
import yaml

#------------------------------------------------------------------------------------------------------------
def calcMass(tab, scalingRelation, QFit, fRelWeightsDict, mockSurvey, otherMassEstimates = [],
             inferSZProperties = False):
    """Calculates masses for cluster data in table.
    
    """

    refMassDef=ccl.halos.MassDef(scalingRelation['delta'], scalingRelation['rhoType'])

    if 'relativisticCorrection' not in scalingRelation.keys():
        scalingRelation['relativisticCorrection']=True
        
    # Experimenting with E(z)^gamma instead of E(z)^2
    if 'Ez_gamma' not in scalingRelation.keys():
        scalingRelation['Ez_gamma']=2

    # Experimenting with arbitrary (1+z)^something
    if 'onePlusRedshift_power' not in scalingRelation.keys():
        scalingRelation['onePlusRedshift_power']=0.0
    
    # print("scalingRelation", scalingRelation)

    # Extra columns for other inferred quantities
    if inferSZProperties == True:
        extraCols=['Q', 'theta500Arcmin', 'inferred_y_c', 'inferred_Y500Arcmin2']
        extraDivs=[1, 1, 1e-4, 1e-4] # Divide quantities by this (e.g, inferred_y_c in output table will be * 1e-4)
    else:
        extraCols=[]
        extraDivs=[]

    # Add all columns for all used mass definitions
    if 'label' not in scalingRelation.keys():
        prefix=''
    else:
        prefix=scalingRelation['label']+"_"
    labels=['M%d%s' % (scalingRelation['delta'], scalingRelation['rhoType'][0])]
    for massDefDict in otherMassEstimates:
        labels.append('M%d%s' % (massDefDict['delta'], massDefDict['rhoType'][0]))
    for l in labels:
        colNames=[prefix+'%s' % (l), prefix+'%sUncorr' % (l)]
        if 'rescaleFactor' in scalingRelation.keys():
            if 'rescaleLabel' not in scalingRelation.keys():
                rescaleLabel='Cal'
            else:
                rescaleLabel=scalingRelation['rescaleLabel']
            colNames.append(prefix+'%s%s' % (l, rescaleLabel))
        # colNames=colNames+["M200m", "M200mUncorr"] # We will generalize fully later so user can choose mass defs
        for c in colNames:
            if c in tab.keys():
                raise Exception("Already found column named '%s' - add 'label' keys to scalingRelation dictionaries in massOptions?" % (c))
            tab['%s' % (c)]=np.zeros(len(tab))
            tab['%s_errPlus' % (c)]=np.zeros(len(tab))
            tab['%s_errMinus' % (c)]=np.zeros(len(tab))
            if l == labels[0]:
                for mc in extraCols:
                    tab[mc]=np.zeros(len(tab))
                    tab[mc+"_err"]=np.zeros(len(tab))

    label=labels[0]
    assert(label == mockSurvey.mdefLabel)

    count=0
    for row in tab:
        count=count+1
        print("... rank %d; %d/%d; %s (%.3f +/- %.3f) ..." % (config.rank, count, len(tab), row['name'], 
                                                              row['redshift'], row['redshiftErr']))

        tileName=row['tileName']
        
        # Cuts on z, fixed_y_c for forced photometry mode (invalid objects will be listed but without a mass)
        # Added cut on z here to avoid blow-ups for extreme low-z
        if row['redshift'] < 0.027: # Ok, this is highly tuned... but is the lowest z in ACT DR6
            print("... skipping %s - redshift < 0.027 ..." % (row['name']))
            continue
        if row['fixed_y_c'] > 0 and np.isnan(row['redshift']) == False:
            # Corrected for mass function steepness
            massDict=signals.inferClusterProperties(row['fixed_y_c']*1e-4, row['fixed_err_y_c']*1e-4,
                                            row['redshift'], row['redshiftErr'],
                                            tenToA0 = scalingRelation['tenToA0'],
                                            B0 = scalingRelation['B0'],
                                            Mpivot = scalingRelation['Mpivot'],
                                            sigma_int = scalingRelation['sigma_int'],
                                            Ez_gamma = scalingRelation['Ez_gamma'],
                                            onePlusRedshift_power = scalingRelation['onePlusRedshift_power'],
                                            QFit = QFit, mockSurvey = mockSurvey, 
                                            applyMFDebiasCorrection = True,
                                            applyRelativisticCorrection = scalingRelation['relativisticCorrection'],
                                            fRelWeightsDict = fRelWeightsDict[tileName],
                                            tileName = tileName,
                                            inferSZProperties = inferSZProperties)
            if massDict is None:
                print("... skipping %s - failed to get mass (perhaps Q outside valid range for this M, z) ..." % (row['name']))
                continue
            row[prefix+'%s' % (label)]=massDict['%s' % (label)]
            row[prefix+'%s_errPlus' % (label)]=massDict['%s_errPlus' % (label)]
            row[prefix+'%s_errMinus' % (label)]=massDict['%s_errMinus' % (label)]
            for mc, div in zip(extraCols, extraDivs):
                row[mc]=massDict[mc]/div
                row[mc+"_err"]=massDict[mc+"_err"]/div
            # Uncorrected for mass function steepness
            unCorrMassDict=signals.inferClusterProperties(row['fixed_y_c']*1e-4, row['fixed_err_y_c']*1e-4,
                                                    row['redshift'], row['redshiftErr'],
                                                    tenToA0 = scalingRelation['tenToA0'],
                                                    B0 = scalingRelation['B0'],
                                                    Mpivot = scalingRelation['Mpivot'],
                                                    sigma_int = scalingRelation['sigma_int'],
                                                    Ez_gamma = scalingRelation['Ez_gamma'],
                                                    onePlusRedshift_power = scalingRelation['onePlusRedshift_power'],
                                                    QFit = QFit, mockSurvey = mockSurvey, 
                                                    applyMFDebiasCorrection = False,
                                                    applyRelativisticCorrection = scalingRelation['relativisticCorrection'],
                                                    fRelWeightsDict = fRelWeightsDict,
                                                    tileName = tileName)
            row[prefix+'%sUncorr' % (label)]=unCorrMassDict['%s' % (label)]
            row[prefix+'%sUncorr_errPlus' % (label)]=unCorrMassDict['%s_errPlus' % (label)]
            row[prefix+'%sUncorr_errMinus' % (label)]=unCorrMassDict['%s_errMinus' % (label)]
            # Re-scaling (e.g., using richness-based weak-lensing mass calibration)
            if 'rescaleFactor' in scalingRelation.keys():
                row[prefix+'%s%s' % (label, rescaleLabel)]=massDict['%s' % (label)]/scalingRelation['rescaleFactor']
                row[prefix+'%s%s_errPlus' % (label, rescaleLabel)]=np.sqrt(np.power(row['%s_errPlus' % (label)]/row['%s' % (label)], 2) + \
                                                       np.power(scalingRelation['rescaleFactorErr']/scalingRelation['rescaleFactor'], 2))*row['%s%s' % (label, rescaleLabel)]
                row[prefix+'%s%s_errMinus' % (label, rescaleLabel)]=np.sqrt(np.power(row['%s_errMinus' % (label)]/row['%s' % (label)], 2) + \
                                                        np.power(scalingRelation['rescaleFactorErr']/scalingRelation['rescaleFactor'], 2))*row['%s%s' % (label, rescaleLabel)]
                calMassDict={label: row[prefix+'%s%s' % (label, rescaleLabel)],
                             label+'_errPlus': row[prefix+'%s%s_errPlus' % (label, rescaleLabel)],
                             label+'_errMinus': row[prefix+'%s%s_errMinus' % (label, rescaleLabel)]}

            # CCL-based mass conversions
            resultsList=[massDict, unCorrMassDict]
            suffixList=['', 'Uncorr']
            if 'rescaleFactor' in scalingRelation.keys():
                resultsList.append(calMassDict)
                suffixList.append(rescaleLabel)
            for resultDict, suffix in zip(resultsList, suffixList):
                for massDefDict in otherMassEstimates:
                    if 'concMassRelation' not in massDefDict.keys():
                        massDefDict['concMassRelation']=None
                    thisLabel=prefix+'M%d%s' % (massDefDict['delta'], massDefDict['rhoType'][0])
                    thisMassDef=ccl.halos.MassDef(massDefDict['delta'], massDefDict['rhoType'])
                    thisMass=signals.MDef1ToMDef2(resultDict[label]*1e14, row['redshift'], refMassDef, thisMassDef, mockSurvey.cosmoModel,
                                                  c_m_relation = massDefDict['concMassRelation'])/1e14
                    row[thisLabel+suffix]=thisMass
                    row[thisLabel+suffix+'_errPlus']=(row[label+suffix+'_errPlus']/row[label+suffix])*row[thisLabel+suffix]
                    row[thisLabel+suffix+'_errMinus']=(row[label+suffix+'_errMinus']/row[label+suffix])*row[thisLabel+suffix]

    return tab

#------------------------------------------------------------------------------------------------------------
def makeParser():
    
    parser=argparse.ArgumentParser("nemoMass")
    parser.add_argument("configFileName", help="A .yml configuration file.")
    parser.add_argument("-c", "--catalog", dest="catFileName", help = "Catalog file name (.fits format).\
                        The catalog must contain at least the following columns: name, RADeg, decDeg, \
                        redshift, redshiftErr. If the catalog contains fixed_y_c, fixed_err_y_c columns,\
                        then these will be used to infer mass estimates. If not, 'forced photometry' mode \
                        will be enabled, and the fixed_y_c, fixed_err_y_c values will be extracted from the\
                        filtered maps.", default = None)
    parser.add_argument("-o", "--output", dest="outFileName", help = "Output catalog file name \
                        (.fits format). If not given, the name of the output catalog file will be based on\
                        either configFileName or catFileName.", default = None)
    parser.add_argument("-Q", "--Q-source", dest="QSource", help = "Source of the Q function data - either\
                        'fit' (the 'classic' method), 'injection' (for Q based on source injection test\
                        results), or 'hybrid' (uses the 'fit' method for scales smaller than the reference\
                        filter scale, and the 'injection' method for scales larger than the reference filter\
                        scale).", default = 'fit')
    parser.add_argument("-x", "--x-match-arcmin", dest="xMatchArcmin", default = 2.5,
                        help = "Specifies the cross-match radius (in arcmin) to use when matching the redshift\
                        catalog onto the cluster candidate catalog.")
    parser.add_argument("-z", "--z-column", dest="zColumnName", help = "Specifies the name of the redshift\
                        column in the input catalog.", default = None)
    parser.add_argument("-e", "--z-error-column", dest="zErrColumnName", help = "Specifies the name of the \
                        redshift uncertainty column in the input catalog.", default = None)
    parser.add_argument("-F", "--forced-photometry", dest="forcedPhotometry", help = "Perform forced photometry.\
                        This is automatically enabled if the catalog does not contain the fixed_y_c, \
                        fixed_err_y_c columns. Use this switch to force using this mode even if the \
                        catalog already contains fixed_y_c, fixed_err_y_c  columns (e.g., for doing forced\
                        photometry on one ACT map using positions of clusters found in another, deeper \
                        map).", default = False, action="store_true")
    parser.add_argument("-I", "--infer-sz-properties", dest = "inferSZProperties", default = False,
                        action="store_true", help = "Infer (model-dependent) SZ observables and related\
                        quantities such as angular size (theta500Arcmin), Q, central Comptonization parameter\
                        (inferred_y_c), and Y500 (inferred_Y500Arcmin2). Y500 values currently assume the Arnaud+2010\
                        Universal Pressure Profile, even if alternative GNFW parameters are given in the config file.")
    parser.add_argument("-m", "--make-model-maps", dest = "makeModelMaps", default = False,
                        action = "store_true", help = "Make model maps and masks for the SZ signal\
                        corresponding to the maps given in the config file by running the nemoModel script.\
                        Output will be written to the clusterModelMaps directory under the Nemo output\
                        directory.")
    parser.add_argument("-M", "--mpi", dest="MPIEnabled", action="store_true", help="Enable MPI. If you \
                        want to use this, run with e.g., mpiexec -np 4 nemoMass configFile.yml -M", 
                        default = False)
    parser.add_argument("-n", "--no-strict-errors", dest="noStrictMPIExceptions", action="store_true",
                        help="Disable strict exception handling (applies under MPI only, i.e., must be\
                        used with the -M switch). If you use this option, you will get the full traceback\
                        when a Python Exception is triggered, but the code may not terminate. This is due\
                        to the Exception handling in mpi4py.", default = False)
    parser.add_argument("-v", "--version", action = 'version', version = '%(prog)s' + ' %s' % (nemo.__version__))

    return parser

#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    
    parser = makeParser()
    args = parser.parse_args()
    
    parDictFileName=args.configFileName
    catFileName=args.catFileName
    outFileName=args.outFileName
    forcedPhotometry=args.forcedPhotometry
    if forcedPhotometry == True and args.MPIEnabled == True:
        print("Aborted - cannot use forced photometry mode with MPI currently")
        sys.exit()
    QSource=args.QSource
    if QSource not in ['fit', 'injection', 'hybrid']:
        raise Exception("QSource must be either 'fit', 'injection', or 'hybrid'")
    
    if args.noStrictMPIExceptions == True:
        strictMPIExceptions=False
    else:
        strictMPIExceptions=True
    
    # Load the nemo catalog and match against the z catalog
    # NOTE: This is now done using coord matching (nearest within some maximum tolerance), rather than names
    if catFileName is None:
        config=startUp.NemoConfig(parDictFileName, MPIEnabled = args.MPIEnabled, divideTilesByProcesses = False,
                                  makeOutputDirs = False, setUpMaps = False, verbose = False,
                                  strictMPIExceptions = strictMPIExceptions)
        optimalCatalogFileName=config.rootOutDir+os.path.sep+"%s_optimalCatalog.fits" % (os.path.split(config.rootOutDir)[-1])           
        nemoTab=atpy.Table().read(optimalCatalogFileName)
        zTab=atpy.Table().read(config.parDict['massOptions']['redshiftCatalog'])
        if args.xMatchArcmin != 'rayleigh':
            xMatchArcmin=float(args.xMatchArcmin)
            nemoTab, zTab, rDeg=catalogs.crossMatch(nemoTab, zTab, radiusArcmin = xMatchArcmin)
        else:
            nemoTab, zTab, rDeg=catalogs.crossMatch(nemoTab, zTab, radiusArcmin = 2.5)
            mask=[]
            for row, zRow, r in zip(nemoTab, zTab, rDeg):
                mask.append(catalogs.checkCrossMatchRayleigh(r*60.0, row['fixed_SNR']))
            nemoTab=nemoTab[mask]
            zTab=zTab[mask]
            # raise Exception("Fancier cross-matching not implemented yet.")
        nemoTab['redshift']=zTab['redshift']
        if 'redshiftErr' in zTab.keys():
            nemoTab['redshiftErr']=zTab['redshiftErr']
        else:
            print("... WARNING: no redshiftErr column found in '%s' - assuming redshiftErr = 0 for all clusters ..." %  (config.parDict['massOptions']['redshiftCatalog']))
            nemoTab['redshiftErr']=0
        tab=nemoTab
        if outFileName is None:
            outFileName=optimalCatalogFileName.replace("_optimalCatalog.fits", "_mass.fits")

    else:
        
        # Load another catalog (e.g., a mock, for testing)
        optimalCatalogFileName=catFileName
        tab=atpy.Table().read(optimalCatalogFileName)
        if outFileName is None:
            outFileName=catFileName.replace(".fits", "_mass.fits")
        
        # Enter forced photometry mode if we can't find the columns we need
        # If we're doing forced photometry, we're going to need to set-up so we can find the filtered maps
        keysNeeded=['fixed_y_c', 'fixed_err_y_c']
        for key in keysNeeded:
            if key not in tab.keys():
                forcedPhotometry=True
        config=startUp.NemoConfig(parDictFileName, MPIEnabled = args.MPIEnabled, divideTilesByProcesses = False,
                                  setUpMaps = False, writeTileInfo = False, verbose = False,
                                  strictMPIExceptions = strictMPIExceptions)

    # Clean out any z = 0 things that have crept through because these blow things up
    tab=tab[tab['redshift'] > 0]
    # tab=tab[:10] # For testing
    
    # Remaining set-up
    massOptions=config.parDict['massOptions']
    Q=signals.QFit(QSource, selFnDir = config.selFnDir)
    fRelWeightsDict=signals.getFRelWeights(config)
    
    if 'otherMassEstimates' in config.parDict.keys() and config.parDict['massOptions'] is not None:
        otherMassEstimates=config.parDict['otherMassEstimates']
    else:
        # Default to match older versions of Nemo
        otherMassEstimates=[]#[{'delta': 200, 'rhoType': 'matter', 'concMassRelation': "Bhattacharya13"}]

    # Forced photometry (if enabled) - modifying table in place
    # NOTE: Move this up if/when we make it run under MPI
    if forcedPhotometry == True:
        tab=photometry.addForcedPhotometry(catFileName, config, args.zColumnName, args.zErrColumnName)
    
    # Set cosmological parameters for e.g. E(z) calc if these are set in .par file
    # We set them after the Q calc, because the Q calc needs to be for the fiducial cosmology
    # (OmegaM0 = 0.3, OmegaL0 = 0.7, H0 = 70 km/s/Mpc) used in object detection/filtering stage
    # Set-up the mass function stuff also
    # This is for correction of mass bias due to steep cluster mass function
    # Hence minMass here needs to be set well below the survey mass limit
    # areaDeg2 we don't care about here
    minMass=1e13
    areaDeg2=700.
    zMin=0.0
    zMax=4.0
    # H0, Om0, Ol0 used for E(z), theta500 calcs in Q - these are updated when we call create mockSurvey
    # NOTE: startUp now sets defaults for these if not given in config
    H0=massOptions['H0']
    Om0=massOptions['Om0']
    Ob0=massOptions['Ob0']
    sigma8=massOptions['sigma8']
    ns=massOptions['ns']

    tab.add_column(atpy.Column(np.arange(len(tab)), "sortIndex"))
    if config.MPIEnabled == True:
        numRowsPerProcess=int(np.ceil(len(tab)/config.size))
        startIndex=config.rank*numRowsPerProcess
        endIndex=startIndex+numRowsPerProcess
        if config.rank == config.size-1:
            endIndex=len(tab)
        tab=tab[startIndex:endIndex]

    # Multiple scaling relations now supported in one run
    # BUT will only infer SZ properties for the first scaling relation in the list
    count=0
    for scalingRelation in massOptions['scalingRelations']:
        if scalingRelation == massOptions['scalingRelations'][0]:
            inferSZProperties=args.inferSZProperties
        else:
            inferSZProperties=False
        count=count+1
        print(">>> Scaling relation %d/%d" % (count, len(massOptions['scalingRelations'])))
        mockSurvey=MockSurvey.MockSurvey(minMass, areaDeg2, zMin, zMax, H0, Om0, Ob0, sigma8, ns,
                                     rhoType = scalingRelation['rhoType'], delta = scalingRelation['delta'])
        tab=calcMass(tab, scalingRelation, Q, fRelWeightsDict, mockSurvey, otherMassEstimates = otherMassEstimates,
                     inferSZProperties = inferSZProperties)
    
    if config.MPIEnabled == True:
        tabList=config.comm.gather(tab, root = 0)
        if config.rank != 0:
            assert tabList is None
            print("... MPI rank %d finished ..." % (config.rank))
            sys.exit()
        else:
            print("... gathering catalogs ...")
            tab=atpy.vstack(tabList)
    
    tab.sort('sortIndex')
    tab.remove_column('sortIndex')
    
    outDir=os.path.split(outFileName)[0]
    if outDir != '':
        os.makedirs(outDir, exist_ok = True)
    tab.meta['NEMOVER']=nemo.__version__
    tab.meta['QSOURCE']=QSource
    # Cosmo params useful for when running with nemo-sim-kit
    tab.meta['OM0']=massOptions['Om0']
    tab.meta['OB0']=massOptions['Ob0']
    tab.meta['H0']=massOptions['H0']
    tab.meta['SIGMA8']=massOptions['sigma8']
    tab.meta['NS']=massOptions['ns']
    # We don't need scaling params any more because we may have multiple scaling relations
    # tab.meta['TENTOA0']=massOptions['tenToA0']
    # tab.meta['B0']=massOptions['B0']
    # tab.meta['SIGMA']=massOptions['sigma_int']
    # tab.meta['MPIVOT']=massOptions['Mpivot']/1e14
    # tab.meta['RHOTYPE']=massOptions['rhoType']
    # tab.meta['DELTA']=massOptions['delta']
    tab.write(outFileName, overwrite = True)
    print("... mass table written to %s ..." % (outFileName))

    # Optionally making model images
    # These take ~2 min per map for ~7000 clusters in an ACT DR5/DR6 scale map on a single core
    if args.makeModelMaps == True:
        del MockSurvey
        print(">>> Making model cluster SZ signal maps")
        t0=time.time()
        os.makedirs(outDir+os.path.sep+"clusterModelMaps", exist_ok = True)
        catFileName=outFileName
        for mapDict in config.parDict['unfilteredMaps']:
            print("... %s ..." % (mapDict['label']))
            outModelFileName=outDir+os.path.sep+"clusterModelMaps"+os.path.sep+"clusterModelMap_%s.fits" % (mapDict['label'])
            cmd="nemoModel %s %s %s %s -f %.1f -m" % (catFileName, mapDict['surveyMask'], mapDict['beamFileName'],
                                                      outModelFileName, mapDict['obsFreqGHz'])
            os.system(cmd)
        t1=time.time()
        print("... time taken to generate model cluster SZ signal maps = %.1f sec" % (t1-t0))
