#!python


# This python facilitates eye-alignment with a graphical QT interface
# for Discrete Digital Image Correlation using SPAM functions
# Copyright (C) 2020 SPAM Contributors
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, see <http://www.gnu.org/licenses/>.

import spam.visual.visualClass as visual

import os
import sys
import subprocess
import spam.helpers
import spam.label
import numpy
numpy.seterr(all='ignore')
import tifffile
import argparse

from PyQt5.QtWidgets import QApplication, QWidget, QFileDialog, QGridLayout, QPushButton

# Define argument parser object
parser = argparse.ArgumentParser(description="spam-ereg-discrete "+spam.helpers.optionsParser.GLPv3descriptionHeader +\
                                             "This script facilitates eye-alignment for Discrete Digital Image Correlation two 3D greyscale images"+\
                                             " (reference and deformed configurations) and requires the input of a labelled image for the reference configuration",
                                 formatter_class=argparse.RawTextHelpFormatter)

# Parse arguments with external helper function
args = spam.helpers.optionsParser.eregDiscreteParser(parser)

print("spam-ereg-discrete: Current Settings:")
argsDict = vars(args)
for key in sorted(argsDict):
    print("\t{}: {}".format(key, argsDict[key]))

outFile = args.OUT_DIR+"/"+args.PREFIX+"-ereg-discrete.tsv"

REFlab    = tifffile.imread(args.lab1.name)
REFlabBB  = spam.label.boundingBoxes(REFlab)
REFlabCOM = spam.label.centresOfMass(REFlab)

REFgrey   = tifffile.imread(args.im1.name)
DEFgrey   = tifffile.imread(args.im2.name)

# Some variable for nice filenames
REFstr    = os.path.basename(args.im1.name)
DEFstr    = os.path.basename(args.im2.name)

if args.PHIFILE is not None:
    DDIC                = spam.helpers.readCorrelationTSV(args.PHIFILE.name, readConvergence=True)
    DDIC['error']       = numpy.genfromtxt(args.PHIFILE.name, delimiter='\t', names=True)['error']
    try:
        DDIC['LabelDilate'] = numpy.genfromtxt(args.PHIFILE.name, delimiter='\t', names=True)['LabelDilate']
    except:
        DDIC['LabelDilate'] = numpy.zeros(DDIC['error'].shape[0])
    try:
        DDIC['PSCC']        = numpy.genfromtxt(args.PHIFILE.name, delimiter='\t', names=True)['PSCC']
    except:
        DDIC['PSCC']        = numpy.zeros(DDIC['error'].shape[0])

    # 09-12-21 GP: Adding the mask for converge particles
    if args.MASK_CONV:
        print('Moving image')
        REFlabMoved = spam.label.moveLabels(REFlab,
                                            DDIC['PhiField'],
                                            returnStatus=DDIC['returnStatus'])
        DEFgrey = numpy.where(REFlabMoved > 0, 0, DEFgrey)

else:
    labMax = REFlab.max()
    DDIC = {}
    DDIC['fieldCoords']  = REFlabCOM
    DDIC['PhiField'] = numpy.zeros([labMax, 4, 4])
    for label in range(labMax):
        DDIC['PhiField'][label] = numpy.eye(4)
    DDIC['returnStatus'] = numpy.zeros(labMax, dtype=int)
    DDIC['deltaPhiNorm'] = numpy.zeros(labMax, dtype=int)
    DDIC['iterations']   = numpy.zeros(labMax, dtype=int)
    DDIC['error']        = numpy.zeros(labMax, dtype=int)
    DDIC['LabelDilate']  = numpy.zeros(labMax, dtype=int)
    DDIC['PSCC']         = numpy.zeros(labMax, dtype=int)

class MainWindow(QWidget):
    def __init__(self):
        QWidget.__init__(self)
        self.Phi = numpy.eye(4)
        self.mainWindowGrid = QGridLayout(self)

        # Issue #192 will be fixed here, by making sure the loew return stat list also has real boundin boxes
        #   (i.e., grains are really defined)
        nonConvergedGrains = DDIC['returnStatus'][0:REFlabBB.shape[0]] < args.RETURN_STATUS_THRESHOLD
        presentGrains      = REFlabBB[:,1]>REFlabBB[:,0]
        # In the very unlucky case that there is a nax numbered nonConverged grain that is not present at all:
        if REFlabBB.shape[0] < len(nonConvergedGrains):
            print("Warning: there are higher-numbered labels in your TSV file that are not in the labelled image, discarding them")

        print(numpy.where(nonConvergedGrains), presentGrains)
        self.nonConvergedGrains = numpy.where(numpy.logical_and(nonConvergedGrains, presentGrains))[0][0:]

        self.N = 0 # Number of the current nonConvergedGrain that's being studied
        print("Going to work on these labels:\n", self.nonConvergedGrains, "(p.s. I removed non-existent labels:", numpy.where(~presentGrains)[0][1:]," )")
        if len(self.nonConvergedGrains) > 0:
            self.labAndPhi = []
            self.labelExists = False
            self.alignOneLabel()
        else:
            print("No labels to work on")
            exit()

    def alignOneLabel(self):
        nonConvergedGrain = self.nonConvergedGrains[self.N]

        print("\tGrain {}".format(nonConvergedGrain))
        print("\t\tPosition in reference image: {}".format(REFlabCOM[nonConvergedGrain]))

        Phi = DDIC['PhiField'][nonConvergedGrain]

        displacement = Phi[0:3,-1]
        displacementInt = displacement.astype(int)
        self.diplacementInt = displacementInt
        # Remove the int part of displacement
        Phi[0:3,-1] -= displacementInt
        print("\t\tSubtracted this displacement:", displacementInt)

        REFgl = spam.label.getLabel(REFlab, nonConvergedGrain,
                                    boundingBoxes=REFlabBB, centresOfMass=REFlabCOM,
                                    labelDilate=args.LABEL_DILATE, margin=args.margin,
                                    maskOtherLabels=args.MASK)

        if REFgl is not None:
            self.labelExists = True
            # 2020-10-23: EA on Issue #186: using spam.helpers.slicePadded
            REFsubvol = spam.helpers.slicePadded(REFgrey, REFgl['boundingBox']+numpy.array([0,1,0,1,0,1]))

            if args.MASK:
                # If mask asked, also flatten greylevels
                REFsubvol[REFgl['subvol'] == 0] = 0

            # 2020-10-23: EA on Issue #186: using spam.helpers.slicePadded
            DEFsubvol = spam.helpers.slicePadded(DEFgrey, REFgl['boundingBox']+
                                                        numpy.array([0,1,0,1,0,1])+
                                                        numpy.array([displacementInt[0], displacementInt[0],
                                                                    displacementInt[1], displacementInt[1],
                                                                    displacementInt[2], displacementInt[2]]))

            self.eregWidget = visual.ereg(  [REFsubvol, DEFsubvol],
                                            Phi,
                                            [f"{REFstr} - label {nonConvergedGrain}",
                                             f"{DEFstr} - label {nonConvergedGrain}"],
                                            binning=1,
                                            imUpdate=0)
            self.mainWindowGrid.addWidget(self.eregWidget, 1, 1)
            self.nextLabelButton = QPushButton("Accept and move on to next grain", self)
            self.nextLabelButton.clicked.connect(self.nextLabel)
            self.mainWindowGrid.addWidget(self.nextLabelButton, 2, 1)
        else:
            #print('alignOneGrain(): warning refgl is none')
            self.labelExists = False
            self.nextLabel()

    def nextLabel(self):
        #print("Entering nextLabel(): self.labelExists = ", self.labelExists)
        if self.labelExists:
            self.eregWidget.close()

            # Get Phi output from graphical
            PhiTmp = self.eregWidget.output()
            # Add back in int displacement
            PhiTmp[0:3, -1] += self.diplacementInt
            #                       nonConvergedGrain label number, eye-Phi
            self.labAndPhi.append([self.nonConvergedGrains[self.N], PhiTmp])
            print("nextLabel: I accepted a Phi for label {}".format([self.nonConvergedGrains[self.N]]))
        else:
            print("nextLabel: I skipped label {}".format([self.nonConvergedGrains[self.N]]))
            # This grain was skipped, let's add nothing in its place
            self.labAndPhi.append([self.nonConvergedGrains[self.N], numpy.eye(4)])

        # Move onto next grain, otherwise write and quit
        self.N += 1
        if self.N < len(self.nonConvergedGrains):
            self.alignOneLabel()
        else:
            self.nextLabelButton.close()
            self.eregWidget.close()
            #print(self.labAndPhi)

            print("Updating output...")
            for nonConvergedGrain, Phi in self.labAndPhi:
                DDIC['PhiField'][nonConvergedGrain] = Phi

            print("Writing output to {}...".format(outFile), end='')
            outMatrix = numpy.array([numpy.array(range(DDIC['numberOfLabels'])),
                                        DDIC['fieldCoords'][:, 0], DDIC['fieldCoords'][:, 1], DDIC['fieldCoords'][:, 2],
                                        DDIC['PhiField'][:, 0, 3], DDIC['PhiField'][:, 1, 3], DDIC['PhiField'][:, 2, 3],
                                        DDIC['PhiField'][:, 0, 0], DDIC['PhiField'][:, 0, 1], DDIC['PhiField'][:, 0, 2],
                                        DDIC['PhiField'][:, 1, 0], DDIC['PhiField'][:, 1, 1], DDIC['PhiField'][:, 1, 2],
                                        DDIC['PhiField'][:, 2, 0], DDIC['PhiField'][:, 2, 1], DDIC['PhiField'][:, 2, 2],
                                        DDIC['error'], DDIC['iterations'],
                                        DDIC['returnStatus'], DDIC['deltaPhiNorm'],
                                        DDIC['LabelDilate'], DDIC['PSCC']]).T

            numpy.savetxt(outFile,
                            outMatrix,
                            fmt='%.7f',
                            delimiter='\t',
                            newline='\n',
                            comments='',
                            header="Label\tZpos\tYpos\tXpos\t" +
                            "Zdisp\tYdisp\tXdisp\t" +
                            "Fzz\tFzy\tFzx\t" +
                            "Fyz\tFyy\tFyx\t" +
                            "Fxz\tFxy\tFxx\t" +
                            "error\titerations\treturnStatus\tdeltaPhiNorm\tLabelDilate\tPSCC")
            print("...done")
            self.close()
            #self.mainWindowGrid.close()

app = QApplication(["Label Registration"])
window = MainWindow()
window.show()
app.exec_()
