#!/usr/bin/env python
#
# pyQvarsi, Ensight to MPIO.
#
# Converter for EnsightGold to MPIO.
#
# Last rev: 24/09/2021
from __future__ import print_function, division

# Please do not delete this part otherwise it will not work
# you have been warned after a long weekend of debugging
import mpi4py
mpi4py.rc.recv_mprobe = False

import os, argparse, numpy as np
import pyQvarsi


## Argparse
argpar = argparse.ArgumentParser(prog='pyQvarsi_ensi2mpio', description='Convert EnsightGold to MPIO')
argpar.add_argument('-c1','--casename1',required=True,type=str,help='Name of the case for EnsightGold',dest='casestr1')
argpar.add_argument('-v','--variables',required=True,type=str,help='list of variables to include from MPIO files given as a list separated by comma (e.g., VELOC,PRESS)',dest='vars')
argpar.add_argument('-r','--range',type=str,help='simulation range to write (default: 0,1)',dest='range')
argpar.add_argument('--basedir1',type=str,help='base directory where to open the data (default: ./)',dest='basedir1')
argpar.add_argument('-c2','--casename2',required=True,type=str,help='Name of the case for MPIO',dest='casestr2')
argpar.add_argument('--basedir2',type=str,help='base directory where to output the data (default: ./)',dest='basedir2')
argpar.add_argument('--order-array',type=str,help='path to the ordering array, if it does not exist it will be created',dest='lninv')

# Parse inputs
args = argpar.parse_args()
if not args.range:    args.range    = '0,1'  # Default to instant zero
if args.range:        args.range    = [i for i in range(int(args.range.split(',')[0]),int(args.range.split(',')[1]))]
if not args.basedir1: args.basedir1 = './'
if not args.basedir2: args.basedir2 = './'
if args.vars:         args.vars     = args.vars.split(',')


# Print info in screen
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| pyQvarsi_ensi2mpio |-- ')
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Convert EnsightGold to MPIO.')
pyQvarsi.pprint(0,'--|',flush=True)


## Load source mesh in serial
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Loading source mesh <%s>... '%args.casestr2,end='',flush=True)
source_mesh = pyQvarsi.MeshAlya.read(args.casestr1,basedir=args.basedir1,ngauss=1,fmt='ensi') if pyQvarsi.utils.is_rank_or_serial(0) else None
pyQvarsi.pprint(0,'done!')
pyQvarsi.pprint(0,'--|',flush=True)


## Load target mesh and reduce LNINV in order
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Loading target mesh <%s>... '%args.casestr2,end='',flush=True)
target_mesh = pyQvarsi.MeshAlya.read(args.casestr2,basedir=args.basedir2,ngauss=1,fmt='mpio',read_commu=False,read_massm=False,read_codno=False,read_exnor=False)
pyQvarsi.pprint(0,'done!')
pyQvarsi.pprint(0,'--|',flush=True)


## Find the order of 1 that gives the order of 2 to be used in
#  sorting the fields of 1 so that the results are not scrambled
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Computing sorting strategy... ',end='',flush=True)
# Gather the node positions of the target mesh in rank 0
xyz2  = pyQvarsi.utils.mpi_gather(target_mesh.xyz,root=0)
# Only rank 0 computes this
if pyQvarsi.utils.is_rank_or_serial():
	xyz_clean = np.unique(source_mesh.xyz,axis=0)
	if args.lninv:
		lninv = pyQvarsi.math.reorder1to2(xyz2,xyz_clean) if not os.path.isfile(args.lninv) else np.load(args.lninv)
		# Store lninv
		np.save(args.lninv,lninv)
	else:
		# Just compute lninv
		lninv = pyQvarsi.math.reorder1to2(xyz2,xyz_clean) 
pyQvarsi.pprint(0,'done!')
pyQvarsi.pprint(0,'--|',flush=True)


## Reading partition tables
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Reading partition table of <%s>... '%args.casestr2,end='',flush=True)
# Read partition table
ptable = pyQvarsi.PartitionTable.fromAlya(args.casestr2,basedir=args.basedir2)
# Scatter nodal positions among the processors
xyz = source_mesh.xyz if pyQvarsi.utils.is_rank_or_serial(0) else None
xyz = pyQvarsi.utils.mpi_scatterp(xyz,ptable.Ids,ptable.Points,root=0)
pyQvarsi.pprint(0,'done!')
pyQvarsi.pprint(0,'--|',flush=True)


## Looping requested instants
for instant in args.range:
	# Load EnsightGold field
	pyQvarsi.pprint(0,'--|')
	pyQvarsi.pprint(0,'--| Outputting MPIO format of <%s> instant %d... '%(args.casestr1,instant),end='',flush=True)
	# Load ensight in serial mode
	f,h = pyQvarsi.FieldAlya.read(args.casestr1,args.vars,instant,source_mesh.xyz,basedir=args.basedir1,fmt='ensi') if pyQvarsi.utils.is_rank_or_serial(0) else (None, None)
	# Clean and reorder
	if pyQvarsi.utils.is_rank_or_serial(0):
		# Clean BC's on field f, which can be identified as repeated nodes
		f.clean()
		# Sort for destiny mesh - This will expand or contract f according to the new size
		f.sort(array=lninv,as_idx=True)
	# Broadcast time to all processors
	time = pyQvarsi.utils.mpi_bcast(h.time if pyQvarsi.Communicator.serial() else None,root=0)
	# Create the field structure for all processors
	if pyQvarsi.utils.is_rank_or_serial(0):
		f.xyz = xyz
	else:
		# Create output field
		f = pyQvarsi.FieldAlya(xyz=xyz)
		for var in args.vars: f[var] = None
	# Scatter field among the processors
	for var in args.vars:
		f[var] = pyQvarsi.utils.mpi_scatterp(f[var],ptable.Ids,ptable.Points,root=0)
	# Store Field in MPIO format
	f.write(args.casestr2,instant,time,basedir=args.basedir2,fmt='mpio')
	pyQvarsi.pprint(0,'done!')
pyQvarsi.pprint(0,'--|',flush=True)


## Say goodbye
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Bye!',flush=True)
pyQvarsi.cr_info()