#!/usr/bin/env python

from subprocess import run
from random import choice, seed
from os.path import expanduser, exists
from collections import defaultdict
from time import sleep
import pickle
import re
import sys
from time import time
import json

from rich.progress import track
from rich import print
from rich.panel import Panel
from rich.console import Console

import argparse
parser = argparse.ArgumentParser("whoishome", """
    Check who is home using python and nmap!
    The storage is really just a pickled defaultdict mapping names to lists of mac addresses. 
    This means it would be pretty easy to automatically generate this info from some other source too.
    """)
parser.add_argument("--subnet", default="10.0.0.0/24")
parser.add_argument("--sleepTime", type=int, default=15, help="How many seconds between nmap calls. Beware that nmap is a beast...")
parser.add_argument("--add", help="Add people and mac addresses like `--add Steve=11:22:33:44:55:66`. This can also be used to reassign mac addresses to a different person.")
parser.add_argument("--remove", help="Remove people like `--remove Steve` or remove MAC addresses like `--remove 11:22:33:44:55:66`")
parser.add_argument("--storageLocation", default=expanduser("~/.whosHome.pkl"), help="This could be used to have multiple `homes`")
parser.add_argument("--streamStats", help="You can specify a filename to write stats to about the comings and goings of people.")
parser.add_argument("--viewStats", action="store_true", help="Use this to see a nice visualization of the stats saved by --streamStats")
args = parser.parse_args()

def colorGen(matplotlib=False):
    r,g,b = 100, 150, 50
    while True:
        funcs = [
            lambda r,g,b: (255-r, g, b),
            lambda r,g,b: (r, 255-g, b),
            lambda r,g,b: (r, g, 255-b),
            lambda r,g,b: (r + 100, g, b),
            lambda r,g,b: (r + 100, g, b),
            lambda r,g,b: (r + 100, g, b)
        ]
        r,g,b = choice(funcs)(r,g,b)
        r,g,b = r%256, g%256, b%256
        if matplotlib:
            yield (r/255, g/255, b/255)
        else:
            yield f"rgb({r},{g},{b})"

def loadPeople():
    if exists(args.storageLocation):
        with open(args.storageLocation, "rb") as f:
            people = pickle.load(f)
    else:
        people = defaultdict(list)

    return people

def savePeople(people):
    with open(args.storageLocation, "wb") as f:
        pickle.dump(people, f)


# Get MAC Addresses on the Network.
def findMacs():
    p = run(f"sudo nmap -sP {args.subnet}", shell=True, capture_output=True)
    output = p.stdout.decode("utf8").upper()
    macs = [x[0] for x in re.findall(r"(([0-9A-f]{2}:){5}[0-9A-f]{2})", output)]
    return macs, output

c = Console()
def printPanels(foundStr, unfoundStr, unknownStr, extraStr):
    c.clear()
    print(Panel(foundStr, title="People Found", border_style="red"))
    print(Panel(unfoundStr, title="People NOT Found", border_style="green"))
    print(Panel(unknownStr, title="Unknown MAC Addresses", border_style="purple"))
    print(Panel(extraStr, title="Extra Info", border_style="blue"))


def streamStats(macs):
    if "streamStats" not in args:
        return

    filename = args.streamStats
    with open(filename, "a") as f:
        f.write(f"{time()}: {macs}\n")


people = loadPeople()

if args.add:
    name, mac = args.add.split("=")
    mac = mac.upper()

    for pname in people:
        if mac in people[pname]:
            people[pname].remove(mac)
    people[name].append(mac)

    savePeople(people)
    sys.exit()

if args.remove:
    if re.match(r"(([0-9A-f]{2}:){5}[0-9A-f]{2})", args.remove):
        for person in people.values():
            args.remove = args.remove.upper()
            if args.remove in person:
                person.remove(args.remove)
    elif args.remove in people:
        del people[args.remove]

    savePeople(people)
    sys.exit()

if args.viewStats:
    import matplotlib.pyplot as plt

    if "streamStats" not in args:
        print("When using --viewStats, you must also specify a stats file to view via --streamStats")
        sys.exit()

    with open(args.streamStats, "r") as f:
        lines = f.readlines()
        ts, macs = zip(*[l.split(":", 1) for l in lines])
        ts = [float(t) for t in ts]
        macs = [json.loads(mac_list.replace("'", '"')) for mac_list in macs]


    # First Plot streaks of known people
    seed(0)
    g = colorGen(matplotlib=True)
    for i, person in enumerate(sorted(people)):
        color = next(g)
        mask = [len(people)+1-i if len(set(people[person]).intersection(set(mac_list))) > 0 else 0 for mac_list in macs]
        
        label = person

        streak = []
        for t, val in zip(ts, mask):
            if val != 0:
                streak.append((t,val))

            if val == 0 and streak:
                plt.plot(*zip(*streak), label=label, color=color)
                label=None
                streak = []

        if streak:
            plt.plot(*zip(*streak), label=label, color=color)
        elif label:
            plt.plot([],[], label=label, color=color)
    
    # Next plot streaks for unknown mac addresses
    unknownMacs = set()
    for mac_list in macs:
        for mac in mac_list:
            if not any(mac in person_macs for person_macs in people.values()):
                unknownMacs.add(mac)

    for i, mac in enumerate(unknownMacs):
        color = next(g)
        label = mac
        mask = [-i if mac in mac_list else 0 for mac_list in macs]
        
        streak = []
        for t, val in zip(ts, mask):
            if val != 0:
                streak.append((t,val))

            if val == 0 and streak:
                plt.plot(*zip(*streak), label=label, color=color)
                label=None
                streak = []

        if streak:
            plt.plot(*zip(*streak), label=label, color=color)
        elif label:
            plt.plot([],[], label=label, color=color)

        
    plt.legend()
    plt.show()

    sys.exit()



printPanels("loading :ten_o’clock:", "please be patient :pray:", "nmap takes a long time :poop:", "sorry :broken_heart:")

# main monitoring loop....
while True:
    foundStr, unfoundStr, unknownStr, extraStr = "", "", "", ""

    print("running nmap")
    macs, output = findMacs()
    print("finished running nmap")

    unfound = []
    # FOUND PEOPLE ####################################
    seed(0)
    g = colorGen()
    for name in sorted(people):
        color = next(g)
        personMacs = sorted([mac for mac in people[name] if mac in macs])
        if personMacs:
            boldMacs = ", ".join(f"[underline]{m}[/underline]" if m in personMacs else m for m in people[name])
            foundStr += f"[{color}]{name:<15}[/{color}] [{boldMacs}]\n"
        else:
            unfound.append((name, color))

    unknownMacs = [m for m in macs if not any(m in pmacs for pmacs in people.values())]

    # UNFOUND PEOPLE ###################################
    for name, color in unfound:
        tmp = ", ".join(people[name])
        unfoundStr += f"[{color}]{name:<15}[/{color}] [{tmp}]\n"

    # UNKNOWN MAC ADDRESSES ############################
    if unknownMacs:
        unknownStr += f"{'Unknown':<15} {unknownMacs}\n"

    # EXTRA INFORMATION #################################
    extraStr = f"{len(macs)} MACS DETECTED\n"
    extraStr += re.findall(r"[0-9]+ HOSTS UP", output)[0]

    # ACTUAL PRINT ######################################
    printPanels(foundStr, unfoundStr, unknownStr, extraStr)

    # STREAM STATISTICS #################################
    streamStats(macs)

    # SLEEP SO WE DON'T BOMBARD NMAP ####################
    for i in track(range(args.sleepTime), description="Time till next nmap call starts"):
        sleep(1)
        print(".", end='')
