#!/usr/bin/env python
# vim: set ts=4 sts=4 sw=4 et tw=0 fileencoding=utf-8:
#
# traflog: a simple traffic accounter
#
# Logs traffic data: source and destination addresses, bytes sent and received,
# timestamp (minute level).  Produces a simple mb-per-ip report for last 24 hours.
# Stores data in an SQLite database, complex reports can be done easily.

import argparse
import datetime
import json
import os
import socket
import sys
import time

import pcap
from sqlite3 import dbapi2 as sqlite


__author__ = "Justin Forest <hex@umonkey.net>"


FREE_HOURS = [1, 2, 3, 4, 5, 6]

ETH_FRAME_LEN = 14


class Sniffer(object):
    def __init__(self, db_path, net, netmask):
        self.db_path = db_path
        self.net = net
        self.netmask = netmask

        self.spool = {}
        self.last_ts = self.get_timestamp()

        self.db = self.connect(db_path)
        self.db_init()

    def get_timestamp(self):
        ts = int(time.time())
        ts = ts - ts % 60
        return ts

    def connect(self, dbpath):
        db = sqlite.connect(dbpath)
        return db

    def db_init(self):
        self.db_query("CREATE TABLE IF NOT EXISTS `log` (`ts` INTEGER, `mac` INTEGER, `ipv4` INTEGER, `sent` INTEGER, `recv` INTEGER, `free` INTEGER)")
        self.db_query("CREATE INDEX IF NOT EXISTS `IDX_log_ts` ON `log` (`ts`)")
        self.db_query("CREATE INDEX IF NOT EXISTS `IDX_log_mac` ON `log` (`mac`)")
        self.db_query("CREATE INDEX IF NOT EXISTS `IDX_log_ipv4` ON `log` (`ipv4`)")
        self.db_query("CREATE INDEX IF NOT EXISTS `IDX_log_free` ON `log` (`free`)")
        self.db_query("CREATE TABLE IF NOT EXISTS `labels` (`addr` TEXT, `label` TEXT)")

    def db_query(self, query, params=None):
        cur = self.db.cursor()
        cur.execute(query, params or [])

        if query.startswith("INSERT"):
            return cur.lastrowid
        else:
            return cur.rowcount

    def db_fetch(self, query, params=None):
        cur = self.db.cursor()
        cur.execute(query, params or [])

        rows = []
        while True:
            row = cur.fetchone()
            if row is None:
                break
            rows.append(row)

        cur.close()
        return rows

    def db_commit(self):
        self.db.commit()

    def parse_mac(self, pkt, offset):
        addr = 0L
        for x in range(0, 6):
            addr = addr * 256 + ord(pkt[offset + x])
        return addr

    def parse_ipv4(self, pkt, offset):
        addr = 0
        for x in range(0, 4):
            addr = addr * 256 + ord(pkt[offset + x])
        return addr

    def format_mac(self, addr):
        parts = []
        for x in range(0, 6):
            parts.append("%02x" % (addr & 255))
            addr = addr >> 8
        return ":".join(reversed(parts))

    def format_ipv4(self, addr):
        parts = []
        if addr is None:
            return "0.0.0.0"
        for x in range(0, 4):
            parts.append(str(addr & 255))
            addr = addr >> 8
        return ".".join(reversed(parts))

    def parse_packet(self, pkt):
        packet = {
            "src_mac": self.parse_mac(pkt, 0),
            "dst_mac": self.parse_mac(pkt, 6),
            "src_ipv4": None,
            "dst_ipv4": None,
            "length": len(pkt) - ETH_FRAME_LEN
        }

        if pkt[13] == '\0':
            packet["src_ipv4"] = self.parse_ipv4(pkt, 14 + 12)
            packet["dst_ipv4"] = self.parse_ipv4(pkt, 14 + 16)
            return packet

    def is_local(self, addr):
        return addr & self.netmask == self.net

    def is_free(self, ts):
        now = datetime.datetime.now()
        return now.hour in FREE_HOURS

    def log_data(self, mac, ipv4, sent, recv):
        ts = self.get_timestamp()
        if ts != self.last_ts:
            self.flush_log()
            self.spool = {}
            self.last_ts = ts

        if ipv4 not in self.spool:
            self.spool[ipv4] = {
                "mac": mac,
                "sent": 0,
                "recv": 0,
                "free": 1 if self.is_free(ts) else 0,
            }

        self.spool[ipv4]["sent"] += sent
        self.spool[ipv4]["recv"] += recv

    def flush_log(self):
        for addr, info in self.spool.items():
            self.db_query("INSERT INTO `log` (`ts`, `mac`, `ipv4`, `sent`, `recv`, `free`) VALUES (?, ?, ?, ?, ?, ?)", [self.last_ts, info["mac"], addr, info["sent"], info["recv"], info["free"]])
        self.db_commit()

    def run(self):
        try:
            _pcap = pcap.pcap(name=None, promisc=True, immediate=True, timeout_ms=50)
        except OSError, e:
            print >>sys.stderr, "Packet capture doesn't work, need root?  %s" % e
            exit(1)

        print "Logging to %s" % args.db_path

        for ts, raw_pkt in _pcap:
            p = self.parse_packet(raw_pkt)
            if p is None:
                continue

            src_local = self.is_local(p["src_ipv4"])
            dst_local = self.is_local(p["dst_ipv4"])

            sent = 0
            recv = 0
            mac = None
            ipv4 = None

            if src_local and dst_local:
                continue  # local traffic, don't count

            elif src_local:
                sent = p["length"]
                mac = p["src_mac"]
                ipv4 = p["src_ipv4"]

            elif dst_local:
                recv = p["length"]
                mac = p["dst_mac"]
                ipv4 = p["dst_ipv4"]

            self.log_data(mac, ipv4, recv, sent)

    def load_labels(self):
        rows = self.db_fetch("SELECT addr, label FROM labels")
        return {r[0]: r[1] for r in rows}

    def report(self, hours):
        labels = self.load_labels()

        since = int(time.time()) - hours * 3600
        rows = self.db_fetch("SELECT ipv4, SUM(sent) + SUM(recv) AS total, free FROM log WHERE ts >= ? GROUP BY ipv4, free HAVING total > 1048576", [since])

        stats = {}
        for addr, total, free in rows:
            if addr not in stats:
                stats[addr] = {"free": 0, "nonfree": 0}
            if free:
                stats[addr]["free"] = total
            else:
                stats[addr]["nonfree"] = total

        print "addr              nonfree      free"
        print "-----------------------------------"
        for addr, st in sorted(stats.items(), key=lambda x: x[1]["nonfree"], reverse=True):
            addr = self.format_ipv4(addr)
            print "%-15s  %8.2f  %8.2f  %s" % (addr, st["nonfree"] / 1048576, st["free"] / 1048576, labels.get(addr, "?"))


def aton(s):
    return int(socket.inet_aton(s).encode("hex"), 16)


def main(prog, *args):
    ap = argparse.ArgumentParser(
        description="traffic counter and reporter")

    db_path = os.path.expanduser("~/traflog.sqlite")

    ap.add_argument("--net", default="192.168.1.0", help="network address", metavar="IP")
    ap.add_argument("--mask", default="255.255.255.0", help="network mask", metavar="IP")
    ap.add_argument("--report", default=False, help="print usage report", action="store_true")
    ap.add_argument("--hours", type=int, default=24, help="report duration")
    ap.add_argument("db_path", nargs="?", help="SQLite database file", default=db_path)

    args = ap.parse_args()

    sniffer = Sniffer(args.db_path, aton(args.net), aton(args.mask))

    if args.report:
        sniffer.report(args.hours)
    else:
        sniffer.run()


if __name__ == "__main__":
    main(*sys.argv)
