#!/usr/bin/env python3

import argparse
import numpy as np
from pNbody import *
from pNbody.mass_models import nfw
from pNbody.DF import DistributionFunction
from astropy import constants as cte
from astropy import units as u



####################################################################
# option parser
####################################################################

description="generate an nfw model at equilibrium"
epilog     ="""
Examples:
--------
ic_nfw       -N 730000 --Rmin 1e-3 --Rmax 100  --rho0 0.1   --rs 0.1    -t swift -o nfw.hdf5
ic_nfw       -N 730000 --Rmin 1e-3 --Rmax 100  --c    8     --M200 1e10 -t swift -o nfw.hdf5

"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)




parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  


parser.add_argument("-t",
                    action="store",
                    type=str,
                    dest="ftype",
                    default='gh5',
                    help="Type of the output file")  

parser.add_argument("--irand",
                    action="store", 
                    dest="irand", 
                    metavar='INT', 
                    type=int,
                    default=0,
                    help='random seed') 

parser.add_argument("--ptype",
                    action="store", 
                    dest="ptype", 
                    metavar='INT', 
                    type=int,
                    default=1,
                    help='particle type') 
                       
parser.add_argument("--plot",
                    action="store_true",
                    dest="plot",
                    default=False,
                    help="plot dynamical time")                        
                       
                                 
parser.add_argument("-N",
                    action="store", 
                    dest="N", 
                    metavar='INT', 
                    type=int,
                    default=1e5,
                    help='total number of particles') 
         
                    

parser.add_argument("--rs",
                    action="store",
                    type=float,
                    dest="rs",
                    default=None,
                    help="rs parameter")

parser.add_argument("--rho0",
                    action="store",
                    type=float,
                    dest="rho0",
                    default=None,
                    help="density at critical radius")


parser.add_argument("--c",
                    action="store",
                    type=float,
                    dest="c",
                    default=None,
                    help="concentration parameter")

parser.add_argument("--M200",
                    action="store",
                    type=float,
                    dest="M200",
                    default=None,
                    help="virial mass in Msol")




parser.add_argument("--Rmin",
                    action="store",
                    type=float,
                    dest="Rmin",
                    default=1e-2,
                    help="Rmin")

parser.add_argument("--Rmax",
                    action="store",
                    type=float,
                    dest="Rmax",
                    default=100,
                    help="Rmax")

parser.add_argument("--NE",
                    action="store",
                    type=float,
                    dest="NE",
                    default=1e4,
                    help="number of energy bins")

parser.add_argument("--NR",
                    action="store",
                    type=float,
                    dest="NR",
                    default=1e4,
                    help="number of radius bins")
              
parser.add_argument("--Ndraw",
                    action="store",
                    type=int,
                    dest="Ndraw",
                    default=1e6,
                    help="number of particles to draw at each loop")  

parser.add_argument("--mass",
                    action="store",
                    type=float,
                    dest="mass",
                    default=None,
                    help="particle mass in solar mass")           

parser.add_argument("-q","--quiet",
                    action="store_true",
                    dest="quiet",
                    default=False,
                    help="quiet mode") 


########################################
# main
########################################

opt = parser.parse_args()

if opt.quiet:
  import warnings
  warnings.filterwarnings("ignore")


# define units
u_Length   = 1* u.kpc
u_Mass     = 10**10 * u.M_sun
u_Velocity = 1* u.km/u.s
u_Time     = u_Length/u_Velocity 
toMsol     = u_Mass.to(u.M_sun).value

u_dict = {}
u_dict["UnitLength_in_cm"]         = u_Length.to(u.cm).value
u_dict["UnitMass_in_g"]            = u_Mass.to(u.g).value
u_dict["UnitVelocity_in_cm_per_s"] = u_Velocity.to(u.cm/u.s).value
u_dict["Unit_time_in_cgs"]         = u_Time.to(u.s).value

# define constants values
G = cte.G.to( u_Length**3/(u_Mass*u_Time**2) ).value


# define the NFW model
if (opt.rho0 is not None) and (opt.rs is not None):
  nfw = nfw.NFW(rho0=opt.rho0,rs=opt.rs,G=G)
elif (opt.c is not None) and (opt.M200 is not None): 
  from astropy.cosmology import Planck18 as cosmo
  M200 = (opt.M200*u.Msun).to(u_Mass).value
  H0   = cosmo.H0.to(1/u_Time).value
  nfw = nfw.NFW(c=opt.c,M200=M200,G=G,H0=H0)
  nfw.info()
else:
  raise ValueError("Either (rho0 and rs) or (c and M200) must be defined.")

# here, we can provide the analytical values
fctDensity        = lambda x:nfw.Density(x)
fctPotential      = lambda x:nfw.Potential(x)
fctCumulativeMass = lambda x:nfw.CumulativeMass(x)       
TotalMass         = nfw.CumulativeMass(opt.Rmax)  # unbounded if rmax=inf
fctVcirc          = lambda x:nfw.Vcirc(x)  

# set the number of particles
if opt.mass is not None:
  opt.N = int(TotalMass*toMsol/opt.mass)


DF = DistributionFunction(Rmin=opt.Rmin,Rmax=opt.Rmax,Rinf=np.inf,NR=opt.NR,NE=opt.NE,
                          fctDensity=fctDensity,
                          fctPotential=fctPotential,
                          fctCumulativeMass=fctCumulativeMass,
                          TotalMass=TotalMass,G=G)

DF.computeDF()
DF.clean()
DF.computeMaxLikelihood()
DF.sample(opt.N,opt.Ndraw,irand=opt.irand)

if opt.outputfilename:
  
  nb = Nbody(status='new',p_name=opt.outputfilename,pos=DF.pos,vel=DF.vel,mass=DF.mass,ftype=opt.ftype)
  
  # add units
  nb.UnitLength_in_cm         = u_dict["UnitLength_in_cm"]
  nb.UnitMass_in_g            = u_dict["UnitMass_in_g"]           
  nb.UnitVelocity_in_cm_per_s = u_dict["UnitVelocity_in_cm_per_s"]
  nb.Unit_time_in_cgs         = u_dict["Unit_time_in_cgs"]
  
  # additional stuffs
  if opt.ftype == "gh5":
    nb.massarr=None
    nb.nzero = None
  if opt.ftype == "swift":
    nb.boxsize=4*opt.Rmax
  
  nb.setComovingIntegrationOff()
  
  # particle type
  nb.set_tpe(opt.ptype)
  
  # cvcenter
  nb.cvcenter()
  
  # write
  nb.write()

if opt.plot:
  import matplotlib.pyplot as plt
  # plot DynTime
  vcirc = fctVcirc(DF.R)    
  T = 2*np.pi*DF.R/vcirc
  plt.plot(DF.R,T)
  plt.loglog()
  plt.show()
