#!python

# Copyright (C) 2020 SPAM Contributors
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, see <http://www.gnu.org/licenses/>.

import sys
import os

# pyQt
from PyQt5.QtWidgets import (
    QWidget,
    QLabel,
    QPushButton,
    QComboBox,
    QGridLayout,
    QFileDialog,
    QApplication,
)
from PyQt5.QtGui import QPixmap, QIcon, QFont, QPalette
from PyQt5.QtCore import Qt

# from PyQt5.QtGui import *
# from PyQt5.QtWidgets import *

# science
import numpy
import tifffile

# spam
import spam.visual.visualClass as visual

# This variable decides which image is updated in the steps: eye reg, crop, joint hist
# If == 0 im1 will be updated with Phi
# If == 1 im2 will be updated with PhiInv
# If  > 1 you're by yourself
imUpdate = 0


class MainWindow(QWidget):
    def __init__(self):
        print("[SPAM MMR] Welcome to SPAM Multimodal Registration")

        myFontBold = QFont()
        myFontBold.setBold(True)

        # myFontCourrier = QFont("courier new")

        QWidget.__init__(self)
        self.mainWindowGrid = QGridLayout(self)
        self.binningSelection = QWidget()
        self.mainWindowGrid.addWidget(self.binningSelection, 1, 1)
        grid = QGridLayout(self.binningSelection)

        # number of files to load
        nLoads = len(sys.argv) - 1

        # urls of icon
        self.rootDirectory = (
            os.path.dirname(os.path.realpath(sys.argv[1])) if nLoads else None
        )
        iconUrl = os.path.join(
            os.path.abspath(os.path.dirname(__file__)), "../share/img/icon.png"
        )
        logoUrl = os.path.join(
            os.path.abspath(os.path.dirname(__file__)), "../share/img/logo.png"
        )

        self.absoluteArgv = [os.path.realpath(f) for f in sys.argv]

        # logo + welcome
        labelLogo = QLabel()
        labelLogo.pixmap = QPixmap(logoUrl).scaled(150, 150)
        labelLogo.setPixmap(labelLogo.pixmap)
        grid.addWidget(labelLogo, 0, 1, 1, 1)

        labelWelcome = QLabel("Welcome to SPAM\nMultimodal Registration")
        labelWelcome.setAlignment(Qt.AlignCenter)
        self.fontColor = labelWelcome.palette().color(QPalette.WindowText).name()

        labelWelcome.setStyleSheet("QLabel {font-weight: bold;}")
        grid.addWidget(labelWelcome, 0, 2, 1, 2)

        labelWelcome = QLabel("Tutorial")
        labelWelcome.setStyleSheet("QLabel {font-weight: bold; margin-top: 20px;}")
        grid.addWidget(labelWelcome, 1, 1, 1, 3)
        tmp = []
        tmp.append(
            "{tab}o You can load 4 files:\n{tab2}o Image 1 and Image 2 (tif, mandatory - Gradients computed in im2)\n{tab2}o Phi file and Crop file (tsv, optional)".format(
                tab=" " * 4, tab2=" " * 12
            )
        )
        tmp.append(
            "{tab}o Phi (im1 -> im2) and Crop files are saved configurations of previous registrations.\n{tab}{tab}They are automatically adapted to the binning level.".format(
                tab=" " * 4
            )
        )
        tmp.append(
            "{tab}o The first image loaded is the one that is going to be deformed (i.e. interpolated).".format(
                tab=" " * 4
            )
        )
        tmp.append(
            "{tab}o You can load the files directly with the command line:\n\n{tab2}spam-mmr-graphical im1.tif im2.tif phi.tsv crop.tsv\n".format(
                tab=" " * 4, tab2=" " * 12
            )
        )
        tmp.append(
            "{tab}o If the files are not loaded with the command line a file manager will pop up\n{tab}{tab}after choosing the binning level.".format(
                tab=" " * 4
            )
        )
        labelWelcome = QLabel("\n".join(tmp))
        grid.addWidget(labelWelcome, 2, 1, 1, 3)

        labelWelcome = QLabel("Configuration")
        labelWelcome.setStyleSheet("QLabel {font-weight: bold; margin-top: 20px;}")
        grid.addWidget(labelWelcome, 3, 1, 1, 3)

        if len(sys.argv) > 1:
            commandLine = "{} {}".format(
                sys.argv[0].split("/")[-1], " ".join(sys.argv[1:])
            )
        else:
            commandLine = "{}".format(sys.argv[0].split("/")[-1])

        tmp = []
        warning = []
        tmp.append(
            "{tab}o You launched the command line:\n\n{tab2}{com}\n".format(
                tab=" " * 4, tab2=" " * 12, com=commandLine
            )
        )
        if self.rootDirectory is None:
            tmp.append(
                "{tab}o All your data will be saved in the directory of the first loaded image".format(
                    tab=" " * 4
                )
            )
        else:
            tmp.append(
                "{tab}o All your data will be loaded and saved from the root directory:\n\n{tab2}{dir}\n".format(
                    tab=" " * 4, tab2=" " * 12, dir=self.rootDirectory
                )
            )
        if nLoads == 0:
            # should have at least image 1
            tmp.append(
                "{tab}o Your are not going to load anything from the command line arguments.".format(
                    tab=" " * 4
                )
            )
            imName = ""
            print("[SPAM MMR] No file to load from command line")

        elif nLoads > 0:
            tmp.append(
                "{tab}o From the command line arguments you are going to load:".format(
                    tab=" " * 4
                )
            )
            imName = sys.argv[1]
            # try if image 1 exist
            tmp.append(
                "{tab2}o Image 1: {file}".format(tab2=" " * 12, file=sys.argv[1])
            )
            if not os.path.isfile(sys.argv[1]):
                warning.append(
                    "File 1: {file} does not exist, you'll be asked to load it from the file manager.".format(
                        file=sys.argv[1]
                    )
                )
            print(
                "[SPAM MMR] Trying to load file(s): {file}".format(
                    file=", ".join(sys.argv[1:5])
                )
            )
        if nLoads > 1:
            tmp.append(
                "{tab2}o Image 2: {file}".format(tab2=" " * 12, file=sys.argv[2])
            )
            if not os.path.isfile(sys.argv[2]):
                warning.append(
                    "File 2: {file} does not exist, you'll be asked to load it from the file manager.".format(
                        file=sys.argv[2]
                    )
                )
        if nLoads > 2:
            tmp.append("{tab2}o Phi: {file}".format(tab2=" " * 12, file=sys.argv[3]))
            if not os.path.isfile(sys.argv[3]):
                warning.append(
                    "File 3: {file} does not exist, you'll be asked to load it from the file manager.".format(
                        file=sys.argv[3]
                    )
                )
        if nLoads > 3:
            tmp.append("{tab2}o Crop: {file}".format(tab2=" " * 12, file=sys.argv[4]))
            if not os.path.isfile(sys.argv[4]):
                warning.append(
                    "File 4: {file} does not exist, you'll be asked to load it from the file manager.".format(
                        file=sys.argv[4]
                    )
                )
        if nLoads > 4:
            warning.append(
                "The file(s) {file} are ignored.".format(file=", ".join(sys.argv[5:]))
            )
            print(
                "[SPAM MMR] File(s) ignored: {file}".format(
                    file=", ".join(sys.argv[5:])
                )
            )

        labelWelcome = QLabel("\n".join(tmp))
        grid.addWidget(labelWelcome, 4, 1, 1, 3)

        if len(warning):
            labelWelcome = QLabel("\n".join(warning))
            labelWelcome.setStyleSheet("QLabel {margin-top: 20px; color: orange;}")
            grid.addWidget(labelWelcome, 5, 1, 1, 3)

        # binning
        labelBin = QLabel("\nSelect the binning of the images you're going to load:\n")
        grid.addWidget(labelBin, 6, 1, 1, 2)

        self.listOfBinning = QComboBox()
        bins = [1, 2, 4, 8, 16]
        self.listOfBinning.addItems([str(b) for b in bins])
        for i, b in enumerate(bins):
            if "bin{}".format(b) in imName.lower().replace("-", "").replace("_", ""):
                print(
                    "[SPAM MMR] From {}, I'm guessing you work at bin level {}.".format(
                        imName, b
                    )
                )
                self.listOfBinning.setCurrentIndex(i)
        grid.addWidget(self.listOfBinning, 6, 3, 1, 1)

        # next step
        self.startEregButton = QPushButton("Next Step (Eye Registration)", self)
        self.startEregButton.clicked.connect(self.startEreg)
        self.mainWindowGrid.addWidget(self.startEregButton, 7, 1, 1, 1)

        # set icon
        self.setWindowIcon(QIcon(iconUrl))

    def startEreg(self):
        self.binning = int(str(self.listOfBinning.currentText()))
        print("[SPAM MMR] Working at bin level: {}".format(self.binning))

        self.images = []

        # LOAD IMAGE 1

        # try load image 1
        if len(sys.argv) > 1 and os.path.isfile(self.absoluteArgv[1]):
            self.images.append(tifffile.imread(self.absoluteArgv[1]))
            self.fileName1 = self.absoluteArgv[1]
            # print("[SPAM MMR] Loading {} from command line".format(self.fileName))
        # except open file manager
        else:
            try:
                self.fileName1 = QFileDialog.getOpenFileName(
                    None, "Open Image 1", os.getcwd()
                )[0]
                self.images.append(tifffile.imread(self.fileName1))
                os.chdir(os.path.dirname(os.path.realpath(self.fileName1)))
                # print("[SPAM MMR] Loading {} from file manager".format(self.fileName))
            except BaseException as e:
                print("[SPAM MMR] ERROR: {}".format(e))
                print("[SPAM MMR] You need to load image 1... exiting")
                # print(traceback.print_last())
                exit()

        print(
            "[SPAM MMR] Loading image 1: {} of size {}".format(
                self.fileName1, self.images[0].shape
            )
        )

        # LOAD IMAGE 2

        # try load image 2
        if len(self.absoluteArgv) > 2 and os.path.isfile(self.absoluteArgv[2]):
            self.fileName2 = self.absoluteArgv[2]
            print(
                "[SPAM MMR] Loading {} from command line".format(self.absoluteArgv[2])
            )
            self.images.append(tifffile.imread(self.absoluteArgv[2]))
        # except open file manager
        else:
            try:
                self.fileName2 = QFileDialog.getOpenFileName(
                    None, "Open Image 2", os.getcwd()
                )[0]
                self.images.append(tifffile.imread(self.fileName2))
                # print("[SPAM MMR] Loading {} from command line".format(self.fileName))
            except BaseException as e:
                print("[SPAM MMR] ERROR: {}".format(e))
                print("[SPAM MMR] You need to load image 2... exiting")
                exit()

        print(
            "[SPAM MMR] Loading image 2: {} of size {}".format(
                self.fileName2, self.images[1].shape
            )
        )

        try:
            if not self.images[0].shape == self.images[1].shape:
                raise AssertionError("images of different sizes")
        except BaseException as e:
            print("[SPAM MMR] ERROR: {}".format(e))
            print("[SPAM MMR] The two images need to have the same size... exiting")
            exit()

        # LOAD PHI

        # open phi from command line
        if len(self.absoluteArgv) > 3 and os.path.isfile(self.absoluteArgv[3]):
            # try:
            f = numpy.genfromtxt(self.absoluteArgv[3], delimiter="\t", names=True)
            self.Phi = numpy.array(
                [
                    [
                        float(f["Fzz"]),
                        float(f["Fzy"]),
                        float(f["Fzx"]),
                        float(f["Zdisp"]),
                    ],
                    [
                        float(f["Fyz"]),
                        float(f["Fyy"]),
                        float(f["Fyx"]),
                        float(f["Ydisp"]),
                    ],
                    [
                        float(f["Fxz"]),
                        float(f["Fxy"]),
                        float(f["Fxx"]),
                        float(f["Xdisp"]),
                    ],
                    [0, 0, 0, 1],
                ]
            )
            print("[SPAM MMR] Loading Phi file: {}".format(self.absoluteArgv[3]))
        # except:
        # f = numpy.genfromtxt(self.absoluteArgv[3], delimiter="\t", names=True)
        # self.Phi = numpy.array([[float(f["F11"]), float(f["F12"]), float(f["F13"]), float(f['Zdisp'])],
        # [float(f["F21"]), float(f["F22"]), float(f["F23"]), float(f['Ydisp'])],
        # [float(f["F31"]), float(f["F32"]), float(f["F33"]), float(f['Xdisp'])],
        # [0, 0, 0, 1]])
        # print("[SPAM MMR] Loading phi file: {}".format(self.absoluteArgv[3]))

        # open phi from file manager
        else:
            try:
                tmp = QFileDialog.getOpenFileName(
                    None, "(optional) Open Phi TSV", os.getcwd()
                )[0]
                if tmp:
                    try:
                        f = numpy.genfromtxt(tmp, delimiter="\t", names=True)
                        self.Phi = numpy.array(
                            [
                                [
                                    float(f["Fzz"]),
                                    float(f["Fzy"]),
                                    float(f["Fzx"]),
                                    float(f["Zdisp"]),
                                ],
                                [
                                    float(f["Fyz"]),
                                    float(f["Fyy"]),
                                    float(f["Fyx"]),
                                    float(f["Ydisp"]),
                                ],
                                [
                                    float(f["Fxz"]),
                                    float(f["Fxy"]),
                                    float(f["Fxx"]),
                                    float(f["Xdisp"]),
                                ],
                                [0, 0, 0, 1],
                            ]
                        )
                        print("[SPAM MMR] Loaded Phi file: {}".format(tmp))
                    except Exception:
                        f = numpy.genfromtxt(tmp, delimiter="\t", names=True)
                        self.Phi = numpy.array(
                            [
                                [
                                    float(f["F11"]),
                                    float(f["F12"]),
                                    float(f["F13"]),
                                    float(f["Zdisp"]),
                                ],
                                [
                                    float(f["F21"]),
                                    float(f["F22"]),
                                    float(f["F23"]),
                                    float(f["Ydisp"]),
                                ],
                                [
                                    float(f["F31"]),
                                    float(f["F32"]),
                                    float(f["F33"]),
                                    float(f["Xdisp"]),
                                ],
                                [0, 0, 0, 1],
                            ]
                        )
                        print("[SPAM MMR] Loaded Phi file: {}".format(tmp))

                else:
                    print("[SPAM MMR] Taking Phi as Identity Matrix")
                    self.Phi = numpy.eye(4, 4)

            # if press cancel or load wrong file
            except BaseException as e:
                self.Phi = numpy.eye(4, 4)
                print("[SPAM MMR] WARNING: {}".format(e))
                print(
                    "[SPAM MMR] Could not load Phi file, so taking Phi as Identity Matrix"
                )

        # modify phi based on bin level ratio
        if "f" in locals():
            try:
                self.phiBin = int(f["bin"])
            except BaseException as e:
                self.phiBin = int(self.binning)
                print("[SPAM MMR] WARNING: {}".format(e))
                print(
                    "[SPAM MMR] Phi file might comes from another program and bining level is not recognised (default value taken)"
                )

            print("[SPAM MMR] Phi file at bin level: {}".format(self.phiBin))
        else:
            self.phiBin = int(self.binning)

        self.Phi[0:3, -1] *= self.phiBin / self.binning

        # LOAD CROP

        # open crop from command line
        if len(self.absoluteArgv) > 4 and os.path.isfile(self.absoluteArgv[4]):
            print("[SPAM MMR] Loading crop file: {}".format(self.absoluteArgv[4]))
            c = numpy.genfromtxt(self.absoluteArgv[4], delimiter="\t", names=True)

            # try if good gile
            try:
                self.cropBin = int(c["bin"])
            except BaseException as e:
                self.cropBin = int(self.binning)
                print("[SPAM MMR] WARNING: {}".format(e))
                print(
                    "[SPAM MMR] Crop file might comes from another program and bining level is not recognised (default value taken)"
                )

            print("[SPAM MMR] Crop file at bin level: {}".format(self.cropBin))

            try:
                self.crop = (
                    slice(
                        int(int(c["Zs"]) * self.cropBin / self.binning),
                        int(int(c["Ze"]) * self.cropBin / self.binning),
                    ),
                    slice(
                        int(int(c["Ys"]) * self.cropBin / self.binning),
                        int(int(c["Ye"]) * self.cropBin / self.binning),
                    ),
                    slice(
                        int(int(c["Xs"]) * self.cropBin / self.binning),
                        int(int(c["Xe"]) * self.cropBin / self.binning),
                    ),
                )
            except BaseException as e:
                print("[SPAM MMR] WARNING: {}".format(e))
                print("[SPAM MMR] Crop file not recognised (taking full image)")
                self.crop = (
                    slice(1, self.images[0].shape[0] - 1),
                    slice(1, self.images[0].shape[1] - 1),
                    slice(1, self.images[0].shape[2] - 1),
                )

        # open crop from file manager
        else:
            try:
                tmp = QFileDialog.getOpenFileName(
                    None, "(optional) Open crop TSV", os.getcwd()
                )[0]
                if tmp:
                    print("[SPAM MMR] Loading crop: {}".format(tmp))
                    c = numpy.genfromtxt(tmp, delimiter="\t", names=True)
                    try:
                        self.cropBin = int(c["bin"])
                    except BaseException as e:
                        self.cropBin = int(self.binning)
                        print("[SPAM MMR] WARNING: {}".format(e))
                        print(
                            "[SPAM MMR] Crop file might comes from another program and bining level is not recognised (default value taken)"
                        )

                    self.crop = (
                        slice(
                            int(int(c["Zs"]) * self.cropBin / self.binning),
                            int(int(c["Ze"]) * self.cropBin / self.binning),
                        ),
                        slice(
                            int(int(c["Ys"]) * self.cropBin / self.binning),
                            int(int(c["Ye"]) * self.cropBin / self.binning),
                        ),
                        slice(
                            int(int(c["Xs"]) * self.cropBin / self.binning),
                            int(int(c["Xe"]) * self.cropBin / self.binning),
                        ),
                    )
                    print("[SPAM MMR] Crop file at bin level: {}".format(self.cropBin))

                # if press cancel
                else:
                    print("[SPAM MMR] No crop file selected (taking full image)")
                    self.crop = (
                        slice(1, self.images[0].shape[0] - 1),
                        slice(1, self.images[0].shape[1] - 1),
                        slice(1, self.images[0].shape[2] - 1),
                    )

            # Wrong file
            except BaseException as e:
                print("[SPAM MMR] WARNING: {}".format(e))
                print("[SPAM MMR] Crop file not recognised (taking full image)")
                self.crop = (
                    slice(1, self.images[0].shape[0] - 1),
                    slice(1, self.images[0].shape[1] - 1),
                    slice(1, self.images[0].shape[2] - 1),
                )

        try:
            tmp = self.images[0][self.crop]
            del tmp
        except BaseException as e:
            print("[SPAM MMR] WARNING: {}".format(e))
            print("[SPAM MMR] Fail to crop the image (taking full image)")
            self.crop = (
                slice(1, self.images[0].shape[0] - 1),
                slice(1, self.images[0].shape[1] - 1),
                slice(1, self.images[0].shape[2] - 1),
            )

        # rescaling crop if bigger than image
        start = [self.crop[i].start for i in range(3)]
        stop = [self.crop[i].stop for i in range(3)]
        self.crop = (
            slice(start[0], min(stop[0], self.images[0].shape[0])),
            slice(start[1], min(stop[1], self.images[0].shape[0])),
            slice(start[2], min(stop[2], self.images[0].shape[0])),
        )
        print(
            "[SPAM MMR] Cropping image from {} < z < {}".format(
                self.crop[0].start, self.crop[0].stop
            )
        )
        print(
            "[SPAM MMR] Cropping image from {} < y < {}".format(
                self.crop[1].start, self.crop[1].stop
            )
        )
        print(
            "[SPAM MMR] Cropping image from {} < x < {}".format(
                self.crop[2].start, self.crop[2].stop
            )
        )

        self.binningSelection.close()
        passImages = [numpy.copy(self.images[0]), numpy.copy(self.images[1])]
        self.eregWidget = visual.ereg(
            passImages,
            self.Phi,
            [os.path.basename(self.fileName1), os.path.basename(self.fileName2)],
            binning=self.binning,
            imUpdate=imUpdate,
        )
        # 2022-02-22 EA and AT: Issue #222
        del passImages

        self.mainWindowGrid.addWidget(self.eregWidget, 1, 1)
        self.startEregButton.close()
        self.endEregButton = QPushButton("Next Step (Cropping)", self)
        self.endEregButton.clicked.connect(self.endEreg)
        self.mainWindowGrid.addWidget(self.endEregButton, 2, 1)

    def endEreg(self):
        self.Phi = self.eregWidget.output()
        self.eregWidget.close()
        self.endEregButton.close()

        del self.eregWidget
        del self.endEregButton

        passImages = [numpy.copy(self.images[0]), numpy.copy(self.images[1])]
        self.cropWidget = visual.QtCropWidget(
            passImages,
            self.Phi,
            self.crop,
            [os.path.basename(self.fileName1), os.path.basename(self.fileName2)],
            binning=self.binning,
            imUpdate=imUpdate,
        )
        # 2022-02-22 EA and AT: Issue #222
        del passImages

        self.mainWindowGrid.addWidget(self.cropWidget, 1, 1, 1, 2)
        self.returnEregButton = QPushButton("Back (Eye Registration)")
        self.returnEregButton.clicked.connect(self.backEreg)
        self.mainWindowGrid.addWidget(self.returnEregButton, 2, 1)
        self.endCropButton = QPushButton("Next Step (Joint Histogram)")
        self.endCropButton.clicked.connect(self.endCrop)
        self.mainWindowGrid.addWidget(self.endCropButton, 2, 2)

    def backEreg(self):
        self.crop = self.cropWidget.output()
        self.cropWidget.close()
        self.returnEregButton.close()
        self.endCropButton.close()

        del self.cropWidget
        del self.returnEregButton
        del self.endCropButton

        passImages = [numpy.copy(self.images[0]), numpy.copy(self.images[1])]
        self.eregWidget = visual.ereg(
            passImages,
            self.Phi,
            [os.path.basename(self.fileName1), os.path.basename(self.fileName2)],
            binning=self.binning,
            imUpdate=imUpdate,
        )
        # 2022-02-22 EA and AT: Issue #222
        del passImages

        self.mainWindowGrid.addWidget(self.eregWidget, 1, 1)
        self.endEregButton = QPushButton("Next Step (Cropping)", self)
        self.endEregButton.clicked.connect(self.endEreg)
        self.mainWindowGrid.addWidget(self.endEregButton, 2, 1)

    def endCrop(self):
        self.crop = self.cropWidget.output()
        self.cropWidget.close()
        self.endCropButton.close()
        self.returnEregButton.close()
        passImages = [numpy.copy(self.images[0]), numpy.copy(self.images[1])]
        self.JoinHist = visual.JHist(
            passImages,
            self.Phi,
            self.crop,
            self.fontColor,
            [os.path.basename(self.fileName1), os.path.basename(self.fileName2)],
            imUpdate,
        )
        # 2022-02-22 EA and AT: Issue #222
        del passImages

        self.mainWindowGrid.addWidget(self.JoinHist, 1, 1, 1, 2)
        self.returnCropButton = QPushButton("Back (Cropping)")
        self.returnCropButton.clicked.connect(self.backCrop)
        self.mainWindowGrid.addWidget(self.returnCropButton, 2, 1)
        self.endJoinHistButton = QPushButton("Next Step (Phase Diagram)")
        self.endJoinHistButton.clicked.connect(self.endJHist)
        self.mainWindowGrid.addWidget(self.endJoinHistButton, 2, 2)

    def backCrop(self):
        self.returnCropButton.close()
        self.JoinHist.close()
        self.endJoinHistButton.close()
        passImages = [numpy.copy(self.images[0]), numpy.copy(self.images[1])]
        self.cropWidget = visual.QtCropWidget(
            passImages,
            self.Phi,
            self.crop,
            [os.path.basename(self.fileName1), os.path.basename(self.fileName2)],
            binning=self.binning,
            imUpdate=imUpdate,
        )
        # 2022-02-22 EA and AT: Issue #222
        passImages

        self.mainWindowGrid.addWidget(self.cropWidget, 1, 1, 1, 2)
        self.returnEregButton = QPushButton("Back (Eye Registatration)")
        self.returnEregButton.clicked.connect(self.backEreg)
        self.mainWindowGrid.addWidget(self.returnEregButton, 2, 1)
        self.endCropButton = QPushButton("Next Step (Joint Histogram)")
        self.endCropButton.clicked.connect(self.endCrop)
        self.mainWindowGrid.addWidget(self.endCropButton, 2, 2)

    def endJHist(self):
        if len(self.JoinHist.gaussianParameters) != 0:
            self.returnCropButton.close()
            self.endJoinHistButton.close()
            gaussianParameters = self.JoinHist.gaussianParameters
            bins = self.JoinHist.BINS
            jointHistogram = self.JoinHist.hist
            self.JoinHist.hide()
            self.phasePhase = visual.PhaseDiagram(
                gaussianParameters,
                bins,
                jointHistogram,
                [os.path.basename(self.fileName1), os.path.basename(self.fileName2)],
            )
            # 2022-02-22 EA and AT: Dump current fitting for eventual passing into spam-mmr

            self.mainWindowGrid.addWidget(self.phasePhase, 1, 1, 1, 2)
            self.returnJoinHist = QPushButton("Back (Joint Histogram)")
            self.returnJoinHist.clicked.connect(self.backJHist)
            self.mainWindowGrid.addWidget(self.returnJoinHist, 2, 1)
            self.endPhaseDiagramButton = QPushButton("Next Step (Final Step)")
            self.endPhaseDiagramButton.clicked.connect(self.endPhaseDiagram)
            self.mainWindowGrid.addWidget(self.endPhaseDiagramButton, 2, 2)

    def backJHist(self):
        self.phasePhase.close()
        self.returnJoinHist.close()
        self.endPhaseDiagramButton.close()
        # passImages = [numpy.copy(self.images[0]), numpy.copy(self.images[1])]
        # self.JoinHist = JHist(passImages, self.Phi, self.crop)
        self.JoinHist.show()
        self.mainWindowGrid.addWidget(self.JoinHist, 1, 1, 1, 2)
        self.returnCropButton = QPushButton("Back (Cropping)")
        self.returnCropButton.clicked.connect(self.backCrop)
        self.mainWindowGrid.addWidget(self.returnCropButton, 2, 1)
        self.endJoinHistButton = QPushButton("Next Step (Phase Diagram)")
        self.endJoinHistButton.clicked.connect(self.endJHist)
        self.mainWindowGrid.addWidget(self.endJoinHistButton, 2, 2)

    def endPhaseDiagram(self):
        self.endPhaseDiagramButton.close()
        self.returnJoinHist.close()
        self.phasePhase.hide()
        passImages = [numpy.copy(self.images[0]), numpy.copy(self.images[1])]
        crop = self.crop
        TMPphi = self.Phi
        bins = self.phasePhase.BINS
        GP = self.phasePhase.gaussianParameters
        phaseDiagram = self.phasePhase.phase
        # greyLimitsOrig = self.JoinHist.greyLimitsOrig
        greyLimitsCrop = self.JoinHist.greyLimitsCrop
        self.returnPhaseDiagram = QPushButton("Back (Phase Diagram)")
        self.returnPhaseDiagram.clicked.connect(self.backPhaseDiagram)
        self.finalStep = visual.FinalStep(
            passImages,
            crop,
            phaseDiagram,
            TMPphi,
            GP,
            bins,
            greyLimitsCrop,
            self.binning,
            self.fileName1,
            [os.path.basename(self.fileName1), os.path.basename(self.fileName2)],
        )
        self.mainWindowGrid.addWidget(self.returnPhaseDiagram, 2, 1)
        self.mainWindowGrid.addWidget(self.finalStep, 1, 1, 1, 2)

    def backPhaseDiagram(self):
        self.returnPhaseDiagram.close()
        self.finalStep.close()
        self.phasePhase.show()
        self.returnJoinHist = QPushButton("Back (Joint Histogram)")
        self.returnJoinHist.clicked.connect(self.backJHist)
        self.mainWindowGrid.addWidget(self.returnJoinHist, 2, 1)
        self.endPhaseDiagramButton = QPushButton("Next Step (Final Step)")
        self.endPhaseDiagramButton.clicked.connect(self.endPhaseDiagram)
        self.mainWindowGrid.addWidget(self.endPhaseDiagramButton, 2, 2)


def main():
    app = QApplication(["Multimodal Registration"])
    window = MainWindow()
    window.show()
    sys.exit(app.exec_())


if __name__ == "__main__":
    main()
