#!/usr/bin/env python3

import argparse
import numpy as np
from pNbody import *
from pNbody.mass_models import plummer
from pNbody.mass_models import nfw
from pNbody.DF import DistributionFunction
from astropy import constants as cte
from astropy import units as u



####################################################################
# option parser
####################################################################

description="generate an nfw+plummer model at equilibrium"
epilog     ="""
Examples:
--------
ic_nfw+plummer --Rmin 1e-3 --Rmax 30  --rho0 0.007 --rs 0.5   --Mtot 1e5 -a 0.1 --ptype1 4 --ptype2 1  --mass1 1000 --mass2 1000  -t swift -o nfw+plummer.hdf5
ic_nfw+plummer --Rmin 1e-3 --Rmax 100 --M200 1e11  --c  17    --Mtot 1e9 -a 1   --ptype1 4 --ptype2 1  -N 100000                  -t swift -o nfw+plummer.hdf5
ic_nfw+plummer --Rmin 1e-3 --Rmax 100 --M200 1e11  --c  17    --Mtot 1e9 -a 1   --ptype1 4 --ptype2 1  --mass1 1e4   --mass2 1e5  -t swift -o nfw+plummer.hdf5
"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)




parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  


parser.add_argument("-t",
                    action="store",
                    type=str,
                    dest="ftype",
                    default='gh5',
                    help="Type of the output file")  

parser.add_argument("--irand",
                    action="store", 
                    dest="irand", 
                    metavar='INT', 
                    type=int,
                    default=0,
                    help='random seed') 

                       
parser.add_argument("--plot",
                    action="store_true",
                    dest="plot",
                    default=False,
                    help="plot dynamical time")                        
                       
                                 
parser.add_argument("-N",
                    action="store", 
                    dest="N", 
                    metavar='INT', 
                    type=int,
                    default=1e5,
                    help='total number of particles') 


parser.add_argument("--rs",
                    action="store",
                    type=float,
                    dest="rs",
                    default=None,
                    help="NFW rs parameter")

parser.add_argument("--rho0",
                    action="store",
                    type=float,
                    dest="rho0",
                    default=None,
                    help="NFW density at critical radius")


parser.add_argument("--c",
                    action="store",
                    type=float,
                    dest="c",
                    default=None,
                    help="concentration parameter")

parser.add_argument("--M200",
                    action="store",
                    type=float,
                    dest="M200",
                    default=None,
                    help="virial mass in Msol")




parser.add_argument("-a",
                    action="store",
                    type=float,
                    dest="a",
                    default=1,
                    help="a parameter")

parser.add_argument("--Mtot",
                    action="store",
                    type=float,
                    dest="Mtot",
                    default=1,
                    help="Total mass in solar mass")

parser.add_argument("--Rmin",
                    action="store",
                    type=float,
                    dest="Rmin",
                    default=1e-2,
                    help="Rmin")

parser.add_argument("--Rmax",
                    action="store",
                    type=float,
                    dest="Rmax",
                    default=100,
                    help="Rmax")

parser.add_argument("--NE",
                    action="store",
                    type=float,
                    dest="NE",
                    default=1e4,
                    help="number of energy bins")

parser.add_argument("--NR",
                    action="store",
                    type=float,
                    dest="NR",
                    default=1e4,
                    help="number of radius bins")
              
parser.add_argument("--Ndraw",
                    action="store",
                    type=int,
                    dest="Ndraw",
                    default=1e6,
                    help="number of particles to draw at each loop")  


parser.add_argument("--ptype1",
                    action="store", 
                    dest="ptype1", 
                    metavar='INT', 
                    type=int,
                    default=1,
                    help="particle type of component 1")

parser.add_argument("--ptype2",
                    action="store", 
                    dest="ptype2", 
                    metavar='INT', 
                    type=int,
                    default=1,
                    help="particle type of component 2")
                     
parser.add_argument("--mass1",
                    action="store",
                    type=float,
                    dest="mass1",
                    default=None,
                    help="particle 1 mass in solar mass")   
                      
parser.add_argument("--mass2",
                    action="store",
                    type=float,
                    dest="mass2",
                    default=None,
                    help="particle 2 mass in solar mass")          

parser.add_argument("-q","--quiet",
                    action="store_true",
                    dest="quiet",
                    default=False,
                    help="quiet mode") 


########################################
# main
########################################

opt = parser.parse_args()

if opt.quiet:
  import warnings
  warnings.filterwarnings("ignore")


#########################################################
# define units
#########################################################
u_Length   = 1* u.kpc
u_Mass     = 10**10 * u.M_sun
u_Velocity = 1* u.km/u.s
u_Time     = u_Length/u_Velocity 
toMsol     = u_Mass.to(u.M_sun).value

u_dict = {}
u_dict["UnitLength_in_cm"]         = u_Length.to(u.cm).value
u_dict["UnitMass_in_g"]            = u_Mass.to(u.g).value
u_dict["UnitVelocity_in_cm_per_s"] = u_Velocity.to(u.cm/u.s).value
u_dict["Unit_time_in_cgs"]         = u_Time.to(u.s).value

# define constants values
G = cte.G.to( u_Length**3/(u_Mass*u_Time**2) ).value

# convert Mtot
opt.Mtot = opt.Mtot/toMsol

#########################################################
# define the NFW model
#########################################################


if (opt.rho0 is not None) and (opt.rs is not None):
  nfw = nfw.NFW(rho0=opt.rho0,rs=opt.rs,G=G)
elif (opt.c is not None) and (opt.M200 is not None): 
  from astropy.cosmology import Planck18 as cosmo
  M200 = (opt.M200*u.Msun).to(u_Mass).value
  H0   = cosmo.H0.to(1/u_Time).value
  nfw = nfw.NFW(c=opt.c,M200=M200,G=G,H0=H0)
  nfw.info()
else:
  raise ValueError("Either (rho0 and rs) or (c and M200) must be defined.")



#########################################################
# Define the total potential
#########################################################
fctPotential      = lambda x:nfw.Potential(x)+plummer.Potential(M=opt.Mtot, a=opt.a, r=x, G=G)


#########################################################
# Define the total mass
#########################################################

TotalMass1         = plummer.TotalMass(M=opt.Mtot, a=opt.a, G=G)   
TotalMass2         = nfw.CumulativeMass(opt.Rmax)



#########################################################
# Define the mass distribution of the sub-component 1
#########################################################
fctDensity        = lambda x:plummer.Density(M=opt.Mtot, a=opt.a, r=x, G=G)
fctCumulativeMass = lambda x:plummer.CumulativeMass(M=opt.Mtot, a=opt.a, r=x, G=G) 
fctVcirc          = lambda x:plummer.Vcirc(M=opt.Mtot, a=opt.a, r=x, G=G)
TotalMass         = plummer.TotalMass(M=opt.Mtot, a=opt.a, G=G)   

# set the number of particles
if opt.mass1 is not None:
  N1 = int(TotalMass*toMsol/opt.mass1)
else:
  N1 = int(opt.N* (TotalMass1)/(TotalMass1+TotalMass2))

if N1==0:
  raise ValueError("N1 is 0 !")

DF = DistributionFunction(Rmin=opt.Rmin,Rmax=opt.Rmax,Rinf=np.inf,NR=opt.NR,NE=opt.NE,
                          fctDensity=fctDensity,
                          fctPotential=fctPotential,
                          fctCumulativeMass=fctCumulativeMass,
                          TotalMass=TotalMass,G=G)

DF.computeDF()
DF.clean()
DF.computeMaxLikelihood()
DF.sample(N1,opt.Ndraw,irand=opt.irand)


#########################################################
# Create first model
#########################################################
nb1 = Nbody(status='new',p_name=opt.outputfilename,pos=DF.pos,vel=DF.vel,mass=DF.mass,ftype=opt.ftype)
nb1.set_tpe(opt.ptype1)




#########################################################
# Define the mass distribution of the sub-component 2
#########################################################
fctDensity        = lambda x:nfw.Density(x)
fctCumulativeMass = lambda x:nfw.CumulativeMass(x)       
TotalMass         = nfw.CumulativeMass(opt.Rmax)  # unbounded if rmax=inf
fctVcirc          = lambda x:nfw.Vcirc(x)  



# set the number of particles
if opt.mass2 is not None:
  N2 = int(TotalMass*toMsol/opt.mass2)
else:
  N2 = int(opt.N* (TotalMass2)/(TotalMass1+TotalMass2))

if N2==0:
  raise ValueError("N2 is 0 !")

DF = DistributionFunction(Rmin=opt.Rmin,Rmax=opt.Rmax,Rinf=np.inf,NR=opt.NR,NE=opt.NE,
                          fctDensity=fctDensity,
                          fctPotential=fctPotential,
                          fctCumulativeMass=fctCumulativeMass,
                          TotalMass=TotalMass,G=G)


DF.computeDF()
DF.clean()
DF.computeMaxLikelihood()
DF.sample(N2,opt.Ndraw,irand=opt.irand)


#########################################################
# Create second model
#########################################################
nb2 = Nbody(status='new',p_name=opt.outputfilename,pos=DF.pos,vel=DF.vel,mass=DF.mass,ftype=opt.ftype)
nb2.set_tpe(opt.ptype2)


if opt.outputfilename:
  
  # sum
  nb = nb1+nb2
  
  # add units
  nb.UnitLength_in_cm         = u_dict["UnitLength_in_cm"]
  nb.UnitMass_in_g            = u_dict["UnitMass_in_g"]           
  nb.UnitVelocity_in_cm_per_s = u_dict["UnitVelocity_in_cm_per_s"]
  nb.Unit_time_in_cgs         = u_dict["Unit_time_in_cgs"]
  

  # additional stuffs
  if opt.ftype == "gh5":
    nb.massarr=None
    nb.nzero = None
  if opt.ftype == "swift":
    nb.boxsize=4*opt.Rmax
  
  nb.setComovingIntegrationOff()
  
  # cvcenter
  nb.cvcenter()
  
  # write
  nb.write()


