#!/usr/bin/python3
# -*- coding: utf-8; mode: python -*-

import sys
import logging
import time
import textwrap

from functools import partial

import os
from pathlib import Path
import termios
import asyncio
import socket
from ipaddress import IPv4Network, IPv4Interface
import urllib.request

from asyncio.subprocess import DEVNULL, STDOUT, PIPE
from subprocess import check_output
import shlex
import shutil

from pprint import pprint
import xml.etree.ElementTree as ET

import netifaces


OUI_DB = '/var/lib/ieee-data/oui.txt'
OUI_DB = Path.home() / 'oui.txt'
TERMINAL_WIDTH = os.get_terminal_size().columns

class Host:
    def __init__(self, ip, mac):
        self.ip = ip
        self.mac = mac
        self.name = ''
        self.osname = 'unknown'
        self.ports = ''

    @property
    def oui(self):
        return self.mac.replace(':', '')[:6].upper()

    def __lt__(self, other):
        return self.ip < other.ip

    def __repr__(self):
        return "{} {}".format(self.ip, self.mac)


async def exec(cmd, stdout=DEVNULL):
    ps = await asyncio.create_subprocess_shell(
        cmd,
        stdout=stdout,
        stderr=STDOUT)
    await ps.wait()
    return ps


def sync_exec(cmd):
    return check_output(shlex.split(cmd)).decode()


def get_ip():
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
    s.connect(('<broadcast>', 0))
    return s.getsockname()[0]


def get_iface(ip):
    for ifname in netifaces.interfaces():
        ipdata = netifaces.ifaddresses(ifname).get(netifaces.AF_INET)
        if ipdata is None:
            continue

        ipdata = ipdata[0]
        cidr = IPv4Network((0, ipdata['netmask'])).prefixlen
        iface = IPv4Interface("{}/{}".format(ipdata['addr'], cidr))

        if str(iface.ip) == ip:
            iface.ifname = ifname
            return iface


async def ping(ip):
    ps = await exec(f"ping -c 1 {ip}")
    if ps.returncode == 0:
        return ip


def get_arp_table():
    def parse(line):
        if '(incomplete)' not in line:
            return line.split()

    arp = sync_exec('/usr/sbin/arp -n')
    lines = [parse(line) for line in arp.splitlines()[1:]]
    return {x[0]:Host(ip=x[0], mac=x[2]) for x in lines if x}


def download_oui_db():
    OUI_URL = 'http://standards-oui.ieee.org/oui.txt'

    print(f"\r-downloading {OUI_URL}... ", end="", flush=True)
    urllib.request.urlretrieve(OUI_URL, OUI_DB)


async def async_map(awaitable, iterable):
    tasks = []

    for x in iterable:
        tasks.append(asyncio.create_task(awaitable(x)))

    retval = await asyncio.gather(*tasks, return_exceptions=True)
    return retval


def text_ellipsis(text, width):
    return text[:width-3] + '...' if len(text) > width else f"{text:<{width}}"


def restore_echo():
    fd = sys.stdout.fileno()
    attr = termios.tcgetattr(fd)

    attr[3] = attr[3] | termios.ECHO
    termios.tcsetattr(fd, termios.TCSADRAIN, attr)


class Scanner:
    def __init__(self):
        self.hosts = []
        self.loop = None
        ip = get_ip()
        self.iface = get_iface(ip)

    async def add_name(self, ip:str):
        result = await self.loop.getnameinfo((ip, 443))
        hostname = result[0]
        if hostname != ip:
            self.hosts[ip].name = hostname

    def add_vendors(self):
        # FIXME: remove
        vendors = {h.oui:'' for h in self.hosts.values()}

        if not os.path.exists(OUI_DB):
            download_oui_db()

        with open(OUI_DB) as fd:
            for line in fd.readlines():
                try:
                    oui = line.split(maxsplit=1)[0]
                    if oui in vendors.keys():
                        vendors[oui] = line.split(maxsplit=3)[-1].strip()
                except IndexError:
                    pass

        for h in self.hosts.values():
            h.vendor = vendors[h.oui]

    def print_hosts(self):
        self.longest_ip = max(len(ip) for ip in self.hosts)
        longest_vendor =  max(len(h.vendor) for h in self.hosts.values())
        longest_name =    max(len(h.name) for h in self.hosts.values())

        mac_len = 17
        other = 10 + self.longest_ip + mac_len + longest_name
        vendor_len = min(TERMINAL_WIDTH - other, longest_vendor)

        for h in sorted(self.hosts.values()):
            vendor = text_ellipsis(h.vendor, width=vendor_len)
            print(f"{h.ip:>{self.longest_ip}} - {h.mac:>{mac_len}} - {vendor:<{vendor_len}} - {h.name}")

    async def guess_os(self, ip):
        # FIXME: check nmap is available

        # self.hosts[ip].ports = [12, 34, 5000, 8908]
        # self.hosts[ip].osname = 'Debian Linux'
        # return

        ps = await exec(f'nmap -O -oX - {ip}', stdout=PIPE)
        output = await ps.stdout.read()
        root = ET.fromstring(output)

        open_ports = []
        for port in root.findall(".//port/state[@state='open'].."):
            open_ports.append(port.attrib['portid'])

        self.hosts[ip].ports = [int(x) for x in open_ports]
        self.hosts[ip].osname = root.findall(".//osmatch")[0].attrib['name']

    def print_fingerprints(self):
        for h in self.hosts.values():
            h.str_ports = str.join(',', [str(p) for p in h.ports])

        longest_ports = max(len(h.str_ports) for h in self.hosts.values())
        longest_os = max(len(h.osname) for h in self.hosts.values())
        other = 10 + self.longest_ip + longest_ports
        os_len = min(TERMINAL_WIDTH - other, longest_os)

        for h in sorted(self.hosts.values()):
            osname = text_ellipsis(h.osname, width=os_len)
            print(f"{h.ip:>{self.longest_ip}} - {h.str_ports:<{longest_ports}} - {osname}")

    async def async_run(self):       
        print("\n-scanning neighbors... ", end='\r', flush='')
        self.loop = asyncio.get_running_loop()

        await async_map(ping, self.iface.network.hosts())

        time.sleep(6)
        self.hosts = get_arp_table()
        self.add_vendors()

        await async_map(self.add_name, self.hosts)
        self.print_hosts()

        if os.getuid() != 0:
            print("\n(run this as root program to perform nmap fingerprinting)")
            return

        if not shutil.which('nmap'):
            print("\n-command 'nmap' is not installed, install it for fingerprinting.")
            return

        print("\n-root mode: fingerprinting ... ", end='\r', flush='')
        await async_map(self.guess_os, self.hosts)
        self.print_fingerprints()

    def run(self):
        print(f"local address: {self.iface.with_prefixlen} ({self.iface.ifname})")
        asyncio.run(self.async_run())



try:
    Scanner().run()
except KeyboardInterrupt:
    print("C-c pressed" + 20 * ' ')
finally:
    restore_echo()
