#! /usr/bin/env python
#
# This is open-source software licensed under a BSD license.
# Please see the file LICENSE.txt for details.
#
import pickle

from autobahn.twisted.component import Component, run
from autobahn.twisted.wamp import ApplicationSession
from influxdb import InfluxDBClient
from twisted.internet import threads
from twisted.internet.defer import inlineCallbacks
from twisted.logger import Logger
from twisted.internet.task import LoopingCall

HEARTBEAT_DELAY = 15


class TelemetryLogger(ApplicationSession):
    log = Logger()

    def onJoin(self, details):
        print("session ready")
        self.topic_callbacks = {
            "hipercam.ccd1.telemetry": self.log_ccd1_telemetry,
            "hipercam.ccd2.telemetry": self.log_ccd2_telemetry,
            "hipercam.ccd3.telemetry": self.log_ccd3_telemetry,
            "hipercam.ccd4.telemetry": self.log_ccd4_telemetry,
            "hipercam.ccd5.telemetry": self.log_ccd5_telemetry,
            "hipercam.pressure1.telemetry": self.log_ccd1_pressure_telemetry,
            "hipercam.pressure2.telemetry": self.log_ccd2_pressure_telemetry,
            "hipercam.pressure3.telemetry": self.log_ccd3_pressure_telemetry,
            "hipercam.pressure4.telemetry": self.log_ccd4_pressure_telemetry,
            "hipercam.pressure5.telemetry": self.log_ccd5_pressure_telemetry,
            "hipercam.slide.telemetry": self.log_slide_telemetry,
            "hipercam.compo.telemetry": self.log_compo_telemetry,
            "hipercam.autoguider.telemetry": self.log_ag_telemetry,
            "hipercam.rack.telemetry": self.log_rack_telemetry,
        }
        self._client = InfluxDBClient(host="pulsar.shef.ac.uk", port=8086)
        self._client.switch_database("hipercam")
        self._points = []  # list of dictionaries containing data to write
        self._bulk_limit = 10  # a write is triggered once we have this many points
        self._retries_remaining = (
            10  # will attempt to write data numerous times before giving up
        )
        # subscribe to telemetry topics
        for topic in self.topic_callbacks:
            callback = self.topic_callbacks[topic]
            self.subscribe(callback, topic)

        self._tick_no = 0
        self._tick_loop = LoopingCall(self._tick)
        self._tick_loop.start(HEARTBEAT_DELAY)

    def _tick(self):
        self._tick_no += 1
        self.log.info("hwlogger is alive [tick {}]".format(self._tick_no))

    def update_db(self):
        """
        Update influx DB
        """
        try:
            self.log.debug("attempting write")
            result = self._client.write_points(self._points, time_precision="ms")
        except Exception as err:
            self.log.warn(f"failed writing batch of data (will retry): {err}")
            result = False

        self.log.debug("updated DB")
        if result:
            self._points = []
            self._retries_remaining = 10
        elif self._retries_remaining > 0:
            self._retries_remaining -= 1
        else:
            self.points = []
            self.retries_remainining = 10
            self.log.error(
                "could not write data to DB, giving up. this batch of data is lost!"
            )

    @inlineCallbacks
    def log_telemetry(self, point):
        """
        asynchronous database update
        """
        self.log.debug("logging telemetry")
        # if we haven't reached threshold for update, add to queue
        if len(self._points) < self._bulk_limit:
            self._points.append(point)
            return

        # we've reached bulk threshold, update_db
        try:
            self.log.debug("logging telemetry to DB")
            yield threads.deferToThread(self.update_db)
        except Exception as e:
            self.log.error(str(e))

    def preprocess_telemetry(self, data):
        telemetry = pickle.loads(data)
        ts = telemetry.pop("timestamp")
        ts.precision = 2
        return ts.iso, telemetry

    def process_field(self, field):
        if hasattr(field, "value"):
            return field.value
        return field

    def make_point(self, data, measurement):
        point = {}
        point["measurement"] = measurement
        ts, telemetry = self.preprocess_telemetry(data)
        point["time"] = ts
        # set fields to telemetry
        point["fields"] = {k: self.process_field(v) for (k, v) in telemetry.items()}
        return point

    def log_ccd_telemetry(self, ccd, data):
        try:
            point = self.make_point(data, ccd)
            self.log.debug(f"point made with timestamp {point['time']}")
        except Exception as err:
            self.log.error(f"cannot process telemetry for {ccd}: {err}")
        else:
            # don't log pressure twice!
            point["fields"].pop("pressure")
            self.log_telemetry(point)

    def log_rack_telemetry(self, data):
        try:
            point = self.make_point(data, "rack")
        except Exception as err:
            self.log.error(f"cannot process rack telemetry: {err}")
        else:
            # only want temperatures
            point["fields"] = {
                k.replace("rack_", ""): v
                for (k, v) in point["fields"].items()
                if "rack_temp" in k
            }
            self.log_telemetry(point)

    def log_pressure_telemetry(self, ccd, data):
        try:
            point = self.make_point(data, ccd)
        except Exception as err:
            self.log.error(f"cannot process pressure telemetry for {ccd}: {err}")
        else:
            # only want pressure data
            point["fields"] = {
                k: v for (k, v) in point["fields"].items() if k == "pressure"
            }
            self.log_telemetry(point)

    def log_ccd1_telemetry(self, data):
        self.log_ccd_telemetry("ccd1", data)

    def log_ccd2_telemetry(self, data):
        self.log_ccd_telemetry("ccd2", data)

    def log_ccd3_telemetry(self, data):
        self.log_ccd_telemetry("ccd3", data)

    def log_ccd4_telemetry(self, data):
        self.log_ccd_telemetry("ccd4", data)

    def log_ccd5_telemetry(self, data):
        self.log_ccd_telemetry("ccd5", data)

    def log_ccd1_pressure_telemetry(self, data):
        self.log_pressure_telemetry("ccd1", data)

    def log_ccd2_pressure_telemetry(self, data):
        self.log_pressure_telemetry("ccd2", data)

    def log_ccd3_pressure_telemetry(self, data):
        self.log_pressure_telemetry("ccd3", data)

    def log_ccd4_pressure_telemetry(self, data):
        self.log_pressure_telemetry("ccd4", data)

    def log_ccd5_pressure_telemetry(self, data):
        self.log_pressure_telemetry("ccd5", data)

    def log_ag_telemetry(self, data):
        try:
            point = self.make_point(data, "autoguider")
        except Exception as err:
            self.log.error(f"cannot process telemetry for autoguider: {err}")
        else:
            if point["fields"]["state"] == "guiding":
                self.log_telemetry(point)

    def log_compo_telemetry(self, data):
        """
        We don't currently log COMPO telemetry
        """
        pass

    def log_slide_telemetry(self, data):
        """
        We don't currently log slide telemetry
        """
        pass


if __name__ == "__main__":
    import os

    URL = os.getenv("WAMP_SERVER", "192.168.1.2")

    comp = Component(
        transports=f"ws://{URL}:8080/ws",
        realm="realm1",
        session_factory=TelemetryLogger,
    )
    run([comp], log_level="info")
