#!/usr/bin/env python3
import argparse
import ifaddr

import logging
import sys
import socket

from scapy.all import sniff
from rstconn import __version__
from rstconn.rstconn import (
    send_reset,
    log_packet,
    is_packet_tcp_client_to_server,
    is_packet_on_tcp_conn,
    is_adapter_localhost
)


LOCAL_IFACES = [
    adapter.name for adapter in ifaddr.get_adapters()
    #if is_adapter_localhost(adapter, localhost_ip)
]
logger = logging.getLogger(__name__)


parser = argparse.ArgumentParser(
    description=f"{__file__} kills network connections.",
    epilog=f"{__file__} arguments",
    formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
    '--iface', '-i',
    required=False,
    choices=LOCAL_IFACES,
    default=LOCAL_IFACES[0],
    help="Interface where to listen to"
)
parser.add_argument(
    '--server-ip', '-sip',
    required=False,
    help="IPv4 or hostname"
)
parser.add_argument(
    '--client-ip', '-cip',
    required=False,
    help="IPv4 or hostname"
)
parser.add_argument(
    '--server-port', '-p',
    type=int,
    required=False,
    help="Server port"
)
parser.add_argument(
    '--packet-count', '-pc',
    required=False,
    type=int,
    default=50,
    help="sends N RST packets"
)
parser.add_argument(
    '--seq-jitter', '-sj',
    required=False,
    type=int,
    default=0,
    help="""Set seq_jitter to be non-zero in order to prove to yourself that the
sequence number of a RST segment does indeed need to be exactly equal
to the last sequence number ACK-ed by the receiver"""
)
parser.add_argument(
    '--ignore-syn', '-is',
    required=False,
    action="store_true",
    help="if a Packet has SYN flag, not sending RST"
)
parser.add_argument(
    '--window-size', '-ws',
    required=False,
    type=int,
    default=2052,
    help="Window size"
)
parser.add_argument(
    '-d', '--debug', required=False,
    choices=('CRITICAL', 'ERROR',
             'WARNING', 'INFO', 'DEBUG'),
    default='INFO',
    help="Debug level, see python logging; defaults to INFO if omitted"
)
parser.add_argument(
    '-m', '--monitor', required=False,
    action="store_true",
    default=False,
    help="Just sniff traffic, for the gievn filter, without sendin RST"
)
parser.add_argument(
    '-ma', '--monitor-all', required=False,
    action="store_true",
    default=False,
    help="Just sniff traffic without any filter"
)
parser.add_argument(
    '-v', '--version', required=False,
    action="store_true",
    help="Print version and exit"
)

_args = parser.parse_args()
logging.basicConfig(
    level=getattr(logging, _args.debug),
    format= '%(levelname)-2s %(message)s',
)

if _args.version:
    sys.exit(f'{__version__}')
elif _args.monitor or _args.monitor_all:
    _func = log_packet
else:
    _func = send_reset(
            _args.iface,
            seq_jitter=_args.seq_jitter,
            ignore_syn=_args.ignore_syn,
            window_size=_args.window_size,
            verbose=True if _args.debug == 'DEBUG' else False
        )

if _args.server_ip:
    server_ip = socket.gethostbyname(_args.server_ip)
    logger.info(f"Resolving {_args.server_ip} to {server_ip}")
else:
    server_ip = None

if _args.client_ip:
    client_ip = socket.gethostbyname(_args.client_ip)
    logger.info(f"Resolving {_args.client_ip} to {client_ip}")
else:
    client_ip = None


_lfilter = is_packet_on_tcp_conn(
    server_ip = server_ip,
    client_ip = client_ip,
    server_port = _args.server_port
)

sniff_data = dict(
    iface=_args.iface,
    count=_args.packet_count,
    prn=_func
)

if server_ip or client_ip or _args.server_port:
    sniff_data['lfilter'] = _lfilter


logger.info(f"Starting sniffing for the target {_func}...")
t = sniff(**sniff_data)

if t.res:
    logger.info(
        f"Sent {_args.packet_count} SYN/ACK + RST packets in response of the following packets: \n"
    )
    t.nsummary()
