#!/usr/bin/env python3
###########################################################################################
#  package:   pNbody
#  file:      orbits_launch_MW
#  brief:     Launches dwarf galaxies in orbit of thw MW, modeled by
#             MWPotential2014 potential.
#  copyright: GPLv3
#             Copyright (C) 2023 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Darwin Roduit <darwin.roduit@epfl.ch>
#
# This file is part of pNbody.
###########################################################################################

from pNbody.mass_models import MWPotential2014 as MW
from pNbody import Nbody
from pNbody.orbits import coordinates_utils
from pNbody.misc.scripts_utils import RawTextArgumentDefaultsHelpFormatter, store_as_array
import argparse

import numpy as np
from astropy import units as u
from astropy.constants import G as G_a
import astropy.coordinates as coord
_ = coord.galactocentric_frame_defaults.set("v4.0")

# %% Functions

# %%Orbit launcher


def parse_options():
    parser = argparse.ArgumentParser(formatter_class=RawTextArgumentDefaultsHelpFormatter,
                                     description='Launches a dwarf galaxy into the orbit of another galaxy. Units : '
                                     "kpc, 10^10 M_sun, km/s, 0.97 Gyr. The script return the boxsize "
                                     "and the circular orbit period.",
                                     epilog="Example\n------------\n \norbits_launch_MW nfw_plummer.hdf5 dwarf_launched.hdf 0 0 90 --velocity_direction 1 0 0 \n \n")

    parser.add_argument('filename', help='File of the initial conditions')
    parser.add_argument('-o', dest='outputfilename',
                        help='Output file in which the output data are saved.')
    parser.add_argument('--pos_orbit', dest="pos_orbit", action=store_as_array, type=float, nargs=3,
                        help='Position from where the dwarf galaxy is launched')
                        
    parser.add_argument('--position_coord', action="store", default="carthesian",
                        choices=["carthesian", "spherical",
                                 "cylindrical", "ICRS"],
                        help="Coordinate system of the option pos_orbit.\nNotice "
                        "that for ICRS coordinates to work the --velocity_coord "
                        "must also be set to ICRS. \nIn this case, --pos_orbit=(RA, DEC, distance_modulus) "
                        "and --vel_orbit=(pm_RA, pm_DEC, v_radial). \nUnits : (Degrees, Degrees, kpc) and (mas/yr, mas/yr, km/s).")

    parser.add_argument('--boxsize', default=None, type=float,
                        help='Size of the box of the simulation. In kpc')

    parser.add_argument('--vel_orbit', action=store_as_array, default=None,
                        type=float, nargs=3, help='Velocity at which the dwarf galaxy is launched')
    parser.add_argument('--velocity_direction', action=store_as_array, type=float,
                        nargs=3, help="Direction to project the circular velocity.")
    parser.add_argument('--velocity_coord', default="carthesian",
                        choices=["carthesian", "spherical", "ICRS"],
                        help="Coordinate system of the options --vel_orbit and "
                        "--velocity_direction. \nNotice that for ICRS coordinates "
                        "to work the --position_coord must also be set to ICRS. \n"
                        "In this case, --pos_orbit=(RA, DEC, distance_modulus) "
                        "and --vel_orbit=(pm_RA, pm_DEC, v_radial). \nUnits : (Degrees, Degrees, kpc) and (mas/yr, mas/yr, km/s).")

    parser.add_argument("--boxsize_multiplier", help="Value to adjust the boxsize of the simulation",
                        default=2.0, type=float)
    parser.add_argument("--T_simulation", help="Time of the simulation.",
                        default=None, type=float)

    args = parser.parse_args()
    return args


def launch_on_orbits(args):
    # --------------------------------------------
    # ------------User entered values-------------
    # --------------------------------------------

    # --Potential parameters
    # Retrieve the parameters from the package
    params = MW.MWPotential2014_parameters()

    # NFW parameters
    rho_0 = params["rho_0"]
    r_s = params["r_s"]

    # MN parameters
    M_disk = params["M_disk"]
    a = params["a"]
    b = params["b"]

    # PSC parameters
    alpha = params["alpha"]
    r_c = params["r_c"]
    amplitude = params["amplitude"]
    r_1 = params["r_1"]

    # Potential contribution
    f_1 = params["f_1"]
    f_2 = params["f_2"]
    f_3 = params["f_3"]

    # --------------------------------------------
    # --------------------------------------------
    # --------------------------------------------
    # Sets the output precision for arrays
    np.set_printoptions(precision=3)

    nb = Nbody(args.filename, ftype="swift")
    nb.rename(p_name=args.outputfilename)  # sets the output filename

    # --Units
    unit_length = nb.UnitLength_in_cm*u.cm
    unit_mass = nb.UnitMass_in_g*u.g
    unit_velocity = nb.UnitVelocity_in_cm_per_s*u.cm/u.s
    unit_time = unit_length/unit_velocity

    # Determines the value of G in those units - used to compute v_circ
    G = G_a.to(unit_length**3 * unit_mass**(-1) * unit_time**(-2))
    G = G.value

    # --Determines the position coordinates where we want to place the dwarf galaxy
    # converts to carthesian
    if args.position_coord == "spherical":
        args.pos_orbit = np.array(coordinates_utils.convert_sph2carth(*args.pos_orbit))
    elif args.position_coord == "cylindrical":
        args.pos_orbit = np.array(coordinates_utils.convert_cyl2carth(*args.pos_orbit))
    elif args.position_coord == "carthesian":
        # pos_orbit = args.pos_orbit
        pass
    elif (args.position_coord == "ICRS") and (args.position_coord == "ICRS"):
        # Converts Gaia coordinates to Galactocentric carthesian coordinates
        distance_modulus = args.pos_orbit[2]
        dist_particles = coordinates_utils.compute_distance(distance_modulus)*u.pc
        dist_particles = dist_particles.to(u.kpc).value
        ics = coord.SkyCoord(ra=args.pos_orbit[0]*u.degree, dec=args.pos_orbit[1]*u.degree,
                             distance=dist_particles*u.kpc,
                             pm_ra_cosdec=args.vel_orbit[0]*u.mas/u.yr,
                             pm_dec=args.vel_orbit[1]*u.mas/u.yr,
                             radial_velocity=args.vel_orbit[2]*u.km/u.s)
        ics = ics.transform_to(coord.Galactocentric)
        args.pos_orbit = ics.data.get_xyz().to(unit_length)
        args.vel_orbit = ics.velocity.get_d_xyz().to(unit_velocity)
    else:
        raise ValueError("The given coordinate system for the position \"{}\" is not valid.".format(
            args.position_coord))

    # Then translate all the particles to this position
    print("Placing the galaxy in orbit at position x = {} * {}.".format(
        args.pos_orbit, unit_length[0].to(u.kpc)))
    nb.translate(args.pos_orbit, mode="p")

    # --Sets the boxsize
    if args.boxsize is None:
        r_max = np.max(nb.rxyz())  # gets the maximal radius
        args.boxsize = 2*r_max*args.boxsize_multiplier
    nb.boxsize = float(args.boxsize)
    print("Size of the box : {}".format(nb.boxsize))

    # If there is no given velocity (i.e None), computes the circular velocity at pos_orbit
    r = np.linalg.norm(args.pos_orbit)
    if args.vel_orbit is None:
        # print("Is none")
        R, phi, z = coordinates_utils.convert_cart2cyl(*args.pos_orbit)
        v_circ = MW.Vcirc(rho_0, r_s, M_disk, a, b, alpha, r_c,
                          amplitude, r_1, f_1, f_2, f_3, R, z, G)
        T_orbit = 2*np.pi*r/v_circ  # orbital period on circular orbit
        print("The circular velocity at {} is v_circ = {:.2e} * {}.".format(
            args.pos_orbit, v_circ, unit_velocity.to(u.km/u.s)))
        print(
            "The orbital period is T_circ = {:.2e} * {}.".format(T_orbit, unit_time[0]))

        # Then, we project this norm onto the chosen direction (x,y,z)
        if args.velocity_coord == "spherical":
            # If the direction is given in circular coodinates.
            # May be deleted later
            # converts to carthesian
            print("spherical")
            print("For now, does nothing.")
        elif args.velocity_coord == "carthesian":
            norm_velocity_direction = np.linalg.norm(args.velocity_direction)
            if np.abs(norm_velocity_direction - 1) > 1e-14:  # if the norm is not 1
                args.velocity_direction /= np.linalg.norm(
                    args.velocity_direction)
            args.vel_orbit = args.velocity_direction*v_circ
        else:
            raise ValueError("The given coordinate system for the velocity \"{}\" is not valid.".format(
                args.velocity_coord))
    else:
        if (args.position_coord == "ICRS") and (args.velocity_coord == "ICRS"):
            # Converts Gaia coordinates to Galactocentric carthesian coordinates
            # Conversion already done in the position section
            pass
        elif args.position_coord == "carthesian":
            pass
        else:
            raise ValueError(
                "The given coordinate system for the velocity \"{}\" is not valid.".format(args.vel_orbit))
        # Estimate the period by assuming circular orbit
        v_norm = np.linalg.norm(args.vel_orbit)
        if v_norm == 0.0:
            T_orbit = 5.0
        else:
            T_orbit = 2*np.pi*r / \
                np.linalg.norm(args.vel_orbit)  # orbital period

    if args.T_simulation is not None:
        T_orbit = args.T_simulation
    print(
        "The orbital period is T_orbit = {:.2e} * {}.".format(T_orbit, unit_time[0]))

    # Then translate all the particles velocities
    nb.translate(args.vel_orbit, mode="v")
    print("The galaxy is launched at velocity v = {} * {}.".format(args.vel_orbit,
          unit_velocity.to(u.km/u.s)))

    # Write everything in a file
    print("Writing the output to {}".format(args.outputfilename))
    nb.write()

    return nb.boxsize, T_orbit


if __name__ == '__main__':
    args = parse_options()
    boxsize, T_orbit = launch_on_orbits(args)
    print(boxsize, T_orbit)
