#!/usr/bin/env python3

import argparse
import numpy as np
from pNbody import *
from pNbody.mass_models import hernquist
from pNbody.DF import DistributionFunction
from astropy import constants as cte
from astropy import units as u


####################################################################
# option parser
####################################################################

description="generate an hernquist model at equilibrium"
epilog     ="""
Examples:
--------
ic_hernquist -N 730000 --Rmin 1e-3 --Rmax 100 --rho0 0.1 -a 0.1  -t swift -o hernquist.hdf5
"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)




parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  


parser.add_argument("-t",
                    action="store",
                    type=str,
                    dest="ftype",
                    default='gh5',
                    help="Type of the output file")  

parser.add_argument("--irand",
                    action="store", 
                    dest="irand", 
                    metavar='INT', 
                    type=int,
                    default=0,
                    help='random seed') 

parser.add_argument("--ptype",
                    action="store", 
                    dest="ptype", 
                    metavar='INT', 
                    type=int,
                    default=1,
                    help='particle type') 
                                 
parser.add_argument("--plot",
                    action="store_true",
                    dest="plot",
                    default=False,
                    help="plot dynamical time")  
                                 
parser.add_argument("-N",
                    action="store", 
                    dest="N", 
                    metavar='INT', 
                    type=int,
                    default=1e5,
                    help='total number of particles') 
         
                    

parser.add_argument("-a",
                    action="store",
                    type=float,
                    dest="a",
                    default=1,
                    help="a parameter")

parser.add_argument("--rho0",
                    action="store",
                    type=float,
                    dest="rho0",
                    default=0.1,
                    help="density at critical radius")


parser.add_argument("-b",
                    action="store",
                    type=float,
                    dest="b",
                    default=0.3,
                    help="b parameter")

parser.add_argument("--Rmin",
                    action="store",
                    type=float,
                    dest="Rmin",
                    default=1e-2,
                    help="Rmin")

parser.add_argument("--Rmax",
                    action="store",
                    type=float,
                    dest="Rmax",
                    default=100,
                    help="Rmax")

parser.add_argument("--NE",
                    action="store",
                    type=float,
                    dest="NE",
                    default=1e4,
                    help="number of energy bins")

parser.add_argument("--NR",
                    action="store",
                    type=float,
                    dest="NR",
                    default=1e4,
                    help="number of radius bins")
              
parser.add_argument("--Ndraw",
                    action="store",
                    type=int,
                    dest="Ndraw",
                    default=1e6,
                    help="number of particles to draw at each loop")  

parser.add_argument("--mass",
                    action="store",
                    type=float,
                    dest="mass",
                    default=None,
                    help="particle mass in solar mass")                      
                  
parser.add_argument("-q","--quiet",
                    action="store_true",
                    dest="quiet",
                    default=False,
                    help="quiet mode") 
                    

########################################
# main
########################################

opt = parser.parse_args()

if opt.quiet:
  import warnings
  warnings.filterwarnings("ignore")
  
# define units
u_Length   = 1* u.kpc
u_Mass     = 10**10 * u.M_sun
u_Velocity = 1* u.km/u.s
u_Time     = u_Length/u_Velocity 

u_dict = {}
u_dict["UnitLength_in_cm"]         = u_Length.to(u.cm).value
u_dict["UnitMass_in_g"]            = u_Mass.to(u.g).value
u_dict["UnitVelocity_in_cm_per_s"] = u_Velocity.to(u.cm/u.s).value
u_dict["Unit_time_in_cgs"]         = u_Time.to(u.s).value

# define constants values
G = cte.G.to( u_Length**3/(u_Mass*u_Time**2) ).value

# here, we can provide the analytical values
fctDensity        = lambda x:hernquist.Density(rho0=opt.rho0, a=opt.a, r=x, G=G)
fctPotential      = lambda x:hernquist.Potential(rho0=opt.rho0, a=opt.a, r=x, G=G)
fctCumulativeMass = lambda x:hernquist.CumulativeMass(rho0=opt.rho0, a=opt.a, r=x, G=G)       
TotalMass         = hernquist.CumulativeMass(rho0=opt.rho0, a=opt.a, r=opt.Rmax, G=G)  # unbounded if rmax=inf
fctVcirc          = lambda x:hernquist.Vcirc(rho0=opt.rho0, a=opt.a, r=x, G=G)  

# set the number of particles
if opt.mass is not None:
  opt.N = int(TotalMass*toMsol/opt.mass)


DF = DistributionFunction(Rmin=opt.Rmin,Rmax=opt.Rmax,Rinf=np.inf,NR=opt.NR,NE=opt.NE,
                          fctDensity=fctDensity,
                          fctPotential=fctPotential,
                          fctCumulativeMass=fctCumulativeMass,
                          TotalMass=TotalMass,G=G)

DF.computeDF()
DF.clean()
DF.computeMaxLikelihood()
DF.sample(opt.N,opt.Ndraw,irand=opt.irand)

if opt.outputfilename:
  
  nb = Nbody(status='new',p_name=opt.outputfilename,pos=DF.pos,vel=DF.vel,mass=DF.mass,ftype=opt.ftype)
  
  # add units
  nb.UnitLength_in_cm         = u_dict["UnitLength_in_cm"]
  nb.UnitMass_in_g            = u_dict["UnitMass_in_g"]           
  nb.UnitVelocity_in_cm_per_s = u_dict["UnitVelocity_in_cm_per_s"]
  nb.Unit_time_in_cgs         = u_dict["Unit_time_in_cgs"]
  
  # additional stuffs
  if opt.ftype == "gh5":
    nb.massarr=None
    nb.nzero = None
  if opt.ftype == "swift":
    nb.boxsize=4*opt.Rmax
  
  nb.setComovingIntegrationOff()
  
  # particle type
  nb.set_tpe(opt.ptype)
  
  # cvcenter
  nb.cvcenter()
  
  # write
  nb.write()


if opt.plot:
  import matplotlib.pyplot as plt
  # plot DynTime
  vcirc = fctVcirc(DF.R)    
  T = 2*np.pi*DF.R/vcirc
  plt.plot(DF.R,T)
  plt.loglog()
  plt.show()
