#!/usr/bin/env python3
import pyshark
import time
import logging
import os
import sys
from pathlib import Path
from pxgrid_pyshark import endpointsdb
from pxgrid_pyshark import parser

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
default_filter = '!ipv6 && (ssdp || (http && http.user_agent != "") || sip || xml || browser || (mdns && (dns.resp.type == 1 || dns.resp.type == 16)))'
parser = parser()
currentPacket = 0

packet_callbacks = {
    'mdns': parser.parse_mdns_v7,
    'xml': parser.parse_xml,
    'sip': parser.parse_sip,
    'ssdp': parser.parse_ssdp,
    'http': parser.parse_http,
    'browser': parser.parse_smb_browser,
}

## Process network packets using global Parser instance and dictionary of supported protocols
def process_packet(packet):
    try:
        highest_layer = packet.highest_layer
        inspection_layer = str(highest_layer).split('_')[0]
        ## If XML traffic included over HTTP, match on XML parsing
        if inspection_layer == 'XML':
            fn = parser.parse_xml(packet)
            if fn is not None:
                endpoints.update_db_list(fn)
        else:
            for layer in packet.layers:
                fn = packet_callbacks.get(layer.layer_name)
                if fn is not None:
                    endpoints.update_db_list(fn(packet))
    except Exception as e:
        logger.debug(f'error processing packet details {highest_layer}: {e}')

## Process a given PCAP(NG) file with a provided PCAP filter
def process_capture_file(capture_file, capture_filter):
    if Path(capture_file).exists():
        logger.debug(f'processing capture file: {capture_file}')
        start_time = time.perf_counter()
        capture = pyshark.FileCapture(capture_file, display_filter=capture_filter, only_summaries=False, include_raw=True, use_json=True)
        currentPacket = 0
        for packet in capture:
            ## Wrap individual packet processing within 'try' statement to avoid formatting issues crashing entire process
            try:
                process_packet(packet)
            except TypeError as e:
                logger.debug(f'Error processing packet: {capture_file}, packet # {currentPacket}: TypeError: {e}')
            currentPacket += 1
        capture.close()
        end_time = time.perf_counter()
        logger.debug(f'processing capture file complete: execution time: {end_time - start_time:0.6f} : {currentPacket} packets processed ##')
    else:
        logger.debug(f'capture file not found: {capture_file}')

if __name__ == '__main__':

    debugMode = True

    if debugMode:
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter('%(asctime)s:%(name)s:%(levelname)s:%(message)s'))
        logger.addHandler(handler)
        logger.setLevel(logging.DEBUG)

        # Extend the logger function to the sub-CLASSES
        for modname in ['pxgrid_pyshark.parser', 'pxgrid_pyshark.endpointsdb', 'pxgrid_pyshark.ouidb']:
            s_logger = logging.getLogger(modname)
            handler.setFormatter(logging.Formatter('%(asctime)s:%(name)s:%(levelname)s:%(message)s'))
            s_logger.addHandler(handler)
            s_logger.setLevel(logging.DEBUG)
    
    ## Check if endpointsDB already exists
    logger.debug('database check - started')
    file_exists = os.path.exists('endpoint_database.db')
    if file_exists:
        endpoints = endpointsdb()
        print('## EXISTING ENDPOINT DB DETECTED! ###')
        append_DB = input('Re-use existing DB? (append new endpoint data to existing file) (yes/no): ').lower()
        if append_DB == 'yes' or append_DB == 'y':
            pass
        elif append_DB == 'no' or append_DB == 'n':
            endpoints.clear_database()
            endpoints.create_database()
        else:
            print('invalid input')
            sys.exit(0)
    else:
        endpoints = endpointsdb()
        endpoints.create_database()
    logger.debug('database check - complete')

    ### CAPTURE FILE ###
    print('#######################################')
    print('##  pxgrid-pyshark capture file test ##')
    print('#######################################')
    filename = input('Input local pcap(ng) file to be parsed: ')
    filter = input('Input custom wireshark filter (leave blank to use built-in filter): ')
    if filter == '':
        filter = default_filter
    print('#######################################')
    print('##      Analyzing capture file       ##')
    print('#######################################')

    start_time = time.perf_counter()
    process_capture_file(filename, filter)
    end_time = time.perf_counter()


    print('#######################################')
    print('##      Resulting Endpoint Data      ##')
    print('#######################################')
    endpoints.view_all_entries()
    endpoints.view_stats()
    print(f"Execution Time : {end_time - start_time:0.6f}" )
