#!python

"""
This python script computes a strain field from a displacement field defined on a regular grid (e.g., from spam-ldic) using SPAM functions
Copyright (C) 2020 SPAM Contributors

This program is free software: you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the Free
Software Foundation, either version 3 of the License, or (at your option)
any later version.

This program is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
more details.

You should have received a copy of the GNU General Public License along with
this program.  If not, see <http://www.gnu.org/licenses/>.
"""

import spam.helpers
import spam.DIC
import spam.deformation
import spam.mesh

import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'

import argparse
import tifffile
import multiprocessing
try:                 multiprocessing.set_start_method('fork')
except RuntimeError: pass
import numpy
numpy.seterr(all='ignore')



# Define argument parser object
parser = argparse.ArgumentParser(description="spam-regularStrain "+spam.helpers.optionsParser.GLPv3descriptionHeader +\
                                             "This script computes different components of strain, given a regularly-spaced displacement"+\
                                             " field like that coming from spam-ldic. Both infinitesimal and finite strain frameworks"+\
                                             " are implemented, and TSV, VTK and TIF output are possible",
                                 formatter_class=argparse.RawTextHelpFormatter)
# Parse arguments with external helper function
args = spam.helpers.optionsParser.regularStrainParser(parser)

# Figure out processes if not passed
if args.PROCESSES is None: args.PROCESSES = multiprocessing.cpu_count()

print("+----------------------------+")
print("| Regular Strain Calculation |")
print("+----------------------------+")

print("\nCurrent Settings:")
argsDict = vars(args)
for key in sorted(argsDict):
    print("\t{}: {}".format(key, argsDict[key]))

print("\nspam-regularStrain: Loading data...")
f = spam.helpers.readCorrelationTSV(args.inFile.name, readConvergence=args.MASK)

# Get the dimensions and coordinates of the field
dims        = f["fieldDims"]
fieldCoords = f["fieldCoords"]

# Calculate node spacing for each direction
# 2020-08-31 OS: safer calculation of node spacing
nodeSpacing = numpy.array([numpy.unique(fieldCoords[:, i])[1] - numpy.unique(fieldCoords[:, i])[0] if len(numpy.unique(fieldCoords[:, i])) > 1 else numpy.unique(fieldCoords[:, i])[0] for i in range(3)])

# Catch 2D case
if dims[0] == 1:
    twoD = True
    print("spam-regularStrain: detected 2D field")
else:
    twoD = False

dispFlat = f["PhiField"][:,:3,-1]

# Check if a mask of points (based on return status of the correlation) is asked
if args.MASK:
    mask = f["returnStatus"] < args.RETURN_STATUS_THRESHOLD
    print(f"\nspam-regularStrain: excluding points based on return threshold < {args.RETURN_STATUS_THRESHOLD} (excluded {100*(1-numpy.mean(mask)):2.1f}%)")
    dispFlat[mask] = numpy.nan

disp = dispFlat.reshape(dims[0], dims[1], dims[2], 3)

print("\nspam-regularStrain: Computing F=I+du/dx")
if args.Q8:
    Ffield = spam.deformation.FfieldRegularQ8(   disp,
                                                 nodeSpacing=nodeSpacing,
                                                 nProcesses=args.PROCESSES,
                                                 verbose=True)
elif args.RAW:
    # Just take it straight form the file
    Ffield = f["PhiField"][:,:3,:3]
    if args.MASK:
        Ffield[mask] = numpy.nan
    Ffield = Ffield.reshape(dims[0], dims[1], dims[2], 3, 3)

else:
    Ffield = spam.deformation.FfieldRegularGeers(disp,
                                                 nodeSpacing=nodeSpacing,
                                                 neighbourRadius=args.STRAIN_NEIGHBOUR_RADIUS,
                                                 nProcesses=args.PROCESSES,
                                                 verbose=True)

# Now compute what's been asked for...
print("\nspam-regularStrain: Decomposing F into ", args.COMPONENTS)
decomposedFfield = spam.deformation.decomposeFfield(Ffield,
                                                    args.COMPONENTS,
                                                    twoD=twoD,
                                                    nProcesses=args.PROCESSES,
                                                    verbose=True)


# Define base fileName
if args.Q8:
    if twoD:
        fileNameBase = args.OUT_DIR+"/"+args.PREFIX+"-strain-Q4"
        mode = "Q4"
    else:
        fileNameBase = args.OUT_DIR+"/"+args.PREFIX+"-strain-Q8"
        mode = "Q8"
elif args.RAW:
    fileNameBase = args.OUT_DIR+"/"+args.PREFIX+"-strain-raw"
    mode = "raw"
else:
    fileNameBase = args.OUT_DIR+"/"+args.PREFIX+"-strain-Geers"
    mode = "Geers"

# Save strain fields
print("\nspam-regularStrain: Saving strain fields...")
if args.TSV:

    # Positions for the centres of the Q8 elements are between the measurement points
    #   (so there is one number fewer compared to measurement points
    #   so we strip off last node points -- not Z ones in twoD for Q8 mode
    if args.Q8:
        if twoD:
            outputPositions = fieldCoords.copy().reshape(1, dims[1], dims[2], 3)[:, 0:-1, 0:-1, :]
        else:
            outputPositions = fieldCoords.copy().reshape(dims[0], dims[1], dims[2], 3)[0:-1, 0:-1, 0:-1, :]
        # Add a half-node spacing to the output field
        outputPositions[:, :, :, 0] += nodeSpacing[0] / 2.0
        outputPositions[:, :, :, 1] += nodeSpacing[1] / 2.0
        outputPositions[:, :, :, 2] += nodeSpacing[2] / 2.0
    else:
        # Positions for Geers and "raw" are the measurement points
        if twoD:
            outputPositions = fieldCoords.copy().reshape(1, dims[1], dims[2], 3)
        else:
            outputPositions = fieldCoords.copy().reshape(dims[0], dims[1], dims[2], 3)

    # Here we want to pass an Nx3 matrix of poitions:
    spam.helpers.writeStrainTSV(fileNameBase+".tsv", outputPositions.reshape(-1,3), decomposedFfield, firstColumn="StrainPointNumber")

if args.TIFF:
    for component in args.COMPONENTS:
        axes = ['z', 'y', 'x'] if not twoD else ['y', 'x']

        if component == 'vol' or component == 'dev' or component == 'volss' or component == 'devss':
            tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-{}-{}.tif".format(component, mode), decomposedFfield[component].astype('<f4'))

        if component == 'r' or component == 'z':
            if twoD:
                for n, di in enumerate(axes, start=1):
                    tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-{}{}-{}.tif".format(component, di, mode), decomposedFfield[component][:,:,:,n].astype('<f4'))
            else:
                for n, di in enumerate(axes):
                    tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-{}{}-{}.tif".format(component, di, mode), decomposedFfield[component][:,:,:,n].astype('<f4'))

        if component == 'e' or component == 'U':
            if twoD:
                for n, di in enumerate(axes, start=1):
                    for m, dj in enumerate(axes, start=1):
                        if m>=n:
                            tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-{}{}{}-{}.tif".format(component, di, dj, mode), decomposedFfield[component][:,:,:,n,m].astype('<f4'))
            else:
                for n, di in enumerate(axes):
                    for m, dj in enumerate(axes):
                        if m>=n:
                            tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-{}{}{}-{}.tif".format(component, di, dj, mode), decomposedFfield[component][:,:,:,n,m].astype('<f4'))

if args.VTK:
    cellData = {}
    if not twoD: aspectRatio = nodeSpacing
    else:        aspectRatio = [   1,   nodeSpacing[1], nodeSpacing[2]]

    # For geers strains are at the measurement points
    #   As per the displacements coming out of spam-ldic this will plot nicely if 2xHWS = NS
    if not args.Q8:
        origin=fieldCoords[0]-numpy.array(aspectRatio)/2.0
    # Q8's centre is between measurement points, but corners fall on displacement points, obviously
    else:
        origin=fieldCoords[0]

    for component in args.COMPONENTS:
        tmp = decomposedFfield[component]
        if args.VTKmaskNAN:
            tmp[numpy.isnan(tmp)] = 0.0
        cellData[component] = tmp
    spam.helpers.writeStructuredVTK(origin=origin, aspectRatio=aspectRatio, cellData=cellData, fileName=fileNameBase+".vtk")
