#!python

# This script deforms an image according to an input deformation field using SPAM functions
# Copyright (C) 2020 SPAM Contributors
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, see <http://www.gnu.org/licenses/>.

"""
This script can be very useful for generating deformed images to calculate a residual field

The current implementation will mesh correlation points with tetrahedra and deform them with displacements,
this has the advantage of speed, but the interpolation of displacements is approximative.

We don't use the more accurate `spam.DIC.deformationFunction.applyPhiField` which is slow for large images
"""

import spam.helpers
import spam.mesh
import spam.label
import spam.DIC
import spam.deformation

import numpy

numpy.seterr(all="ignore")
import argparse
import tifffile
import multiprocessing
try:                 multiprocessing.set_start_method('fork')
except RuntimeError: pass

# Define argument parser object
parser = argparse.ArgumentParser(
    description="spam-deformImage "
    + spam.helpers.optionsParser.GLPv3descriptionHeader
    + "This deforms our input image according to some measured kinematics.\n"
    + "If a registration is given, it is wholly applied, otherwise if a displacement"
    + "field is given, it is triangulated and the displacements are applied",
    formatter_class=argparse.RawTextHelpFormatter,
)

# Parse arguments with external helper function
args = spam.helpers.optionsParser.deformImageParser(parser)

if args.PROCESSES is None:
    args.PROCESSES = multiprocessing.cpu_count()

print("spam-deformImage: Current Settings:")
argsDict = vars(args)
for key in sorted(argsDict):
    print("\t{}: {}".format(key, argsDict[key]))
print()

# Read displacements file
TSV = spam.helpers.readCorrelationTSV(
    args.PHIFILE.name, fieldBinRatio=args.PHIFILE_BIN_RATIO, readConvergence=True
)

im = tifffile.imread(args.inFile.name)

# Detect unpadded 2D image:
if len(im.shape) == 2:
    im = im[numpy.newaxis, ...]
    twoD = True
    if args.MESH_TRANSFORMATION:
        print(
            "\nspam-deformImage: the -tet option is only implemented for a 3D image. Forcing per-pixel interpolation."
        )
        args.MESH_TRANSFORMATION = False
else:
    twoD = False

# in case of a registration (assuming it's applied in the middle of the volume)
if TSV["PhiField"].shape[0] == 1:
    print(f"\nRegistration mode, applying Phi at {TSV['fieldCoords'][0]}")
    args.PREFIX += "-reg-def"

    Phi = TSV["PhiField"][0]
    if args.RIGID:
        PhiDecomposed = spam.deformation.decomposePhi(Phi)
        Phi = spam.deformation.computePhi(
            {"t": PhiDecomposed["t"], "r": PhiDecomposed["r"]}
        )
        print("Using only rigid part of the registration")
    if twoD:
        imdef = spam.DIC.applyPhiPython(im, Phi=Phi, PhiCentre=TSV["fieldCoords"][0])[0]
    else:
        imdef = spam.DIC.applyPhi(im, Phi=Phi, PhiCentre=TSV["fieldCoords"][0])

else:
    ### BIG SWITCH between meshTransformation and per-pixel displacement
    print("\nIn PhiField mode.")
    # Accept points based on return stat
    mask = TSV["returnStatus"] >= args.RETURN_STATUS_THRESHOLD
    print(
        f"\nspam-deformImage: excluding points based on return threshold < {args.RETURN_STATUS_THRESHOLD} (excluded {100*(1-numpy.mean(mask)):2.1f}%)"
    )

    if args.RADIUS is not None:
        # Also exclude based on radius
        radius = args.RADIUS
        y = TSV["fieldCoords"][:, 1].copy()
        y -= (im.shape[1] - 1) / 2.0
        x = TSV["fieldCoords"][:, 2].copy()
        x -= (im.shape[2] - 1) / 2.0
        r = numpy.sqrt(numpy.square(x) + numpy.square(y))
        mask[r > args.RADIUS] = False

    # print("Proportion of correlation points included {:0.0f}%".format(100*(mask.sum()/(len(mask)-1))))

    # update points
    points = TSV["fieldCoords"][mask]
    # update displacements
    disp = TSV["PhiField"][mask][:, 0:3, -1]
    print("\tnPoints = ", points.shape[0])

    if args.MESH_TRANSFORMATION:
        args.PREFIX += "-tetMesh-def"

        # 2019-12-10 EA and OS: triangulate in the deformed configuration
        conn = spam.mesh.triangulate(points + disp, alpha=args.MESH_ALPHA)
        print("\tnTets = ", conn.shape[0])

        # Let's make the tet image here, in case we want to recycle it for the cgs
        imTetLabel = spam.label.labelTetrahedra(
            im.shape, points + disp, conn, nThreads=args.PROCESSES
        )

        print("Interpolating image... ", end="")
        # 2019-12-10 EA and OS: look up pixels, remember im is the reference configuration that we are deforming
        imdef = spam.DIC.applyMeshTransformation(
            im, points, conn, disp, imTetLabel=imTetLabel, nThreads=args.PROCESSES
        )
        print("done")

        if args.CORRECT_GREY_FOR_STRAIN:
            print(
                "Correcting greyvalues for strain, assuming that vacuum greylevel = 0.0",
                end="",
            )
            # We're going to pre-deform the greylevels using the tetLabel as a mask
            volumesRef = spam.mesh.tetVolumes(points, conn)
            volumesDef = spam.mesh.tetVolumes(points + disp, conn)
            volStrain = volumesDef / volumesRef
            volStrain[volumesRef == 0] = 0.0
            correction = spam.label.convertLabelToFloat(imTetLabel, volStrain)
            imdef /= correction
            del correction
            print("done")
    else:
        ### "Exact mode"
        print("Per-pixel displacement interpolation mode")
        if args.INTERPOLATE_DISPLACEMENTS:
            args.PREFIX += "-disp"
        args.PREFIX += "-def"
        if args.MASK2 is not None:
            print("\tLoading im2 mask")
            imMaskDef = tifffile.imread(args.MASK2.name) > 0
            if len(imMaskDef.shape) == 2:
                imMaskDef = imMaskDef[numpy.newaxis, ...]
        else:
            imMaskDef = None

        imdef = spam.DIC.applyPhiField(
            im,
            points,
            TSV["PhiField"][mask],
            imMaskDef=imMaskDef,
            nNeighbours=args.NUMBER_OF_NEIGHBOURS,
            interpolationOrder=args.INTERPOLATION_ORDER,
            nProcesses=args.PROCESSES,
            displacementMode="interpolate"
            if args.INTERPOLATE_DISPLACEMENTS
            else "applyPhi",
            verbose=True,
        )

print("Saving deformed image:\n\t{}".format(args.OUT_DIR + "/" + args.PREFIX + ".tif"))
tifffile.imwrite(args.OUT_DIR + "/" + args.PREFIX + ".tif", imdef.astype(im.dtype))
