#!/usr/bin/env python

"""

Calculate masses of clusters detected by nemo

Requires 'massOptions' in nemo config file

Run this after filtering / detecting objects in a map with nemo itself

Can be used to obtain 'forced photometry', i.e., mass estimates for objects in redshiftCatalog
(for, e.g., optical stacking)

"""

import os
import sys
import numpy as np
import pylab as plt
import astropy.table as atpy
from astLib import *
from scipy import stats
from scipy import interpolate
import nemo
from nemo import catalogs
from nemo import signals
from nemo import maps
from nemo import filters
from nemo import MockSurvey
from nemo import photometry
from nemo import startUp
from nemo import completeness
from nemo import pipelines
on_rtd=os.environ.get('READTHEDOCS', None)
if on_rtd is None:
    import pyccl as ccl
import argparse
import astropy.io.fits as pyfits
import time
import yaml

#------------------------------------------------------------------------------------------------------------
def addForcedPhotometry(pathToCatalog, config, zColumnName = None, zErrColumnName = None):
    """Given the path to a catalog that contains minimal columns name, RADeg, decDeg, redshift, 
    add forced photometry columns so we can estimate SZ masses.
    
    """
    
    print(">>> Doing forced photometry ...")
    
    zTab=atpy.Table().read(pathToCatalog)
    if zColumnName is not None:
        if 'redshift' in zTab.keys():
            zTab.remove_column('redshift')
        zTab.rename_column(zColumnName, 'redshift')
    if zErrColumnName is not None:
        if 'redshiftErr' in zTab.keys():
            zTab.remove_column('redshiftErr')
        zTab.rename_column(zErrColumnName, 'redshiftErr')
    if 'redshift' not in zTab.keys():
        foundRedshiftCol=False
        possZCols=['z', 'Z', 'REDSHIFT', 'Redshift', 'z_cl', 'Photz']
        for p in possZCols:
            if p in zTab.keys():
                print("... assuming %s is the redshift column ..." % (p))
                foundRedshiftCol=True
                zTab.rename_column(p, 'redshift')
        if foundRedshiftCol == False:
            raise Exception("Couldn't find a redshift column in %s" % (pathToCatalog))
    if 'redshiftErr' not in zTab.keys():
        possZErrCols=['zErr', 'dz']
        for p in possZErrCols:
            if p in zTab.keys():
                print("... assuming %s is the redshiftErr column ..." % (p))
                zTab.rename_column(p, 'redshiftErr')
    if 'redshiftErr' not in zTab.keys():
        print("... assuming redshiftErr = 0 for all objects (no suitable redshiftErr column found) ...")
        zTab.add_column(atpy.Column(np.zeros(len(zTab)), 'redshiftErr'))
        
    config.parDict['forcedPhotometryCatalog']=pathToCatalog
    
    # We need to disable any cut in SNR, otherwise we're not doing truly forced photometry
    # This will allow -ve fixed_y_c, but we'll still only get the +ve masses though
    config.parDict['thresholdSigma']=-100
    
    # Trim all filters except the reference one, then make catalog
    trimmedList=[]
    for mapFilter in config.parDict['mapFilters']:
        if mapFilter['label'] == config.parDict['photFilter']:
            trimmedList.append(mapFilter)
    config.parDict['mapFilters']=trimmedList
    forcedTab=pipelines.filterMapsAndMakeCatalogs(config, useCachedFilteredMaps = True)

    # Need to graft the redshifts back on
    zMatched, forcedMatched, rDeg=catalogs.crossMatch(zTab, forcedTab)
    tab=forcedMatched
    tab.add_column(zMatched['redshift'])
    tab.add_column(zMatched['redshiftErr'])
    
    return tab

#------------------------------------------------------------------------------------------------------------
def calcMass(tab, massOptions, QFit, fRelWeightsDict, mockSurvey, otherMassEstimates = []):
    """Calculates masses for cluster data in table.
    
    """

    refMassDef=ccl.halos.MassDef(massOptions['delta'], massOptions['rhoType'])

    if 'relativisticCorrection' not in massOptions.keys():
        massOptions['relativisticCorrection']=True
        
    # Experimenting with E(z)^gamma instead of E(z)^2
    if 'Ez_gamma' not in massOptions.keys():
        massOptions['Ez_gamma']=2

    # Experimenting with arbitrary (1+z)^something
    if 'onePlusRedshift_power' not in massOptions.keys():
        massOptions['onePlusRedshift_power']=0.0
    
    # print("massOptions", massOptions)

    # Add all columns for all used mass definitions
    labels=['M%d%s' % (massOptions['delta'], massOptions['rhoType'][0])]
    for massDefDict in otherMassEstimates:
        labels.append('M%d%s' % (massDefDict['delta'], massDefDict['rhoType'][0]))
    for l in labels:
        colNames=['%s' % (l), '%sUncorr' % (l)]
        if 'rescaleFactor' in massOptions.keys():
            colNames.append('%sCal' % (l))
        colNames=colNames+["M200m", "M200mUncorr"] # We will generalize fully later so user can choose mass defs
        for c in colNames:
            tab['%s' % (c)]=np.zeros(len(tab))
            tab['%s_errPlus' % (c)]=np.zeros(len(tab))
            tab['%s_errMinus' % (c)]=np.zeros(len(tab))
            if l == labels[0]:
                tab['Q']=np.zeros(len(tab))
    label=labels[0]
    assert(label == mockSurvey.mdefLabel)

    count=0
    for row in tab:
        count=count+1
        print("... rank %d; %d/%d; %s (%.3f +/- %.3f) ..." % (config.rank, count, len(tab), row['name'], 
                                                              row['redshift'], row['redshiftErr']))

        tileName=row['tileName']
        
        # Cuts on z, fixed_y_c for forced photometry mode (invalid objects will be listed but without a mass)
        if row['fixed_y_c'] > 0 and np.isnan(row['redshift']) == False:
            # Corrected for mass function steepness
            massDict=signals.calcMass(row['fixed_y_c']*1e-4, row['fixed_err_y_c']*1e-4, 
                                            row['redshift'], row['redshiftErr'],
                                            tenToA0 = massOptions['tenToA0'],
                                            B0 = massOptions['B0'], 
                                            Mpivot = massOptions['Mpivot'], 
                                            sigma_int = massOptions['sigma_int'],
                                            Ez_gamma = massOptions['Ez_gamma'],
                                            onePlusRedshift_power = massOptions['onePlusRedshift_power'],
                                            QFit = QFit, mockSurvey = mockSurvey, 
                                            applyMFDebiasCorrection = True,
                                            applyRelativisticCorrection = massOptions['relativisticCorrection'],
                                            fRelWeightsDict = fRelWeightsDict[tileName],
                                            tileName = tileName)
            row['%s' % (label)]=massDict['%s' % (label)]
            row['%s_errPlus' % (label)]=massDict['%s_errPlus' % (label)]
            row['%s_errMinus' % (label)]=massDict['%s_errMinus' % (label)]
            row['Q']=massDict['Q']
            # Uncorrected for mass function steepness
            unCorrMassDict=signals.calcMass(row['fixed_y_c']*1e-4, row['fixed_err_y_c']*1e-4, 
                                                    row['redshift'], row['redshiftErr'],
                                                    tenToA0 = massOptions['tenToA0'],
                                                    B0 = massOptions['B0'], 
                                                    Mpivot = massOptions['Mpivot'], 
                                                    sigma_int = massOptions['sigma_int'],
                                                    Ez_gamma = massOptions['Ez_gamma'],
                                                    onePlusRedshift_power = massOptions['onePlusRedshift_power'],
                                                    QFit = QFit, mockSurvey = mockSurvey, 
                                                    applyMFDebiasCorrection = False,
                                                    applyRelativisticCorrection = massOptions['relativisticCorrection'],
                                                    fRelWeightsDict = fRelWeightsDict,
                                                    tileName = tileName)
            row['%sUncorr' % (label)]=unCorrMassDict['%s' % (label)]
            row['%sUncorr_errPlus' % (label)]=unCorrMassDict['%s_errPlus' % (label)]
            row['%sUncorr_errMinus' % (label)]=unCorrMassDict['%s_errMinus' % (label)]
            # Re-scaling (e.g., using richness-based weak-lensing mass calibration)
            if 'rescaleFactor' in massOptions.keys():
                row['%sCal' % (label)]=massDict['%s' % (label)]/massOptions['rescaleFactor']
                row['%sCal_errPlus' % (label)]=np.sqrt(np.power(row['%s_errPlus' % (label)]/row['%s' % (label)], 2) + \
                                                       np.power(massOptions['rescaleFactorErr']/massOptions['rescaleFactor'], 2))*row['%sCal' % (label)]
                row['%sCal_errMinus' % (label)]=np.sqrt(np.power(row['%s_errMinus' % (label)]/row['%s' % (label)], 2) + \
                                                        np.power(massOptions['rescaleFactorErr']/massOptions['rescaleFactor'], 2))*row['%sCal' % (label)]
                calMassDict={label: row['%sCal' % (label)],
                             label+'_errPlus': row['%sCal_errPlus' % (label)],
                             label+'_errMinus': row['%sCal_errMinus' % (label)]}

            # CCL-based mass conversions
            resultsList=[massDict, unCorrMassDict]
            suffixList=['', 'Uncorr']
            if 'rescaleFactor' in massOptions.keys():
                resultsList.append(calMassDict)
                suffixList.append('Cal')
            for resultDict, suffix in zip(resultsList, suffixList):
                for massDefDict in otherMassEstimates:
                    if 'concMassRelation' not in massDefDict.keys():
                        massDefDict['concMassRelation']=None
                    thisLabel='M%d%s' % (massDefDict['delta'], massDefDict['rhoType'][0])
                    thisMassDef=ccl.halos.MassDef(massDefDict['delta'], massDefDict['rhoType'])
                    thisMass=signals.MDef1ToMDef2(resultDict[label]*1e14, row['redshift'], refMassDef, thisMassDef, mockSurvey.cosmoModel,
                                                  c_m_relation = massDefDict['concMassRelation'])/1e14
                    row[thisLabel+suffix]=thisMass
                    row[thisLabel+suffix+'_errPlus']=(row[label+suffix+'_errPlus']/row[label+suffix])*row[thisLabel+suffix]
                    row[thisLabel+suffix+'_errMinus']=(row[label+suffix+'_errMinus']/row[label+suffix])*row[thisLabel+suffix]

    return tab

#------------------------------------------------------------------------------------------------------------
def makeParser():
    
    parser=argparse.ArgumentParser("nemoMass")
    parser.add_argument("configFileName", help="A .yml configuration file.")
    parser.add_argument("-c", "--catalog", dest="catFileName", help = "Catalog file name (.fits format).\
                        The catalog must contain at least the following columns: name, RADeg, decDeg, \
                        redshift, redshiftErr. If the catalog contains fixed_y_c, fixed_err_y_c columns,\
                        then these will be used to infer mass estimates. If not, 'forced photometry' mode \
                        will be enabled, and the fixed_y_c, fixed_err_y_c values will be extracted from the\
                        filtered maps.", default = None)
    parser.add_argument("-o", "--output", dest="outFileName", help = "Output catalog file name \
                        (.fits format). If not given, the name of the output catalog file will be based on\
                        either configFileName or catFileName.", default = None)
    parser.add_argument("-Q", "--Q-source", dest="QSource", help = "Source of the Q function data - either\
                        'fit' (the 'classic' method), 'injection' (for Q based on source injection test\
                        results), or 'hybrid' (uses the 'fit' method for scales smaller than the reference\
                        filter scale, and the 'injection' method for scales larger than the reference filter\
                        scale).", default = 'fit')
    parser.add_argument("-x", "--x-match-arcmin", dest="xMatchArcmin", default = 2.5,
                        help = "Specifies the cross-match radius (in arcmin) to use when matching the redshift\
                        catalog onto the cluster candidate catalog.")
    parser.add_argument("-z", "--z-column", dest="zColumnName", help = "Specifies the name of the redshift\
                        column in the input catalog.", default = None)
    parser.add_argument("-e", "--z-error-column", dest="zErrColumnName", help = "Specifies the name of the \
                        redshift uncertainty column in the input catalog.", default = None)
    parser.add_argument("-F", "--forced-photometry", dest="forcedPhotometry", help = "Perform forced photometry.\
                        This is automatically enabled if the catalog does not contain the fixed_y_c, \
                        fixed_err_y_c columns. Use this switch to force using this mode even if the \
                        catalog already contains fixed_y_c, fixed_err_y_c  columns (e.g., for doing forced\
                        photometry on one ACT map using positions of clusters found in another, deeper \
                        map).", default = False, action="store_true")
    parser.add_argument("-M", "--mpi", dest="MPIEnabled", action="store_true", help="Enable MPI. If you \
                        want to use this, run with e.g., mpiexec -np 4 nemoMass configFile.yml -M", 
                        default = False)
    parser.add_argument("-n", "--no-strict-errors", dest="noStrictMPIExceptions", action="store_true",
                        help="Disable strict exception handling (applies under MPI only, i.e., must be\
                        used with the -M switch). If you use this option, you will get the full traceback\
                        when a Python Exception is triggered, but the code may not terminate. This is due\
                        to the Exception handling in mpi4py.", default = False)
    parser.add_argument("-v", "--version", action = 'version', version = '%(prog)s' + ' %s' % (nemo.__version__))

    return parser

#------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    
    parser = makeParser()
    args = parser.parse_args()
    
    parDictFileName=args.configFileName
    catFileName=args.catFileName
    outFileName=args.outFileName
    forcedPhotometry=args.forcedPhotometry
    QSource=args.QSource
    if QSource not in ['fit', 'injection', 'hybrid']:
        raise Exception("QSource must be either 'fit', 'injection', or 'hybrid'")
    
    if args.noStrictMPIExceptions == True:
        strictMPIExceptions=False
    else:
        strictMPIExceptions=True
    
    # Load the nemo catalog and match against the z catalog
    # NOTE: This is now done using coord matching (nearest within some maximum tolerance), rather than names
    if catFileName is None:
        config=startUp.NemoConfig(parDictFileName, MPIEnabled = args.MPIEnabled, divideTilesByProcesses = False,
                                  makeOutputDirs = False, setUpMaps = False, verbose = False,
                                  strictMPIExceptions = strictMPIExceptions)
        optimalCatalogFileName=config.rootOutDir+os.path.sep+"%s_optimalCatalog.fits" % (os.path.split(config.rootOutDir)[-1])           
        nemoTab=atpy.Table().read(optimalCatalogFileName)
        zTab=atpy.Table().read(config.parDict['massOptions']['redshiftCatalog'])
        if args.xMatchArcmin != 'act':
            xMatchArcmin=float(args.xMatchArcmin)
            nemoTab, zTab, rDeg=catalogs.crossMatch(nemoTab, zTab, radiusArcmin = xMatchArcmin)
        else:
            raise Exception("Fancier cross-matching not implemented yet.")
        nemoTab['redshift']=zTab['redshift']
        if 'redshiftErr' in zTab.keys():
            nemoTab['redshiftErr']=zTab['redshiftErr']
        else:
            print("... WARNING: no redshiftErr column found in '%s' - assuming redshiftErr = 0 for all clusters ..." %  (config.parDict['massOptions']['redshiftCatalog']))
            nemoTab['redshiftErr']=0
        tab=nemoTab
        if outFileName is None:
            outFileName=optimalCatalogFileName.replace("_optimalCatalog.fits", "_mass.fits")

    else:
        
        # Load another catalog (e.g., a mock, for testing)
        optimalCatalogFileName=catFileName
        tab=atpy.Table().read(optimalCatalogFileName)
        if outFileName is None:
            outFileName=catFileName.replace(".fits", "_mass.fits")
        
        # Enter forced photometry mode if we can't find the columns we need
        # If we're doing forced photometry, we're going to need to set-up so we can find the filtered maps
        keysNeeded=['fixed_y_c', 'fixed_err_y_c']
        for key in keysNeeded:
            if key not in tab.keys():
                forcedPhotometry=True
        config=startUp.NemoConfig(parDictFileName, MPIEnabled = args.MPIEnabled, divideTilesByProcesses = False,
                                  setUpMaps = forcedPhotometry, writeTileInfo = False, verbose = False,
                                  strictMPIExceptions = strictMPIExceptions)
    
    # Remaining set-up
    massOptions=config.parDict['massOptions']
    Q=signals.QFit(QSource, selFnDir = config.selFnDir)
    fRelWeightsDict=signals.getFRelWeights(config)
    
    if 'otherMassEstimates' in config.parDict.keys() and config.parDict['massOptions'] is not None:
        otherMassEstimates=config.parDict['otherMassEstimates']
    else:
        # Default to match older versions of Nemo
        otherMassEstimates=[{'delta': 200, 'rhoType': 'matter', 'concMassRelation': "Bhattacharya13"}]

    # Forced photometry (if enabled) - modifying table in place
    # NOTE: Move this up if/when we make it run under MPI
    if forcedPhotometry == True:
        tab=addForcedPhotometry(catFileName, config, args.zColumnName, args.zErrColumnName)
    
    # Set cosmological parameters for e.g. E(z) calc if these are set in .par file
    # We set them after the Q calc, because the Q calc needs to be for the fiducial cosmology
    # (OmegaM0 = 0.3, OmegaL0 = 0.7, H0 = 70 km/s/Mpc) used in object detection/filtering stage
    # Set-up the mass function stuff also
    # This is for correction of mass bias due to steep cluster mass function
    # Hence minMass here needs to be set well below the survey mass limit
    # areaDeg2 we don't care about here
    minMass=1e13
    areaDeg2=700.
    zMin=0.0
    zMax=2.0
    # H0, Om0, Ol0 used for E(z), theta500 calcs in Q - these are updated when we call create mockSurvey
    # NOTE: startUp now sets defaults for these if not given in config
    H0=massOptions['H0']
    Om0=massOptions['Om0']
    Ob0=massOptions['Ob0']
    sigma8=massOptions['sigma8']
    ns=massOptions['ns']
    mockSurvey=MockSurvey.MockSurvey(minMass, areaDeg2, zMin, zMax, H0, Om0, Ob0, sigma8, ns,
                                     rhoType = massOptions['rhoType'], delta = massOptions['delta'])
        
    tab.add_column(atpy.Column(np.arange(len(tab)), "sortIndex"))
    if config.MPIEnabled == True:
        numRowsPerProcess=int(np.ceil(len(tab)/config.size))
        startIndex=config.rank*numRowsPerProcess
        endIndex=startIndex+numRowsPerProcess
        if config.rank == config.size-1:
            endIndex=len(tab)
        tab=tab[startIndex:endIndex]

    tab=calcMass(tab, massOptions, Q, fRelWeightsDict, mockSurvey, otherMassEstimates = otherMassEstimates)
    
    if config.MPIEnabled == True:
        tabList=config.comm.gather(tab, root = 0)
        if config.rank != 0:
            assert tabList is None
            print("... MPI rank %d finished ..." % (config.rank))
            sys.exit()
        else:
            print("... gathering catalogs ...")
            tab=atpy.vstack(tabList)
    
    tab.sort('sortIndex')
    tab.remove_column('sortIndex')
    
    outDir=os.path.split(outFileName)[0]
    if outDir != '':
        os.makedirs(outDir, exist_ok = True)
    tab.meta['NEMOVER']=nemo.__version__
    tab.meta['QSOURCE']=QSource
    tab.write(outFileName, overwrite = True)
    
    # Detect if testing a mock catalog, and write some stats on recovered masses
    if 'true_M500' in tab.keys():
        # Noise sources in mocks
        if 'applyPoissonScatter' in config.parDict.keys():
            applyPoissonScatter=config.parDict['applyPoissonScatter']
        else:
            applyPoissonScatter=True
        if 'applyIntrinsicScatter' in config.parDict.keys():
            applyIntrinsicScatter=config.parDict['applyIntrinsicScatter']
        else:
            applyIntrinsicScatter=True
        if 'applyNoiseScatter' in config.parDict.keys():
            applyNoiseScatter=config.parDict['applyNoiseScatter']
        else:
            applyNoiseScatter=True
        print(">>> Mock noise sources (Poisson, intrinsic, measurement noise) = (%s, %s, %s) ..." % (applyPoissonScatter, applyIntrinsicScatter, applyNoiseScatter))
        if applyNoiseScatter == True and applyIntrinsicScatter == True:
            print("... for these options, median M500 / true_M500 = 1.000 if mass recovery is unbiased ...")
        elif applyNoiseScatter == False and applyIntrinsicScatter == False:
            print("... for these options, median M500Unc / true_M500 = 1.000 if mass recovery is unbiased ...")
        else:
            print("... for these options, both median M500 / true_M500 and median M500Unc / true_M500 will be biased ...")
        ratio=tab['M500']/tab['true_M500']
        ratioUnc=tab['M500Uncorr']/tab['true_M500']
        print(">>> Mock catalog mass recovery stats:")
        SNRCuts=[4, 5, 7, 10]
        for s in SNRCuts:
            mask=np.greater(tab['fixed_SNR'], s)
            print("--> SNR > %.1f (N = %d):" % (s, np.sum(mask))) 
            print("... mean M500 / true_M500 = %.3f +/- %.3f (stdev) +/- %.3f (sterr)" % (np.mean(ratio[mask]), 
                                                                                          np.std(ratio[mask]), 
                                                                                          np.std(ratio[mask])/np.sqrt(np.sum(mask))))
            print("... median M500 / true_M500 = %.3f" % (np.median(ratio[mask])))
            print("... mean M500Unc / true_M500 = %.3f +/- %.3f (stdev) +/- %.3f (sterr)" % (np.mean(ratioUnc[mask]), 
                                                                                          np.std(ratioUnc[mask]), 
                                                                                          np.std(ratioUnc[mask])/np.sqrt(np.sum(mask))))
            print("... median M500Unc / true_M500 = %.3f" % (np.median(ratioUnc[mask])))


