#!/usr/bin/env python

"""
esmond-ps-pipe - take piped json output from bwctl and load into an esmond 
perfsonar MA. To wit:

bwctl -c lbl-pt1.es.net -s llnl-pt1.es.net -T iperf3 --parsable --verbose |& esmond-ps-pipe -url http://ps-archive.es.net --user mgoode --key api_key_string

See further documention in README.rst or use the --help flag for a full 
list of arguments.
"""

import datetime
import json
import logging
import os
import socket
import sys
import time

from optparse import OptionParser

from esmond_client.perfsonar.post import (
    EventTypeBulkPost,
    EventTypeBulkPostException,
    EventTypePost,
    EventTypePostException, 
    MetadataPost, 
    MetadataPostException, 
)

import pprint

pp = pprint.PrettyPrinter(indent=4)

class EsmondPipeException(Exception):
    def __init__(self, value):
        self.value = value
    def __str__(self):
        return repr(self.value)

class EsmondPipeWarning(Warning): pass

def setup_log(log_path=None):
    """
    Usage:
    _log('main.start', 'happy simple log event')
    _log('launch', 'more={0}, complex={1} log=event'.format(100, 200))
    """
    log = logging.getLogger("esmond-ps-pipe")
    if not log_path:
        _h = logging.StreamHandler()
    else:
        # it's on you to make sure log_path is valid.
        logfile = '{0}/esmond-ps-pipe.log'.format(log_path)
        _h = logging.FileHandler(logfile)
    _h.setFormatter(logging.Formatter('ts=%(asctime)s %(message)s'))
    log.addHandler(_h)
    log.setLevel(logging.INFO)
    return log

#
# Data objects
#

class RecursiveDataObject(object):
    def __init__(self, initial=None):
        self.__dict__['_data'] = {}

        if hasattr(initial, 'items'):
            self.__dict__['_data'] = initial
            try:
                # Recurse through entire all attrs at all levels
                # to make sure the wrapping is ok so we can check
                # the parse once.
                for i in self._data.keys():
                    self._wrap(self._data.get(i), i)
            except Exception as e:
                raise EsmondPipeException(str(e))

    def _wrap(self, data, name):
        # print 'wrap', name
        # dict => wrapper object
        if isinstance(data, dict):
            return RecursiveDataObject(data)
        # list of dicts => list of wrapper objects
        elif isinstance(data, list):
            return [ RecursiveDataObject(x) for x in data ]
        # modify outgoing named attributes
        elif name in ('time',):
            return self._to_datetime(data)

        return data

    def _to_datetime(self, d):
        if isinstance(d, int):
            # presume unixtime
            return datetime.datetime.utcfromtimestamp(d)
        else:
            return datetime.datetime.strptime(d, '%a, %d %b %Y %X GMT')

    def __getattr__(self, name):
        return self._wrap(self._data.get(name, None), name)

    def __setattr__(self, name, value):
        self.__dict__['_data'][name] = value

    def __str__(self):
        return pp.pformat(self.to_dict())

    def to_dict(self):
        return self._data

class BwctlPayload(object):
    def __init__(self):
        self.tool_name = None
        self.input_source = None
        self.input_destination = None
        self.payload = None

        self._measurement_agent = None

    @property
    def measurement_agent(self):
        if self._measurement_agent:
            return self._measurement_agent

        version = None

        try:
            socket.gethostbyname(self.input_destination)
            version = 4
        except socket.gaierror:
            pass

        try:
            socket.getaddrinfo(self.input_destination, None, socket.AF_INET6)
            version = 6
        except socket.gaierror:
            pass

        sock_flags = { 4: socket.AF_INET, 6: socket.SOCK_DGRAM }

        s = socket.socket(sock_flags.get(version), socket.SOCK_DGRAM)
        s.connect((self.input_destination, 80))
        self._measurement_agent = (s.getsockname()[0])
        s.close()

        return self._measurement_agent

#
# Code for all tool types
#

def id_and_extract(data, _log):
    """
    Identify the piped input and isolate the json part.
    """

    payload = BwctlPayload()

    json_block = ''

    scanning = False

    for i in data:
        ## ignore
        # skip blank lines
        if not i.strip():
            continue
        # skip extraneous output if not scanning the json
        if not scanning and not i.startswith('bwctl:'):
            continue

        ## pull out metadata from the --verbose output
        if i.startswith('bwctl: Using tool:'):
            payload.tool_name = i.split()[3]

        if i.strip().endswith('as the address for remote sender'):
            payload.input_source = i.split()[2]

        if i.strip().endswith('as the address for remote receiver'):
            payload.input_destination = i.split()[2]

        ## scanning logic
        # end scanning token
        if i.startswith('bwctl: stop_tool:'):
            break
        # gather the json
        if scanning:
            json_block += i
        # token to start gathering the json
        if i.startswith('bwctl: start_tool:'):
            scanning = True

    # did we find the scanning tokens?
    if not scanning:
        _log('id_and_extract.error', 'did not find bwctl: start_tool token to extract JSON from output')
        _log('id_and_extract.error', 'bwctl must be run with the --verbose flag')

    # check that appropriate metadata was extracted
    metadata_found = True
    fields = ('tool_name', 'input_source', 'input_destination',)

    for f in fields:
        if not getattr(payload, f):
            metadata_found = False
            _log('id_and_extract.error', 'could not extract {0}'.format(f))

    # parse the json
    try:
        payload.payload = json.loads(json_block)
    except ValueError:
        _log('id_and_extract.error', 'could not load json_block: {0}'.format(json_block))

    # make sure we got everything
    if not payload.payload or not metadata_found:
        return None
    else:
        return payload

class PostEvents(object):
    def __init__(self, metadata, ts, options, _log):
        """
        Wrapper around the EventTypeBulkPost class.
        """
        self.metadata = metadata
        self.ts = ts # one ts used for all of the events
        self.options = options # option parser object
        self._log = _log

        self.etb = EventTypeBulkPost(options.api_url, username=options.user,
                api_key=options.key, script_alias=options.script_alias, 
                metadata_key=metadata.metadata_key)

    def process_event_type(self, event_type, val, ts):
        if self.options.verbose:
            self._log('process_event_type.start', '{0}'.format(event_type))

        self.etb.add_data_point(event_type, ts, val)

    def write(self):
        try:
            self.etb.post_data()
        except EventTypeBulkPostException as e:
            self._log('write.warning', str(e))

#
# Code to process the iperf3 data
#

def _iperf3_metadata_args(o, payload):
    """
    Assemble the metadata args and the event types for the 
    metadata depending on what kind of iperf3 input we got.
    """
    mda = dict(
        subject_type='point-to-point',
        source=o.start.connected[0].local_host,
        destination=o.start.connected[0].remote_host,
        tool_name='bwctl/{0}'.format(payload.tool_name),
        input_source=payload.input_source,
        input_destination=payload.input_destination,
        measurement_agent=payload.measurement_agent,
        ip_transport_protocol=o.start.test_start.protocol,
        time_duration=o.start.test_start.duration,
    )

    
    event_types = [
        'throughput',
        'throughput-subintervals',
        'packet-retransmits-subintervals',
        
    ]

    if o.start.test_start.num_streams > 1:
        event_types += [
            'streams-packet-retransmits',
            'streams-packet-retransmits-subintervals',
            'streams-throughput',
            'streams-throughput-subintervals',
        ]

    if o.start.test_start.protocol == 'TCP':
        event_types += ['packet-retransmits']

    if o.start.test_start.protocol == 'UDP':
        event_types += ['packet-count-lost', 'packet-count-sent', 'packet-loss-rate']

    return mda, event_types

def process_iperf3(o, payload, options, _log):
    """
    Process iperf3 input.
    """
    _log('process_iperf3.start', 'begin')

    # metadata

    mda, event_types = _iperf3_metadata_args(o, payload)

    mp = MetadataPost(options.api_url, username=options.user,
            api_key=options.key, script_alias=options.script_alias, 
            **mda)

    for et in event_types:
        mp.add_event_type(et)

    mp.add_freeform_key_value('bw-parallel-streams', o.start.test_start.num_streams)
    mp.add_freeform_key_value('bw_ignore_first_seconds', o.start.test_start.omit)

    try:
        metadata = mp.post_metadata()
    except MetadataPostException as e:
        _log('process_iperf3.error', 'MetadataPost failed, abort processing, not updating record state - ERROR: {0}'.format(str(e)))
        return None

    # event type values

    # timestamp for all event type inserts
    ts = o.start.timestamp.timesecs

    # Wrapper around event bulk post object
    post = PostEvents(metadata, ts, options, _log)

    # throughput
    if o.start.test_start.protocol == 'TCP':
        tp = o.end.sum_received.bits_per_second
    elif o.start.test_start.protocol == 'UDP':
        tp = o.end.sum.bits_per_second

    post.process_event_type('throughput', tp, ts)

    # throughput-subintervals
    tp_subints = list()
    for i in o.intervals:
        d = dict(
            start=i.sum.start, 
            duration=i.sum.seconds, 
            val=i.sum.bits_per_second
        )
        tp_subints.append(d)

    post.process_event_type('throughput-subintervals', tp_subints, ts)

    # packet-count-lost
    # and
    # packet-count-sent
    # and
    # packet-loss-rate
    if o.start.test_start.protocol == 'UDP':
        pl = ps = None
        # lost
        if o.end.sum.lost_packets is not None:
            pl = o.end.sum.lost_packets
            post.process_event_type('packet-count-lost', pl, ts)
        # sent
        if o.end.sum.packets is not None:
            ps = o.end.sum.packets
            post.process_event_type('packet-count-sent', ps, ts)
        # loss rate
        if pl is not None and ps is not None:
            lr = dict(numerator=pl, denominator=ps)
            post.process_event_type('packet-loss-rate', lr, ts)

    # packet-retransmits
    if o.start.test_start.protocol == 'TCP':
        prt = o.end.sum_sent.retransmits
        post.process_event_type('packet-retransmits', prt, ts)

    # packet-retransmits-subintervals
    pr_subints = list()
    for i in o.intervals:
        d = dict(
            start=i.sum.start, 
            duration=i.sum.seconds, 
            val=i.sum.retransmits
        )
        pr_subints.append(d)

    post.process_event_type('packet-retransmits-subintervals', pr_subints, ts)

    # streams-packet-retransmits
    # and
    # streams-packet-retransmits-subintervals
    # and
    # streams-throughput
    # and
    # streams-throughput-subintervals
    if o.start.test_start.num_streams > 1:
        # retrans
        sp_retrans = list()
        for i in o.end.streams:
            sp_retrans.append(i.sender.retransmits)

        post.process_event_type('streams-packet-retransmits', sp_retrans, ts)

        # retrans-subint
        sp_retrans_subint = list()
        for i in o.intervals:
            sl = list() # sub-list
            for ii in i.streams:
                d = dict(
                    start=ii.start, 
                    duration=ii.seconds, 
                    retransmits=ii.retransmits
                )
                sl.append(d)
            sp_retrans_subint.append(sl)

        post.process_event_type('streams-packet-retransmits-subintervals', sp_retrans_subint, ts)

        # throughput
        s_throughput = list()
        for i in o.end.streams:
            s_throughput.append(i.receiver.bits_per_second)

        post.process_event_type('streams-throughput', s_throughput, ts)

        # throughput-subint
        s_tput_subint = list()
        for i in o.intervals:
            sl = list()
            for ii in i.streams:
                d = dict(
                    start=ii.start, 
                    duration=ii.seconds, 
                    throughput=ii.bits_per_second
                )
                sl.append(d)
            s_tput_subint.append(sl)

        post.process_event_type('streams-throughput-subintervals', s_tput_subint, ts)

    post.write()

    _log('process_iperf3.end', 'finished')


typemap = dict(
    iperf3=process_iperf3,
)

def main():
    usage = ' bwctl ... |& %prog [ -u USER -k API_KEY | -U ESMOND_REST_URL | -v ]'
    parser = OptionParser(usage=usage)
    parser.add_option('-U', '--url', metavar='ESMOND_REST_URL',
            type='string', dest='api_url', 
            help='URL for the REST API (default=%default) - required.',
            default='http://localhost:8000')
    parser.add_option('-u', '--user', metavar='USER',
            type='string', dest='user', default='',
            help='POST interface username.')
    parser.add_option('-k', '--key', metavar='API_KEY',
            type='string', dest='key', default='',
            help='API key for POST operation.')
    parser.add_option('-s', '--script_alias', metavar='URI_PREFIX',
            type='string', dest='script_alias', default='/',
            help='Set the script_alias arg if the perfsonar API is configured to use one (default=%default which means none set).')
    parser.add_option('-v', '--verbose',
        dest='verbose', action='store_true', default=False,
        help='Verbose output.')
    parser.add_option('-l', '--log_dir', metavar='DIR',
                type='string', dest='logdir', default='',
                help='Write log output to specified directory - if not set, log goes to stdout.')
    options, args = parser.parse_args()

    log_path = None
    
    if options.logdir:
        log_path = os.path.normpath(options.logdir)
        if not os.path.exists(log_path):
            parser.error('{0} log path does not exist.'.format(log_path))
    
    log = setup_log(log_path)
    _log = lambda e, s: log.info('event={e} id={gid} {s}'.format(e=e, gid=int(time.time()), s=s))

    if not options.user or not options.key:
        # this exits with error status 2 FYI
        parser.error('both --user and --key args are required.')

    _log('main.start', 'reading input')

    data = sys.stdin.readlines()
    
    payload = id_and_extract(data, _log)

    if not payload:
        _log('main.fatal', 'could not extract metadata and valid json from input')
        if options.verbose:
            _log('main.debug', '{0}'.format(''.join(data)))
        _log('main.fatal', 'exiting')
        sys.exit(-1)

    try:
        o = RecursiveDataObject(payload.payload)
    except EsmondPipeException as e:
        _log('main.fatal', 'error encapsulating json - ERROR: {0}'.format(str(e)))
        _log('main.fatal', 'exiting')
        sys.exit(-1)

    # process the encapsulated data
    try:
        typemap[payload.tool_name](o, payload, options, _log)
    except KeyError:
        _log('main.fatal', 'there is no handler for input type: {0}'.format(payload.input_type))
        _log('main.fatal', 'exiting')
    except EsmondPipeException as e:
        _log('main.error', 'unable to process input: {0}'.format(str(e)))

    _log('main.exit', 'success')

    sys.exit(0)

if __name__ == '__main__':
    main()