#!/usr/bin/env python3

import numpy as np
from pNbody import *
from pNbody import iofunc
from pNbody.mass_models import plummer
from pNbody import ic
from astropy import constants as cte
from astropy import units as u
import argparse
import pickle

from pNbody import Isochrones
from pNbody import pychem

from tqdm import tqdm

####################################################################
# option parser
####################################################################

description="""Generate a grid of magnitude. The grid is a 2d matrix discretized in [M/H] and ages. 
It is actually stored as a pickle file.
The script takes as input either a MIST database or an ascii file from the CMD tool (http://stev.oapd.inaf.it/cmd/).
By default, the IMF is the Kroupa 2001 one. It can be changed to anything providing a parameter file containing
the IMF parametrisation:

params = {}
params["Mmax"] = 50.
params["Mmin"] = 0.05
params["as"] = [0.7,-0.8,-1.7,-1.3]
params["ms"] = [0.08,0.5,1.0]


where ``as`` is the slope of the IMF in the mass intervals ``ms`` given is solar masses.   
"""

epilog     ="""
Examples:
--------

For an ascii CMD file e.g. CMD_Euclid.dat:

isochrones_generate_magnitudes_grid  --rebinAges --M0 1e6 --filter_key VISmag  -f CMD_Euclid.dat -o EuclidVISmag_CMD_1e6.hdf5   
isochrones_generate_magnitudes_grid  --rebinAges --M0 1e6 --filter_key VISmag  -f CMD_Euclid.dat -o EuclidVISmag_CMD_1e6.hdf5   --imf imf.py
isochrones_generate_magnitudes_grid  --rebinAges --M0 1e6 --filter_key G       -d P04O1D1E1Y247  -o GAIA_G_BastI_1e6.hdf5
isochrones_generate_magnitudes_grid  --rebinAges --M0 1e6 --filter_key G  --luminosity_key 'log(L/Lo)'  -d P04O1D1E1Y247  -o GAIA_G_BastI_1e6.hdf5
"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)


parser.add_argument("-d",
                    action="store", 
                    dest="directory", 
                    metavar='DIRECTORY', 
                    type=str,
                    default=Isochrones.database_directory,
                    help='a directory (default=%s)'%Isochrones.database_directory) 

parser.add_argument("-f",
                    action="store", 
                    dest="file", 
                    metavar='FILENAME', 
                    type=str,
                    default=None,
                    help='the name of isochrones database file') 


parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  


parser.add_argument("--filter_key",
                    action="store", 
                    dest="filter_key", 
                    metavar='STRING', 
                    type=str,                    
                    default=None,
                    help='the filter key as stored in the database')  

parser.add_argument("--luminosity_key",
                    action="store", 
                    dest="luminosity_key", 
                    metavar='STRING', 
                    type=str,                    
                    default=None,
                    help='the luminosity key as stored in the database (log(L/Lo))')


                    
parser.add_argument("--interpolation_mode",
                    action="store", 
                    dest="interpolation_mode", 
                    metavar='STRING', 
                    type=str,
                    default="masslininterp",
                    help='the interpolation mode (nearest or lininterp)')                    
                    

parser.add_argument("--M0",
                    action="store", 
                    dest="M0", 
                    metavar='FLOAT', 
                    type=float,
                    default=1e5,
                    help='the mass of the stellar population assumed to compute the magnitudes. Values lower than 1e6 may lead to stochastic fluctuations') 

parser.add_argument("--rebinAges",
                    action="store_true", 
                    dest="rebinAges", 
                    default=False,
                    help='use an hard-coded binning [10**(np.linspace(0.6,4.1,71))] for the ages and not the binning provided by the database')                           

                    

parser.add_argument("--seed",
                    action="store", 
                    dest="seed",
                    metavar='INT',
                    type=int, 
                    default=1,
                    help='random seed')  
                
parser.add_argument("--imf",
                    action="store", 
                    dest="imf", 
                    metavar='STRING', 
                    type=str,
                    default=None,
                    help='a file containing the IMF parameters')                    
                

parser.add_argument("--info",
                    action="store_true", 
                    dest="info", 
                    default=False,
                    help='get info (list of keys) and exit.')  
                    
                    
####################################################################
# main
####################################################################

# compute cmdline
cmdline = ""
for elt in sys.argv:
 cmdline = cmdline + " " + elt
cmdline = cmdline[1:]


def Do(opt):
  
  ##################################
  # IMF parameters
  ##################################

  if opt.imf:
    exec(open(opt.imf).read(),globals())
    pychem.set_parameters(params)
  else:
    # note : here we need to change the name due to a conflixt with exec above.
    imf_params = {}
    imf_params["Mmax"] = 50.
    imf_params["Mmin"] = 0.05
    imf_params["as"] = [0.7,-0.8,-1.7,-1.3]
    imf_params["ms"] = [0.08,0.5,1.0]    
    pychem.set_parameters(imf_params)
    


  ##################################
  # init the mist database
  ##################################
  DB = Isochrones.Isochrones(opt.directory,default_keys=None,filename=opt.file)

  if opt.info:
    print("list of keys")
    print("------------")
    keys = DB.getKeys()
    for key in keys:
      print(key)
    exit()

  
  # check the filter
  if opt.filter_key is not None:
    DB.checkKeyExists(opt.filter_key)
  
  # check the luminosity
  if opt.luminosity_key is not None:
    DB.checkKeyExists(opt.luminosity_key)
    
  
  #############################################
  # generate a given amount of stellar mass
  ############################################# 

  mmax = pychem.get_Mmax()
  mmin = pychem.get_Mmin()

  # compute the number of stars per mass between m1 and m2 (dep on M0)
  # N this is thus the number of stars in a particle of mass M0
  N      = pychem.get_imf_N(np.array([mmin]),np.array([mmax]))*opt.M0        
  # compute the masses
  MassesFullIMF = pychem.imf_sampling(int(N),opt.seed)
  MassesFullIMF.sort()
    

  #######################################
  # Set the binning
  #######################################
    
  # set the binning (Age in Myr)
  binsMH  = DB.binsMH
  binsAge = DB. binsAge
  
  # use specific Age binning
  if opt.rebinAges:
    binsAge = 10**(np.linspace(0.6,4.1,71))
    
  #binsMH = np.array([-2,-1,0])
  #binsAge = np.array([1,4,7,10,13])*1e3


  print("# of bins in Age : %d"%len(binsAge),binsAge)
  print("# of bins in MH  : %d"%len(binsMH),binsMH)


  #######################################
  # Main loop
  #######################################
  
  Mags = None
  Lums = None  
    
  if opt.filter_key is not None: 
    Mags      = np.zeros((len(binsAge),len(binsMH)))
  
  if opt.luminosity_key is not None:
    Lums      = np.zeros((len(binsAge),len(binsMH)))

  M0MfRatio = np.zeros((len(binsAge),len(binsMH)))
  
  # loop over Ages
  #for Age in DB.binsAge:
  for i in tqdm(range(len(binsAge))):   
    Age = binsAge[i] 
    # loop over MH
    for j in range(len(binsMH)):
      MH = binsMH[j]
      
      # maximal mass for this Age and MH
      Mmaxs = DB.getMaxStellarMass(Age,MH)   
          
      c = MassesFullIMF < Mmaxs
      Masses = np.compress(c,MassesFullIMF)      
      n = len(Masses)

      ###################################
      # compute magnitudes
      
      if opt.filter_key is not None:
        Ms = DB.get(Age,MH,Masses,key=opt.filter_key,mode="masslininterp")
        Fs = 10**(-Ms/2.5)        
        # sum all fluxes  
        Ftot = Fs.sum()
        # compute flux per unit mass
        Ftot = Ftot/opt.M0
        # back to magnitudes
        M = -2.5*np.log10(Ftot)
        # store the magnitude      
        Mags[i,j] = M
      
      if opt.luminosity_key is not None:
        Ls = DB.get(Age,MH,Masses,key=opt.luminosity_key,mode="masslininterp")        
        Ls = 10**(Ls)
        # sum all luminosities  
        Ltot = Ls.sum()
        # compute luminosity per unit mass
        Ltot = Ltot/opt.M0
        # back to log10(Ltot)
        L = np.log10(Ltot)      
        # store the luminosity      
        Lums[i,j] = L
    
      
      # store the Initial IMF to Final IMF ratio
      M0MfRatio[i,j] = opt.M0/Masses.sum()


  #######################################
  # output
  #######################################  
   
  if opt.outputfilename is not None:
    
    # hdf5
    if opt.outputfilename.endswith(".hdf5"):
      sspGrid = iofunc.SSPGrid(opt.outputfilename)
      sspGrid.write(binsAge,binsMH,magnitudes=Mags,luminosities=Lums,minimfin=M0MfRatio,cmdline=cmdline)
    
    # pickle
    elif opt.outputfilename.endswith(".pkl"):    
      f = open(opt.outputfilename,'wb')
      pickle.dump(Mags,f)
      pickle.dump(binsAge,f)
      pickle.dump(binsMH,f)
      pickle.dump(M0MfRatio,f)
      f.close()
  

    

if __name__ == '__main__':
  
  opt = parser.parse_args()
  
  Do(opt)
 
