#!/usr/bin/env python3
import pyshark
import time
import asyncio
import logging
import sys
import json
import netifaces
import os
import psutil
from pxgrid_pyshark import endpointsdb
from pxgrid_pyshark import parser
from pxgrid_util import WebSocketStomp
from pxgrid_util import Config
from pxgrid_util import create_override_url
from pxgrid_util import PXGridControl
from websockets.exceptions import WebSocketException
from signal import SIGINT, SIGTERM

logger = logging.getLogger(__name__)

default_bpf_filter = "(tcp port 80 || tcp port 8080 || udp port 1900 || udp port 5060 || udp port 5353) and not ip6"
parser = parser()
capture_running = False
capture_count = 0

## Create dict of supported protocols and their appropriate inspection functions
packet_callbacks = {
    'sip': parser.parse_sip,
    'ssdp': parser.parse_ssdp,
    'mdns': parser.parse_mdns,
    'http': parser.parse_http,
    'xml': parser.parse_xml
}

## Return a list of processes matching 'name' (https://psutil.readthedocs.io/en/latest/)
def find_procs_by_name(name):
    ls = []
    for p in psutil.process_iter(['name']):
        # if p.info['name'] == name:
        if name in p.info['name']:
            ls.append(p )
    return ls

## Kill a process based on provided PID value (https://psutil.readthedocs.io/en/latest/)
def kill_proc_tree(pid, sig=SIGTERM, include_parent=True, timeout=None, on_terminate=None):
    assert pid != os.getpid(), "won't kill myself"
    parent = psutil.Process(pid)
    children = parent.children(recursive=True)
    if include_parent:
        children.append(parent)
    for p in children:
        try:
            p.send_signal(sig)
        except psutil.NoSuchProcess:
            pass
    gone, alive = psutil.wait_procs(children, timeout=timeout,
                                    callback=on_terminate)
    return (gone, alive)

## Wrap the search and kill process functions into single call
def proc_cleanup(proc_name):
    proc_check = find_procs_by_name(proc_name)
    if len(proc_check) > 0:
        for item in proc_check:
            logger.warning(f'orphaned {item._name} proc: {item.pid}')
            proc_kill = kill_proc_tree(item.pid)
            if len(proc_kill) > 0:
                if f"{item.pid}, status='terminated'" in str(proc_kill):
                    logger.warning(f'orphaned proc {item.pid} terminated')

## Process network packets using global Parser instance and dictionary of supported protocols
def process_packet(packet):
    try:
        highest_layer = packet.highest_layer
        ## Avoids any UDP/TCP.SEGMENT reassemblies and raw UDP/TCP packets
        if '_' in highest_layer and highest_layer != 'DATA_RAW':        
            inspection_layer = str(highest_layer).split('_')[0]
            ## If XML traffic included over HTTP, match on XML parsing
            if inspection_layer == 'XML':
                fn = parser.parse_xml(packet)
                if fn is not None:
                    endpoints.update_db_list(fn)
            else:
                for layer in packet.layers:
                    fn = packet_callbacks.get(layer.layer_name)
                    if fn is not None:
                        endpoints.update_db_list(fn(packet))
    except Exception as e:
        logger.debug(f'error processing packet details {highest_layer}: {e}')

def capture_live_packets(network_interface, bpf_filter):
    global capture_count
    currentPacket = 0
    capture = pyshark.LiveCapture(interface=network_interface, bpf_filter=bpf_filter, include_raw=True, use_json=True, output_file='/tmp/pyshark.pcapng')
    logger.debug(f'beginning capture instance to file: {capture._output_file}')
    for packet in capture.sniff_continuously(packet_count=1000):
        try:
            process_packet(packet)
            currentPacket += 1
        except Exception as e:
            logger.debug(f'error processing packet {e}')
    capture.close()
    logger.debug(f'stopping capture instance')
    ## Check for any orphaned 'dumpcap' processes from pyshark still running from old instance, and terminate them
    proc_cleanup('dumpcap')
    capture_count += 1
        
async def default_service_reregister_loop(config, pxgrid, service_id, reregister_delay):
    '''
    Simple custom service reregistration to keep things alive.
    '''
    try:
        while True:
            await asyncio.sleep(reregister_delay)
            try:
                resp = pxgrid.service_reregister(service_id)
                logger.debug(
                    '[default_service_reregister_loop] service reregister response %s',
                    json.dumps(resp))
            except Exception as e:
                logger.debug(
                    '[default_service_reregister_loop] failed to reregister, Exception: %s',
                    e.__str__())

            # pull service back to check
            service_lookup_response = pxgrid.service_lookup(config.service)
            service = service_lookup_response['services'][0]
            debug_text = json.dumps(resp, indent=2, sort_keys=True)
            for debug_line in debug_text.splitlines():
                logger.debug('[default_publish_loop] service_register_response %s', debug_line)

    except asyncio.CancelledError as e:
        logger.debug('[default_service_reregister_loop] reregister loop cancelled')

async def default_publish_loop(config, secret, pubsub_node_name, ws_url, topic):
    '''
    Simple publish loop just to send some canned data.
    '''
    if config.discovery_override:
        logger.info('[default_publish_loop] overriding original URL %s', ws_url)
        ws_url = create_override_url(config, ws_url)
        logger.info('[default_publish_loop] new URL %s', ws_url)

    logger.debug('[default_publisher_loop] starting subscription to %s at %s', topic, ws_url)

    logger.debug('[default_publish_loop] opening web socket and stomp')
    ws = WebSocketStomp(
        ws_url,
        config.node_name,
        secret,
        config.ssl_context,
        # ping_interval=None)
        ping_interval=config.ws_ping_interval)

    try:
        logger.debug('[default_publish_loop] connect websocket')    
        await ws.connect()
        logger.debug('[default_publish_loop] connect STOMP node %s', pubsub_node_name)    
        await ws.stomp_connect(pubsub_node_name)
    except Exception as e:
        logger.debug('[default_publish_loop] failed to connect, Exception: %s', e.__str__())
        return
    try:
        count = 0
        while True:
            await asyncio.sleep(5.0)
            logger.debug('obtaining endpoints from local db to send to ISE')
            results = await endpoints.get_active_entries()
            logger.debug(f'local db records pending update to ISE: {len(results)}')
            if results:
                for row in results:
                    message = {
                        "opType": "UPDATE",
                        "asset": {
                            "assetId": row[3],
                            "assetName": row[4],
                            "assetIpAddress": row[2],
                            "assetMacAddress": row[0],
                            "assetVendor": row[5],
                            "assetHwRevision": row[6],
                            "assetSwRevision": row[7],
                            "assetProtocol": row[1],
                            "assetProductId": row[8],
                            "assetSerialNumber": row[9],
                            "assetDeviceType": row[10]
                        }
                    }
                    try:
                        await ws.stomp_send(topic, json.dumps(message))
                        logger.debug(f'ISE Endpoint Updated: {row[0]}, {row[2]}')
                        count += 1
                        await endpoints.ise_endpoint_updated(row[0])
                    except Exception as e:
                        logger.debug(
                            '[default_publish_loop] Exception: %s',
                            e.__str__())
                logger.debug(f'endpoint updates sent to ISE: {str(count)}')
            logger.debug(
                '[default_publish_loop] message published to node %s, topic %s',
                pubsub_node_name,
                topic)
            sys.stdout.flush()
    except asyncio.CancelledError as e:
        pass
    except WebSocketException as e:
        logger.debug(
            '[default_publish_loop] WebSocketException: %s',
            e.__str__())
        return
    
    logger.debug('[default_publish_loop] shutting down publisher...')
    await ws.stomp_disconnect('123')
    await asyncio.sleep(2.0)
    await ws.disconnect()

if __name__ == '__main__':
    ## Parse all of the CLI options provided
    config = Config()

    ## Add additional arguments to pxgrid_util Config class for pyshark funtionality
    g = config.parser.add_mutually_exclusive_group(required=True)
    g.add_argument(
        '--interface',
        help='Network interface receiving traffic to be analyzed')
    g = config.parser.add_mutually_exclusive_group(required=False)
    ## Process all arguments to the config class
    config.parse_args()

    ## Verbose logging if configured
    if config.verbose:
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter('%(asctime)s:%(name)s:%(levelname)s:%(message)s'))
        logger.addHandler(handler)
        logger.setLevel(logging.DEBUG)

        # and set for stomp and ws_stomp modules and sub-CLASSES of pxgrid_pyshark
        for modname in ['pxgrid_util.stomp', 'pxgrid_util.ws_stomp', 'pxgrid_util.pxgrid', 'pxgrid_pyshark.parser', 'pxgrid_pyshark.endpointsdb', 'pxgrid_pyshark.ouidb']:
            s_logger = logging.getLogger(modname)
            handler.setFormatter(logging.Formatter('%(asctime)s:%(name)s:%(levelname)s:%(message)s'))
            s_logger.addHandler(handler)
            s_logger.setLevel(logging.DEBUG)

    ## Verify required attributes provided via CLI
    if not config.hostname:
        print("No hostname!")
        sys.exit(0)
    if not config.node_name:
        print("No nodename provided (aka. pxgrid account username)")
        sys.exit(0)
    if not config.service:
        config.config.service = 'com.cisco.endpoint.asset'
        logger.debug(f'using default pxgrid service: com.cisco.endpoint.asset')
    if not config.topic:
        config.config.topic = 'asset'
        logger.debug(f'using default pxgrid topic: asset')
    if not config.config.interface:
        print("No capture interface provided")
        sys.exit(1)
    else:
        capture_int = config.config.interface
        ints = netifaces.interfaces()
        if capture_int not in ints:
            print(f'Invalid interface name provided: {capture_int}.')
            print(f'Valid interface names are: {ints}')
            sys.exit(1)
        logger.debug(f'using capture interface = {capture_int}')
    
    config.parse_args()    
    
    ## Setup pxGrid control object
    pxgrid = PXGridControl(config=config)
    ## Ensure account provided is approved in ISE UI
    while pxgrid.account_activate()['accountState'] != 'ENABLED':
        time.sleep(60)
    ## Register a custom service
    properties = {
        'wsPubsubService': 'com.cisco.ise.pubsub',
        f'{config.topic}': f'/topic/{config.service}',
    }
    resp = pxgrid.service_register(config.service, properties)
    debug_text = json.dumps(resp, indent=2, sort_keys=True)
    for debug_line in debug_text.splitlines():
        logger.debug('[service_register_response] %s', debug_line)
    ## Setup periodic service reregistration as a task
    reregister_task = asyncio.ensure_future(
        default_service_reregister_loop(
            config,
            pxgrid,
            resp['id'],
            config.reregister_delay,
    ))

    ## Lookup service and topic details for the service we just registered
    service_lookup_response = pxgrid.service_lookup(config.service)
    slr_string = json.dumps(service_lookup_response, indent=2, sort_keys=True)
    logger.debug('service lookup response:')
    for s in slr_string.splitlines():
        logger.debug('  %s', s)
    service = service_lookup_response['services'][0]
    pubsub_service_name = service['properties']['wsPubsubService']
    try:
        topic = service['properties'][config.topic]
    except KeyError as e:
        logger.debug('invalid topic %s', config.topic)
        possible_topics = [
            k for k in service['properties'].keys()
            if k != 'wsPubsubService' and k != 'restBaseUrl' and k != 'restBaseURL'
        ]
        logger.debug('possible topic handles: %s', ', '.join(possible_topics))
        sys.exit(1)

    ## Lookup the pubsub service
    service_lookup_response = pxgrid.service_lookup(pubsub_service_name)

    ## Use the first pubsub service node returned (there is randomness)
    pubsub_service = service_lookup_response['services'][0]
    pubsub_node_name = pubsub_service['nodeName']
    secret = pxgrid.get_access_secret(pubsub_node_name)['secret']
    ws_url = pubsub_service['properties']['wsUrl']

    ## Setup the publishing loop
    main_task = asyncio.ensure_future(
        default_publish_loop(
            config,
            secret,
            pubsub_node_name,
            ws_url,
            topic,
    ))

    ## Setup sigint/sigterm handlers
    def signal_handlers():
        global capture_running
        main_task.cancel()
        reregister_task.cancel()
        capture_running = False
    loop = asyncio.get_event_loop()
    loop.add_signal_handler(SIGINT, signal_handlers)
    loop.add_signal_handler(SIGTERM, signal_handlers)

    ## Create the local DB file for storing parsed packets
    logger.debug('building databases')
    endpoints = endpointsdb()
    logger.debug('building databases complete')

    ## Begin the capture on the indicated interface (replace w/ relevant interface name)
    capture_running = True

    try:
        while capture_running:
            try:
                capture_live_packets(capture_int, default_bpf_filter)
            except Exception as e:
                logging.warning(f'error with capture instance: {e}')
    except KeyboardInterrupt:
        logging.warning('closing capture down due to keboard interrupt')
        capture_running = False
        sys.exit(0)

    try:
        loop.run_until_complete(main_task)
    except:
        pass
    print('### FINAL OUTPUT ###')

    ## Provide output of all entries within local DB and stats for update messages
    endpoints.view_all_entries()
    endpoints.view_stats()