#!python

# Copyright (C) 2020 SPAM Contributors
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, see <http://www.gnu.org/licenses/>.

"""
This script performs a point-by-point "pixel search" which computes a correlation coefficient
of an imagette extracted in im1 to a brute-force search in a given search range in z, y, x in image 2.

Imagettes in im1 can either be defined with a nodeSpacing and a halfWindowSize or a labelled image.
"""


import spam.DIC
import spam.mesh
import spam.helpers

import os

import numpy
import progressbar
import argparse
import tifffile
from scipy.spatial import KDTree

os.environ["OPENBLAS_NUM_THREADS"] = "1"
numpy.seterr(all="ignore")


# Define argument parser object
parser = argparse.ArgumentParser(
    description="spam-pixelSearchPropagate "
    + spam.helpers.optionsParser.GLPv3descriptionHeader
    + "This script performs a pixel search from im1 to im2 propagating the motion from the top guiding point\n",
    formatter_class=argparse.RawTextHelpFormatter,
)

# Parse arguments with external helper function
args = spam.helpers.optionsParser.pixelSearchPropagate(parser)

print("Current Settings:")
argsDict = vars(args)
for key in sorted(argsDict):
    print("\t{}: {}".format(key, argsDict[key]))

# Fill in search range
searchRange = numpy.array(
    [
        args.SEARCH_RANGE[0],
        args.SEARCH_RANGE[1],
        args.SEARCH_RANGE[2],
        args.SEARCH_RANGE[3],
        args.SEARCH_RANGE[4],
        args.SEARCH_RANGE[5],
    ]
)
pixelSearchCCmin = args.CC_MIN
weightingDistance = args.DIST

startPoint = numpy.array(args.START_POINT_DISP[0:3])
startPointDisplacement = numpy.array(args.START_POINT_DISP[3:6])

# Load reference image
im1 = tifffile.imread(args.im1.name)
# Detect unpadded 2D image first:
# if len(im1.shape) == 2: im1 = im1[numpy.newaxis, ...]
# if im1.shape[0] == 1:   twoD = True
# else:                   twoD = False

# Load deformed image
im2 = tifffile.imread(args.im2.name)
# if len(im2.shape) == 2: im2 = im2[numpy.newaxis, ...]

if args.MASK1:
    im1mask = tifffile.imread(args.MASK1.name) != 0
    # if len(im1mask.shape) == 2:
    #    im1mask = im1mask[numpy.newaxis, ...]
else:
    im1mask = None

if args.MASK2:
    im2mask = tifffile.imread(args.MASK2.name) != 0
    # if len(im2mask.shape) == 2:
    #    im2mask = im2mask[numpy.newaxis, ...]
else:
    im2mask = None

# There are 3 modes:
# - points-to-correlate defined by input "guiding points", which should be points with good texture
# - points-to-correlate defined by labelled image
# - points-to-correlate defined by regular grid

# Detect guiding points mode
if args.GUIDING_POINTS_FILE is not None:
    gp = numpy.genfromtxt(args.gp.name)

# ...or label mode
elif args.LAB1 is not None:
    lab1 = tifffile.imread(args.LAB1.name).astype(spam.label.labelType)
    boundingBoxes = spam.label.boundingBoxes(lab1)
    centresOfMass = spam.label.centresOfMass(lab1, boundingBoxes=boundingBoxes)
    im1mask = None
    im2mask = None
    gp = centresOfMass.copy()
    gp[0] = startPoint

else:
    gp, nodesDim = spam.DIC.makeGrid(im1.shape, args.NS)
    gp = numpy.vstack([startPoint, gp])


print("\n\tRanking points")
guidingPoints, rowNumbers = spam.mesh.rankPoints(gp, neighbourRadius=args.RADIUS)
numberOfPoints = guidingPoints.shape[0]

# Initialise arrays
PhiField = numpy.zeros((numberOfPoints, 4, 4))
for point in range(numberOfPoints):
    PhiField[point] = numpy.eye(4)

PhiField[0, 0:3, -1] += startPointDisplacement

pixelSearchCC = numpy.zeros((numberOfPoints), dtype=float)
# Returns compatible with register()
error = numpy.zeros((numberOfPoints), dtype=float)
returnStatus = numpy.zeros((numberOfPoints), dtype=int)
deltaPhiNorm = numpy.ones((numberOfPoints), dtype=int)
iterations = numpy.zeros((numberOfPoints), dtype=int)

print("\n\tStarting sequential Pixel Search")
widgets = [
    progressbar.FormatLabel(""),
    " ",
    progressbar.Bar(),
    " ",
    progressbar.AdaptiveETA(),
]
pbar = progressbar.ProgressBar(widgets=widgets, maxval=numberOfPoints)
pbar.start()


# Step 1: simple PS for first point
if args.LAB1 is not None:
    imagetteReturnsTop = spam.label.getImagettesLabelled(
        lab1,
        lab1[startPoint[0], startPoint[1], startPoint[2]],
        PhiField[0].copy(),
        im1,
        im2,
        searchRange.copy(),
        boundingBoxes,
        centresOfMass,
        margin=args.LABEL_DILATE,
        labelDilate=args.LABEL_DILATE,
        applyF="no",
        volumeThreshold=3**3,
    )
    imagetteReturnsTop["imagette2mask"] = None
else:
    imagetteReturnsTop = spam.DIC.getImagettes(
        im1,
        guidingPoints[0],
        args.HWS,
        PhiField[0].copy(),
        im2,
        searchRange.copy(),
        im1mask=im1mask,
        im2mask=im2mask,
        minMaskCoverage=args.MASK_COVERAGE,
        greyThreshold=[args.GREY_LOW_THRESH, args.GREY_HIGH_THRESH],
        applyF="no",
    )


# If getImagettes was successful (size check and mask coverage check)
if imagetteReturnsTop["returnStatus"] == 1:
    PSreturnsTop = spam.DIC.pixelSearch(
        imagetteReturnsTop["imagette1"],
        imagetteReturnsTop["imagette2"],
        imagette1mask=imagetteReturnsTop["imagette1mask"],
        imagette2mask=imagetteReturnsTop["imagette2mask"],
        returnError=True,
    )

    PhiField[0, 0:3, -1] = PSreturnsTop[0] + imagetteReturnsTop["pixelSearchOffset"]
    pixelSearchCC[0] = PSreturnsTop[1]
    error[0] = PSreturnsTop[2]
# Failed to extract imagettes or something
else:
    print("Failed to extract correlation window for starting point, exiting")
    exit()
if pixelSearchCC[0] < args.CC_MIN:
    print("CC obtained for starting point is less than threshold, not continuing")
    exit()

# Step 2: Loop sequentially over the guiding points list
# 2.1: create the tree of the coordinates to find easily neighbours
treeCoord = KDTree(guidingPoints)
for point in range(1, numberOfPoints):
    indices = []
    radius = args.RADIUS
    # 2.2: Extract good neighbours
    #      double the radius until it finds at least 1 point in the vicinity
    while len(indices) < 1:
        indices = numpy.array(treeCoord.query_ball_point(guidingPoints[point], radius))
        # Discard current point and points with low CC from indices
        indices = numpy.delete(
            indices,
            numpy.where(
                numpy.logical_or(
                    indices == point, pixelSearchCC[indices] < pixelSearchCCmin
                )
            )[0],
        )
        radius *= 2

    # 2.3: Estimate initial displacement
    #      by a gaussian weighting of extracted good neighbours
    distances = numpy.linalg.norm(guidingPoints[point] - guidingPoints[indices], axis=1)
    weights = numpy.exp(-(distances**2) / weightingDistance**2)
    initialDisplacement = (
        numpy.sum(PhiField[indices, 0:3, -1] * weights[:, numpy.newaxis], axis=0)
        / weights.sum()
    )

    # 2.4: Call PS around the estimated position
    PhiField[point, 0:3, -1] = initialDisplacement

    if args.LAB1 is not None:
        imagetteReturns = spam.label.getImagettesLabelled(
            lab1,
            rowNumbers[point],
            PhiField[0].copy(),
            im1,
            im2,
            searchRange.copy(),
            boundingBoxes,
            centresOfMass,
            margin=args.LABEL_DILATE,
            labelDilate=args.LABEL_DILATE,
            applyF="no",
            volumeThreshold=3**3,
        )
        imagetteReturns["imagette2mask"] = None

    else:
        imagetteReturns = spam.DIC.getImagettes(
            im1,
            guidingPoints[point],
            args.HWS,
            PhiField[point].copy(),
            im2,
            searchRange.copy(),
            im1mask=im1mask,
            im2mask=im2mask,
            minMaskCoverage=args.MASK_COVERAGE,
            greyThreshold=[args.GREY_LOW_THRESH, args.GREY_HIGH_THRESH],
            applyF="no",
        )

    if imagetteReturns["returnStatus"] == 1:
        PSreturns = spam.DIC.pixelSearch(
            imagetteReturns["imagette1"],
            imagetteReturns["imagette2"],
            imagette1mask=imagetteReturns["imagette1mask"],
            imagette2mask=imagetteReturns["imagette2mask"],
            returnError=True,
        )

        PhiField[point, 0:3, -1] = PSreturns[0] + imagetteReturns["pixelSearchOffset"]
        pixelSearchCC[point] = PSreturns[1]
        error[point] = PSreturns[2]
        returnStatus[point] = imagetteReturns["returnStatus"]

        widgets[0] = progressbar.FormatLabel("  CC={:0>7.5f} ".format(PSreturns[1]))
        pbar.update(point)
    else:
        PhiField[point, 0:3, -1] = [numpy.nan] * 3
        error[point] = numpy.inf
        returnStatus[point] = imagetteReturns["returnStatus"]

# Detect regular grid mode
# if args.GUIDING_POINTS_FILE is None:
if args.GUIDING_POINTS_FILE is None and args.LAB1 is None:
    rowNumbers = rowNumbers[1:]

guidingPoints = guidingPoints[rowNumbers]
PhiField = PhiField[rowNumbers]
error = error[rowNumbers]
returnStatus = returnStatus[rowNumbers]
pixelSearchCC = pixelSearchCC[rowNumbers]
deltaPhiNorm = deltaPhiNorm[rowNumbers]
iterations = iterations[rowNumbers]

if args.GUIDING_POINTS_FILE is None and args.LAB1 is None:
    if args.TIFF:
        tifffile.imwrite(
            args.OUT_DIR + "/" + args.PREFIX + "-Zdisp.tif",
            PhiField[:, 0, -1].astype("<f4").reshape(nodesDim),
        )
        tifffile.imwrite(
            args.OUT_DIR + "/" + args.PREFIX + "-Ydisp.tif",
            PhiField[:, 1, -1].astype("<f4").reshape(nodesDim),
        )
        tifffile.imwrite(
            args.OUT_DIR + "/" + args.PREFIX + "-Xdisp.tif",
            PhiField[:, 2, -1].astype("<f4").reshape(nodesDim),
        )
        tifffile.imwrite(
            args.OUT_DIR + "/" + args.PREFIX + "-CC.tif",
            pixelSearchCC.astype("<f4").reshape(nodesDim),
        )
        tifffile.imwrite(
            args.OUT_DIR + "/" + args.PREFIX + "-returnStatus.tif",
            returnStatus.astype("<f4").reshape(nodesDim),
        )

if args.TSV:
    # Make one big array for writing:
    #   First the node number,
    #   3 node positions,
    #   F[0:3,0:3]
    #   Pixel-search CC
    outMatrix = numpy.array(
        [
            numpy.array(range(guidingPoints.shape[0])),
            guidingPoints[:, 0],
            guidingPoints[:, 1],
            guidingPoints[:, 2],
            PhiField[:, 0, 0],
            PhiField[:, 0, 1],
            PhiField[:, 0, 2],
            PhiField[:, 0, 3],
            PhiField[:, 1, 0],
            PhiField[:, 1, 1],
            PhiField[:, 1, 2],
            PhiField[:, 1, 3],
            PhiField[:, 2, 0],
            PhiField[:, 2, 1],
            PhiField[:, 2, 2],
            PhiField[:, 2, 3],
            pixelSearchCC,
            returnStatus,
            error,
            deltaPhiNorm,
            iterations,
        ]
    ).T
    if args.LAB1 is None:
        TSVheader = "NodeNumber\tZpos\tYpos\tXpos\tFzz\tFzy\tFzx\tZdisp\tFyz\tFyy\tFyx\tYdisp\tFxz\tFxy\tFxx\tXdisp\tpixelSearchCC\treturnStatus\terror\tdeltaPhiNorm\titerations"

    else:  # Lab mode, need to pad one 0 row to the matrix
        TSVheader = "Label\tZpos\tYpos\tXpos\tFzz\tFzy\tFzx\tZdisp\tFyz\tFyy\tFyx\tYdisp\tFxz\tFxy\tFxx\tXdisp\tpixelSearchCC\treturnStatus\terror\tdeltaPhiNorm\titerations"
        outMatrix[0, :] = 0

    numpy.savetxt(
        args.OUT_DIR + "/" + args.PREFIX + ".tsv",
        outMatrix,
        fmt="%.7f",
        delimiter="\t",
        newline="\n",
        comments="",
        header=TSVheader,
    )

# Collect data into VTK output
if args.VTK:
    if args.GUIDING_POINTS_FILE is None and args.LAB1 is None:
        cellData = {}
        cellData["displacements"] = PhiField[:, :-1, 3].reshape(
            (nodesDim[0], nodesDim[1], nodesDim[2], 3)
        )
        cellData["pixelSearchCC"] = pixelSearchCC.reshape(nodesDim)

        # Overwrite nans and infs with 0, rubbish I know
        cellData["displacements"][
            numpy.logical_not(numpy.isfinite(cellData["displacements"]))
        ] = 0

        # This is perfect in the case where NS = 2xHWS, these cells will all be in the right place
        #   In the case of overlapping of under use of data, it should be approximately correct
        # If you insist on overlapping, then perhaps it's better to save each point as a cube glyph
        #   and actually *have* overlapping
        spam.helpers.writeStructuredVTK(
            origin=guidingPoints[0] - args.HWS,
            aspectRatio=args.NS,
            cellData=cellData,
            fileName=args.OUT_DIR + "/" + args.PREFIX + ".vtk",
        )

    else:
        # boring nans overwriting
        disp = PhiField[:, 0:3, -1]
        disp[numpy.logical_not(numpy.isfinite(disp))] = 0

        magDisp = numpy.linalg.norm(disp, axis=1)

        VTKglyphDict = {
            "displacements": disp,
            "mag(displacements)": magDisp,
            "pixelSearchCC": pixelSearchCC,
        }

        spam.helpers.writeGlyphsVTK(
            guidingPoints,
            VTKglyphDict,
            fileName=args.OUT_DIR + "/" + args.PREFIX + ".vtk",
        )
