#!/usr/bin/python3

import numpy as np
from pNbody import *
from pNbody.mass_models import plummer
from pNbody import ic
from astropy import constants as cte
from astropy import units as u
import argparse
from scipy import signal


from tqdm import tqdm

from pNbody import pychem
from pNbody import Isochrones


####################################################################
# option parser
####################################################################

description="""generate a dwarf galaxy

By default, a density profile is generated with N particles having a mass M0/N.
If --inidividual_stars is specified, particles will represent individual stars with a given age and mass obtained
from a distribution.
To get the stellar masses, by default, the Kroupa 2001 IMF is used. It can be changed to anything providing a parameter file containing
the IMF parametrisation:

params = {}
params["Mmax"] = 50.
params["Mmin"] = 0.05
params["as"] = [0.7,-0.8,-1.7,-1.3]
params["ms"] = [0.08,0.5,1.0]

where ``as`` is the slope of the IMF in the mass intervals ``ms`` given is solar masses.   

"""


epilog     ="""
Examples:
--------

# particles representing SSP

mockimgs_generate_dwarf  --M0 1e4 -e 0.1 -N 1000000   --minFe -4 --maxFe -2 --minAge 12500  --maxAge 13500 -o 1e4.hdf5 


# particles representing individual stars

mockimgs_generate_dwarf  --inidividual_stars --minFe -4 --maxFe -2 --minAge 12500  --maxAge 13500 --M0 1e4 -e 0.1  -f CMD_Euclid.dat --logL 'logL' -o 1e4.hdf5 --filters_keys VISmag Ymag Jmag --filters_names MagVIS MagY MagJ
mockimgs_generate_dwarf  --inidividual_stars --minFe -4 --maxFe -2 --minAge 12500  --maxAge 13500 --M0 1e4 -e 0.1  -d P04O1D1E1Y247 -o 1e4.hdf5  --filters_keys r g --filters_names Magr Magg
mockimgs_generate_dwarf  --inidividual_stars --minFe -4 --maxFe -2 --minAge 12500  --maxAge 13500 --M0 1e4 -e 0.1  -d P04O1D1E1Y247 -o 1e4.hdf5  --imf imf.py --filters_keys r g --filters_names Magr Magg



"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)


parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  


parser.add_argument("-d",
                    action="store", 
                    dest="directory", 
                    metavar='DIRECTORY', 
                    type=str,
                    default=Isochrones.database_directory,
                    help='a directory (default=%s)'%Isochrones.database_directory) 

parser.add_argument("-f",
                    action="store", 
                    dest="file", 
                    metavar='FILENAME', 
                    type=str,
                    default=None,
                    help='the name of isochrones database file') 

parser.add_argument("--info",
                    action="store_true", 
                    dest="info", 
                    default=False,
                    help='get info (list of keys) and exit.')  

parser.add_argument("--seed",
                    action="store", 
                    dest="seed",
                    metavar='INT',
                    type=int, 
                    default=1,
                    help='random seed')  
                

parser.add_argument("--imf",
                    action="store", 
                    dest="imf", 
                    metavar='STRING', 
                    type=str,
                    default=None,
                    help='a file containing the IMF parameters')   





parser.add_argument("--minAge",
                    action="store", 
                    dest="minAge", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='minAge') 

parser.add_argument("--maxAge",
                    action="store", 
                    dest="maxAge", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='maxAge') 


parser.add_argument("--minFe",
                    action="store", 
                    dest="minFe", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='minFe') 
                                 
parser.add_argument("--maxFe",
                    action="store", 
                    dest="maxFe", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='maxFe') 



parser.add_argument("--logL",
                    action="store", 
                    dest="logL", 
                    metavar='STRING', 
                    default=None,
                    type=str,
                    help='field name for logL')  

parser.add_argument("--filters_keys",
                    action="store", 
                    dest="filters_keys", 
                    metavar='STRING', 
                    type=str,                    
                    default=None,
                    nargs='*',
                    help='the filters keys as stored in the database')  

parser.add_argument("--filters_names",
                    action="store", 
                    dest="filters_names", 
                    metavar='STRING', 
                    type=str,                    
                    default=None,
                    nargs='*',
                    help='the filters names as stored in the output n-body file')  


parser.add_argument("--M0",
                    action="store", 
                    dest="M0", 
                    metavar='FLOAT', 
                    type=float,
                    default=1e5,
                    help='stellar mass to generate [Msol]') 

parser.add_argument("--MassFactor",
                    action="store", 
                    dest="MassFactor", 
                    metavar='FLOAT', 
                    type=float,
                    default=5,
                    help='Mass factor (to the the correct final luminosity)')                      
                    
           
parser.add_argument("-e",
                    action="store", 
                    dest="e", 
                    metavar='FLOAT', 
                    type=float,
                    default=0.5,
                    help='scale length [kpc]') 

parser.add_argument("--rmax_e_ratio",
                    action="store", 
                    dest="rmax_e_ratio", 
                    metavar='FLOAT', 
                    type=float,
                    default=10,
                    help='rmax to scale length ratio') 
                               
parser.add_argument("--ftype",
                    action="store",
                    type=str,
                    dest="ftype",
                    default="swift",
                    help="type of file")                     

parser.add_argument("--do_not_compute_rsp",
                    action="store_true",
                    default=False,
                    help="do not compute rsp") 

parser.add_argument("--nngb",
                    action="store",
                    type=int,
                    dest="nngb",
                    default=5,
                    help="Number of neighbouring particles to consider to compute RSP(==HSML)")  


parser.add_argument("--inidividual_stars",
                    action="store_true",
                    default=False,
                    help="sample individual stars") 

parser.add_argument("-N",
                    action="store", 
                    dest="N", 
                    metavar='INT', 
                    type=int,
                    default=10000,
                    help='number of particles') 



####################################################################
# main
####################################################################


def Do(opt):
  

  ###################################
  # sample individual stars

  if opt.inidividual_stars:
    

    ##################################
    # IMF parameters
    ##################################
    
    if opt.imf:
      exec(open(opt.imf).read(),globals())
      pychem.set_parameters(params)
    else:
      # note : here we need to change the name due to a conflixt with exec above.
      imf_params = {}
      imf_params["Mmax"] = 50.
      imf_params["Mmin"] = 0.05
      imf_params["as"] = [0.7,-0.8,-1.7,-1.3]
      imf_params["ms"] = [0.08,0.5,1.0]    
      pychem.set_parameters(imf_params)
    
    
    
    ##################################
    # init the mist database
    ##################################
    print("Reading isochrones database...")
    DB = Isochrones.Isochrones(opt.directory,default_keys=None,filename=opt.file)
    
    if opt.info:
      print("list of keys")
      print("------------")
      keys = DB.getKeys()
      for key in keys:
        print(key)
      exit()    
    
    #############################################
    # generate a given amount of stellar mass
    ############################################# 
    
    mmax = pychem.get_Mmax()
    mmin = pychem.get_Mmin()
    
    # compute the number of stars per mass between m1 and m2 (dep on M0)
    # N this is thus the number of stars in a particle of mass M0
    N      = int(pychem.get_imf_N(np.array([mmin]),np.array([mmax]))*opt.M0)        
    # compute the masses
    MassesFullIMF = pychem.imf_sampling(N,opt.seed)
    MassesFullIMF.sort()
    
        
    #############################################
    # generate ages in Myr
    ############################################# 
    
    # attribute Ages and Fe
    #Ages   = opt.Age    * np.ones(N)
    Ages   = np.random.uniform(opt.minAge,opt.maxAge,N)
    
    #############################################
    # generate metallicities
    #############################################   
    
    #Fes    = opt.Fe     * np.ones(N)
    Fes    = np.random.uniform(opt.minFe,opt.maxFe,N)
      
    #############################################
    # remove the masses that exploded
    #############################################   
    
    Mmaxs = np.zeros(N)
    for i,Mass in enumerate(MassesFullIMF):
      Mmaxs[i] = DB.getMaxStellarMass(Ages[i],Fes[i])
    
    c = MassesFullIMF < Mmaxs
    Masses = np.compress(c,MassesFullIMF)
    Ages   = np.compress(c,Ages)
    Fes    = np.compress(c,Fes)
    n = len(Masses)     # number of remaining stars in the IMF
    
    print("%d stars"%n)
    print("%g [1e5 Msol]"%sum(Masses/1e5))
    
    
    Ls = np.zeros(n)
    
    class Filters():
      pass 
    
    F = {}
    for filter_key in opt.filters_keys:
      F[filter_key] = np.zeros(n)
    
    
    for i in tqdm(range(n)):
      if opt.logL is not None:
        Ls[i] = 10**DB.get(Ages[i],Fes[i],Masses[i],key=opt.logL,   mode="masslininterp")
    
      for filter_key in F:
        F[filter_key][i] = DB.get(Ages[i],Fes[i],Masses[i],key=filter_key,  mode="masslininterp")
        
    if opt.logL is not None:  
      print("Ltot = %g [1e5 Lsol]"%(Ls.sum()/1e5))  
  
    # set masses # output units are in 1e10 Msol
    mass = Masses/1e10
    has_magnitudes = True
  
  
  
  
  ###################################
  # particles only
  
  else:
    
    n = opt.N    
    
    # masses
    mass = np.ones(n)*opt.M0/n/1e10 # output units are in 1e10 Msol
        
        
    # generate ages in Myr
    Ages   = np.random.uniform(opt.minAge,opt.maxAge,n)
    
    # generate metallicities    
    Fes    = np.random.uniform(opt.minFe,opt.maxFe,n)    
    
    # magnitudes
    has_magnitudes = False
  


  #############################################
  # generate a plummer model
  ############################################# 

  
  # set maximal radius
  rmax = opt.e * opt.rmax_e_ratio
    
  # generate the model
  nb = ic.plummer(n,1,1,1,opt.e,rmax,M=1,irand=1,vel='no',name=opt.outputfilename,ftype=opt.ftype) 

  # compute Hsml
  if opt.do_not_compute_rsp is False:
    print("Compute Rsp...")
    # remove doublets
    #u,idx = np.unique(nb.rxyz(),return_index=True)
    #nb = nb.selectp(lst=nb.num[idx])
    nb.ComputeRsp(opt.nngb)
    print("done.")
      
  
  # set particles to be stars
  md = nb.getParticleMatchingDict()
  nb.set_tpe(md["stars"])
  
  # set masses (as mass or luminosities)
  nb.mass = mass

  
  


  
  # store age
  if Ages is not None:
    nb.age = Ages/1e3 # store in Gyr
  
  # store metallicity
  if Fes is not None:
    nb.mh  = Fes
  
  # store the magnitudes
  if has_magnitudes:
    for (filter_key,filter_name) in zip(opt.filters_keys,opt.filters_names):
      setattr(nb,filter_name,F[filter_key]) 
      

  


  # define units
  u_Length   = 1* u.kpc
  u_Mass     = 10**10 * u.M_sun
  u_Velocity = 1* u.km/u.s
  u_Time     = u_Length/u_Velocity 
  toMsol     = u_Mass.to(u.M_sun).value
  

  # add units
  nb.UnitLength_in_cm         = u_Length.to(u.cm).value
  nb.UnitMass_in_g            = u_Mass.to(u.g).value
  nb.UnitVelocity_in_cm_per_s = u_Velocity.to(u.cm/u.s).value
  nb.Unit_time_in_cgs         = u_Time.to(u.s).value

  # non cosmological 
  nb.setComovingIntegrationOff()
  nb.cosmorun=0
  nb.time    =0

  
  if opt.outputfilename is not None:
    nb.rename(opt.outputfilename)
    nb.write()



if __name__ == '__main__':
  
  opt = parser.parse_args()
  
  Do(opt)
 
