#!/usr/bin/env python

"""

    Interactive display of image/CMD and selection of the BCG. SDSS only at this point.

    Copyright 2017 Matt Hilton (matt.hilton@mykolab.com)
    
    This file is part of zCluster.

    zCluster is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    zCluster is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with zCluster.  If not, see <http://www.gnu.org/licenses/>.

"""

import os
import sys
import urllib.request, urllib.parse, urllib.error
import numpy as np
import astropy.table as atpy
import astropy.io.fits as pyfits
from zCluster import *
import pylab as plt
from astLib import *
from PIL import Image
import IPython
import argparse
from matplotlib.widgets import Button
from matplotlib.widgets import RadioButtons
#plt.matplotlib.interactive(False)

#-------------------------------------------------------------------------------------------------------------
# Globals because we yanked this out of sourcery SourceBrowser

cacheDir="imageCache"
configDict={'plotSizeArcmin': 8.0, 'figSize': (14, 7.5)}
DESTiles=None

if os.path.exists(cacheDir) == False:
    os.makedirs(cacheDir, exist_ok = True)

#-------------------------------------------------------------------------------------------------------------
def getDESTiles():
    return DESTiles

#-------------------------------------------------------------------------------------------------------------
def fetchSDSSImage(obj, refetch = False):
    """Fetches the SDSS .jpg for the given image size using the casjobs webservice.
            
    """

    sdssCacheDir=cacheDir+os.path.sep+"SDSS"
    if os.path.exists(sdssCacheDir) == False:
        os.makedirs(sdssCacheDir, exist_ok = True)
        
    name=obj['name']
    RADeg=obj['RADeg']
    decDeg=obj['decDeg']                
    #outFileName=sdssCacheDir+os.path.sep+name.replace(" ", "_")+".jpg"
    outFileName=sdssCacheDir+os.path.sep+catalogs.makeRADecString(RADeg, decDeg)+".jpg"
    print("... image path = %s ..." % (outFileName))
    SDSSWidth=1200.0
    SDSSScale=(configDict['plotSizeArcmin']*60.0)/SDSSWidth # 0.396127
    if os.path.exists(outFileName) == False or refetch == True:
        urlString="http://skyservice.pha.jhu.edu/DR10/ImgCutout/getjpeg.aspx?ra="+str(RADeg)+"&dec="+str(decDeg)
        urlString=urlString+"&scale="+str(SDSSScale)+"&width="+str(int(SDSSWidth))+"&height="+str(int(SDSSWidth))
        try:
            urllib.request.urlretrieve(urlString, filename = outFileName)
        except:
            print("... WARNING: couldn't get SDSS image ...")
            print(urlString)
            outFileName=None
        os.system("convert -gamma 2 %s %s" % (outFileName, outFileName))

#-------------------------------------------------------------------------------------------------------------
def fetchDESImage(obj, refetch = False):
    """Fetches DES image. If the image isn't already in the cache, you'll sourcery installed and a directory 
    full of DES tile .jpg images (contact Matt).
    
    """
    
    desCacheDir=cacheDir+os.path.sep+"DES"
    if os.path.exists(desCacheDir) == False:
        os.makedirs(desCacheDir, exist_ok = True)
        
    name=obj['name']
    RADeg=obj['RADeg']
    decDeg=obj['decDeg']                
    outFileName=desCacheDir+os.path.sep+catalogs.makeRADecString(RADeg, decDeg)+".jpg"
    print("... image path = %s ..." % (outFileName))
    global DESTiles # We're doing things this way to avoid long set-up time for tile dir
    if os.path.exists(outFileName) == False or refetch == True:
        if DESTiles is None:
            from sourcery import tileDir
            DESTiles=tileDir.TileDir("DES", "DESTiles", cacheDir, sizePix = 1200)
        DESTiles.fetchImage(name, RADeg, decDeg, configDict['plotSizeArcmin'], refetch = refetch)
    
#-------------------------------------------------------------------------------------------------------------
class CMDImageCombo:
    """Combination CMD and clickable image plot in one.
    
    """
    
    def __init__(self, cluster, galTab, surveyLabel, clipSizeArcmin = None, zColumn = 'z', matchTab1 = None, 
                 matchTab2 = None, plotFileName = False):
        """Needs all the info to make a clickable image plot and CMD.
        
        cluster should be a table row or dict with keys 'name', 'RADeg', 'decDeg'
        
        """
        
        self.cluster=cluster
        self.redshiftColumn=zColumn
        self.galTab=galTab
        self.matchTab1=matchTab1
        self.matchTab2=matchTab2
        self.keepBCG=None
        self.plotFileName=plotFileName
        
        # Just in case they are passed as strings (e.g., if direct from the url)
        name=cluster['name']
        RADeg=cluster['RADeg']
        decDeg=cluster['decDeg']
        
        sizeDeg=configDict['plotSizeArcmin']/60.0
        
        # Load data
        inJPGPath=cacheDir+os.path.sep+surveyLabel+os.path.sep+catalogs.makeRADecString(RADeg, decDeg)+".jpg"
        #inJPGPath=cacheDir+os.path.sep+surveyLabel+os.path.sep+name.replace(" ", "_")+".jpg"
        if os.path.exists(inJPGPath) == False:
            return None
        
        im=Image.open(inJPGPath)
        data=np.array(im)
        try:
            data=np.flipud(data)
            #data=np.fliplr(data)
        except:
            #"... something odd about image (1d?) - aborting ..."
            return None
        
        R=data[:, :, 0]
        G=data[:, :, 1]
        B=data[:, :, 2]
        
        # Make a WCS
        sizeArcmin=configDict['plotSizeArcmin']
        xSizeDeg, ySizeDeg=sizeArcmin/60.0, sizeArcmin/60.0
        xSizePix=R.shape[1]
        ySizePix=R.shape[0]
        xRefPix=xSizePix/2.0
        yRefPix=ySizePix/2.0
        xOutPixScale=xSizeDeg/xSizePix
        yOutPixScale=ySizeDeg/ySizePix
        newHead=pyfits.Header()
        newHead['NAXIS']=2
        newHead['NAXIS1']=xSizePix
        newHead['NAXIS2']=ySizePix
        newHead['CTYPE1']='RA---TAN'
        newHead['CTYPE2']='DEC--TAN'
        newHead['CRVAL1']=RADeg
        newHead['CRVAL2']=decDeg
        newHead['CRPIX1']=xRefPix+1
        newHead['CRPIX2']=yRefPix+1
        newHead['CDELT1']=-xOutPixScale
        newHead['CDELT2']=xOutPixScale    # Makes more sense to use same pix scale
        newHead['CUNIT1']='DEG'
        newHead['CUNIT2']='DEG'
        wcs=astWCS.WCS(newHead, mode='pyfits')
        
        # Optional zoom
        if clipSizeArcmin != None:
            clipSizeArcmin=float(clipSizeArcmin)
            RClip=astImages.clipImageSectionWCS(R, wcs, RADeg, decDeg, clipSizeArcmin/60.0)
            GClip=astImages.clipImageSectionWCS(G, wcs, RADeg, decDeg, clipSizeArcmin/60.0)
            BClip=astImages.clipImageSectionWCS(B, wcs, RADeg, decDeg, clipSizeArcmin/60.0)
            R=RClip['data']
            G=GClip['data']
            B=BClip['data']
            wcs=RClip['wcs']

        cutLevels=[[R.min(), R.max()], [G.min(), G.max()], [B.min(), B.max()]]
                                                
        # Make image plot
        self.imageFig=plt.figure(figsize = configDict['figSize'])
        axes=[0.07,0.085,0.5,0.85]
        axesLabels="sexagesimal"
        self.imagePlot=astPlots.ImagePlot([R, G, B], wcs, cutLevels = cutLevels, title = name, axes = axes, 
                                          axesLabels = axesLabels)
        
        # Save as .fits for checking PA stuff
        #astImages.saveFITS("test.fits", R, wcs)
        
        # Cut down object catalog to only things we can see in image
        mask=[]
        for g in self.galTab:
            if self.imagePlot.wcs.coordsAreInImage(g['RADeg'], g['decDeg']) == True:
                mask.append(True)
            else:
                mask.append(False)
        self.galTab=self.galTab[np.where(np.array(mask))]
        
        # Add image coords to galaxies catalog (since images are TAN projected this is fine for PA calc later)
        pixCoords=np.array(wcs.wcs2pix(self.galTab['RADeg'].data, self.galTab['decDeg'].data))
        self.galTab.add_column(atpy.Column(pixCoords[:, 0], "xImage"))
        self.galTab.add_column(atpy.Column(pixCoords[:, 1], "yImage"))
        
        # Make CMD
        self.cmd=plt.axes([0.63, 0.4, 0.35, 0.535])
        self.drawCMD()

        self.imagePlot.draw()
        if self.matchTab1 is not None:
            self.imagePlot.addPlotObjects(np.array(self.matchTab1['RADeg']), np.array(self.matchTab1['decDeg']), 
                                          'BCG', symbol='circle', size=10.0, color='yellow')
        if self.matchTab2 is not None:
            self.imagePlot.addPlotObjects(np.array(self.matchTab2['RADeg']), np.array(self.matchTab2['decDeg']), 
                                          'galaxy2', symbol='circle', size=10.0, color='green') 
            
        # Buttons
        axok=plt.axes([0.63, 0.25, 0.1, 0.075])
        self.okButton=Button(axok, 'Record BCG')
        self.okButton.on_clicked(self.okClicked)

        axignore=plt.axes([0.73, 0.25, 0.1, 0.075])
        self.ignoreButton=Button(axignore, 'Skip')        
        self.ignoreButton.on_clicked(self.ignoreClicked)
    
        # Mode selection - sticky to galaxies in the catalog/cmd plot or 'freehand'
        axradio = plt.axes([0.63, 0.12, 0.2, 0.1])
        self.radio=RadioButtons(axradio, ['Nearest Galaxy', 'Free Selection'])

        # Activate clickiness
        self.imageFig.canvas.mpl_connect('button_press_event', self.onclick)
        self.imageFig.canvas.mpl_connect('pick_event', self.onpick)

        #plt.axes(self.imagePlot.axes)
                

    def okClicked(self, blah):
        self.keepBCG=True
        if self.plotFileName is not None:
            self.imageFig.savefig(self.plotFileName) 
        plt.close()

    
    def ignoreClicked(self, blah):
        self.keepBCG=False
        plt.close()


    def getSelectedMode(self):
        selected=(0.0, 0.0, 1.0, 1.0)   # colour tuple of the currently selected thing
        for c, l in zip(self.radio.circles, self.radio.labels):
            if c.get_facecolor() == selected:
                modeText=l.get_text()
        return modeText

    
    def drawCMD(self):
        """Draw the CMD, highlighting galaxies with phot-z within some z cut and the selected object in the
        image plot.
        
        """
        ax=plt.axes(self.cmd)
        plt.cla()
        zpMask=np.logical_and(np.greater(self.galTab['zPhot'], self.cluster[zColumn]-0.1), np.less(self.galTab['zPhot'], self.cluster[zColumn]+0.1))
        plt.plot(self.galTab['i'], self.galTab['r-i'], 'b.', picker = 5)
        plt.plot(self.galTab['i'][zpMask], self.galTab['r-i'][zpMask], 'r.')
        if self.matchTab1 is not None:
            plt.plot(self.matchTab1['i'], self.matchTab1['r-i'], 'yo')
        if self.matchTab2 is not None:
            plt.plot(self.matchTab2['i'], self.matchTab2['r-i'], 'go')            
        plt.ylim(-0.5, 1.5)
        plt.xlim(15, 23)
        plt.xlabel("i")
        plt.ylabel("r-i")        
    
    
    def onclick(self, event):
        """Get coords of where we clicked, update plots.
        
        """
        
        button=event.button
        x=event.xdata
        y=event.ydata
        if button == 1:     # left
            matchTab=self.matchTab1
            label='BCG'
            color='yellow'
        elif button == 3:   # right
            matchTab=self.matchTab2
            label='galaxy2'
            color='green'
            
        if event.inaxes == self.imagePlot.axes:

            modeText=self.getSelectedMode()
            
            if modeText == 'Nearest Galaxy':
                RADeg, decDeg=self.imagePlot.wcs.pix2wcs(x, y)
                rDeg=astCoords.calcAngSepDeg(RADeg, decDeg, self.galTab['RADeg'], self.galTab['decDeg'])
                matchTab=self.galTab[np.where(np.equal(rDeg, rDeg.min()))]
            elif modeText == 'Free Selection':
                RADeg, decDeg=self.imagePlot.wcs.pix2wcs(event.xdata, event.ydata)
                freeTab=atpy.Table()
                for key in list(self.galTab.keys()):
                    freeTab.add_column(atpy.Column(np.zeros(1), key))
                freeTab['RADeg'][0]=RADeg
                freeTab['decDeg'][0]=decDeg
                freeTab.table_name="FreeSelection"
                matchTab=freeTab
                
            self.drawCMD()
            
            self.imagePlot.addPlotObjects(np.array(matchTab['RADeg']), np.array(matchTab['decDeg']), 
                                          label, symbol='circle', size=10.0, color=color)
            
            if button == 1:
                self.matchTab1=matchTab
            elif button == 3:
                self.matchTab2=matchTab

            self.imagePlot.draw()


    def onpick(self, event):
        """Select an object in the CMD and highlight it in the image plot.
        
        """
        
        # NOTE: This only works with left button (can't use right to pick two different objects)
        thispoint = event.artist
        xdata = thispoint.get_xdata()
        ydata = thispoint.get_ydata()
        ind = event.ind
        colour=ydata[ind]
        mag=xdata[ind]
        
        # Find closest point
        minDiff=1e6
        for m, c in zip(mag, colour):
            diff=np.sqrt((self.galTab['r-i']-c)**2 + (self.galTab['i']-m)**2)
            if diff.min() < minDiff:
                self.matchTab1=self.galTab[np.where(np.equal(diff, diff.min()))]
        self.drawCMD()
        self.imagePlot.addPlotObjects(np.array(self.matchTab1['RADeg']), np.array(self.matchTab1['decDeg']), 'BCG', 
                                      symbol='circle', size=10.0, color='yellow')        
        self.imagePlot.draw()

        
#-------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    parser = argparse.ArgumentParser("zClusterBCG")
    parser.add_argument("catalogFileName", help="""A .fits table with keys 'name', 'RADeg', 'decDeg', 
                        and 'z' or 'redshift'. An output table including additional columns with the
                        chosen BCG coordinates will be written as catalogFileName_incBCG.fits.""")
    parser.add_argument("resultsDir", help="""A directory containing zCluster results that matches 
                        against cluster names in clusterTable.fits.""")
    parser.add_argument("-S", "--survey", dest="survey", help="""Name of the imaging survey to use (SDSS or
                        DES).""", default = 'SDSS')
    parser.add_argument("-n", "--name", dest="name", help="Select/edit BCG only for the named object.")
    parser.add_argument("-s", "--size", dest="sizeArcmin", help="Size of the image plot in arcmin.")
    parser.add_argument("-P", "--write-image-plots", dest="writeImagePlots", action="store_true", 
                        help="""Write out .png image plot to the BCGsCache directory.""", 
                        default = False)
    args = parser.parse_args()
    
    catalogFileName=args.catalogFileName
    resultsDir=args.resultsDir
    
    if args.survey == 'SDSS':
        fetchImage=fetchSDSSImage
    elif args.survey == 'DES':
        fetchImage=fetchDESImage
    else:
        raise Exception("Unknown survey - only 'SDSS' and 'DES' supported at the moment")
    
    clusTab=atpy.Table().read(catalogFileName)
    if 'z' in list(clusTab.keys()):
        zColumn='z'
    elif 'redshift' in list(clusTab.keys()):
        zColumn='redshift'
    else:
        raise Exception("No redshift column found")
    
    # Add columns for BCG, an optional 2nd galaxy, and the position angle between them
    # (this is for setting up long slit observations)
    colsToAdd=['BCG_RADeg', 'BCG_decDeg', 'BCG_rMag', 'galaxy2_RADeg', 'galaxy2_decDeg', 'galaxy2_rMag', 'PA_deg']
    for col in colsToAdd:
        if col not in list(clusTab.keys()):
            clusTab.add_column(atpy.Column(np.zeros(len(clusTab)), col))
    
    # Add something so that we can adjust a single BCG if given on command line
    # Keep things straight by making a mask of objects to look at with just that one object selected
    # But still iterate over all rows of clusTab, hoovering up previous BCG positions and adding to clusTab
    # After loop, write out clusTab as catalogFileName with _incBCGPos appended
    selectionMask=np.ones(len(clusTab), dtype = bool)
    if args.name is not None:
        selectionMask[np.where(clusTab['name'] != args.name)]=False
    
    if args.sizeArcmin == None:
        clipSizeArcmin=6.0
    else:
        clipSizeArcmin=float(args.sizeArcmin)
        
    # Cache previously measured BCG positions
    bcgCacheDir="BCGsCache"
    jpgsCacheDir=bcgCacheDir+os.path.sep+"jpgs"
    for dirToMake in [bcgCacheDir, jpgsCacheDir]:
        if os.path.exists(dirToMake) == False:
            os.makedirs(dirToMake, exist_ok = True)
        
    for cluster, selected in zip(clusTab, selectionMask):
        print(">>> %s" % (cluster['name']))
        
        galTab=atpy.Table().read(resultsDir+os.path.sep+cluster['name']+os.path.sep+"galaxyCatalog_"+cluster['name'].replace(" ", "_")+".fits")
        outFileName=bcgCacheDir+os.path.sep+cluster['name'].replace(" ", "_")+"_BCG.fits"
        if args.writeImagePlots == True:
            plotFileName=jpgsCacheDir+os.path.sep+cluster['name'].replace(" ", "_")+"_BCG.jpg"
        else:
            plotFileName=None
        if os.path.exists(outFileName) == True:
            print("... found previously determined BCG position ...")
            matchTab=atpy.Table().read(outFileName)
            matchTab1=matchTab[0:1]
            if len(matchTab) > 1:
                matchTab2=matchTab[1:2]
            else:
                matchTab2=None
        else:
            matchTab=None
            matchTab1=None
            matchTab2=None
                
        if selected == True:
            fetchImage(cluster)
            cmdImage=CMDImageCombo(cluster, galTab, args.survey, clipSizeArcmin = clipSizeArcmin, zColumn = zColumn, 
                                   matchTab1 = matchTab1, matchTab2 = matchTab2, plotFileName = plotFileName)
                
            plt.show(block = True)

            # Store BCG pos separately
            if cmdImage.keepBCG == True:
                matchTab=atpy.vstack([cmdImage.matchTab1, cmdImage.matchTab2])
                matchTab.write(outFileName, overwrite = True)
            else:
                matchTab=None
        
        if matchTab is not None:
            cluster['BCG_RADeg']=matchTab['RADeg'][0]
            cluster['BCG_decDeg']=matchTab['decDeg'][0]
            cluster['BCG_rMag']=matchTab['r'][0]
            if len(matchTab) > 1:
                cluster['galaxy2_RADeg']=matchTab['RADeg'][1]
                cluster['galaxy2_decDeg']=matchTab['decDeg'][1]
                cluster['galaxy2_rMag']=matchTab['r'][1]
                # Calculate position angle (fine to use pixel coords as TAN projection)
                cluster['PA_deg']=np.degrees(np.arctan2(matchTab['yImage'][0]-matchTab['yImage'][1], 
                                                    matchTab['xImage'][0]-matchTab['xImage'][1]))
            
    # Write out new table 
    outFileName=catalogFileName.replace(".fits", "_incBCGPos.fits")
    clusTab.write(outFileName, overwrite = True)

