#!/usr/bin/python3

import time
import sys
import numpy as np
import argparse

from pNbody import *
from pNbody import Isochrones

import matplotlib.pyplot as plt
import matplotlib.colors as plt_colors
import matplotlib.cm as plt_cm

description="Display isochrones or colour magnitude diagrams"
epilog     ="""
Examples:
--------
isochrones_plot -d P04O1D1E1Y247 --info
isochrones_plot -d P04O1D1E1Y247 --Age 13000 --x logTe  --y 'log(L/Lo)'--invertx 
isochrones_plot -d P04O1D1E1Y247 --Age 13000 --x G_BP-G_RP  --y G --inverty 
isochrones_plot -d P04O1D1E1Y247 --Age 13000 --MH -2
isochrones_plot -d P04O1D1E1Y247 --Age 13000 --MH -2 -1 
isochrones_plot -d P04O1D1E1Y247 --Age 13000 --addMasses --MH -2 -1 
isochrones_plot -d P04O1D1E1Y247 --Age 13000 --x G_BP-G_RP --y G --inverty --xmin -1 --xmax 3 --ymin -5 --ymax 13 --MH -2

isochrones_plot -d MIST_v1.2_vvcrit0.0_UBVRIplus --info
isochrones_plot -d MIST_v1.2_vvcrit0.0_UBVRIplus --Age 13000 --x log_Teff  --y 'log_L' --invertx
isochrones_plot -d MIST_v1.2_vvcrit0.0_UBVRIplus --Age 13000 --x Gaia_BP_EDR3-Gaia_RP_EDR3 --y Gaia_G_EDR3 --inverty 

isochrones_plot -f CMD_Euclid.dat --info
isochrones_plot -f CMD_Euclid.dat --Age 13000 --x logTe  --y logL --invertx 
isochrones_plot -f CMD_Euclid.dat --Age 13000 --x Bluemag-Redmag  --y VISmag --inverty 

"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)



parser.add_argument("-o",
                    action="store", 
                    metavar='FILENAME', 
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  
                                        

parser.add_argument("-d",
                    action="store", 
                    dest="directory", 
                    metavar='DIRECTORY', 
                    type=str,
                    default=Isochrones.database_directory,
                    help='a directory (default=%s)'%Isochrones.database_directory) 

parser.add_argument("-f",
                    action="store", 
                    dest="file", 
                    metavar='FILENAME', 
                    type=str,
                    default=None,
                    help='the name of isochrones database file') 

                
            
parser.add_argument("--Age",
                    action="store", 
                    dest="Age", 
                    metavar='FLOAT', 
                    type=float,
                    help='Isochrone Age')                
            
parser.add_argument("--MH",
                    action="store", 
                    dest="MH", 
                    metavar='FLOAT', 
                    default=None,
                    type=float,
                    nargs=argparse.REMAINDER,
                    help='[MH/H]')             
                  
parser.add_argument("--addMasses",
                    action="store_true", 
                    dest="addMasses", 
                    default=False,
                    help='Add masses')             


parser.add_argument("--x",
                    action="store", 
                    dest="x", 
                    metavar='STRING', 
                    default='g_minus_i',
                    type=str,
                    help='x value')   
                    
parser.add_argument("--y",
                    action="store", 
                    dest="y", 
                    metavar='STRING', 
                    default='g',
                    type=str,
                    help='y value')                       


parser.add_argument("--xmin",
                    metavar='FLOAT', 
                    type=float,
                    help='xmin')  

parser.add_argument("--xmax",
                    metavar='FLOAT', 
                    type=float,
                    help='xmax')  
                    
parser.add_argument("--ymin",
                    metavar='FLOAT', 
                    type=float,
                    help='ymin')  

parser.add_argument("--ymax",
                    metavar='FLOAT', 
                    type=float,
                    help='ymax')                      
                    


parser.add_argument("--invertx",
                    action="store_true", 
                    dest="invertx", 
                    default=False,
                    help='force invert x axis') 
                    
parser.add_argument("--inverty",
                    action="store_true", 
                    dest="inverty", 
                    default=False,
                    help='force invert y axis') 


parser.add_argument("--info",
                    action="store_true", 
                    dest="info", 
                    default=False,
                    help='get info (list of keys) and exit.')  





def get_label(val):
  '''
  get the proper label for a given value
  '''
  
  return "%s"%val



  

def get_values(val, M, DB):
  '''
  get the proper key for a given value
  
  M : the MIST block
  DB: the MIST data base
  filter_system : the filter system
  
  '''
  #label = r'$\rm{Radius}\,\left[ kpc \right]$'
  
  # first, get the keys
  keys = DB.getKeys().keys()
      

  # check if its a diff of two values
  if val.find("-") > -1:
    val = val.split("-")
    

  # combined values
  
  if type(val)==list:
    key1 = val[0]
    DB.checkKeyExists(key1)  
    M1 = M.data[:,M.keys[key1]]
    
    key2 = val[1]
    DB.checkKeyExists(key2)  
    M2 = M.data[:,M.keys[key2]]
        
    return M1-M2
    


  if val == "Bluemag_minus_Redmag":
    
    #key1 = get_filter_key('Bluemag')
    #DB.checkKeyExists(key1)  
    M1 = M.data[:,M.keys['Bluemag']]
    
    #key2 = get_filter_key('Redmag')
    #DB.checkKeyExists(key2)  
    M2 = M.data[:,M.keys['Redmag']]

    return M1-M2

  if val == "Bmag_minus_Rmag":
    
    #key1 = get_filter_key('Bluemag')
    #DB.checkKeyExists(key1)  
    M1 = M.data[:,M.keys['Bmag']]
    
    #key2 = get_filter_key('Redmag')
    #DB.checkKeyExists(key2)  
    M2 = M.data[:,M.keys['Rmag']]

    return M1-M2
        
  if val == "VISmag_minus_Jmag":
    
    #key1 = get_filter_key('Bluemag')
    #DB.checkKeyExists(key1)  
    M1 = M.data[:,M.keys['VISmag']]
    
    #key2 = get_filter_key('Redmag')
    #DB.checkKeyExists(key2)  
    M2 = M.data[:,M.keys['Jmag']]

    return M1-M2  
    
    
  
  # other values (no conversion)  
  else:  
    key = val

  
  DB.checkKeyExists(key)  
  return M.data[:,M.keys[key]]

  





#######################################
# MakePlot
#######################################

def MakePlot(files, opt):
  
  
  params = {
    "axes.labelsize": 14,
    "axes.titlesize": 18,
    "font.size": 12,
    "legend.fontsize": 12,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "text.usetex": False,
    "figure.subplot.left": 0.12,
    "figure.subplot.right": 0.98,
    "figure.subplot.bottom": 0.10,
    "figure.subplot.top": 0.90,
    "figure.subplot.wspace": 0.0,
    "figure.subplot.hspace": 0.0,
    "figure.figsize" : (10,8),
    "lines.markersize": 6,
    "lines.linewidth": 3.0,
  }
  plt.rcParams.update(params)
  

  
  ##################################
  # init the mist database
  ##################################
  DB = Isochrones.Isochrones(opt.directory,default_keys=None,filename=opt.file)
  
  
  if opt.info:
    print("list of keys")
    print("------------")
    keys = DB.getKeys()
    for key in keys:
      print(key)
    exit()
  


  # plot isochrone
  # mass, for Age fixed
  # loop over MH
  
  if opt.Age is not None:
    
    #cmap = plt.colormaps['jet'] # this will have to be used with the new matplotlib
    cmap = plt.cm.get_cmap('jet')
    
    MHmin = DB.binsMH[0]
    MHmax = DB.binsMH[-1]
    
    if opt.MH is not None:
      binsMH = opt.MH
    else:
      binsMH = DB.binsMH  
    
    nMH = len(binsMH)
    
    colors = np.linspace(0,255,len(DB.binsMH)).astype(int)
  
    i = DB.getAgeIndex(opt.Age)
    
    Mmin = 1e10
    Mmax = 0
  
    for MH in binsMH:      
      j = DB.getMHIndex(MH)
  
      M = DB.M[i,j]    
      x = get_values(opt.x, M, DB)
      y = get_values(opt.y, M, DB)
      
      
      # add masses
      if opt.addMasses:
        Masses = M.data[:,M.keys["initial_mass"]]
        Mmax = max(Masses.max(),Mmax)
        Mmin = min(Masses.min(),Mmin)    
        plt.scatter(x,y,c=Masses,cmap=cmap)
      else:
        plt.plot(x,y,color=cmap(colors[j]))
        #plt.scatter(x,y,color=cmap(colors[j]))   
          
    
    
    if opt.addMasses:
      norm = plt_colors.Normalize(vmin=Mmin, vmax=Mmax)
      plt.colorbar(plt_cm.ScalarMappable(norm=norm, cmap=cmap), orientation='vertical', label='Mass')
    else:
      norm = plt_colors.Normalize(vmin=MHmin, vmax=MHmax)
      plt.colorbar(plt_cm.ScalarMappable(norm=norm, cmap=cmap), orientation='vertical', label='log10[MH/H]',ax=plt.gca())
    
    plt.title("Age = %g Myr"%opt.Age)

  
  
  xlabel = get_label(opt.x)
  ylabel = get_label(opt.y)
  
  plt.xlabel(xlabel)
  plt.ylabel(ylabel)
  
  ax = plt.gca()
  
  ax.set_xlim(opt.xmin,opt.xmax)
  ax.set_ylim(opt.ymin,opt.ymax)
  
  if opt.invertx:
    ax.invert_xaxis()
  
  if opt.inverty:
    ax.invert_yaxis()
  

  



########################################################################
# MAIN
########################################################################


if __name__ == '__main__':
    opt = parser.parse_args()
    opt.dpi = 600
        
    MakePlot(None, opt)

    if opt.outputfilename is not None:
      plt.savefig(opt.outputfilename)
    else:  
      plt.show()    

