#!/usr/bin/env python

"""
Fetch photometry for galaxy fields from photometric databases at given position and radius.
"""

import os
import sys
import argparse
import numpy as np
import astropy.table as atpy
import astropy.io.fits as pyfits
import astropy.stats as apyStats
from astLib import *
from zCluster import retrievers
from zCluster import PhotoRedshiftEngine
from zCluster import clusters
from zCluster import catalogs
import pickle
import urllib
import time

#------------------------------------------------------------------------------------------------------------
def makeParser():
    parser = argparse.ArgumentParser("zField")
    parser.add_argument("RADeg", type = float, help="RA (in decimal degrees) of the centre of the field to fetch.")
    parser.add_argument("decDeg", type = float, help="Dec (in decimal degrees) of the centre of the field to fetch.")
    parser.add_argument("-r", "--max-radius-deg", dest="maxRDeg", help="Radius in degrees around the given\
                        RA, Dec. coordinates within which galaxy photo-zs will be estimated (default = 0.2).",
                        type = float, default = 0.2)
    parser.add_argument("-D", "--database", help="The photometric database to use. Options are 'SDSSDR12', 'S82'\
                        (for SDSS DR7 Stripe 82 co-add); 'CFHTLenS'; 'DESDR1', 'DESDR2',\
                        'DESY3' [experimental; requires access to proprietary DES data]; 'PS1' [experimental];\
                        'DECaLSDR8', 'DECaLSDR9'; or the path to a .fits table with columns in the\
                        appropriate format ('ID', 'RADeg', 'decDeg', and magnitude column names in the form\
                        'u_MAG_AUTO', 'u_MAGERR_AUTO' etc.).", default = 'DECaLSDR8')
    parser.add_argument("-c", "--cachedir", dest="cacheDir", default = None, help="Cache directory location\
                        (default: $HOME/.zCluster/cache). Downloaded photometric catalogs will be stored\
                        here.")
    parser.add_argument("-o", "--output-directory", dest="outDir", help="Name of the output directory (default: zFieldOutput).",
                        default = "zFieldOutput")
    parser.add_argument("-l", "--label", dest="label", help="Optional label to use for output files (default\
                        is to use the given RA, Dec coordinates in the output file names).", default = None)
    parser.add_argument("-t", "--templates-directory", dest="templatesDir", help="Specify a directory containing\
                        a custom set of spectral templates.", default = None)
    parser.add_argument("-e", "--max-mag-error", dest="maxMagError", help="Maximum acceptable\
                        photometric error (in magnitudes; default: 0.25).", default = 0.25)
    parser.add_argument("-E", "--photometric-zero-point-error", dest="ZPError", type = float,
                        help="Global photometric zero point uncertainty in magnitudes, applied to all bands\
                        (default: 0). Added in quadrature to the photometric uncertainties in the catalog.",
                        default = 0.0)
    parser.add_argument("-f", "--fit-for-zero-point-offsets", dest="fitZPOffsets",
                        help="If the input catalog contains a z_spec column, use those galaxies to fit\
                        for magnitude zero point offsets. These will then be applied when estimating galaxy\
                        photometric redshifts.",
                        default = False, action = "store_true")
    parser.add_argument("-b", "--brighter-absmag-cut", dest="absMagCut", help="Set bright absolute magnitude cut.",
                        default = -24.)
    parser.add_argument("-d", "--write-density-maps", dest="writeDensityMaps", action="store_true",
                        help="Write out a .fits image projected density map (within --delta-z of the value\
                        specified using --redshift).", default = False)
    parser.add_argument("-z", "--redshift", dest="z", help="Redshift at which to make the density map.",
                        type = float, default = None)
    parser.add_argument("-dz", "--delta-z", dest="dz", help="When making a projected density map, galaxy\
                        p(z) distributions are integrated within +/- dz of the given redshift (default: 0.1).",
                        type = float, default = 0.1)
    parser.add_argument("-DS9", "--write-ds9-regions", dest="writeDS9regions", action="store_true",
                        help="Write out a DS9 .reg file corresponding to the fetched photometric catalog.", default = False)
    parser.add_argument("-ErrMap", "--make-error-map", dest="ErrMap", action="store_true",
                        help="Estimate and output uncertainties on the projected density map and associated\
                        summary statistics (warning: extremely slow).""", default = False)
    parser.add_argument("-MC", "--monte-carlo-samples", dest="MC",
                        help="Number of Monte Carlo samples to use when estimating uncertainties on the\
                        projected galaxy density map", type = int, default = 1000)
    parser.add_argument("-S", "--stellar-mass-model-dir", dest="stellarMassModelDir", default = None,
                        help = "If given, estimate stellar masses for galaxies using the (for now) BC03-format\
                        stellar population models in the given directory. If -z (--redshift) is given, stellar\
                        masses will be estimated at the given redshift. Otherwise, the stellar masses will be\
                        estimated at the maximum likelihood redshift of each galaxy (although that isn't\
                        actually implemented yet).")
    parser.add_argument("-p", "--pickle-galaxy-catalog", dest="pickleGalaxyCatalog", default = False,
                        action = "store_true", help = "If given, pickle the galaxy catalog after estimating\
                        photo-zs. Use this if you want to e.g. read p(z) for each galaxy from disk for some\
                        other analysis.")

    return parser

#------------------------------------------------------------------------------------------------------------
start=time.time()
if __name__ == '__main__':

    parser=makeParser()
    args=parser.parse_args()

    RADeg=args.RADeg
    decDeg=args.decDeg
    maxRDeg=float(args.maxRDeg)
    ZPError=args.ZPError
    templatesDir=args.templatesDir
    absMagCut=float(args.absMagCut)
    database=args.database
    maxMagError=float(args.maxMagError)
    fitZPOffsets=args.fitZPOffsets
    cacheDir=args.cacheDir
    writeDensityMaps=args.writeDensityMaps
    z=args.z
    dz=args.dz
    ErrMap=args.ErrMap
    MC=args.MC

    # Memory saving - we don't care about carrying around p(z) if we're not writing density maps
    returnPZ=False
    if args.writeDensityMaps == True or args.pickleGalaxyCatalog == True:
        returnPZ=True

    # Set-up cache
    if cacheDir is not None:
        if os.path.exists(cacheDir) == False:
            os.makedirs(cacheDir, exist_ok=True)
    else:
        if os.path.exists(retrievers.CACHE_DIR) == False:
                os.makedirs(retrievers.CACHE_DIR, exist_ok = True)

    # Optional output label
    if args.label is not None:
        outputLabel=args.label
    else:
        outputLabel="%.3f_%.3f" % (RADeg, decDeg)

    # Set-up where we will write output
    os.makedirs(args.outDir, exist_ok = True)
    outFileName=args.outDir+os.path.sep+"zField_catalog_%s.fits" % (outputLabel)
    
    # Method for fetching catalogs
    retriever, retrieverOptions, passbandSet=retrievers.getRetriever(database, maxMagError = 0.2)

    ZPOffsets=None
    if database == 'DECaLS':
        # Zero point offsets remove bias from galaxy max likelihood photo-zs when testing on SDSS
        ZPOffsets=np.array([0.02271222, -0.05051711, -0.02465597, -0.00406835, 0.05406105])
        
    if cacheDir is not None:
        retrieverOptions['altCacheDir']=cacheDir

    galaxyCatalog=retriever(RADeg, decDeg, halfBoxSizeDeg = maxRDeg, optionsDict = retrieverOptions)

    if galaxyCatalog is None:
        print("No galaxies found at the given position.")
        sys.exit()
    
    photoRedshiftEngine=PhotoRedshiftEngine.PhotoRedshiftEngine(absMagCut, ZPError = ZPError,
                                                                passbandSet = passbandSet,
                                                                templatesDir = templatesDir,
                                                                ZPOffsets = ZPOffsets)
    bands=photoRedshiftEngine.bands
    bandErrs=[]
    for b in bands:
        bandErrs.append(b+"Err")
    bands=bands+bandErrs

    if fitZPOffsets == True:
        photoRedshiftEngine.calcZeroPointOffsets(galaxyCatalog)
    photoRedshiftEngine.calcPhotoRedshifts(galaxyCatalog, calcMLRedshiftAndOdds = True, returnPZ = returnPZ)
    
    # Optional stellar mass estimates
    if args.stellarMassModelDir is not None:
        photoRedshiftEngine.estimateStellarMasses(galaxyCatalog, args.stellarMassModelDir, z = z)

    # Dump pickle file for p(z), optional
    if args.pickleGalaxyCatalog is True:
        pickleFileName=args.outDir+os.path.sep+"galaxyCatalog.pkl"
        with open(pickleFileName, "wb") as pickleFile:
            pickler=pickle.Pickler(pickleFile)
            pickler.dump(galaxyCatalog)
    # with open(pickleFileName, "rb") as pickleFile:
    #     unpickler=pickle.Unpickler(pickleFile)
    #     galaxyCatalog=unpickler.load()

    # Write catalog and region file
    wantedKeys=['id', 'RADeg', 'decDeg', 'zPhot', 'odds']
    if args.stellarMassModelDir is not None:
        wantedKeys.append('log10StellarMass')
    wantedKeys=wantedKeys+bands
    tab=atpy.Table()
    for key in wantedKeys:
        arr=[]
        for gobj in galaxyCatalog:
            try:
                arr.append(gobj[key])
            except:
                arr.append(99)
        tab.add_column(atpy.Column(np.array(arr), key))
    tab.write(outFileName, overwrite = True)
    
    if args.writeDS9regions == True:
        catalogs.catalog2DS9(galaxyCatalog, args.outDir+os.path.sep+"%s_%s.reg" % (database, outputLabel),
                                 idKeyToUse = 'id', addInfo = [{'key': 'r', 'fmt': '%.3f'}, {'key': 'zPhot', 'fmt': '%.2f'}])
        
    if writeDensityMaps == True:
        if z is None:
            raise Exception("Redshift needs to be specified using -z when using the -d option")
        else:
            sizeMpc=astCalc.da(z)*np.tan(np.radians(maxRDeg))*2
            dMapDict=clusters.makeDensityMap(RADeg, decDeg, galaxyCatalog, z, dz, sizeMpc = sizeMpc,
                                             outFITSFileName = args.outDir+os.path.sep+"densityMap_z%.2f_%s.fits" % (z, outputLabel),
                                             outPlotFileName = args.outDir+os.path.sep+"densityMap_z%.2f_%s.png" % (z, outputLabel))
            #astImages.saveFITS(args.outDir+os.path.sep+"densityMap_z%.2f_%s.fits" % (z, outputLabel),
                               #dMapDict['map'], dMapDict['wcs'])

            
    # Estimate errors for density map, CS and A
    if ErrMap == True:
        print(">>> Estimating galaxy density error map ...")
        AErr = [] ; CSErr = [] ; MapErr=[]
        newGalaxyCatalog=galaxyCatalog
        for monteCarlo in range(0,MC):
            for g in newGalaxyCatalog:
                for i in bands:
                    res = g.get(i, None)
                    if res and i[-3:]!="Err" and g[i]!=99:
                        g[i]=g[i]+ np.random.normal(0,g[i+'Err']) # Add Gaussian noise to photmetry of each galaxy
            
            if fitZPOffsets == True:
                photoRedshiftEngine.calcZeroPointOffsets(newGalaxyCatalog)
            photoRedshiftEngine.calcPhotoRedshifts(newGalaxyCatalog, calcMLRedshiftAndOdds = True) # Recompute photo-z's
            sizeMpc=astCalc.da(z)*np.tan(np.radians(maxRDeg))*2
            dMapDictwerr = clusters.makeDensityMap(RADeg, decDeg, newGalaxyCatalog, z, dz, sizeMpc = sizeMpc) # Make new density maps with added noise
            CSErr.append(dMapDictwerr['CS'])
            AErr.append(dMapDictwerr['A'])
            MapErr.append(dMapDictwerr['map'])
        AError=np.percentile(AErr, 68.3) # 68th percentile as error
        CSError=np.percentile(CSErr, 68.3)
        MapError=np.percentile(MapErr,68.3, axis=0)
        astImages.saveFITS(args.outDir+os.path.sep+"errorMap_z%.2f_%s.fits" % (z, outputLabel),
                               MapError, dMapDict['wcs'])
        print(">>> CS:\t"+str(dMapDict['CS'])+" +- "+str(CSError)+"\n A:\t"+str(dMapDict['A'])+" +- "+str(AError))
        
    # We may as well print some stats
    print(">>> Summary:")
    print("    N = %d" %  (len(tab['zPhot'])))
    print(">>> Total time: %.1f sec" % (time.time()-start))
