#!/usr/bin/python3

import sys
import numpy as np
import argparse

from pNbody import *

import matplotlib.pyplot as plt

from astropy.io import ascii


description="""
From a given magnitude grid, plot the magnitude as a function of stellar masses, for each metallicity bin.
Magnitudes are normalized to 1e11 Msol.
"""
epilog     ="""
Examples:
--------
isochrones_plot_magnitudes_grid HSTF475Xmag_CMD_1e6.pkl 
isochrones_plot_magnitudes_grid HSTF475Xmag_CMD_1e6.hdf5 

"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)


parser.add_argument("file",
                    action="store", 
                    metavar='FILENAME', 
                    type=str,
                    default=None,
                    help='file') 

parser.add_argument("--luminosity",
                    action="store_true", 
                    dest="luminosity", 
                    default=False,
                    help='assume luminosities') 
                                
            
parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  


params = {
    "axes.labelsize": 14,
    "axes.titlesize": 18,
    "font.size": 12,
    "legend.fontsize": 12,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "text.usetex": True,
    "figure.subplot.left": 0.08,
    "figure.subplot.right": 0.98,
    "figure.subplot.bottom": 0.05,
    "figure.subplot.top": 0.99,
    "figure.subplot.wspace": 0.0,
    "figure.subplot.hspace": 0.0,
    "figure.figsize" : (10,16),
    "lines.markersize": 6,
    "lines.linewidth": 3.0,
}
plt.rcParams.update(params)




def ReadMistPickleFile(filename):
  """
  read file and create a data table
  
  data has a dimension of len(binsAge) x len(binsFe)
  """
  import pickle
  
  # read the pickle file  
  f = open(filename,"rb")
  data = pickle.load(f)
  binsAge = pickle.load(f)
  binsFe  = pickle.load(f)
  f.close()
  
  return data,binsAge,binsFe
  
  
  

####################################################################
# main
####################################################################


opt = parser.parse_args()


if opt.file.endswith(".pkl"):
  # open the "raw data"
  data,binsAge,binsFe = ReadMistPickleFile(opt.file)

elif opt.file.endswith(".hdf5"):
  from pNbody.iofunc import SSPGrid 

  SSP = SSPGrid(opt.file)
  data = SSP.read()
  
  #MiniMfin  = data["Data"]["MiniMfin"]
  binsAge   = data["Data"]["Ages"]
  binsFe    = data["Data"]["MH"]
  data      = data["Data"]["Magnitudes"]
  


# normalize to 1e11 Msol
if opt.luminosity:
  data     = 10**(data)
else:  
  F     = 10**(-data/2.5)*1e11
  data  = -2.5*np.log10(F)



# number of panels
Npanels = len(binsFe)

# create Npanels subplots
fig, axes = plt.subplots(Npanels)


# loop over the plots
for i in range(Npanels):
    
  ax = axes[i]
  
  if opt.luminosity:
    ax.set_ylabel("Luminosity [Lsol]")
  else:  
    ax.set_ylabel("Magnitude")
    ax.set_xlim(1e-3,20)
    ax.set_ylim(-31,-19)
  
  ax.semilogx()
  ax.axes.get_xaxis().set_visible(False)
  
  Z1 = binsFe[i]
  ax.plot(binsAge/1000,data[:,i],color="blue",label="[M/H]=%g"%Z1)
  ax.legend(loc="upper left")  
  

  if i==Npanels-1:
    ax.set_xlabel("Ages [Gyr]")
    ax.axes.get_xaxis().set_visible(True)
    

  
if opt.outputfilename is not None:
  plt.savefig(opt.outputfilename)
else:  
  plt.show()  

