#! /usr/bin/env python
#
# This is open-source software licensed under a BSD license.
# Please see the file LICENSE.txt for details.
#
import pickle
import math
from queue import Queue

import pandas as pd
from autobahn.twisted.component import Component, run
from autobahn.twisted.wamp import ApplicationSession
from influxdb import InfluxDBClient
from twisted.internet import threads
from twisted.internet.defer import inlineCallbacks
from twisted.logger import Logger
from twisted.internet.task import LoopingCall

HEARTBEAT_DELAY = 15


class TelemetryWriter:
    def __init__(self, num_retries=10):
        self._client = InfluxDBClient(host="pulsar.shef.ac.uk", port=8086)
        self._client.switch_database("hipercam")
        self._retries_remaining = num_retries
        # we use a threadsafe queue to store data points for writing
        self._data_queue = Queue()

    def put(self, points):
        self._data_queue.put(points)

    def write_points(self):
        points = self._data_queue.get()
        while self._retries_remaining > 0:
            try:
                self._client.write_points(points, time_precision="ms")
            except Exception as e:
                self._retries_remaining -= 1
                continue
            else:
                self._retries_remaining = 10
                break
        if self._retries_remaining == 0:
            raise Exception("cannot write to DB, retries exhausted")
        self._data_queue.task_done()


class TelemetryLogger(ApplicationSession):
    log = Logger()

    def onJoin(self, details):
        print("session ready")
        self.topic_callbacks = {
            "hipercam.ccd1.telemetry": self.log_ccd1_telemetry,
            "hipercam.ccd2.telemetry": self.log_ccd2_telemetry,
            "hipercam.ccd3.telemetry": self.log_ccd3_telemetry,
            "hipercam.ccd4.telemetry": self.log_ccd4_telemetry,
            "hipercam.ccd5.telemetry": self.log_ccd5_telemetry,
            "hipercam.pressure1.telemetry": self.log_ccd1_pressure_telemetry,
            "hipercam.pressure2.telemetry": self.log_ccd2_pressure_telemetry,
            "hipercam.pressure3.telemetry": self.log_ccd3_pressure_telemetry,
            "hipercam.pressure4.telemetry": self.log_ccd4_pressure_telemetry,
            "hipercam.pressure5.telemetry": self.log_ccd5_pressure_telemetry,
            "hipercam.slide.telemetry": self.log_slide_telemetry,
            "hipercam.compo.telemetry": self.log_compo_telemetry,
            "hipercam.autoguider.telemetry": self.log_ag_telemetry,
            "hipercam.rack.telemetry": self.log_rack_telemetry,
            "hipercam.gps.telemetry": self.log_gps_telemetry,
            "hipercam.chiller.telemetry": self.log_chiller_telemetry,
        }
        self._points = []  # list of dictionaries containing data to write
        self._bulk_limit = 50  # a write is triggered once we have this many points
        # writer class
        self._writer = TelemetryWriter()

        # subscribe to telemetry topics
        for topic in self.topic_callbacks:
            callback = self.topic_callbacks[topic]
            self.subscribe(callback, topic)

        self._tick_no = 0
        self._tick_loop = LoopingCall(self._tick)
        self._tick_loop.start(HEARTBEAT_DELAY)

    def _tick(self):
        self._tick_no += 1
        self.log.info("hwlogger is alive [tick {}]".format(self._tick_no))

    @inlineCallbacks
    def log_telemetry(self, point):
        """
        asynchronous database update, with retries
        """
        try:
            self.log.debug("logging telemetry")
            # add this point
            if not self.is_empty(point):
                self._points.append(point)

                # if we haven't reached threshold for update, add to queue
                if len(self._points) < self._bulk_limit:
                    return

                # we've reached bulk threshold, add to database queue
                self.log.debug("logging telemetry to DB")
                self._writer.put(self._points)
                self._points = []
                try:
                    yield threads.deferToThread(self._writer.write_points)
                except Exception as e:
                    self.log.warn(f"cannot write to DB: {e}\n this data is lost")
                finally:
                    self.log.debug("data written to DB")
        except Exception as e:
            self.log.warn(str(e))

    def preprocess_telemetry(self, data):
        telemetry = pickle.loads(data)
        ts = telemetry.pop("timestamp")
        ts.precision = 2
        return ts.iso, telemetry

    def process_field(self, field):
        if hasattr(field, "value"):
            return field.value
        return field

    def _is_not_null(self, value):
        is_none = value is None
        # check for empty strings or None
        if isinstance(value, str):
            return not value or is_none
        # numeric values could be NaN or None - check None first
        if is_none:
            return False
        # not a string or None
        return not math.isnan(value)

    def is_empty(self, point):
        # check if all points are None or NaN
        # first change None to NaN
        try:
            fields = point["fields"]
            fields = [
                field if self._is_not_null(field) else None for field in fields.values()
            ]
            return all([v is None for v in fields])
        except Exception as err:
            self.log.error(f"cannot check if point is empty: {err}")
            return True

    def make_point(self, data, measurement):
        point = {}
        point["measurement"] = measurement
        ts, telemetry = self.preprocess_telemetry(data)
        point["time"] = ts
        # set fields to telemetry
        point["fields"] = {k: self.process_field(v) for (k, v) in telemetry.items()}
        return point

    def dump_points(self):
        df = pd.json_normalize(self._points)
        df.to_csv("dump.csv", index=True)

    def log_ccd_telemetry(self, ccd, data):
        try:
            point = self.make_point(data, ccd)
            self.log.debug(f"point made with timestamp {point['time']}")
        except Exception as err:
            self.log.error(f"cannot process telemetry for {ccd}: {err}")
        else:
            # don't log pressure twice!
            point["fields"].pop("pressure")
            self.log_telemetry(point)

    def log_gps_telemetry(self, data):
        try:
            point = self.make_point(data, "gps")
            del point["fields"]["state"]
        except Exception as err:
            self.log.error(f"cannot process GPS telemetry: {err}")
        else:
            self.log_telemetry(point)

    def log_chiller_telemetry(self, data):
        try:
            point = self.make_point(data, "chiller")
            del point["fields"]["state"]
        except Exception as err:
            self.log.error(f"cannot process Chiller telemetry: {err}")
        else:
            self.log_telemetry(point)

    def log_rack_telemetry(self, data):
        try:
            point = self.make_point(data, "rack")
        except Exception as err:
            self.log.error(f"cannot process rack telemetry: {err}")
        else:
            dewpoint = point["fields"].pop("dewpoint")
            # strip "rack_" from field names
            point["fields"] = {
                k.replace("rack_", ""): v
                for (k, v) in point["fields"].items()
                if "rack_temp" in k
            }
            point["fields"]["dewpoint"] = dewpoint
            self.log_telemetry(point)

    def log_pressure_telemetry(self, ccd, data):
        try:
            point = self.make_point(data, ccd)
        except Exception as err:
            self.log.error(f"cannot process pressure telemetry for {ccd}: {err}")
        else:
            # only want pressure data
            point["fields"] = {
                k: v for (k, v) in point["fields"].items() if k == "pressure"
            }
            self.log_telemetry(point)

    def log_ccd1_telemetry(self, data):
        self.log_ccd_telemetry("ccd1", data)

    def log_ccd2_telemetry(self, data):
        self.log_ccd_telemetry("ccd2", data)

    def log_ccd3_telemetry(self, data):
        self.log_ccd_telemetry("ccd3", data)

    def log_ccd4_telemetry(self, data):
        self.log_ccd_telemetry("ccd4", data)

    def log_ccd5_telemetry(self, data):
        self.log_ccd_telemetry("ccd5", data)

    def log_ccd1_pressure_telemetry(self, data):
        self.log_pressure_telemetry("ccd1", data)

    def log_ccd2_pressure_telemetry(self, data):
        self.log_pressure_telemetry("ccd2", data)

    def log_ccd3_pressure_telemetry(self, data):
        self.log_pressure_telemetry("ccd3", data)

    def log_ccd4_pressure_telemetry(self, data):
        self.log_pressure_telemetry("ccd4", data)

    def log_ccd5_pressure_telemetry(self, data):
        self.log_pressure_telemetry("ccd5", data)

    def log_ag_telemetry(self, data):
        try:
            point = self.make_point(data, "autoguider")
        except Exception as err:
            self.log.error(f"cannot process telemetry for autoguider: {err}")
        else:
            if point["fields"]["state"] == "guiding":
                self.log_telemetry(point)

    def log_compo_telemetry(self, data):
        """
        We don't currently log COMPO telemetry
        """
        pass

    def log_slide_telemetry(self, data):
        """
        We don't currently log slide telemetry
        """
        pass


if __name__ == "__main__":
    import os

    URL = os.getenv("WAMP_SERVER", "192.168.1.2")

    comp = Component(
        transports=f"ws://{URL}:8080/ws",
        realm="realm1",
        session_factory=TelemetryLogger,
    )
    run([comp], log_level="info")
