#!/usr/bin/env python3
###########################################################################################
#  package:   pNbody
#  file:      ic_gen_2_slopes+plummer
#  brief:     Generates a two particle model based on a generalized two slopes model and plummer
#  copyright: GPLv3
#             Copyright (C) 2023 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Darwin Roduit <darwin.roduit@epfl.ch>
#
# This file is part of pNbody.
###########################################################################################


import argparse
import numpy as np
from pNbody import *
from pNbody.mass_models import plummer
from pNbody.mass_models import gen_2_slopes
from pNbody.DF import DistributionFunction
from astropy import constants as cte
from astropy import units as u
from scipy.integrate import quad


def potential_from_density(fctDensity, G):
    """Integrates the potential from the density"""
    return np.vectorize(lambda x: -4*np.pi*G*quad( lambda r: r*r * fctDensity(r) , 0., x)[0]/x 
                                                -4*np.pi*G*quad( lambda r: r*fctDensity(r) , x, np.inf)[0])

def cumulativemass_from_density(fctDensity):
    """Integrates the mass from the density"""
    Mr = lambda x:4*np.pi * quad( lambda r: r*r * fctDensity(r) , 0., x)[0]
    return np.vectorize(Mr)


def exp_cutoff(x, x_cut, speed):
    return np.exp( - (x / x_cut)**speed)

####################################################################
# option parser
####################################################################

description = """Generates a generalized 2-slopes + Plummer model at equilibrium.

The model is composed of a  generalized 2-slopes model (the dark halo) as well as a Plummer model (the stellar counterpart).

The Plummer is parametrized by :

  Mtot : the total Plummer mass, in Msol 
  a    : the scale radius, equivalent to the half light radius when projected, in kpc

The generalized 2-slope model is parametrized either by:

  rs    : the scale radius, in kpc
  rho0  : the density at the scale radius, in Msol/kpc^3
  a     : the inner slope
  b     : the outer slope

or:

  M200  : the virial mass, in Msol
  c     : the concentration parameter
  a     : the inner slope
  b     : the outer slope


"""
epilog = """
Examples:
--------
ic_gen_2_slopes+plummer --alpha 1 --beta 3 --Rmin 1e-3 --Rmax 30 --r_cutoff 30  --rho0 0.007 --r_s 0.5 --Mtot 1e5 -a 0.1 --ptype1 4 --ptype2 1 --mass1 1000 --mass2 1000 -t swift -o gen_2_slopes+plummer.hdf5 \n
ic_gen_2_slopes+plummer --alpha 0 --beta 3 --Rmin 1e-3 --Rmax 30 --r_cutoff 30 --rho0 0.0008 --r_s 0.5 --Mtot 1e5 -a 0.1 --ptype1 4 --ptype2 1 --mass1 125 --mass2 125 -t swift -o 2s0_3_0008_05+plummer.hdf5 \n

ic_gen_2_slopes+plummer --alpha 1 --beta 3 --Rmin 1e-3 --Rmax 30 --r_cutoff 30 --M200 1e9 --c 17 --Mtot 1e5 -a 0.1 --ptype1 4 --ptype2 1 --mass1 1000 --mass2 10000 -t swift -o gen_2_slopes+plummer.hdf5 
ic_gen_2_slopes+plummer --alpha 0 --beta 3 --Rmin 1e-3 --Rmax 30 --r_cutoff 30 --M200 1e9 --c 17 --Mtot 1e5 -a 0.1 --ptype1 4 --ptype2 1 --mass1 1000 --mass2 10000 -t swift -o gen_2_slopes+plummer.hdf5 

ic_gen_2_slopes+plummer --alpha 1 --beta 3 --Rmin 1e-3 --Rmax 30 --r_cutoff 30 --M200 1e9 --c 17 --Mtot 1e5 -a 0.1 --ptype1 4 --ptype2 1 -N 1000000 -t swift -o gen_2_slopes+plummer.hdf5 
ic_gen_2_slopes+plummer --alpha 0 --beta 3 --Rmin 1e-3 --Rmax 30 --r_cutoff 30 --M200 1e9 --c 17 --Mtot 1e5 -a 0.1 --ptype1 4 --ptype2 1 -N 1000000 -t swift -o gen_2_slopes+plummer.hdf5 

"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)


parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  

parser.add_argument("-t",
                    action="store",
                    type=str,
                    dest="ftype",
                    default='gh5',
                    help="Type of the output file")  

parser.add_argument("--irand",
                    action="store", 
                    dest="irand", 
                    metavar='INT', 
                    type=int,
                    default=0,
                    help='Random seed') 

parser.add_argument("--plot",
                    action="store_true",
                    dest="plot",
                    default=False,
                    help="Plot dynamical time")                        
                      
parser.add_argument("-N",
                    action="store", 
                    dest="N", 
                    metavar='INT', 
                    type=int,
                    default=1e5,
                    help='Total number of particles') 


parser.add_argument("--c",
                    action="store",
                    type=float,
                    dest="c",
                    default=None,
                    help="concentration parameter")

parser.add_argument("--M200",
                    action="store",
                    type=float,
                    dest="M200",
                    default=None,
                    help="virial mass in Msol")

parser.add_argument("--rho0",
                    action="store",
                    type=float,
                    dest="rho0",
                    default=None,
                    help="Gen2slopes : density at rs")

parser.add_argument("--rs",
                    action="store",
                    type=float,
                    dest="rs",
                    default=None,
                    help="Gen2slopes : rs parameter")

parser.add_argument("--alpha",
                    action="store",
                    type=float,
                    dest="alpha",
                    default=1,
                    help="Gen2slopes : alpha parameter")

parser.add_argument("--beta",
                    action="store",
                    type=float,
                    dest="beta",
                    default=3,
                    help="Gen2slopes : beta parameter")

parser.add_argument("-a",
                    action="store",
                    type=float,
                    dest="a",
                    default=1,
                    help="Plummer a parameter")

parser.add_argument("--Mtot",
                    action="store",
                    type=float,
                    dest="Mtot",
                    default=1,
                    help="Plummer total mass in solar mass")

parser.add_argument("--Rmin",
                    action="store",
                    type=float,
                    dest="Rmin",
                    default=1e-2,
                    help="Rmin")

parser.add_argument("--Rmax",
                    action="store",
                    type=float,
                    dest="Rmax",
                    default=100,
                    help="Rmax")

parser.add_argument("--NE",
                    action="store",
                    type=float,
                    dest="NE",
                    default=1e4,
                    help="number of energy bins")

parser.add_argument("--NR",
                    action="store",
                    type=float,
                    dest="NR",
                    default=1e4,
                    help="number of radius bins")
              
parser.add_argument("--Ndraw",
                    action="store",
                    type=int,
                    dest="Ndraw",
                    default=1e6,
                    help="number of particles to draw at each loop")  

parser.add_argument("--ptype1",
                    action="store", 
                    dest="ptype1", 
                    metavar='INT', 
                    type=int,
                    default=1,
                    help="particle type of component 1")

parser.add_argument("--ptype2",
                    action="store", 
                    dest="ptype2", 
                    metavar='INT', 
                    type=int,
                    default=1,
                    help="particle type of component 2")
                     
parser.add_argument("--mass1",
                    action="store",
                    type=float,
                    dest="mass1",
                    default=None,
                    help="particle 1 mass in solar mass")   
                      
parser.add_argument("--mass2",
                    action="store",
                    type=float,
                    dest="mass2",
                    default=None,
                    help="particle 2 mass in solar mass")
  
parser.add_argument("--r_cutoff",
                    action="store",
                    type=float,
                    dest="r_cutoff",
                    default=100,
                    help="Density cutoff radius")        
  
parser.add_argument("--power_cutoff",
                    action="store",
                    type=float,
                    dest="power_cutoff",
                    default=3,
                    help="Power of the cutoff")          

parser.add_argument("-q","--quiet",
                    action="store_true",
                    dest="quiet",
                    default=False,
                    help="quiet mode") 



########################################
# main
########################################

opt = parser.parse_args()

if opt.quiet:
  import warnings
  warnings.filterwarnings("ignore")

#########################################################
# define units
#########################################################
u_Length   = 1* u.kpc
u_Mass     = 10**10 * u.M_sun
u_Velocity = 1* u.km/u.s
u_Time     = u_Length/u_Velocity 
toMsol     = u_Mass.to(u.M_sun).value

u_dict = {}
u_dict["UnitLength_in_cm"]         = u_Length.to(u.cm).value
u_dict["UnitMass_in_g"]            = u_Mass.to(u.g).value
u_dict["UnitVelocity_in_cm_per_s"] = u_Velocity.to(u.cm/u.s).value
u_dict["Unit_time_in_cgs"]         = u_Time.to(u.s).value

# define constants values
G = cte.G.to( u_Length**3/(u_Mass*u_Time**2) ).value

# convert Mtot
opt.Mtot = opt.Mtot/toMsol


#########################################################
# define the GEN2SLOPES model
#########################################################


if (opt.rho0 is not None) and (opt.rs is not None):
  gen2slopes = gen_2_slopes.GEN2SLOPES(alpha=opt.alpha,beta=opt.beta,rho0=opt.rho0,rs=opt.rs,G=G)
  gen2slopes.info()
elif (opt.c is not None) and (opt.M200 is not None): 
  from astropy.cosmology import Planck18 as cosmo
  M200 = (opt.M200*u.Msun).to(u_Mass).value
  H0   = cosmo.H0.to(1/u_Time).value
  gen2slopes = gen_2_slopes.GEN2SLOPES(alpha=opt.alpha,beta=opt.beta,c=opt.c,M200=M200,G=G,H0=H0)
  gen2slopes.info()
else:
  raise ValueError("Either (rho0 and rs) or (c and M200) must be defined.")



#########################################################
# Define the total potential
#########################################################
fctDensity_gen_2_slopes_cutoff = lambda x:gen2slopes.Density(x)*exp_cutoff(x, opt.r_cutoff, opt.power_cutoff)
fctCumulativeMass_gen_2_slopes_cutoff = cumulativemass_from_density(fctDensity_gen_2_slopes_cutoff)
fctPotential_gen_2_slopes_cutoff =  potential_from_density(fctDensity_gen_2_slopes_cutoff, G) #vectorized function
fctPotential = lambda x:fctPotential_gen_2_slopes_cutoff(x)+plummer.Potential(M=opt.Mtot, a=opt.a, r=x, G=G) #Potential with cutoff


#########################################################
# Define the total mass
#########################################################

TotalMass1         = plummer.TotalMass(M=opt.Mtot, a=opt.a, G=G)   
#TotalMass2         = gen2slopes.CumulativeMass(opt.Rmax)
TotalMass2         = fctCumulativeMass_gen_2_slopes_cutoff(opt.Rmax)


#########################################################
# Define the mass distribution of the sub-component 1
#########################################################
fctDensity        = lambda x:plummer.Density(M=opt.Mtot, a=opt.a, r=x, G=G)
fctCumulativeMass = lambda x:plummer.CumulativeMass(M=opt.Mtot, a=opt.a, r=x, G=G) 
fctVcirc          = lambda x:plummer.Vcirc(M=opt.Mtot, a=opt.a, r=x, G=G)
TotalMass         = plummer.TotalMass(M=opt.Mtot, a=opt.a, G=G)   


# set the number of particles
if opt.mass1 is not None:
  N1 = int(TotalMass*toMsol/opt.mass1)
else:
  N1 = int(opt.N* (TotalMass1)/(TotalMass1+TotalMass2))

if N1==0:
  raise ValueError("N1 is 0 !")


DF = DistributionFunction(Rmin=opt.Rmin,Rmax=opt.Rmax,Rinf=np.inf,NR=opt.NR,NE=opt.NE,
                          fctDensity=fctDensity,
                          fctPotential=fctPotential,
                          fctCumulativeMass=fctCumulativeMass,
                          TotalMass=TotalMass,G=G)

DF.computeDF()
DF.clean()
DF.computeMaxLikelihood()
DF.sample(N1,opt.Ndraw,irand=opt.irand)


#########################################################
# Create first model
#########################################################
nb1 = Nbody(status='new',p_name=opt.outputfilename,pos=DF.pos,vel=DF.vel,mass=DF.mass,ftype=opt.ftype)
nb1.set_tpe(opt.ptype1)







#########################################################
# Define the mass distribution of the sub-component 2
#########################################################
fctDensity        = fctDensity_gen_2_slopes_cutoff
fctCumulativeMass = None
TotalMass         = TotalMass2
fctVcirc          = None
fctPotential      = fctPotential_gen_2_slopes_cutoff  # Note: according to Darwin Roduit, it works better if here is no cutoff



# set the number of particles
if opt.mass2 is not None:
  N2= int(TotalMass*toMsol/opt.mass2)
else:
  N2 = int(opt.N* (TotalMass2)/(TotalMass1+TotalMass2))

if N2==0:
  raise ValueError("N2 is 0 !")
  



DF = DistributionFunction(Rmin=opt.Rmin,Rmax=opt.Rmax,Rinf=np.inf,NR=opt.NR,NE=opt.NE,
                          fctDensity=fctDensity,
                          fctPotential=fctPotential,
                          fctCumulativeMass=fctCumulativeMass,
                          TotalMass=TotalMass,G=G)


DF.computeDF()
DF.clean()
DF.computeMaxLikelihood()
DF.sample(N2,opt.Ndraw,irand=opt.irand)


#########################################################
# Create second model
#########################################################
nb2 = Nbody(status='new',p_name=opt.outputfilename,pos=DF.pos,vel=DF.vel,mass=DF.mass,ftype=opt.ftype)
nb2.set_tpe(opt.ptype2)


if opt.outputfilename:
  
  # sum
  nb = nb1+nb2
  
  # add units
  nb.UnitLength_in_cm         = u_dict["UnitLength_in_cm"]
  nb.UnitMass_in_g            = u_dict["UnitMass_in_g"]           
  nb.UnitVelocity_in_cm_per_s = u_dict["UnitVelocity_in_cm_per_s"]
  nb.Unit_time_in_cgs         = u_dict["Unit_time_in_cgs"]
  

  # additional stuffs
  if opt.ftype == "gh5":
    nb.massarr=None
    nb.nzero = None
  if opt.ftype == "swift":
    nb.boxsize=4*opt.Rmax
  
  nb.setComovingIntegrationOff()
  
  # cvcenter
  nb.cvcenter()
  
  # write
  nb.write()
