#!/bin/python3

# Forward stdin to a subcommand. If EOF is read from stdin or
# stdin/stdout/stderr are closed or hungup, then give the command "timeout"
# seconds to complete before it is killed.
#
# The command is run in a separate process group. This is mostly to simplify
# killing the set of processes (if well-behaving). You can configure that with
# --setpgrp switch.

# usage: stdin-killer [-h] [--timeout TIMEOUT] [--debug DEBUG] [--signal SIGNAL] [--verbose] [--setpgrp {no,self,child}] command [arguments ...]
#
# wait for stdin EOF then kill forked subcommand
#
# positional arguments:
#   command            command to execute
#   arguments          arguments to command
#
# options:
#   -h, --help         show this help message and exit
#   --timeout TIMEOUT  time to wait for forked subcommand to willing terminate
#   --debug DEBUG      debug file
#   --signal SIGNAL    signal to send
#   --verbose          increase debugging
#  --setpgrp {no,self,child}
#                        create process group


import argparse
import fcntl
import logging
import os
import select
import signal
import struct
import subprocess
import sys
import time

NAME = "stdin-killer"

log = logging.getLogger(NAME)
PAGE_SIZE = 4096

POLL_HANGUP = select.POLLHUP | (select.POLLRDHUP if hasattr(select, 'POLLRDHUP') else 0) | select.POLLERR


def handle_event(poll, buffer, fd, event, p):
    if sigfdr == fd:
        b = os.read(sigfdr, 1)
        (signum,) = struct.unpack("B", b)
        log.debug("got signal %d", signum)
        try:
            p.wait(timeout=0)
            return True
        except subprocess.TimeoutExpired:
            pass
    elif 0 == fd:
        if event & POLL_HANGUP:
            log.debug("peer closed connection, waiting for process exit")
            poll.unregister(0)
            sys.stdin.close()
            if len(buffer) == 0 and p.stdin is not None:
                p.stdin.close()
                p.stdin = None
            return True
        elif event & select.POLLIN:
            b = os.read(0, PAGE_SIZE)
            if b == b"":
                log.debug("read EOF")
                poll.unregister(0)
                sys.stdin.close()
                if len(buffer) == 0:
                    p.stdin.close()
                return True
            if p.stdin is not None:
                buffer += b
                # ignore further POLLIN until buffer is written to p.stdin
                poll.register(0, POLL_HANGUP)
                poll.register(p.stdin.fileno(), select.POLLOUT)
    elif p.stdin is not None and p.stdin.fileno() == fd:
        assert event & select.POLLOUT
        b = buffer[:PAGE_SIZE]
        log.debug("sending %d bytes to process", len(b))
        try:
            n = p.stdin.write(b)
            p.stdin.flush()
            log.debug("wrote %d bytes", n)
            buffer = buffer[n:]
            poll.register(0, select.POLLIN | POLL_HANGUP)
            poll.unregister(p.stdin.fileno())
        except BrokenPipeError:
            log.debug("got SIGPIPE")
            poll.unregister(p.stdin.fileno())
            p.stdin.close()
            p.stdin = None
            return True
        except BlockingIOError:
            poll.register(p.stdin.fileno(), select.POLLOUT | POLL_HANGUP)
    elif 1 == fd:
        assert event & POLL_HANGUP
        log.debug("stdout pipe has closed")
        poll.unregister(1)
        return True
    elif 2 == fd:
        assert event & POLL_HANGUP
        log.debug("stderr pipe has closed")
        poll.unregister(2)
        return True
    else:
        assert False
    return False


def listen_for_events(sigfdr, p, timeout):
    poll = select.poll()
    # listen for data on stdin
    poll.register(0, select.POLLIN | POLL_HANGUP)
    # listen for stdout/stderr to be closed, if they are closed then my parent
    # is gone and I should expire the command and myself.
    poll.register(1, POLL_HANGUP)
    poll.register(2, POLL_HANGUP)
    # for SIGCHLD
    poll.register(sigfdr, select.POLLIN)
    buffer = bytearray()
    expired = 0.0
    while True:
        if expired > 0.0:
            since = time.monotonic() - expired
            wait = int((timeout - since) * 1000.0)
            if wait <= 0:
                return
        else:
            wait = 5000
        log.debug("polling for %d milliseconds", wait)
        events = poll.poll(wait)
        for fd, event in events:
            log.debug("event: (%d, %d)", fd, event)
            if handle_event(poll, buffer, fd, event, p):
                if p.returncode is not None:
                    return
                if expired == 0.0:
                    expired = time.monotonic()
                    log.info(
                        "expiration expected; waiting %d seconds for command to complete",
                        NS.timeout,
                    )


if __name__ == "__main__":
    signal.signal(signal.SIGPIPE, signal.SIG_IGN)
    try:
        (sigfdr, sigfdw) = os.pipe2(os.O_NONBLOCK | os.O_CLOEXEC)
    except AttributeError:
        # pipe2 is only available on "some flavors of Unix"
        # https://docs.python.org/3.10/library/os.html?highlight=pipe2#os.pipe2
        pipe_ends = os.pipe()
        for fd in pipe_ends:
            flags = fcntl.fcntl(fd, fcntl.F_GETFL)
            fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK | os.O_CLOEXEC)
        (sigfdr, sigfdw) = pipe_ends

    signal.set_wakeup_fd(sigfdw)

    def do_nothing(signum, frame):
        pass

    signal.signal(signal.SIGCHLD, do_nothing)

    P = argparse.ArgumentParser(
        description="wait for stdin EOF then kill forked subcommand"
    )
    P.add_argument(
        "--timeout",
        action="store",
        default=5,
        help="time to wait for forked subcommand to willing terminate",
        type=int,
    )
    P.add_argument("--debug", action="store", help="debug file", type=str)
    P.add_argument(
        "--signal",
        action="store",
        help="signal to send",
        type=int,
        default=signal.SIGKILL,
    )
    P.add_argument("--verbose", action="store_true", help="increase debugging")
    P.add_argument(
        "--setpgrp",
        action="store",
        choices=["no", "self", "child"],
        default="self",
        help="create process group",
    )
    P.add_argument(
        "cmd", metavar="command", type=str, nargs=1, help="command to execute"
    )
    P.add_argument(
        "args", metavar="arguments", type=str, nargs="*", help="arguments to command"
    )
    NS = P.parse_args()

    logargs = {}
    if NS.debug is not None:
        logargs["filename"] = NS.debug
    else:
        logargs["stream"] = sys.stderr
    if NS.verbose:
        logargs["level"] = logging.DEBUG
    else:
        logargs["level"] = logging.INFO
    logargs["format"] = f"%(asctime)s {NAME} %(levelname)s: %(message)s"
    logargs["datefmt"] = "%Y-%m-%dT%H:%M:%S"
    logging.basicConfig(**logargs)

    cargs = NS.cmd + NS.args
    popen_kwargs = {
        "stdin": subprocess.PIPE,
    }

    if NS.setpgrp == "self":
        pgrp = os.getpgrp()
        if pgrp != os.getpid():
            os.setpgrp()
            pgrp = os.getpgrp()
    elif NS.setpgrp == "child":
        popen_kwargs["preexec_fn"] = os.setpgrp
        pgrp = None
    elif NS.setpgrp == "no":
        pgrp = 0
    else:
        assert False

    log.debug("executing %s", cargs)
    p = subprocess.Popen(cargs, **popen_kwargs)
    if pgrp is None:
        pgrp = p.pid
    flags = fcntl.fcntl(p.stdin.fileno(), fcntl.F_GETFL)
    fcntl.fcntl(p.stdin.fileno(), fcntl.F_SETFL, flags | os.O_NONBLOCK)

    listen_for_events(sigfdr, p, NS.timeout)

    if p.returncode is None:
        log.error("timeout expired: sending signal %d to command and myself", NS.signal)
        if pgrp == 0:
            os.kill(p.pid, NS.signal)
        else:
            os.killpg(pgrp, NS.signal)  # should kill me too
        os.kill(os.getpid(), NS.signal)  # to exit abnormally with same signal
        log.error("signal did not cause termination, sending myself SIGKILL")
        os.kill(os.getpid(), signal.SIGKILL)  # failsafe
    rc = p.returncode
    log.debug("rc = %d", rc)
    assert rc is not None
    if rc < 0:
        log.error("command terminated with signal %d: sending same signal to myself!", -rc)
        os.kill(os.getpid(), -rc)  # kill myself with the same signal
        log.error("signal did not cause termination, sending myself SIGKILL")
        os.kill(os.getpid(), signal.SIGKILL)  # failsafe
    else:
        log.info("command exited with status %d: exiting normally with same code!", rc)
        sys.exit(rc)
