#!/usr/bin/env python3

###########################################################################################
#  package:   Gtools
#  file:      pplot
#  brief:     
#  copyright: GPLv3
#             Copyright (C) 2019 EPFL (Ecole Polytechnique Federale de Lausanne)
#             LASTRO - Laboratory of Astrophysics of EPFL
#  author:    Yves Revaz <yves.revaz@epfl.ch>
#
# This file is part of Gtools.
###########################################################################################

NTYPES = 7

from mpi4py import MPI
import h5py
import numpy as np
import argparse


description="Copy the datasets of PartTypeX of a swift file to the ones taken from a given one."
epilog     ="""
Examples:
--------
sw_copyPartTypeDataset  -i inputfile.hdf5  file.hdf5
"""

parser = argparse.ArgumentParser(description=description,epilog=epilog,formatter_class=argparse.RawDescriptionHelpFormatter)


parser.add_argument("-i",
                    action="store", 
                    dest="inputfile", 
                    metavar='FILE', 
                    nargs='?',
                    required=True,
                    type=str,
                    default=None,
                    help='inputfile') 
                    
parser.add_argument("-p","--ptype",
                    action="store", 
                    dest="ptype", 
                    metavar='INT', 
                    type=str,
                    default=0,
                    help='particle type')                     
                    
parser.add_argument(action="store", 
                    dest="file", 
                    metavar='FILE', 
                    type=str,
                    default=None,
                    help='file to modify') 
                    
                    
opt = parser.parse_args()




# MPI initialisation
comm = MPI.COMM_WORLD
ThisTask = comm.Get_rank()
NTask = comm.Get_size()
Procnm = MPI.Get_processor_name()
Rank = ThisTask



def get_particles_limits(size):
    """ Gives the limits for a thread.
    In order to get the particles, slice them like this pos[start:end].
    :param int i: Particle type
    :returns: (start, end)
    """
    nber = float(size) / NTask
    start = int(ThisTask * nber)
    end = int((ThisTask + 1) * nber)
    return start, end    



def h5_load(filename,ptype,key):
  '''
  load a vector
  
  ptype = 'PartType0', 'PartType1'
  key   = 'Coordinates', 'InternalEnergy', 'Masses', 'ParticleIDs', 'SmoothingLength', 'Velocities'
  '''
  
  
  if NTask > 1:
    fd = h5py.File(filename,'r',driver="mpio",comm=MPI.COMM_WORLD)
  else:
    fd = h5py.File(filename,'r')
  
  
  
  block = fd[ptype]
  
  data_length = block[key].len()
  
  # get what needs to be read
  idx0, idx1 = get_particles_limits(data_length)
  
  # read
  data = block[key][idx0:idx1]
    
  comm.barrier()
  fd.close()

  return data, idx0, idx1


def h5_dump(filename,ptype,key,val,ini,end):

  if NTask > 1:
    fd = h5py.File(filename,'a',driver="mpio",comm=MPI.COMM_WORLD)
  else:
    fd = h5py.File(filename,'a')
  
        
  block = fd[ptype]
  
    
  if NTask > 1:
    del block[key]
    
    # communicate the shape
    size = comm.allreduce(val.shape[0], op=MPI.SUM)
    
    if   len(val.shape) == 1:
      shape = (size,)
    elif len(val.shape) == 2:
      shape = (size,val.shape[1])
    else:
      raise("This is not expected")
      
    dset = block.create_dataset(key, shape=shape,dtype=val.dtype)
    dset[ini:end] = val
    
  else:
    del block[key]        
    dset = block.create_dataset(key, shape=val.shape,dtype=val.dtype)
    dset[:] = val 

  
  comm.barrier()
  fd.flush()
  fd.close()



def h5_update_npart(filename):
  
  
  # only master update npart
  if ThisTask==0:
    fd = h5py.File(filename,'a')
  
    npart = np.zeros(NTYPES)
  
    for itype in range(NTYPES):
      ptype = "PartType%d"%itype
    
      if ptype in fd:
        shape = fd[ptype]["Coordinates"].shape
        npart[itype] = shape[0]


    if ThisTask==0:
      print("old npart = ",fd['Header'].attrs['NumPart_ThisFile'])
    
    fd['Header'].attrs.modify('NumPart_ThisFile',npart.astype(np.uint64))
    fd['Header'].attrs.modify('NumPart_Total',   npart.astype(np.uint32))
    fd.flush()
    fd.close()
    
    if ThisTask==0:
      print("new npart = ",npart.astype(np.uint64))
    


def h5_getPartTypeKeys(filename,ptype):
  
  attrs = []
  
  # only master works
  if ThisTask==0:
    fd = h5py.File(filename,'a')
  
    for attr in fd[ptype]:
      attrs.append(attr)
  
    fd.close() 
 
  attrs = comm.bcast(attrs, root=0)
  return attrs


############################
#
# main
#
############################

filename1 = opt.inputfile
filename2 = opt.file


# setup ptype
ptype = "PartType%d"%opt.ptype

# get the keys attached to ptype
keys = h5_getPartTypeKeys(filename1,ptype)

if ThisTask==0:
  print()
  print("Running using %d tasks."%NTask)
  print("copying %s from %s"%(ptype,opt.inputfile))

for key in keys:
  
  if ThisTask==0:
    print("  copying dataset ",key)
  
  # load from the first file
  val,ini,end = h5_load(filename1,ptype=ptype,key=key)

  # dump to the second one
  h5_dump(filename2,ptype,key,val,ini,end)


# change npart
if ThisTask==0:
  print("file %s updated."%filename2)
  
h5_update_npart(filename2)















