#!/usr/bin/python3
###########################################################################################
#  package:   pNbody
#  file:      ic_makeGalaxy
#  brief:     
#  copyright: GPLv3
#             Copyright (C) 2019 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Yves Revaz <yves.revaz@epfl.ch>
#
# This file is part of pNbody.
###########################################################################################

import argparse
from pNbody import *
from pNbody import ic
from astropy import units as u
from astropy import constants as c
from pNbody import thermodyn, ctes, units
from pNbody import iofunc
import warnings


params_yml = """
OutputUnits:
  UnitLength_in_cm:         3.085e21          # 1 kpc in centimeters
  UnitMass_in_g:            1.98848e43        # 10^10 M_sun in grams
  UnitVelocity_in_cm_per_s: 1e5               # 1 km/s in centimeters per second
  
Galaxy:  
  m_ref:              200000         # reference mass of particles [solar mass]
  nf:                 1              # particle count multiplicative factor (used to reduce noise) 
  hydro:              1              # if set to 0, turn gas to collisionless particles
  ftype:              "swift"        # output format (swift, gh5, gadget)
  outputfilename:     "galaxy.hdf5"  # output file name
  irand:              0              # random seed
  boxsize:            1500           # boxsize [kpc]
  
# Parameters for gas component
Gas:
  fm:                 1             # particle mass to ref. mass ratio (set to 0 disable the component)
  mass:               0.5e10        # component mass [solar mass]
  Hr:                 4.0           # radial scale [kpc]
  Hz:                 0.3           # vertical scale [kpc]
  rmax_to_Hr_ratio:   10            # rmax to Hr ratio
  zmax_to_Hz_ratio:   3             # zmax to Hz ratio
  sigmavel:           10            # velocity dispersion [km/s] (if hydro=0)
  T:                  100           # temperature [K]            (if hydro=1)

# Parameters for disk component
Disk:
  fm:                 1             # particle mass to ref. mass ratio (set to 0 disable the component)
  mass:               2e10          # component mass [solar mass]
  Hr:                 2.0           # radial scale [kpc]
  Hz:                 0.3           # vertical scale [kpc]
  fr:                 10            # grid rmax to Hr ratio
  fz:                 10            # grid rmax to Hr ratio
  toomre:             2             # Toomre stability criterion


# Parameters for bulge component
Bulge:
  fm:                 1             # particle mass to ref. mass ratio (set to 0 disable the component)
  mass:               0.4e10        # component mass [solar mass]
  Hr:                 1.0           # radial scale [kpc]
  fr:                 5.0           # grid rmax to Hr ratio

# Parameters for dark halo component
Halo:
  fm:                 5             # particle mass to ref. mass ratio (set to 0 disable the component)
  mass:               115e10        # component mass [solar mass]
  c:                  10            # concentration parameter
  alpha:              1             # inner slope
  beta:               3             # outer slope
  Rcut:               None          # cutting radius [kpc]
  power_cut:          3             # power of the cut if Rcut is defined
  fr:                 1.0           # rmax to r200 ratio
  dR:                 0.25          # grid resolution [kpc]
"""



####################################################################
# option parser
####################################################################

description = "Generates a mulit-component galaxy model"
epilog = """
The model is made out of four components:

a gas disk (Miyamoto-Nagai)
a stellar disk (Exponential disk)
a bulge (Plummer model)
a dark halo (generic two slope model)

All components can be fine tuned using parameters provided by a parameter file (see below).

Examples:
--------
ic_makeGalaxy
ic_makeGalaxy -p params.yml
ic_makeGalaxy -p params.yml -o galaxy.hdf5
ic_makeGalaxy --print-parameters


where the file params.yml contains the following parameters:

'''
%s
'''

"""%params_yml


parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)


parser.add_argument("-p", "--parameter-file",
                    action="store",
                    type=str,
                    dest="parameter_file",
                    required=False,
                    default=None,
                    help="Name of the parameter file")  

parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  

parser.add_argument("--debug",
                    action="store_true",
                    help="debug mode: save intermediate files")  

parser.add_argument("--print-parameters",
                    action="store_true",
                    help="print parameters and quit")                                  

parser.add_argument("--HubbleParameter",
                    action="store",
                    type=float,
                    default=0.6766,
                    help="Hubble parameter")  

                    
opt = parser.parse_args()



  


#################################################################
# read parameters
#################################################################

# if the parameter file is not provided, use the default values
if opt.parameter_file is None:
  parameter_file = "tmp_parameter_file.yml"
  with open(parameter_file, 'w') as f:
    f.write(params_yml)
  
  print("# using default parameters in %s"%parameter_file)
else:
  parameter_file = opt.parameter_file

params = iofunc.ReadYaml(parameter_file)

if opt.print_parameters:
  if opt.parameter_file is None:
    print(params_yml)
  else:
    with open(parameter_file, 'r') as f:
      content = f.read()
      print(content)
  
  exit()  


#################################################################
# Units
#################################################################

outputUnitLength_in_cm           = float(params["OutputUnits"]["UnitLength_in_cm"]) 
outputUnitMass_in_g              = float(params["OutputUnits"]["UnitMass_in_g"]) 
outputUnitVelocity_in_cm_per_s   = float(params["OutputUnits"]["UnitVelocity_in_cm_per_s"]) 

# pNb internal units (G=1)
pNbUnitLength_in_cm         = 3.085e21
pNbUnitMass_in_g            = 4.435693e44
pNbUnitVelocity_in_cm_per_s = 97824708.2699


uparams = {}
uparams['UnitLength_in_cm'] = pNbUnitLength_in_cm
uparams['UnitVelocity_in_cm_per_s'] = pNbUnitVelocity_in_cm_per_s
uparams['UnitMass_in_g'] = pNbUnitMass_in_g

pNb_system_of_units = units.Set_SystemUnits_From_Params(uparams)

kpctopNbUnits = (1 / (pNbUnitLength_in_cm * u.cm).to(u.kpc)).value
kmstopNbUnits = (1 / (pNbUnitVelocity_in_cm_per_s * u.cm / u.s).to(u.km / u.s)).value
MsoltopNbUnits = 1 / (pNbUnitMass_in_g * u.g).to(u.M_sun).value

#################################################################
# model parameters
#################################################################


m_ref          = float(params["Galaxy"]["m_ref"])    # mass of gas particles, in solar mass
nf             = params["Galaxy"]["nf"]              # particle count multiplicative factor (applied to reduce noise)
hydro          = params["Galaxy"]["hydro"]           # 1=gas is treated with SPH, 0=gas will be collision-less
outputfilename = params["Galaxy"]["outputfilename"]
ftype          = params["Galaxy"]["ftype"]
irand          = int(params["Galaxy"]["irand"])
boxsize        = float(params["Galaxy"]["boxsize"])

# units conversion (kpc to pNbody)
boxsize = boxsize *kpctopNbUnits


if opt.outputfilename is not None:
  outputfilename = opt.outputfilename


#####################
# Gas disk
#####################

fm_gas            = float(params["Gas"]["fm"])
M_gas             = float(params["Gas"]["mass"])
Hr_gas            = float(params["Gas"]["Hr"])
Hz_gas            = float(params["Gas"]["Hz"])
rmax_to_Hr_ratio  = float(params["Gas"]["rmax_to_Hr_ratio"])
zmax_to_Hz_ratio  = float(params["Gas"]["zmax_to_Hz_ratio"])
sigmavel_gas      = float(params["Gas"]["sigmavel"])
T_gas             = float(params["Gas"]["T"])

# units conversion (to pNbody units)
M_gas  = M_gas  *MsoltopNbUnits
Hr_gas = Hr_gas *kpctopNbUnits
Hz_gas = Hz_gas *kpctopNbUnits
if fm_gas == 0:
    M_gas = 0

Hr_gas = Hr_gas - Hz_gas             # to be compatible with the Plummer potential (Revaz 2004)
rmax_gas = rmax_to_Hr_ratio * Hr_gas
zmax_gas = zmax_to_Hz_ratio * Hz_gas
sigmavel_gas = sigmavel_gas * kmstopNbUnits  # km/s to code units

#####################
# Stellar disk
#####################

fm_disk           = float(params["Disk"]["fm"])
M_disk            = float(params["Disk"]["mass"])
Hr_disk           = float(params["Disk"]["Hr"])
Hz_disk           = float(params["Disk"]["Hz"])
fr_disk           = float(params["Disk"]["fr"])
fz_disk           = float(params["Disk"]["fz"])
toomre_disk       = float(params["Disk"]["toomre"])

# units conversion (to pNbody pNbody)
M_disk  = M_disk  *MsoltopNbUnits
Hr_disk = Hr_disk *kpctopNbUnits
Hz_disk = Hz_disk *kpctopNbUnits


# Stellar disk parameters
if fm_disk == 0:
    M_disk = 0
 
#####################
# Bulge
#####################

fm_bulge           = float(params["Bulge"]["fm"])
M_bulge            = float(params["Bulge"]["mass"])
Hr_bulge           = float(params["Bulge"]["Hr"])
fr_bulge           = float(params["Bulge"]["fr"])

# units conversion (to pNbody units)
M_bulge  = M_bulge  *MsoltopNbUnits
Hr_bulge = Hr_bulge *kpctopNbUnits

if fm_bulge == 0:
    M_bulge = 0


#####################
# Dark matter halo
#####################

from astropy.cosmology import Planck18 as cosmo
from pNbody.mass_models import gen_2_slopes


fm_halo            = float(params["Halo"]["fm"])
M_halo             = float(params["Halo"]["mass"])
c_halo             = float(params["Halo"]["c"])
alpha_halo         = float(params["Halo"]["alpha"])
beta_halo          = float(params["Halo"]["beta"])
fr_halo            = float(params["Halo"]["fr"])
dR_halo            = float(params["Halo"]["dR"])

Rcut_halo = None
if "Rcut" in params["Halo"]:
  Rcut_halo   = params["Halo"]["Rcut"]
  if Rcut_halo != "None":
    Rcut_halo = float(Rcut_halo)
  else:
    Rcut_halo = None
    
power_cut_halo = 3
if "power_cut" in params["Halo"]:
  power_cut_halo   = float(params["Halo"]["power_cut"])

  


# units conversion (to pNbody units)
M_halo     = M_halo     *MsoltopNbUnits
alpha_halo = alpha_halo 
beta_halo  = beta_halo  
dR_halo    = dR_halo    *kpctopNbUnits

# needs H0 in pNbody units
H0 = ctes.HUBBLE.into(pNb_system_of_units) * opt.HubbleParameter
G = ctes.GRAVITY.into(pNb_system_of_units)

gen2slopes = gen_2_slopes.GEN2SLOPES(alpha=alpha_halo,beta=beta_halo,c=c_halo,M200=M_halo,G=G,H0=H0)
gen2slopes.info()

Rs_halo    = gen2slopes.rs
Rmax_halo  = gen2slopes.r200*fr_halo

if fm_halo == 0:
    M_halo = 0


# estimation of the gravitational softening length
eps = 0.05 / (m_ref / 200000) ** (1 / 3.0)

m_ref = m_ref * MsoltopNbUnits  # to code units


#################################################################
# parameters for the velocities
#################################################################

ErrTolTheta = 0.5
AdaptativeSoftenning = False


###################################
# spherical components
###################################


# grid parameters halo
stats_name_halo = "stats_halo.dmp"
grmin_halo = 0  # grid minimal radius
grmax_halo = Rmax_halo * 1.05  # grid maximal radius
nr_halo = 64  # number of radial bins
eps_halo = eps
# grid bins functions
rc_halo = Rs_halo


def g_halo(r):
    return np.log(r / rc_halo + 1.0)


def gm_halo(r):
    return rc_halo * (np.exp(r) - 1.0)


# grid parameters bulge
stats_name_bulge = "stats_bulge.dmp"
grmin_bulge = 0  # grid minimal radius
grmax_bulge = Hr_bulge * fr_bulge * 1.05  # grid maximal radius
nr_bulge = 64  # number of radial bins
eps_bulge = eps
# grid bins functions
rc_bulge = Hr_bulge


def g_bulge(r):
    return np.log(r / rc_bulge + 1.0)


def gm_bulge(r):
    return rc_bulge * (np.exp(r) - 1.0)


###################################
# cylindrical components
###################################

# grid parameters disk
stats_name_disk = "stats_disk.dmp"
grmin_disk = 0.0  # minimal grid radius
grmax_disk = Hr_disk * fr_disk  # maximal grid radius
gzmin_disk = -Hz_disk * fz_disk  # minimal grid z
gzmax_disk = Hz_disk * fz_disk  # maximal grid z
nr_disk = 32  # number of bins in r
nt_disk = 2  # number of bins in t
nz_disk = 64 + 1  # number of bins in z
# for an even value of nz, the potential is computed at z=0
# for an odd  value of nz, the density   is computed at z=0
eps_disk = eps
# grid bins functions
rc_disk = 3.0


def g_disk(r):
    return np.log(r / rc_disk + 1.0)


def gm_disk(r):
    return rc_disk * (np.exp(r) - 1.0)


mode_sigma_z = {"name": "jeans", "param": None}
mode_sigma_r = {"name": "isothropic", "param": 2}
mode_sigma_p = {"name": "epicyclic_approximation", "param": None}
params_disk = [mode_sigma_z, mode_sigma_r, mode_sigma_p]


# grid parameters gas
stats_name_gas = "stats_gas.dmp"
grmin_gas = 0.0  # minimal grid radius
grmax_gas = rmax_gas * 1.05  # maximal grid radius
gzmin_gas = -zmax_gas * 1.05  # minimal grid z
gzmax_gas = zmax_gas * 1.05  # maximal grid z
nr_gas = 32  # number of bins in r
nt_gas = 2  # number of bins in t
nz_gas = 64 + 1  # number of bins in z
# for an even value of nz, the potential is computed at z=0
# for an odd  value of nz, the density   is computed at z=0
eps_gas = eps
rc_gas = 3.0
# grid bins functions
def g_gas(r):
    return np.log(r / rc_gas + 1.0)


def gm_gas(r):
    return rc_gas * (np.exp(r) - 1.0)


mode_sigma_z = {"name": "jeans", "param": None}
mode_sigma_r = {"name": "constant", "param": sigmavel_gas}
mode_sigma_p = {"name": "epicyclic_approximation", "param": None}
params_gas = [mode_sigma_z, mode_sigma_r, mode_sigma_p]


#################################################################
# compute the number of particles for each component
#################################################################


# here we give explicitly the mass of the gas particles
m = m_ref

if fm_gas == 0:
    N_gas = 0
else:
    N_gas = int(M_gas / (m * fm_gas))

if fm_disk == 0:
    N_disk = 0
else:
    N_disk = int(M_disk / (m * fm_disk))

if fm_bulge == 0:
    N_bulge = 0
else:
    N_bulge = int(M_bulge / (m * fm_bulge))

if fm_halo == 0:
    N_halo = 0
else:
    N_halo = int(M_halo / (m * fm_halo))

print("N_gas   = %d" % N_gas)
print("N_disk  = %d" % N_disk)
print("N_bulge = %d" % N_bulge)
print("N_halo  = %d" % N_halo)
print("----------------------------")
print("N_tot   = %d" % (N_gas + N_disk + N_bulge + N_halo))
print("----------------------------")


print()
if N_gas > 0:
    print("m_gas   = %g Msol" % ((M_gas / N_gas) / MsoltopNbUnits))

if N_disk > 0:
    print("m_disk  = %g Msol" % ((M_disk / N_disk) / MsoltopNbUnits))

if N_bulge > 0:
    print("m_bulge = %g Msol" % ((M_bulge / N_bulge) / MsoltopNbUnits))

if N_halo > 0:
    print("m_halo  = %g Msol" % ((M_halo / N_halo) / MsoltopNbUnits))
print()


if nf > 1:
    N_gas = int(nf * N_gas)
    N_disk = int(nf * N_disk)
    N_bulge = int(nf * N_bulge)
    N_halo = int(nf * N_halo)


#################################################################
# generate models
#################################################################


#####################
# exponnential disk
#####################

nb_disk = None
if M_disk != 0.0:
    print("generating disk...")
    nb_disk = ic.expd(
        N_disk,
        Hr_disk,
        Hz_disk,
        fr_disk * Hr_disk,
        fz_disk * Hz_disk,
        irand=0,
        ftype="gh5",
    )
    nb_disk.set_tpe("disk")
    nb_disk.mass = (M_disk / N_disk) * np.ones(nb_disk.nbody).astype(np.float32)
    nb_disk.rename("disk.dat")
    
    if opt.debug:
      nb_disk.write()


#####################
# halo
#####################


nb_halo = None
if M_halo != 0.0:
    print("generating halo...")
        
    if Rcut_halo is None:
      # use the old version
      nb_halo = ic.generic2c(N_halo,Rs_halo,alpha_halo,beta_halo,Rmax_halo,dR_halo,ftype="gh5")
    else:
      # use the new one, including the cutoff
      nb_halo = ic.gen_2_slopes(N_halo,Rs_halo,alpha_halo,beta_halo,Rmax_halo,dR_halo,Rcut_halo,power_cut_halo,ftype="gh5") 
    
    nb_halo.set_tpe("halo")
    nb_halo.mass = (M_halo / N_halo) * np.ones(nb_halo.nbody).astype(np.float32)
    nb_halo.rename("halo.dat")
    if opt.debug:
      nb_halo.write()


#####################
# bulge
#####################

nb_bulge = None
if M_bulge != 0.0:
    print("generating bulge...")
    nb_bulge = ic.plummer(
        N_bulge, 1, 1, 1, Hr_bulge, fr_bulge * Hr_bulge, vel="no", ftype="gh5"
    )
    nb_bulge.set_tpe("bulge")
    nb_bulge.mass = (M_bulge / N_bulge) * np.ones(nb_bulge.nbody).astype(np.float32)
    nb_bulge.rename("bulge.dat")
    if opt.debug:
      nb_bulge.write()

#####################
# gas disk
#####################

nb_gas = None
if M_gas != 0.0:
    print("generating gas...")
    nb_gas = ic.miyamoto_nagai(
        N_gas, Hr_gas, Hz_gas, rmax_gas, zmax_gas, irand=-2, ftype="gh5"
    )
    nb_gas.set_tpe("gas")
    nb_gas.mass = (M_gas / N_gas) * np.ones(nb_gas.nbody).astype(np.float32)
    nb_gas.rename("gas.dat")
    if opt.debug:
      nb_gas.write()


###############################################################
# merge all components
###############################################################

nb = None

if nb_disk is not None:
    if nb is None:
        nb = nb_disk
    else:
        nb = nb + nb_disk

if nb_halo is not None:
    if nb is None:
        nb = nb_halo
    else:
        nb = nb + nb_halo

if nb_bulge is not None:
    if nb is None:
        nb = nb_bulge
    else:
        nb = nb + nb_bulge

if nb_gas is not None:
    if nb is None:
        nb = nb_gas
    else:
        nb = nb + nb_gas

# save particles without velocities
if opt.debug:
  nb.write("snapnf.hdf5")

###############################################################
# compute velocities
###############################################################


if nb_disk is not None:
    print("------------------------")
    print("disk velocities...")
    print("------------------------")

    nb_disk, phi, stats_disk = nb.Get_Velocities_From_Cylindrical_Grid(
        select="disk",
        disk=("disk", "gas"),
        eps=eps_disk,
        nR=nr_disk,
        nz=nz_disk,
        nt=nt_disk,
        Rmax=grmax_disk,
        zmin=gzmin_disk,
        zmax=gzmax_disk,
        params=params_disk,
        Phi=None,
        g=g_disk,
        gm=gm_disk,
        ErrTolTheta=ErrTolTheta,
        AdaptativeSoftenning=AdaptativeSoftenning,
    )
    if opt.debug:
      iofunc.write_dmp(stats_name_disk, stats_disk)

    r = stats_disk["R"]
    z = stats_disk["z"]
    dr = r[1] - r[0]
    dz = z[nz_disk // 2 + 1] - z[nz_disk // 2]

    print("disk : Delta R :", dr, "=", dr // eps_disk, "eps")
    print("disk : Delta z :", dz, "=", dz // eps_disk, "eps")

    # reduce if needed
    if nf > 1:
        nb_disk = nb_disk.reduc(nf, mass=True)


if nb_gas is not None:
    print("------------------------")
    print("gas velocities...")
    print("------------------------")
    nb_gas, phi, stats_gas = nb.Get_Velocities_From_Cylindrical_Grid(
        select="gas",
        disk=("disk", "gas"),
        eps=eps_gas,
        nR=nr_gas,
        nz=nz_gas,
        nt=nt_gas,
        Rmax=grmax_gas,
        zmin=gzmin_gas,
        zmax=gzmax_gas,
        params=params_gas,
        Phi=None,
        g=g_gas,
        gm=gm_gas,
        ErrTolTheta=ErrTolTheta,
        AdaptativeSoftenning=AdaptativeSoftenning,
    )
    if opt.debug:
      iofunc.write_dmp(stats_name_gas, stats_gas)

    r = stats_gas["R"]
    z = stats_gas["z"]
    dr = r[1] - r[0]
    dz = z[nz_gas // 2 + 1] - z[nz_gas // 2]
    print("gas   : Delta R :", dr, "=", dr / eps_gas, "eps")
    print("gas   : Delta z :", dz, "=", dz / eps_gas, "eps")

    # reduce if needed
    if nf > 1:
        nb_gas = nb_gas.reduc(nf, mass=True)


if nb_bulge is not None:
    print("------------------------")
    print("bulge velocities...")
    print("------------------------")
    nb_bulge, phi, stats_bulge = nb.Get_Velocities_From_Spherical_Grid(
        select="bulge",
        eps=eps_bulge,
        nr=nr_bulge,
        rmax=grmax_bulge,
        phi=None,
        g=g_bulge,
        gm=gm_bulge,
        UseTree=True,
        ErrTolTheta=ErrTolTheta,
    )
    if opt.debug:
      iofunc.write_dmp(stats_name_bulge, stats_bulge)

    r = stats_bulge["r"]
    dr = r[1] - r[0]
    print("bulge : Delta r :", dr, "=", dr / eps_bulge, "eps")

    # reduce if needed
    if nf > 1:
        nb_bulge = nb_bulge.reduc(nf, mass=True)


if nb_halo is not None:
    print("------------------------")
    print("halo velocities...")
    print("------------------------")
    nb_halo, phi, stats_halo = nb.Get_Velocities_From_Spherical_Grid(
        select="halo",
        eps=eps_halo,
        nr=nr_halo,
        rmax=grmax_halo,
        phi=None,
        g=g_halo,
        gm=gm_halo,
        UseTree=True,
        ErrTolTheta=ErrTolTheta,
    )
    if opt.debug:
      iofunc.write_dmp(stats_name_halo, stats_halo)

    r = stats_halo["r"]
    dr = r[1] - r[0]
    print("halo : Delta r :", dr, "=", dr / eps_halo, "eps")

    # reduce if needed
    if nf > 1:
        nb_halo = nb_halo.reduc(nf, mass=True)


###############################################################
# sum the different components and save the final model
###############################################################

nb = None

if nb_disk is not None:
    if nb is None:
        nb = nb_disk
    else:
        nb = nb + nb_disk

if nb_halo is not None:
    if nb is None:
        nb = nb_halo
    else:
        nb = nb + nb_halo

if nb_bulge is not None:
    if nb is None:
        nb = nb_bulge
    else:
        nb = nb + nb_bulge

if nb_gas is not None:
    if nb is None:
        nb = nb_gas
    else:
        nb = nb + nb_gas


# reorganize components
nb1 = nb.select("halo")
nb1.set_tpe("halo_1")  # position expected by swift

nb2 = nb.select("disk")
nb2.set_tpe("stars_1")  # position expected by swift

nb3 = nb.select("bulge")
nb3.set_tpe("stars_1")  # position expected by swift

nb4 = nb.select("gas")

if hydro == 0:
    nb4.set_tpe("bndry")

nb = nb1 + nb2 + nb3 + nb4

# convert to swift
nb = nb.set_ftype(ftype=ftype)


# add units
np.UnitLength_in_cm         = outputUnitLength_in_cm
nb.UnitMass_in_g            = outputUnitMass_in_g
nb.UnitVelocity_in_cm_per_s = outputUnitVelocity_in_cm_per_s
nb.Unit_time_in_cgs         = outputUnitLength_in_cm / outputUnitVelocity_in_cm_per_s

nb.boxsize = boxsize
nb.rsp_init = np.ones(nb.nbody) * eps
nb.birth_time_init = -1 * np.ones(nb.nbody)

if hydro:  
    # Boltzman constant in user units
    k = (c.k_B.to(u.g * u.cm ** 2 / u.s ** 2 / u.K)/ outputUnitVelocity_in_cm_per_s ** 2/ outputUnitMass_in_g).value 
    # proton mass in user units 
    m_p = (c.m_p.to(u.g) / outputUnitMass_in_g).value  
    gamma = 5 / 3.0
    xi = 0.76
    ionisation = 0
    mu = thermodyn.MeanWeight(xi, ionisation)
    mumh = m_p * mu
    nb.u_init = T_gas / (gamma - 1.0) * k / mumh * np.ones(nb.nbody)


# final unit conversion : pNbody -> user
nb.pos      = nb.pos     *pNbUnitLength_in_cm/outputUnitLength_in_cm
nb.vel      = nb.vel     *pNbUnitVelocity_in_cm_per_s/outputUnitVelocity_in_cm_per_s
nb.mass     = nb.mass    *pNbUnitMass_in_g/outputUnitMass_in_g
nb.rsp_init = nb.rsp_init*pNbUnitLength_in_cm/outputUnitLength_in_cm
nb.boxsize  = nb.boxsize *pNbUnitLength_in_cm/outputUnitLength_in_cm


# check boxsize
xmax = max(np.fabs(nb.x()))
if nb.boxsize < 2*xmax:
  txt = "WARNING: boxsize (=%g) < 2*xmax (=%g) !"%(nb.boxsize,2*xmax)
  warnings.warn(txt, UserWarning)


# save model
nb.rename(outputfilename)
nb.write()

#%%
#Add the StellarParticleType attribute to the dataset
import h5py as h5
import numpy as np

nb_star = nb.select("stars")

if nb_star.nbody > 0:
  
  N_star = np.sum(nb_star.npart)
  star_tpe = 2 # Single population stars
  star_type = np.ones(N_star)*star_tpe

  with  h5.File(outputfilename, "r+") as f:
    f["PartType4"].create_dataset("StellarParticleType", data=star_type)
