#!/usr/bin/python3
# -*- coding: utf-8 -*-

"""package scinstr-bin
author    Benoit Dubois
copyright FEMTO ENGINEERING, 2022
license   GPL v3.0+
brief     Acquire data trace from N5234A or N5230A device.
          Try to detrend trace in order to extract mode shape.
"""

# Ctrl-c closes the application
import signal
signal.signal(signal.SIGINT, signal.SIG_DFL)

import sys
import os.path as path
import logging
from pyqtgraph.parametertree import Parameter, ParameterTree
import pyqtgraph as pg
import numpy as np
import numpy.polynomial.polynomial as poly
import scipy.signal as scs
from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout, \
    QHBoxLayout, QMessageBox, QFileDialog
from PyQt5.QtCore import pyqtSlot, pyqtSignal, QDir, QFileInfo

import scinstr.vna.n523x as vna

CONSOLE_LOG_LEVEL = logging.DEBUG
FILE_LOG_LEVEL = logging.WARNING

APP_NAME = "ExtractPeak"

#===============================================================================
class MyVna(vna.N523x):

    def connect(self, try_=3):
        """Overloaded vna.Vna method beacause, VNA seems to refuse connection
        without reasons.
        :param try_: number of connection attempt (int).
        """
        for i in range(try_):
            if super().connect() is True:
                break

#===============================================================================
PARAMS = [
    {'name': 'Load data', 'type': 'group', 'children': [
        {'name': 'VNA', 'type': 'group', 'children': [
            {'name': 'IP', 'type': 'str'}, ##, 'value': vna.DEFAULT_IP},
            {'name': 'Port', 'type': 'int', 'value': vna.PORT},
            {'name': 'Acquisition', 'type': 'action'},
        ]},
        {'name': 'File', 'type': 'group', 'children': [
            {'name': 'Filename', 'type': 'str'},
            {'name': 'Open', 'type': 'action'},
        ]},
    ]},
    {'name': 'Fit', 'type': 'group', 'children': [
        {'name': 'Order', 'type': 'int', 'value': 3, 'limits': (1, 15)},
        {'name': 'Filtering', 'type': 'bool', 'value': False},
        {'name': 'Run fit', 'type': 'action'},
    ]},
    {'name': 'Plot', 'type': 'group', 'children': []},
]

class PeakUi(QMainWindow):
    """Ui of extract peak application.
    """

    def __init__(self):
        """Constructor.
        :returns: None
        """
        super().__init__()
        self.setWindowTitle("Fit Peak")
        self.setCentralWidget(self._central_widget())
        self.region = pg.LinearRegionItem()
        self.region.setZValue(10)
        # Add the LinearRegionItem to the ViewBox, but tell the ViewBox to
        # exclude this item when doing auto-range calculations.
        self.mplot1.addItem(self.region, ignoreBounds=True)

    def _central_widget(self):
        """Generates central widget.
        :returns: central widget of UI (QWidget)
        """
        self.p = Parameter.create(name='params', type='group', children=PARAMS)
        self.ptree = ParameterTree()
        self.ptree.setParameters(self.p, showTop=False)
        self.mplot1 = pg.PlotWidget()
        self.mplot2 = pg.PlotWidget()
        plot_lay = QVBoxLayout()
        plot_lay.addWidget(self.mplot1)
        plot_lay.addWidget(self.mplot2)
        main_lay = QHBoxLayout()
        main_lay.addWidget(self.ptree)
        main_lay.addLayout(plot_lay)
        main_lay.setStretchFactor(plot_lay, 2)
        central_widget = QWidget()
        central_widget.setLayout(main_lay)
        return central_widget

#===============================================================================
class PeakApp(QApplication):
    """Peak application.
    """

    acquire_done = pyqtSignal()
    fit_done = pyqtSignal(object, object)

    def __init__(self, args):
        """Constructor.
        :returns: None
        """
        super().__init__(args)
        self._ui = PeakUi()
        self._data = None
        self._cdata_name = None
        self._dplot = None # Data plot
        self._fdplot = None # Fitted data plot
        self._pplot = None # (Extracted) Peak plot
        self._ui.p.param(
            'Load data', 'VNA', 'Acquisition').sigActivated.connect(
                self.acquire_data_from_device)
        self._ui.p.param('Load data', 'File', 'Open').sigActivated.connect(
            self.acquire_data_from_file)
        self._ui.p.param('Fit', 'Run fit').sigActivated.connect(
            lambda: self.fit_peak(self._ui.p.param('Fit', 'Order').value(),
                                  self._ui.p.param('Fit', 'Filtering').value()))
        self.acquire_done.connect(self.handle_acq_data)
        self.fit_done.connect(self.display_fit)
        self._ui.show()

    @pyqtSlot()
    def acquire_data_from_device(self):
        """Acquire data process.
        :returns: None
        """
        ip = self._ui.p.param('Load data', 'VNA', 'IP').value()
        port = self._ui.p.param('Load data', 'VNA', 'Port').value()
        # Acquisition itself
        dev = MyVna(ip, port)
        dev.connect()
        dev.write("FORMAT:DATA REAL,64")
        try:
            datas = dev.get_measurements()
        except Exception as ex:
            logging.error("Problem during acquisition: %s", str(ex))
            QMessageBox.warning(self._ui, "Acquisition problem",
                                "Problem during acquisition: {}".format(ex),
                                QMessageBox.Ok)
            return
        self._data = {dev.measurement_number_to_name(idx+1):
                      data for idx, data in enumerate(datas)}
        self.acquire_done.emit()

    @pyqtSlot()
    def acquire_data_from_file(self):
        filename = QFileDialog().getOpenFileName(
            parent=None,
            caption="Choose .s2p file to load",
            directory=QDir.currentPath(),
            filter="s2p files (*.s2p);;Any files (*)")
        if filename == '':
            return
        try:
            data = np.transpose(np.loadtxt(filename, comments=('!', '#')))
        except Exception as ex:
            logging.error("Problem when reading file: %s", str(ex))
            QMessageBox.warning(self._ui, "Acquisition problem",
                                "Problem when reading file: {}".format(ex),
                                QMessageBox.Ok)
            return
        self._data = {QFileInfo(filename).baseName(): data}
        self.acquire_done.emit()

    @pyqtSlot()
    def handle_acq_data(self):
        self._ui.p.param('Plot').clearChildren()
        for name_, data in self._data.items():
            self._ui.p.param('Plot').addChild(
                Parameter.create(name=name_, type='action'))
        for child in self._ui.p.param('Plot').children():
            child.sigActivated.connect(self.set_cdata)
        self.set_cdata(self._ui.p.param('Plot').children()[0])

    @pyqtSlot(object)
    def set_cdata(self, param):
        """
        :param param: Parameter object
        """
        self._cdata_name = param.name()
        self.display_data()

    @pyqtSlot()
    def display_data(self):
        if self._dplot is not None:
            self._ui.mplot1.removeItem(self._dplot)
        if self._fdplot is not None:
            self._ui.mplot1.removeItem(self._fdplot)
        if self._pplot is not None:
            self._ui.mplot2.removeItem(self._pplot)
        self._dplot = self._ui.mplot1.plot(
            self._data[self._cdata_name][0, :],
            self._data[self._cdata_name][1, :], pen="w")
        self._ui.region.setRegion([self._data[self._cdata_name][0, 0],
                                   self._data[self._cdata_name][0, -1]])

    @pyqtSlot(int, bool)
    def fit_peak(self, order, filter_=False):
        name_ = self._cdata_name
        x_min, x_max = self._ui.region.getRegion()
        x_min_id = np.where(self._data[name_][0, :] >= x_min)[0][0]
        x_max_id = np.where(self._data[name_][0, :] <= x_max)[0][-1]
        x_n = np.concatenate([self._data[name_][0, 1:x_min_id],
                              self._data[name_][0, x_max_id:-1]])
        # Suppress mean to get better fitting:
        x_n0 = x_n - self._data[name_][0, :].mean()
        y_n = np.concatenate([self._data[name_][1, 1:x_min_id],
                              self._data[name_][1, x_max_id:-1]])
        try:
            coefs = poly.polyfit(x_n0, y_n, order)
        except TypeError: # when user has not selected region to exclude
            QMessageBox.information(self._ui,
                                    "ROI omited",
                                    "Select region to exclude before fit")
            return
        xfit = self._data[name_][0, :] - self._data[name_][0, :].mean()
        yfit = poly.polyval(xfit, coefs)
        res = self._data[name_][1, :] - yfit
        if filter_ is True:
            b, a = scs.butter(3, 0.005)
            res = scs.filtfilt(b, a, res, padlen=150)
        self.fit_done.emit(yfit, res)

    @pyqtSlot(object, object)
    def display_fit(self, yfit, res):
        name_ = self._cdata_name
        if self._fdplot is not None:
            self._ui.mplot1.removeItem(self._fdplot)
        if self._pplot is not None:
            self._ui.mplot2.removeItem(self._pplot)
        self._fdplot = self._ui.mplot1.plot(
            self._data[name_][0, :], yfit, pen='g')
        self._pplot = self._ui.mplot2.plot(
            self._data[name_][0, :]-self._data[name_][0, :].mean(),
            res, pen='g')

#==============================================================================
def configure_logging():
    """Configures logs.
    """
    home = path.expanduser("~")
    log_file = "." + APP_NAME + ".log"
    abs_log_file = path.join(home, log_file)
    date_fmt = "%d/%m/%Y %H:%M:%S"
    log_format = "%(asctime)s %(levelname) -8s %(filename)s " + \
                 " %(funcName)s (%(lineno)d): %(message)s"
    logging.basicConfig(level=FILE_LOG_LEVEL, \
                        datefmt=date_fmt, \
                        format=log_format, \
                        filename=abs_log_file, \
                        filemode='w')
    console = logging.StreamHandler()
    # define a Handler which writes messages to the sys.stderr
    console.setLevel(CONSOLE_LOG_LEVEL)
    # set a format which is simpler for console use
    console_format = '%(levelname) -8s %(filename)s (%(lineno)d): %(message)s'
    formatter = logging.Formatter(console_format)
    # tell the handler to use this format
    console.setFormatter(formatter)
    # add the handler to the root logger
    logging.getLogger('').addHandler(console)

#==============================================================================
def main():
    configure_logging()
    app = PeakApp(sys.argv)
    sys.exit(app.exec_())

#==============================================================================
if __name__ == '__main__':
    main()
