#!/usr/bin/env python3
###########################################################################################
#  package:   pNbody
#  file:      gtree
#  copyright: GPLv3
#             Copyright (C) 2019 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Yves Revaz <yves.revaz@epfl.ch>
#
# This file is part of pNbody.
###########################################################################################


from pNbody import *
from pNbody import libgrid
from pNbody import myNumeric

import time

from optparse import OptionParser


########################################
#
# parser
#
########################################


def parse_options():

    usage = "usage: %prog [options] file"
    parser = OptionParser(usage=usage)

    parser.add_option("-t",
                      action="store",
                      dest="ftype",
                      type="string",
                      default='gadget',
                      help="type of the file",
                      metavar=" TYPE")

    parser.add_option("-i",
                      action="store",
                      dest="input_file",
                      type="string",
                      default=None,
                      help="input file",
                      metavar=" FILE")

    parser.add_option("--eps",
                      action="store",
                      dest="eps",
                      type="float",
                      default=0.25,
                      help="softening length",
                      metavar=" FLOAT")

    parser.add_option("--dTime",
                      action="store",
                      dest="dTime",
                      type="float",
                      default=0.1,
                      help="time step",
                      metavar=" FLOAT")

    parser.add_option("--TimeEnd",
                      action="store",
                      dest="TimeEnd",
                      type="float",
                      default=10,
                      help="final time",
                      metavar=" FLOAT")

    parser.add_option("--dOutputTime",
                      action="store",
                      dest="dOutputTime",
                      type="float",
                      default=0,
                      help="time between output",
                      metavar=" FLOAT")

    parser.add_option("--dStatTime",
                      action="store",
                      dest="dStatTime",
                      type="float",
                      default=0,
                      help="time between system statistics",
                      metavar=" FLOAT")

    parser.add_option("--theta",
                      action="store",
                      dest="theta",
                      type="float",
                      default=0.7,
                      help="time between output",
                      metavar=" FLOAT")

    (options, args) = parser.parse_args()

    return options


##############################################
# write integrals
##############################################


def WriteIntegrals(fe, t, T, U, C, P, L, I):

    E = T + U
    Cx = C[0]
    Cy = C[1]
    Cz = C[2]
    Px = P[0]
    Py = P[1]
    Pz = P[2]
    Lx = L[0]
    Ly = L[1]
    Lz = L[2]
    Ix = I[0]
    Iy = I[1]
    Iz = I[2]

    line = "%20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e %20.10e\n" % (
        t, E, T, U, Cx, Cy, Cz, Px, Py, Pz, Lx, Ly, Lz, Ix, Iy, Iz)
    fe.write(line)
    fe.flush()


##############################################
# compute energy kin
##############################################

def ComputeEnergyKin(nb):
    return sum(0.5 *
               nb.mass *
               (nb.vel[:, 0]**2 +
                nb.vel[:, 1]**2 +
                   nb.vel[:, 2]**2))


##############################################
# compute energy pot
##############################################

def ComputeEnergyPot(nb, eps):
    return sum(nb.Epot(eps))


##############################################
# compute mass center
##############################################

def ComputeMassCenter(nb):

    mass_tot = sum(nb.mass)
    cmx = sum(nb.pos[:, 0] * nb.mass) / mass_tot
    cmy = sum(nb.pos[:, 1] * nb.mass) / mass_tot
    cmz = sum(nb.pos[:, 2] * nb.mass) / mass_tot
    return array([cmx, cmy, cmz])


##############################################
# compute momentum
##############################################

def ComputeMomentum(nb):

    px = sum(nb.vel[:, 0] * nb.mass)
    py = sum(nb.vel[:, 1] * nb.mass)
    pz = sum(nb.vel[:, 2] * nb.mass)
    return array([px, py, pz])

##############################################
# compute angular momentum
##############################################


def ComputeAngularMomentum(nb):
    return nb.Ltot()


##############################################
# intertial momentum
##############################################

def ComputeInertialMomentum(nb):
    return nb.minert()


##########################################################################
#
#                                    MAIN
#
##########################################################################


options = parse_options()

ftype = options.ftype
input_file = options.input_file
eps = options.eps
dOutputTime = options.dOutputTime
dStatTime = options.dStatTime
TimeEnd = options.TimeEnd
dTime = options.dTime
ErrTolTheta = options.theta


# open file
print("open initial conditions")
nb = Nbody(input_file, ftype=ftype)

print("build the tree")
nb.getTree(ErrTolTheta=ErrTolTheta, force_computation=True)

print("compute acceleration")
nb.acc = nb.TreeAccel(nb.pos, eps)


# some init

Time = nb.atime

Step = 0
CPUTimeRef = time.time()
CPUTime = 0.0

OutputTime = 0
OutputNumber = 0

StatTime = 0

# open output
fi = open("integrals.dat", 'w')
fi.write("# t E T U Cx Cy Cz Px Py Pz Lx Ly Lz Ix Iy Iz\n")


#####################
# main loop
#####################

while (Time < TimeEnd):

    # write output
    if (Time >= OutputTime):

        outputname = 'snap_%04d' % (OutputNumber)
        print("Step %06d  writing %s" % (Step, outputname))
        nb.rename(outputname)
        nb.atime = Time
        nb.write()

        OutputNumber += 1
        OutputTime = Time + dOutputTime

    print("Step %06d  Time = %8.3f CPUTime=%8.1f" % (Step, Time, CPUTime))

    # leap-frog, first stage
    nb.vel = nb.vel + nb.acc * dTime / 2.		# vel to step n+1/2
    nb.pos = nb.pos + nb.vel * dTime		# pos to step n+1

    # make the tree
    nb.getTree(ErrTolTheta=ErrTolTheta, force_computation=True)

    # compute acceleration
    nb.acc = nb.TreeAccel(nb.pos, eps)

    # leap-frog, second stage
    nb.vel = nb.vel + nb.acc * dTime / 2.		# vel to step n+1

    # increment time
    Time = Time + dTime
    Step = Step + 1
    CPUTime = time.time() - CPUTimeRef

    # write stats
    if (Time >= StatTime):

        print("Compute System Statistic")

        # compute integrals
        # for the potential computation, we do not need to recompute the tree

        T = ComputeEnergyKin(nb)
        U = ComputeEnergyPot(nb, eps)
        C = ComputeMassCenter(nb)
        P = ComputeMomentum(nb)
        L = ComputeAngularMomentum(nb)
        I = ComputeInertialMomentum(nb)

        WriteIntegrals(fi, Time, T, U, C, P, L, I)

        StatTime = Time + dStatTime


# close files

fi.close()
