#!/usr/bin/python3

import numpy as np

from astropy import constants as cte
from astropy import units as u

import argparse

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

from pNbody import *
from pNbody import ic
from pNbody.Mockimgs import lib as libMockimgs


import pickle
import os
  
####################################################################
# option parser
####################################################################

description="display surface brightness images"
epilog     ="""
Examples:
--------
mockimgs_sb_show_images output.pkl
mockimgs_sb_show_images output.pkl -o output.png
mockimgs_sb_show_images output.pkl -n 1 -o output.png
mockimgs_sb_show_images output.pkl --ext_kpc 100  --sbmin 25 --sbcontours 28 30.5 -o output.png
"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)





parser.add_argument(action="store", 
                    dest="file", 
                    metavar='FILE', 
                    type=str,
                    default=None,
                    nargs='*',
                    help='a file name') 

parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  



                    
                                    
                                    
parser.add_argument('--sbmin',
                    action="store", 
                    dest="sbmin", 
                    metavar='FLOAT', 
                    type=float,
                    default=25,
                    help='surface brightness minimum')

parser.add_argument('--sbmax',
                    action="store", 
                    dest="sbmax", 
                    metavar='FLOAT', 
                    type=float,
                    default=32,
                    help='surface brightness maximum')
                    
parser.add_argument('--sbcontours',
                    action="store", 
                    dest="sbcontours", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    nargs="*",
                    help='surface brightness contours')
                    

parser.add_argument("--ext_kpc",
                    action="store", 
                    dest="ext_kpc", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='field of view extension in kpc') 


                    
parser.add_argument("-n",
                    action="store", 
                    dest="n", 
                    metavar='INT', 
                    type=int,
                    default=None,
                    help='number of images to display')                         
                                             
                                        
                    
####################################################################
# main
####################################################################


if __name__ == '__main__':
  
  
  params = {
    "axes.labelsize": 14,
    "axes.titlesize": 18,
    "font.size": 12,
    "legend.fontsize": 12,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "text.usetex": False,
    "figure.subplot.left": 0.1,
    "figure.subplot.right": 1.0,
    "figure.subplot.bottom": 0.05,
    "figure.subplot.top": 0.95,
    "figure.subplot.wspace": 0.02,
    "figure.subplot.hspace": 0.02,
    "figure.figsize" : (12*1.15, 12),
    "lines.markersize": 6,
    "lines.linewidth": 3.0,
  }
  plt.rcParams.update(params)
  
  

  
  opt = parser.parse_args()
  
  
    
  if len(opt.file)==0:
    print("need to provie one file")
    exit()
    
   
  f = open(opt.file[0],"rb")
  images = pickle.load(f)
  f.close()
  
  
  if opt.n is None:
    nx = int(np.sqrt(len(images)))
    ny = nx
  else:
    nx = int(np.sqrt(opt.n))
    ny = nx
  
  
  if nx==1:  
    fig = plt.gcf()
    ax  = plt.gca()
  else:    
    fig, ax = plt.subplots(nx,ny)
  

  

  # plot
  fig = plt.gcf()
  #fig.set_size_inches(12*1.15, 12)
  #fig.subplots_adjust(left=0.1)
  #fig.subplots_adjust(right=1)
  #fig.subplots_adjust(bottom=0.05)
  #fig.subplots_adjust(top=0.95)
  #fig.subplots_adjust(wspace=0.02)
  #fig.subplots_adjust(hspace=0.02)  

  
  colors2 = plt.cm.gist_heat(np.linspace(0.3, 1.0, 255))
  mymap = mcolors.LinearSegmentedColormap.from_list('my_colormap', colors2)
  
  for i in range(nx*ny):
      
    FluxMap,params = images[i]
    

    # apply some filters
    FluxMap = libMockimgs.FilterFluxMap(FluxMap,None)
  
    # compute the sufrace density map from the luminosity/flux map 
    image =  libMockimgs.FluxToSurfaceBrightness(FluxMap,params["object_distance"],params["ccd_pixel_area"])
    

    # get image extension
    #xmin = params["ccd_xmin"].to(u.arcsec).value
    #xmax = params["ccd_xmax"].to(u.arcsec).value
    #ymin = params["ccd_ymin"].to(u.arcsec).value
    #ymax = params["ccd_ymax"].to(u.arcsec).value


    xmin = params["ccd_xmin_kpc"].to(u.kpc).value
    xmax = params["ccd_xmax_kpc"].to(u.kpc).value
    ymin = params["ccd_ymin_kpc"].to(u.kpc).value
    ymax = params["ccd_ymax_kpc"].to(u.kpc).value
    
      

              
    #image = np.where(image>26,34,image)
    
    
    ix = i//nx
    iy = i- (ix*nx)
    
    if nx==1:
      axii = ax
    else:
      axii = ax[ix,iy]
    
    #ax[i] = plt.gca()      
    axii.set_aspect('equal')
    
  
    im = axii.imshow(image,aspect='equal',extent=(xmin,xmax,ymin,ymax),interpolation=None,vmin=opt.sbmin,vmax=opt.sbmax,cmap=mymap)
    
    if opt.sbcontours is not None:
      axii.contour(image,levels=opt.sbcontours,colors='k',linewidths=1,linestyles='solid',aspect='equal',extent=(xmin,xmax,ymin,ymax),origin='upper')
    
    
    #axii.set_xlabel(r"$x\,[\rm{kpc}]$")
    #axii.set_ylabel(r"$y\,[\rm{kpc}]$")
    axii.set_xlabel(r"x [kpc]")
    axii.set_ylabel(r"y [kpc]")    
    
    
    if opt.ext_kpc is not None:
      fac = opt.ext_kpc/xmax
      axii.set_xlim(xmin*fac,xmax*fac)
      axii.set_ylim(ymin*fac,ymax*fac)
    
    
    if i != nx*(ny-1):
      axii.get_xaxis().set_visible(False)
      axii.get_yaxis().set_visible(False)
      axii.get_xaxis().set_ticks([])
      axii.get_yaxis().set_ticks([])
    
    # save fits
    if opt.outputfilename:
      if os.path.splitext(opt.outputfilename)[-1] == ".fits":
        from astropy.io import fits
        hdu = fits.PrimaryHDU(image)
        
        if opt.n == 1:
          filename = opt.outputfilename
        else:
          filename = os.path.splitext(opt.outputfilename)[0]
          filename = "%s%02d.fits"%(filename,i)
        
        print("saving... %s"%filename)
        
        if os.path.isfile(filename):
          os.remove(filename)
        
        hdu.writeto(filename)
    
    
    
    
  plt.colorbar(im,label="surface brightness [mag/arcsec]",ax=ax)
  
  
  #txt = r"$\rm{%s}\,\,:\rm{%s}\,\,:\,\,\rm{D=%4.1f}\,\rm{Mpc}$"%(params["name"],params["filter_name"],params["object_distance"].to(u.Mpc).value)
  txt = r"%s : %s : D=%4.1f Mpc"%(params["name"],params["filter_name"],params["object_distance"].to(u.Mpc).value)
  
  fig.suptitle(txt,fontsize=20)
    
    
    
    
  if opt.outputfilename:
    if os.path.splitext(opt.outputfilename)[-1] != ".fits":
      plt.savefig(opt.outputfilename)
  else:
    plt.show()
    
  
  
  



