#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
This script allows you to sniff the Wi-Fi probe requests
passing near your wireless interface.
"""

from argparse import ArgumentParser, FileType
from csv import writer
from netaddr import EUI, NotRegisteredError
from os import geteuid
from queue import Queue, Empty
from re import compile as rcompile, match, IGNORECASE
from scapy.all import *
from sys import exit as sys_exit
from threading import Thread, Event
from time import sleep, localtime, strftime

class ProbeRequest:
    """
    A Wi-Fi probe request.
    """

    def __init__(self, timestamp, s_mac, essid):
        self.timestamp = timestamp
        self.s_mac = s_mac
        self.essid = essid

        self.s_mac_oui = self.get_mac_organisation()

    def __str__(self):
        return "{timestamp} - {s_mac} ({mac_org}) -> {essid}".format(
            timestamp=strftime("%a, %d %b %Y %H:%M:%S %Z", localtime(self.timestamp)),
            s_mac=self.s_mac,
            mac_org=self.s_mac_oui,
            essid=self.essid
        )

    def get_mac_organisation(self):
        """
        Returns the OUI of the MAC address as a string.
        """

        try:
            return EUI(self.s_mac).oui.registration().org
        except NotRegisteredError:
            return None

class ProbeRequestSniffer:
    """
    A Wi-Fi probe request sniffer.
    """

    SNIFFER_STOP_TIMEOUT = 2.0

    def __init__(self, interface, essid_filters=None, essid_regex=None, ignore_case=False, mac_exclusions=None, mac_filters=None, display_func=lambda p: None, storage_func=lambda p: None, debug=False):
        if not hasattr(display_func, '__call__'):
            raise TypeError('The display function parameter is not a callable object')
        if not hasattr(storage_func, '__call__'):
            raise TypeError('The storage function parameter is not a callable object')

        self.new_packets = Queue()

        self.sniffer = self.PacketSniffer(
            self.new_packets,
            interface,
            mac_exclusions=mac_exclusions,
            mac_filters=mac_filters,
            debug=debug
        )

        self.parser = self.ProbeRequestParser(
            self.new_packets,
            essid_filters=essid_filters,
            essid_regex=essid_regex,
            ignore_case=ignore_case,
            display_func=display_func,
            storage_func=storage_func,
            debug=debug
        )

    def start(self):
        """
        Starts the probe request sniffer.

        This method will start the sniffing thread and the parsing thread.
        """

        self.sniffer.start()
        self.parser.start()

        e = self.sniffer.get_exception()

        if e is not None:
            raise e

    def stop(self):
        """
        Stops the probe request sniffer.

        This method will stop the sniffing thread and the parsing thread.
        """

        self.sniffer.join(timeout=ProbeRequestSniffer.SNIFFER_STOP_TIMEOUT)

        if self.sniffer.isAlive():
            self.sniffer.socket.close()

        self.parser.join()

    class PacketSniffer(Thread):
        """
        A packet sniffing thread.
        """

        def __init__(self, new_packets, interface, mac_exclusions=None, mac_filters=None, debug=False):
            super().__init__()

            self.daemon = True

            self.new_packets = new_packets
            self.interface = interface

            self.frame_filters = "type mgt subtype probe-req"
            self.socket = None
            self.stop_sniffer = Event()
            self.exception = None

            if mac_exclusions is not None:
                self.frame_filters += " and not ("

                for i, station in enumerate(mac_exclusions):
                    if i == 0:
                        self.frame_filters += "ether src host {s_mac}".format(s_mac=station)
                    else:
                        self.frame_filters += " || ether src host {s_mac}".format(s_mac=station)

                self.frame_filters += ")"

            if mac_filters is not None:
                self.frame_filters += " and ("

                for i, station in enumerate(mac_filters):
                    if i == 0:
                        self.frame_filters += "ether src host {s_mac}".format(s_mac=station)
                    else:
                        self.frame_filters += " || ether src host {s_mac}".format(s_mac=station)

                self.frame_filters += ")"

            if debug:
                print("[!] Frame filters: " + self.frame_filters)

        def run(self):
            try:
                self.socket = conf.L2listen(
                    type=ETH_P_ALL,
                    iface=self.interface,
                    filter=self.frame_filters
                )

                sniff(
                    opened_socket=self.socket,
                    store=False,
                    prn=self.new_packet,
                    stop_filter=self.stop_sniffer.isSet
                )
            except Exception as e:
                self.exception = e

        def join(self, timeout=None):
            """
            Stops the packet sniffer.
            """

            self.stop_sniffer.set()
            super().join(timeout)

        def new_packet(self, packet):
            """
            Adds a new packet to the queue to be processed.
            """

            self.new_packets.put(packet)

        def get_exception(self):
            """
            Returns the raised exception if any, otherwise returns none.
            """
            return self.exception

    class ProbeRequestParser(Thread):
        """
        A Wi-Fi probe request parsing thread.
        """

        def __init__(self, new_packets, essid_filters=None, essid_regex=None, ignore_case=False, display_func=lambda p: None, storage_func=lambda p: None, debug=False):
            super().__init__()

            self.new_packets = new_packets
            self.essid_filters = essid_filters
            self.display_func = display_func
            self.storage_func = storage_func

            self.stop_parser = Event()

            if debug:
                print("[!] ESSID filters: " + str(self.essid_filters))
                print("[!] ESSID regex: " + str(essid_regex))
                print("[!] Ignore case: " + str(ignore_case))

            if essid_regex is not None:
                if ignore_case:
                    self.essid_regex = rcompile(essid_regex, IGNORECASE)
                else:
                    self.essid_regex = rcompile(essid_regex)
            else:
                self.essid_regex = None

        def run(self):
            # The parser continues to do its job even after the call of the
            # join method if the queue is not empty.
            while not self.stop_parser.isSet() or not self.new_packets.empty():
                try:
                    packet = self.new_packets.get(timeout=1)
                    probe_request = self.parse(packet)

                    if not probe_request.essid:
                        continue

                    if self.essid_filters is not None and not probe_request.essid in self.essid_filters:
                        continue

                    if self.essid_regex is not None and not match(self.essid_regex, probe_request.essid):
                        continue

                    self.display_func(probe_request)
                    self.storage_func(probe_request)

                    self.new_packets.task_done()
                except Empty:
                    pass

        def join(self, timeout=None):
            """
            Stops the probe request parsing thread.
            """

            self.stop_parser.set()
            super().join(timeout)

        @staticmethod
        def parse(packet):
            """
            Parses the packet and returns a probe request object.
            """

            timestamp = packet.getlayer(RadioTap).time
            s_mac = packet.getlayer(RadioTap).addr2
            essid = packet.getlayer(Dot11ProbeReq).info.decode("utf-8")

            return ProbeRequest(timestamp, s_mac, essid)

if __name__ == "__main__":
    ap = ArgumentParser(description="Wi-Fi Probe Requests Sniffer")
    essid_arguments = ap.add_mutually_exclusive_group()
    station_arguments = ap.add_mutually_exclusive_group()

    ap.add_argument("--debug", action="store_true", help="debug mode")
    essid_arguments.add_argument("-e", "--essid", nargs="+", help="ESSID of the APs to filter (space-separated list)")
    station_arguments.add_argument("--exclude", nargs="+", help="MAC addresses of the stations to exclude (space-separated list)")
    ap.add_argument("-i", "--interface", required=True, help="wireless interface to use (must be in monitor mode)")
    ap.add_argument("--ignore-case", action="store_true", help="ignore case distinctions in the regex pattern (default: false)")
    ap.add_argument("-o", "--output", type=FileType("a"), help="output file to save the captured data (CSV format)")
    essid_arguments.add_argument("-r", "--regex", help="regex to filter the ESSIDs")
    station_arguments.add_argument("-s", "--station", nargs="+", help="MAC addresses of the stations to filter (space-separated list)")

    ap.set_defaults(debug=False)
    ap.set_defaults(ignore_case=False)
    args = vars(ap.parse_args())

    if not geteuid() == 0:
        sys_exit("[!] You must be root")

    if args["output"]:
        outfile = writer(args["output"], delimiter=";")

        def write_csv(probe_req):
            outfile.writerow([probe_req.timestamp, probe_req.s_mac, probe_req.s_mac_oui, probe_req.essid])
    else:
        write_csv = lambda p: None

    def display_probe_req(probe_req):
        print(probe_req)

    print("[*] Start sniffing probe requests...")

    try:
        sniffer = ProbeRequestSniffer(
            args["interface"],
            essid_filters=args["essid"],
            essid_regex=args["regex"],
            ignore_case=args["ignore_case"],
            mac_exclusions=args["exclude"],
            mac_filters=args["station"],
            display_func=display_probe_req,
            storage_func=write_csv,
            debug=args["debug"]
        )

        sniffer.start()

        while True:
            sleep(100)
    except OSError:
        sniffer.stop()

        if args["output"]:
            args["output"].close()

        sys_exit("[!] Interface {interface} doesn't exist".format(interface=args["interface"]))
    except KeyboardInterrupt:
        print("[*] Stopping the threads...")
        sniffer.stop()

        if args["output"]:
            args["output"].close()

        print("[*] Bye!")
