#!/usr/bin/env python
#
# pyQvarsi, Mesh H5 partition editor.
#
# Change the partition of an hdf5 file.
#
# Last rev: 07/02/2023
from __future__ import print_function, division

import os, gc, argparse, numpy as np
import pyQvarsi


# Argparse
argpar = argparse.ArgumentParser(prog='pyQvarsi_h5mpiopart', description='Change the partition of an hdf5 file')
argpar.add_argument('-c','--casename',required=True,type=str,help='Name of the soruce case for partition',dest='casestr')
argpar.add_argument('--basedir',type=str,help='base directory where to read the case the data (default: ./)',dest='basedir')
argpar.add_argument('file',type=str,help='HDF5 file')

# Parse inputs
args = argpar.parse_args()
if not args.basedir: args.basedir = './'

# Print info in screen
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| pyQvarsi_h5mpiopart |-- ')
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Change the partition of an hdf5 file.')
pyQvarsi.pprint(0,'--|',flush=True)

# Load LNINV from h5 file in serial
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Obtaining LNINV from <%s>... '%args.file,end='',flush=True)
lninv = pyQvarsi.Field.load(args.file,mpio=False,vars=['LNINV'])['LNINV']
pyQvarsi.pprint(0,'done!')
pyQvarsi.pprint(0,'--|',flush=True)

# Load target case
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Opening <%s>... '%args.casestr,end='',flush=True)
m = pyQvarsi.Mesh.read(args.casestr,basedir=args.basedir,ngauss=-1,
	read_commu=False,read_massm=False,read_codno=False,read_exnor=False,read_lmast=False
)
pyQvarsi.pprint(0,'done!')
pyQvarsi.pprint(0,'--|',flush=True)

# Find which points should be read
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Computing points to be read... ',end='',flush=True)
# Points of m.lninv that are on lninv which should be read
inods = np.where(np.isin(lninv,m.lninv))[0]
# Get rid of repeated points
_,snods = np.unique(lninv[inods],return_index=True)
inods = np.sort(inods[snods]) 
# Obtain the sorting of this lninv
_,bnods = np.where(m.lninv[:,None] == lninv[inods])
pyQvarsi.pprint(0,'done!')
pyQvarsi.pprint(0,'--|',flush=True)
# Free memory
del m, lninv, snods
gc.collect()
    
# Change partition table stored inside an HDF5 and read
# the file only on the selected inods
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Reading full <%s>... '%args.file,end='',flush=True)
p = pyQvarsi.PartitionTable.fromAlya(args.casestr,basedir=args.basedir)
pyQvarsi.io.h5_edit_partition_table(args.file,p)
f = pyQvarsi.Field.load(args.file,inods=inods)
# Reverse sort according to inods
f.sort(array=bnods,as_idx=True)
pyQvarsi.pprint(0,'done!')
pyQvarsi.pprint(0,'--|',flush=True)

# Finally store the field in MPIO format
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Outputting field... ',end='',flush=True)
f.write(args.casestr,basedir=args.basedir,exclude_vars=['LNINV'],fmt='mpio')
pyQvarsi.pprint(0,'done!')
pyQvarsi.pprint(0,'--|',flush=True)

# Say goodbye
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Bye!',flush=True)
pyQvarsi.cr_info()
