#!/usr/bin/env python3
###########################################################################################
#  package:   pNbody
#  file:      pNbody_checkall
#  copyright: GPLv3
#             Copyright (C) 2019 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Yves Revaz <yves.revaz@epfl.ch>
#
# This file is part of pNbody.
###########################################################################################

from pNbody import *
from pNbody import ic

import numpy as np

from optparse import OptionParser
import os

HOME = os.environ['HOME']


def random_string(n=10):
  import string
  s = np.array(list(string.ascii_lowercase))
  idxs = np.random.randint(0,len(s),n)
  return "".join(list(s[idxs]))


random_filename = os.path.join("/tmp/",random_string())



########################################
#
# parser
#
########################################


def parse_options():

    usage = "usage: %prog [options] file"
    parser = OptionParser(usage=usage)

    parser.add_option("-t",
                      action="store",
                      dest="ftype",
                      type="string",
                      default='gadget',
                      help="type of the file",
                      metavar=" TYPE")

    parser.add_option("-f",
                      action="store",
                      dest="file",
                      type="string",
                      default=os.path.join(HOME, random_filename),
                      help="output file name",
                      metavar=" FILE")

    parser.add_option("-n",
                      action="store",
                      dest="n",
                      type="int",
                      default=2**14,
                      help="number of particles",
                      metavar=" INT")

    (options, args) = parser.parse_args()

    files = args

    return files, options


#################################
#
# main
#
#################################


files, opt = parse_options()


print(72 * "#")
print("Testing %s format" % opt.ftype)
print(72 * "#")

ftype = opt.ftype
file = opt.file


# create file and save it
print("create an exponential disk")
nb = ic.expd(
    n=opt.n,
    Hr=3.,
    Hz=0.3,
    Rmax=20,
    Zmax=2,
    irand=1,
    name=file,
    ftype=ftype)
nb.write()

# read it
nb = Nbody(file, ftype=ftype)

# save it with another name
file2 = file + '.2'
nb.rename(file2)
nb.write()
nb.rename(file)

# compare the two files
print(72 * "#")
cmd = "diff %s %s.2" % (file, file)
print(cmd)
f = os.popen(cmd)
txt = f.readline()
f.close()
if len(txt) != 0:
    print(txt)
    print()
    print("Bad news : %s and %s.2 differs" % (file, file))
    sys.exit()
else:
    print("diff ok")

print(72 * "#")


params = param.Params(PARAMETERFILE, None)
uparams = param.Params(UNITSPARAMETERFILE, None)


#################################
#
# init functions
#
#################################

print("testing init functions...")
nb.init()
nb.set_ftype(ftype='gadget')
nb.get_num()
nb.get_default_spec_vars()
nb.get_default_spec_array()
nb.set_pio('no')
nb.rename('test.dat')
nb.set_filenames('test.dat')
nb.get_ntype()
nb.get_nbody()
nb.get_nbody_tot()
nb.get_npart()
nb.get_npart_tot()
nb.get_npart_all(nb.get_npart(), mpi.NTask)
nb.get_npart_and_npart_all(nb.get_npart())
nb.get_mxntpe()
nb.make_default_vars_global()
nb.set_npart(nb.npart)
nb.set_tpe(0)

#################################
#
# parameters functions
#
#################################

print("testing parameters functions...")
nb.set_parameters(params)
nb.set_unitsparameters(uparams)
nb.set_local_system_of_units()

#################################
#
# info functions
#
#################################

print("testing info functions...")
nb.info()
nb.spec_info()
nb.object_info()
nb.nodes_info()
nb.memory_info()
nb.print_filenames()

#################################
#
# list of variables functions
#
#################################

print("testing list of variables functions...")
nb.get_list_of_arrays()
nb.get_list_of_methods()
nb.get_list_of_vars()
nb.has_var('pos')
nb.has_array('pos')
nb.find_vars()


#################################
#
# check special values
#
#################################
nb.check_arrays()

#################################
#
# read/write functions
#
#################################

#print("testing read/write functions...")
# nb.read()
# nb.open_and_read(nb.p_name[0],nb.get_read_fcts()[0])
# nb.rename('treo0020.000b')
# nb.write()
# nb.open_and_write(nb.p_name[0],nb.get_write_fcts()[0])
# nb.write_num('num.dat')
# nb.read_num('num.dat')


#################################
#
# coordinate transformation
#
#################################

print("testing coordinate transformation functions...")
nb.x()
nb.y()
nb.z()
nb.rxyz()
nb.phi_xyz()
nb.theta_xyz()
nb.rxy()
nb.phi_xy()
nb.r()
nb.R()
nb.cart2sph()
nb.sph2cart()
nb.vx()
nb.vy()
nb.vz()
nb.vn()
nb.vrxyz()
nb.Vr()
nb.Vt()
nb.Vz()
nb.vel_cyl2cart()
nb.vel_cart2cyl()

#################################
#
# physical values
#
#################################

print("testing physical values functions...")
nb.get_ns()
nb.get_mass_tot()
nb.size()
nb.cm()
nb.get_histocenter()
nb.get_histocenter2()
nb.cv()
nb.minert()
nb.x_sigma()
nb.v_sigma()
nb.dx_mean()
nb.dv_mean()
nb.Ekin()
nb.ekin()
nb.Epot(0.1)
nb.epot(0.1)
nb.L()
nb.l()
nb.Ltot()
nb.ltot()
nb.Pot([0, 0, 0], 0.1)
nb.TreePot(np.array([[0, 0, 0]], np.float32), eps=0.1)
nb.Accel([0, 0, 0], 0.1)
nb.TreeAccel(np.array([[0, 0, 0]], np.float32), eps=0.1)
nb.tork(nb.vel)
nb.dens()						# bof
nb.mdens()						# bof
nb.mr()							# bof
nb.Mr_Spherical()					# bof
nb.sdens()						# bof
nb.msdens()						# bof
nb.sigma_z()						# bof
nb.sigma_vz()						# bof
nb.zprof()
nb.sigma()
#nb.histovel()
# nb.zmodes()
# nb.dmodes()

nb.getRadiusInCylindricalGrid(0, 10)
nb.getAccelerationInCylindricalGrid(0.1, 0, 10)
nb.getPotentialInCylindricalGrid(0.1, 0, 10)
nb.getSurfaceDensityInCylindricalGrid(10)
nb.getNumberParticlesInCylindricalGrid(10)
nb.getRadialVelocityDispersionInCylindricalGrid(10)


#################################
#
# geometrical operations
#
#################################

print("testing geometrical operations functions...")
nb.cmcenter()
nb.cvcenter()
nb.histocenter()
nb.histocenter2()
nb.hdcenter()
nb.translate([10, 0, 0])
nb.rebox()
nb.rotate(axis='y', angle=np.pi / 2)
nb.rotate(axis=[1, 1, 1], angle=np.pi / 2)
nb.align(axis=[1, 1, 1])
nb.align2()
nb.spin()


#################################
#
# selection of particles
#
#################################

print("testing selection of particles functions...")
nb.selectc(nb.num < 100)
nb.tpe = np.where(nb.num > 10, 1, 0)
nb.tpe = np.where(nb.num > nb.nbody - 10, 2, nb.tpe)
nb = nb.select(1)
nb = nb.sub(2, 12)
nb = nb.reduc(2)
nb = nb.selectp([17, 19])
nb.getindex(19)

#################################
#
# add particles
#
#################################

print("testing add particles functions...")
nb1 = Nbody(file, ftype=ftype)
nb2 = Nbody(file, ftype=ftype)
nb1.append(nb2)
nb = nb1 + nb2

#################################
#
# sort particles
#
#################################

print("testing sort particles functions...")
nb = nb.sort()
nb = nb.sort_type()


#################################
#
# Tree and SPH functions
#
#################################

print("testing Tree and SPH functions...")
nb = Nbody(file, ftype=ftype)
nb.InitSphParameters()
nb.setTreeParameters(nb.Tree, 33, 3)
nb.getTree()
nb.get_rsp_approximation()
nb.ComputeSph()
nb.ComputeDensityAndHsml()
nb.SphEvaluate(nb.vel[:, 0])

#################################
#
# sph functions
#
#################################

"""
print "testing  sph functions..."
nb.weighted_numngb(1)
nb.real_numngb(1)
nb.usual_numngb(1)
"""

#################################
#
# redistribution of particles
#
#################################

print("testing redistribution of particles functions...")
nb.redistribute()
# nb.ExchangeParticles()					# not tested, need ptree


#################################
#
# specific parallel functions
#
#################################

print("testing specific parallel functions...")
nb.gather_pos()
nb.gather_vel()
nb.gather_mass()
nb.gather_num()
nb.gather_vec(nb.pos)

#################################
#
# graphical operations
#
#################################

print("testing graphical operations functions...")

print()

print(72 * "#")
print("Please, check that the file tmp.png has been successfully written.")
print(72 * "#")

nb.display(size=[30, 30],save="tmp.png")
# nb.show(size=[30,30])	# it craches if the previous image has been closed...

nb.Map()
nb.CombiMap()
nb.ComputeMeanMap(mode1='0')
nb.ComputeSigmaMap(mode1='0', mode2='0')
nb.ComputeMap(mode='0')
# nb.expose()						# tested by map


#################################
#
# 1d histograms
#
#################################

print("testing 1d histograms functions...")
nb.Histo(bins=np.array([0, 1, 2, 3, 4, 5]))
nb.CombiHisto(bins=np.array([0, 1, 2, 3, 4, 5]))
nb.ComputeMeanHisto(np.array([0, 1, 2, 3, 4, 5]), mode1='m', space='R')
nb.ComputeSigmaHisto(np.array([0, 1, 2, 3, 4, 5]),
                     mode1='m', mode2='m', space='R')
nb.ComputeHisto(np.array([0, 1, 2, 3, 4, 5]), mode='m', space='R')


############################################
#
# Routines to get velocities from positions
#
############################################

print("testing routines to get velocities from positions...")
nb.Get_Velocities_From_Spherical_Grid()
#nb.Get_Velocities_From_Cylindrical_Grid()


############################################
#
# evolution routines
#
############################################

# nb.IntegrateUsingRK()


#################################
#
# Thermodynamic functions
#
#################################
"""
nb.U()							# not tested
nb.Rho()						# not tested
nb.T()							# not tested
nb.MeanWeight()						# not tested
nb.Tmu()						# not tested
nb.A()							# not tested
nb.P()							# not tested
nb.Tcool()						# not tested
nb.Ne()							# not tested
nb.S()							# not tested
nb.Lum()						# not tested
"""


#################################
#
# the end
#
#################################

print(72 * "#")
print("Good News ! pNbody with format %s is working !" % ftype)
print("Use pNbody_test to check other formats.")
print(72 * "#")


#################################
#
# some info
#
#################################
print()
print("You are using the following paths")
parameters.print_path()
