#!python

# Copyright (C) 2020 SPAM Contributors
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, see <http://www.gnu.org/licenses/>.

"""
This script performs a point-by-point "pixel search" which computes a correlation coefficient
of an imagette extracted in im1 to a brute-force search in a given search range in z, y, x in image 2.

Imagettes in im1 can either be defined with a nodeSpacing and a halfWindowSize or a labelled image.
"""


import spam.DIC
import spam.deformation
import spam.helpers

# import spam.mesh
import spam.label

import os

import numpy

import multiprocessing
try:                 multiprocessing.set_start_method('fork')
except RuntimeError: pass
import progressbar
import argparse
import tifffile

os.environ["OPENBLAS_NUM_THREADS"] = "1"
numpy.seterr(all="ignore")


# Define argument parser object
parser = argparse.ArgumentParser(
    description="spam-pixelSearch "
    + spam.helpers.optionsParser.GLPv3descriptionHeader
    + "This script performs a pixel search from im1 to im2\n",
    formatter_class=argparse.RawTextHelpFormatter,
)

# Parse arguments with external helper function
args = spam.helpers.optionsParser.pixelSearch(parser)

if args.PROCESSES is None:
    args.PROCESSES = multiprocessing.cpu_count()

print("spam-pixelSearch -- Current Settings:")
argsDict = vars(args)
for key in sorted(argsDict):
    print("\t{}: {}".format(key, argsDict[key]))

# Fill in search range
searchRange = numpy.array(
    [
        args.SEARCH_RANGE[0],
        args.SEARCH_RANGE[1],
        args.SEARCH_RANGE[2],
        args.SEARCH_RANGE[3],
        args.SEARCH_RANGE[4],
        args.SEARCH_RANGE[5],
    ]
)

# Load reference image
im1 = tifffile.imread(args.im1.name)


# Detect unpadded 2D image first:
if len(im1.shape) == 2:
    im1 = im1[numpy.newaxis, ...]
if im1.shape[0] == 1:
    twoD = True
else:
    twoD = False

# Load deformed image
im2 = tifffile.imread(args.im2.name)
if len(im2.shape) == 2:
    im2 = im2[numpy.newaxis, ...]


###############################################################
# Main switch for LAB or GRID mode
###############################################################
# If there is a labelled image, load that...
if args.LAB1 is not None:
    lab1 = tifffile.imread(args.LAB1.name).astype(spam.label.labelType)
    boundingBoxes = spam.label.boundingBoxes(lab1)
    nodePositions = spam.label.centresOfMass(lab1, boundingBoxes=boundingBoxes)
    numberOfNodes = nodePositions.shape[0]
    im1mask = None
    im2mask = None
    if twoD:
        lab1 = lab1[numpy.newaxis, ...]

# Otherwise we are in node spacing and half-window size mode
else:
    if args.MASK1 is not None:
        im1mask = tifffile.imread(args.MASK1.name) != 0
        if len(im1mask.shape) == 2:
            im1mask = im1mask[numpy.newaxis, ...]
    else:
        im1mask = None

    if args.MASK2 is not None:
        im2mask = tifffile.imread(args.MASK2.name) != 0
        if len(im2mask.shape) == 2:
            im2mask = im2mask[numpy.newaxis, ...]
    else:
        im2mask = None

# Three cases to handle:
#   1. phi file is reg   -> define nodes and apply reg
#   2. phi file is field -> take everything and check NS if passed
#   3. no phi file       -> define nodes
if args.PHIFILE is not None:
    PhiFromFile = spam.helpers.readCorrelationTSV(
        args.PHIFILE.name, fieldBinRatio=args.PHIFILE_BIN_RATIO
    )
    if PhiFromFile is None:
        print(f"\tFailed to read your TSV file passed with -pf {args.PHIFILE.name}")
        exit()

    # If the read Phi-file has only one line -- it's a single point registration!
    if PhiFromFile["fieldCoords"].shape[0] == 1:
        PhiInit = PhiFromFile["PhiField"][0]
        print(
            f"\tI read a registration from a file in binning {args.PHIFILE_BIN_RATIO}"
        )

        decomposedPhiInit = spam.deformation.decomposePhi(PhiInit)
        print("\tTranslations (px)")
        print("\t\t", decomposedPhiInit["t"])
        print("\tRotations (deg)")
        print("\t\t", decomposedPhiInit["r"])
        print("\tZoom")
        print("\t\t", decomposedPhiInit["z"])
        del decomposedPhiInit

        if args.LAB1 is None:
            # Create nodes if in regular mode, in label mode these are already defined
            if args.NS is None:
                print(
                    f"spam-pixelSearch: You passed a registration file {args.PHIFILE.name}, I need -ns to be defined"
                )
                exit()
            nodePositions, nodesDim = spam.DIC.makeGrid(im1.shape, args.NS)
            numberOfNodes = nodePositions.shape[0]

        PhiField = spam.DIC.applyRegistrationToPoints(
            PhiInit,
            PhiFromFile["fieldCoords"][0],
            nodePositions,
            applyF=args.APPLY_F,
            nProcesses=args.PROCESSES,
            verbose=False,
        )

    # If the read Phi-file contains multiple lines it's an F field!
    else:
        # print("spam-pixelSearch: Assuming loaded PhiFile is coherent with the current run.")
        nodePositionsFile = PhiFromFile["fieldCoords"]
        numberOfNodes = nodePositionsFile.shape[0]
        nodeSpacingFile = numpy.array(
            [
                numpy.unique(nodePositionsFile[:, i])[1]
                - numpy.unique(nodePositionsFile[:, i])[0]
                if len(numpy.unique(nodePositionsFile[:, i])) > 1
                else numpy.unique(nodePositionsFile[:, i])[0]
                for i in range(3)
            ]
        )
        PhiField = PhiFromFile["PhiField"]
        nodesDim = PhiFromFile["fieldDims"]

        # different checks to be done for lab and grid:
        if args.LAB1 is None:
            # In case NS is also defined, complain, but if it's the same as the loaded data, continue
            if args.NS is not None:
                # compare them
                if not numpy.allclose(numpy.array(args.NS), nodeSpacingFile, atol=0.0):
                    print(
                        f"spam-pixelSearch: you passed a -ns={args.NS} which contradicts the node spacing in your Phi Field TSV of {nodeSpacingFile}"
                    )
                    print(
                        "\thint 1: if you pass a Phi Field TSV you don't need to also define the node spacing"
                    )
                    print(
                        f"\thint 2: if you want to use your Phi Field TSV {args.PHIFILE.name} on a finer node spacing, pass it with spam-passPhiField"
                    )
                    exit()
                else:
                    print(
                        "spam-pixelSearch: passing -ns with a Phi Field TSV is not needed"
                    )
            else:
                # args.NS is None
                args.NS = nodeSpacingFile
            nodePositions = nodePositionsFile
        else:
            # Lab phi-field consistency check
            if not numpy.allclose(nodePositionsFile, nodePositions, atol=1.0):
                print(
                    f"spam-pixelSearch: Input PhiField positions from {args.PHIFILE.name} are not within 1px of the centre of mass of the labels from {args.LAB1}, this seems dangerous."
                )
                print(
                    "\tplease consider using spam-passPhiField to apply your PhiField to a new labelled image"
                )
                exit()

else:  # no Phi file
    if args.LAB1 is None:
        if args.NS is None:
            print(
                "spam-pixelSearch: You're in regular grid mode, but no -ns is set and no Phi Field TSV has been passed, exiting."
            )
            exit()
        nodePositions, nodesDim = spam.DIC.makeGrid(im1.shape, args.NS)
        numberOfNodes = nodePositions.shape[0]

    PhiField = numpy.zeros((numberOfNodes, 4, 4))
    for node in range(numberOfNodes):
        PhiField[node] = numpy.eye(4)


def pixelSearchOneNode(nodeNumber):
    """
    Function to be called by multiprocessing parallelisation for pixel search in one position.
    This function will call getImagettes, or the equivalent for labels and perform the pixel search

    Parameters
    ----------
        nodeNumber : int
            node number to work on

    Returns
    -------
        List with:
            - nodeNumber (needed to write result in right place)
            - displacement vector
            - NCC value
            - error value
            - return Status
    """
    # All global variables, we will return a list with:
    # nodeNumber,
    if args.LAB1 is not None:
        imagetteReturns = spam.label.getImagettesLabelled(
            lab1,
            nodeNumber,
            PhiField[nodeNumber].copy(),
            im1,
            im2,
            searchRange.copy(),
            boundingBoxes,
            nodePositions,
            margin=args.LABEL_DILATE,
            labelDilate=args.LABEL_DILATE,
            applyF=args.APPLY_F,
            volumeThreshold=args.LABEL_VOLUME_THRESHOLD,
        )
        imagetteReturns["imagette2mask"] = None
    else:
        imagetteReturns = spam.DIC.getImagettes(
            im1,
            nodePositions[nodeNumber],
            args.HWS,
            PhiField[nodeNumber].copy(),
            im2,
            searchRange.copy(),
            im1mask=im1mask,
            im2mask=im2mask,
            minMaskCoverage=args.MASK_COVERAGE,
            greyThreshold=[args.GREY_LOW_THRESH, args.GREY_HIGH_THRESH],
            applyF=args.APPLY_F,
            twoD=twoD,
        )

    # If getImagettes was successful (size check and mask coverage check)
    if imagetteReturns["returnStatus"] == 1:
        PSreturns = spam.DIC.pixelSearch(
            imagetteReturns["imagette1"],
            imagetteReturns["imagette2"],
            imagette1mask=imagetteReturns["imagette1mask"],
            imagette2mask=imagetteReturns["imagette2mask"],
            returnError=True,
        )
        pixelSearchOffset = imagetteReturns["pixelSearchOffset"]

        return (
            nodeNumber,
            PSreturns[0] + pixelSearchOffset,
            PSreturns[1],
            PSreturns[2],
            imagetteReturns["returnStatus"],
        )

    # Failed to extract imagettes or something
    else:
        return (
            nodeNumber,
            numpy.array([numpy.nan] * 3),
            0.0,
            numpy.inf,
            imagetteReturns["returnStatus"],
        )


# Create pixelSearchCC vector
pixelSearchCC = numpy.zeros((numberOfNodes), dtype=float)
# Error compatible with register()
error = numpy.zeros((numberOfNodes), dtype=float)
returnStatus = numpy.ones((numberOfNodes), dtype=int)
deltaPhiNorm = numpy.ones((numberOfNodes), dtype=int)
iterations = numpy.ones((numberOfNodes), dtype=int)

if args.LAB1 is not None:
    firstNode = 1
    finishedNodes = 1
    returnStatus[0] = 0
else:
    firstNode = 0
    finishedNodes = 0


print(
    "\n\tStarting Pixel search (with {} process{})".format(
        args.PROCESSES, "es" if args.PROCESSES > 1 else ""
    )
)

widgets = [
    progressbar.FormatLabel(""),
    " ",
    progressbar.Bar(),
    " ",
    progressbar.AdaptiveETA(),
]
pbar = progressbar.ProgressBar(widgets=widgets, maxval=numberOfNodes)
pbar.start()
finishedNodes = 0

with multiprocessing.Pool(processes=args.PROCESSES) as pool:
    # for returns in pool.imap_unordered(pixelSearchOneNode, range(numberOfNodes//2, numberOfNodes)):
    for returns in pool.imap_unordered(
        pixelSearchOneNode, range(firstNode, numberOfNodes)
    ):
        finishedNodes += 1

        # Update progres bar if point is not skipped
        if returns[4] > 0:
            widgets[0] = progressbar.FormatLabel("  CC={:0>7.5f} ".format(returns[2]))
            pbar.update(finishedNodes)

        PhiField[returns[0], 0:3, -1] = returns[1]
        # Create pixelSearchCC vector
        pixelSearchCC[returns[0]] = returns[2]
        error[returns[0]] = returns[3]
        returnStatus[returns[0]] = returns[4]
    pool.close()
    pool.join()

pbar.finish()

print("\n")


if args.TSV:
    # Make one big array for writing:
    #   First the node number,
    #   3 node positions,
    #   F[0:3,0:2]
    #   Pixel-search CC
    if args.LAB1 is not None:
        TSVheader = "Label\tZpos\tYpos\tXpos\tFzz\tFzy\tFzx\tZdisp\tFyz\tFyy\tFyx\tYdisp\tFxz\tFxy\tFxx\tXdisp\tpixelSearchCC\treturnStatus\terror\tdeltaPhiNorm\titerations"
    else:
        TSVheader = "NodeNumber\tZpos\tYpos\tXpos\tFzz\tFzy\tFzx\tZdisp\tFyz\tFyy\tFyx\tYdisp\tFxz\tFxy\tFxx\tXdisp\tpixelSearchCC\treturnStatus\terror\tdeltaPhiNorm\titerations"
    outMatrix = numpy.array(
        [
            numpy.array(range(nodePositions.shape[0])),
            nodePositions[:, 0],
            nodePositions[:, 1],
            nodePositions[:, 2],
            PhiField[:, 0, 0],
            PhiField[:, 0, 1],
            PhiField[:, 0, 2],
            PhiField[:, 0, 3],
            PhiField[:, 1, 0],
            PhiField[:, 1, 1],
            PhiField[:, 1, 2],
            PhiField[:, 1, 3],
            PhiField[:, 2, 0],
            PhiField[:, 2, 1],
            PhiField[:, 2, 2],
            PhiField[:, 2, 3],
            pixelSearchCC,
            returnStatus,
            error,
            deltaPhiNorm,
            iterations,
        ]
    ).T

    numpy.savetxt(
        args.OUT_DIR + "/" + args.PREFIX + ".tsv",
        outMatrix,
        fmt="%.7f",
        delimiter="\t",
        newline="\n",
        comments="",
        header=TSVheader,
    )

if args.TIFF:
    if args.LAB1 is None:
        if nodesDim[0] != 1:
            tifffile.imwrite(
                args.OUT_DIR + "/" + args.PREFIX + "-Zdisp.tif",
                PhiField[:, 0, -1].astype("<f4").reshape(nodesDim),
            )
        tifffile.imwrite(
            args.OUT_DIR + "/" + args.PREFIX + "-Ydisp.tif",
            PhiField[:, 1, -1].astype("<f4").reshape(nodesDim),
        )
        tifffile.imwrite(
            args.OUT_DIR + "/" + args.PREFIX + "-Xdisp.tif",
            PhiField[:, 2, -1].astype("<f4").reshape(nodesDim),
        )
        tifffile.imwrite(
            args.OUT_DIR + "/" + args.PREFIX + "-CC.tif",
            pixelSearchCC.astype("<f4").reshape(nodesDim),
        )
        tifffile.imwrite(
            args.OUT_DIR + "/" + args.PREFIX + "-returnStatus.tif",
            returnStatus.astype("<f4").reshape(nodesDim),
        )
    else:
        # Think about relabelling grains here automatically?
        pass


# Collect data into VTK output
if args.VTK and args.LAB1 is None:
    cellData = {}
    cellData["displacements"] = PhiField[:, :-1, 3].reshape(
        (nodesDim[0], nodesDim[1], nodesDim[2], 3)
    )
    cellData["pixelSearchCC"] = pixelSearchCC.reshape(nodesDim)

    # Overwrite nans and infs with 0, rubbish I know
    cellData["displacements"][
        numpy.logical_not(numpy.isfinite(cellData["displacements"]))
    ] = 0
    # if args.REGSUB:
    # cellData['displacements-regsub'][numpy.logical_not(numpy.isfinite(cellData['displacements-regsub']))] = 0

    # This is perfect in the case where NS = 2xHWS, these cells will all be in the right place
    #   In the case of overlapping of under use of data, it should be approximately correct
    # If you insist on overlapping, then perhaps it's better to save each point as a cube glyph
    #   and actually *have* overlapping
    spam.helpers.writeStructuredVTK(
        origin=nodePositions[0] - args.HWS,
        aspectRatio=args.NS,
        cellData=cellData,
        fileName=args.OUT_DIR + "/" + args.PREFIX + ".vtk",
    )

elif args.VTK and args.LAB1 is not None:
    # Redundant output for VTK visualisation
    magDisp = numpy.zeros(numberOfNodes)
    for node in range(numberOfNodes):
        magDisp[node] = numpy.linalg.norm(PhiField[node][0:3, -1])

    VTKglyphDict = {
        "displacements": PhiField[:, 0:3, -1],
        "mag(displacements)": magDisp,
        "pixelSearchCC": pixelSearchCC,
    }

    spam.helpers.writeGlyphsVTK(
        nodePositions, VTKglyphDict, fileName=args.OUT_DIR + "/" + args.PREFIX + ".vtk"
    )
