#!/usr/bin/env python3
###########################################################################################
#  package:   pNbody
#  file:      orbits_integration_MW
#  brief:    Backward  integration of particles to find their apocenter in the MWPotential2014 potential
#  copyright: GPLv3
#             Copyright (C) 2023 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Darwin Roduit <darwin.roduit@epfl.ch>
#
# This file is part of pNbody.
###########################################################################################

import argparse
import os
from tqdm import tqdm

from astropy.constants import G as G_a
from astropy import units as u
import matplotlib.pyplot as plt
import numpy as np
# from scipy.integrate import solve_ivp

from pNbody.mass_models import MWPotential2014 as MW
from pNbody.orbits import coordinates_utils
from pNbody.orbits import integration_utils
from pNbody.misc.scripts_utils import RawTextArgumentDefaultsHelpFormatter, store_as_array
from pNbody import orbits 

# %% Some functions


def function(t, y, rho_0, r_s, M_disk, a, b, alpha, r_c, amplitude, r_1, f_1,
             f_2, f_3, G, potential_choice):
    """Function for solving the ODE for Newton equations : dy/dt = f(t)."""
    # y = (x y z v_x v_y v_z)
    position = y[0:3]
    velocity = y[3:]

    if potential_choice == "MWPotential2014":
        acceleration = - MW.GradPot(rho_0, r_s, M_disk,
                                    a, b, alpha, r_c, amplitude, r_1, f_1, f_2,
                                    f_3, position, t, G)
    elif potential_choice == "MWPotential2014_time_evolved":
        acceleration = - integration_utils.GradPot_time_evolved(t, position, rho_0, r_s, M_disk,
                                                                a, b, alpha, r_c, amplitude,
                                                                r_1, f_1, f_2, f_3, G)
    else:
        raise ValueError("Potential model not implemented")

    return np.array([velocity, acceleration]).ravel()


def setup_plot(fig, ax, opt):
    """Prepares the plots and adapt the limits. """
    ax[0].set_aspect('equal', 'box')
    ax[0].set_ylabel(r"$y$ [kpc]")
    ax[0].set_xlabel(r"$x$ [kpc]")
    ax[1].set_aspect('equal', 'box')
    ax[1].set_ylabel(r"$z$ [kpc]")
    ax[1].set_xlabel(r"$x$ [kpc]")
    ax[2].set_aspect('equal', 'box')
    ax[2].set_ylabel(r"$z$ [kpc]")
    ax[2].set_xlabel(r"$y$ [kpc]")
    if len(ax)>3:
      ax[3].set_xlabel(r"$t$ [Gyr]")
      ax[3].set_ylabel(r"$r$ [kpc]")    
    plt.tight_layout()

    # Adapts the limits of the plots
    y_min_1, y_max_1 = ax[0].get_ylim()
    y_min_2, y_max_2 = ax[1].get_ylim()
    y_min_3, y_max_3 = ax[2].get_ylim()
    y_min = min(y_min_1, y_min_2, y_min_3)
    y_max = max(y_max_1, y_max_2, y_max_3)

    x_min_1, x_max_1 = ax[0].get_xlim()
    x_min_2, x_max_2 = ax[1].get_xlim()
    x_min_3, x_max_3 = ax[2].get_xlim()
    x_min = min(x_min_1, x_min_2, x_min_3)
    x_max = max(x_max_1, x_max_2, x_max_3)
    
    if opt.xmax is not None:
      x_min = -opt.xmax
      x_max =  opt.xmax
      y_min = -opt.xmax
      y_max =  opt.xmax
      

    ax[0].set_xlim([x_min, x_max])
    ax[1].set_xlim([x_min, x_max])
    ax[2].set_xlim([x_min, x_max])
    ax[0].set_ylim([y_min, y_max])
    ax[1].set_ylim([y_min, y_max])
    ax[2].set_ylim([y_min, y_max])

    ax[1].legend()


def plot_orbit(position, t, legend, fig, ax):
    """Plot the orbits of the objects. """
    if fig == None:  # create a new figure
        fig, ax = plt.subplots(nrows=1, ncols=4, num=1, figsize=(12, 4.1))
        ax[0].set_aspect('equal', 'box')
        ax[0].set_ylabel(r"$y$ [kpc]")
        ax[0].set_xlabel(r"$x$ [kpc]")
        ax[1].set_aspect('equal', 'box')
        ax[1].set_ylabel(r"$z$ [kpc]")
        ax[1].set_xlabel(r"$x$ [kpc]")
        ax[2].set_aspect('equal', 'box')
        ax[2].set_ylabel(r"$z$ [kpc]")
        ax[2].set_xlabel(r"$y$ [kpc]")
        if len(ax)>3:
          ax[3].set_aspect('auto', 'box')
          ax[3].set_ylabel(r"$t$ [Gyr]")
          ax[3].set_xlabel(r"$r$ [kpc]")
        plt.tight_layout()
    
    r = np.sqrt(position[0, :]**2+position[1, :]**2+position[2, :]**2)

    # Do the plot
    ax[0].plot(position[0, :], position[1, :])
    ax[1].plot(position[0, :], position[2, :], label=legend)
    ax[2].plot(position[1, :], position[2, :])
    
    if len(ax)>3:
      ax[3].plot(t, r)

    # Plot the starting point
    ax[0].plot(position[0, 0], position[1, 0], marker="x")
    ax[1].plot(position[0, 0], position[2, 0], marker="x")
    ax[2].plot(position[1, 0], position[2, 0], marker="x")
    if len(ax)>3:
      ax[3].plot(t[0], r[0], marker="x")


def load_data_from_file(file, coordinate):
    """Load the data from file. The format in the file must follow
    those conventions:
      - coordinate == ICRS:
        # Name  Distance modulus  RA [deg]  DEC [deg]  PM_RA [mas/yr]   PM_DEC [mas/yr]  v_rad [km/s]
      - coordinate == carthesian:
        # Name  x [kpc]  y [kpc]  z [kpc]  v_x [km/s]  v_y [km/s]  v_z [km/s]

    Return: dictionary of coordinates values
      - ICRS:
           {"coordinate", "n_dwarfs", "name", "distance_modulus", "RA", "DEC",
            "pm_RA", "pm_DEC", "v_rad"}
      - carthesian:
           {"coordinate", "n_dwarfs", "name", "x", "y", "z", "v_x", "v_y", "v_z"}
    """
    # Load the data from file
    dwarfs_data = np.loadtxt(file, comments="#", dtype=object)

    if dwarfs_data.ndim == 1:
        dwarfs_data = np.array([dwarfs_data])

    # Get the names of the dwarfs
    name = list(dwarfs_data[:, 0])

    # Remove the name column and convert to float
    dwarfs_data = np.array(dwarfs_data[:, 1:], dtype=float)

    # Get the number of dwarfs
    n_dwarfs = len(name)

    if coordinate == "ICRS":
        distance_modulus = dwarfs_data[:, 0]
        right_ascension = dwarfs_data[:, 1]
        declination = dwarfs_data[:, 2]
        pm_RA = dwarfs_data[:, 3]
        pm_DEC = dwarfs_data[:, 4]
        v_radial = dwarfs_data[:, 5]
        data_dict = {"coordinate": coordinate, "n_dwarfs": n_dwarfs, "name": name,
                     "distance_modulus": distance_modulus,
                     "RA": right_ascension, "DEC": declination, "pm_RA": pm_RA,
                     "pm_DEC": pm_DEC, "v_rad": v_radial}

    elif coordinate == "carthesian":
        x = dwarfs_data[:, 0]
        y = dwarfs_data[:, 1]
        z = dwarfs_data[:, 2]
        v_x = dwarfs_data[:, 3]
        v_y = dwarfs_data[:, 4]
        v_z = dwarfs_data[:, 5]
        data_dict = {"coordinate": coordinate, "n_dwarfs": n_dwarfs, "name": name,
                     "x": x, "y": y, "z": z, "v_x": v_x, "v_y": v_y, "v_z":
                     v_z}
    else:
        raise ValueError(
            "The coordinate type \"{}\" is not implemented !".format(coordinate))

    return data_dict


def convert_data(data_dict, unit_length, unit_velocity):
    """Converts the data from one coordinate system to carthesian
    galactocentric"""
    if data_dict["coordinate"] == "ICRS":
        initial_position, initial_velocity = coordinates_utils.convert_ICRS2carth(data_dict["distance_modulus"],
                                                                                  data_dict["RA"],
                                                                                  data_dict["DEC"],
                                                                                  data_dict["pm_RA"],
                                                                                  data_dict["pm_DEC"],
                                                                                  data_dict["v_rad"],
                                                                                  unit_length,
                                                                                  unit_velocity)
        return initial_position.value, initial_velocity.value
    elif data_dict["coordinate"] == "carthesian":
        initial_position = np.array(
            [data_dict["x"], data_dict["y"], data_dict["z"]])
        initial_velocity = np.array(
            [data_dict["v_x"], data_dict["v_y"], data_dict["v_z"]])
        return initial_position, initial_velocity


# %% Main computations part
description = """This scripts integrates orbits around a MW potential.
If position and velocities are provided, the integration is performed forwards.
If a file is provided, the script integrate backward the orbit of dwarf galaxies to retrieve their
apocenters. The apocenters data, the orbits plot and the r(t) plots are saved. Then, it integrates"
forwards to plot their orbits.

The plots are saved under the names backward_orbits.pdf, forward_orbits.pdf and r.pdf.
The apocenters data are saved under apocenters_data.npz. The position and velocity vectors are
stored as *column vectors*. For the forward and backward data, they are stored in
backward_integration_data.npz and forward_integration_data.npz. Notice that each dwarf's data is stored
in the third dimension.

Output units are kpc, km/s.
"""


epilog = """
Example:
-------


# integrate providing the positions and velocities (Bootes3)
--position 1.62655353 6.85191783 45.24958  --velocity 9.80652717 -23.25309335 251.84365236

# use initial conditions from distance, position on the sky and proper motions
orbits_integration_MW -i dwarfs_data.txt - o output/
orbits_integration_MW dwarfs_data.txt - o output/
orbits_integration_MW dwarfs_data.txt - o output / --delta_t 0.001

where the file dwarfs_data.txt contains for example:

#    Name    Distance modulus           RA         DEC       PM_RA    PM_DEC     v_rad
  Bootes1            19.11000     210.0225     14.50060      -0.39     -1.06    101.80
 Hercules            20.68000     247.7722     12.78520      -0.04     -0.34     45.00
  Tucana2            18.80000     342.9796    -58.56890       0.91     -1.27    129.10
  Bootes3            18.35000     209.3000     26.80000      -1.16     -0.88    197.50
   Draco2            16.67000     238.1740     64.57900       1.12      0.91   -342.50
   Segue1            16.80000     151.7504     16.07560      -2.06     -3.4opt.2    208.50
   Segue2            17.68000      34.8226     20.16240       1.43     -0.31    -40.20
  Tucana3            16.80000     359.1075    -59.58332      -0.08     -1.62   -101.20
  Tucana5            18.70000     354.3470    -63.26600      -0.13     -1.15    -36.20
       U1            15.00000     174.7074     31.07833      -0.87      1.15     88.60

or in carthesian coordinates:

# Name    x           y         z       v_x    v_y     v_z
Test1    10          12        15         1      5       1
Test2    25          78        21        10      4       5
"""


def parse_options():
    """Parser object to run the script from the command line. """
    parser = argparse.ArgumentParser(formatter_class=RawTextArgumentDefaultsHelpFormatter,
                                     description=description,
                                     epilog=epilog)  # Ajouter un copier coller du fichier input

    parser.add_argument('-i', dest='input_file', help="Location of the file containing "
                        "the dwarfs data. The file should have a header "
                        "containing, in this order: Name, Distance modulus, "
                        "RA[degrees], DEC[degrees], PM_R [mas/yr], "
                        "PM_DEC[mas/yr], v_rad[km/s]. The separation character "
                        "is free. Pay attention to spaces in the names that "
                        "will be treated as a distinct column. ", default=None)

    parser.add_argument('-o', dest='output_location',
                        help='Output directory where the output data are saved.')

    parser.add_argument("--coordinate_system", default="ICRS", type=str,
                        choices=["ICRS", "carthesian"],
                        help="Coordinate system of the data in the input file.")

    parser.add_argument('--t_backward', default=7, type=float,
                        help="Time of the backward integration, in Gyr. ")

    parser.add_argument('--t_forward', default=7, type=float,
                        help="Time of the forward integration, in Gyr. ")

    parser.add_argument('--delta_t', default=1e-2, type=float,
                        help="Integration timesteps.")

    parser.add_argument('--potential', default="MWPotential2014", type=str,
                        choices=["MWPotential2014",
                                 "MWPotential2014_time_evolved"],
                        help="Model of the host galaxy potential.")

    # A few option to choose the forward orbit launch point
    parser.add_argument("--forward_launch_option", default="initial",
                        type=str, choices=["initial", "apocenter", "custom"],
                        help="Choose if the forward integration starts from "
                        "the apocenter or from a custom position.")

    parser.add_argument("--position", default=None, action=store_as_array,
                        type=float, nargs=3, help="Position from where the "
                        "objects are launched for forward integration. "
                        "Carthesian coordinates in [kpc].")

    parser.add_argument("--velocity", default=None, action=store_as_array,
                        type=float, nargs=3, help="Velocity at which the "
                        "objects are launched for forward integration. "
                        "Carthesian coordinates in [km/s].")

    parser.add_argument("--dynamical_friction", default=False, action='store_true',
                        help="Use dynamical friction ")

    parser.add_argument("--lnLambda", default=5, action='store', type=float,
                        help="ln of Coulomb potential used for the dynamical friction")

    parser.add_argument("--satellite_mass", default=1e10, action='store', type=float,
                        help="Satellite mass in Msol")
                        
    parser.add_argument("--df_core_radius", default=10, action='store', type=float,
                        help="Radius below which the dynamical friction vanishes.")
                                                                                                
    parser.add_argument("--xmax", default=None, action='store', type=float,
                        help="plot xmax in kpc")
                        
                        
    args = parser.parse_args()
    return args


def set_orbit(potential_type,dynamical_friction,lnLambda,satellite_mass,df_core_radius):
    # Defines the units (kpc, 10^10 M_sun, 0.98 Gyr)---------------------------
    unit_length = 1*u.kpc
    unit_length = unit_length.to(u.kpc)
    unit_mass = u.M_sun*1e10
    unit_mass = unit_mass.to(1e10*u.M_sun)
    unit_velocity = 1*u.km/u.s
    unit_velocity = unit_velocity.to(unit_velocity)
    unit_time = unit_length/unit_velocity
    unit_time = unit_time.to(u.Gyr)
    unit_velocity = (unit_length/unit_time).to(u.cm/u.s)
    unit_velocity = unit_velocity.to(u.km/u.s)
    G = G_a.to(unit_length**3 * unit_mass**(-1) * unit_time**(-2))
    G = G.value
    satellite_mass = satellite_mass/(u.M_sun*1e10).value
    df_core_radius = df_core_radius*unit_length.value

    # Parameters---------------------------------------------------------------
    # MW parameters
    mw_params = MW.MWPotential2014_parameters()
    rho_0 = mw_params["rho_0"]
    r_s = mw_params["r_s"]
    M_disk = mw_params["M_disk"]
    a = mw_params["a"]
    b = mw_params["b"]
    alpha = mw_params["alpha"]
    r_c = mw_params["r_c"]
    amplitude = mw_params["amplitude"]
    r_1 = mw_params["r_1"]
    f_1 = mw_params["f_1"]
    f_2 = mw_params["f_2"]
    f_3 = mw_params["f_3"]

    # create the orbit object
    pot_args = (rho_0, r_s, M_disk, a, b, alpha, r_c, amplitude,
                r_1, f_1, f_2, f_3, G, potential_type,dynamical_friction,lnLambda,satellite_mass,df_core_radius)
                
                
                
    orbit = orbits.orbit(potential_arguments=pot_args)

    return orbit

def orbit_integration(opt, orbit):

    position = opt.position
    velocity = opt.velocity
    t_forwards = opt.t_forward
    delta_t = opt.delta_t

    # number of output points
    num_points = np.int32(np.ceil(t_forwards/delta_t))

    # initial conditions
    y_0 = np.concatenate((position, velocity))

    # do the integration
    orbit.Integrate(t0=0, t1=t_forwards, y0=y_0, npoints=num_points)

    # get results
    t = orbit.GetTimes()
    pos = orbit.GetPositions()
    vel = orbit.GetVelocities()

    # Plots the orbits
    legend = ""
    fig, ax = plt.subplots(nrows=1, ncols=4, num=1, figsize=(12, 4.1))
    plot_orbit(pos, t, legend, fig, ax)
    setup_plot(fig, ax, opt)
    plt.show()
    plt.close()
    
    if opt.output_location is not None:
      import csv
      with open(opt.output_location, 'w', newline='') as csvfile:
        fieldnames = ['t','x', 'y', 'z']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for i in range(len(pos[0,:])):
          writer.writerow({'t': t[i], 'x': pos[0,i], 'y': pos[1,i], 'z': pos[2,i]})
          
      
      
    

def dwarf_orbit_integration(input_file, output_location, t_backward, t_forward,
                            delta_t, coordinate_system, forward_launch_option, orbit, opt):

    # Defines the units (kpc, 10^10 M_sun, 0.98 Gyr)---------------------------
    unit_length = 1*u.kpc
    unit_length = unit_length.to(u.kpc)
    unit_mass = u.M_sun*1e10
    unit_mass = unit_mass.to(1e10*u.M_sun)
    unit_velocity = 1*u.km/u.s
    unit_velocity = unit_velocity.to(unit_velocity)
    unit_time = unit_length/unit_velocity
    unit_time = unit_time.to(u.Gyr)
    unit_velocity = (unit_length/unit_time).to(u.cm/u.s)
    unit_velocity = unit_velocity.to(u.km/u.s)
    G = G_a.to(unit_length**3 * unit_mass**(-1) * unit_time**(-2))
    G = G.value

    # Load the data
    data_dict = load_data_from_file(input_file, coordinate_system)
    name = data_dict["name"]
    n_dwarfs = data_dict["n_dwarfs"]

    # -------------------------------------------------------------------------
    print("Welcome to the dwarfs orbit integration script !")
    # Creates the output directory
    if not (os.path.exists(output_location)):
        print("Output directory does not exist. It will be created.")
        os.makedirs(output_location)

    # Adds the "/" at the end if needed
    if (output_location[-1] != "/"):
        output_location = output_location+"/"

    # Get initial positions and velocities ------------------------------------
    initial_position, initial_velocity = convert_data(
        data_dict, unit_length, unit_velocity)

    # Backward integration-----------------------------------------------------

    position_0 = initial_position
    velocity_0 = initial_velocity  # For the backward integration, we should change
    # velocity sign. However, since now (November 2023) the backward
    # integration uses negative times, the sign should not be changed.

    # Solves the ODE to find the apocenter
    num_points = np.int32(np.ceil(t_backward/delta_t))
    # t_integration = np.linspace(0, -t_backward, num_points)

    # Preallocation
    t = np.zeros((num_points, n_dwarfs))
    # y = np.zeros((6, num_points, n_dwarfs))
    position = np.zeros((3, num_points, n_dwarfs))
    velocity = np.zeros((3, num_points, n_dwarfs))

    print("Backward integration...")
    for i in tqdm(range(n_dwarfs)):
        y_0 = np.array([position_0[:, i], velocity_0[:, i]]).ravel()
        orbit.Integrate(t0=0, t1=-t_backward, y0=y_0, npoints=num_points)
        t[:, i] = orbit.GetTimes()
        position[:, :, i] = orbit.GetPositions()
        velocity[:, :, i] = orbit.GetVelocities()

    # Saves backward integration trajectory
    np.savez_compressed(output_location+"backward_integration_data",
                        position=position, velocity=velocity, time=t)

    # Finds the apocenter------------------------------------------------------
    r = np.linalg.norm(position, axis=0)

    if len(r.shape) > 1:  # more than one particle
        index_apo = np.argmax(r, axis=0)
    else:  # one particle
        index_apo = np.argmax(r)

    pos_apo = np.zeros((3, n_dwarfs))
    vel_apo = np.zeros((3, n_dwarfs))
    t_apo = np.zeros(n_dwarfs)
    r_apo = np.zeros(n_dwarfs)

    print("\nLooking for the apocenters of the orbits...")
    for i in tqdm(range(n_dwarfs)):
        pos_apo[:, i] = position[:, index_apo[i], i]
        vel_apo[:, i] = - velocity[:, index_apo[i], i]
        t_apo[i] = t[index_apo[i], i]
        r_apo[i] = r[index_apo[i], i]

    # Saving the apocenters data
    np.savez_compressed(output_location+"apocenters_data",
                        pos_apo=pos_apo, vel_apo=vel_apo, t_apo=t_apo)

    # Plots the orbits (backward)----------------------------------------------
    legend = name
    fig, ax = plt.subplots(nrows=1, ncols=3, num=1, figsize=(12, 4.1))
    for i in range(n_dwarfs):
        plot_orbit(position[:, :, i],None, legend[i], fig, ax)
    setup_plot(fig, ax, opt)
    plt.savefig(output_location + "backward_orbits.pdf", bbox_inches='tight')
    plt.close()

    # Plots r(t) and the position of the apocenter
    fig2, ax2 = plt.subplots(nrows=1, ncols=1, num=2)
    ax2.plot(t, r, label=legend)
    ax2.plot(t_apo, r_apo, linestyle="", marker="x", label="Apocenters")
    ax2.set_xlabel(r"$t$ [Gyr]")
    ax2.set_ylabel(r"$r$ [kpc]")
    ax2.legend()
    plt.savefig(output_location + "r.pdf", bbox_inches='tight')
    plt.close()

    # Forward integration------------------------------------------------------

    if (forward_launch_option == "initial"):
        t_start = np.zeros(n_dwarfs)
        pos_start = initial_position
        vel_start = initial_velocity

    elif (forward_launch_option == "apocenter"):
        t_start = t_apo
        pos_start = pos_apo
        vel_start = vel_apo
    else:
        # The start time does not matter
        # t_start = np.zeros(n_dwarfs)

        # # Arrange de position/velocity to repeat n_dwarfs time
        # pos_start = np.tile(np.array(args.position), (n_dwarfs, 1)).transpose()
        # vel_start = np.tile(np.array(args.velocity), (n_dwarfs, 1)).transpose()
        raise ValueError("This is no longer valid.")

    # Solves the ODE with forward integration
    num_points = np.int32(np.ceil(t_forward/delta_t))
    t_end = t_start + t_forward  # Find the right ending time
    # t_integration = np.linspace(t_start, t_end, num_points)

    # Preallocation
    t = np.zeros((num_points, n_dwarfs))
    # y = np.zeros((6, num_points, n_dwarfs))
    position = np.zeros((3, num_points, n_dwarfs))
    velocity = np.zeros((3, num_points, n_dwarfs))

    print("\nForward integration...")
    for i in tqdm(range(n_dwarfs)):
        y_0 = np.array([pos_start[:, i], vel_start[:, i]]).ravel()
        orbit.Integrate(t0=t_start[i], t1=t_end[i], y0=y_0, npoints=num_points)
        t[:, i] = orbit.GetTimes()
        position[:, :, i] = orbit.GetPositions()
        velocity[:, :, i] = orbit.GetVelocities()

    # Saves backward integration trajectory
    np.savez_compressed(output_location+"forward_integration_data",
                        position=position, velocity=velocity, time=t)

    # Plots the orbits after forward integration
    fig, ax = plt.subplots(nrows=1, ncols=3, num=3, figsize=(12, 4.1))
    for i in range(n_dwarfs):
        plot_orbit(position[:, :, i], t[:,i], legend[i], fig, ax)
    setup_plot(fig, ax, opt)
    plt.savefig(output_location + "forward_orbits.pdf", bbox_inches='tight')
    plt.close()
    print("End---------------------------------------------\n\n")
    return pos_apo, vel_apo, t_apo


# %% Main
if __name__ == '__main__':
    opt = parse_options()

    # set the orbit object from the potential
    orbit = set_orbit(opt.potential,opt.dynamical_friction,opt.lnLambda,opt.satellite_mass,opt.df_core_radius)

    if opt.position is not None and opt.velocity is not None:
        orbit_integration(opt, orbit)
    else:
        # do the integration and other stuffs using the
        dwarf_orbit_integration(opt.input_file, opt.output_location,
                                opt.t_backward, opt.t_forward, opt.delta_t, opt.coordinate_system, opt.forward_launch_option, orbit, opt)

