#!/bin/python

# -*- coding: utf-8 -*-
"""
tritond
~~~~~~~~

Logging daemon for non-blocking triton events.

Adapted from https://github.com/rhettg/BlueOx/blob/master/bin/oxd

"""
from __future__ import unicode_literals
import argparse
import errno
import sys
import logging
import signal
import struct
import os
import time
from collections import defaultdict
import msgpack
import json

import zmq
import pystatsd

from triton import nonblocking_stream
from triton.stream import get_stream
from triton import config, errors

ENV_VAR_TRITON_CONFIG_PATH = 'TRITON_CONFIG'
DEFAULT_CONFIG_PATH = '/etc/triton.yaml'

# We want to limit how many messages we'll hold in memory so if kinesis is
# unavailable, we don't just run out of memory.  I based this value on rough
# value of rather large 3k sized messages, and how many we can fit in 10 megs.
MAX_QUEUED_MESSAGES = 3500

# Batch writes to Kinesis; collect messages for this long then write
POLL_LOOP_TIMEOUT_MS = 100

# See https://github.com/postmates/pystatsd for statsd configuration
# Note that pystatsd writes raise no exception if no statsd server is running
STATSD_PREFIX = 'tritond.'
STATSD_EVENTCOUNT = STATSD_PREFIX + "eventcount."
STATSD_LOOPTIME = STATSD_PREFIX + "write_loop.timing"

log = logging.getLogger("triton.d")

_triton_config = None
_streams = dict()

# version byte in our meta struct for JSON meta.
META_STRUCT_VERSION_JSON = 0x7B


def setup_logging(options):
    if len(options.verbose) > 1:
        level = logging.DEBUG
    elif options.verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    log_format = "%(asctime)s %(levelname)s:%(name)s: %(message)s"

    # We just log straight to stdout. Generally tritond is run by some process
    # that handles collecting logging for you, like upstart or supervisord. It
    # would be easy enough to add an option if it was needed by someone though.
    logging.basicConfig(level=level, format=log_format, stream=sys.stdout)


def get_triton_config():
    global _triton_config
    if not _triton_config:
        config_path = os.environ.get(
            ENV_VAR_TRITON_CONFIG_PATH, DEFAULT_CONFIG_PATH)
        _triton_config = config.load_config(config_path)

    if _triton_config is None:
        raise errors.TritonNotConfiguredError(
            'Failed to load config for tritond')
    return _triton_config


def check_meta_version(meta):
    value, = struct.unpack(">B", meta[0])
    if value not in (
        nonblocking_stream.META_STRUCT_VERSION, META_STRUCT_VERSION_JSON
    ):
        raise ValueError(value)
    return value


def get_header_data(event_meta, version):
    if version == nonblocking_stream.META_STRUCT_VERSION:
        # See nonblocking_stream.network for how this is packed.
        _, stream_name, partition_key = struct.unpack(
            nonblocking_stream.META_STRUCT_FMT, event_meta)
    elif version == META_STRUCT_VERSION_JSON:
        try:
            meta_data = json.loads(event_meta)
        except Exception:
            raise ValueError('Cannot Parse Meta JSON')
        try:
            stream_name = meta_data['stream_name']
            partition_key = meta_data['partition_key']
        except Exception:
            raise ValueError('Cannot Parse Meta JSON')
    else:
        raise ValueError('Incorrect meta version')
    return stream_name, partition_key


def load_or_get_stream(stream_name):
    try:
        return _streams[stream_name]
    except KeyError:
        stream = get_stream(stream_name, get_triton_config())
        _streams[stream_name] = stream
        return stream


def _write_messages_to_streams(waiting_messages):
    for stream_name, list_of_messages in waiting_messages.items():
        try:
            stream = load_or_get_stream(stream_name)
        except errors.StreamNotConfiguredError:
            log.error("Unable to get stream {}; dropping {} messages".format(
                stream_name, len(list_of_messages)))
            continue
        try:
            stream._put_many_packed(list_of_messages)
        except:
            log.exception(
                "Tritond failed to write messages to stream",
                extra={
                    'stream_name': stream_name,
                    'list_of_messages': list_of_messages
                }
            )
        pystatsd.increment(
            STATSD_EVENTCOUNT + stream_name,
            len(list_of_messages)
        )


def _write_messages_to_file(waiting_messages, file_obj):
    '''Debug method; write msgpack binary data stream by stream
    '''
    # import pdb; pdb.set_trace()
    for stream_name, list_of_messages in waiting_messages.items():
        log.debug("writing to stdout for stream %s", stream_name)
        file_obj.write(msgpack.packb(stream_name))
        file_obj.writelines(message['Data'] for message in list_of_messages)
    file_obj.flush()


def _pending_events(waiting_messages):
    return len(waiting_messages) > 0


def _maybe_flush_events(waiting_messages, last_flush, output_file=None):
    """
        Maybe publishes given events based on the configured publish interval.

        Arguments:
            waiting_messages : dict(string, list) - Events pending publication.
            last_flush : time.time() - Timestamp of the last successful publication.
            output_file : file_descriptor - File to flush to instead of Kinesis.  Optional, default = None.

        Returns:
            (time.time(), dict(string, list))
    """
    now = time.time()
    time_delta_ms = (now - last_flush) * 1000

    if (time_delta_ms > POLL_LOOP_TIMEOUT_MS):
        return (now, _flush_events(waiting_messages, output_file))
    else:
        return (last_flush, waiting_messages)


def _flush_events(waiting_messages, output_file=None):
    """
        Flushes per stream buffers contained in waiting_messages to
        either Kinesis or the given output file.

        Note - For now it is assumed this action always results in the events
        being either written or dropped.  The day may come when this isn't the
        case, as such, we return a value to represent unpublished events.

        Arguments:
            waiting_messages : dict(string, list) - Events pending publication.
            output_file : file_descriptor - File to flush to instead of Kinesis.  Optional, default = None.

        Returns:
            dict(string, list)
    """
    if _pending_events(waiting_messages):
        with pystatsd.Timer(STATSD_LOOPTIME):
            if output_file is not None:
                _write_messages_to_file(waiting_messages, output_file)
            else:
                _write_messages_to_streams(waiting_messages)

    return defaultdict(list)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--verbose', '-v',
        dest='verbose',
        action='append_const',
        const=True,
        default=list())
    parser.add_argument(
        '--skip-kinesis',
        dest='skip_kinesis',
        action='store_true',
        default=False,
        help="Skip publishing to Kinesis; for debug purposes.")
    parser.add_argument(
        '--output_file',
        dest='output_file',
        action='store',
        default=None,
        help=
            """
            Output file for incoming data; otherwise output to stdout.
            Only used in conjunction with --skip-kinesis
            """
    )

    options = parser.parse_args()
    setup_logging(options)

    output_file = None
    if options.skip_kinesis:
        if options.output_file is None:
            output_file = sys.stdout
        else:
            output_file = open(options.output_file, 'wb')

    continue_running = [True]
    final_flush = [True]

    def handle_sigint(signum, frame):
        log.info("Exiting immediately.")
        continue_running[0] = False
        final_flush = [False]

    def handle_sigterm(signum, frame):
        log.info("Exiting after all events have been flushed.")
        continue_running[0] = False

    signal.signal(signal.SIGTERM, handle_sigterm)
    signal.signal(signal.SIGINT, handle_sigint)

    zmq_context = zmq.Context()
    poller = zmq.Poller()

    collect_host = "{}:{}".format(*config.get_zmq_config())
    log.info("Initializing collector port %s", collect_host)
    collector_sock = zmq_context.socket(zmq.PULL)
    collector_sock.hwm = MAX_QUEUED_MESSAGES
    collector_sock.bind("tcp://%s" % collect_host)
    poller.register(collector_sock, zmq.POLLIN)

    last_write = time.time()
    waiting_messages = defaultdict(list)

    log.info("Starting IO Loop")
    while continue_running[0]:
        log.debug("Poll")

        try:
            ready = dict(poller.poll(POLL_LOOP_TIMEOUT_MS))
        except (KeyboardInterrupt, SystemExit):
            continue_running[0] = False
            break
        except zmq.ZMQError, e:
            if e.errno == errno.EINTR:
                # If this is from a SIGTERM, we have a handler for that and the
                # loop should exit gracefull.
                continue
            else:
                raise

        log.debug("Poller returned: %r", ready)

        if collector_sock in ready:
            try:
                event_meta, event_data = collector_sock.recv_multipart()
            except ValueError, e:
                # Sometimes clients can fail and corrupt these two-part sends.
                log.warning("Failed to recv from %r: %r", collector_sock, e)
                continue

            try:
                version = check_meta_version(event_meta)
                stream_name, partition_key = get_header_data(
                    event_meta, version)
            except ValueError:
                log.warning("Failed to decode event due to version mismatch")
                continue

            waiting_messages[stream_name].append({
                'Data': event_data,
                'PartitionKey': partition_key
            })

        last_write, waiting_messages = _maybe_flush_events(waiting_messages, last_write, output_file)

    collector_sock.close(0)

    if final_flush[0]:
        _flush_events(waiting_messages, output_file)

    sys.exit(0)


if __name__ == '__main__':
    main()
