#!/bin/python

# -*- coding: utf-8 -*-
"""
tritond
~~~~~~~~

Logging daemon for non-blocking triton events.

Adapted from https://github.com/rhettg/BlueOx/blob/master/bin/oxd

"""
import argparse
import errno
import sys
import logging
import signal
import struct
import os
import time
from collections import defaultdict
import msgpack
import json

import zmq
import pystatsd

from triton import nonblocking_stream
from triton.stream import get_stream
from triton import config, errors

ENV_VAR_TRITON_CONFIG_PATH = 'TRITON_CONFIG'
DEFAULT_CONFIG_PATH = '/etc/triton.yaml'

# We want to limit how many messages we'll hold in memory so if kinesis is
# unavailable, we don't just run out of memory.  I based this value on rough
# value of rather large 3k sized messages, and how many we can fit in 10 megs.
MAX_QUEUED_MESSAGES = 3500

# Batch writes to Kinesis; collect messages for this long then write
POLL_LOOP_TIMEOUT_MS = 100

# See https://github.com/postmates/pystatsd for statsd configuration
# Note that pystatsd writes raise no exception if no statsd server is running
STATSD_PREFIX = 'tritond.'
STATSD_EVENTCOUNT = STATSD_PREFIX + "eventcount."
STATSD_LOOPTIME = STATSD_PREFIX + "write_loop.timing"

log = logging.getLogger("triton.d")

_triton_config = None
_streams = dict()

# version byte in our meta struct for JSON meta.
META_STRUCT_VERSION_JSON = 0x7B


def setup_logging(options):
    if len(options.verbose) > 1:
        level = logging.DEBUG
    elif options.verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    log_format = "%(asctime)s %(levelname)s:%(name)s: %(message)s"

    # We just log straight to stdout. Generally tritond is run by some process
    # that handles collecting logging for you, like upstart or supervisord. It
    # would be easy enough to add an option if it was needed by someone though.
    logging.basicConfig(level=level, format=log_format, stream=sys.stdout)


def get_triton_config():
    global _triton_config
    if not _triton_config:
        config_path = os.environ.get(
            ENV_VAR_TRITON_CONFIG_PATH, DEFAULT_CONFIG_PATH)
        _triton_config = config.load_config(config_path)

    if _triton_config is None:
        raise errors.TritonNotConfiguredError(
            'Failed to load config for tritond')
    return _triton_config


def check_meta_version(meta):
    value, = struct.unpack(">B", meta[0])
    if value not in (
        nonblocking_stream.META_STRUCT_VERSION, META_STRUCT_VERSION_JSON
    ):
        raise ValueError(value)
    return value


def get_header_data(event_meta, version):
    if version == nonblocking_stream.META_STRUCT_VERSION:
        # See nonblocking_stream.network for how this is packed.
        _, stream_name, partition_key = struct.unpack(
            nonblocking_stream.META_STRUCT_FMT, event_meta)
    elif version == META_STRUCT_VERSION_JSON:
        try:
            meta_data = json.loads(event_meta)
        except Exception:
            raise ValueError('Cannot Parse Meta JSON')
        try:
            stream_name = meta_data['stream_name']
            partition_key = meta_data['partition_key']
        except Exception:
            raise ValueError('Cannot Parse Meta JSON')
    else:
        raise ValueError('Incorrect meta version')
    return stream_name, partition_key


def load_or_get_stream(stream_name):
    try:
        return _streams[stream_name]
    except KeyError:
        stream = get_stream(stream_name, get_triton_config())
        _streams[stream_name] = stream
        return stream


def _write_messages_to_streams(waiting_messages):
    for stream_name, list_of_messages in waiting_messages.items():
        try:
            stream = load_or_get_stream(stream_name)
        except errors.StreamNotConfiguredError:
            log.error("Unable to get stream {}; dropping {} messages".format(
                stream_name, len(list_of_messages)))
            continue
        try:
            stream._put_many_packed(list_of_messages)
        except:
            log.exception(
                "Tritond failed to write messages to stream",
                extra={
                    'stream_name': stream_name,
                    'list_of_messages': list_of_messages
                }
            )
        pystatsd.increment(
            STATSD_EVENTCOUNT + stream_name,
            len(list_of_messages)
        )


def _write_messages_to_file(waiting_messages, file_obj):
    '''Debug method; write msgpack binary data stream by stream
    '''
    # import pdb; pdb.set_trace()
    for stream_name, list_of_messages in waiting_messages.items():
        log.debug("writing to stdout for stream %s", stream_name)
        file_obj.write(msgpack.packb(stream_name))
        file_obj.writelines(message['Data'] for message in list_of_messages)
    file_obj.flush()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--verbose', '-v',
        dest='verbose',
        action='append_const',
        const=True,
        default=list())
    parser.add_argument(
        '--skip-kinesis',
        dest='skip_kinesis',
        action='store_true',
        default=False,
        help="will skip sends to Kinesis; for debug purposess")
    parser.add_argument(
        '--output_file',
        dest='output_file',
        action='store',
        default=None,
        help=(
            'output file for incomeing data; otherwise output to stdout. ',
            'Only used with --skip-kinesis'
        )
    )

    options = parser.parse_args()

    setup_logging(options)

    continue_running = [True]

    def handle_sigterm(signum, frame):
        log.info("Exiting")
        continue_running[0] = False

    signal.signal(signal.SIGTERM, handle_sigterm)
    signal.signal(signal.SIGINT, handle_sigterm)

    zmq_context = zmq.Context()
    poller = zmq.Poller()

    collect_host = "{}:{}".format(*config.get_zmq_config())
    log.info("Initializing collector port %s", collect_host)
    collector_sock = zmq_context.socket(zmq.PULL)
    collector_sock.hwm = MAX_QUEUED_MESSAGES
    collector_sock.bind("tcp://%s" % collect_host)
    poller.register(collector_sock, zmq.POLLIN)

    if options.skip_kinesis:
        if options.output_file is None:
            output_file = sys.stdout
        else:
            output_file = open(options.output_file, 'wb')

    last_write = time.time()
    waiting_messages = defaultdict(list)

    log.info("Starting IO Loop")
    while continue_running[0]:
        log.debug("Poll")

        try:
            ready = dict(poller.poll(POLL_LOOP_TIMEOUT_MS))
        except (KeyboardInterrupt, SystemExit):
            continue_running[0] = False
            break
        except zmq.ZMQError, e:
            if e.errno == errno.EINTR:
                # If this is from a SIGTERM, we have a handler for that and the
                # loop should exit gracefull.
                continue
            else:
                raise

        log.debug("Poller returned: %r", ready)

        if collector_sock in ready:
            try:
                event_meta, event_data = collector_sock.recv_multipart()
            except ValueError, e:
                # Sometimes clients can fail and corrupt these two-part sends.
                log.warning("Failed to recv from %r: %r", collector_sock, e)
                continue

            try:
                version = check_meta_version(event_meta)
                stream_name, partition_key = get_header_data(
                    event_meta, version)
            except ValueError:
                log.warning("Failed to decode event due to version mismatch")
                continue

            waiting_messages[stream_name].append({
                'Data': event_data,
                'PartitionKey': partition_key
            })

        time_delta_ms = (time.time() - last_write) * 1000
        if (
            len(waiting_messages) > 0
            and
            (
                (time_delta_ms > POLL_LOOP_TIMEOUT_MS)
                or not continue_running[0]
            )
        ):
            if options.skip_kinesis:
                _write_messages_to_file(waiting_messages, output_file)
            else:
                _write_messages_to_streams(waiting_messages)

            waiting_messages = defaultdict(list)
            last_write = time.time()
            pystatsd.timing(STATSD_LOOPTIME, time_delta_ms)

    collector_sock.close(0)

    sys.exit(0)


if __name__ == '__main__':
    main()
