#!/usr/bin/env python
# encoding: utf-8

from klag import KafkaConsumerLag

import sys
import time
import logging
import json
import argparse


def parse_groups(value):
    return json.loads(value, encoding='UTF-8')

parser = argparse.ArgumentParser(
        description='Kafka 0.8.2+ consumer monitoring.',
        epilog="http://www.github.com/andrewkcarter/klag")
parser.add_argument('-b', '--brokers',
                    help="Comma separated list of Kafka brokers",
                    type=str,
                    metavar='LIST',
                    default='localhost:9092')
parser.add_argument('-g', '--groups',
                    help="Consumer groups and list of topics for each group to check (even if dead). JSON structured as '{\"<group_id>\":[\"<topic>\"]}'",
                    type=parse_groups,
                    metavar='JSON',
                    default='{}')
parser.add_argument('--groups-file',
                    help="JSON file containing consumer groups and list of topics for each group to check (even if dead). JSON structured as '{\"<group_id>\":[\"<topic>\"]}'",
                    type=argparse.FileType('r'),
                    metavar='FILE',)
parser.add_argument('-d', '--discover',
                    help="Include all active consumer groups and topics in output",
                    action='store_true')
parser.add_argument('-c', '--cache',
                    help="Consumer groups that go dead (all consumers disconnect) will continue to be checked. Used with '-s'",
                    action='store_true')
parser.add_argument('-p', '--partitions',
                    help="Include partition metrics in output",
                    action='store_true')
parser.add_argument('-s', '--seconds',
                    help="Repeat check every N seconds",
                    type=int,
                    metavar='N')
parser.add_argument('-o', '--output-file',
                    help="Write output to a file",
                    type=argparse.FileType('a'),
                    metavar='FILE',
                    default=sys.stdout)
parser.add_argument('-f', '--format',
                    help="Write output using the specified format structure {default,json,json-pretty,json-discrete}",
                    choices=['default', 'json', 'json-pretty', 'json-discrete'],
                    metavar="FORMAT",
                    default='default')
parser.add_argument('--log-level',
                    help="Set the application logging level",
                    metavar='LEVEL',
                    default='ERROR')
parser.add_argument(u'--version',
                    action=u'version',
                    version=u'klag 1.0.0')

args = parser.parse_args()

logging.basicConfig(level=args.log_level.upper())


def format_result(format, partitions, results):

    if not partitions:
        for group in results.itervalues():
            for topic in group[u'topics'].itervalues():
                del topic[u'partitions']

    if format == u'json':
        return json.dumps(results)
    elif format == u'json-pretty':
        return json.dumps(results, indent=4)
    elif format == u'json-discrete':

        out = []
        record = {}

        for group, g_data in results.iteritems():

            record[u'group'] = group
            record[u'state'] = g_data[u'state']

            if u'topics' in g_data and len(g_data[u'topics']) > 0:
                for topic, t_data in g_data[u'topics'].iteritems():

                    record[u'topic'] = topic
                    record[u'consumer_lag'] = t_data[u'consumer_lag']

                    if u'partitions' in t_data and len(t_data[u'partitions']) > 0:
                        for partition, p_data in t_data[u'partitions'].iteritems():
                            record[u'partition'] = partition
                            record[u'offset_consumed'] = p_data[u'offset_consumed']
                            record[u'offset_first'] = p_data[u'offset_first']
                            record[u'lag'] = p_data[u'lag']
                            record[u'offset_last'] = p_data[u'offset_last']
                            out.append(json.dumps(record))
                    else:
                        out.append(json.dumps(record))
            else:
                out.append(json.dumps(record))

        return u'\n'.join(out)
    else:
        out = []
        out.append(u'{:=<80}'.format(u''))
        partitions = False

        if len(results) > 0:
            for group, g_data in results.iteritems():

                out.append(u'{:68.66}{:>12}'.format(group, u'[' + g_data[u'state'].upper() + u']'))

                if u'topics' in g_data and len(g_data[u'topics']) > 0:
                    for topic, t_data in g_data[u'topics'].iteritems():

                        out.append(u'{:10}{:58.56}{:>12}'.format(u'', topic, t_data[u'consumer_lag']))

                        if u'partitions' in t_data and len(t_data[u'partitions']) > 0:
                            partitions = True
                            for partition, p_data in t_data[u'partitions'].iteritems():

                                out.append(u'{:20}{:12}{:>12}{:>12}{:>12}{:>12}'
                                           .format(u'',
                                                   str(partition),
                                                   p_data[u'offset_first'],
                                                   p_data[u'offset_consumed'],
                                                   p_data[u'offset_last'],
                                                   p_data[u'lag']))
                else:
                    out.append(u'{:10}None'.format(u''))
        else:
            out.append(u'No consumer groups')

        if partitions:
            out.insert(0, u'{:10}{:10}{:12}{:>12}{:>12}{:>12}{:>12}'.format(u'Group', u'Topic', u'Partition', u'Earliest', u'Consumed', u'Latest', u'Remaining'))
        else:
            out.insert(0, u'{:10}{:10}{:>60}'.format(u'Group', u'Topic', u'Remaining'))

        return u'\n'.join(out)

klag = KafkaConsumerLag(args.brokers)

group_topics = {}

if args.groups_file:
    group_topics = json.load(args.groups_file, encoding=u'UTF-8')
    args.groups_file.close()

if args.groups:
    group_topics.update(args.groups)

discover = not group_topics or args.discover

try:
    if args.seconds:
        while True:
            next_check = time.time() + args.seconds

            results = klag.check(group_topics, discover)

            output = format_result(args.format, args.partitions, results)

            if output and len(output) > 0:
                args.output_file.write(output + '\n')

            args.output_file.flush()

            if args.cache:
                for name, group in results.iteritems():
                    group_topics[name] = list(group[u'topics'].iterkeys())

            now = time.time()
            if now <= next_check:
                time.sleep(next_check - now)
    else:
        results = klag.check(group_topics, discover)
        output = format_result(args.format, args.partitions, results)

        if output and len(output) > 0:
            args.output_file.write(output + '\n')
except KeyboardInterrupt:
    pass
finally:
    klag.close()

    args.output_file.flush()
