#!/usr/bin/python3
###########################################################################################
#  package:   pNbody
#  file:      plotSphericalProfile
#  brief:     
#  copyright: GPLv3
#             Copyright (C) 2019 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Yves Revaz <yves.revaz@epfl.ch>
#
# This file is part of Gtools.
###########################################################################################

import os
import numpy as np
import argparse

import matplotlib.pyplot as plt
from pNbody import plot
from pNbody import *

####################################################################
# option parser
####################################################################

description="""plot different physical quantities as a function of the 3-d radius, 
assuming spherical quantities."""
epilog     ="""
Examples:
--------

# density profile
plotSphericalProfile -y density --forceComovingIntegrationOff --xmax 10 --rmax 10 --nr 64  --log xy  snapshot.hdf5  
plotSphericalProfile -y density --forceComovingIntegrationOff --xmax 10 --rmax 10 --nr 64  --log xy  --select stars snapshot.hdf5  

# velocity dispersion profile
plotSphericalProfile -y sigmaz --forceComovingIntegrationOff --ymin 0 --xmax 10 --rmax 10 --nr 64    snapshot.hdf5
plotSphericalProfile -y sigmaz --forceComovingIntegrationOff --ymin 0 --xmax 10 --rmax 10 --nr 64    --select stars snapshot.hdf5

# mass profile
plotSphericalProfile -y mass    --ymin 0 --xmax 50 --rmax 50 --nr 64    snapshot.hdf5

# integrated mass profile
plotSphericalProfile -y imass   --ymin 0 --xmax 50 --rmax 50 --nr 64    snapshot.hdf5

# circular velocity
plotSphericalProfile -y vcirc      --xmax 50 --rmax 50 --nr 64 --eps 0.1   snapshot.hdf5

# dynamical time
plotSphericalProfile -y tdyn      --xmax 50 --rmax 50 --nr 64 --eps 0.1   snapshot.hdf5
"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)

plot.add_files_options(parser)
plot.add_arguments_units(parser)
plot.add_arguments_reduc(parser)
plot.add_arguments_center(parser)
plot.add_arguments_select(parser)
plot.add_arguments_info(parser)
plot.add_comoving_options(parser)
plot.add_arguments_legend(parser)
plot.add_arguments_icshift(parser)
plot.add_arguments_cmd(parser)

parser.add_argument(action="store", 
                    dest="files", 
                    metavar='FILE', 
                    type=str,
                    default=None,
                    nargs='*',
                    help='a list of files')                     


parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  



parser.add_argument('--xmin',
                    action="store", 
                    dest="xmin", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='x min')

parser.add_argument('--xmax',
                    action="store", 
                    dest="xmax", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='x max')
                    
parser.add_argument('--ymin',
                    action="store", 
                    dest="ymin", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='y min')

parser.add_argument('--ymax',
                    action="store", 
                    dest="ymax", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='y max')

                                        
parser.add_argument('--log',
                    action="store", 
                    dest="log", 
                    metavar='STR', 
                    type=str,
                    default=None,
                    help='log scale (None,x,y,xy)')


parser.add_argument('-y','--y',
                    action="store", 
                    dest="y", 
                    metavar='STR', 
                    type=str,
                    default='density',
                    help='quantity to plot in the y axis')

parser.add_argument("--rmax",
                  action="store",
                  dest="rmax",
                  type=float,
                  default=50.,
                  help="max radius of bins",
                  metavar=" FLOAT")

parser.add_argument("--nr",
                  action="store",
                  dest="nr",
                  type=int,
                  default=32,
                  help="number of bins in r",
                  metavar=" INT")   

parser.add_argument("--eps",
                     action="store",
                     dest="eps",
                     type=float,
                     default=0.1,
                     help="smoothing length",
                     metavar=" FLOAT")


parser.add_argument("--colormap",
                    action="store", 
                    default='jet',
                    help='matplotlib colormap name (e.g. mycmap, tab20c, Greys, jet, binary)') 
                    
                    
                    

#######################################
# MakePlot
#######################################


def MakePlot(opt):
  
  params = {
    "axes.labelsize": 14,
    "axes.titlesize": 18,
    "font.size": 12,
    "legend.fontsize": 12,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "text.usetex": True,
    "figure.subplot.left": 0.15,
    "figure.subplot.right": 0.95,
    "figure.subplot.bottom": 0.15,
    "figure.subplot.top": 0.95,
    "figure.subplot.wspace": 0.02,
    "figure.subplot.hspace": 0.02,
    "figure.figsize" : (8, 6),
    "lines.markersize": 6,
    "lines.linewidth": 2.0,
  }
  plt.rcParams.update(params)
  
  
  # create the plot
  fig = plt.gcf()
  fig.set_size_inches(8,6)
  ax  = plt.gca()


  # get a list of color
  colors = plot.ColorList(n=len(opt.files),colormap=opt.colormap)
  
  # list of points
  datas = []
  
  
  #################################
  # loop over files
  
  for filename in opt.files:
    
    nb = Nbody(filename, ftype=opt.ftype)
    
    ################
    # units
    ################
    
    # define local units
    unit_params = plot.apply_arguments_units(opt)
    nb.set_local_system_of_units(params=unit_params)
    
    ################
    # apply options
    ################
    nb = plot.apply_arguments_icshift(nb, opt)
    nb = plot.apply_arguments_comoving(nb, opt)
    nb = plot.apply_arguments_reduc(nb, opt)
    nb = plot.apply_arguments_select(nb, opt)
    nb = plot.apply_arguments_center(nb, opt)
    nb = plot.apply_arguments_cmd(nb, opt)
    nb = plot.apply_arguments_info(nb, opt)
    nb = plot.apply_arguments_display(nb, opt)


    ################
    # some info
    ################
    print("---------------------------------------------------------")
    nb.localsystem_of_units.info()
    nb.ComovingIntegrationInfo()
    print("---------------------------------------------------------")

    # grid division
    rc = 1

    def f(r): return np.log(r / rc + 1.)

    def fm(r): return rc * (np.exp(r) - 1.)

    ###############################
    # compute physical quantities
    ###############################

    ##########################################
    # Spherical_1d_Grid
    ##########################################

    if opt.y == 'density':

        G = libgrid.Spherical_1d_Grid(
            rmin=0, rmax=opt.rmax, nr=opt.nr, g=f, gm=fm)

        x = G.get_r()
        y = G.get_DensityMap(nb)

        # comoving conversion
        if nb.isComovingIntegrationOn():
            print((
                "    converting to physical units (a=%5.3f h=%5.3f)" %
                (nb.atime, nb.hubbleparam)))
            x = x * nb.atime / nb.hubbleparam		# length  conversion
            y = y / nb.atime**3 * nb.hubbleparam**2 	# density conversion

        # output units
        out_units_x = units.UnitSystem(
            'local', [units.Unit_kpc, units.Unit_Ms, units.Unit_Myr, units.Unit_K])

        Unit_atom = ctes.PROTONMASS.into(units.cgs) * units.Unit_g
        Unit_atom.set_symbol('atom')
        out_units_y = units.UnitSystem(
            'local', [units.Unit_cm, Unit_atom, units.Unit_s, units.Unit_K])

        fx = nb.localsystem_of_units.convertionFactorTo(
            out_units_x.UnitLength)
        fy = nb.localsystem_of_units.convertionFactorTo(
            out_units_y.UnitDensity)

        x = x * fx
        y = y * fy

        # set labels

        xlabel = r'$\rm{Radius}\,\left[ \rm{kpc} \right]$'
        ylabel = r'$\rm{Density}\,\left[ \rm{atom/cm^3} \right]$'

        x, y = plot.CleanVectorsForLogX(x, y)
        x, y = plot.CleanVectorsForLogY(x, y)
        datas.append(
            plot.DataPoints(
                x,
                y,
                color=colors.get(),
                label=filename,
                tpe='points'))

    if opt.y == 'mass':

        G = libgrid.Spherical_1d_Grid(
            rmin=0, rmax=opt.rmax, nr=opt.nr, g=f, gm=fm)

        x = G.get_r()
        y = G.get_MassMap(nb)
                
        # comoving conversion
        if nb.isComovingIntegrationOn():
            print((
                "    converting to physical units (a=%5.3f h=%5.3f)" %
                (nb.atime, nb.hubbleparam)))
            x = x * nb.atime / nb.hubbleparam		# length  conversion
            y = y / nb.hubbleparam 			# mass conversion

        # output units
        out_units_x = units.UnitSystem(
            'local', [units.Unit_kpc, units.Unit_Ms, units.Unit_Myr, units.Unit_K])
        out_units_y = units.UnitSystem(
            'local', [units.Unit_kpc, units.Unit_Ms, units.Unit_Myr, units.Unit_K])

        fx = nb.localsystem_of_units.convertionFactorTo(
            out_units_x.UnitLength)
        fy = nb.localsystem_of_units.convertionFactorTo(
            out_units_y.UnitDensity)

        x = x * fx
        y = y * fy

        # set labels

        xlabel = r'$\rm{Radius}\,\left[ \rm{kpc} \right]$'
        ylabel = r'$\rm{Mass}\,\left[ M_{\odot} \right]$'

        x, y = plot.CleanVectorsForLogX(x, y)
        x, y = plot.CleanVectorsForLogY(x, y)
        datas.append(
            plot.DataPoints(
                x,
                y,
                color=colors.get(),
                label=filename,
                tpe='points'))

    if opt.y == 'imass':

        G = libgrid.Spherical_1d_Grid(
            rmin=0, rmax=opt.rmax, nr=opt.nr, g=f, gm=fm)

        x = G.get_r()
        y = G.get_MassMap(nb)
        y = np.add.accumulate(y)

        # comoving conversion
        if nb.isComovingIntegrationOn():
            print((
                "    converting to physical units (a=%5.3f h=%5.3f)" %
                (nb.atime, nb.hubbleparam)))
            x = x * nb.atime / nb.hubbleparam		# length  conversion
            y = y / nb.hubbleparam 			# mass conversion

        # output units
        out_units_x = units.UnitSystem(
            'local', [units.Unit_kpc, units.Unit_Ms, units.Unit_Myr, units.Unit_K])
        out_units_y = units.UnitSystem(
            'local', [units.Unit_kpc, units.Unit_Ms, units.Unit_Myr, units.Unit_K])

        fx = nb.localsystem_of_units.convertionFactorTo(
            out_units_x.UnitLength)
        fy = nb.localsystem_of_units.convertionFactorTo(
            out_units_y.UnitDensity)

        x = x * fx
        y = y * fy

        # set labels

        xlabel = r'$\rm{Radius}\,\left[ \rm{kpc} \right]$'
        ylabel = r'$\rm{Mass}\,\left[ M_{\odot} \right]$'

        x, y = plot.CleanVectorsForLogX(x, y)
        x, y = plot.CleanVectorsForLogY(x, y)
        datas.append(
            plot.DataPoints(
                x,
                y,
                color=colors.get(),
                label=filename,
                tpe='points'))
                  
            
    if ((opt.y == 'sigmar') or (opt.y == "sigmat") or (opt.y == "sigmatheta") or (opt.y == "sigmaphi") or (opt.y == "sigmaz") or (opt.y == "beta") or (opt.y == "Vt") or (opt.y == "Vr") or (opt.y == "Vtheta") or (opt.y == "Vphi")):
      
        # comoving conversion
        if nb.isComovingIntegrationOn():
            print((
                "    converting to physical units (a=%5.3f h=%5.3f)" %
                (nb.atime, nb.hubbleparam)))

            x0 = nb.pos
            v0 = nb.vel

            Hubble = ctes.HUBBLE.into(nb.localsystem_of_units)
            pars = {
                "Hubble": Hubble,
                "HubbleParam": nb.hubbleparam,
                "OmegaLambda": nb.omegalambda,
                "Omega0": nb.omega0}
            a = nb.atime
            Ha = cosmo.Hubble_a(a, pars=pars)

            nb.pos = x0 * nb.atime / nb.hubbleparam  # length    conversion
            nb.vel = v0 * np.sqrt(a) + x0 * Ha * a		# velocity  conversion

        # output units
        out_units_x = units.UnitSystem(
            'local', [units.Unit_kpc, units.Unit_Ms, units.Unit_Myr, units.Unit_K])
        out_units_y = units.UnitSystem(
            'local', [units.Unit_km, units.Unit_Ms, units.Unit_s, units.Unit_K])

        fx = nb.localsystem_of_units.convertionFactorTo(
            out_units_x.UnitLength)
        fy = nb.localsystem_of_units.convertionFactorTo(
            out_units_y.UnitVelocity)

        nb.pos = nb.pos * fx
        nb.vel = nb.vel * fy

        G = libgrid.Spherical_1d_Grid(
            rmin=0, rmax=opt.rmax, nr=opt.nr, g=f, gm=fm)

        x = G.get_r()
        
        if opt.y == "beta":
          sigmaphi   = G.get_SigmaValMap(nb, nb.Vphi())
          sigmatheta = G.get_SigmaValMap(nb, nb.Vtheta())
          st = np.sqrt(sigmaphi**2 + sigmatheta**2)              
          sr = G.get_SigmaValMap(nb, nb.Vr())
          y = 1-((st**2)/(2.*sr**2))
          ylabel = r'$\beta$'              
        else:  
        
          if opt.y == 'sigmaz':
            y = G.get_SigmaValMap(nb, nb.Vz())   
            ylabel = r'$\sigma_z\,\left[ \rm{km}/\rm{s} \right]$'         
          if opt.y == 'sigmar':
            y = G.get_SigmaValMap(nb, nb.Vr())
            ylabel = r'$\sigma_r\,\left[ \rm{km}/\rm{s} \right]$'
          if opt.y == 'sigmat':
            sigmaphi   = G.get_SigmaValMap(nb, nb.Vphi())
            sigmatheta = G.get_SigmaValMap(nb, nb.Vtheta())
            y = np.sqrt(sigmaphi**2 + sigmatheta**2)
            ylabel = r'$\sigma_t\,\left[ \rm{km}/\rm{s} \right]$'
          if opt.y == 'sigmatheta':
            y = G.get_SigmaValMap(nb, nb.Vtheta())
            ylabel = r'$\sigma_\theta\,\left[ \rm{km}/\rm{s} \right]$'                
          if opt.y == 'sigmaphi':
            y = G.get_SigmaValMap(nb, nb.Vphi())
            ylabel = r'$\sigma_\phi\,\left[ \rm{km}/\rm{s} \right]$'
                            
          if opt.y == 'Vt':
            y = G.get_MeanValMap(nb, nb.Vt())
            ylabel = r'$V_t\,\left[ \rm{km}/\rm{s} \right]$'                
          if opt.y == 'Vr':
            y = G.get_MeanValMap(nb, nb.Vr())
            ylabel = r'$V_r\,\left[ \rm{km}/\rm{s} \right]$'
          if opt.y == 'Vtheta':
            y = G.get_MeanValMap(nb, nb.Vtheta())
            ylabel = r'$V_\theta\,\left[ \rm{km}/\rm{s} \right]$'                
          if opt.y == 'Vphi':
            y = G.get_MeanValMap(nb, nb.Vphi())
            ylabel = r'$V_\phi\,\left[ \rm{km}/\rm{s} \right]$'                 
                          
        # set xlabels
        xlabel = r'$\rm{Radius}\,\left[ \rm{kpc} \right]$'
        

        #x, y = plot.CleanVectorsForLogX(x, y)
        #x, y = plot.CleanVectorsForLogY(x, y)
        datas.append(
            plot.DataPoints(
                x,
                y,
                color=colors.get(),
                label=filename,
                tpe='points'))



    if opt.y == 'vcirc':

        G = libgrid.Spherical_1d_Grid(rmin=0, rmax=opt.rmax, nr=opt.nr, g=f, gm=fm)
        
        
        if nb.isComovingIntegrationOn():
          # enventually correct pos and mass from h and a
          fp = nb.ConversionFactor(units=None,mode='pos')
          fm = nb.ConversionFactor(units=None,mode='mass')
        else:
          fp = 1
          fm = 1  
    
        # r
        r = G.get_r()*fp
    
        # M(r)
        M = G.get_MassMap(nb)
        M = np.add.accumulate(M)*fm
        
        # Newton theorem
        G=ctes.GRAVITY.into(nb.localsystem_of_units)   
        vc = np.sqrt(G*M/r) 
              
        x = r      
        y = vc

        # output units
        out_units_x = units.UnitSystem(
            'local', [units.Unit_kpc, units.Unit_Ms, units.Unit_Myr, units.Unit_K])
        out_units_y = units.UnitSystem(
            'local', [units.Unit_km, units.Unit_Ms, units.Unit_s, units.Unit_K])

        fx = nb.localsystem_of_units.convertionFactorTo(
            out_units_x.UnitLength)
        fy = nb.localsystem_of_units.convertionFactorTo(
            out_units_y.UnitVelocity)

        x = x * fx
        y = y * fy

        # set labels

        xlabel = r'$\rm{Radius}\,\left[ \rm{kpc} \right]$'
        ylabel = r'$V_{\rm c}\,\left[ \rm{km}/\rm{s} \right]$'

        x, y = plot.CleanVectorsForLogX(x, y)
        x, y = plot.CleanVectorsForLogY(x, y)
        datas.append(
            plot.DataPoints(
                x,
                y,
                color=colors.get(),
                label=filename,
                tpe='points'))

    if opt.y == 'tdyn':

        G = libgrid.Spherical_1d_Grid(rmin=0, rmax=opt.rmax, nr=opt.nr, g=f, gm=fm)
        
        if nb.isComovingIntegrationOn():
          # enventually correct pos and mass from h and a
          fp = nb.ConversionFactor(units=None,mode='pos')
          fm = nb.ConversionFactor(units=None,mode='mass')
        else:
          fp = 1
          fm = 1  
    
        # r
        r = G.get_r()*fp
    
        # M(r)
        M = G.get_MassMap(nb)
        M = np.add.accumulate(M)*fm
        
        # Newton theorem
        G=ctes.GRAVITY.into(nb.localsystem_of_units)   
        vc = np.sqrt(G*M/r) 
              
        x = r      
        y = 2*np.pi*r/vc

        # output units
        out_units_x = units.UnitSystem(
            'local', [units.Unit_kpc, units.Unit_Ms, units.Unit_Myr, units.Unit_K])
        out_units_y = units.UnitSystem(
            'local', [units.Unit_km, units.Unit_Ms, units.Unit_Myr, units.Unit_K])

        fx = nb.localsystem_of_units.convertionFactorTo(
            out_units_x.UnitLength)
        fy = nb.localsystem_of_units.convertionFactorTo(
            out_units_y.UnitTime)

        x = x * fx
        y = y * fy

        # set labels

        xlabel = r'$\rm{Radius}\,\left[ \rm{kpc} \right]$'
        ylabel = r'$t_{\rm dyn}\,\left[ \rm{Myr} \right]$'

        x, y = plot.CleanVectorsForLogX(x, y)
        x, y = plot.CleanVectorsForLogY(x, y)
        datas.append(
            plot.DataPoints(
                x,
                y,
                color=colors.get(),
                label=filename,
                tpe='points'))

  ##################
  # plot all
  ##################


  for d in datas:
    ax.plot(d.x, d.y, color=d.color)

  # set limits
  xmin, xmax, ymin, ymax, log = plot.SetLimitsFromDataPoints(datas, opt.xmin, opt.xmax, opt.ymin, opt.ymax, opt.log)

  # set the axis (the extention is done in the previous command)
  plot.SetAxis(ax,xmin,xmax,ymin,ymax,log,extend=False)
  
  # labels
  ax.set_xlabel(xlabel)
  ax.set_ylabel(ylabel)
  
  # legend
  if opt.legend:
    plot.LegendFromDataPoints(ax, datas, opt.legend_loc)



  # save or display
  if opt.outputfilename:
    plt.savefig(opt.outputfilename)
  else:
    plt.show()    
      


#################################
# main
#################################

if __name__ == '__main__':  
  
  opt = parser.parse_args()
  MakePlot(opt)
  



