#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
    Copyright 2017 Matt Hilton (matt.hilton@mykolab.com)
    
    This file is part of zCluster.

    zCluster is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    zCluster is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with zCluster.  If not, see <http://www.gnu.org/licenses/>.

"""

import os
import sys
import argparse
import astropy.table as atpy
from astropy.coordinates import SkyCoord
from astropy.coordinates import match_coordinates_sky
import astropy.stats as apyStats
from zCluster import *
import pylab as plt
import numpy as np
from cycler import cycler
import IPython
plt.matplotlib.interactive(False)

#------------------------------------------------------------------------------------------------------------
def update_rcParams(dict={}):
    """Based on Cristobal's preferred settings. Updates matplotlib rcParams in place.
    
    """
    default = {}
    for tick in ('xtick', 'ytick'):
        default['{0}.major.size'.format(tick)] = 8
        default['{0}.minor.size'.format(tick)] = 4
        default['{0}.major.width'.format(tick)] = 2
        default['{0}.minor.width'.format(tick)] = 2
        default['{0}.labelsize'.format(tick)] = 20
        default['{0}.direction'.format(tick)] = 'in'
    default['xtick.top'] = True
    default['ytick.right'] = True
    default['axes.linewidth'] = 2
    default['axes.labelsize'] = 22
    default['font.size'] = 22
    default['font.family']='sans-serif'
    default['legend.fontsize'] = 18
    default['lines.linewidth'] = 2

    for key in default:
        plt.rcParams[key] = default[key]
    # if any parameters are specified, overwrite anything previously
    # defined
    for key in dict:
        plt.rcParams[key] = dict[key]

    # From https://github.com/mhasself/rg_friendly
    plt.rcParams['axes.prop_cycle']=cycler(color=['#2424f0','#df6f0e','#3cc03c','#d62728','#b467bd','#ac866b','#e397d9','#9f9f9f','#ecdd72','#77becf'])
        
#-------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser = argparse.ArgumentParser("zClusterComparisonPlot")
    parser.add_argument("inCatalog", help="""A .fits table or a .csv file with columns 'name', 'RADeg',
                        'decDeg', 'z'. The last column here is the redshift which is for comparison with
                        the zCluster output. Objects are matched against zClusterCatalog based on the 'name'
                        columns.""")
    parser.add_argument("zClusterCatalog", help="""Path to a zCluster .fits table, i.e., the output of
                        the zCluster code.""")
    parser.add_argument("-x", "--x-match-radius", dest="xMatchRadiusArcsec", type = float, 
                        help = """Cross match radius (arcsec) for matching between inCatalog and zClusterCatalog
                        (default: 150).""", default = 150.0)
    parser.add_argument("-o", "--output-label", dest="outLabel", help="""Label to use for output  
                        (default: comparison_inCatalogFileName_zClusterCatalog, stripped of file extension). 
                        The code will write both a plot (.pdf) and a .fits table in the current
                        directory. Note that '.' in outLabel will be replaced with 'p' in the plot
                        file name in order to be latex friendly.""", default = None)
    parser.add_argument("-s", "--marker-size", dest="markerSize", help="Size of markers in the plot.",
                        default = 5, type = float)
    parser.add_argument("-S", "--spec-z-only", dest="specRedshiftsOnly", help="""If inCatalog contains a 
                        column named redshiftType, selects only objects with redshiftType = 'spec'.""",
                        default = False, action = "store_true")
    parser.add_argument("-t", "--title", dest="plotTitle", help="Plot title.", default = None)
    parser.add_argument("-z", "--z-min", dest="zMin", help="""Minimum redshift for plotting and overall
                        statistics (default: 0.0).""", default = 0.0, type = float)
    parser.add_argument("-Z", "--z-max", dest="zMax", help="""Maximum redshift for plotting and overall
                        statistics (default: 2.0).""", default = 1.999, type = float)
    parser.add_argument("-d", "--delta-min", dest="deltaMin", help="""Minimum delta for statistics 
                        (default: 3.0).""", default = 3.0, type = float)
    parser.add_argument("-D", "--delta-max", dest="deltaMax", help="""Maximum delta for statistics.""", 
                        default = 1e6, type = float)
    parser.add_argument("-l", "--inlabel", dest="inLabel", help="""Subscript label for input redshifts
                        (default: s).""", default = "s")
    parser.add_argument("-y", "--yrange", dest="yRangeStr", help="""Range of the plot y-axis (which
                        is the difference between the inCatalog redshift and the redshift in 
                        zClusterCatalog), given in the format yMin:yMax.""", default = "-0.55:0.55")
    parser.add_argument("-A", "--alpha", dest="alpha", help="""Alpha value transparency to use for shading
                        data points in the plots (default: 1.0).""", default = 1.0, type = float)
    parser.add_argument("-f", "--fit-bias", dest="fitBias", default = False, action = "store_true",
                        help="""If given, fit for redshift dependent bias""")
    args = parser.parse_args()

    inCatalogFileName=args.inCatalog
    outCatalogFileName=args.zClusterCatalog

    if args.plotTitle != None:
        plotTitle=str(args.plotTitle)
    else:
        plotTitle=None
    deltaMin=args.deltaMin
    deltaMax=args.deltaMax
    if args.inLabel != None:
        inLabel=str(args.inLabel)
    if args.yRangeStr != None:
        yRangeStr=str(args.yRangeStr)
    else:
        yRangeStr=args.yRangeStr
    xMatchRadiusDeg=args.xMatchRadiusArcsec/3600.0
        
    if inCatalogFileName.split(".")[-1] == "csv":
        inTab=atpy.Table().read(inCatalogFileName, format = 'ascii')
    else:
        inTab=atpy.Table().read(inCatalogFileName)
        
    outTab=atpy.Table().read(outCatalogFileName)

    if args.outLabel is None:
        outFileName="comparison_"+os.path.split(inCatalogFileName)[-1].replace(".fits", "").replace(".csv", "")+"_vs_"+outCatalogFileName
    else:
        outFileName=args.outLabel+".fits"
    plotFileName=outFileName.replace(".fits", "")
    plotFileName=plotFileName.replace(".", "p")+".pdf"

    if 'z' in list(inTab.keys()):
        zKey='z'
    elif 'redshift' in list(inTab.keys()):
        zKey='redshift'
    else:
        raise Exception("didn't find 'z' or 'redshift' column in input table %s" % (inCatalogFileName))

    # Optionally cut down to spec-zs only (if we can)
    if args.specRedshiftsOnly == True:
        if 'redshiftType' not in inTab.keys():
            raise Exception("inCatalog should contain 'redshiftType' column if you want to use the -S switch.")
        inTab=inTab[inTab['redshiftType'] == 'spec']
        print(">>> Using only objects with spectroscopic redshifts ...")
        
    # Now cross matching against spec-zs table by nearest match in zCluster results file by coords
    # Because zCluster now estimates its own RADeg, decDeg postion
    # (though we could just match on origRADeg, origDecDeg)
    cat1=SkyCoord(ra = inTab['RADeg'], dec = inTab['decDeg'], unit = 'deg')     # spec-z catalog
    cat2=SkyCoord(ra = outTab['RADeg'], dec = outTab['decDeg'], unit = 'deg')   # zCluster results catalog
    xIndices, rDeg, sep3d = match_coordinates_sky(cat1, cat2, nthneighbor = 1)
    mask=np.less(rDeg.value, xMatchRadiusDeg)
    cutTab=outTab[xIndices[mask]]
    inTab=inTab[mask]
        
    tab=atpy.Table()
    tab.add_column(inTab['name'])
    tab.add_column(cutTab['RADeg'])
    tab.add_column(cutTab['decDeg'])
    tab.add_column(atpy.Column(inTab[zKey], 'zIn'))
    tab.add_column(cutTab['z'])
    tab.add_column(atpy.Column(tab['zIn']-tab['z'], 'zDiff'))
    tab.add_column(cutTab['delta'])
    tab.add_column(cutTab['errDelta'])
    tab.add_column(cutTab['offsetArcmin'])
    tab.add_column(cutTab['offsetMpc'])
    tab=tab[np.where(tab['z'] > 0)] # because filled in -99 for fails/missing clusters (e.g., not in S82)
    tab=tab[np.where(tab['delta'] != np.inf)]
    
    zIn=tab['zIn']
    zOut=tab['z']
    delta=tab['delta']
    diff=tab['zDiff']
    diff_over1plusz=diff/(1+tab['zIn'])

    if args.fitBias == True:
        print(">>> Fitting for z-dependent bias:")
        zBinEdges=np.linspace(0, 2, 21)
        zBinCentres=(zBinEdges[:-1]+zBinEdges[1:])/2
        debiasVals=np.linspace(-0.1, 0.1, 201)
        bestTestStat=1e6
        for test_zDebias in debiasVals:
            debiased_zOut=zOut+test_zDebias*(1+zOut)
            medDiff=-99*np.ones(len(zBinEdges)-1)
            for i in range(len(zBinEdges)-1):
                binMin=zBinEdges[i]
                binMax=zBinEdges[i+1]
                mask=np.logical_and(zIn >= binMin, zIn < binMax)
                if mask.sum() > 0:
                    medDiff[i]=np.median(zIn[mask]-debiased_zOut[mask])
            valid=medDiff != -99
            testStat=abs(np.median(medDiff[np.logical_and(zBinCentres < 1.3, zBinCentres > 0.1)]))
            if testStat < bestTestStat:
                bestTestStat=testStat
                zDebias=test_zDebias
        print("... zDebias = %.3f ; testStat = %.4f" % (zDebias, bestTestStat))
        print("... applying this correction to all that follows")
        zOut=zOut+zDebias*(1+zOut)
        diff=zIn-zOut
        diff_over1plusz=diff/(1+zIn)
        # Write out debiased zs table
        # print("... written debiased output catalog")
        outTab['z']=outTab['z']+zDebias*(1+outTab['z'])
        debias_outCatalogFileName=os.path.abspath(outCatalogFileName)
        outDir=os.path.dirname(debias_outCatalogFileName)
        outTab.write(outDir+os.path.sep+"debiased_"+os.path.basename(debias_outCatalogFileName), overwrite = True)

    # Stats based on delta cut by z
    zCutsList=[[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [args.zMin, args.zMax]]
    if deltaMin != None:
        print(">>> Stats for delta > %.2f:" % (deltaMin))
        for zCuts in zCutsList:
            zMin=zCuts[0]
            zMax=zCuts[1]
            deltaMask=np.logical_and(np.greater(delta, deltaMin), np.less(delta, deltaMax))
            notMask=np.less(delta, deltaMin)
            zMask=np.logical_and(np.greater(zIn, zMin), np.less(zIn, zMax))
            mask=np.logical_and(zMask, deltaMask)
            if mask.sum() > 0:
                print(">>> %.2f < z < %.2f (N = %d):" % (zMin, zMax, mask.sum()))
                maskedDiff=zIn[mask]-zOut[mask]
                meanMaskedDiff=np.mean(maskedDiff)
                stdMaskedDiff=np.std(maskedDiff)
                maskedDiff_over1plusz=(zIn[mask]-zOut[mask])/(1+zIn[mask])
                meanMaskedDiff_over1plusz=np.mean(maskedDiff_over1plusz)
                stdMaskedDiff_over1plusz=np.std(maskedDiff_over1plusz)
                biweightMaskedDiff_over1plusz=apyStats.biweight_scale(maskedDiff_over1plusz, c = 9.0, modify_sample_size = True)
                #print("... median (zIn - zOut) = %.3f" % (np.median(maskedDiff)))
                #print("... mean (zIn - Zout) = %.3f +/- %.3f" % (meanMaskedDiff, stdMaskedDiff / np.sqrt(mask.sum())))
                #print("... sigma (zIn - zOut) = %.3f" % (stdMaskedDiff))
                #if zCuts == zCutsList[0]:
                print("... median (zIn - zOut) / (1 + zIn)                  = %.3f" % (np.median(maskedDiff_over1plusz)))
                print("... mean (zIn - Zout) / (1 + zIn)                    = %.3f" % (meanMaskedDiff_over1plusz))
                print("... sigma (zIn - zOut) / (1 + zIn)                   = %.3f" % (stdMaskedDiff_over1plusz))
                print("... biweight scale (zIn - zOut) / (1 + zIn)          = %.3f" % (biweightMaskedDiff_over1plusz))
                print("... fraction with (zIn - zOut) / (1 + zIn) < 0.05   = %.3f" % ((abs(maskedDiff_over1plusz) < 0.05).sum()/len(maskedDiff_over1plusz)))

    # Default mask / stats for plotting
    zMin=zCutsList[-1][0]
    zMax=zCutsList[-1][1]
    deltaMask=np.greater(delta, deltaMin)
    notMask=np.equal(deltaMask, False)
    zMask=np.logical_and(np.greater(zIn, zMin), np.less(zIn, zMax))
    mask=np.logical_and(zMask, deltaMask)
    maskedDiff=zIn[mask]-zOut[mask]
    meanMaskedDiff=np.mean(maskedDiff)
    medianMaskedDiff=np.median(maskedDiff)
    stdMaskedDiff=np.std(maskedDiff)
    maskedDiff_over1plusz=(zIn[mask]-zOut[mask])/(1+zIn[mask])
    meanMaskedDiff_over1plusz=np.mean(maskedDiff_over1plusz)
    medianMaskedDiff_over1plusz=np.median(maskedDiff_over1plusz)
    stdMaskedDiff_over1plusz=np.std(maskedDiff_over1plusz)
    biweightMaskedDiff_over1plusz=apyStats.biweight_scale(maskedDiff_over1plusz, c = 9.0, modify_sample_size = True)

    # Output a table of catastropic failures for investigation
    catMask=np.greater(abs(maskedDiff_over1plusz), 0.15)
    catTab=(tab[mask])[catMask]
    catTab.write(outFileName.replace(".fits", "_catastrophic.fits"), overwrite = True)

    # Plot
    zMin=args.zMin
    zMax=args.zMax
    
    update_rcParams()

    # Because we want dummy points that aren't transparent in the legend
    colors=[]
    for cc in plt.rcParams['axes.prop_cycle']:
        colors.append(cc['color'])
    
    plt.figure(figsize=(18, 6.5))

    plt.subplot(121)
    plt.subplots_adjust(left = 0.065, bottom = 0.11, right = 0.99, top = 0.97, wspace = 0.01, hspace = 0.24)

    plt.plot(zIn[deltaMask], diff_over1plusz[deltaMask], 'D', color = colors[0], ms = args.markerSize, alpha = args.alpha, mew = 0, zorder = 3)    
    plt.plot(zIn[notMask], diff_over1plusz[notMask], 'D', ms = args.markerSize, color = '#888888', alpha = args.alpha, mew = 0, zorder = 1)
    
    # Dummy points that aren't transparent
    plt.plot([], [], 'D', color = colors[0], label = '$\delta > %.0f$' % (deltaMin))
    plt.plot([], [], 'D', color = '#888888', label = '$\delta < %.0f$' % (deltaMin))
    
    yPlotRange=np.linspace(-1, 1, 10)
    xPlotRange=np.linspace(zMin, zMax, 10)
    plt.plot(xPlotRange, [meanMaskedDiff]*len(xPlotRange), 'k--', 
             label = '$\langle \, (z_{\mathrm{%s}}$-$z_{\mathrm{c}}$)/($1+z_{\mathrm{%s}}) \, \\rangle$ = %.3f' % (inLabel, inLabel, medianMaskedDiff_over1plusz))
    plt.plot(xPlotRange, [meanMaskedDiff-stdMaskedDiff]*len(xPlotRange), 'k:', 
             label = '$\sigma_{\\rm bw}$ [($z_{\mathrm{%s}}$-$z_{\mathrm{c}}$)/($1+z_{\mathrm{%s}}$)] = %.3f' % (inLabel, inLabel, biweightMaskedDiff_over1plusz))
    plt.plot(xPlotRange, [meanMaskedDiff+stdMaskedDiff]*len(xPlotRange), 'k:')
    plt.plot(np.linspace(zMin, zMax, 3), [0]*3, 'k-')
    plt.xlim(zMin, zMax)
    if yRangeStr is not None:
        yMin, yMax=yRangeStr.split(":")
        yMin=float(yMin)
        yMax=float(yMax)
        plt.ylim(yMin, yMax)
    leg=plt.legend(numpoints = 1, loc = 'lower right')
    if plotTitle != None:
        plt.figtext(0.1, 0.85, plotTitle, fontdict = {'size': 28, 'weight': 'bold'}, ha = 'left')#, size = 24)
    #leg.draw_frame(False)
    yticks, ylabels=plt.yticks()
    plt.xlabel("$z_{\mathrm{%s}}$" % (inLabel))
    plt.ylabel("$(z_{\mathrm{%s}}-z_{\mathrm{c}})$/$(1+z_{\mathrm{%s}})$" % (inLabel, inLabel))

    plt.subplot(122)
    plt.plot(delta[mask], diff_over1plusz[mask], 'D', color = colors[0], ms = args.markerSize, alpha = args.alpha, mew = 0, zorder = 3)
    plt.plot(delta[notMask], diff_over1plusz[notMask], 'D', color = '#888888', ms = args.markerSize, alpha = args.alpha, mew = 0, zorder = 1)
    yPlotRange=np.linspace(-1, 1, 10)
    xPlotRange=np.linspace(0, delta.max()*1.1, 10)
    plt.plot(xPlotRange, [meanMaskedDiff]*len(xPlotRange), 'k--', 
             label = 'mean($z_{\mathrm{%s}}$-$z_{\mathrm{c}}$) = %.3f' % (inLabel, meanMaskedDiff_over1plusz))
    plt.plot(xPlotRange, [meanMaskedDiff-stdMaskedDiff]*len(xPlotRange), 'k:', 
             label = '$\sigma_{\\rm bw}$($z_{\mathrm{%s}}$-$z_{\mathrm{c}}$) = %.3f' % (inLabel, biweightMaskedDiff_over1plusz))
    plt.plot(xPlotRange, [meanMaskedDiff+stdMaskedDiff]*len(xPlotRange), 'k:')
    plt.plot(np.linspace(0, delta.max()*1.1, 3), [0]*3, 'k-')
    plt.xlim(0, delta.max()*1.1)
    plt.yticks(yticks, len(yticks)*[''])
    if yRangeStr != None and yRangeStr != 'auto':
        yMin, yMax=yRangeStr.split(":")
        yMin=float(yMin)
        yMax=float(yMax)
        plt.ylim(yMin, yMax)
    plt.xlabel("$\delta$")
    #plt.ylabel("$(z_{\mathrm{%s}}-z_{\mathrm{c}})$/$(1+z_{\mathrm{%s}})$" % (inLabel, inLabel))

    plt.savefig(plotFileName)
    plt.close()
    
    # Output for more fiddling
    tab.table_name="zCluster comparison"
    if os.path.exists(outFileName) == True:
        os.remove(outFileName)
    tab.write(outFileName)
    

    
