#!python
"""
This script filters one or more provided pcap files (not pcapng), creating
an output file containing only packets with a supplied Community ID hash.

The output file's packets retain timestamp and all packet data from the
originally supplied pcap file(s).

This is based heavily on the "community-id-pcap" script in the same
directory and retains all of its limitations and caveats at the time this
script was created.

Currently supported protocols include IP, IPv6, ICMP, ICMPv6, TCP,
UDP, SCTP.

Please note: the protocol parsing implemented in this script relies
on the dpkt module and is somewhat simplistic:

- dpkt seems to struggle with some SCTP packets, for which it fails
  to register SCTP even though its header is correctly present.

- The script doesn't try to get nested network layers (IP over IPv6,
  IP in IP, etc) right. It expects either IP or IPv6, and it expects
  a transport-layer protocol (including the ICMPs here) as the
  immediate next layer.
"""
import argparse
import gzip
import sys

import communityid

try:
    import dpkt
except ImportError:
    print('This needs the dpkt Python module')
    sys.exit(1)

from dpkt.ethernet import Ethernet #pylint: disable=import-error
from dpkt.ip import IP #pylint: disable=import-error
from dpkt.ip6 import IP6 #pylint: disable=import-error
from dpkt.icmp import ICMP #pylint: disable=import-error
from dpkt.icmp6 import ICMP6 #pylint: disable=import-error
from dpkt.tcp import TCP #pylint: disable=import-error
from dpkt.udp import UDP #pylint: disable=import-error
from dpkt.sctp import SCTP #pylint: disable=import-error

class PcapFilter(object):
    def __init__(self, commid, pcap, commidfilter, outputwriter):
        self._commid = commid
        self._pcap = pcap
        self._commidfilter = commidfilter
        self._outputwriter = outputwriter

    def process(self):
        if self._pcap.endswith('.gz'):
            opener=gzip.open
        else:
            opener=open

        with opener(self._pcap, 'r+b') as inhdl:
            reader = dpkt.pcap.Reader(inhdl)
            for tstamp, pktdata in reader:
                self._process_packet(tstamp, pktdata, self._outputwriter)

    def _process_packet(self, tstamp, pktdata, outputwriter):
        pkt = self._packet_parse(pktdata)

        if not pkt:
            return

        if IP in pkt:
            saddr = pkt[IP].src
            daddr = pkt[IP].dst
        elif IP6 in pkt:
            saddr = pkt[IP6].src
            daddr = pkt[IP6].dst
        else:
            return

        tpl = None

        if TCP in pkt:
            tpl = communityid.FlowTuple(
                dpkt.ip.IP_PROTO_TCP, saddr, daddr,
                pkt[TCP].sport, pkt[TCP].dport)

        elif UDP in pkt:
            tpl = communityid.FlowTuple(
                dpkt.ip.IP_PROTO_UDP, saddr, daddr,
                pkt[UDP].sport, pkt[UDP].dport)

        elif SCTP in pkt:
            tpl = communityid.FlowTuple(
                dpkt.ip.IP_PROTO_SCTP, saddr, daddr,
                pkt[SCTP].sport, pkt[SCTP].dport)

        elif ICMP in pkt:
            tpl = communityid.FlowTuple(
                dpkt.ip.IP_PROTO_ICMP, saddr, daddr,
                pkt[ICMP].type, pkt[ICMP].code)

        elif ICMP6 in pkt:
            tpl = communityid.FlowTuple(
                dpkt.ip.IP_PROTO_ICMP6, saddr, daddr,
                pkt[ICMP6].type, pkt[ICMP6].code)

        if tpl is None:
            # Fallbacks to other IP protocols:
            if IP in pkt:
                tpl = communityid.FlowTuple(pkt[IP].p, saddr, daddr)
            elif IP6 in pkt:
                tpl = communityid.FlowTuple(pkt[IP].nxt, saddr, daddr)

        if tpl is None:
            return

        res = self._commid.calc(tpl)

        if res == self._commidfilter:
            outputwriter.writepkt(pktdata, tstamp)

    def _packet_parse(self, pktdata):
        """
        Parses the protocols in the given packet data and returns the
        resulting packet (here, as a dict indexed by the protocol layers
        in form of dpkt classes).
        """
        layer = Ethernet(pktdata)
        pkt = {}

        if isinstance(layer.data, IP):
            pkt[IP] = layer = layer.data
        elif isinstance(layer.data, IP6):
            # XXX This does not correctly skip IPv6 extension headers
            pkt[IP6] = layer = layer.data
        else:
            return pkt

        if isinstance(layer.data, ICMP):
            pkt[ICMP] = layer.data
        elif isinstance(layer.data, ICMP6):
            pkt[ICMP6] = layer.data
        elif isinstance(layer.data, TCP):
            pkt[TCP] = layer.data
        elif isinstance(layer.data, UDP):
            pkt[UDP] = layer.data
        elif isinstance(layer.data, SCTP):
            pkt[SCTP] = layer.data

        return pkt

def main():
    parser = argparse.ArgumentParser(description='Community ID pcap filtering utility')
    parser.add_argument('pcaps', metavar='PCAP', nargs='+',
                        help='PCAP packet capture files')
    parser.add_argument('--filter', metavar='FILTER', required=True,
                        help='Community ID string in base64 format to filter from input pcap file(s)')
    parser.add_argument('--output', metavar='OUTPUT', required=True,
                        help='Output pcap file to create and place matching packets into')
    parser.add_argument('--seed', type=int, default=0, metavar='NUM',
                        help='Seed value for hash operations')
    args = parser.parse_args()

    commid = communityid.CommunityID(args.seed)

    # if outfile exists, quit
    try:
        outhdl = open(args.output, 'xb')
    except FileExistsError:
        print('Error: output file %s already exists. Exiting.' % (args.output))
        return 2
    else:
        writer = dpkt.pcap.Writer(outhdl)

        for pcap in args.pcaps:
            itr = PcapFilter(commid, pcap, args.filter, writer)
            itr.process()

    return 0

if __name__ == '__main__':
    sys.exit(main())
