#!/usr/bin/env python3
import argparse
import ifaddr

import logging
import sys
import socket

from scapy.all import sniff
from rstconn import __version__
from rstconn.rstconn import send_reset, log_packet, is_packet_tcp_client_to_server, is_packet_on_tcp_conn, is_adapter_localhost


LOCAL_IFACES = [
    adapter.name for adapter in ifaddr.get_adapters()
    #if is_adapter_localhost(adapter, localhost_ip)
]
logger = logging.getLogger(__name__)


parser = argparse.ArgumentParser(
    description=f"{__file__} kills network connections.",
    epilog=f"{__file__} arguments",
    formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
    '--iface', '-i',
    required=False,
    choices=LOCAL_IFACES,
    default=LOCAL_IFACES[0],
    help="Interface where to listen to"
)
parser.add_argument(
    '--server-ip', '-sip',
    required=False,
    help="IPv4 or hostname"
)
parser.add_argument(
    '--client-ip', '-cip',
    required=False,
    help="IPv4 or hostname"
)
parser.add_argument(
    '--server-port', '-p',
    type=int,
    required=False,
    help="Server port"
)
parser.add_argument(
    '--packet-count', '-pc',
    required=False,
    type=int,
    default=50,
    help="listen for a maximum of N packets"
)
parser.add_argument(
    '--seq-jitter', '-sj',
    required=False,
    type=int,
    default=0,
    help="""Set seq_jitter to be non-zero in order to prove to yourself that the
sequence number of a RST segment does indeed need to be exactly equal
to the last sequence number ACK-ed by the receiver"""
)
parser.add_argument(
    '--ignore-syn', '-is',
    required=False,
    action="store_true",
    help="if a Packet has SYN flag, not sending RST"
)
parser.add_argument(
    '--window-size', '-ws',
    required=False,
    type=int,
    default=2052,
    help="Window size"
)
parser.add_argument(
    '-d', '--debug', required=False,
    choices=('CRITICAL', 'ERROR',
             'WARNING', 'INFO', 'DEBUG'),
    default='INFO',
    help="Debug level, see python logging; defaults to INFO if omitted"
)
parser.add_argument(
    '-m', '--monitor', required=False,
    action="store_true",
    default=False,
    help="Just sniff traffic without sendin RST"
)
parser.add_argument(
    '-v', '--version', required=False,
    action="store_true",
    help="Print version and exit"
)

_args = parser.parse_args()
logging.basicConfig(
    level=getattr(logging, _args.debug),
    format= '%(levelname)-2s %(message)s',
)

if _args.version:
    sys.exit(f'{__version__}')
elif not _args.server_ip and not _args.client_ip:
    logger.error("You must define least a client or a server ip.")
    sys.exit(1)
elif _args.monitor:
    _func = log_packet
else:
    _func = send_reset(
            _args.iface,
            seq_jitter=_args.seq_jitter,
            ignore_syn=_args.ignore_syn,
            window_size=_args.window_size,
            verbose=True if _args.debug == 'DEBUG' else False
        )

if _args.server_ip:
    server_ip = socket.gethostbyname(_args.server_ip)
    logger.info(f"Resolving {_args.server_ip} to {server_ip}")
else:
    server_ip = None

if _args.client_ip:
    client_ip = socket.gethostbyname(_args.client_ip)
    logger.info(f"Resolving {_args.client_ip} to {client_ip}")
else:
    client_ip = None


logger.info(f"Starting sniffing for the target {_func}...")
t = sniff(
        iface=_args.iface,
        count=_args.packet_count,
        prn=_func,
        lfilter=is_packet_on_tcp_conn(
            server_ip = server_ip,
            client_ip = client_ip,
            server_port = _args.server_port
            )
    )

if t.res:
    logger.info(f"Sent {_args.packet_count} SYN/ACK + RST packets in response of the following packets: \n")
    t.nsummary()
