#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
    Copyright 2017 Matt Hilton (matt.hilton@mykolab.com)
    
    This file is part of zCluster.

    zCluster is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    zCluster is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with zCluster.  If not, see <http://www.gnu.org/licenses/>.

"""

import os
import sys
import argparse
import glob
import zCluster
from zCluster import *
import sqlite3
import pylab as plt
import pickle
import numpy as np
from astLib import *
from scipy import interpolate
import astropy.io.fits as pyfits
import astropy.table as atpy
import datetime
import urllib.request
#import multiprocessing as mp
plt.matplotlib.interactive(False)

#------------------------------------------------------------------------------------------------------------
def update_rcParams(dict={}):
    """
    Based on Cristobal's preferred settings.
    
    """
    default = {}
    for tick in ('xtick', 'ytick'):
        default['{0}.major.size'.format(tick)] = 8
        default['{0}.minor.size'.format(tick)] = 4
        default['{0}.major.width'.format(tick)] = 2
        default['{0}.minor.width'.format(tick)] = 2
        default['{0}.labelsize'.format(tick)] = 20
        default['{0}.direction'.format(tick)] = 'in'
    default['xtick.top'] = True
    default['ytick.right'] = True
    default['axes.linewidth'] = 2
    default['axes.labelsize'] = 22
    default['font.size'] = 22
    default['font.family']='sans-serif'
    default['legend.fontsize'] = 18
    default['lines.linewidth'] = 2

    for key in default:
        plt.rcParams[key] = default[key]
    # if any parameters are specified, overwrite anything previously
    # defined
    for key in dict:
        plt.rcParams[key] = dict[key]
        
#-------------------------------------------------------------------------------------------------------------
def parseClusterCatalog(fileName):
    """Parses .fits or .csv table with name, RADeg, decDeg columns into dictionary list.
    
    """
    
    if fileName.split(".")[-1] == "csv":
        tab=atpy.Table().read(fileName, format = 'ascii')
    else:
        tab=atpy.Table().read(fileName)
    catalog=[]
    wantedKeys=['name', 'RADeg', 'decDeg']
    for row in tab:
        objDict={}
        for key in wantedKeys:
            if key == 'name':
                objDict[key]=str(row[key])
            else:
                objDict[key]=row[key]
        catalog.append(objDict)
    
    return catalog

#-------------------------------------------------------------------------------------------------------------
def writePlot(objName, result, zPriorMax, plotsDir):
    """Write a .pdf plot of n(z), z using the given zCluster result dictionary.
    
    """

    if os.path.exists(plotsDir) == False:
        os.makedirs(plotsDir, exist_ok = True)
        
    if result != None:            
        
        # NOTE: in calculateRedshiftAndOdds, prior hasn't been applied to pz, but has been to zOdds
        pz=result['pz']
        zArray=result['pz_z']
        prior=np.ones(pz.shape)
        if zPriorMax != None:
            prior[np.greater(zArray, zPriorMax)]=0.0
            prior[np.less(zArray, 0.05)]=0.0
            pz=pz*prior
            norm=np.trapz(pz, zArray)
            pz=pz/norm   
        
        # Plot of normalised odds and p(z)        
        norm=np.trapz(result['zOdds'], result['pz_z'])
        result['zOdds']=result['zOdds']/norm
        
        plt.figure(figsize=(9, 5))

        ax=plt.axes([0.125, 0.125, 0.845, 0.79])
            
        plt.plot(result['pz_z'], result['zOdds'], 'k', label = '$n_{\\Delta z}(z)$')
        plt.plot(result['pz_z'], pz, 'k:', label = '$n(z)$')

        plotMax=max([pz.max(), result['zOdds'].max()])
        plt.ylabel("$n_{\\Delta z}(z)$ | $n(z)$ (normalised)")
        plt.xlabel("$z$")
        y=np.linspace(result['zOdds'].min(), plotMax*1.5, 4)
        plt.plot([result['z']]*len(y), y, 'k--')
        plt.ylim(0, plotMax*1.1)
        plt.xlim(0, zPriorMax*1.5)
        
        # Faffing with labels so we can lay out plots on a grid
        values, tickLabels=plt.yticks()
        for v, t in zip(values, tickLabels):
            t.set_text('%.1f' % (v))
        plt.yticks(values, tickLabels)
        
        leg=plt.legend(loc = 'upper right', numpoints = 1) 
        leg.draw_frame(False)
        plt.title("%s (z = %.2f)" % (objName.replace("_", " "), result['z'])) 
        plt.savefig(plotsDir+os.path.sep+"pz_"+objName.replace(" ", "_")+".pdf")
        plt.close()

#-------------------------------------------------------------------------------------------------------------
def runOnCatalog(catalog, retriever, retrieverOptions, photoRedshiftEngine, outDir, zPriorMin,\
                 zPriorMax, weightsType, maxRMpc, zMethod, bckCatalogFileName, bckAreaDeg2,\
                 maskPath = None, writeGalaxyCatalogs = False, writeDensityMaps = False, writePlots = False,\
                 rank = 0, zDebias = None, fitZPOffsets = False, maxIter = 5, minBackgroundAreaMpc2 = 11,\
                 fetchAndCacheOnly = False):
    """Runs zCluster algorithm on each object in catalog.
    
    """
    
    # Now using sqlite database for storage, rather than loads of .pickle files (filesystem friendly)
    # We're only storing final results in here - if the user wants to see n(z) later, they can use -n option 
    # to rerun specific clusters
    # Each rank has it's own database to write into, and they all share a global database for reading
    conn=sqlite3.connect(outDir+os.path.sep+"redshifts_rank%d.db" % (rank))
    c=conn.cursor()
    c.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='redshifts'")
    matches=c.fetchall()
    if len(matches) == 0:
        c.execute("""CREATE TABLE redshifts (name text, RADeg real, decDeg real, origRADeg real, origDecDeg real, 
                  offsetArcmin real, offsetMpc real, z real, delta real, errDelta real, CS real, A real)""")
    
    # The global database for checking if we've already measured a redshift - read only
    if os.path.exists(outDir+os.path.sep+"redshifts_global.db") == True:
        connGlobal=sqlite3.connect(outDir+os.path.sep+"redshifts_global.db")
        cGlobal=connGlobal.cursor()
    else:
        cGlobal=None
    
    # Optional area mask
    if maskPath is not None:
        with pyfits.open(maskPath) as img:
            maskMap=img[0].data
            maskWCS=astWCS.WCS(img[0].header, mode = 'pyfits')
    else:
        maskMap=None
        maskWCS=None
        
    # Optional background catalog (usually would only be used with proprietary photometric survey, e.g., SOAR)
    if bckCatalogFileName != None and bckAreaDeg2 != None:
        bckTab=atpy.Table().read(bckCatalogFileName)
        bckCatalog=retrievers.parseFITSPhotoTable(bckTab, fieldIDKey = 'field', optionsDict = retrieverOptions)
        photoRedshiftEngine.calcPhotoRedshifts(bckCatalog, calcMLRedshiftAndOdds = True)
    else:
        bckCatalog=[]
    
    count=0
    for obj in catalog:
        count=count+1
        print(">>> %s (%d/%d):" % (obj['name'], count, len(catalog)))
        objOutDir=outDir+os.path.sep+obj['name']
        os.makedirs(objOutDir, exist_ok = True)
                            
        # We actually query both the process database AND the global database
        c.execute("SELECT name FROM redshifts where name=?", (obj['name'],))
        matches=c.fetchall()
        if cGlobal != None:
            cGlobal.execute("SELECT name FROM redshifts where name=?", (obj['name'],))
            globalMatches=cGlobal.fetchall()
        else:
            globalMatches=[]
        if len(matches) == 0 and len(globalMatches) == 0:
            stuff="retry"
            while stuff == "retry":
                stuff=retriever(obj['RADeg'], obj['decDeg'], optionsDict = retrieverOptions)
            if fetchAndCacheOnly == True:
                continue
            galaxyCatalog=stuff
            # Default values... will get filled if sucessful
            obj['origRADeg']=obj['RADeg']
            obj['origDecDeg']=obj['decDeg']
            obj['offsetArcmin']=-99
            obj['offsetMpc']=-99
            obj['z']=-99
            obj['delta']=0
            obj['errDelta']=0
            obj['CS']=-99
            obj['A']=-99
            if galaxyCatalog is not None:
                
                if fitZPOffsets == True:
                    photoRedshiftEngine.calcZeroPointOffsets(galaxyCatalog)
                photoRedshiftEngine.calcPhotoRedshifts(galaxyCatalog, calcMLRedshiftAndOdds = True)                    
                
                # Iterate on position of optical peak - estimate z, find centre, estimate z again...
                converged=False
                obj['origRADeg'], obj['origDecDeg']=obj['RADeg'], obj['decDeg']
                RADeg, decDeg=obj['RADeg'], obj['decDeg']
                numIter=0
                while converged == False and numIter < maxIter:
                    numIter=numIter+1
                    result=clusters.estimateClusterRedshift(RADeg, decDeg, galaxyCatalog, zPriorMin, 
                                                                  zPriorMax, weightsType, maxRMpc, zMethod, 
                                                                  maskMap = maskMap, maskWCS = maskWCS,
                                                                  bckCatalog = bckCatalog, bckAreaDeg2 = bckAreaDeg2,
                                                                  minBackgroundAreaMpc2 = minBackgroundAreaMpc2)
                    if result is None:
                        break
                    else:
                        # We want the offset relative to the original position, not the last iteration
                        dMapDict=clusters.makeDensityMap(obj['origRADeg'], obj['origDecDeg'], galaxyCatalog, 
                                                               result['z'], dz = 0.05)
                        result['offsetArcmin']=dMapDict['offsetArcmin']
                        result['offsetMpc']=dMapDict['offsetMpc']
                        result['CS']=dMapDict['CS']
                        result['A']=dMapDict['A']
                        newRADeg=dMapDict['cRADeg']
                        newDecDeg=dMapDict['cDecDeg']
                        if newRADeg == RADeg and newDecDeg == decDeg:
                            converged=True
                        RADeg=newRADeg
                        decDeg=newDecDeg

                # We're now updating object coords in-place to whatever we think the peak is
                # NOTE: On-the-fly z-debias gets applied here now, if active
                if result is not None:
                    obj['RADeg']=RADeg
                    obj['decDeg']=decDeg
                    obj['offsetArcmin']=result['offsetArcmin']
                    obj['offsetMpc']=result['offsetMpc']
                    obj['z']=result['z']
                    obj['CS']=result['CS']
                    obj['A']=result['A']
                    # Optional: de-bias right here (use with caution)
                    if zDebias is not None:
                        obj['z']=obj['z']+zDebias*(1+obj['z'])
                    obj['delta']=result['delta']
                    obj['errDelta']=result['errDelta']
                    print("... %s: z = %.2f; delta = %.1f; errDelta = %.1f; offsetArcmin = %.1f; offsetMpc = %.1f; CS = %.1f; A = %.1f ..." 
                          % (obj['name'], obj['z'], obj['delta'], obj['errDelta'], obj['offsetArcmin'], obj['offsetMpc'], obj['CS'], obj['A']))
                    if writeDensityMaps == True:
                        astImages.saveFITS(objOutDir+os.path.sep+"densityMap_%s.fits" % (obj['name'].replace(" ", "_")), 
                                           dMapDict['map'], dMapDict['wcs']) 
                        # Useful for debugging, but NOT if we specified a pre-made mask
                        if maskPath is None:    
                            astImages.saveFITS(objOutDir+os.path.sep+"areaMap_%s.fits" % (obj['name'].replace(" ", "_")), 
                                               result['areaMask'], result['wcs'])
                else:
                    obj['RADeg']=obj['origRADeg']
                    obj['decDeg']=obj['origDecDeg']
                    
                ## This is just to stop pickle-related fail if under MPI
                ## Really we should tidy up the area mask making stuff and call before estimateClusterRedshift
                #if obj['zClusterResult'] is not None:
                    #if 'wcs' in obj['zClusterResult'].keys():
                        #del obj['zClusterResult']['wcs']
                    #if 'areaMask' in obj['zClusterResult'].keys():
                        #del obj['zClusterResult']['areaMask']

                # Write out galaxy catalog with photo-zs in both FITS and DS9 .reg format
                if writeGalaxyCatalogs == True:
                    # Also r-i colour, useful for sanity checking of BCGs
                    for gobj in galaxyCatalog:
                        if 'r' in list(gobj.keys()) and 'i' in list(gobj.keys()):
                            gobj['r-i']=gobj['r']-gobj['i']
                        if 'r' in list(gobj.keys()) and 'z' in list(gobj.keys()):
                            gobj['r-z']=gobj['r']-gobj['z']

                    wantedKeys=['id', 'RADeg', 'decDeg', 'u', 'uErr', 'g', 'gErr', 'r', 'rErr', 'i', 
                                'iErr', 'z', 'zErr', 'Ks', 'KsErr', 'w1', 'w1Err', 'w2', 'w2Err', 'r-i', 'r-z', 'zPhot', 'odds']
                    tab=atpy.Table()
                    for key in wantedKeys:
                        arr=[]
                        allSentinelVals=True
                        for gobj in galaxyCatalog:
                            try:
                                arr.append(gobj[key])
                                allSentinelVals=False
                            except:
                                arr.append(99)
                        if allSentinelVals == False:
                            tab[key]=arr
                    # NOTE: to cut down on disk space this takes, include only galaxies within some radius
                    # Chosen one is just over 1.5 Mpc at z = 0.1
                    tab['rDeg']=astCoords.calcAngSepDeg(obj['RADeg'], obj['decDeg'], tab['RADeg'].data, tab['decDeg'].data)
                    #tab=tab[np.where(tab['rDeg'] < 14./60.0)]
                    tab.table_name="zCluster"
                    tab.write(objOutDir+os.path.sep+"galaxyCatalog_%s.fits" % (obj['name'].replace(" ", "_")), overwrite = True)
                    try:
                        catalogs.catalog2DS9(galaxyCatalog, 
                                                 objOutDir+os.path.sep+"galaxyCatalog_%s.reg" % (obj['name'].replace(" ", "_")), 
                                                 idKeyToUse = 'id', 
                                                 addInfo = [{'key': 'r', 'fmt': '%.3f'}, {'key': 'zPhot', 'fmt': '%.2f'}]) 
                    except:
                        print("... couldn't write .reg file (missing key?) ...")
                        pass

                if writePlots == True:
                    writePlot(obj['name'], result, zPriorMax, outDir+os.path.sep+"Plots")


            # Insert a row of data
            c.execute("INSERT INTO redshifts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 
                      (obj['name'], obj['RADeg'], obj['decDeg'], obj['origRADeg'], obj['origDecDeg'], 
                       obj['offsetArcmin'], obj['offsetMpc'], obj['z'], obj['delta'], obj['errDelta'], obj['CS'], obj['A']))
            conn.commit() # May want to think about how often we do this...
            
        else:
            print("... already measured photo-z ...")
            # Prefer the global result, if it exists
            # NOTE: This is dumb if multiple objects have the same name but slightly different coords!
            if len(matches) > 0:
                thisCursor=c
            if len(globalMatches) > 0:
                thisCursor=cGlobal
            thisCursor.execute("""SELECT z, delta, errDelta, RADeg, decDeg, origRADeg, origDecDeg, offsetArcmin, 
                                offsetMpc, CS, A FROM redshifts WHERE name = ?""", (obj['name'],))
            match=thisCursor.fetchone()
            obj['z']=match[0]
            obj['delta']=match[1]
            obj['errDelta']=match[2]
            obj['RADeg']=match[3]
            obj['origRADeg']=match[4]
            obj['origDecDeg']=match[5]
            obj['decDeg']=match[6]
            obj['offsetArcmin']=match[7]
            obj['offsetMpc']=match[8]
            obj['CS']=match[9]
            obj['A']=match[10]
 
    return catalog

#-------------------------------------------------------------------------------------------------------------
def getRoughFootprint(database):
    """Returns RAMin, RAMax, decMin, decMax for given photometric survey indicated by database. Used for
    null tests.
    
    """

    if database in ['SDSSDR7', 'SDSSDR8', 'SDSSDR10', 'SDSSDR12']:
        RAMin, RAMax, decMin, decMax=[130.0, 250.0, 0.0, 60.0]
    elif database == 'S82':
        RAMin, RAMax, decMin, decMax=[-60.0, 60.0, -1.2, 1.2]
    elif database == 'CFHTLenS':
        RAMin, RAMax, decMin, decMax=[30.5, 38.5, -11.5, -2.5]
    else:
        print("WARNING: no rough footprint defined for database '%s' yet" % (database))
        return None
    
    return RAMin, RAMax, decMin, decMax

#-------------------------------------------------------------------------------------------------------------
def makeParser():
    parser=argparse.ArgumentParser("zCluster")
    parser.add_argument("catalogFileName", help="A .fits table or a .csv file with at least the columns\
                        'name', 'RADeg', 'decDeg'.")
    parser.add_argument("database", help="The photometric database to use. Options are 'SDSSDR12', 'S82'\
                        (for SDSS DR7 Stripe 82 co-add); 'CFHTLenS'; 'DESDR1', 'DESDR2',\
                        'DESY3' [experimental; requires access to proprietary DES data]; 'PS1' [experimental];\
                        'DECaLSDR8', 'DECaLSDR9'; or the path to a .fits table with columns in the\
                        appropriate format ('ID', 'RADeg', 'decDeg', and magnitude column names in the form\
                        'u_MAG_AUTO', 'u_MAGERR_AUTO' etc.).")
    parser.add_argument("-o", "--output-label", dest="outLabel", help="Label to use for outputs\
                        (default: catalogFileName_database, stripped of file extension).\
                        A redshift catalog called zCluster_outLabel.fits will be created. Cached\
                        results for each entry in the input cluster catalog and associated plots will\
                        be written into the directory outLabel/, together with a log file that\
                        records the arguments used for running zCluster.", default = None)
    parser.add_argument("-w", "--weights-type", dest="weightsType", help="Radial weighting type. Options\
                        are 'NFW', 'flat', or 'radial' (default: NFW).", default = 'NFW')
    parser.add_argument("-R", "--max-radius-Mpc", dest="maxRMpc", help="Maximum radius in Mpc within\
                        which to calculate delta statistic for each cluster (default: 0.5).", 
                        default = 0.5)
    parser.add_argument("-m", "--mask", dest="mask", help="Path to mask image (FITS format with WCS)\
                        that defines valid survey area.", default = None)
    parser.add_argument("-a", "--algorithm", dest="algorithm", help="Algorithm to use for the maximum likelihood\
                        redshift. Options are 'max' or 'odds' (default: odds).", 
                        default = 'odds')
    parser.add_argument("-i", "--max-iter", dest="maxIter", help="Maximum number of iterations for finding\
                        the cluster redshift and optical position based on projected density map (default: 1).",
                        default = 1, type = int)
    parser.add_argument("-c", "--cachedir", dest="cacheDir", default = None, help="Cache directory location\
                        (default: $HOME/.zCluster/cache). Downloaded photometric catalogs will be stored\
                        here.")
    parser.add_argument("-F", "--fetch-only-and-cache", dest="fetchAndCacheOnly",
                        help="Only fetch and cache galaxy catalogs - i.e., do not run photo-z estimation.\
                        This is useful if you are running on a compute cluster where compute nodes do\
                        not have internet access (allows a first run with -F, followed by a second run\
                        under MPI to compute the photo-zs in parallel).",
                        default = False, action = "store_true")
    parser.add_argument("-M", "--mpi", dest="MPIEnabled", action="store_true", help="Enable MPI. If you\
                        want to use this, run zCluster using something like: mpirun --np 4 zCluster ...",
                        default = False)
    parser.add_argument("-e", "--max-mag-error", dest="maxMagError", help="Maximum acceptable\
                        photometric error (in magnitudes; default: 0.25).", default = 0.25)
    parser.add_argument("-E", "--photometric-zero-point-error", dest="ZPError", type = float, 
                        help="Global photometric zero point uncertainty in magnitudes, applied to all bands\
                        (default: 0). Added in quadrature to the photometric uncertainties in the catalog.", 
                        default = 0.0)
    parser.add_argument("-f", "--fit-for-zero-point-offsets", dest="fitZPOffsets", 
                        help="If the input catalog contains a z_spec column, use those galaxies to fit\
                        for magnitude zero point offsets. These will then be applied when estimating galaxy\
                        photometric redshifts.", 
                        default = False, action = "store_true")
    parser.add_argument("-z", "--z-prior-min", dest="zPriorMin", help="Set minimum redshift of prior.",
                        default = None)
    parser.add_argument("-Z", "--z-prior-max", dest="zPriorMax", help="Set maximum redshift of prior.",
                        default = None)
    parser.add_argument("-b", "--absmag-cut", dest="absMagCut",
                        help="Set absolute (r-band) magnitude cut to use in magnitude-based prior. If a single number\
                        is given, p(z) for objects brighter than this limit will be set to 0. If a\
                        list of numbers is given (e.g., [-15,-24]), p(z) will be set to 0\
                        for objects outside of the absolute magnitude range.", default = -24.)
    parser.add_argument("-n", "--name", dest="name", help="Find photo-z of only the named cluster in the catalog.")
    parser.add_argument("-t", "--templates-directory", dest="templatesDir", help="Specify a directory containing\
                        a custom set of spectral templates.",  default = None)
    parser.add_argument("-d", "--write-density-maps", dest="writeDensityMaps", action="store_true", 
                        help="Write out a .fits image projected density map (within delta z = +/- 0.1 of the best\
                        fit redshift) for each cluster.", default = False)
    parser.add_argument("-W", "--write-galaxy-catalogs", dest="writeGalaxyCatalogs", action="store_true", 
                        help="Write out a .fits format galaxy catalog and DS9 .reg file for each cluster.",
                        default = False)
    parser.add_argument("-P", "--write-plots", dest="writePlots", action="store_true", 
                        help = "Write out a .pdf plot of n(z) for each cluster.", default = False)
    parser.add_argument("-X", "--add-extra-photometry", dest="addExtraPhoto",
                        help = "If using a user-supplied FITS galaxy photometic catalog, add in additional\
                        photometry from the given database (e.g., SDSSDR12, DECaLS - see 'database'),\
                        if available.", default = None)
    parser.add_argument("--min-background-area-Mpc2", dest="minBackgroundAreaMpc2", default = 11, type = float,
                        help="The minimum area of the background region in square Mpc. Candidate redshifts that have\
                        less than this minimum background area will be filtered out of the output catalog (default: 11,\
                        which corresponds to 50%% of the background area for the standard 3-4 Mpc annulus used).")
    parser.add_argument("-B", "--background-catalog", dest="bckCatalogFileName", help="A .fits table with columns\
                        in the appropriate format ('ID', 'RADeg', 'decDeg', and magnitude column names in the form\
                        'u_MAG_AUTO', 'u_MAGERR_AUTO' etc.) to be used as the background sample for delta estimates.\
                        If this is given, the area covered by the background catalog must be given also (-A flag)")
    parser.add_argument("-A", "--background-area-deg2", dest="bckAreaDeg2", default = None, help="The area,\
                        in square degrees, covered by the background galaxy catalog given using the -B flag.")
    parser.add_argument("-C", "--credentials-filename", dest="credentialsFileName", default = None, 
                        help = "The location of a file containing username (first line), password (second line),\
                        for login to e.g., ESO Portal (this option is only currently used for the KIDSDR3\
                        database option).")
    
    return parser
    
#-------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser=makeParser()
    args=parser.parse_args()
    
    catalogFileName=args.catalogFileName
    database=args.database
    outLabel=args.outLabel
    cacheDir=args.cacheDir
    weightsType=args.weightsType
    maxRMpc=float(args.maxRMpc)
    method=args.algorithm
    MPIEnabled=args.MPIEnabled
    maxMagError=float(args.maxMagError)
    #magsBrighterMStarCut=float(args.magsBrighterMStarCut)
    try:
        absMagCut=float(args.absMagCut)
    except:
        vals=args.absMagCut.replace("[", "").replace("]", "").split(",")
        if len(vals) != 2:
            raise Exception("If you want to give a range for --absmag-cut, write as e.g. [-15,-23]")
        absMagCut=[float(vals[0]), float(vals[1])]
        absMagCut.sort() # Bright mag comes first
    writeGalaxyCatalogs=args.writeGalaxyCatalogs
    writePlots=args.writePlots
    maskPath=args.mask
    ZPError=args.ZPError
    maxIter=args.maxIter
    
    update_rcParams()

    if outLabel == None:
        baseOutputLabel=os.path.split(catalogFileName)[-1]
        baseOutputLabel=baseOutputLabel.replace(".fits", "").replace(".csv", "")
        baseOutputLabel=baseOutputLabel+"_%s" % (database.replace(".fits", ""))
    else:
        baseOutputLabel=outLabel
        
    outDir=baseOutputLabel
    outFileName="zCluster_"+baseOutputLabel+".fits"

    if method not in ['odds', 'max']:
        raise Exception("method must be 'odds' or 'max'")
    
    if weightsType not in ['flat', 'radial', 'NFW']:
        raise Exception("weights must be 'flat', 'radial', or 'NFW'")

    # These are ONLY used if not None...
    bckCatalogFileName=args.bckCatalogFileName
    if args.bckAreaDeg2 == None and bckCatalogFileName != None:
        raise Exception("area covered by separate background galaxy catalogue must be given.")
    if args.bckAreaDeg2 != None:
        bckAreaDeg2=float(args.bckAreaDeg2)
    else:
        bckAreaDeg2=None
    
    if MPIEnabled == True:
        from mpi4py import MPI
        comm=MPI.COMM_WORLD
        size=comm.Get_size()
        rank=comm.Get_rank()
        if size == 1:
            raise Exception("if you want to use MPI, run with mpirun --np 4 zCluster ...")
    else:
        rank=0

    if rank == 0:
        if cacheDir is not None:
            os.makedirs(cacheDir, exist_ok = True)
        else:
            os.makedirs(retrievers.CACHE_DIR, exist_ok = True)

    catalog=parseClusterCatalog(catalogFileName)

    # Method for fetching catalogs
    retriever, retrieverOptions, passbandSet=retrievers.getRetriever(database, maxMagError = 0.2)
    
    # If we're not given a database, we're using a user-supplied FITS catalog
    # We can supplement that with photometry from a chosen database cross-matched onto that catalog
    if retriever is None:
        retriever=retrievers.FITSRetriever
        extraRetriever, extraOptions, extraPassbandSet=retrievers.getRetriever(args.addExtraPhoto)
        if cacheDir is not None:
            extraOptions['altCacheDir']=cacheDir
        retrieverOptions={'fileName': database, 'extraRetriever': extraRetriever, 'extraOptions': extraOptions}
    
    # Redshift priors and passband sets
    zDebias=None    # A fudge used to correct output zs for some surveys (final z = z + zDebias*(1+z))
    ZPOffsets=None
    zPriorMin=0.1
    zPriorMax=2.0
    if database == 'S82':
        zPriorMin=0.20
        zPriorMax=1.5
    elif database in ['SDSSDR7', 'SDSSDR8', 'SDSSDR10', 'SDSSDR12']:
        zPriorMin=0.05
        zPriorMax=0.8
    elif database == 'PS1':
        # Min prior here set to match 4 Mpc max search radius
        zPriorMin=0.15
        zPriorMax=0.6
    elif database in ['DESY3', 'DESDR1']:
        zPriorMin=0.05
        zPriorMax=1.5
        zDebias=0.02    # From testing against spec-zs (don't know yet why there is a mean offset)
    elif database == 'DESY3+WISE':
        zPriorMin=0.05
        zPriorMax=1.5
        zDebias=0.02    # From testing against spec-zs (don't know yet why there is a mean offset)
    elif database == 'KiDSDR4':
        zPriorMin=0.05
        zPriorMax=1.5
        zDebias=0.02    # From testing against spec-zs (don't know yet why there is a mean offset)
    elif database == 'ATLASDR4':
        zPriorMin=0.05
        zPriorMax=0.8
    elif database == 'CFHTLenS':
        zPriorMin=0.05
        zPriorMax=1.5
    elif database == 'DECaLS':
        zPriorMin=0.05
        zPriorMax=1.5
        # Zero point offsets remove bias from galaxy max likelihood photo-zs when testing on SDSS
        # But they don't get rid of bias on cluster photo-zs
        ZPOffsets=np.array([0.02271222, -0.05051711, -0.02465597, -0.00406835, 0.05406105])
        zDebias=0.02    # From testing against spec-zs (don't know yet why there is a mean offset)
    elif database == 'DL_DECaLSDR10' or database == 'DECaLSDR10':
        # WARNING: This is hacky of course, but fastest way to test...
        # ZPOffsets values below from fit with 1 million SDSS spec-zs
        zPriorMin=0.05
        zPriorMax=2.0
        ZPOffsets=np.array([0.04087833, -0.04830536, -0.03361726, -0.02293858, 0.03046254, 0.07543627])

    # Allow prior overrides from command line
    if args.zPriorMax is not None:
        zPriorMax=float(args.zPriorMax)
    if args.zPriorMin is not None:
        zPriorMin=float(args.zPriorMin)

    # Set-up output dir and log the options we're running with
    if rank == 0:
        if os.path.exists(outDir) == False:
            os.makedirs(outDir, exist_ok = True)
        if args.name != None:
            logDir=logFileName=outDir+os.path.sep+args.name
            if os.path.exists(logDir) == False:
                os.makedirs(logDir, exist_ok = True)
            logFileName=outDir+os.path.sep+args.name+os.path.sep+"zCluster.log"
        else:
            logFileName=outDir+os.path.sep+"zCluster.log"
        logFile=open(logFileName, "w")
        logFile.write("started: %s\n" % (datetime.datetime.now().isoformat()))
        for key in list(args.__dict__.keys()): 
            if key not in ['zPriorMin', 'zPriorMax']:
                logFile.write("%s: %s\n" % (key, str(args.__dict__[key])))
        logFile.write("zPriorMin: %.3f\n" % (zPriorMin))
        logFile.write("zPriorMax: %.3f\n" % (zPriorMax))
        logFile.close()
        
    if cacheDir is not None:
        retrieverOptions['altCacheDir']=cacheDir

    retrieverOptions['fetchAndCacheOnly']=args.fetchAndCacheOnly
            
    photoRedshiftEngine=PhotoRedshiftEngine.PhotoRedshiftEngine(absMagCut, passbandSet = passbandSet, 
                                                                ZPError = ZPError, ZPOffsets = ZPOffsets,
                                                                templatesDir = args.templatesDir)
    
    # Optionally only running on a single cluster
    if args.name != None:
        foundObj=False
        for objDict in catalog:
            if objDict['name'] == args.name:
                foundObj=True
                break
        catalog=[objDict]
        if foundObj == False:
            raise Exception("didn't find %s in input catalog" % (args.name))

    if MPIEnabled == True:
        # New - bit clunky but distributes more evenly
        rankCatalogs={}
        rankCounter=0
        for objDict in catalog:
            if rankCounter not in rankCatalogs:
                rankCatalogs[rankCounter]=[]
            rankCatalogs[rankCounter].append(objDict)
            rankCounter=rankCounter+1
            if rankCounter > size-1:
                rankCounter=0
        if rank in rankCatalogs.keys():
            catalog=rankCatalogs[rank]
        else:
            catalog=[]

    catalog=runOnCatalog(catalog, retriever, retrieverOptions, photoRedshiftEngine, outDir,
                         zPriorMin, zPriorMax, weightsType, maxRMpc, method, bckCatalogFileName, bckAreaDeg2,
                         maskPath = maskPath, writeGalaxyCatalogs = writeGalaxyCatalogs, 
                         writeDensityMaps = args.writeDensityMaps, writePlots = writePlots, 
                         rank = rank, zDebias = zDebias, fitZPOffsets = args.fitZPOffsets,
                         maxIter = maxIter, minBackgroundAreaMpc2 = args.minBackgroundAreaMpc2,
                         fetchAndCacheOnly = args.fetchAndCacheOnly)
    
    # If running under MPI, gather everything back together
    # Rank 0 process will continue with plots
    if MPIEnabled == True:
        if rank == 0:
            wholeCatalog=catalog
            for i in range(1, size):
                catalogPart=comm.recv(source = i, tag = 11)
                wholeCatalog=wholeCatalog+catalogPart
            catalog=wholeCatalog
        else:
            comm.send(catalog, dest = 0, tag = 11)
            sys.exit()
                
    # Makes a global .db file (might be useful) after run complete
    if os.path.exists(outDir+os.path.sep+"redshifts_global.db") == True:
        os.remove(outDir+os.path.sep+"redshifts_global.db")
    print(">>> Updating global redshifts .db file ...")
    conn=sqlite3.connect(outDir+os.path.sep+"redshifts_global.db")
    c=conn.cursor()
    c.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='redshifts'")
    matches=c.fetchall()
    if len(matches) == 0:
        c.execute("""CREATE TABLE redshifts (name text, RADeg real, decDeg real, origRADeg real, origDecDeg real,
                  offsetArcmin real, offsetMpc real, z real, delta real, errDelta real, CS real, A real)""")
    for obj in catalog:
        try:
            c.execute("INSERT INTO redshifts VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 
                      (obj['name'], obj['RADeg'], obj['decDeg'], obj['origRADeg'], obj['origDecDeg'], obj['offsetArcmin'],
                       obj['offsetMpc'], obj['z'], obj['delta'], obj['errDelta'], obj['CS'], obj['A']))
        except:
            raise Exception("Failed to insert %s (keys: %s)" % (obj['name'], obj.keys()))
    conn.commit()
    conn.close()
    
    catalogs.writeRedshiftsCatalog(catalog, outFileName)
    
        
