#!/usr/bin/python3

import numpy as np

from astropy import constants as cte
from astropy import units as u

import argparse

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

from pNbody import *
from pNbody import ic
from pNbody.Mockimgs import lib as libMockimgs


import pickle
import os

from astropy.io import fits
  
####################################################################
# option parser
####################################################################

description="display surface brightness images from a fits file"
epilog     ="""
Examples:
--------
mockimgs_sb_profile image.fits
mockimgs_sb_profile image1.fits image2.fits
mockimgs_sb_profile image1.fits image2.fits -o output.png

mockimgs_sb_profile image.fits --add_axes --ax_unit kpc --ax_max 3
mockimgs_sb_profile image.fits --add_axes --ax_unit arcsec --ax_max 10
mockimgs_sb_profile image.fits --add_axes --ax_unit arcmin --ax_max 60



"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)





parser.add_argument(action="store", 
                    dest="files", 
                    metavar='FILE', 
                    type=str,
                    default=None,
                    nargs='*',
                    help='a list of fits files') 

parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  


parser.add_argument("--add_axes",
                    action="store_true", 
                    default=False,
                    help='add axes to the figure') 

parser.add_argument("--no_plot",
                    action="store_true", 
                    default=False,
                    help='do not plot') 
                                        
                                    
parser.add_argument('--fig_size',
                    action="store", 
                    metavar='FLOAT', 
                    type=float,
                    default=6,
                    help='size of a single figure in inches')
                    
                                                        
                                    
                                    

                    

parser.add_argument("--ax_unit",
                    action="store", 
                    metavar='STRING', 
                    type=str,
                    default='pixels',
                    help='axes units (kpc, arcsec, pixels)') 

parser.add_argument('--ax_max',
                    action="store", 
                    metavar='FLOAT', 
                    type=float,
                    default=20,
                    help='extention of the image in the axes units')
                    
                                        
                    
####################################################################
# main
####################################################################


if __name__ == '__main__':
  
  
  params = {
    "axes.labelsize": 14,
    "axes.titlesize": 18,
    "font.size": 12,
    "legend.fontsize": 12,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "text.usetex": False,
    #"figure.subplot.left": 0.0,
    #"figure.subplot.right": 1.0,
    #"figure.subplot.bottom": 0.0,
    #"figure.subplot.top": 1,
    #"figure.subplot.wspace": 0.02,
    #"figure.subplot.hspace": 0.02,
    #"figure.figsize" : (15, 4),
    "lines.markersize": 6,
    "lines.linewidth": 3.0,
  }
  plt.rcParams.update(params)
  
  

  
  opt = parser.parse_args()
  
  
  # number of images  
  n = len(opt.files)  
   



  # plot
  fig = plt.gcf()
  
  
  #fig.subplots_adjust(left=0.0)
  #fig.subplots_adjust(right=0.995)
  #fig.subplots_adjust(bottom=0.0)
  #fig.subplots_adjust(top=1)
  #fig.subplots_adjust(wspace=0.02)
  #fig.subplots_adjust(hspace=0.02)  


  for i in range(n):
      
    # open image  
    hdul = fits.open(opt.files[i])
    
    data = hdul[0].data
    header = hdul[0].header 
    
    
    # find the image extend
    NX = header["NX"]
    NY = header["NY"]
      
    
    if opt.ax_unit == 'pixels':
  
      NX = header["NX"]
      NY = header["NY"]
      DX = opt.ax_max
    
      imgext_xmin = 0
      imgext_xmax = NX
      imgext_ymin = 0 
      imgext_ymax = NY
    
      xmin = NX//2 - DX
      xmax = NX//2 + DX
      ymin = NY//2 - DX
      ymax = NY//2 + DX
      
      xlabel = r"x [pixel]"
      ylabel = r"y [pixel]"
    
    elif opt.ax_unit == 'kpc':
      
      imgext_xmin = header["XMINKPC"]
      imgext_xmax = header["XMAXKPC"]
      imgext_ymin = header["YMINKPC"] 
      imgext_ymax = header["YMAXKPC"]
    
      xmin =  opt.ax_max
      xmax = -opt.ax_max
      ymin =  opt.ax_max
      ymax = -opt.ax_max  
      
      xlabel = r"x [kpc]"
      ylabel = r"y [kpc]"
      
      
    elif opt.ax_unit == 'arcsec':
      
      imgext_xmin = header["XMIN"]
      imgext_xmax = header["XMAX"]
      imgext_ymin = header["YMIN"] 
      imgext_ymax = header["YMAX"]
    
      xmin =  opt.ax_max
      xmax = -opt.ax_max
      ymin =  opt.ax_max
      ymax = -opt.ax_max  
      
      xlabel = r"x [arcsec]"
      ylabel = r"y [arcsec]"      

    elif opt.ax_unit == 'arcmin':
      
      imgext_xmin = header["XMIN"]/60
      imgext_xmax = header["XMAX"]/60
      imgext_ymin = header["YMIN"]/60
      imgext_ymax = header["YMAX"]/60
    
      xmin =  opt.ax_max
      xmax = -opt.ax_max
      ymin =  opt.ax_max
      ymax = -opt.ax_max  
      
      xlabel = r"x [arcmin]"
      ylabel = r"y [arcmin]"          
    
    
        
    
    
    x, y = np.meshgrid(np.arange(NX), np.arange(NY))
    xc = int(NX/2)
    yc = int(NX/2)

    r = np.sqrt( (x-xc)**2 +  (y-yc)**2 ) 

    bins = np.arange(0,NX,1.5)
    
    
    SB = np.zeros(len(bins)-1)
    RS = np.zeros(len(bins)-1)
    
    for j in range(len(bins)-1):
      
      rmin = bins[j]
      rmax = bins[j+1]
      
      tmp = np.where((r>rmin)*(r<rmax),data,100)
      tmp = np.compress(tmp.ravel()<100,tmp.ravel())
      
      if len(tmp!=0):
        F = 10**(-tmp/2.5)
        F = F.mean()
        tmp = -2.5 * np.log10(F) 
        
        SB[j] = tmp
        RS[j] = 0.5*(rmax+rmin)  
      else:
        SB[j] = np.NaN
        RS[j] = 0.5*(rmax+rmin)            
      
    
    c = np.isreal(SB)
    SB = np.compress(c,SB) 
    RS = np.compress(c,RS) 
    
    
    print(opt.files[i],SB[0])
    
    if not opt.no_plot:
      
      plt.plot(RS,SB) 
      plt.semilogx()
       
      ax = plt.gca() 
      ax.invert_yaxis()
      plt.xlabel("Radius")
      plt.ylabel("Surface brightness [mag/arsec2]")
    
    
  if not opt.no_plot:  
    if opt.outputfilename:
      if os.path.splitext(opt.outputfilename)[-1] != ".fits":
        plt.savefig(opt.outputfilename)
    else:
      plt.show()
