#!/usr/bin/python3
###########################################################################################
#  package:   pNbody
#  file:      plotSphericalRotationCurves
#  brief:     
#  copyright: GPLv3
#             Copyright (C) 2019 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Yves Revaz <yves.revaz@epfl.ch>
#
# This file is part of Gtools.
###########################################################################################

import os
import numpy as np
import argparse

import matplotlib.pyplot as plt
from pNbody import plot
from pNbody import *

####################################################################
# option parser
####################################################################

description="""plot the circular velocity curve for each component found, assuming the system to be pherical. The cicular velocity
curve is obtained via the Newton theorem."""
epilog     ="""
Examples:
--------
plotSphericalRotationCurves --forceComovingIntegrationOff --xmax 10 --rmax 10 --nr 64  --log xy  snapshot.hdf5  
"""



parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)

plot.add_files_options(parser)
plot.add_arguments_units(parser)
plot.add_arguments_reduc(parser)
plot.add_arguments_center(parser)
plot.add_arguments_select(parser)
plot.add_arguments_info(parser)
plot.add_comoving_options(parser)
plot.add_arguments_legend(parser)
plot.add_arguments_icshift(parser)
plot.add_arguments_cmd(parser)

parser.add_argument(action="store", 
                    dest="files", 
                    metavar='FILE', 
                    type=str,
                    default=None,
                    nargs='*',
                    help='a list of files')                     


parser.add_argument("-o",
                    action="store",
                    type=str,
                    dest="outputfilename",
                    default=None,
                    help="Name of the output file")  



parser.add_argument('--xmin',
                    action="store", 
                    dest="xmin", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='x min')

parser.add_argument('--xmax',
                    action="store", 
                    dest="xmax", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='x max')
                    
parser.add_argument('--ymin',
                    action="store", 
                    dest="ymin", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='y min')

parser.add_argument('--ymax',
                    action="store", 
                    dest="ymax", 
                    metavar='FLOAT', 
                    type=float,
                    default=None,
                    help='y max')

                                        
parser.add_argument('--log',
                    action="store", 
                    dest="log", 
                    metavar='STR', 
                    type=str,
                    default=None,
                    help='log scale (None,x,y,xy)')



parser.add_argument("--rmax",
                  action="store",
                  dest="rmax",
                  type=float,
                  default=50.,
                  help="max radius of bins",
                  metavar=" FLOAT")

parser.add_argument("--nr",
                  action="store",
                  dest="nr",
                  type=int,
                  default=32,
                  help="number of bins in r",
                  metavar=" INT")   


parser.add_argument("--colormap",
                    action="store", 
                    default='jet',
                    help='matplotlib colormap name (e.g. mycmap, tab20c, Greys, jet, binary)') 
                    
                    
                    

#######################################
# MakePlot
#######################################


def MakePlot(opt):
  
  params = {
    "axes.labelsize": 14,
    "axes.titlesize": 18,
    "font.size": 12,
    "legend.fontsize": 12,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "text.usetex": True,
    "figure.subplot.left": 0.15,
    "figure.subplot.right": 0.95,
    "figure.subplot.bottom": 0.15,
    "figure.subplot.top": 0.95,
    "figure.subplot.wspace": 0.02,
    "figure.subplot.hspace": 0.02,
    "figure.figsize" : (8, 6),
    "lines.markersize": 6,
    "lines.linewidth": 2.0,
  }
  plt.rcParams.update(params)
  
  
  # create the plot
  fig = plt.gcf()
  fig.set_size_inches(8,6)
  ax  = plt.gca()


  
  # list of points
  datas = []
  
  #################################
  # loop over files
  
  for filename in opt.files:
    
    nb = Nbody(filename, ftype=opt.ftype)
    
    ################
    # units
    ################
    
    # define local units
    unit_params = plot.apply_arguments_units(opt)
    nb.set_local_system_of_units(params=unit_params)
    
    ################
    # apply options
    ################
    nb = plot.apply_arguments_icshift(nb, opt)
    nb = plot.apply_arguments_comoving(nb, opt)
    nb = plot.apply_arguments_reduc(nb, opt)
    nb = plot.apply_arguments_select(nb, opt)
    nb = plot.apply_arguments_center(nb, opt)
    nb = plot.apply_arguments_cmd(nb, opt)
    nb = plot.apply_arguments_info(nb, opt)
    nb = plot.apply_arguments_display(nb, opt)


    ################
    # some info
    ################
    print("---------------------------------------------------------")
    nb.localsystem_of_units.info()
    nb.ComovingIntegrationInfo()
    print("---------------------------------------------------------")


    ##########################################
    # Circular velocity
    ##########################################

    # use a linear grid
    G = libgrid.Spherical_1d_Grid(rmin=0, rmax=opt.rmax, nr=opt.nr, g=None, gm=None)

    if nb.isComovingIntegrationOn():
      # enventually correct pos and mass from h and a
      fp = nb.ConversionFactor(units=None,mode='pos')
      fmm = nb.ConversionFactor(units=None,mode='mass')
    else:
      fp = 1
      fmm = 1  
    
    # get the radius
    r = G.get_r()*fp
    
    
    # get a list of color
    colors = plot.ColorList(n=len(nb.npart),colormap=opt.colormap)
    
    # loop over all components
    for i in range(len(nb.npart)):
      
      if nb.npart[i]==0:
        continue
        
      nbsub = nb.select(i)
            
      
      # M(r)
      M = G.get_MassMap(nbsub)
      M = np.add.accumulate(M)*fmm
      
      # Newton theorem
      Gcte=ctes.GRAVITY.into(nb.localsystem_of_units)   
      vc = np.sqrt(Gcte*M/r) 
            
      x = r      
      y = vc
      
      # output units
      out_units_x = units.UnitSystem('local', [units.Unit_kpc, units.Unit_Ms, units.Unit_Myr, units.Unit_K])
      out_units_y = units.UnitSystem('local', [units.Unit_km, units.Unit_Ms, units.Unit_s, units.Unit_K])
      
      fx = nb.localsystem_of_units.convertionFactorTo(out_units_x.UnitLength)
      fy = nb.localsystem_of_units.convertionFactorTo(out_units_y.UnitVelocity)
      
      x = x * fx
      y = y * fy
      
      # set labels
      
      xlabel = r'$\rm{Radius}\,\left[ \rm{kpc} \right]$'
      ylabel = r'$V_{\rm c}\,\left[ \rm{km}/\rm{s} \right]$'
      
      x, y = plot.CleanVectorsForLogX(x, y)
      x, y = plot.CleanVectorsForLogY(x, y)
      datas.append(plot.DataPoints(x,y,color=colors.get(),label=str(i),tpe='points'))
      

  ##################
  # add the total velocity curve
  ##################
  
  vtot2 = np.zeros(len(x))
  
  for d in datas:
    ax.plot(d.x, d.y, color=d.color)  
    vtot2 = vtot2 + d.y**2

  y = np.sqrt(vtot2)
  datas.append(plot.DataPoints(x,y,color='k',label='tot',tpe='points'))
  


  ##################
  # plot all
  ##################


  for d in datas:
    ax.plot(d.x, d.y, color=d.color)



  # set limits
  xmin, xmax, ymin, ymax, log = plot.SetLimitsFromDataPoints(datas, opt.xmin, opt.xmax, opt.ymin, opt.ymax, opt.log)

  # set the axis (the extention is done in the previous command)
  plot.SetAxis(ax,xmin,xmax,ymin,ymax,log,extend=False)
  
  # labels
  ax.set_xlabel(xlabel)
  ax.set_ylabel(ylabel)
  
  # legend
  if opt.legend:
    plot.LegendFromDataPoints(ax, datas, opt.legend_loc)




  # save or display
  if opt.outputfilename:
    plt.savefig(opt.outputfilename)
  else:
    plt.show()    
      


#################################
# main
#################################

if __name__ == '__main__':  
  
  opt = parser.parse_args()
  MakePlot(opt)
  



