#!python


# This python script performs Discrete Digital Image Correlation using SPAM functions
# Copyright (C) 2020 SPAM Contributors
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import multiprocessing

try:
    multiprocessing.set_start_method("fork")
except RuntimeError:
    pass
import os

import numpy
import progressbar
import spam.deformation
import spam.DIC
import spam.helpers
import spam.label
import tifffile

numpy.seterr(all="ignore")
os.environ["OPENBLAS_NUM_THREADS"] = "1"


# Define argument parser object
parser = argparse.ArgumentParser(
    description="spam-ddic "
    + spam.helpers.optionsParser.GLPv3descriptionHeader
    + "This script performs Discrete Digital Image Correlation script between two 3D greyscale images"
    + " (reference and deformed configurations) and requires the input of a labelled image for the reference configuration"
    + "\nSee for more details: https://ttk.gricad-pages.univ-grenoble-alpes.fr/spam/tutorial-04-discreteDIC.html",
    formatter_class=argparse.RawTextHelpFormatter,
)

# Parse arguments with external helper function
args = spam.helpers.optionsParser.ddicParser(parser)

if args.PROCESSES is None:
    args.PROCESSES = multiprocessing.cpu_count()

print("spam-ddic -- Current Settings:")
argsDict = vars(args)
for key in sorted(argsDict):
    print("\t{}: {}".format(key, argsDict[key]))


print("\nspam-ddic: Loading Data...", end="")
im1 = tifffile.imread(args.im1.name)
lab1 = tifffile.imread(args.lab1.name).astype(spam.label.labelType)
im2 = tifffile.imread(args.im2.name)
print("done.")

# Detect unpadded 2D image first:
twoD = False
if len(im1.shape) == 2:
    im1 = im1[numpy.newaxis, ...]
    twoD = True
if len(lab1.shape) == 2:
    lab1 = lab1[numpy.newaxis, ...]
if len(im2.shape) == 2:
    im2 = im2[numpy.newaxis, ...]

assert im1.shape == im2.shape, "\nim1 and im2 must have the same size! Exiting."
assert im1.shape == lab1.shape, "\nim1 and lab1 must have the same size! Exiting."

###############################################################
# Analyse labelled volume in state 01 in order to get bounding
# boxes and centres of mass for correlation
###############################################################
numberOfLabels = (lab1.max() + 1).astype("u4")

print("spam-ddic: Number of labels = {}\n".format(numberOfLabels - 1))

print("spam-ddic: Calculating Bounding Boxes and Centres of Mass of all labels.")
boundingBoxes = spam.label.boundingBoxes(lab1)
centresOfMass = spam.label.centresOfMass(lab1, boundingBoxes=boundingBoxes)
print("\n")

###############################################################
# Set up kinematics array
###############################################################
PhiField = numpy.zeros((numberOfLabels, 4, 4), dtype="<f4")
PSCC = numpy.zeros((numberOfLabels), dtype="<f4")
error = numpy.zeros((numberOfLabels), dtype="<f4")
iterations = numpy.zeros((numberOfLabels), dtype="<u2")
returnStatus = numpy.zeros((numberOfLabels), dtype="<i2")
deltaPhiNorm = numpy.zeros((numberOfLabels), dtype="<f4")
labelDilateList = numpy.zeros((numberOfLabels), dtype="<u2")

# Initialise field of Fs with the identity matrix
for label in range(numberOfLabels):
    PhiField[label] = numpy.eye(4)
# define empty rigid displacements for registration:
# if args.REGSUB: rigidDisp = numpy.zeros((numberOfLabels, 3))


# Option 2 - load previous DVC
#################################
if args.PHIFILE is not None:
    PhiFromFile = spam.helpers.readCorrelationTSV(
        args.PHIFILE.name,
        fieldBinRatio=args.PHIFILE_BIN_RATIO,
        readConvergence=True,
        readError=True,
        readLabelDilate=True,
        readPixelSearchCC=True,
    )

    # If the read Phi-file has only one line -- it's a single point registration!
    if PhiFromFile["fieldCoords"].shape[0] == 1:
        PhiInit = PhiFromFile["PhiField"][0]
        print("\tI read a registration from a file in binning {}".format(args.PHIFILE_BIN_RATIO))

        decomposedPhiInit = spam.deformation.decomposePhi(PhiInit)
        print("\tTranslations (px)")
        print("\t\t", decomposedPhiInit["t"])
        print("\tRotations (deg)")
        print("\t\t", decomposedPhiInit["r"])
        print("\tZoom")
        print("\t\t", decomposedPhiInit["z"])

        PhiField = spam.DIC.applyRegistrationToPoints(
            PhiInit.copy(),
            PhiFromFile["fieldCoords"][0],
            centresOfMass,
            applyF=args.APPLY_F,
            nProcesses=args.PROCESSES,
            verbose=False,
        )

    # If the read Phi-file contains multiple lines it's an F field!
    else:
        # print("spam-ddic: Assuming loaded PhiFile is coherent with the current run (i.e., labels are the same).")
        PhiField = PhiFromFile["PhiField"]
        # Also check that the node positions are approx the same as from the labelled image above:
        if not numpy.allclose(PhiFromFile["fieldCoords"], centresOfMass, atol=0.1):
            print(PhiFromFile["fieldCoords"])
            print(centresOfMass)
            print(f"spam-ddic: Input PhiField positions from {args.PHIFILE.name} are not within 1px of the centre of mass of the labels from {args.lab1.name}, this seems dangerous.")
            print("\tplease consider using spam-passPhiField to apply your PhiField to a new labelled image")
            exit()
        else:
            if args.SKIP_PARTICLES:
                # Read the previous result for all grains -- new grains will be overwritten
                returnStatus = PhiFromFile["returnStatus"]
                iterations = PhiFromFile["iterations"]
                deltaPhiNorm = PhiFromFile["deltaPhiNorm"]
                labelDilateList = PhiFromFile["LabelDilate"]
                error = PhiFromFile["error"]
                PSCC = PhiFromFile["pixelSearchCC"]


def correlateOneLabel(label):
    # label, labelDilateCurrent = q.get()

    # WARNING HACK BAD FIXME
    labelDilateCurrent = args.LABEL_DILATE

    initialDisplacement = numpy.round(PhiField[label][0:3, 3]).astype(int)

    if args.DEBUG:
        print("\n\n\nWorking on label:", label, "\n")
    if args.DEBUG:
        print("Position (ZYX):", centresOfMass[label])

    imagetteReturns = spam.label.getImagettesLabelled(
        lab1,
        label,
        PhiField[label],
        im1,
        im2,
        [0, 0, 0, 0, 0, 0],  # Search range, don't worry about it
        boundingBoxes,
        centresOfMass,
        margin=args.MARGIN,
        labelDilate=labelDilateCurrent,
        maskOtherLabels=args.MASK_OTHERS,
        applyF="no",
        volumeThreshold=args.VOLUME_THRESHOLD,
    )

    if twoD:
        imagetteReturns["imagette1"] = imagetteReturns["imagette1"][int(imagetteReturns["imagette1"].shape[0] - 1) // 2, :, :]
        imagetteReturns["imagette2"] = imagetteReturns["imagette2"][int(imagetteReturns["imagette2"].shape[0] - 1) // 2, :, :]
        imagetteReturns["imagette1mask"] = imagetteReturns["imagette1mask"][int(imagetteReturns["imagette1mask"].shape[0] - 1) // 2, :, :]

    badPhi = numpy.eye(4)
    badPhi[0:3, 3] = numpy.nan

    # In case the label is missing or the Phi is duff
    if imagetteReturns["returnStatus"] != 1 or not numpy.all(numpy.isfinite(PhiField[label])):
        return label, badPhi, -7, numpy.inf, 0, numpy.inf, labelDilateCurrent

    else:
        # Remove int() part of displacement since it's already used to extract imagette2
        PhiTemp = PhiField[label].copy()
        PhiTemp[0:3, -1] -= initialDisplacement
        if args.DEBUG:
            print("Starting lk iterations with Phi - int(disp):\n", PhiTemp)
        if args.DEBUG:
            print("\nStarting lk iterations with int(disp):\n", initialDisplacement)

        registerReturns = spam.DIC.registerMultiscale(
            imagetteReturns["imagette1"],
            imagetteReturns["imagette2"],
            args.MULTISCALE_BINNING,
            im1mask=imagetteReturns["imagette1mask"],
            margin=1,
            PhiInit=PhiTemp,
            PhiRigid=args.CORRELATE_RIGID,
            updateGradient=args.UPDATE_GRADIENT,
            maxIterations=args.MAX_ITERATIONS,
            deltaPhiMin=args.MIN_PHI_CHANGE,
            interpolationOrder=args.INTERPOLATION_ORDER,
            verbose=args.DEBUG,
            imShowProgress=args.DEBUG,
        )
        goodPhi = registerReturns["Phi"]
        goodPhi[0:3, -1] += initialDisplacement
        return (
            label,
            goodPhi,
            registerReturns["returnStatus"],
            registerReturns["error"],
            registerReturns["iterations"],
            registerReturns["deltaPhiNorm"],
            labelDilateCurrent,
        )


# Add labels to a queue -- mostly useful for MPI
# q = queue.Queue()
labelsToCorrelate = numpy.arange(0, numberOfLabels)

if args.SKIP_PARTICLES:
    labelsToCorrelate = numpy.delete(labelsToCorrelate, numpy.where(returnStatus == 2))
    labelsToCorrelate = numpy.delete(labelsToCorrelate, 0)
else:
    labelsToCorrelate = numpy.delete(labelsToCorrelate, 0)

print("\n\tStarting Discrete DIC (with {} process{})".format(args.PROCESSES, "es" if args.PROCESSES > 1 else ""))
widgets = [
    progressbar.FormatLabel(""),
    " ",
    progressbar.Bar(),
    " ",
    progressbar.AdaptiveETA(),
]
pbar = progressbar.ProgressBar(widgets=widgets, maxval=len(labelsToCorrelate))
pbar.start()

finishedLabels = 0

with multiprocessing.Pool(processes=args.PROCESSES) as pool:
    for returns in pool.imap_unordered(correlateOneLabel, labelsToCorrelate):
        finishedLabels += 1

        # Update progres bar if point is not skipped
        if returns[2] > 0:
            widgets[0] = progressbar.FormatLabel("  it={:0>3d}  dPhiNorm={:0>6.4f}  rs={:+1d} ".format(returns[4], returns[5], returns[2]))
            pbar.update(finishedLabels)
        label = returns[0]
        PhiField[label] = returns[1]
        returnStatus[label] = returns[2]
        error[label] = returns[3]
        iterations[label] = returns[4]
        deltaPhiNorm[label] = returns[5]
        labelDilateList[label] = returns[6]

pbar.finish()


print("\n")


# Redundant output for VTK visualisation
magDisp = numpy.zeros(numberOfLabels)
for label in range(numberOfLabels):
    magDisp[label] = numpy.linalg.norm(PhiField[label][0:3, -1])

# Finished! Get ready for output.
# if args.REGSUB:
# print("\n\tFinished correlations. Subtracting rigid-body motion from displacements of each particle")
# PhiFieldMinusRigid = PhiField.copy()
# magDispRegsub = numpy.zeros(numberOfLabels)
# for label in range(numberOfLabels):
# PhiFieldMinusRigid[label][0:3,-1] -= rigidDisp[label]
# magDispRegsub[label] = numpy.linalg.norm(PhiFieldMinusRigid[label][0:3,-1])

outMatrix = numpy.array(
    [
        numpy.array(range(numberOfLabels)),
        centresOfMass[:, 0],
        centresOfMass[:, 1],
        centresOfMass[:, 2],
        PhiField[:, 0, 3],
        PhiField[:, 1, 3],
        PhiField[:, 2, 3],
        PhiField[:, 0, 0],
        PhiField[:, 0, 1],
        PhiField[:, 0, 2],
        PhiField[:, 1, 0],
        PhiField[:, 1, 1],
        PhiField[:, 1, 2],
        PhiField[:, 2, 0],
        PhiField[:, 2, 1],
        PhiField[:, 2, 2],
        PSCC,
        error,
        iterations,
        returnStatus,
        deltaPhiNorm,
        labelDilateList,
    ]
).T

numpy.savetxt(
    args.OUT_DIR + "/" + args.PREFIX + "-ddic.tsv",
    outMatrix,
    fmt="%.7f",
    delimiter="\t",
    newline="\n",
    comments="",
    header="Label\tZpos\tYpos\tXpos\t" + "Zdisp\tYdisp\tXdisp\t" + "Fzz\tFzy\tFzx\t" + "Fyz\tFyy\tFyx\t" + "Fxz\tFxy\tFxx\t" + "PSCC\terror\titerations\treturnStatus\tdeltaPhiNorm\tLabelDilate",
)

# Prepare VTK outputs with no nans
dispField = PhiField[:, 0:3, -1]
dispFieldNoNans = dispField.copy()
dispFieldNoNans[numpy.isnan(dispFieldNoNans)] = 0.0
magDispNoNans = magDisp.copy()
magDispNoNans[numpy.isnan(magDispNoNans)] = 0.0
centresOfMassNoNans = centresOfMass.copy()
centresOfMassNoNans[numpy.isnan(centresOfMassNoNans)] = 0.0

VTKglyphDict = {
    "displacements": dispFieldNoNans,
    "mag(displacements)": magDispNoNans,
    "returnStatus": returnStatus,
}

# if regsub add a line to VTK output and also save separate TSV file
# if args.REGSUB:
# VTKglyphDict['displacements-regsub'] = PhiFieldMinusRigid[:, 0:3, -1]
# VTKglyphDict['mag(displacements-regsub)'] = magDispRegsub

# outMatrix = numpy.array([numpy.array(range(numberOfLabels)),
# centresOfMass[:, 0], centresOfMass[:, 1], centresOfMass[:, 2],
# PhiFieldMinusRigid[:, 0, 3], PhiFieldMinusRigid[:, 1, 3], PhiFieldMinusRigid[:, 2, 3],
# PhiFieldMinusRigid[:, 0, 0], PhiFieldMinusRigid[:, 0, 1], PhiFieldMinusRigid[:, 0, 2],
# PhiFieldMinusRigid[:, 1, 0], PhiFieldMinusRigid[:, 1, 1], PhiFieldMinusRigid[:, 1, 2],
# PhiFieldMinusRigid[:, 2, 0], PhiFieldMinusRigid[:, 2, 1], PhiFieldMinusRigid[:, 2, 2],
# PSCC,
# error, iterations, returnStatus, deltaPhiNorm,
# labelDilateList]).T

# numpy.savetxt(args.OUT_DIR + "/" + args.PREFIX + "-discreteDVC-regsub.tsv",
# outMatrix,
# fmt='%.7f',
# delimiter='\t',
# newline='\n',
# comments='',
# header="Label\tZpos\tYpos\tXpos\t" +
# "Zdisp\tYdisp\tXdisp\t" +
# "Fzz\tFzy\tFzx\t" +
# "Fyz\tFyy\tFyx\t" +
# "Fxz\tFxy\tFxx\t" +
# "PSCC\terror\titerations\treturnStatus\tdeltaPhiNorm\tLabelDilate")

spam.helpers.writeGlyphsVTK(
    centresOfMassNoNans,
    VTKglyphDict,
    fileName=args.OUT_DIR + "/" + args.PREFIX + "-ddic.vtk",
)
