#!python

# Copyright (C) 2020 SPAM Contributors
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, see <http://www.gnu.org/licenses/>.

import spam.DIC
import spam.deformation
import spam.helpers

import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'

import numpy
numpy.seterr(all='ignore')
import argparse
import tifffile
import multiprocessing
try:                 multiprocessing.set_start_method('fork')
except RuntimeError: pass
import progressbar


# Define argument parser object
parser = argparse.ArgumentParser(description="spam-ldic "+spam.helpers.optionsParser.GLPv3descriptionHeader +\
                                             "This script performs Local Digital Image Correlation script between a series of at least two 3D greyscale images"+\
                                             "with independent measurement points spread on a regular grid (with -ns spacing in pixels between points). "+\
                                             "Around each point a cubic subvolume of +-hws (Half-window size) is extracted and correlated"+\
                                             "\nSee for more details: https://ttk.gricad-pages.univ-grenoble-alpes.fr/spam/tutorial-02b-DIC-practice.html",
                                 formatter_class=argparse.RawTextHelpFormatter)

# Parse arguments with external helper function
args = spam.helpers.optionsParser.ldicParser(parser)

if len(args.inFiles) < 2:
    print("\nldic: Did not receive enough input images... you need (at least) two to tango...")
    exit()

if args.PROCESSES is None: args.PROCESSES = multiprocessing.cpu_count()

print("spam-ldic -- Current Settings:")
argsDict = vars(args)
for key in sorted(argsDict):
    print("\t{}: {}".format(key, argsDict[key]))

# Load reference image
im1 = tifffile.imread(args.inFiles[0].name)

# Detect unpadded 2D image first:
if len(im1.shape) == 2:
    im1 = im1[numpy.newaxis, ...]
if im1.shape[0] == 1:
    twoD = True
else:
    twoD = False

if args.MASK1:
    im1mask = tifffile.imread(args.MASK1.name) != 0
    if len(im1mask.shape) == 2:
        im1mask = im1mask[numpy.newaxis, ...]
else:
    im1mask = None


### Interpolation settings
if args.INTERPOLATION_ORDER == 1:
    interpolator = 'C'
else:
    interpolator = 'python'
# Override interpolator for python in 2D
if twoD:
    interpolator = 'python'

margin = [-args.MARGIN[0], args.MARGIN[0],
          -args.MARGIN[1], args.MARGIN[1],
          -args.MARGIN[2], args.MARGIN[2]]


firstCorrelation = True

# Bad to redifine this for every loop, so it's defined here, to be called by the pool
def correlateOneNode(nodeNumber):
    """
    This function does a correlation at one point and returns:

    Returns
    -------
        List of:
        - nodeNumber
        - Phi
        - returnStatus
        - error
        - iterations
        - deltaPhiNorm
    """
    PhiInit = PhiField[nodeNumber]

    if numpy.isfinite(PhiInit).sum() == 16:
        imagetteReturns = spam.DIC.getImagettes(im1, nodePositions[nodeNumber], args.HWS, PhiInit.copy(), im2, margin, im1mask=im1mask, minMaskCoverage=args.MASK_COVERAGE, greyThreshold=[args.GREY_LOW_THRESH, args.GREY_HIGH_THRESH], applyF='no', twoD=twoD)

        if imagetteReturns['returnStatus'] == 1:
            # compute displacement that will be taken by the getImagettes
            initialDisplacement = numpy.round(PhiInit[0:3, 3]).astype(int)
            PhiInit[0:3,-1] -= initialDisplacement

            registerReturns = spam.DIC.register(imagetteReturns['imagette1'],
                                                imagetteReturns['imagette2'],
                                                im1mask=imagetteReturns['imagette1mask'],
                                                PhiInit=PhiInit, # minus initial displacement above, which is in the search range and thus taken into account in imagette2
                                                margin=1,  # see top of this file for compensation
                                                maxIterations=args.MAX_ITERATIONS,
                                                deltaPhiMin=args.MIN_DELTA_PHI,
                                                PhiRigid=args.RIGID,
                                                updateGradient=args.UPDATE_GRADIENT,
                                                interpolationOrder=args.INTERPOLATION_ORDER,
                                                verbose=False,
                                                imShowProgress=False)
            goodPhi = registerReturns['Phi']
            goodPhi[0:3,-1] += initialDisplacement
            return nodeNumber, goodPhi, registerReturns['returnStatus'], registerReturns['error'], registerReturns['iterations'], registerReturns['deltaPhiNorm']

        else:
            badPhi = numpy.eye(4)
            badPhi[0:3, 3] = numpy.nan
            return nodeNumber, badPhi, imagetteReturns['returnStatus'], numpy.inf, 0, numpy.inf
    else:
        ### Phi has nans or infs
        badPhi = numpy.eye(4)
        badPhi[0:3, 3] = numpy.nan
        return nodeNumber, badPhi, -7, numpy.inf, 0, numpy.inf


# Loop over input images
for im2number in range(1, len(args.inFiles)):
    # Variables to track last correlation in order to ask MPI workers to hang up
    if im2number == len(args.inFiles)-1: lastCorrelation = True
    else: lastCorrelation = False

    # decide on number, in input files list, of the reference image
    if args.SERIES_INCREMENTAL:
        im1number = im2number-1
    else:
        im1number = 0

    # Output file name prefix
    if args.PREFIX is None or len(args.inFiles) > 2:
        args.PREFIX = os.path.splitext(os.path.basename(args.inFiles[im1number].name))[0]+"-"+os.path.splitext(os.path.basename(args.inFiles[im2number].name))[0]

    # If not first correlation and we're interested in loading previous Ffile:
    if not firstCorrelation and args.SERIES_PHIFILE:
        args.PHIFILE = previousPhiFile

    print("\nCorrelating:", args.PREFIX)

    im2 = tifffile.imread(args.inFiles[im2number].name)
    if len(im2.shape) == 2:
        im2 = im2[numpy.newaxis, ...]

    assert(im1.shape == im2.shape), "\nim1 and im2 must have the same size! Exiting."
    if args.MASK1:
        assert(im1.shape == im1mask.shape), "\nim1 and im1mask must have the same size! Exiting."

    # Three cases to handle:
    #   1. phi file is reg   -> define nodes and apply reg
    #   2. phi file is field -> take everything and check NS if passed
    #   3. no phi file       -> define nodes
    if args.PHIFILE is not None:
        PhiFromFile = spam.helpers.readCorrelationTSV(args.PHIFILE.name, fieldBinRatio=args.PHIFILE_BIN_RATIO, readError=True)
        if PhiFromFile is None:
            print(f"\tFailed to read your TSV file passed with -pf {args.PHIFILE.name}")
            exit()

        # If the read Phi-file has only one line -- it's a single point registration!
        if PhiFromFile['fieldCoords'].shape[0] == 1:
            PhiInit = PhiFromFile['PhiField'][0]
            print(f"\tI read a registration from a file in binning {args.PHIFILE_BIN_RATIO}")

            decomposedPhiInit = spam.deformation.decomposePhi(PhiInit)
            print("\tTranslations (px)")
            print("\t\t", decomposedPhiInit['t'])
            print("\tRotations (deg)")
            print("\t\t", decomposedPhiInit['r'])
            print("\tZoom")
            print("\t\t", decomposedPhiInit['z'])
            del decomposedPhiInit

            # Create nodes
            if args.NS is None:
                print(f"spam-ldic: You passed a registration file {args.PHIFILE.name}, I need -ns to be defined")
                exit()
            nodePositions, nodesDim = spam.DIC.makeGrid(im1.shape, args.NS)
            numberOfNodes = nodePositions.shape[0]

            PhiField = spam.DIC.applyRegistrationToPoints(PhiInit.copy(),
                                                          PhiFromFile["fieldCoords"][0],
                                                          nodePositions,
                                                          applyF = args.APPLY_F,
                                                          nProcesses = args.PROCESSES,
                                                          verbose = False)

            error = numpy.zeros((numberOfNodes))
            iterations = numpy.zeros((numberOfNodes))
            returnStatus = numpy.zeros((numberOfNodes))
            deltaPhiNorm = numpy.zeros((numberOfNodes))

        else: # we have a Phi field and not a registration
            nodePositions   = PhiFromFile["fieldCoords"]
            numberOfNodes   = nodePositions.shape[0]
            nodesDim        = PhiFromFile["fieldDims"]
            nodeSpacingFile = numpy.array([numpy.unique(nodePositions[:, i])[1] - numpy.unique(nodePositions[:, i])[0] if len(numpy.unique(nodePositions[:, i])) > 1 else numpy.unique(nodePositions[:, i])[0] for i in range(3)])
            PhiField        = PhiFromFile["PhiField"]

            # GP: Adding skip nodes option, so we can run ldic only on the diverged nodes
            if args.SKIP_NODES:
                error = PhiFromFile["error"]
                iterations = PhiFromFile["iterations"]
                returnStatus = PhiFromFile["returnStatus"]
                deltaPhiNorm = PhiFromFile["deltaPhiNorm"]
            else:
                error = numpy.zeros((numberOfNodes))
                iterations = numpy.zeros((numberOfNodes))
                returnStatus = numpy.zeros((numberOfNodes))
                deltaPhiNorm = numpy.zeros((numberOfNodes))

            # In case NS is also defined, complain, but if it's the same as the loaded data, continue
            if args.NS is not None:
                # compare them
                if not numpy.allclose(numpy.array(args.NS), nodeSpacingFile, atol=0.0):
                    print(f"spam-ldic: you passed a -ns={args.NS} which contradicts the node spacing in your Phi Field TSV of {nodeSpacingFile}")
                    print(f"\thint 1: if you pass a Phi Field TSV you don't need to also define the node spacing")
                    print(f"\thint 2: if you want to use your Phi Field TSV {args.PHIFILE.name} on a finer node spacing, pass it with spam-passPhiField")
                    exit()
                else:
                    print(f"spam-ldic: passing -ns with a Phi Field TSV is not needed")

            # If it's compatible, update args.NS
            args.NS = nodeSpacingFile

    else: # No Phi file at all
        if args.NS is None:
            print("spam-ldic: I don't have a phi file or -ns defined, so don't know how to define grid...")
            exit()
        nodePositions, nodesDim = spam.DIC.makeGrid(im1.shape, args.NS)
        numberOfNodes = nodePositions.shape[0]

        PhiField = numpy.zeros((numberOfNodes, 4, 4))
        for node in range(numberOfNodes):
            PhiField[node] = numpy.eye(4)

        error = numpy.zeros((numberOfNodes))
        iterations = numpy.zeros((numberOfNodes))
        returnStatus = numpy.zeros((numberOfNodes))
        deltaPhiNorm = numpy.zeros((numberOfNodes))

    finishedNodes = 0

    # GP: Adding the skip function
    nodesToCorrelate = numpy.arange(0, numberOfNodes)

    if args.SKIP_NODES:
        nodesToCorrelate = numpy.where((returnStatus == -3) | (returnStatus == -2) | (returnStatus == -1) | (returnStatus == 1))[0]


    print("\n\tStarting local dic (with {} process{})".format(args.PROCESSES, 'es' if args.PROCESSES > 1 else ''))

    widgets = [progressbar.FormatLabel(''), ' ', progressbar.Bar(), ' ', progressbar.AdaptiveETA()]
    pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(nodesToCorrelate))
    pbar.start()
    finishedNodes = 0

    with multiprocessing.Pool(processes=args.PROCESSES) as pool:
        for returns in pool.imap_unordered(correlateOneNode, nodesToCorrelate):
            finishedNodes += 1

            # Update progres bar if point is not skipped
            if returns[2] > 0:
                widgets[0] = progressbar.FormatLabel("  it={:0>3d}  dPhiNorm={:0>6.4f}  rs={:+1d} ".format(returns[4], returns[5], returns[2]))
                pbar.update(finishedNodes)
            nodeNumber = returns[0]
            PhiField[nodeNumber]     = returns[1]
            returnStatus[nodeNumber] = returns[2]
            error[nodeNumber]        = returns[3]
            iterations[nodeNumber]   = returns[4]
            deltaPhiNorm[nodeNumber] = returns[5]

    pbar.finish()

    print("\n")



    ## Finished! Get ready for output.

    if args.TSV:
        # Make one big array for writing:
        #   First the node number,
        #   3 node positions,
        #   F[0:3,0:2]
        #   Pixel-search CC
        #   SubPixError, SubPixIterations, SubPixelReturnStatus
        TSVheader = "NodeNumber\tZpos\tYpos\tXpos\tFzz\tFzy\tFzx\tZdisp\tFyz\tFyy\tFyx\tYdisp\tFxz\tFxy\tFxx\tXdisp\terror\titerations\treturnStatus\tdeltaPhiNorm"
        outMatrix = numpy.array([numpy.array(range(nodePositions.shape[0])),
                                    nodePositions[:, 0], nodePositions[:, 1], nodePositions[:, 2],
                                    PhiField[:, 0, 0],   PhiField[:, 0, 1],   PhiField[:, 0, 2],    PhiField[:, 0, 3],
                                    PhiField[:, 1, 0],   PhiField[:, 1, 1],   PhiField[:, 1, 2],    PhiField[:, 1, 3],
                                    PhiField[:, 2, 0],   PhiField[:, 2, 1],   PhiField[:, 2, 2],    PhiField[:, 2, 3],
                                    error,               iterations,          returnStatus,         deltaPhiNorm]).T

        numpy.savetxt(args.OUT_DIR+"/"+args.PREFIX+"-ldic.tsv",
                        outMatrix,
                        fmt='%.7f',
                        delimiter='\t',
                        newline='\n',
                        comments='',
                        header=TSVheader)
        ## Hold onto that name if we need to reload
        #if args.SERIES_PHIFILE: previousPhiFile = args.OUT_DIR+"/"+args.PREFIX+".tsv"

    if args.TIFF:
        if nodesDim[0] != 1:
            tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-ldic-Zdisp.tif",              PhiField[:, 0, -1].astype('<f4').reshape(nodesDim))
        tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-ldic-Ydisp.tif",                  PhiField[:, 1, -1].astype('<f4').reshape(nodesDim))
        tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-ldic-Xdisp.tif",                  PhiField[:, 2, -1].astype('<f4').reshape(nodesDim))
        tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-ldic-error.tif",        error.astype('<f4').reshape(nodesDim))
        tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-ldic-iterations.tif",   iterations.astype('<f4').reshape(nodesDim))
        tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-ldic-returnStatus.tif", returnStatus.astype('<f4').reshape(nodesDim))
        tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-ldic-deltaPhiNorm.tif", deltaPhiNorm.astype('<f4').reshape(nodesDim))

    # Collect data into VTK output
    if args.VTK:
        cellData = {}
        cellData['displacements'] = PhiField[:, :-1, 3].reshape((nodesDim[0], nodesDim[1], nodesDim[2], 3))
        cellData['error']         = error.reshape(nodesDim)
        cellData['iterations']    = iterations.reshape(nodesDim)
        cellData['returnStatus']  = returnStatus.reshape(nodesDim)
        cellData['deltaPhiNorm']  = deltaPhiNorm.reshape(nodesDim)

        cellData['error'       ][numpy.logical_not(numpy.isfinite(cellData['error']))]        = 0
        cellData['iterations'  ][numpy.logical_not(numpy.isfinite(cellData['iterations']))]   = 0
        cellData['returnStatus'][numpy.logical_not(numpy.isfinite(cellData['returnStatus']))] = 0
        cellData['deltaPhiNorm'][numpy.logical_not(numpy.isfinite(cellData['deltaPhiNorm']))] = 0

        # Overwrite nans and infs with 0, rubbish I know
        cellData['displacements'][numpy.logical_not(numpy.isfinite(cellData['displacements']))] = 0

        # This is perfect in the case where NS = 2xHWS, these cells will all be in the right place
        #   In the case of overlapping of under use of data, it should be approximately correct
        # If you insist on overlapping, then perhaps it's better to save each point as a cube glyph
        #   and actually *have* overlapping
        spam.helpers.writeStructuredVTK(origin=nodePositions[0]-args.HWS, aspectRatio=args.NS, cellData=cellData, fileName=args.OUT_DIR+"/"+args.PREFIX+"-ldic.vtk")
    firstCorrelation = False

    if args.SERIES_INCREMENTAL:
        # If in incremental mode, current deformed image is next reference image
        im1 = im2.copy()
