#!/bin/python

# -*- coding: utf-8 -*-
"""
tritond
~~~~~~~~

Logging daemon for non-blocking triton events.

Adapted from https://github.com/rhettg/BlueOx/blob/master/bin/oxd

"""
import argparse
import errno
import sys
import logging
import signal
import struct
import os
import time
from collections import defaultdict
import msgpack

import zmq

from triton import nonblocking_stream
from triton.stream import get_stream
from triton import config, errors

ENV_VAR_TRITON_CONFIG_PATH = 'TRITON_CONFIG'
DEFAULT_CONFIG_PATH = '/etc/triton.yaml'

# Batch writes to Kinesis; collect messages for this long then write
POLL_LOOP_TIMEOUT_MS = 100

log = logging.getLogger("triton.d")

_triton_config = None
_streams = dict()


def setup_logging(options):
    if len(options.verbose) > 1:
        level = logging.DEBUG
    elif options.verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    log_format = "%(asctime)s %(levelname)s:%(name)s: %(message)s"

    # We just log straight to stdout. Generally tritond is run by some process
    # that handles collecting logging for you, like upstart or supervisord. It
    # would be easy enough to add an option if it was needed by someone though.
    logging.basicConfig(level=level, format=log_format, stream=sys.stdout)


def get_triton_config():
    global _triton_config
    if not _triton_config:
        config_path = os.environ.get(
            ENV_VAR_TRITON_CONFIG_PATH, DEFAULT_CONFIG_PATH)
        _triton_config = config.load_config(config_path)

    if _triton_config is None:
        raise errors.TritonNotConfiguredError(
            'Failed to load config for tritond')
    return _triton_config


def load_or_get_stream(stream_name):
    try:
        return _streams[stream_name]
    except KeyError:
        stream = get_stream(stream_name, get_triton_config())
        _streams[stream_name] = stream
        return stream


def _write_messages_to_streams(waiting_messages):
    for stream_name, list_of_messages in waiting_messages.items():
        try:
            stream = load_or_get_stream(stream_name)
        except errors.StreamNotConfiguredError:
            log.error("Unable to get stream {}; dropping {} messages".format(
                stream_name, len(list_of_messages)))
            continue
        stream._put_many_packed(list_of_messages)


def _write_messages_to_file(waiting_messages, file_obj):
    '''Debug method; write msgpack binary data stream by stream
    '''
    # import pdb; pdb.set_trace()
    for stream_name, list_of_messages in waiting_messages.items():
        log.debug("writing to stdout for stream %s", stream_name)
        file_obj.write(msgpack.packb(stream_name))
        file_obj.writelines(message['Data'] for message in list_of_messages)
    file_obj.flush()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--verbose', '-v',
        dest='verbose',
        action='append_const',
        const=True,
        default=list())
    parser.add_argument(
        '--skip-kinesis',
        dest='skip_kinesis',
        action='store_true',
        default=False,
        help="will skip sends to Kinesis; for debug purposess")
    parser.add_argument(
        '--output_file',
        dest='output_file',
        action='store',
        default=None,
        help=(
            'output file for incomeing data; otherwise output to stdout. ',
            'Only used with --skip-kinesis'
        )
    )

    options = parser.parse_args()

    setup_logging(options)

    continue_running = [True]

    def handle_sigterm(signum, frame):
        log.info("Exiting")
        continue_running[0] = False

    signal.signal(signal.SIGTERM, handle_sigterm)
    signal.signal(signal.SIGINT, handle_sigterm)

    zmq_context = zmq.Context()
    poller = zmq.Poller()

    collect_host = "{}:{}".format(*config.get_zmq_config())
    log.info("Initializing collector port %s", collect_host)
    collector_sock = zmq_context.socket(zmq.PULL)
    collector_sock.bind("tcp://%s" % collect_host)
    poller.register(collector_sock, zmq.POLLIN)

    if options.skip_kinesis:
        if options.output_file is None:
            output_file = sys.stdout
        else:
            output_file = open(options.output_file, 'wb')

    last_write = time.time()
    waiting_messages = defaultdict(list)

    log.info("Starting IO Loop")
    while continue_running[0]:
        log.debug("Poll")

        try:
            ready = dict(poller.poll(POLL_LOOP_TIMEOUT_MS))
        except (KeyboardInterrupt, SystemExit):
            continue_running[0] = False
            break
        except zmq.ZMQError, e:
            if e.errno == errno.EINTR:
                # If this is from a SIGTERM, we have a handler for that and the
                # loop should exit gracefull.
                continue
            else:
                raise

        log.debug("Poller returned: %r", ready)

        if collector_sock in ready:
            try:
                event_meta, event_data = collector_sock.recv_multipart()
            except ValueError, e:
                # Sometimes clients can fail and corrupt these two-part sends.
                log.warning("Failed to recv from %r: %r", collector_sock, e)
                continue

            try:
                nonblocking_stream.check_meta_version(event_meta)
            except ValueError:
                log.warning("Failed to decode event due to version mismatch")
                continue

            # See nonblocking_stream.network for how this is packed.
            _, stream_name, partition_key = struct.unpack(
                nonblocking_stream.META_STRUCT_FMT, event_meta)
            waiting_messages[stream_name].append({
                'Data': event_data,
                'PartitionKey': partition_key
            })

        if (
            len(waiting_messages) > 0
            and
            (
                ((time.time() - last_write) * 1000 > POLL_LOOP_TIMEOUT_MS)
                or not continue_running[0]
            )
        ):
            if options.skip_kinesis:
                _write_messages_to_file(waiting_messages, output_file)
            else:
                _write_messages_to_streams(waiting_messages)

            waiting_messages = defaultdict(list)
            last_write = time.time()

    collector_sock.close(0)

    sys.exit(0)


if __name__ == '__main__':
    main()
