#!/usr/bin/env python
#
# pyQvarsi, Interpolate.
#
# Interpolate a case into another.
#
# Last rev: 27/07/2021
from __future__ import print_function, division

# Please do not delete this part otherwise it will not work
# you have been warned after a long weekend of debugging
import mpi4py
mpi4py.rc.recv_mprobe = False
from mpi4py import MPI

import os, argparse, numpy as np
import pyQvarsi

MPI_RANK = MPI.COMM_WORLD.Get_rank()
MPI_SIZE = MPI.COMM_WORLD.Get_size()


## Argparse
argpar = argparse.ArgumentParser(prog='pyQvarsi_interpolate', description='Interpolate a case into another')
argpar.add_argument('-c','--case',required=True,type=str,help='source case name',dest='src')
argpar.add_argument('-v','--variables',required=True,type=str,help='variables to interpolate',dest='vars')
argpar.add_argument('--basedir',type=str,help='source directory name',dest='basedir')
argpar.add_argument('-i','--instant',required=True,type=int,help='source file instant',dest='source_instant')
argpar.add_argument('-n','--nchunks',type=int,help='number of chunks',dest='nchunks')
argpar.add_argument('--rot-angle',type=str,help='rot angles x,y,z zero by default',dest='rot_angle')
argpar.add_argument('--rot-center',type=str,help='rot center x,y,z zero by default',dest='rot_center')
argpar.add_argument('--method',type=str,help='interpolation method to use',dest='method')
argpar.add_argument('--read-massm',action='store_true',help='read the mass matrix instead of computing',dest='massm')
argpar.add_argument('--start-field',action='store_true',help='create a start field instead of just interpolating',dest='st_field')
argpar.add_argument('-f','--fact',type=float,help='factor to multiply the element distance (default: 2.5)',dest='fact')
argpar.add_argument('--f-inc',type=float,help='factor to increase searching box size (default: 1.0)',dest='f_incr')
argpar.add_argument('--r-inc',type=float,help='factor to increase searching ball radius (default: 1.0)',dest='r_incr')
argpar.add_argument('--ball-max-iter',type=int,help='maximum number of ball searches (default: 5)',dest='ball_max_iter')
argpar.add_argument('--global-max-iter',type=int,help='maximum number of global searches (default: 1)',dest='glob_max_iter')
argpar.add_argument('target',type=str,help='target mesh')

# Parse inputs
args     = argpar.parse_args()
isRotate = True
if not args.rot_angle:     args.rot_angle, isRotate  = '0,0,0', False
if not args.rot_center:    args.rot_center, isRotate = '0,0,0', False
if not args.basedir:       args.basedir       = './'
if not args.nchunks:       args.nchunks       = 1
if not args.st_field:      args.st_field      = False 
if not args.fact:          args.fact          = 2.5
if not args.f_incr:        args.f_incr        = 1.0
if not args.r_incr:        args.r_incr        = 1.0
if not args.ball_max_iter: args.ball_max_iter = 5
if not args.glob_max_iter: args.glob_max_iter = 1
if args.rot_angle:         args.rot_angle     = np.array([float(x) for x in args.rot_angle.split(',')],np.double)
if args.rot_center:        args.rot_center    = np.array([float(x) for x in args.rot_center.split(',')],np.double)
if args.vars:              args.vars          = args.vars.split(',')


## Print info in screen
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| pyQvarsi_interpolate |-- ')
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Interpolate a case into another.')
pyQvarsi.pprint(0,'--|',flush=True)


## Load source mesh and field
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Opening source mesh <%s>... '%args.src,end='',flush=True)
smesh = pyQvarsi.Mesh.read(args.src,basedir=args.basedir,read_commu=False,read_massm=args.massm,read_codno=False)
pyQvarsi.pprint(0,'done!')
pyQvarsi.pprint(0,'--|',flush=True)
pyQvarsi.pprint(0,'--| Opening source field <%s>... '%args.src,end='',flush=True)
sfields,sheader = pyQvarsi.Field.read(args.src,args.vars,args.source_instant,smesh.xyz,basedir=args.basedir)
pyQvarsi.pprint(0,'done!')
pyQvarsi.pprint(0,'--|',flush=True)
if isRotate: 
	smesh.rotate(args.rot_angle,center=args.rot_center)
	sfields.rotate(args.rot_angle)


## Load target mesh - header
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Opening target mesh <%s> header... '%args.target,end='',flush=True)
tfname = (pyQvarsi.io.MPIO_AUXFILE_S_FMT if args.st_field else pyQvarsi.io.MPIO_AUXFILE_P_FMT) % (args.target,'COORD')
header = pyQvarsi.io.AlyaMPIO_header.read(tfname)
pyQvarsi.pprint(0,'done!',flush=True)

# Create headers for the intepolated variables
headers = []
for ivar,var in enumerate(args.vars):
	h = pyQvarsi.io.AlyaMPIO_header(
		fieldname   = var if not args.st_field else 'XFIEL',
		dimension   = 'VECTO' if len(sfields[var].shape) > 1 else 'SCALA',
		association = 'NPOIN',
		dtype       = 'REAL',
		size        = '8BYTE',
		npoints     = header.npoints,
		nsub        = header.nsubd,
		sequence    = header.header['Sequence'],
		ndims       = sfields[var].shape[1] if len(sfields[var].shape) > 1 else 1,
		itime       = sheader.itime if not args.st_field else 0,
		time        = sheader.time if not args.st_field else 0.,
		tag1        = -1 if not args.st_field else ivar+1,
		tag2        = -1 if not args.st_field else ivar+1,
		ignore_err  = True
	)
	headers.append(h)


## Interpolate by chunks
interp_fun = pyQvarsi.meshing.interpolateFEM 
if args.method.lower() in ['nn','nearestneigbour','nearest','nearest_neigbour']: interp_fun = pyQvarsi.meshing.interpolateNearestNeighbour 
if args.method.lower() in ['fem2','femnn','femnearestneigbour','femnearest']:    interp_fun = pyQvarsi.meshing.interpolateFEMNearestNeighbour 
for ichunk in range(args.nchunks):
	# Obtain chunk bounds - hack on the worksplit
	istart, iend = pyQvarsi.utils.worksplit(0,header.npoints,ichunk,nWorkers=args.nchunks)
	# Load target mesh by chunks - all processors read the same part of the source mesh
	pyQvarsi.pprint(0,'--|')
	pyQvarsi.pprint(0,'--| Opening target mesh <%s>, chunk %d/%d... '%(args.target,ichunk+1,args.nchunks),end='',flush=True)
	tpoints, _   = pyQvarsi.io.AlyaMPIO_readByChunk(tfname,iend-istart,istart) # rows2read, rows2skip
	pyQvarsi.pprint(0,'done!',flush=True)
	# Interpolate to destiny mesh
	pyQvarsi.pprint(0,'--|')
	pyQvarsi.pprint(0,'--| Interpolating to destiny mesh, chunk %d/%d... '%(ichunk+1,args.nchunks),end='',flush=True)
	interpf = interp_fun(smesh,tpoints,sfields,fact=args.fact,f_incr=args.f_incr,r_incr=args.r_incr,
		global_max_iter=args.glob_max_iter,ball_max_iter=args.ball_max_iter,root=0)
	pyQvarsi.pprint(0,'done!',flush=True)
	# Sanity check
	if not len(interpf) == (iend-istart): 
		pyQvarsi.raiseError('Reduction returned wrong number of nodes %d (%d)'%(len(interpf),iend-istart))
	# At this point we can write the output file
	for ivar,var in enumerate(args.vars):
		fname = pyQvarsi.io.MPIO_XFLFILE_S_FMT % (args.target,ivar+1,1) if args.st_field else pyQvarsi.io.MPIO_BINFILE_P_FMT % (args.target,var,args.source_instant)
		pyQvarsi.pprint(0,'--|')
		pyQvarsi.pprint(0,'--| Writing <%s>, chunk %d/%d... '%(fname,ichunk+1,args.nchunks),end='',flush=True)
		pyQvarsi.io.AlyaMPIO_writeByChunk_serial(fname,interpf[var],headers[ivar],iend-istart,istart,rank=0)
		pyQvarsi.pprint(0,'done!',flush=True)


## Say goodbye
pyQvarsi.pprint(0,'--|')
pyQvarsi.pprint(0,'--| Bye!',flush=True)
pyQvarsi.cr_info()
