#!/usr/bin/env python3

import numpy as np

from astropy import constants as cte
from astropy import units as u

import argparse

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib as mpl

from pNbody import *
from pNbody import ic
from pNbody.Mockimgs import lib as libMockimgs

from astropy.io import fits

  
####################################################################
# option parser
####################################################################

description="display a set of fits images"
epilog     ="""
Examples:
--------
mockimgs_sb_display_fits --add_axes --ax_unit kpc --ax_max 300 --sbmin 25 --sbmax 31.5 --colorbar --colormap Greys --colormap_reverse *.fits

mockimgs_sb_display_fits *.fits -o output.png
mockimgs_sb_display_fits *.fits 
mockimgs_sb_display_fits *.fits --sbcontours 28 30.5 --ax_max 150  --colorbar -o output.png
mockimgs_sb_display_fits *.fits --sbcontours 28 30.5 --ax_max 150  --colorbar --add-title -o output.png
mockimgs_sb_display_fits *.fits --sbcontours 28 30.5 --ax_max 150  --colorbar --colormap mycmap -o output.png
mockimgs_sb_display_fits *.fits --sbcontours 28 30.5 --ax_max 150  --colorbar --colormap mycmap --colormap_reverse -o output.png 
mockimgs_sb_display_fits *.fits --sbcontours 28 30.5 --ax_max 150  --colorbar --colormap mycmap --colormap_reverse -o output.png 

mockimgs_sb_display_fits *.fits --ax_unit pixels --add_axes  

mockimgs_sb_display_fits image.fits --add_axes --ax_unit kpc --ax_max 3
mockimgs_sb_display_fits image.fits --add_axes --ax_unit arcsec --ax_max 10
mockimgs_sb_display_fits image.fits --add_axes --ax_unit arcmin --ax_max 60

mockimgs_sb_display_fits image.fits --sbmin 25 --sbmax 32
mockimgs_sb_display_fits image.fits --sbcontours 28 30.5
mockimgs_sb_display_fits image.fits --colorbar



"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)



parser.add_argument(action="store", 
                    dest="files", 
                    metavar='FILE', 
                    type=str,
                    default=None,
                    nargs='*',
                    help='a list of files') 

parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  

parser.add_argument("-n",
                    action="store", 
                    dest="n", 
                    metavar='INT', 
                    type=int,
                    default=None,
                    help='number of images to display')  


parser.add_argument("--ax_unit",
                    action="store", 
                    metavar='STRING', 
                    type=str,
                    default='kpc',
                    help='axes units (kpc, arcsec, pixels)') 

parser.add_argument('--ax_max',
                    action="store", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='extention of the image in the axes units')
    
parser.add_argument('--sbmin',
                    action="store", 
                    dest="sbmin", 
                    metavar='FLOAT', 
                    type=float,
                    default=25,
                    help='surface brightness minimum')

parser.add_argument('--sbmax',
                    action="store", 
                    dest="sbmax", 
                    metavar='FLOAT', 
                    type=float,
                    default=32,
                    help='surface brightness maximum')
                        
    
                    
                    
parser.add_argument("--add_axes",
                    action="store_true", 
                    default=False,
                    help='add axes to the figure') 

parser.add_argument("--colorbar",
                    action="store_true", 
                    default=False,
                    help='add colorbar to the figure') 

parser.add_argument("--colormap",
                    action="store", 
                    default=None,
                    help='matplotlib colormap name (e.g. mycmap, tab20c, Greys, jet, binary)') 

                 
parser.add_argument("--colormap_reverse",
                    action="store_true", 
                    default=False,
                    help='reverse colormap')                    

parser.add_argument('--sbcontours',
                    action="store", 
                    dest="sbcontours", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    nargs="*",
                    help='surface brightness contours')

parser.add_argument("--add-title",
                    action="store_true", 
                    dest="add_title", 
                    default=False,
                    help='add a title (automatic)') 

parser.add_argument('--title',
                    action="store", 
                    dest="title", 
                    metavar='STR', 
                    type=str,
                    default=None,
                    help='title to add')

                    
parser.add_argument("--abs-value",
                    action="store_true", 
                    dest="abs_value",
                    default=False,
                    help='plot the absolute value of the flux') 


def mycmap():
  import matplotlib.colors as colors

  carray = np.array([[0.19215686274509805, 0.5098039215686274,  0.7411764705882353,   1.0],
                     [0.4196078431372549,  0.6823529411764706,  0.8392156862745098,   1.0],
                     [0.6196078431372549,  0.792156862745098,   0.8823529411764706,   1.0],
                     [0.7764705882352941,  0.8588235294117647,  0.9372549019607843,   1.0],
                     [0.9019607843137255,  0.3333333333333333,  0.050980392156862744, 1.0],
                     [0.9921568627450981,  0.5529411764705883,  0.23529411764705882,  1.0],
                     [0.9921568627450981,  0.6823529411764706,  0.4196078431372549,   1.0],
                     [0.9921568627450981,  0.8156862745098039,  0.6352941176470588,   1.0],
                     [0.19215686274509805, 0.6392156862745098,  0.32941176470588235,  1.0],
                     [0.4549019607843137,  0.7686274509803922,  0.4627450980392157,   1.0],
                     [0.6313725490196078,  0.8509803921568627,  0.6078431372549019,   1.0],
                     [0.7803921568627451,  0.9137254901960784,  0.7529411764705882,   1.0],
                     [0.4588235294117647,  0.4196078431372549,  0.6941176470588235,   1.0],
                     [0.6196078431372549,  0.6039215686274509,  0.7843137254901961,   1.0],
                     [0.7372549019607844,  0.7411764705882353,  0.8627450980392157,   1.0],
                     [0.8549019607843137,  0.8549019607843137,  0.9215686274509803,   1.0],
                     [0.38823529411764707, 0.38823529411764707, 0.38823529411764707,  1.0],
                     [0.5882352941176471,  0.5882352941176471,  0.5882352941176471,   1.0],
                     [0.7411764705882353,  0.7411764705882353,  0.7411764705882353,   1.0],
                     [0.8509803921568627,  0.8509803921568627,  0.8509803921568627,   1.0]]
  )

  carray = np.array([[0.19215686274509805, 0.6392156862745098,  0.32941176470588235,  1.0],
                     [0.4549019607843137,  0.7686274509803922,  0.4627450980392157,   1.0],
                     [0.6313725490196078,  0.8509803921568627,  0.6078431372549019,   1.0],
                     [0.7803921568627451,  0.9137254901960784,  0.7529411764705882,   1.0],
                     [0.19215686274509805, 0.5098039215686274,  0.7411764705882353,   1.0],
                     [0.4196078431372549,  0.6823529411764706,  0.8392156862745098,   1.0],
                     [0.6196078431372549,  0.792156862745098,   0.8823529411764706,   1.0],
                     [0.7764705882352941,  0.8588235294117647,  0.9372549019607843,   1.0],
                     [0.9019607843137255,  0.3333333333333333,  0.050980392156862744, 1.0],
                     [0.9921568627450981,  0.5529411764705883,  0.23529411764705882,  1.0],
                     [0.9921568627450981,  0.6823529411764706,  0.4196078431372549,   1.0],
                     [0.9921568627450981,  0.8156862745098039,  0.6352941176470588,   1.0],                       
                     [1.0,  1.0,  1.0,   1.0]]
  )

  new_cmap = colors.LinearSegmentedColormap.from_list('mymap',carray,N=13)
  return new_cmap

                    

def find_image_extention(opt,header):

  # find the image extend
  
  if opt.ax_unit == 'pixels':
  
    NX = header["NX"]
    NY = header["NY"]
    if opt.ax_max is not None:
      DX = opt.ax_max
    else:
      DX = NX//2  
      
    imgext_xmin = 0
    imgext_xmax = NX
    imgext_ymin = 0 
    imgext_ymax = NY
  
    xmin = NX//2 - DX
    xmax = NX//2 + DX
    ymin = NY//2 - DX
    ymax = NY//2 + DX
    
    xlabel = r"x [pixel]"
    ylabel = r"y [pixel]"
  
  elif opt.ax_unit == 'kpc':
    
    imgext_xmin = header["XMINKPC"]
    imgext_xmax = header["XMAXKPC"]
    imgext_ymin = header["YMINKPC"] 
    imgext_ymax = header["YMAXKPC"]
    
    if opt.ax_max is None:
      opt.ax_max = header["XMAXKPC"]
        
    xmin = -opt.ax_max
    xmax =  opt.ax_max
    ymin = -opt.ax_max
    ymax =  opt.ax_max  
    
    xlabel = r"x [kpc]"
    ylabel = r"y [kpc]"
    
    
  elif opt.ax_unit == 'arcsec':
    
    imgext_xmin = header["XMIN"]
    imgext_xmax = header["XMAX"]
    imgext_ymin = header["YMIN"] 
    imgext_ymax = header["YMAX"]

    if opt.ax_max is None:
      opt.ax_max = header["XMAX"]
  
    xmin = -opt.ax_max
    xmax =  opt.ax_max
    ymin = -opt.ax_max
    ymax =  opt.ax_max  
    
    xlabel = r"x [arcsec]"
    ylabel = r"y [arcsec]"      
  
  elif opt.ax_unit == 'arcmin':
    
    imgext_xmin = header["XMIN"]/60
    imgext_xmax = header["XMAX"]/60
    imgext_ymin = header["YMIN"]/60
    imgext_ymax = header["YMAX"]/60
  
    if opt.ax_max is None:
      opt.ax_max = header["XMAX"]/60
  
    xmin = -opt.ax_max
    xmax =  opt.ax_max
    ymin = -opt.ax_max
    ymax =  opt.ax_max  
    
    xlabel = r"x [arcmin]"
    ylabel = r"y [arcmin]"          

  

  return imgext_xmin,imgext_xmax,imgext_ymin,imgext_ymax,xmin,xmax,ymin,ymax,xlabel,ylabel


                    
####################################################################
# main
####################################################################



if __name__ == '__main__':
  
  opt = parser.parse_args()


  
  params = {
    "axes.labelsize": 14,
    "axes.titlesize": 18,
    "font.size": 12,
    "legend.fontsize": 12,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "text.usetex": False,
    "figure.subplot.left": 0.1,
    "figure.subplot.right": 1.0,
    "figure.subplot.bottom": 0.05,
    "figure.subplot.top": 0.95,
    "figure.subplot.wspace": 0.02,
    "figure.subplot.hspace": 0.02,
    "figure.figsize" : (15*1.15, 15),    
    "lines.markersize": 6,
    "lines.linewidth": 3.0,
  }
  plt.rcParams.update(params)



  if opt.n is None:
    nx = int(np.sqrt(len(opt.files)))
    ny = nx
  else:
    nx = int(np.sqrt(opt.n))
    ny = nx
  
  if nx==1:  
    fig = plt.gcf()
    fig.set_size_inches(12,10)
    ax  = plt.gca()
  else:    
    fig, ax = plt.subplots(nx,ny)
    fig.set_size_inches(15*1.15, 15)

  # color map
  if opt.colormap is not None:
    if opt.colormap == "mycmap":
      cmap = mycmap()
    
    else:
      cmap = mpl.colormaps[opt.colormap]  
      
  
  else:
    # use default one
    colors2 = plt.cm.gist_heat(np.linspace(0.3, 1.0, 255))
    cmap = mcolors.LinearSegmentedColormap.from_list('my_colormap', colors2)
  
  
  if opt.colormap_reverse:
    cmap = cmap.reversed()




  # loop over images of the same model but with different los
  for k,f in enumerate(opt.files):
    
    print(f)
       
    # open image  
    hdul = fits.open(opt.files[k])
    
    image = hdul[0].data          # surface brightness map
    header = hdul[0].header 
    

    # get image extention
    imgext_xmin,imgext_xmax,imgext_ymin,imgext_ymax,xmin,xmax,ymin,ymax,xlabel,ylabel = find_image_extention(opt,header)

    
    # get plot indexes
    ix = k//nx
    iy = k- (ix*nx)
    
    if nx==1:
      axii = ax
    else:
      axii = ax[ix,iy]
    
    axii.set_aspect('equal')


    # check if pixels contains magnitudes of fluxes
    # get the image shape
    if "UNITS" in header:
      units = header["UNITS"]
    else:
      raise ValueError("'UNITS' not found in the fits header.")    
    
    # the image contains magnitudes
    if units=='mag/arcsec^2':
      # remove value set to 100
      tmp = np.compress(image.ravel()<100,image.ravel())
      # set opt.sbmax if not defined
      if opt.sbmax==None or opt.sbmin==None:
        # remove value set to 100
        if opt.sbmax==None:
          opt.sbmax = tmp.max()
    
    # the image contains a flux
    else:
      if opt.abs_value:
        image = np.abs(image)
        
      if "PIXAREA" in header:
        Omega = header["PIXAREA"]
      else:
        raise ValueError("'PIXAREA' not found in the fits header.")     
      
      # Flux to magnitude (zp=0)  
      image = -2.5*np.log10(image/Omega) + 0

      if opt.sbmax==None or opt.sbmin==None:
        if opt.sbmax==None:
          opt.sbmax = image.max()



    # plot the image
    im = axii.imshow(image,aspect='equal',extent=(imgext_xmin,imgext_xmax,imgext_ymin,imgext_ymax),interpolation='none',vmin=opt.sbmin,vmax=opt.sbmax,cmap=cmap)

    if opt.sbcontours is not None:
      axii.contour(image,levels=opt.sbcontours,colors='k',linewidths=1,linestyles='solid',aspect='equal',extent=(imgext_xmin,imgext_xmax,imgext_ymin,imgext_ymax),origin='upper')




    # labels
    axii.set_xlabel(xlabel)
    axii.set_ylabel(ylabel)    
    
    # limits
    axii.set_xlim(xmin,xmax)
    axii.set_ylim(ymin,ymax)

    
    # add axes
    if not opt.add_axes:
      axii.get_xaxis().set_visible(False)
      axii.get_yaxis().set_visible(False)
      axii.get_xaxis().set_ticks([])
      axii.get_yaxis().set_ticks([])
    

    if k != nx*(ny-1):
      axii.get_xaxis().set_visible(False)
      axii.get_yaxis().set_visible(False)
      axii.get_xaxis().set_ticks([])
      axii.get_yaxis().set_ticks([])



  # add colorbar
  if opt.colorbar:
    #plt.colorbar(im,label="surface brightness [mag/arcsec]",ax=axii,location='right')
    plt.colorbar(im,label="surface brightness [mag/arcsec]",ax=ax,location='right')
  
  # add a title (automatically)
  if opt.add_title:
    filtername = os.path.basename(header["FILTER"])
    txt = r"%s : filter=%s : D=%4.1f Mpc"%(header["OBJ_NAME"],filtername,header["OBJ_DIST"])
    fig.suptitle(txt,fontsize=20)

  # add a title 
  if opt.title is not None:
    txt = r"%s"%(opt.title)
    fig.suptitle(txt,fontsize=20)

  
  # save or display
  if opt.outputfilename:
    plt.savefig(opt.outputfilename)
  else:
    plt.show()
    






