#!python

# This python facilitates eye-alignment with a graphical QT interface
# for Discrete Digital Image Correlation using SPAM functions
# Copyright (C) 2020 SPAM Contributors
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, see <http://www.gnu.org/licenses/>.

"""
This script manipulates a Phi field:

Gridded Phi values:
  - spam-pixelSearch
  - spam-pixelSearchPropagate
  - spam-ldic

Phis defined at labels centres:
  - spam-pixelSearch
  - spam-pixelSearchPropagate
  - spam-ddic


This script allows you to:
  - correct bad points inside a PhiField based on RS, or CC
  - correct incoherent points inside a PhiField based on LQC
  - apply a median filter to the PhiField

Outputs are:
  - TSV files
  - (optional) VTK files for visualisation
  - (optional) TIF files in the case of gridded data
"""

import spam.DIC
import spam.deformation
import spam.helpers

# import spam.mesh
import spam.label

import os

os.environ["OPENBLAS_NUM_THREADS"] = "1"

import numpy

numpy.seterr(all="ignore")
import multiprocessing
try:                 multiprocessing.set_start_method('fork')
except RuntimeError: pass
import scipy.spatial
import scipy.ndimage
import progressbar
import argparse
import tifffile

tol = 1e-6

# Define argument parser object
parser = argparse.ArgumentParser(
    description="spam-filterPhiField "
    + spam.helpers.optionsParser.GLPv3descriptionHeader
    + "This script process Phi fields by\n"
    + "correcting bad or incoherent points or filtering",
    formatter_class=argparse.RawTextHelpFormatter,
)

# Parse arguments with external helper function
args = spam.helpers.optionsParser.filterPhiField(parser)

if args.PROCESSES is None:
    args.PROCESSES = multiprocessing.cpu_count()

print("spam-filterPhiField -- Current Settings:")
argsDict = vars(args)
for key in sorted(argsDict):
    print("\t{}: {}".format(key, argsDict[key]))

###############################################################
### Step 1 (mandatory) read input Phi File
###############################################################
PhiFromFile = spam.helpers.readCorrelationTSV(
    args.PHIFILE.name, readConvergence=True, readPixelSearchCC=True, readError=True
)
if PhiFromFile is None:
    print(f"\tFailed to read your TSV file passed with -pf {args.PHIFILE.name}")
    exit()

# If the read Phi-file has only one line -- it's a single point registration!
# We can either apply it to a grid or to labels
if PhiFromFile["fieldCoords"].shape[0] == 1:
    print(
        f"\tYour TSV passed with -pf {args.PHIFILE.name} is single line file (a registration). A field is required"
    )
    exit()

# Check if it is a discrete or gridded field
grid = True
discrete = False
if PhiFromFile["numberOfLabels"] != 0:
    discrete = True
    grid = False

###############################################################
### Input Phi file is a Phi FIELD
###############################################################
inputNodesDim = PhiFromFile["fieldDims"]
inputNodePositions = PhiFromFile["fieldCoords"]
inputPhiField = PhiFromFile["PhiField"]
inputDisplacements = PhiFromFile["PhiField"][:, 0:3, -1]
inputReturnStatus = PhiFromFile["returnStatus"]
inputPixelSearchCC = PhiFromFile["pixelSearchCC"]
inputDeltaPhiNorm = PhiFromFile["deltaPhiNorm"]
inputIterations = PhiFromFile["iterations"]
inputError = PhiFromFile["error"]
### Empty arrays for masking points
inputGood = numpy.zeros(inputNodePositions.shape[0], dtype=bool)
inputBad = numpy.zeros(inputNodePositions.shape[0], dtype=bool)
inputIgnore = numpy.zeros(inputNodePositions.shape[0], dtype=bool)

# output arrays
outputPhiField = numpy.zeros((inputNodePositions.shape[0], 4, 4))
outputReturnStatus = numpy.ones((inputNodePositions.shape[0]), dtype=float)
outputDeltaPhiNorm = numpy.ones((inputNodePositions.shape[0]), dtype=float) * 100
outputIterations = numpy.zeros((inputNodePositions.shape[0]), dtype=float)
outputError = numpy.ones((inputNodePositions.shape[0]), dtype=float) * 100
outputPixelSearchCC = numpy.zeros((inputNodePositions.shape[0]), dtype=float)
# Check neighbour inputs, either args.NEIGHBOUR_RADIUS or args.NUMBER_OF_NEIGHBOURS should be set.
if args.NEIGHBOUR_RADIUS is not None and args.NUMBER_OF_NEIGHBOURS is not None:
    print(
        "Both number of neighbours and neighbour radius are set, I'm taking the radius and ignoring the number of neighbours"
    )
    args.NUMBER_OF_NEIGHBOURS = None

if args.NEIGHBOUR_RADIUS is None and args.NUMBER_OF_NEIGHBOURS is None:
    if grid:
        # Gridded input field
        nodeSpacing = numpy.array(
            [
                numpy.unique(inputNodePositions[:, i])[1]
                - numpy.unique(inputNodePositions[:, i])[0]
                if len(numpy.unique(inputNodePositions[:, i])) > 1
                else numpy.unique(inputNodePositions[:, i])[0]
                for i in range(3)
            ]
        )
        args.NEIGHBOUR_RADIUS = 4 * int(numpy.mean(nodeSpacing))
        print(
            f"Neither number of neighbours nor neighbour distance set, using default distance of 4*mean(nodeSpacing) = {args.NEIGHBOUR_RADIUS}"
        )
    else:
        # Discrete input field
        args.NUMBER_OF_NEIGHBOURS = 27
        print(
            "Neither number of neighbours nor neighbour distance set, using default 27 neighbours"
        )


###############################################################
### Define IGNORE points:
###############################################################
if args.MASK:
    inputIgnore = inputReturnStatus < -4

###############################################################
### Apply threshold to select good and bad points
###############################################################
if args.SRS:
    print(f"\n\nSelecting bad points as Return Status <= {args.SRST}")
    inputGood = numpy.logical_and(inputReturnStatus > args.SRST, ~inputIgnore)
    inputBad = numpy.logical_and(inputReturnStatus <= args.SRST, ~inputIgnore)
    if args.SLQC:
        print("\tYou passed -slqc but you can only have one selection at a time")
    if args.SCC:
        print("\tYou passed -scc but you can only have one selection at a time")

elif args.SCC:
    print(f"\n\nSelecting bad points with Pixel Search CC <= {args.SCCT}")
    inputGood = numpy.logical_and(inputPixelSearchCC > args.SCCT, ~inputIgnore)
    inputBad = numpy.logical_and(inputPixelSearchCC <= args.SCCT, ~inputIgnore)
    if args.SLQC:
        print("\tYou passed -slqc but you can only have one selection at a time")

elif args.SLQC:
    print("\n\nCalculate coherency")
    LQC = spam.DIC.estimateLocalQuadraticCoherency(
        inputNodePositions[~inputIgnore],
        inputDisplacements[~inputIgnore],
        neighbourRadius=args.NEIGHBOUR_RADIUS,
        nNeighbours=args.NUMBER_OF_NEIGHBOURS,
        nProcesses=args.PROCESSES,
        verbose=True,
    )
    # print(LQC.shape)
    # print(inputGood[~inputIgnore].shape)
    inputGood[~inputIgnore] = LQC < 0.1
    inputBad[~inputIgnore] = LQC >= 0.1


###############################################################
### Copy over the values for good AND ignore to output
###############################################################
gandi = numpy.logical_or(inputGood, inputIgnore)

outputPhiField[gandi] = inputPhiField[gandi]
outputReturnStatus[gandi] = inputReturnStatus[gandi]
outputDeltaPhiNorm[gandi] = inputDeltaPhiNorm[gandi]
outputIterations[gandi] = inputIterations[gandi]
outputError[gandi] = inputError[gandi]
outputPixelSearchCC[gandi] = inputPixelSearchCC[gandi]

if (args.CINT + args.CLQF) > 0 and numpy.sum(inputBad) == 0:
    print("No points to correct, exiting")
    exit()

else:
    print(
        f"\n\nCorrecting {numpy.sum(inputBad)} points ({100*numpy.sum(inputBad)/numpy.sum(inputGood):03.1f}%)"
    )

###############################################################
### Correct those bad points
###############################################################
if args.CINT:
    print(f"\n\nCorrection based on local interpolation (filterF = {args.FILTER_F})")
    PhiFieldCorrected = spam.DIC.interpolatePhiField(
        inputNodePositions[inputGood],
        inputPhiField[inputGood],
        inputNodePositions[inputBad],
        nNeighbours=args.NUMBER_OF_NEIGHBOURS,
        neighbourRadius=args.NEIGHBOUR_RADIUS,
        interpolateF=args.FILTER_F,
        nProcesses=args.PROCESSES,
        verbose=True,
    )
    outputPhiField[inputBad] = PhiFieldCorrected
    outputReturnStatus[inputBad] = 1
    if args.CLQF:
        print(
            "\tYou asked to correct with local QC fitting with -clqf, but only one correciton mode is supported"
        )

elif args.CLQF:
    if args.FILTER_F != "no":
        print(
            "WARNING: non-displacement quadratic coherency correction not implemented, only doing displacements, and returning F=eye(3)\n"
        )

    print("\n\nCorrection based on local quadratic coherency")
    dispLQC = spam.DIC.estimateDisplacementFromQuadraticFit(
        inputNodePositions[inputGood],
        inputDisplacements[inputGood],
        inputNodePositions[inputBad],
        neighbourRadius=args.NEIGHBOUR_RADIUS,
        nNeighbours=args.NUMBER_OF_NEIGHBOURS,
        nProcesses=args.PROCESSES,
        verbose=True,
    )
    # pass the displacements
    outputPhiField[inputBad, 0:3, 0:3] = numpy.eye(3)
    outputPhiField[inputBad, 0:3, -1] = dispLQC
    outputReturnStatus[inputBad] = 1


if args.FILTER_MEDIAN:
    if discrete:
        print(
            "Median filter for discrete mode not implemented... does it even make sense?"
        )
    else:
        # Filter ALL POINTS
        # if asked, apply a median filter of a specific size in the Phi field
        print("\nApplying median filter...")
        filterPointsRadius = int(args.FILTER_MEDIAN_RADIUS)

        if args.MASK:
            inputPhiField[inputIgnore] = numpy.nan

        if args.FILTER_F == "rigid":
            print("Rigid mode not well defined for overall median filtering, exiting")
            exit()

        if args.FILTER_F == "all":
            # Filter F components
            print("Filtering F components...")
            print("\t1/9")
            outputPhiField[:, 0, 0] = scipy.ndimage.generic_filter(
                inputPhiField[:, 0, 0].reshape(inputNodesDim),
                numpy.nanmedian,
                size=(2 * filterPointsRadius + 1),
            ).ravel()
            print("\t2/9")
            outputPhiField[:, 1, 0] = scipy.ndimage.generic_filter(
                inputPhiField[:, 1, 0].reshape(inputNodesDim),
                numpy.nanmedian,
                size=(2 * filterPointsRadius + 1),
            ).ravel()
            print("\t3/9")
            outputPhiField[:, 2, 0] = scipy.ndimage.generic_filter(
                inputPhiField[:, 2, 0].reshape(inputNodesDim),
                numpy.nanmedian,
                size=(2 * filterPointsRadius + 1),
            ).ravel()
            print("\t4/9")
            outputPhiField[:, 0, 1] = scipy.ndimage.generic_filter(
                inputPhiField[:, 0, 1].reshape(inputNodesDim),
                numpy.nanmedian,
                size=(2 * filterPointsRadius + 1),
            ).ravel()
            print("\t5/9")
            outputPhiField[:, 1, 1] = scipy.ndimage.generic_filter(
                inputPhiField[:, 1, 1].reshape(inputNodesDim),
                numpy.nanmedian,
                size=(2 * filterPointsRadius + 1),
            ).ravel()
            print("\t6/9")
            outputPhiField[:, 2, 1] = scipy.ndimage.generic_filter(
                inputPhiField[:, 2, 1].reshape(inputNodesDim),
                numpy.nanmedian,
                size=(2 * filterPointsRadius + 1),
            ).ravel()
            print("\t7/9")
            outputPhiField[:, 0, 2] = scipy.ndimage.generic_filter(
                inputPhiField[:, 0, 2].reshape(inputNodesDim),
                numpy.nanmedian,
                size=(2 * filterPointsRadius + 1),
            ).ravel()
            print("\t8/9")
            outputPhiField[:, 1, 2] = scipy.ndimage.generic_filter(
                inputPhiField[:, 1, 2].reshape(inputNodesDim),
                numpy.nanmedian,
                size=(2 * filterPointsRadius + 1),
            ).ravel()
            print("\t9/9")
            outputPhiField[:, 2, 2] = scipy.ndimage.generic_filter(
                inputPhiField[:, 2, 2].reshape(inputNodesDim),
                numpy.nanmedian,
                size=(2 * filterPointsRadius + 1),
            ).ravel()

        if args.FILTER_F == "no":
            for n in range(inputNodePositions.shape[0]):
                outputPhiField[n] = numpy.eye(4)

        print("Filtering displacements...")
        print("\t1/3")
        outputPhiField[:, 0, -1] = scipy.ndimage.generic_filter(
            inputPhiField[:, 0, -1].reshape(inputNodesDim),
            numpy.nanmedian,
            size=(2 * filterPointsRadius + 1),
        ).ravel()
        print("\t2/3")
        outputPhiField[:, 1, -1] = scipy.ndimage.generic_filter(
            inputPhiField[:, 1, -1].reshape(inputNodesDim),
            numpy.nanmedian,
            size=(2 * filterPointsRadius + 1),
        ).ravel()
        print("\t3/3")
        outputPhiField[:, 2, -1] = scipy.ndimage.generic_filter(
            inputPhiField[:, 2, -1].reshape(inputNodesDim),
            numpy.nanmedian,
            size=(2 * filterPointsRadius + 1),
        ).ravel()

        if args.MASK:
            outputPhiField[inputIgnore] = numpy.nan


# Outputs are:
# - TSV files
# - (optional) VTK files for visualisation
# - (optional) TIF files in the case of gridded data
if args.TSV:
    if discrete:
        TSVheader = "Label\tZpos\tYpos\tXpos\tFzz\tFzy\tFzx\tZdisp\tFyz\tFyy\tFyx\tYdisp\tFxz\tFxy\tFxx\tXdisp\tpixelSearchCC\treturnStatus\terror\tdeltaPhiNorm\titerations"
    else:
        TSVheader = "NodeNumber\tZpos\tYpos\tXpos\tFzz\tFzy\tFzx\tZdisp\tFyz\tFyy\tFyx\tYdisp\tFxz\tFxy\tFxx\tXdisp\tpixelSearchCC\treturnStatus\terror\tdeltaPhiNorm\titerations"
    outMatrix = numpy.array(
        [
            numpy.arange(inputNodePositions.shape[0]),
            inputNodePositions[:, 0],
            inputNodePositions[:, 1],
            inputNodePositions[:, 2],
            outputPhiField[:, 0, 0],
            outputPhiField[:, 0, 1],
            outputPhiField[:, 0, 2],
            outputPhiField[:, 0, 3],
            outputPhiField[:, 1, 0],
            outputPhiField[:, 1, 1],
            outputPhiField[:, 1, 2],
            outputPhiField[:, 1, 3],
            outputPhiField[:, 2, 0],
            outputPhiField[:, 2, 1],
            outputPhiField[:, 2, 2],
            outputPhiField[:, 2, 3],
            outputPixelSearchCC,
            outputReturnStatus,
            outputError,
            outputDeltaPhiNorm,
            outputIterations,
        ]
    ).T

    numpy.savetxt(
        args.OUT_DIR + "/" + args.PREFIX + ".tsv",
        outMatrix,
        fmt="%.7f",
        delimiter="\t",
        newline="\n",
        comments="",
        header=TSVheader,
    )

if args.TIFF:
    if grid:
        if inputNodesDim[0] != 1:
            tifffile.imwrite(
                args.OUT_DIR + "/" + args.PREFIX + "-Zdisp.tif",
                outputPhiField[:, 0, -1].astype("<f4").reshape(inputNodesDim),
            )
        tifffile.imwrite(
            args.OUT_DIR + "/" + args.PREFIX + "-Ydisp.tif",
            outputPhiField[:, 1, -1].astype("<f4").reshape(inputNodesDim),
        )
        tifffile.imwrite(
            args.OUT_DIR + "/" + args.PREFIX + "-Xdisp.tif",
            outputPhiField[:, 2, -1].astype("<f4").reshape(inputNodesDim),
        )
        # tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-CC.tif",                     pixelSearchCC.astype('<f4').reshape(nodesDim))
        # tifffile.imwrite(args.OUT_DIR+"/"+args.PREFIX+"-returnStatus.tif",           returnStatus.astype('<f4').reshape(nodesDim))
    else:
        # Think about relabelling grains here automatically?
        pass


# Collect data into VTK output
if args.VTK:
    if grid:
        cellData = {}
        cellData["displacements"] = outputPhiField[:, :-1, 3].reshape(
            (inputNodesDim[0], inputNodesDim[1], inputNodesDim[2], 3)
        )

        # Overwrite nans and infs with 0, rubbish I know
        cellData["displacements"][
            numpy.logical_not(numpy.isfinite(cellData["displacements"]))
        ] = 0

        # This is perfect in the case where NS = 2xHWS, these cells will all be in the right place
        #   In the case of overlapping of under use of data, it should be approximately correct
        # If you insist on overlapping, then perhaps it's better to save each point as a cube glyph
        #   and actually *have* overlapping
        # HACK assume HWS is half node spacing
        nodeSpacing = numpy.array(
            [
                numpy.unique(inputNodePositions[:, i])[1]
                - numpy.unique(inputNodePositions[:, i])[0]
                if len(numpy.unique(inputNodePositions[:, i])) > 1
                else numpy.unique(inputNodePositions[:, i])[0]
                for i in range(3)
            ]
        )
        HWS = nodeSpacing / 2
        spam.helpers.writeStructuredVTK(
            origin=inputNodePositions[0] - HWS,
            aspectRatio=nodeSpacing,
            cellData=cellData,
            fileName=args.OUT_DIR + "/" + args.PREFIX + ".vtk",
        )

    else:
        disp = outputPhiField[:, 0:3, -1]
        disp[numpy.logical_not(numpy.isfinite(disp))] = 0

        magDisp = numpy.linalg.norm(disp, axis=1)

        VTKglyphDict = {
            "displacements": outputPhiField[:, 0:3, -1],
            "mag(displacements)": magDisp,
        }

        spam.helpers.writeGlyphsVTK(
            inputNodePositions,
            VTKglyphDict,
            fileName=args.OUT_DIR + "/" + args.PREFIX + ".vtk",
        )
