#!/usr/bin/python3
###########################################################################################
#  package:   pNbody
#  file:      ic_makeGalaxy
#  brief:     
#  copyright: GPLv3
#             Copyright (C) 2019 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Yves Revaz <yves.revaz@epfl.ch>
#
# This file is part of pNbody.
###########################################################################################

import argparse
from pNbody import *
from pNbody import ic
from astropy import units as u
from astropy import constants as c
from pNbody import thermodyn, ctes, units
from pNbody import iofunc
import warnings
from pNbody.mass_models import powerSphericalCutoff as psc

params_yml = """
OutputUnits:
  UnitLength_in_cm:         3.085e21          # 1 kpc in centimeters
  UnitMass_in_g:            1.98848e43        # 10^10 M_sun in grams
  UnitVelocity_in_cm_per_s: 1e5               # 1 km/s in centimeters per second
  
Galaxy:  
  m_ref:              200000         # reference mass of particles [solar mass]
  nf:                 1              # particle count multiplicative factor (used to reduce noise) 
  ftype:              "swift"        # output format (swift, gh5, gadget)
  outputfilename:     "galaxy.hdf5"  # output file name
  irand:              0              # random seed
  boxsize:            1500           # boxsize [kpc]
  

# Parameters for disk component
Disk:
  fm:                 1             # particle mass to ref. mass ratio (set to 0 disable the component)
  mass:               6.8e10        # component mass [solar mass]
  Hr:                 3.0           # radial scale [kpc]
  Hz:                 0.280         # vertical scale [kpc]
  rmax_to_Hr_ratio:   10            # rmax to Hr ratio
  zmax_to_Hz_ratio:   10            # zmax to Hz ratio
  toomre:             2             # Toomre stability criterion


# Parameters for bulge component
Bulge:
  fm:                 1             # particle mass to ref. mass ratio (set to 0 disable the component)
  alpha:              1.8           # power law slope
  r_c:                1.9           # cutting radius [kpc]
  r_1:                1.0           # radius [kpc]
  amplitude:          2.22646e8     # density factor [Msol/kpc3]
  fr:                 10            # rmax to r_c ratio
  dR:                 0.01          # grid resolution [kpc]  

# Parameters for dark halo component
Halo:
  fm:                 5             # particle mass to ref. mass ratio (set to 0 disable the component)
  r_s:                16            # scale radius [kpc]
  M_200:              0.6438e12     # virial mass [msol]
  r_200:              157.174455004 # virial radius [kpc]
  fr:                 1.0           # rmax to r200 ratio
  dR:                 0.25          # grid resolution [kpc]
"""



####################################################################
# option parser
####################################################################

description = "Generates a mulit-component galaxy model"
epilog = """
The model is made out of four components:

a stellar disk (Exponential disk)
a bulge (Power law model)
a dark halo (generic two slope model)

All components can be fine tuned using parameters provided by a parameter file (see below).

Examples:
--------
ic_makeMW2014
ic_makeMW2014 -p params.yml
ic_makeMW2014 -p params.yml -o galaxy.hdf5
ic_makeMW2014 --print-parameters


where the file params.yml contains the following parameters:

'''
%s
'''

"""%params_yml


parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)


parser.add_argument("-p", "--parameter-file",
                    action="store",
                    type=str,
                    dest="parameter_file",
                    required=False,
                    default=None,
                    help="Name of the parameter file")  

parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  

parser.add_argument("--debug",
                    action="store_true",
                    help="debug mode: save intermediate files")  

parser.add_argument("--print-parameters",
                    action="store_true",
                    help="print parameters and quit")                                  

parser.add_argument("--HubbleParameter",
                    action="store",
                    type=float,
                    default=0.6766,
                    help="Hubble parameter")  

                    
opt = parser.parse_args()



  


#################################################################
# read parameters
#################################################################

# if the parameter file is not provided, use the default values
if opt.parameter_file is None:
  parameter_file = "tmp_parameter_file.yml"
  with open(parameter_file, 'w') as f:
    f.write(params_yml)
  
  print("# using default parameters in %s"%parameter_file)
else:
  parameter_file = opt.parameter_file

params = iofunc.ReadYaml(parameter_file)

if opt.print_parameters:
  if opt.parameter_file is None:
    print(params_yml)
  else:
    with open(parameter_file, 'r') as f:
      content = f.read()
      print(content)
  
  exit()  


#################################################################
# Units
#################################################################

outputUnitLength_in_cm           = float(params["OutputUnits"]["UnitLength_in_cm"]) 
outputUnitMass_in_g              = float(params["OutputUnits"]["UnitMass_in_g"]) 
outputUnitVelocity_in_cm_per_s   = float(params["OutputUnits"]["UnitVelocity_in_cm_per_s"]) 

# pNb internal units (G=1)
pNbUnitLength_in_cm         = 3.085e21
pNbUnitMass_in_g            = 4.435693e44
pNbUnitVelocity_in_cm_per_s = 97824708.2699


uparams = {}
uparams['UnitLength_in_cm'] = pNbUnitLength_in_cm
uparams['UnitVelocity_in_cm_per_s'] = pNbUnitVelocity_in_cm_per_s
uparams['UnitMass_in_g'] = pNbUnitMass_in_g

pNb_system_of_units = units.Set_SystemUnits_From_Params(uparams)

kpctopNbUnits = (1 / (pNbUnitLength_in_cm * u.cm).to(u.kpc)).value
kmstopNbUnits = (1 / (pNbUnitVelocity_in_cm_per_s * u.cm / u.s).to(u.km / u.s)).value
MsoltopNbUnits = 1 / (pNbUnitMass_in_g * u.g).to(u.M_sun).value


#################################################################
# model parameters
#################################################################


m_ref          = float(params["Galaxy"]["m_ref"])    # mass of gas particles, in solar mass
nf             = params["Galaxy"]["nf"]              # particle count multiplicative factor (applied to reduce noise)
outputfilename = params["Galaxy"]["outputfilename"]
ftype          = params["Galaxy"]["ftype"]
irand          = int(params["Galaxy"]["irand"])
boxsize        = float(params["Galaxy"]["boxsize"])

# units conversion (kpc to pNbody)
boxsize = boxsize *kpctopNbUnits


if opt.outputfilename is not None:
  outputfilename = opt.outputfilename


#####################
# Stellar disk
#####################

fm_disk           = float(params["Disk"]["fm"])
M_disk            = float(params["Disk"]["mass"])
Hr_disk           = float(params["Disk"]["Hr"])
Hz_disk           = float(params["Disk"]["Hz"])
rmax_to_Hr_ratio  = float(params["Disk"]["rmax_to_Hr_ratio"])
zmax_to_Hz_ratio  = float(params["Disk"]["zmax_to_Hz_ratio"])
toomre_disk       = float(params["Disk"]["toomre"])

# units conversion (to pNbody pNbody)
M_disk  = M_disk  *MsoltopNbUnits
Hr_disk = Hr_disk *kpctopNbUnits
Hz_disk = Hz_disk *kpctopNbUnits

rmax_disk = rmax_to_Hr_ratio * Hr_disk
zmax_disk = zmax_to_Hz_ratio * Hz_disk

# Stellar disk parameters
if fm_disk == 0:
    M_disk = 0
 
#####################
# Bulge
#####################

fm_bulge           = float(params["Bulge"]["fm"])
alpha_bulge        = float(params["Bulge"]["alpha"])
r_c_bulge          = float(params["Bulge"]["r_c"])
amplitude_bulge    = float(params["Bulge"]["amplitude"])
r_1_bulge          = float(params["Bulge"]["r_1"])
fr_bulge           = float(params["Bulge"]["fr"])
dR_bulge           = float(params["Bulge"]["dR"])


# units conversion (to pNbody pNbody)
r_c_bulge  = r_c_bulge  *kpctopNbUnits
r_1_bulge  = r_1_bulge  *kpctopNbUnits
amplitude_bulge = amplitude_bulge*MsoltopNbUnits/kpctopNbUnits**3
dR_bulge   = dR_bulge   *kpctopNbUnits

Rmax_bulge      = r_c_bulge*fr_bulge

# M_bulge is already in the correct units
M_bulge  =  psc.CumulativeMass(alpha_bulge, r_c_bulge, amplitude_bulge, r_1_bulge, Rmax_bulge)


if fm_bulge == 0:
    M_bulge = 0


#####################
# Dark matter halo
#####################

from astropy.cosmology import Planck18 as cosmo
from pNbody.mass_models import gen_2_slopes


fm_halo            = float(params["Halo"]["fm"])
M_200              = float(params["Halo"]["M_200"])
r_200              = float(params["Halo"]["r_200"])
r_s                = float(params["Halo"]["r_s"])
fr_halo            = float(params["Halo"]["fr"])
dR_halo            = float(params["Halo"]["dR"])

# units conversion (to pNbody pNbody)
r_200  = r_200  *kpctopNbUnits
r_s    = r_s    *kpctopNbUnits
M_200  = M_200  *MsoltopNbUnits
dR_halo    = dR_halo    *kpctopNbUnits


c = r_200/r_s
log_c200_term = np.log(1 + c) - c / (1. + c)
rho_0 = M_200 / (4 * np.pi * r_s**3 * log_c200_term)
#rho_c = M_200/(200*4/3*np.pi*r_200**3)
#H_0 = np.sqrt(8*np.pi*G/3 * rho_c).to(1/unit_time) 

# needs H0 in pNbody units
#H0 = ctes.HUBBLE.into(pNb_system_of_units) * opt.HubbleParameter
#G = ctes.GRAVITY.into(pNb_system_of_units)
#gen2slopes = gen_2_slopes.GEN2SLOPES(alpha=alpha_halo,beta=beta_halo,c=c_halo,M200=M_halo,G=G,H0=H0)
#gen2slopes.info()

Rs_halo    = r_s
Rmax_halo  = r_200


if fm_halo == 0:
    M_200 = 0

M_halo = M_200




# estimation of the gravitational softening length
eps = 0.05 / (m_ref / 200000) ** (1 / 3.0)

m_ref = m_ref * MsoltopNbUnits  # to code units


#################################################################
# parameters for the velocities
#################################################################

ErrTolTheta = 0.5
AdaptativeSoftenning = False


###################################
# spherical components
###################################


# grid parameters halo
stats_name_halo = "stats_halo.dmp"
grmin_halo = 0  # grid minimal radius
grmax_halo = Rmax_halo * 1.05  # grid maximal radius
nr_halo = 64  # number of radial bins
eps_halo = eps
# grid bins functions
rc_halo = Rs_halo

def g_halo(r):
    return np.log(r / rc_halo + 1.0)

def gm_halo(r):
    return rc_halo * (np.exp(r) - 1.0)


# grid parameters bulge
stats_name_bulge = "stats_bulge.dmp"
grmin_bulge = 0  # grid minimal radius
grmax_bulge = Rmax_bulge * 1.05  # grid maximal radius
nr_bulge = 64  # number of radial bins
eps_bulge = eps

# grid bins functions

def g_bulge(r):
    return np.log(r / r_c_bulge + 1.0)


def gm_bulge(r):
    return r_c_bulge * (np.exp(r) - 1.0)


###################################
# cylindrical components
###################################

# grid parameters disk
stats_name_disk = "stats_disk.dmp"
grmin_disk = 0.0        # minimal grid radius
grmax_disk = rmax_disk  # maximal grid radius
gzmin_disk = -zmax_disk # minimal grid z
gzmax_disk =  zmax_disk # maximal grid z
nr_disk = 32            # number of bins in r
nt_disk = 2             # number of bins in t
nz_disk = 64 + 1        # number of bins in z
# for an even value of nz, the potential is computed at z=0
# for an odd  value of nz, the density   is computed at z=0
eps_disk = eps
# grid bins functions
rc_disk = Hr_disk


def g_disk(r):
    return np.log(r / rc_disk + 1.0)

def gm_disk(r):
    return rc_disk * (np.exp(r) - 1.0)


mode_sigma_z = {"name": "jeans", "param": None}
mode_sigma_r = {"name": "isothropic", "param": 2}
mode_sigma_p = {"name": "epicyclic_approximation", "param": None}
params_disk = [mode_sigma_z, mode_sigma_r, mode_sigma_p]



#################################################################
# compute the number of particles for each component
#################################################################


# here we give explicitly the mass of the gas particles
m = m_ref


if fm_disk == 0:
    N_disk = 0
else:
    N_disk = int(M_disk / (m * fm_disk))

if fm_bulge == 0:
    N_bulge = 0
else:
    N_bulge = int(M_bulge / (m * fm_bulge))

if fm_halo == 0:
    N_halo = 0
else:
    N_halo = int(M_halo / (m * fm_halo))

print("N_disk  = %d" % N_disk)
print("N_bulge = %d" % N_bulge)
print("N_halo  = %d" % N_halo)
print("----------------------------")
print("N_tot   = %d" % (N_disk + N_bulge + N_halo))
print("----------------------------")


if N_disk > 0:
    print("m_disk  = %g Msol" % ((M_disk / N_disk) / MsoltopNbUnits))

if N_bulge > 0:
    print("m_bulge = %g Msol" % ((M_bulge / N_bulge) / MsoltopNbUnits))

if N_halo > 0:
    print("m_halo  = %g Msol" % ((M_halo / N_halo) / MsoltopNbUnits))
print()


if nf > 1:
    N_disk  = int(nf * N_disk)
    N_bulge = int(nf * N_bulge)
    N_halo  = int(nf * N_halo)




#################################################################
# generate models
#################################################################


#####################
# disk
#####################

nb_disk = None
if M_disk != 0.0:
    print("generating disk...")
    nb_disk =  ic.miyamoto_nagai(N_disk, Hr_disk, Hz_disk, rmax_disk, zmax_disk,irand,ftype="gh5")    
    nb_disk.set_tpe("disk")
    nb_disk.mass = (M_disk / N_disk) * np.ones(nb_disk.nbody).astype(np.float32)
    nb_disk.rename("disk.dat")
    
    if opt.debug:
      nb_disk.write()


#####################
# halo
#####################


nb_halo = None
if M_halo != 0.0:
    print("generating halo...")
    nb_halo = ic.nfw(N_halo,Rs_halo,Rmax_halo,dR_halo,Rs=None,irand=irand,name='nfw.hdf5',ftype='gh5',verbose=False)    
    nb_halo.set_tpe("halo")
    nb_halo.mass = (M_halo / N_halo) * np.ones(nb_halo.nbody).astype(np.float32)
    nb_halo.rename("halo.dat")
    
    if opt.debug:
      nb_halo.write()


#####################
# bulge
#####################

nb_bulge = None
if M_bulge != 0.0:
    print("generating bulge...")
    nb_bulge = ic.power_spherical(N_bulge, alpha_bulge, r_c_bulge, amplitude_bulge, r_1_bulge, Rmax_bulge, dR_bulge, Rs=None, ftype="gh5")
    nb_bulge.set_tpe("bulge")
    nb_bulge.mass = (M_bulge / N_bulge) * np.ones(nb_bulge.nbody).astype(np.float32)
    nb_bulge.rename("bulge.dat")
    
    if opt.debug:
      nb_bulge.write()




###############################################################
# merge all components
###############################################################

nb = None

if nb_disk is not None:
    if nb is None:
        nb = nb_disk
    else:
        nb = nb + nb_disk

if nb_halo is not None:
    if nb is None:
        nb = nb_halo
    else:
        nb = nb + nb_halo

if nb_bulge is not None:
    if nb is None:
        nb = nb_bulge
    else:
        nb = nb + nb_bulge


# save particles without velocities
if opt.debug:
  nb.write("snapnf.hdf5")


###############################################################
# compute velocities
###############################################################


if nb_disk is not None:
    print("------------------------")
    print("disk velocities...")
    print("------------------------")

    nb_disk, phi, stats_disk = nb.Get_Velocities_From_Cylindrical_Grid(
        select="disk",
        disk=("disk", "gas"),
        eps=eps_disk,
        nR=nr_disk,
        nz=nz_disk,
        nt=nt_disk,
        Rmax=grmax_disk,
        zmin=gzmin_disk,
        zmax=gzmax_disk,
        params=params_disk,
        Phi=None,
        g=g_disk,
        gm=gm_disk,
        ErrTolTheta=ErrTolTheta,
        AdaptativeSoftenning=AdaptativeSoftenning,
    )
    if opt.debug:
      iofunc.write_dmp(stats_name_disk, stats_disk)

    r = stats_disk["R"]
    z = stats_disk["z"]
    dr = r[1] - r[0]
    dz = z[nz_disk // 2 + 1] - z[nz_disk // 2]

    print("disk : Delta R :", dr, "=", dr // eps_disk, "eps")
    print("disk : Delta z :", dz, "=", dz // eps_disk, "eps")

    # reduce if needed
    if nf > 1:
        nb_disk = nb_disk.reduc(nf, mass=True)



if nb_bulge is not None:
    print("------------------------")
    print("bulge velocities...")
    print("------------------------")
    nb_bulge, phi, stats_bulge = nb.Get_Velocities_From_Spherical_Grid(
        select="bulge",
        eps=eps_bulge,
        nr=nr_bulge,
        rmax=grmax_bulge,
        phi=None,
        g=g_bulge,
        gm=gm_bulge,
        UseTree=True,
        ErrTolTheta=ErrTolTheta,
    )
    if opt.debug:
      iofunc.write_dmp(stats_name_bulge, stats_bulge)

    r = stats_bulge["r"]
    dr = r[1] - r[0]
    print("bulge : Delta r :", dr, "=", dr / eps_bulge, "eps")

    # reduce if needed
    if nf > 1:
        nb_bulge = nb_bulge.reduc(nf, mass=True)


if nb_halo is not None:
    print("------------------------")
    print("halo velocities...")
    print("------------------------")
    nb_halo, phi, stats_halo = nb.Get_Velocities_From_Spherical_Grid(
        select="halo",
        eps=eps_halo,
        nr=nr_halo,
        rmax=grmax_halo,
        phi=None,
        g=g_halo,
        gm=gm_halo,
        UseTree=True,
        ErrTolTheta=ErrTolTheta,
    )
    if opt.debug:
      iofunc.write_dmp(stats_name_halo, stats_halo)

    r = stats_halo["r"]
    dr = r[1] - r[0]
    print("halo : Delta r :", dr, "=", dr / eps_halo, "eps")

    # reduce if needed
    if nf > 1:
        nb_halo = nb_halo.reduc(nf, mass=True)


###############################################################
# sum the different components and save the final model
###############################################################

nb = None

if nb_disk is not None:
    if nb is None:
        nb = nb_disk
    else:
        nb = nb + nb_disk

if nb_halo is not None:
    if nb is None:
        nb = nb_halo
    else:
        nb = nb + nb_halo

if nb_bulge is not None:
    if nb is None:
        nb = nb_bulge
    else:
        nb = nb + nb_bulge



# reorganize components
nb1 = nb.select("halo")
nb1.set_tpe("halo_1")  # position expected by swift

nb2 = nb.select("disk")
nb2.set_tpe("stars_1")  # position expected by swift

nb3 = nb.select("bulge")
nb3.set_tpe("stars_1")  # position expected by swift

nb = nb1 + nb2 + nb3

# convert to swift
nb = nb.set_ftype(ftype=ftype)


# add units
np.UnitLength_in_cm         = outputUnitLength_in_cm
nb.UnitMass_in_g            = outputUnitMass_in_g
nb.UnitVelocity_in_cm_per_s = outputUnitVelocity_in_cm_per_s
nb.Unit_time_in_cgs         = outputUnitLength_in_cm / outputUnitVelocity_in_cm_per_s

nb.boxsize = boxsize
nb.rsp_init = np.ones(nb.nbody) * eps
nb.birth_time_init = -1 * np.ones(nb.nbody)



# final unit conversion : pNbody -> user
nb.pos      = nb.pos     *pNbUnitLength_in_cm/outputUnitLength_in_cm
nb.vel      = nb.vel     *pNbUnitVelocity_in_cm_per_s/outputUnitVelocity_in_cm_per_s
nb.mass     = nb.mass    *pNbUnitMass_in_g/outputUnitMass_in_g
nb.rsp_init = nb.rsp_init*pNbUnitLength_in_cm/outputUnitLength_in_cm
nb.boxsize  = nb.boxsize *pNbUnitLength_in_cm/outputUnitLength_in_cm


# check boxsize
xmax = max(np.fabs(nb.x()))
if nb.boxsize < 2*xmax:
  txt = "WARNING: boxsize (=%g) < 2*xmax (=%g) !"%(nb.boxsize,2*xmax)
  warnings.warn(txt, UserWarning)


# save model
nb.rename(outputfilename)
nb.write()

#Add the StellarParticleType attribute to the dataset
import h5py as h5
import numpy as np

nb_star = nb.select("stars")

if nb_star.nbody > 0:
  
  N_star = np.sum(nb_star.npart)
  star_tpe = 2 # Single population stars
  star_type = np.ones(N_star)*star_tpe

  with  h5.File(outputfilename, "r+") as f:
    f["PartType4"].create_dataset("StellarParticleType", data=star_type)
